Skip to content

Commit

Permalink
[70B-Part3] FSDP Training (#109)
Browse files Browse the repository at this point in the history
* use_fsdp option

* return move to(device) when not using FSDP
  • Loading branch information
farzadab authored Sep 16, 2024
1 parent b12be46 commit be8ee6b
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 21 deletions.
24 changes: 18 additions & 6 deletions ultravox/model/ultravox_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, config: UltravoxConfig):
self.vocab_size = config.vocab_size

self.audio_tower = self._create_audio_tower(config)
self.multi_modal_projector = UltravoxProjector(config)
self.multi_modal_projector = self._create_multi_modal_projector(config)
self.language_model = self._create_language_model(config)

# Determine no_split_modules dynamically to use with FSDP auto_wrap policy.
Expand Down Expand Up @@ -195,7 +195,7 @@ def forward(

# B x A/3200 x D
audio_tower_output = self.audio_tower.forward(
audio_values
audio_values.to(self.audio_tower.dtype)
).last_hidden_state
audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)

Expand Down Expand Up @@ -272,18 +272,26 @@ def prepare_inputs_for_generation(

return model_input

@classmethod
def _create_multi_modal_projector(
cls, config: UltravoxConfig
) -> "UltravoxProjector":
projector = UltravoxProjector(config)
projector.to(config.torch_dtype)
return projector

@classmethod
def _create_audio_tower(
cls, config: UltravoxConfig
) -> Union[transformers.Wav2Vec2Model, "ModifiedWhisperEncoder"]:
if config.audio_model_id is not None:
if "whisper" in config.audio_model_id is not None:
audio_tower = ModifiedWhisperEncoder.from_pretrained(
config.audio_model_id
config.audio_model_id, torch_dtype=config.torch_dtype
)
else:
audio_tower = transformers.AutoModel.from_pretrained(
config.audio_model_id
config.audio_model_id, torch_dtype=config.torch_dtype
)
else:
if "whisper" in config.audio_config._name_or_path:
Expand Down Expand Up @@ -314,14 +322,18 @@ def _create_language_model(
) -> transformers.LlamaForCausalLM:
if config.text_model_id is not None:
language_model = transformers.AutoModelForCausalLM.from_pretrained(
config.text_model_id, attn_implementation=config._attn_implementation
config.text_model_id,
attn_implementation=config._attn_implementation,
torch_dtype=config.torch_dtype,
)
else:
with transformers.modeling_utils.no_init_weights():
# we only ever use from_config if the weights are retrained, hence initializing is not
# required. This makes the model quite creation faster since init on CPU is quite slow.
language_model = transformers.AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
config.text_config,
attn_implementation=config._attn_implementation,
torch_dtype=config.torch_dtype,
)

language_model = apply_lora(language_model, config.text_model_lora_config)
Expand Down
19 changes: 18 additions & 1 deletion ultravox/training/config_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ class TrainConfig:

device: str = "cuda"
data_type: str = "bfloat16"
# Whether to use FSDP (Fully Sharded Data Parallelism) for training
# needed for large model training (e.g. 70B+)
use_fsdp: bool = False
# Path to load the model from. Can be local path, HF hub model_id, or W&B artifact
model_load_dir: Optional[str] = None
text_model_lora_config: Optional[ultravox_config.LoraConfigSimplified] = None
Expand All @@ -72,7 +75,9 @@ class TrainConfig:
optimizer: str = "adamw_torch"
num_epochs: int = 1
max_steps: int = 0
val_steps: Optional[int] = None
# Run an evaluation every X steps. If smaller than 1, will be interpreted as ratio of total training steps.
val_steps: Optional[float] = None
# Save checkpoint every X steps. If smaller than 1, will be interpreted as ratio of total training steps.
save_steps: float = 0
logging_steps: int = 1
grad_accum_steps: int = 1
Expand Down Expand Up @@ -137,6 +142,18 @@ def __post_init__(self):
)
self.disable_layerdrop = True

if self.use_fsdp and self.save_steps:
logging.warning(
"FSDP is enabled: Saving checkpoints is going to be extremely slow and results in a full save."
" Consider setting save_steps=0."
)

if self.use_fsdp and self.do_eval:
logging.warning(
"FSDP is enabled: Evaluation is not supported with FSDP. Disabling evaluation."
)
self.do_eval = False


def fix_hyphens(arg: str):
return re.sub(r"^--([^=]+)", lambda m: "--" + m.group(1).replace("-", "_"), arg)
Expand Down
44 changes: 30 additions & 14 deletions ultravox/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def train(args: config_base.TrainConfig):
text_model_id=args.text_model,
text_model_lora_config=args.text_model_lora_config,
audio_model_lora_config=args.audio_model_lora_config,
torch_dtype=args.data_type,
pad_token_id=text_tokenizer.eos_token_id,
)

logging.info("Instantiating model...")
Expand Down Expand Up @@ -178,13 +180,10 @@ def train(args: config_base.TrainConfig):

model.print_trainable_parameters()

# Move the model to GPU and enable bfloat16
dtype = getattr(torch, args.data_type)
device = torch.device(args.device, index=local_rank)
logging.info(
f"Using dtype and device (world_size): {dtype}, {device} ({world_size})"
)
model.to(device=device, dtype=dtype)
if not args.use_fsdp:
# Moving to device in FSDP is handled by the Trainer
model.to(device=torch.device(args.device, index=local_rank))
logging.info(f"Using device (world_size): {model.device} ({world_size})")

# Prepare dataset, subsetting if needed
train_dataset: data.IterableDataset
Expand Down Expand Up @@ -270,9 +269,9 @@ def train(args: config_base.TrainConfig):
optim=args.optimizer,
num_train_epochs=args.num_epochs,
max_steps=args.max_steps,
evaluation_strategy="steps",
eval_strategy="steps" if args.val_steps else "no",
eval_steps=args.val_steps,
save_strategy="steps",
save_strategy="steps" if args.save_steps else "no",
save_steps=args.save_steps,
logging_first_step=True,
logging_dir=args.logs_dir,
Expand All @@ -289,14 +288,19 @@ def train(args: config_base.TrainConfig):
lr_scheduler_type=args.lr_scheduler,
warmup_steps=args.lr_warmup_steps,
weight_decay=args.weight_decay,
fp16=dtype == torch.float16,
bf16=dtype == torch.bfloat16,
# fp16=dtype == torch.float16,
# bf16=dtype == torch.bfloat16,
use_cpu=args.device == "cpu",
seed=args.seed + local_rank,
report_to=args.report_logs_to,
# torch_compile=True,
# fsdp="full_shard auto_wrap",
# fsdp_transformer_layer_cls_to_wrap='LlamaDecoderLayer',
fsdp="full_shard auto_wrap" if args.use_fsdp else "",
fsdp_config={
"backward_prefetch": "backward_pre",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"state_dict_type": "SHARDED_STATE_DICT",
"sync_module_states": "true",
},
),
)

Expand All @@ -305,13 +309,25 @@ def train(args: config_base.TrainConfig):
logging.info("Starting training...")
t_start = datetime.now()
logging.info(f"train start time: {t_start}")

if args.val_steps:
trainer.evaluate()
if args.use_fsdp:
logging.warning(
"FSDP is enabled: Skipping initial validation since model is not initialized."
)
else:
trainer.evaluate()

trainer.train()
t_end = datetime.now()
logging.info(f"train end time: {t_end}")
logging.info(f"elapsed: {t_end - t_start}")

if args.use_fsdp:
# For training checkpoints, we want to use SHARDED_STATE_DICT which should be faster,
# but for the final save we want FULL_STATE_DICT so it can be serialized properly.
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")

# We use both pipeline.save_pretrained and trainer.save_model to save everything.
# This is because pipeline.save_pretrained knows how to save the pipeline (code and config),
# but it doesn't know how to save FSDP models correctly (the final tensors could be flattened).
Expand Down

0 comments on commit be8ee6b

Please sign in to comment.