diff --git a/ultravox/training/configs/llama_70b.yaml b/ultravox/training/configs/llama_70b.yaml new file mode 100644 index 00000000..4e3a8432 --- /dev/null +++ b/ultravox/training/configs/llama_70b.yaml @@ -0,0 +1,16 @@ +# Training configuration for Ultravox with Meta-Llama-3.1-70B-Instruct and Whisper-Medium +# This configuration is used alongside release_config.yaml + +# TODO: make sure to increase max_duration in mcloud.yaml to 18 hours instead of 6 + +exp_name: "ultravox-v0_4-llama3.1-70B-whisper_m" + +text_model: "meta-llama/Meta-Llama-3.1-70B-Instruct" +audio_model: "openai/whisper-medium" + +batch_size: 5 +# We increase the number of steps by 2x, but with a lower batch_size, we still won't be training on as many samples as the 8B model +# We would revisit this later on when +max_steps: 28800 # x8x5 = 1,152,000 + +do_eval: False # evaluation doesn't support FSDP yet diff --git a/ultravox/training/helpers/prefetch_weights.py b/ultravox/training/helpers/prefetch_weights.py index 30449a38..87106de0 100644 --- a/ultravox/training/helpers/prefetch_weights.py +++ b/ultravox/training/helpers/prefetch_weights.py @@ -7,7 +7,7 @@ from ultravox.model import wandb_utils from ultravox.training import config_base -ALLOW_PATTERNS = ["*.safetensors", "*.json"] +ALLOW_PATTERNS = ["*.safetensors", "*.json", "*.txt"] def main(override_sys_args: Optional[List[str]] = None): diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 3b4e4287..5dde3f62 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -1,3 +1,4 @@ +import contextlib import copy import dataclasses import gc @@ -8,6 +9,7 @@ from datetime import datetime from typing import Dict, List, Optional +import accelerate import datasets as hf_datasets import pandas as pd import safetensors.torch @@ -128,7 +130,17 @@ def train(args: config_base.TrainConfig): ) logging.info("Instantiating model...") - model = ultravox_model.UltravoxModel(config) + + model_load_context = ( + accelerate.init_empty_weights() + if args.use_fsdp and not is_master + else contextlib.nullcontext() + ) + # If we're using FSDP, we can just initialize the model on the main process + # and use sync_model_states to distribute the weights to the other processes. + # Otherwise we'd be loading the model on every process, which uses too much CPU memory. + with model_load_context: + model = ultravox_model.UltravoxModel(config) assert model.get_input_embeddings().num_embeddings == len( text_tokenizer