Skip to content

Commit

Permalink
[70B-Part4] Config and init_empty_weights (#117)
Browse files Browse the repository at this point in the history
* add .txt to prefetch_weights

* init_empty_weights for non-main procs in fsdp

* 70b config
  • Loading branch information
farzadab authored Sep 18, 2024
1 parent be8ee6b commit 51f21e9
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 2 deletions.
16 changes: 16 additions & 0 deletions ultravox/training/configs/llama_70b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Training configuration for Ultravox with Meta-Llama-3.1-70B-Instruct and Whisper-Medium
# This configuration is used alongside release_config.yaml

# TODO: make sure to increase max_duration in mcloud.yaml to 18 hours instead of 6

exp_name: "ultravox-v0_4-llama3.1-70B-whisper_m"

text_model: "meta-llama/Meta-Llama-3.1-70B-Instruct"
audio_model: "openai/whisper-medium"

batch_size: 5
# We increase the number of steps by 2x, but with a lower batch_size, we still won't be training on as many samples as the 8B model
# We would revisit this later on when
max_steps: 28800 # x8x5 = 1,152,000

do_eval: False # evaluation doesn't support FSDP yet
2 changes: 1 addition & 1 deletion ultravox/training/helpers/prefetch_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ultravox.model import wandb_utils
from ultravox.training import config_base

ALLOW_PATTERNS = ["*.safetensors", "*.json"]
ALLOW_PATTERNS = ["*.safetensors", "*.json", "*.txt"]


def main(override_sys_args: Optional[List[str]] = None):
Expand Down
14 changes: 13 additions & 1 deletion ultravox/training/train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import copy
import dataclasses
import gc
Expand All @@ -8,6 +9,7 @@
from datetime import datetime
from typing import Dict, List, Optional

import accelerate
import datasets as hf_datasets
import pandas as pd
import safetensors.torch
Expand Down Expand Up @@ -128,7 +130,17 @@ def train(args: config_base.TrainConfig):
)

logging.info("Instantiating model...")
model = ultravox_model.UltravoxModel(config)

model_load_context = (
accelerate.init_empty_weights()
if args.use_fsdp and not is_master
else contextlib.nullcontext()
)
# If we're using FSDP, we can just initialize the model on the main process
# and use sync_model_states to distribute the weights to the other processes.
# Otherwise we'd be loading the model on every process, which uses too much CPU memory.
with model_load_context:
model = ultravox_model.UltravoxModel(config)

assert model.get_input_embeddings().num_embeddings == len(
text_tokenizer
Expand Down

0 comments on commit 51f21e9

Please sign in to comment.