diff --git a/mcloud.yaml b/mcloud.yaml index 99788954..91e089c1 100644 --- a/mcloud.yaml +++ b/mcloud.yaml @@ -9,9 +9,14 @@ integrations: git_repo: fixie-ai/ultravox git_branch: $UV_BRANCH pip_install: poetry==1.7.1 +scheduling: + max_duration: 6 # 6 hours max for jobs to avoid hanging jobs command: >- - cd ultravox && poetry install --no-dev && poetry run torchrun --nproc_per_node=8 -m ultravox.training.train $TRAIN_ARGS + cd ultravox && + poetry install --no-dev && + poetry run python -m ultravox.training.helpers.prefetch_weights $TRAIN_ARGS && + poetry run torchrun --nproc_per_node=8 -m ultravox.training.train $TRAIN_ARGS env_variables: MLFLOW_TRACKING_URI: databricks UV_BRANCH: main - TRAIN_ARGS: --config_path ultravox/training/configs/release_config.yaml \ No newline at end of file + TRAIN_ARGS: --config_path ultravox/training/configs/release_config.yaml diff --git a/poetry.lock b/poetry.lock index 363dfe51..6ee9b992 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1610,8 +1610,8 @@ torch = ["torch"] [package.source] type = "git" url = "https://github.com/fixie-ai/evals" -reference = "HEAD" -resolved_reference = "908a814c0881d99d4876d7000878ec0760a22311" +reference = "0c66bf85df7a4b903ecb202b23c2a826b749fd71" +resolved_reference = "0c66bf85df7a4b903ecb202b23c2a826b749fd71" [[package]] name = "evaluate" @@ -8883,4 +8883,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.10,<4.0" -content-hash = "70cda1d7eb287a6cd00d2cf8a285c5f2613fcdea125caf03951258ecad79289c" +content-hash = "798d26eeecb0625e6e6b655f7209286319924de573c9cdd9a30416593a492cb5" diff --git a/pyproject.toml b/pyproject.toml index 307f8a75..21c6734c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ tensorboardx = "~2.6.2.2" wandb = "~0.17.1" sacrebleu = "^2.4.2" tenacity = "^9.0.0" -evals = {git = "https://github.com/fixie-ai/evals"} +evals = {git = "https://github.com/fixie-ai/evals", rev = "0c66bf85df7a4b903ecb202b23c2a826b749fd71"} [tool.poetry.group.dev.dependencies] black = "~24.4.2" diff --git a/ultravox/model/ultravox_config.py b/ultravox/model/ultravox_config.py index ed294e35..fa77b142 100644 --- a/ultravox/model/ultravox_config.py +++ b/ultravox/model/ultravox_config.py @@ -99,7 +99,6 @@ def __init__( audio_model_id: Optional[str] = None, text_model_id: Optional[str] = None, ignore_index: int = -100, - audio_token_index: int = 32000, hidden_size: int = 4096, stack_factor: int = 8, norm_init: float = 0.4, @@ -112,7 +111,6 @@ def __init__( self.audio_model_id = audio_model_id self.text_model_id = text_model_id - self.audio_token_index = audio_token_index self.hidden_size = hidden_size self.stack_factor = stack_factor @@ -155,3 +153,14 @@ def __init__( self.initializer_range = self.text_config.initializer_range super().__init__(**kwargs) + + def to_diff_dict(self) -> Dict[str, Any]: + diff_dict = super().to_diff_dict() + + # remove text_config and audio_config if text_model_id and audio_model_id are present + if self.text_model_id is not None: + diff_dict.pop("text_config", None) + if self.audio_model_id is not None: + diff_dict.pop("audio_config", None) + + return diff_dict diff --git a/ultravox/model/ultravox_config_test.py b/ultravox/model/ultravox_config_test.py new file mode 100644 index 00000000..ebb5cd29 --- /dev/null +++ b/ultravox/model/ultravox_config_test.py @@ -0,0 +1,37 @@ +import pytest +import transformers + +from ultravox.model import ultravox_config + + +@pytest.mark.parametrize( + "model_id", + ["fixie-ai/ultravox-v0_2", "fixie-ai/ultravox-v0_3", "fixie-ai/ultravox-v0_4"], +) +def test_can_load_release(model_id: str): + orig_config: transformers.PretrainedConfig = ( + transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=True) + ) + config_from_dict = ultravox_config.UltravoxConfig(**orig_config.to_dict()) + config_from_diff_dict = ultravox_config.UltravoxConfig(**orig_config.to_diff_dict()) + + assert config_from_dict.to_dict() == orig_config.to_dict() + assert config_from_diff_dict.to_dict() == orig_config.to_dict() + + assert config_from_dict.text_config.to_dict() == orig_config.text_config.to_dict() + assert config_from_dict.audio_config.to_dict() == orig_config.audio_config.to_dict() + + config_reloaded = ultravox_config.UltravoxConfig(**config_from_dict.to_dict()) + config_reloaded_diff = ultravox_config.UltravoxConfig( + **config_from_dict.to_diff_dict() + ) + assert config_reloaded.to_dict() == orig_config.to_dict() + assert config_reloaded_diff.to_dict() == orig_config.to_dict() + + +def test_no_config_when_id_present(): + config = ultravox_config.UltravoxConfig(audio_model_id="openai/whisper-small") + assert "audio_config" not in config.to_diff_dict() + + config = ultravox_config.UltravoxConfig(text_model_id="microsoft/phi-2") + assert "text_config" not in config.to_diff_dict() diff --git a/ultravox/model/ultravox_model.py b/ultravox/model/ultravox_model.py index 97b6c3d1..57719c1d 100644 --- a/ultravox/model/ultravox_model.py +++ b/ultravox/model/ultravox_model.py @@ -34,7 +34,6 @@ class UltravoxModel(transformers.LlamaPreTrainedModel): config_class = UltravoxConfig config: UltravoxConfig # for type hinting - _no_split_modules = ["Wav2Vec2Model", "WhisperEncoder", "LlamaDecoderLayer"] # We minimize the weights in state_dict in order to reduce the size of the checkpoint # The issue is that load_pretrained() uses state_dict() keys to know what keys are expected # As such we have to tell is to ignore some keys that are not always in the model @@ -46,6 +45,7 @@ class UltravoxModel(transformers.LlamaPreTrainedModel): def __init__(self, config: UltravoxConfig): super().__init__(config) + self._register_load_state_dict_pre_hook(self._pre_load_state_dict_hook) self.keep_params: Set[str] = set() self.vocab_size = config.vocab_size @@ -55,9 +55,16 @@ def __init__(self, config: UltravoxConfig): if config.audio_model_id is not None and "whisper" in config.audio_model_id: self.audio_tower_context_length = 3000 - self.multi_modal_projector = UltravoxProjector(config) + self.multi_modal_projector = self._create_multi_modal_projector(config) self.language_model = self._create_language_model(config) + # Determine no_split_modules dynamically to use with FSDP auto_wrap policy. + # FSDP throws an error if some of the layer types are not found in the model. + # This would be something like ["LlamaDecoderLayer", "WhisperEncoderLayer"] + self._no_split_modules = (self.language_model._no_split_modules or []) + ( + self.audio_tower._no_split_modules or [] + ) + self.loss_config = LossConfig() self.post_init() @@ -197,7 +204,7 @@ def forward( ), "audio_token_start_idx and audio_token_len and audio_batch_size must have the same batch size." audio_tower_output = self.audio_tower.forward( - audio_values + audio_values.to(self.audio_tower.dtype) ).last_hidden_state audio_tower_output = audio_tower_output.to(inputs_embeds.dtype) audio_embeds = self.multi_modal_projector.forward(audio_tower_output) @@ -281,6 +288,14 @@ def prepare_inputs_for_generation( return model_input + @classmethod + def _create_multi_modal_projector( + cls, config: UltravoxConfig + ) -> "UltravoxProjector": + projector = UltravoxProjector(config) + projector.to(config.torch_dtype) + return projector + @classmethod def _create_audio_tower( cls, config: UltravoxConfig @@ -288,11 +303,11 @@ def _create_audio_tower( if config.audio_model_id is not None: if "whisper" in config.audio_model_id is not None: audio_tower = ModifiedWhisperEncoder.from_pretrained( - config.audio_model_id + config.audio_model_id, torch_dtype=config.torch_dtype ) else: audio_tower = transformers.AutoModel.from_pretrained( - config.audio_model_id + config.audio_model_id, torch_dtype=config.torch_dtype ) else: if "whisper" in config.audio_config._name_or_path: @@ -323,14 +338,18 @@ def _create_language_model( ) -> transformers.LlamaForCausalLM: if config.text_model_id is not None: language_model = transformers.AutoModelForCausalLM.from_pretrained( - config.text_model_id, attn_implementation=config._attn_implementation + config.text_model_id, + attn_implementation=config._attn_implementation, + torch_dtype=config.torch_dtype, ) else: with transformers.modeling_utils.no_init_weights(): # we only ever use from_config if the weights are retrained, hence initializing is not # required. This makes the model quite creation faster since init on CPU is quite slow. language_model = transformers.AutoModelForCausalLM.from_config( - config.text_config, attn_implementation=config._attn_implementation + config.text_config, + attn_implementation=config._attn_implementation, + torch_dtype=config.torch_dtype, ) language_model = apply_lora(language_model, config.text_model_lora_config) @@ -372,9 +391,13 @@ def push_to_hub(self, *args, **kwargs): self.to(self.language_model.dtype) return super().push_to_hub(*args, **kwargs) - def state_dict(self, *args, **kwargs): + def save_pretrained( + self, *args, state_dict: Optional[Dict[str, Any]] = None, **kwargs + ): + if state_dict is None: + state_dict = super().state_dict() + named_params = dict(self.named_parameters()) - state_dict = super().state_dict(*args, **kwargs) state_dict = { k: v @@ -382,16 +405,11 @@ def state_dict(self, *args, **kwargs): if k in self.keep_params or (k in named_params and named_params[k].requires_grad) } - return state_dict - def load_state_dict( - self, - state_dict: Dict[str, Any], - *args, - **kwargs, - ): + super().save_pretrained(*args, state_dict=state_dict, **kwargs) + + def _pre_load_state_dict_hook(self, state_dict: Dict[str, Any], *args, **kwargs): self.keep_params.update(set(state_dict.keys())) - return super().load_state_dict(state_dict, *args, **kwargs) def print_trainable_parameters(self): """ @@ -526,6 +544,7 @@ class ModifiedWhisperEncoder(whisper.WhisperEncoder): """ base_model_prefix = "model.encoder" + _no_split_modules = ["WhisperEncoderLayer"] def forward( self, diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index 232c0584..cfe2dc47 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -2,6 +2,8 @@ import datetime import logging import os +import re +import sys from pathlib import Path from typing import Any, Dict, List, Optional @@ -57,6 +59,9 @@ class TrainConfig: device: str = "cuda" data_type: str = "bfloat16" + # Whether to use FSDP (Fully Sharded Data Parallelism) for training + # needed for large model training (e.g. 70B+) + use_fsdp: bool = False # Path to load the model from. Can be local path, HF hub model_id, or W&B artifact model_load_dir: Optional[str] = None text_model_lora_config: Optional[ultravox_config.LoraConfigSimplified] = None @@ -70,7 +75,9 @@ class TrainConfig: optimizer: str = "adamw_torch" num_epochs: int = 1 max_steps: int = 0 - val_steps: Optional[int] = None + # Run an evaluation every X steps. If smaller than 1, will be interpreted as ratio of total training steps. + val_steps: Optional[float] = None + # Save checkpoint every X steps. If smaller than 1, will be interpreted as ratio of total training steps. save_steps: float = 0 logging_steps: int = 1 grad_accum_steps: int = 1 @@ -117,6 +124,10 @@ def __post_init__(self): self.exp_name = datetime.datetime.now().strftime("exp--%Y-%m-%d--%H-%M-%S") if self.output_dir is None: self.output_dir = Path("runs") / self.exp_name + + # HF Pipeline gets tripped up if the path has a "." in it + self.output_dir = Path(str(self.output_dir).replace(".", "--")) + if self.logs_dir is None: self.logs_dir = self.output_dir / "logs" @@ -130,3 +141,37 @@ def __post_init__(self): "LayerDrop cannot be used in DDP when encoder is not frozen. Disabling LayerDrop." ) self.disable_layerdrop = True + + if self.use_fsdp and self.save_steps: + logging.warning( + "FSDP is enabled: Saving checkpoints is going to be extremely slow and results in a full save." + " Consider setting save_steps=0." + ) + + if self.use_fsdp and self.do_eval: + logging.warning( + "FSDP is enabled: Evaluation is not supported with FSDP. Disabling evaluation." + ) + self.do_eval = False + + +def fix_hyphens(arg: str): + return re.sub(r"^--([^=]+)", lambda m: "--" + m.group(1).replace("-", "_"), arg) + + +def get_train_args(override_sys_args: Optional[List[str]] = None) -> TrainConfig: + """ + Parse the command line arguments and return a TrainConfig object. + + Args: + override_sys_args: The command line arguments. If None, sys.argv[1:] is used. + This is mainly useful for testing. + """ + args = override_sys_args or sys.argv[1:] + + return simple_parsing.parse( + config_class=TrainConfig, + config_path=os.path.join(os.path.dirname(__file__), "configs/meta_config.yaml"), + add_config_path_arg=True, + args=[fix_hyphens(arg) for arg in args], + ) diff --git a/ultravox/training/configs/llama_70b.yaml b/ultravox/training/configs/llama_70b.yaml new file mode 100644 index 00000000..4e3a8432 --- /dev/null +++ b/ultravox/training/configs/llama_70b.yaml @@ -0,0 +1,16 @@ +# Training configuration for Ultravox with Meta-Llama-3.1-70B-Instruct and Whisper-Medium +# This configuration is used alongside release_config.yaml + +# TODO: make sure to increase max_duration in mcloud.yaml to 18 hours instead of 6 + +exp_name: "ultravox-v0_4-llama3.1-70B-whisper_m" + +text_model: "meta-llama/Meta-Llama-3.1-70B-Instruct" +audio_model: "openai/whisper-medium" + +batch_size: 5 +# We increase the number of steps by 2x, but with a lower batch_size, we still won't be training on as many samples as the 8B model +# We would revisit this later on when +max_steps: 28800 # x8x5 = 1,152,000 + +do_eval: False # evaluation doesn't support FSDP yet diff --git a/ultravox/training/helpers/prefetch_weights.py b/ultravox/training/helpers/prefetch_weights.py new file mode 100644 index 00000000..87106de0 --- /dev/null +++ b/ultravox/training/helpers/prefetch_weights.py @@ -0,0 +1,53 @@ +from datetime import datetime +from typing import List, Optional + +import huggingface_hub +import transformers + +from ultravox.model import wandb_utils +from ultravox.training import config_base + +ALLOW_PATTERNS = ["*.safetensors", "*.json", "*.txt"] + + +def main(override_sys_args: Optional[List[str]] = None): + start = datetime.now() + print("Downloading weights ...") + + args = config_base.get_train_args(override_sys_args) + + download_weights([args.text_model, args.audio_model], args.model_load_dir) + + end = datetime.now() + print(f"Weights downloaded in {end - start} seconds") + + +def download_weights(model_ids: List[str], model_load_dir: Optional[str] = None): + for model_id in model_ids: + try: + # Download all model files that match ALLOW_PATTERNS + # This is faster than .from_pretrained due to parallel downloads + huggingface_hub.snapshot_download( + repo_id=model_id, allow_patterns=ALLOW_PATTERNS + ) + except huggingface_hub.utils.GatedRepoError as e: + raise e + except huggingface_hub.utils.RepositoryNotFoundError as e: + # We assume that the model is local if it's not found on HF Hub. + # The `.from_pretrained` call will verify the local case. + print( + f"Model {model_id} not found on HF Hub. Skipping download. Error: {e}" + ) + + # A backstop to make sure the model is fully downloaded. Scenarios to consider: + # - ALLOW_PATTERNS is not enough to download all files needed + # - The model is local, this will verify that everything is in order + # Using `device_map="meta"` to avoid loading the weights into memory or device + transformers.AutoModel.from_pretrained(model_id, device_map="meta") + + if model_load_dir and wandb_utils.is_wandb_url(model_load_dir): + wandb_utils.download_model_from_wandb(model_load_dir) + + +if __name__ == "__main__": + main() diff --git a/ultravox/training/helpers/prefetch_weights_test.py b/ultravox/training/helpers/prefetch_weights_test.py new file mode 100644 index 00000000..19468c6b --- /dev/null +++ b/ultravox/training/helpers/prefetch_weights_test.py @@ -0,0 +1,14 @@ +import transformers + +from ultravox.training.helpers import prefetch_weights + +TEXT_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM" +AUDIO_MODEL = "hf-internal-testing/tiny-random-WhisperForCausalLM" + + +def test_prefetch_weights(): + prefetch_weights.main(["--text-model", TEXT_MODEL, "--audio-model", AUDIO_MODEL]) + + # With local_files_only=True, from_pretrained will throw an error if the weights are not downloaded + transformers.AutoModel.from_pretrained(TEXT_MODEL, local_files_only=True) + transformers.AutoModel.from_pretrained(AUDIO_MODEL, local_files_only=True) diff --git a/ultravox/training/train.py b/ultravox/training/train.py index ee79934b..6b337752 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -1,19 +1,18 @@ +import contextlib import copy import dataclasses import gc import glob import logging import os -import re import subprocess -import sys from datetime import datetime from typing import Dict, List, Optional +import accelerate import datasets as hf_datasets import pandas as pd import safetensors.torch -import simple_parsing import torch import torch.distributed import transformers @@ -30,15 +29,12 @@ from ultravox.model import wandb_utils from ultravox.training import config_base from ultravox.training import ddp_utils +from ultravox.training.helpers import prefetch_weights INPUT_EXAMPLE = {"text": "Transcribe\n<|audio|>", "audio": b"\x00\x00" * 16000} OUTPUT_EXAMPLE = {"text": "Hello, world!"} -def fix_hyphens(arg: str): - return re.sub(r"^--([^=]+)", lambda m: "--" + m.group(1).replace("-", "_"), arg) - - def prepare_dataset( train_args: config_base.TrainConfig, dataset_names: List[str], @@ -76,12 +72,7 @@ def main() -> None: os.environ["WANDB_LOG_MODEL"] = "checkpoint" os.environ["WANDB_PROJECT"] = "ultravox" - args = simple_parsing.parse( - config_class=config_base.TrainConfig, - config_path="ultravox/training/configs/meta_config.yaml", # base config file - add_config_path_arg=True, - args=[fix_hyphens(arg) for arg in sys.argv[1:]], - ) + args = config_base.get_train_args() transformers.set_seed(args.seed) @@ -100,9 +91,7 @@ def train(args: config_base.TrainConfig): world_size = int(os.environ.get("WORLD_SIZE", 1)) local_rank = int(os.environ.get("LOCAL_RANK", 0)) is_master = local_rank == 0 - - if world_size > 1: - torch.distributed.init_process_group(backend="nccl") + is_distributed = world_size > 1 # DDP blows up logging, so this is an attempt to suppress it to only logs from the master process logging.basicConfig(level=logging.INFO if is_master else logging.ERROR) @@ -110,6 +99,17 @@ def train(args: config_base.TrainConfig): transformers.logging.set_verbosity(logging.WARNING if is_master else logging.ERROR) hf_datasets.logging.set_verbosity(logging.WARNING if is_master else logging.ERROR) + if is_distributed: + torch.distributed.init_process_group(backend="nccl") + + with ddp_utils.run_on_master_first(is_master): + # For larger models, we assume that the weights are already downloaded via prefetch_weights.py + # Otherwise the barrier call can timeout. + # This call is only here as a backstop in case prefetch_weights.py was not run, for example in a local/test run. + prefetch_weights.download_weights( + [args.text_model, args.audio_model], args.model_load_dir + ) + logging.info("Instantiating processor...") text_tokenizer: transformers.PreTrainedTokenizerFast = ( transformers.AutoTokenizer.from_pretrained(args.text_model) @@ -124,13 +124,21 @@ def train(args: config_base.TrainConfig): text_model_id=args.text_model, text_model_lora_config=args.text_model_lora_config, audio_model_lora_config=args.audio_model_lora_config, + torch_dtype=args.data_type, + pad_token_id=text_tokenizer.eos_token_id, ) logging.info("Instantiating model...") - # Since the model downloads the language model and audio encoder weights, we want one process to finish up - # downloading before the others start in order to avoid race conditions. - with ddp_utils.run_on_master_first(is_master): + model_load_context = ( + accelerate.init_empty_weights() + if args.use_fsdp and not is_master + else contextlib.nullcontext() + ) + # If we're using FSDP, we can just initialize the model on the main process + # and use sync_model_states to distribute the weights to the other processes. + # Otherwise we'd be loading the model on every process, which uses too much CPU memory. + with model_load_context: model = ultravox_model.UltravoxModel(config) processor = ultravox_processing.UltravoxProcessor( @@ -171,9 +179,10 @@ def train(args: config_base.TrainConfig): logging.info(f"Loading model state dict from {args.model_load_dir}") load_path = args.model_load_dir if wandb_utils.is_wandb_url(load_path): - # Download the model from W&B. The main process should do the download while the others wait. - with ddp_utils.run_on_master_first(is_master): - load_path = wandb_utils.download_model_from_wandb(load_path) + # We assume that the weights are already downloaded via prefetch_weights.py + # and hence this is just resolving the path. If the weights are not downloaded, + # we might see a race condition here when using DDP. + load_path = wandb_utils.download_model_from_wandb(load_path) if os.path.isdir(load_path): load_path = os.path.join(load_path, "model*.safetensors") paths = glob.glob(load_path) @@ -188,13 +197,10 @@ def train(args: config_base.TrainConfig): model.print_trainable_parameters() - # Move the model to GPU and enable bfloat16 - dtype = getattr(torch, args.data_type) - device = torch.device(args.device, index=local_rank) - logging.info( - f"Using dtype and device (world_size): {dtype}, {device} ({world_size})" - ) - model.to(device=device, dtype=dtype) + if not args.use_fsdp: + # Moving to device in FSDP is handled by the Trainer + model.to(device=torch.device(args.device, index=local_rank)) + logging.info(f"Using device (world_size): {model.device} ({world_size})") # Prepare dataset, subsetting if needed train_dataset: data.IterableDataset @@ -280,9 +286,9 @@ def train(args: config_base.TrainConfig): optim=args.optimizer, num_train_epochs=args.num_epochs, max_steps=args.max_steps, - evaluation_strategy="steps", + eval_strategy="steps" if args.val_steps else "no", eval_steps=args.val_steps, - save_strategy="steps", + save_strategy="steps" if args.save_steps else "no", save_steps=args.save_steps, logging_first_step=True, logging_dir=args.logs_dir, @@ -299,14 +305,19 @@ def train(args: config_base.TrainConfig): lr_scheduler_type=args.lr_scheduler, warmup_steps=args.lr_warmup_steps, weight_decay=args.weight_decay, - fp16=dtype == torch.float16, - bf16=dtype == torch.bfloat16, + # fp16=dtype == torch.float16, + # bf16=dtype == torch.bfloat16, use_cpu=args.device == "cpu", seed=args.seed + local_rank, report_to=args.report_logs_to, # torch_compile=True, - # fsdp="full_shard auto_wrap", - # fsdp_transformer_layer_cls_to_wrap='LlamaDecoderLayer', + fsdp="full_shard auto_wrap" if args.use_fsdp else "", + fsdp_config={ + "backward_prefetch": "backward_pre", + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "state_dict_type": "SHARDED_STATE_DICT", + "sync_module_states": "true", + }, ), ) @@ -315,19 +326,41 @@ def train(args: config_base.TrainConfig): logging.info("Starting training...") t_start = datetime.now() logging.info(f"train start time: {t_start}") + if args.val_steps: - trainer.evaluate() + if args.use_fsdp: + logging.warning( + "FSDP is enabled: Skipping initial validation since model is not initialized." + ) + else: + trainer.evaluate() + trainer.train() t_end = datetime.now() logging.info(f"train end time: {t_end}") logging.info(f"elapsed: {t_end - t_start}") - if is_master: - # Saving the model using pipeline to ensure its code is saved - pipeline = ultravox_pipeline.UltravoxPipeline( - model, tokenizer=text_tokenizer, device=device - ) - pipeline.save_pretrained(args.output_dir) + if args.use_fsdp: + # For training checkpoints, we want to use SHARDED_STATE_DICT which should be faster, + # but for the final save we want FULL_STATE_DICT so it can be serialized properly. + trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") + + # We use both pipeline.save_pretrained and trainer.save_model to save everything. + # This is because pipeline.save_pretrained knows how to save the pipeline (code and config), + # but it doesn't know how to save FSDP models correctly (the final tensors could be flattened). + # on the other hand, trainer.save_model knows how to save FSDP models correctly, but it won't save the pipeline. + # Saving FSDP models is already quite slow though, so we don't want to save the model twice. + pipeline = ultravox_pipeline.UltravoxPipeline( + model, tokenizer=text_tokenizer, device=model.device + ) + old_save_pretrained = model.save_pretrained + model.save_pretrained = lambda *_, **__: None # type: ignore[method-assign] + # saves the pipeline code and populates the config + pipeline.save_pretrained(args.output_dir) + model.save_pretrained = old_save_pretrained # type: ignore[method-assign] + + # saves the model weights correctly (FSDP or otherwise) + trainer.save_model(args.output_dir) def evaluate(args: config_base.TrainConfig):