Skip to content

Commit

Permalink
[70B-Part1] Prefetch weights separately (#106)
Browse files Browse the repository at this point in the history
* prefetch weights separately

* moving get_train_args to config

* add a test for prefetch_weights

* double check for prefetch_weights for local/test runs
  • Loading branch information
farzadab authored Sep 13, 2024
1 parent a426890 commit 7bd4eeb
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 25 deletions.
5 changes: 4 additions & 1 deletion mcloud.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ integrations:
scheduling:
max_duration: 6 # 6 hours max for jobs to avoid hanging jobs
command: >-
cd ultravox && poetry install --no-dev && poetry run torchrun --nproc_per_node=8 -m ultravox.training.train $TRAIN_ARGS
cd ultravox &&
poetry install --no-dev &&
poetry run python -m ultravox.training.helpers.prefetch_weights $TRAIN_ARGS &&
poetry run torchrun --nproc_per_node=8 -m ultravox.training.train $TRAIN_ARGS
env_variables:
MLFLOW_TRACKING_URI: databricks
UV_BRANCH: main
Expand Down
24 changes: 24 additions & 0 deletions ultravox/training/config_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import datetime
import logging
import os
import re
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -130,3 +132,25 @@ def __post_init__(self):
"LayerDrop cannot be used in DDP when encoder is not frozen. Disabling LayerDrop."
)
self.disable_layerdrop = True


def fix_hyphens(arg: str):
return re.sub(r"^--([^=]+)", lambda m: "--" + m.group(1).replace("-", "_"), arg)


def get_train_args(override_sys_args: Optional[List[str]] = None) -> TrainConfig:
"""
Parse the command line arguments and return a TrainConfig object.
Args:
override_sys_args: The command line arguments. If None, sys.argv[1:] is used.
This is mainly useful for testing.
"""
args = override_sys_args or sys.argv[1:]

return simple_parsing.parse(
config_class=TrainConfig,
config_path=os.path.join(os.path.dirname(__file__), "configs/meta_config.yaml"),
add_config_path_arg=True,
args=[fix_hyphens(arg) for arg in args],
)
53 changes: 53 additions & 0 deletions ultravox/training/helpers/prefetch_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from datetime import datetime
from typing import List, Optional

import huggingface_hub
import transformers

from ultravox.model import wandb_utils
from ultravox.training import config_base

ALLOW_PATTERNS = ["*.safetensors", "*.json"]


def main(override_sys_args: Optional[List[str]] = None):
start = datetime.now()
print("Downloading weights ...")

args = config_base.get_train_args(override_sys_args)

download_weights([args.text_model, args.audio_model], args.model_load_dir)

end = datetime.now()
print(f"Weights downloaded in {end - start} seconds")


def download_weights(model_ids: List[str], model_load_dir: Optional[str] = None):
for model_id in model_ids:
try:
# Download all model files that match ALLOW_PATTERNS
# This is faster than .from_pretrained due to parallel downloads
huggingface_hub.snapshot_download(
repo_id=model_id, allow_patterns=ALLOW_PATTERNS
)
except huggingface_hub.utils.GatedRepoError as e:
raise e
except huggingface_hub.utils.RepositoryNotFoundError as e:
# We assume that the model is local if it's not found on HF Hub.
# The `.from_pretrained` call will verify the local case.
print(
f"Model {model_id} not found on HF Hub. Skipping download. Error: {e}"
)

# A backstop to make sure the model is fully downloaded. Scenarios to consider:
# - ALLOW_PATTERNS is not enough to download all files needed
# - The model is local, this will verify that everything is in order
# Using `device_map="meta"` to avoid loading the weights into memory or device
transformers.AutoModel.from_pretrained(model_id, device_map="meta")

if model_load_dir and wandb_utils.is_wandb_url(model_load_dir):
wandb_utils.download_model_from_wandb(model_load_dir)


if __name__ == "__main__":
main()
14 changes: 14 additions & 0 deletions ultravox/training/helpers/prefetch_weights_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import transformers

from ultravox.training.helpers import prefetch_weights

TEXT_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM"
AUDIO_MODEL = "hf-internal-testing/tiny-random-WhisperForCausalLM"


def test_prefetch_weights():
prefetch_weights.main(["--text-model", TEXT_MODEL, "--audio-model", AUDIO_MODEL])

# With local_files_only=True, from_pretrained will throw an error if the weights are not downloaded
transformers.AutoModel.from_pretrained(TEXT_MODEL, local_files_only=True)
transformers.AutoModel.from_pretrained(AUDIO_MODEL, local_files_only=True)
43 changes: 19 additions & 24 deletions ultravox/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,13 @@
import glob
import logging
import os
import re
import subprocess
import sys
from datetime import datetime
from typing import Dict, List, Optional

import datasets as hf_datasets
import pandas as pd
import safetensors.torch
import simple_parsing
import torch
import torch.distributed
import transformers
Expand All @@ -30,15 +27,12 @@
from ultravox.model import wandb_utils
from ultravox.training import config_base
from ultravox.training import ddp_utils
from ultravox.training.helpers import prefetch_weights

INPUT_EXAMPLE = {"text": "Transcribe\n<|audio|>", "audio": b"\x00\x00" * 16000}
OUTPUT_EXAMPLE = {"text": "Hello, world!"}


def fix_hyphens(arg: str):
return re.sub(r"^--([^=]+)", lambda m: "--" + m.group(1).replace("-", "_"), arg)


def prepare_dataset(
train_args: config_base.TrainConfig,
dataset_names: List[str],
Expand Down Expand Up @@ -76,12 +70,7 @@ def main() -> None:
os.environ["WANDB_LOG_MODEL"] = "checkpoint"
os.environ["WANDB_PROJECT"] = "ultravox"

args = simple_parsing.parse(
config_class=config_base.TrainConfig,
config_path="ultravox/training/configs/meta_config.yaml", # base config file
add_config_path_arg=True,
args=[fix_hyphens(arg) for arg in sys.argv[1:]],
)
args = config_base.get_train_args()

transformers.set_seed(args.seed)

Expand All @@ -100,16 +89,25 @@ def train(args: config_base.TrainConfig):
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
is_master = local_rank == 0

if world_size > 1:
torch.distributed.init_process_group(backend="nccl")
is_distributed = world_size > 1

# DDP blows up logging, so this is an attempt to suppress it to only logs from the master process
logging.basicConfig(level=logging.INFO if is_master else logging.ERROR)
# os.environ["TORCH_LOGS"] = "ERROR" if is_master else "WARNING"
transformers.logging.set_verbosity(logging.WARNING if is_master else logging.ERROR)
hf_datasets.logging.set_verbosity(logging.WARNING if is_master else logging.ERROR)

if is_distributed:
torch.distributed.init_process_group(backend="nccl")

with ddp_utils.run_on_master_first(is_master):
# For larger models, we assume that the weights are already downloaded via prefetch_weights.py
# Otherwise the barrier call can timeout.
# This call is only here as a backstop in case prefetch_weights.py was not run, for example in a local/test run.
prefetch_weights.download_weights(
[args.text_model, args.audio_model], args.model_load_dir
)

logging.info("Instantiating processor...")
text_tokenizer: transformers.PreTrainedTokenizerFast = (
transformers.AutoTokenizer.from_pretrained(args.text_model)
Expand All @@ -128,11 +126,7 @@ def train(args: config_base.TrainConfig):
)

logging.info("Instantiating model...")

# Since the model downloads the language model and audio encoder weights, we want one process to finish up
# downloading before the others start in order to avoid race conditions.
with ddp_utils.run_on_master_first(is_master):
model = ultravox_model.UltravoxModel(config)
model = ultravox_model.UltravoxModel(config)

assert model.get_input_embeddings().num_embeddings == len(
text_tokenizer
Expand Down Expand Up @@ -166,9 +160,10 @@ def train(args: config_base.TrainConfig):
logging.info(f"Loading model state dict from {args.model_load_dir}")
load_path = args.model_load_dir
if wandb_utils.is_wandb_url(load_path):
# Download the model from W&B. The main process should do the download while the others wait.
with ddp_utils.run_on_master_first(is_master):
load_path = wandb_utils.download_model_from_wandb(load_path)
# We assume that the weights are already downloaded via prefetch_weights.py
# and hence this is just resolving the path. If the weights are not downloaded,
# we might see a race condition here when using DDP.
load_path = wandb_utils.download_model_from_wandb(load_path)
if os.path.isdir(load_path):
load_path = os.path.join(load_path, "model*.safetensors")
paths = glob.glob(load_path)
Expand Down

0 comments on commit 7bd4eeb

Please sign in to comment.