From 7a334c1a56b8e9fcdd87034b69f699e20b72aea7 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sat, 27 Jul 2024 21:42:21 +0000 Subject: [PATCH 1/4] add custom lr shedule --- open_diloco/train_fsdp.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index fdb1c5f..402b4c0 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -7,6 +7,7 @@ """ from functools import partial +import math import os import time from contextlib import nullcontext @@ -28,7 +29,6 @@ DataCollatorForLanguageModeling, LlamaConfig, LlamaForCausalLM, - get_cosine_schedule_with_warmup, ) from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, @@ -39,7 +39,7 @@ from torch.distributed import broadcast_object_list from open_diloco.ckpt_utils import load_checkpoint, save_checkpoint from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoOptimizer - +from torch.optim.lr_scheduler import LambdaLR from hivemind.dht.dht import DHT from hivemind.utils.networking import log_visible_maddrs @@ -189,6 +189,27 @@ def get_model(config: Config) -> LlamaForCausalLM: return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model) +def _get_cosine_schedule_with_warmup_lr_lambda( + current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0 +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) + factor = factor * (1 - min_lr_rate) + min_lr_rate + return max(0, factor) + + +def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_inner_steps): + lambda_lr = partial( + _get_cosine_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=0.5, + ) + return LambdaLR(optimizer, lambda_lr, -1) + + def train(config: Config): sharding_strategy = get_sharding_strategy(config.sharding_strategy) local_rank = int(os.environ["LOCAL_RANK"]) @@ -282,6 +303,7 @@ def scheduler_fn(opt): opt, num_warmup_steps=config.warmup_steps, num_training_steps=config.total_steps, + num_inner_steps=config.hv.local_steps, ) if config.hv is not None: From 852b3c533f44bb90e112ae7fc90d46355c3d8f7c Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sat, 27 Jul 2024 22:03:37 +0000 Subject: [PATCH 2/4] add warmup steps --- open_diloco/train_fsdp.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 402b4c0..82fd0bf 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -105,6 +105,7 @@ class HvConfig(BaseConfig): skip_load_from_peers: bool = False world_rank: int galaxy_size: int + warmup_outerstep: int = 10 @model_validator(mode="before") def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]: @@ -190,8 +191,18 @@ def get_model(config: Config) -> LlamaForCausalLM: def _get_cosine_schedule_with_warmup_lr_lambda( - current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0 + current_step: int, + *, + num_warmup_steps: int, + num_training_steps: int, + num_inner_steps: int, + warmup_outerstep: int | None, + num_cycles: float, + min_lr_rate: float = 0.0, ): + if warmup_outerstep is not None and current_step % num_inner_steps < warmup_outerstep: + return 0 + if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) @@ -200,11 +211,13 @@ def _get_cosine_schedule_with_warmup_lr_lambda( return max(0, factor) -def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_inner_steps): +def get_cosine_schedule_with_warmup(optimizer, config: Config): lambda_lr = partial( _get_cosine_schedule_with_warmup_lr_lambda, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps, + num_warmup_steps=config.warmup_steps, + num_training_steps=config.total_steps, + num_inner_steps=config.hv.local_steps, + warmup_outerstep=config.hv.warmup_outerstep, num_cycles=0.5, ) return LambdaLR(optimizer, lambda_lr, -1) @@ -301,9 +314,7 @@ def train(config: Config): def scheduler_fn(opt): return get_cosine_schedule_with_warmup( opt, - num_warmup_steps=config.warmup_steps, - num_training_steps=config.total_steps, - num_inner_steps=config.hv.local_steps, + config=config, ) if config.hv is not None: From ad82e143cdc2f6b22a418f78a32b65bcd30fa6a2 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Fri, 2 Aug 2024 14:29:34 +0000 Subject: [PATCH 3/4] do not update lr scheduler during warmup --- open_diloco/train_fsdp.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 82fd0bf..a6699e5 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -200,7 +200,11 @@ def _get_cosine_schedule_with_warmup_lr_lambda( num_cycles: float, min_lr_rate: float = 0.0, ): - if warmup_outerstep is not None and current_step % num_inner_steps < warmup_outerstep: + if ( + warmup_outerstep is not None + and current_step > num_warmup_steps + and current_step % num_inner_steps < warmup_outerstep + ): return 0 if current_step < num_warmup_steps: From 1869090440f13e95c3f2e4b4f7266dfe015f8093 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Fri, 2 Aug 2024 14:38:58 +0000 Subject: [PATCH 4/4] do not update lr scheduler during warmup --- open_diloco/train_fsdp.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index a6699e5..bc96950 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -200,15 +200,12 @@ def _get_cosine_schedule_with_warmup_lr_lambda( num_cycles: float, min_lr_rate: float = 0.0, ): - if ( - warmup_outerstep is not None - and current_step > num_warmup_steps - and current_step % num_inner_steps < warmup_outerstep - ): - return 0 - if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) + + if warmup_outerstep is not None and current_step % num_inner_steps < warmup_outerstep: + return 0 + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) factor = factor * (1 - min_lr_rate) + min_lr_rate