-
Notifications
You must be signed in to change notification settings - Fork 126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Learning rate restart broken with Nanoset? #233
Comments
Hey thanks for opening the issue, can you add the error message that you get and the log ? |
Here they are:
(switched to .txt due to github constraints) |
Hi! About the "0 steps remaining" in this issue: here nanotron/src/nanotron/helpers.py Lines 694 to 698 in 97c13b0
|
cc @zzhhjjj maybe if you can take a look at this (i screenshot the part that show two different lr despite having the same lr_schedule in the config and resuming from ckpt) |
I think you are correct. I'll take a look. I remember seeing the same issue before. A temporary bypass would be to modify the metafile by hand. |
Running into the same issue. After resuming training with Nanoset, LR is wrong value and stays constant. |
Added some extra logging to see whether the LR scheduler is initialized correctly:
The LR scheduler is initialized correctly before the first However, this gets overwritten in the next iteration by a number that doesn't make sense ( Could you be explicit about what your temporary bypass would look like @zzhhjjj ? What needs to be edited? Regarding the other bug with number of remaining steps calculation: I don't think it has any effect as that information is not used anywhere by nanoset as far as I can see. |
On our side problem was solved by itself (not happened again for weeks now). As as I remember we re-updated the repo just before so maybe could be worth trying.Le 1 nov. 2024 à 11:29, Faton ***@***.***> a écrit :
Added some extra logging to see whether the LR scheduler is initialized correctly:
[0434:0]:11/01/2024 11:05:18 [INFO|DP=0|PP=0|TP=0|lrdn0434]: Learning rate scheduler state: {'base_lrs': [0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003], 'last_epoch': 12000, 'verbose': False, '_step_count': 12001, '_get_lr_called_within_step': False, '_last_lr': [3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05], 'lr_lambdas': [{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}]}
[0434:0]:11/01/2024 11:05:18 [INFO|DP=0|PP=0|TP=0|lrdn0434]: iteration: 12001 / 17500 | consumed_tokens: 25.2G | elapsed_time_per_iteration_ms: 28.2K | tokens_per_sec: 74.5K | tokens_per_sec_per_gpu: 2.33K | global_batch_size: 1.02K | lm_loss: 2.3 | lr: 0.000289 | model_tflops_per_gpu: 26.7 | hardware_tflops_per_gpu: 26.7 | grad_norm: 0.15 | cuda_memory_allocated: 8.86G | cuda_max_memory_reserved: 34.1G | hd_total_memory_tb: 9.25G | hd_used_memory_tb: 8.21G | hd_free_memory_tb: 1.03G
[0434:0]:11/01/2024 11:05:26 [INFO|DP=0|PP=0|TP=0|lrdn0434]: Learning rate scheduler state: {'base_lrs': [0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003], 'last_epoch': 12001, 'verbose': False, '_step_count': 12002, '_get_lr_called_within_step': False, '_last_lr': [0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615], 'lr_lambdas': [{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}]}
I'm resuming from a checkpoint saved at step 12000.
Training has only a single data stage.
Total train steps in config is set to 17500.
The LR scheduler is initialized correctly before the first self.lr_scheduler.step() is called. It has the correct _last_lr from the previous training run.
However, this gets overwritten in the next iteration by a number that doesn't make sense. I'm having a hard time identifying where this would be happening in the code. I'll try and step through it line by line some other day and take a closer look.
Could you be explicit about what your temporary bypass would look like @zzhhjjj ? What needs to be edited?
Regarding the other bug with number of remaining steps calculation: I don't think it has any effect as that information is not used anywhere by nanoset as far as I can see.
—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you authored the thread.Message ID: ***@***.***>
|
I have been using the latest commit |
I have been debugging this issue and have found what I believe is causing the faulty learning rates when restarting training. It happens when the LR scheduler is being built due to what I think is a mistaken assumption in Nanotron. The issue:
|
def lr_lambda(current_step: int, initial_lr: float): | |
""" | |
current_step: current training step | |
initial_lr: the learning rate of a parameter group |
However, it should instead be "the initial learning rate of a parameter group".
The initial_lr
parameter is being used as a direct substitute for what most standard parametrization LR scheduler formulas refer to as η_max
(the maximum learning rate during training). The issue however is that the optimizer's param_group["lr"]
(which gets passed as initial_lr
) is only equal to the initial learning_rate
when you start a training run from scratch. Whenever you resume a training run, the lr
key in the optimizer's parameter groups corresponds to the learning rate from the iteration that checkpoint was saved, as opposed to the learning_rate
specified in the yaml configs.
nanotron/src/nanotron/helpers.py
Lines 167 to 168 in 51ca40b
for param_group in optimizer.get_base_optimizer().param_groups: | |
lr_lambdas.append(get_lr_lambda_for_param_group(lr=param_group["lr"])) |
Whenever a user restarts a training, the initial_lr
is therefore initialized incorrectly.
To patch this and resume correctly from a checkpoint, we need to pass lr_scheduler_args.learning_rate
or param_groups["initial_lr"]
as an argument instead. Since param_groups["initial_lr"]
does not exist when the optimizer is first initialized, we pass lr_scheduler_args.learning_rate
:
for param_group in optimizer.get_base_optimizer().param_groups:
lr_lambdas.append(get_lr_lambda_for_param_group(lr=lr_scheduler_args.learning_rate))
This issue is quite subtle and hard to detect if you restart training from an early checkpoint. It however becomes more and more noticeable as the difference becomes greater between the current learning rate at a specific checkpoint and the initial learning rate.
I'd advise anyone that has done large training runs with lots of resumptions from checkpoints to doublecheck their logs and see whether their learning rates actually corresponded to what they expected and specificed. @eliebak
(assuming you haven't been using µTransfer)
I thought it was fixed, but discovered it's still not fixed. If you are training with standard parametrization as opposed to µ-parametrization (as most of are), I would apply my hacky fix while awaiting a proper fix. |
Ok thanks a lot! cc @NouamaneTazi or @xrsrke (as it's working with spectral muP?) if you can take a look! |
Retraining on checkpoint works perfectly with the tokenization on the fly, but breaks while using nanoset: training restart with a different lr, which is not the same as lr_schedule.pt
We also have two additional issues that are likely connected:
Training tested with this configuration:
The text was updated successfully, but these errors were encountered: