diff --git a/mcloud.yaml b/mcloud.yaml index 0fbdfd67..91e089c1 100644 --- a/mcloud.yaml +++ b/mcloud.yaml @@ -12,7 +12,10 @@ integrations: scheduling: max_duration: 6 # 6 hours max for jobs to avoid hanging jobs command: >- - cd ultravox && poetry install --no-dev && poetry run torchrun --nproc_per_node=8 -m ultravox.training.train $TRAIN_ARGS + cd ultravox && + poetry install --no-dev && + poetry run python -m ultravox.training.helpers.prefetch_weights $TRAIN_ARGS && + poetry run torchrun --nproc_per_node=8 -m ultravox.training.train $TRAIN_ARGS env_variables: MLFLOW_TRACKING_URI: databricks UV_BRANCH: main diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index 232c0584..ec3ec58f 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -2,6 +2,8 @@ import datetime import logging import os +import re +import sys from pathlib import Path from typing import Any, Dict, List, Optional @@ -130,3 +132,25 @@ def __post_init__(self): "LayerDrop cannot be used in DDP when encoder is not frozen. Disabling LayerDrop." ) self.disable_layerdrop = True + + +def fix_hyphens(arg: str): + return re.sub(r"^--([^=]+)", lambda m: "--" + m.group(1).replace("-", "_"), arg) + + +def get_train_args(override_sys_args: Optional[List[str]] = None) -> TrainConfig: + """ + Parse the command line arguments and return a TrainConfig object. + + Args: + override_sys_args: The command line arguments. If None, sys.argv[1:] is used. + This is mainly useful for testing. + """ + args = override_sys_args or sys.argv[1:] + + return simple_parsing.parse( + config_class=TrainConfig, + config_path=os.path.join(os.path.dirname(__file__), "configs/meta_config.yaml"), + add_config_path_arg=True, + args=[fix_hyphens(arg) for arg in args], + ) diff --git a/ultravox/training/helpers/prefetch_weights.py b/ultravox/training/helpers/prefetch_weights.py new file mode 100644 index 00000000..30449a38 --- /dev/null +++ b/ultravox/training/helpers/prefetch_weights.py @@ -0,0 +1,53 @@ +from datetime import datetime +from typing import List, Optional + +import huggingface_hub +import transformers + +from ultravox.model import wandb_utils +from ultravox.training import config_base + +ALLOW_PATTERNS = ["*.safetensors", "*.json"] + + +def main(override_sys_args: Optional[List[str]] = None): + start = datetime.now() + print("Downloading weights ...") + + args = config_base.get_train_args(override_sys_args) + + download_weights([args.text_model, args.audio_model], args.model_load_dir) + + end = datetime.now() + print(f"Weights downloaded in {end - start} seconds") + + +def download_weights(model_ids: List[str], model_load_dir: Optional[str] = None): + for model_id in model_ids: + try: + # Download all model files that match ALLOW_PATTERNS + # This is faster than .from_pretrained due to parallel downloads + huggingface_hub.snapshot_download( + repo_id=model_id, allow_patterns=ALLOW_PATTERNS + ) + except huggingface_hub.utils.GatedRepoError as e: + raise e + except huggingface_hub.utils.RepositoryNotFoundError as e: + # We assume that the model is local if it's not found on HF Hub. + # The `.from_pretrained` call will verify the local case. + print( + f"Model {model_id} not found on HF Hub. Skipping download. Error: {e}" + ) + + # A backstop to make sure the model is fully downloaded. Scenarios to consider: + # - ALLOW_PATTERNS is not enough to download all files needed + # - The model is local, this will verify that everything is in order + # Using `device_map="meta"` to avoid loading the weights into memory or device + transformers.AutoModel.from_pretrained(model_id, device_map="meta") + + if model_load_dir and wandb_utils.is_wandb_url(model_load_dir): + wandb_utils.download_model_from_wandb(model_load_dir) + + +if __name__ == "__main__": + main() diff --git a/ultravox/training/helpers/prefetch_weights_test.py b/ultravox/training/helpers/prefetch_weights_test.py new file mode 100644 index 00000000..19468c6b --- /dev/null +++ b/ultravox/training/helpers/prefetch_weights_test.py @@ -0,0 +1,14 @@ +import transformers + +from ultravox.training.helpers import prefetch_weights + +TEXT_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM" +AUDIO_MODEL = "hf-internal-testing/tiny-random-WhisperForCausalLM" + + +def test_prefetch_weights(): + prefetch_weights.main(["--text-model", TEXT_MODEL, "--audio-model", AUDIO_MODEL]) + + # With local_files_only=True, from_pretrained will throw an error if the weights are not downloaded + transformers.AutoModel.from_pretrained(TEXT_MODEL, local_files_only=True) + transformers.AutoModel.from_pretrained(AUDIO_MODEL, local_files_only=True) diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 05cea992..d74620e2 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -4,16 +4,13 @@ import glob import logging import os -import re import subprocess -import sys from datetime import datetime from typing import Dict, List, Optional import datasets as hf_datasets import pandas as pd import safetensors.torch -import simple_parsing import torch import torch.distributed import transformers @@ -30,15 +27,12 @@ from ultravox.model import wandb_utils from ultravox.training import config_base from ultravox.training import ddp_utils +from ultravox.training.helpers import prefetch_weights INPUT_EXAMPLE = {"text": "Transcribe\n<|audio|>", "audio": b"\x00\x00" * 16000} OUTPUT_EXAMPLE = {"text": "Hello, world!"} -def fix_hyphens(arg: str): - return re.sub(r"^--([^=]+)", lambda m: "--" + m.group(1).replace("-", "_"), arg) - - def prepare_dataset( train_args: config_base.TrainConfig, dataset_names: List[str], @@ -76,12 +70,7 @@ def main() -> None: os.environ["WANDB_LOG_MODEL"] = "checkpoint" os.environ["WANDB_PROJECT"] = "ultravox" - args = simple_parsing.parse( - config_class=config_base.TrainConfig, - config_path="ultravox/training/configs/meta_config.yaml", # base config file - add_config_path_arg=True, - args=[fix_hyphens(arg) for arg in sys.argv[1:]], - ) + args = config_base.get_train_args() transformers.set_seed(args.seed) @@ -100,9 +89,7 @@ def train(args: config_base.TrainConfig): world_size = int(os.environ.get("WORLD_SIZE", 1)) local_rank = int(os.environ.get("LOCAL_RANK", 0)) is_master = local_rank == 0 - - if world_size > 1: - torch.distributed.init_process_group(backend="nccl") + is_distributed = world_size > 1 # DDP blows up logging, so this is an attempt to suppress it to only logs from the master process logging.basicConfig(level=logging.INFO if is_master else logging.ERROR) @@ -110,6 +97,17 @@ def train(args: config_base.TrainConfig): transformers.logging.set_verbosity(logging.WARNING if is_master else logging.ERROR) hf_datasets.logging.set_verbosity(logging.WARNING if is_master else logging.ERROR) + if is_distributed: + torch.distributed.init_process_group(backend="nccl") + + with ddp_utils.run_on_master_first(is_master): + # For larger models, we assume that the weights are already downloaded via prefetch_weights.py + # Otherwise the barrier call can timeout. + # This call is only here as a backstop in case prefetch_weights.py was not run, for example in a local/test run. + prefetch_weights.download_weights( + [args.text_model, args.audio_model], args.model_load_dir + ) + logging.info("Instantiating processor...") text_tokenizer: transformers.PreTrainedTokenizerFast = ( transformers.AutoTokenizer.from_pretrained(args.text_model) @@ -128,11 +126,7 @@ def train(args: config_base.TrainConfig): ) logging.info("Instantiating model...") - - # Since the model downloads the language model and audio encoder weights, we want one process to finish up - # downloading before the others start in order to avoid race conditions. - with ddp_utils.run_on_master_first(is_master): - model = ultravox_model.UltravoxModel(config) + model = ultravox_model.UltravoxModel(config) assert model.get_input_embeddings().num_embeddings == len( text_tokenizer @@ -166,9 +160,10 @@ def train(args: config_base.TrainConfig): logging.info(f"Loading model state dict from {args.model_load_dir}") load_path = args.model_load_dir if wandb_utils.is_wandb_url(load_path): - # Download the model from W&B. The main process should do the download while the others wait. - with ddp_utils.run_on_master_first(is_master): - load_path = wandb_utils.download_model_from_wandb(load_path) + # We assume that the weights are already downloaded via prefetch_weights.py + # and hence this is just resolving the path. If the weights are not downloaded, + # we might see a race condition here when using DDP. + load_path = wandb_utils.download_model_from_wandb(load_path) if os.path.isdir(load_path): load_path = os.path.join(load_path, "model*.safetensors") paths = glob.glob(load_path)