Skip to content

Commit

Permalink
Merge
Browse files Browse the repository at this point in the history
  • Loading branch information
liPatrick committed Sep 19, 2024
2 parents d427b38 + 51f21e9 commit b723dc6
Show file tree
Hide file tree
Showing 11 changed files with 299 additions and 68 deletions.
9 changes: 7 additions & 2 deletions mcloud.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@ integrations:
git_repo: fixie-ai/ultravox
git_branch: $UV_BRANCH
pip_install: poetry==1.7.1
scheduling:
max_duration: 6 # 6 hours max for jobs to avoid hanging jobs
command: >-
cd ultravox && poetry install --no-dev && poetry run torchrun --nproc_per_node=8 -m ultravox.training.train $TRAIN_ARGS
cd ultravox &&
poetry install --no-dev &&
poetry run python -m ultravox.training.helpers.prefetch_weights $TRAIN_ARGS &&
poetry run torchrun --nproc_per_node=8 -m ultravox.training.train $TRAIN_ARGS
env_variables:
MLFLOW_TRACKING_URI: databricks
UV_BRANCH: main
TRAIN_ARGS: --config_path ultravox/training/configs/release_config.yaml
TRAIN_ARGS: --config_path ultravox/training/configs/release_config.yaml
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ tensorboardx = "~2.6.2.2"
wandb = "~0.17.1"
sacrebleu = "^2.4.2"
tenacity = "^9.0.0"
evals = {git = "https://github.com/fixie-ai/evals"}
evals = {git = "https://github.com/fixie-ai/evals", rev = "0c66bf85df7a4b903ecb202b23c2a826b749fd71"}

[tool.poetry.group.dev.dependencies]
black = "~24.4.2"
Expand Down
13 changes: 11 additions & 2 deletions ultravox/model/ultravox_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def __init__(
audio_model_id: Optional[str] = None,
text_model_id: Optional[str] = None,
ignore_index: int = -100,
audio_token_index: int = 32000,
hidden_size: int = 4096,
stack_factor: int = 8,
norm_init: float = 0.4,
Expand All @@ -112,7 +111,6 @@ def __init__(

self.audio_model_id = audio_model_id
self.text_model_id = text_model_id
self.audio_token_index = audio_token_index

self.hidden_size = hidden_size
self.stack_factor = stack_factor
Expand Down Expand Up @@ -155,3 +153,14 @@ def __init__(
self.initializer_range = self.text_config.initializer_range

super().__init__(**kwargs)

def to_diff_dict(self) -> Dict[str, Any]:
diff_dict = super().to_diff_dict()

# remove text_config and audio_config if text_model_id and audio_model_id are present
if self.text_model_id is not None:
diff_dict.pop("text_config", None)
if self.audio_model_id is not None:
diff_dict.pop("audio_config", None)

return diff_dict
37 changes: 37 additions & 0 deletions ultravox/model/ultravox_config_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest
import transformers

from ultravox.model import ultravox_config


@pytest.mark.parametrize(
"model_id",
["fixie-ai/ultravox-v0_2", "fixie-ai/ultravox-v0_3", "fixie-ai/ultravox-v0_4"],
)
def test_can_load_release(model_id: str):
orig_config: transformers.PretrainedConfig = (
transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=True)
)
config_from_dict = ultravox_config.UltravoxConfig(**orig_config.to_dict())
config_from_diff_dict = ultravox_config.UltravoxConfig(**orig_config.to_diff_dict())

assert config_from_dict.to_dict() == orig_config.to_dict()
assert config_from_diff_dict.to_dict() == orig_config.to_dict()

assert config_from_dict.text_config.to_dict() == orig_config.text_config.to_dict()
assert config_from_dict.audio_config.to_dict() == orig_config.audio_config.to_dict()

config_reloaded = ultravox_config.UltravoxConfig(**config_from_dict.to_dict())
config_reloaded_diff = ultravox_config.UltravoxConfig(
**config_from_dict.to_diff_dict()
)
assert config_reloaded.to_dict() == orig_config.to_dict()
assert config_reloaded_diff.to_dict() == orig_config.to_dict()


def test_no_config_when_id_present():
config = ultravox_config.UltravoxConfig(audio_model_id="openai/whisper-small")
assert "audio_config" not in config.to_diff_dict()

config = ultravox_config.UltravoxConfig(text_model_id="microsoft/phi-2")
assert "text_config" not in config.to_diff_dict()
53 changes: 36 additions & 17 deletions ultravox/model/ultravox_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):

config_class = UltravoxConfig
config: UltravoxConfig # for type hinting
_no_split_modules = ["Wav2Vec2Model", "WhisperEncoder", "LlamaDecoderLayer"]
# We minimize the weights in state_dict in order to reduce the size of the checkpoint
# The issue is that load_pretrained() uses state_dict() keys to know what keys are expected
# As such we have to tell is to ignore some keys that are not always in the model
Expand All @@ -46,6 +45,7 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):

def __init__(self, config: UltravoxConfig):
super().__init__(config)
self._register_load_state_dict_pre_hook(self._pre_load_state_dict_hook)

self.keep_params: Set[str] = set()
self.vocab_size = config.vocab_size
Expand All @@ -55,9 +55,16 @@ def __init__(self, config: UltravoxConfig):
if config.audio_model_id is not None and "whisper" in config.audio_model_id:
self.audio_tower_context_length = 3000

self.multi_modal_projector = UltravoxProjector(config)
self.multi_modal_projector = self._create_multi_modal_projector(config)
self.language_model = self._create_language_model(config)

# Determine no_split_modules dynamically to use with FSDP auto_wrap policy.
# FSDP throws an error if some of the layer types are not found in the model.
# This would be something like ["LlamaDecoderLayer", "WhisperEncoderLayer"]
self._no_split_modules = (self.language_model._no_split_modules or []) + (
self.audio_tower._no_split_modules or []
)

self.loss_config = LossConfig()
self.post_init()

Expand Down Expand Up @@ -197,7 +204,7 @@ def forward(
), "audio_token_start_idx and audio_token_len and audio_batch_size must have the same batch size."

audio_tower_output = self.audio_tower.forward(
audio_values
audio_values.to(self.audio_tower.dtype)
).last_hidden_state
audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)
audio_embeds = self.multi_modal_projector.forward(audio_tower_output)
Expand Down Expand Up @@ -281,18 +288,26 @@ def prepare_inputs_for_generation(

return model_input

@classmethod
def _create_multi_modal_projector(
cls, config: UltravoxConfig
) -> "UltravoxProjector":
projector = UltravoxProjector(config)
projector.to(config.torch_dtype)
return projector

@classmethod
def _create_audio_tower(
cls, config: UltravoxConfig
) -> Union[transformers.Wav2Vec2Model, "ModifiedWhisperEncoder"]:
if config.audio_model_id is not None:
if "whisper" in config.audio_model_id is not None:
audio_tower = ModifiedWhisperEncoder.from_pretrained(
config.audio_model_id
config.audio_model_id, torch_dtype=config.torch_dtype
)
else:
audio_tower = transformers.AutoModel.from_pretrained(
config.audio_model_id
config.audio_model_id, torch_dtype=config.torch_dtype
)
else:
if "whisper" in config.audio_config._name_or_path:
Expand Down Expand Up @@ -323,14 +338,18 @@ def _create_language_model(
) -> transformers.LlamaForCausalLM:
if config.text_model_id is not None:
language_model = transformers.AutoModelForCausalLM.from_pretrained(
config.text_model_id, attn_implementation=config._attn_implementation
config.text_model_id,
attn_implementation=config._attn_implementation,
torch_dtype=config.torch_dtype,
)
else:
with transformers.modeling_utils.no_init_weights():
# we only ever use from_config if the weights are retrained, hence initializing is not
# required. This makes the model quite creation faster since init on CPU is quite slow.
language_model = transformers.AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
config.text_config,
attn_implementation=config._attn_implementation,
torch_dtype=config.torch_dtype,
)

language_model = apply_lora(language_model, config.text_model_lora_config)
Expand Down Expand Up @@ -372,26 +391,25 @@ def push_to_hub(self, *args, **kwargs):
self.to(self.language_model.dtype)
return super().push_to_hub(*args, **kwargs)

def state_dict(self, *args, **kwargs):
def save_pretrained(
self, *args, state_dict: Optional[Dict[str, Any]] = None, **kwargs
):
if state_dict is None:
state_dict = super().state_dict()

named_params = dict(self.named_parameters())
state_dict = super().state_dict(*args, **kwargs)

state_dict = {
k: v
for k, v in state_dict.items()
if k in self.keep_params
or (k in named_params and named_params[k].requires_grad)
}
return state_dict

def load_state_dict(
self,
state_dict: Dict[str, Any],
*args,
**kwargs,
):
super().save_pretrained(*args, state_dict=state_dict, **kwargs)

def _pre_load_state_dict_hook(self, state_dict: Dict[str, Any], *args, **kwargs):
self.keep_params.update(set(state_dict.keys()))
return super().load_state_dict(state_dict, *args, **kwargs)

def print_trainable_parameters(self):
"""
Expand Down Expand Up @@ -526,6 +544,7 @@ class ModifiedWhisperEncoder(whisper.WhisperEncoder):
"""

base_model_prefix = "model.encoder"
_no_split_modules = ["WhisperEncoderLayer"]

def forward(
self,
Expand Down
47 changes: 46 additions & 1 deletion ultravox/training/config_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import datetime
import logging
import os
import re
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -57,6 +59,9 @@ class TrainConfig:

device: str = "cuda"
data_type: str = "bfloat16"
# Whether to use FSDP (Fully Sharded Data Parallelism) for training
# needed for large model training (e.g. 70B+)
use_fsdp: bool = False
# Path to load the model from. Can be local path, HF hub model_id, or W&B artifact
model_load_dir: Optional[str] = None
text_model_lora_config: Optional[ultravox_config.LoraConfigSimplified] = None
Expand All @@ -70,7 +75,9 @@ class TrainConfig:
optimizer: str = "adamw_torch"
num_epochs: int = 1
max_steps: int = 0
val_steps: Optional[int] = None
# Run an evaluation every X steps. If smaller than 1, will be interpreted as ratio of total training steps.
val_steps: Optional[float] = None
# Save checkpoint every X steps. If smaller than 1, will be interpreted as ratio of total training steps.
save_steps: float = 0
logging_steps: int = 1
grad_accum_steps: int = 1
Expand Down Expand Up @@ -117,6 +124,10 @@ def __post_init__(self):
self.exp_name = datetime.datetime.now().strftime("exp--%Y-%m-%d--%H-%M-%S")
if self.output_dir is None:
self.output_dir = Path("runs") / self.exp_name

# HF Pipeline gets tripped up if the path has a "." in it
self.output_dir = Path(str(self.output_dir).replace(".", "--"))

if self.logs_dir is None:
self.logs_dir = self.output_dir / "logs"

Expand All @@ -130,3 +141,37 @@ def __post_init__(self):
"LayerDrop cannot be used in DDP when encoder is not frozen. Disabling LayerDrop."
)
self.disable_layerdrop = True

if self.use_fsdp and self.save_steps:
logging.warning(
"FSDP is enabled: Saving checkpoints is going to be extremely slow and results in a full save."
" Consider setting save_steps=0."
)

if self.use_fsdp and self.do_eval:
logging.warning(
"FSDP is enabled: Evaluation is not supported with FSDP. Disabling evaluation."
)
self.do_eval = False


def fix_hyphens(arg: str):
return re.sub(r"^--([^=]+)", lambda m: "--" + m.group(1).replace("-", "_"), arg)


def get_train_args(override_sys_args: Optional[List[str]] = None) -> TrainConfig:
"""
Parse the command line arguments and return a TrainConfig object.
Args:
override_sys_args: The command line arguments. If None, sys.argv[1:] is used.
This is mainly useful for testing.
"""
args = override_sys_args or sys.argv[1:]

return simple_parsing.parse(
config_class=TrainConfig,
config_path=os.path.join(os.path.dirname(__file__), "configs/meta_config.yaml"),
add_config_path_arg=True,
args=[fix_hyphens(arg) for arg in args],
)
16 changes: 16 additions & 0 deletions ultravox/training/configs/llama_70b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Training configuration for Ultravox with Meta-Llama-3.1-70B-Instruct and Whisper-Medium
# This configuration is used alongside release_config.yaml

# TODO: make sure to increase max_duration in mcloud.yaml to 18 hours instead of 6

exp_name: "ultravox-v0_4-llama3.1-70B-whisper_m"

text_model: "meta-llama/Meta-Llama-3.1-70B-Instruct"
audio_model: "openai/whisper-medium"

batch_size: 5
# We increase the number of steps by 2x, but with a lower batch_size, we still won't be training on as many samples as the 8B model
# We would revisit this later on when
max_steps: 28800 # x8x5 = 1,152,000

do_eval: False # evaluation doesn't support FSDP yet
53 changes: 53 additions & 0 deletions ultravox/training/helpers/prefetch_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from datetime import datetime
from typing import List, Optional

import huggingface_hub
import transformers

from ultravox.model import wandb_utils
from ultravox.training import config_base

ALLOW_PATTERNS = ["*.safetensors", "*.json", "*.txt"]


def main(override_sys_args: Optional[List[str]] = None):
start = datetime.now()
print("Downloading weights ...")

args = config_base.get_train_args(override_sys_args)

download_weights([args.text_model, args.audio_model], args.model_load_dir)

end = datetime.now()
print(f"Weights downloaded in {end - start} seconds")


def download_weights(model_ids: List[str], model_load_dir: Optional[str] = None):
for model_id in model_ids:
try:
# Download all model files that match ALLOW_PATTERNS
# This is faster than .from_pretrained due to parallel downloads
huggingface_hub.snapshot_download(
repo_id=model_id, allow_patterns=ALLOW_PATTERNS
)
except huggingface_hub.utils.GatedRepoError as e:
raise e
except huggingface_hub.utils.RepositoryNotFoundError as e:
# We assume that the model is local if it's not found on HF Hub.
# The `.from_pretrained` call will verify the local case.
print(
f"Model {model_id} not found on HF Hub. Skipping download. Error: {e}"
)

# A backstop to make sure the model is fully downloaded. Scenarios to consider:
# - ALLOW_PATTERNS is not enough to download all files needed
# - The model is local, this will verify that everything is in order
# Using `device_map="meta"` to avoid loading the weights into memory or device
transformers.AutoModel.from_pretrained(model_id, device_map="meta")

if model_load_dir and wandb_utils.is_wandb_url(model_load_dir):
wandb_utils.download_model_from_wandb(model_load_dir)


if __name__ == "__main__":
main()
Loading

0 comments on commit b723dc6

Please sign in to comment.