Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learning rate restart broken with Nanoset? #233

Open
Pclanglais opened this issue Sep 25, 2024 · 13 comments
Open

Learning rate restart broken with Nanoset? #233

Pclanglais opened this issue Sep 25, 2024 · 13 comments
Labels
bug Something isn't working

Comments

@Pclanglais
Copy link

Retraining on checkpoint works perfectly with the tokenization on the fly, but breaks while using nanoset: training restart with a different lr, which is not the same as lr_schedule.pt

We also have two additional issues that are likely connected:

  • Loading a different staging dataset results in several anomalous messages (same number of tokens as the previous one, 0 steps remaining). It's not clear if it is properly loaded at all based on these information.
  • Training continues even when there are no tokens remaining (probably looping on the past tokens?).

Training tested with this configuration:

checkpoints:
  checkpoint_interval: 1500  # Adjusted to save checkpoints less frequently
  checkpoints_path: checkpoints_marianne_300m_pretrain_mini
  checkpoints_path_is_shared_file_system: false
  resume_checkpoint_path: checkpoints_marianne_300m_pretrain_mini
  save_initial_state: false
data_stages:
- data:
    dataset:
      dataset_folder: /lustre/fsn1/projects/rech/fmr/uft12cr/mini_corpus_base_300m_tokenized
    num_loading_workers: 96
    seed: 42
  name: Base corpus
  start_training_step: 1
- data:
    dataset:
      dataset_folder: /lustre/fsn1/projects/rech/fmr/uft12cr/mini_corpus_annealing_300m_tokenized
    num_loading_workers: 96
    seed: 42
  name: Annealing corpus
  start_training_step: 4501
general:
  benchmark_csv_path: null
  consumed_train_samples: null
  ignore_sanity_checks: true
  project: pretrain
  run: marianne_3b_pretrain_%date_%jobid
  seed: 42
  step: null
lighteval: null
logging:
  iteration_step_info_interval: 1
  log_level: info
  log_level_replica: info
model:
  ddp_bucket_cap_mb: 25
  dtype: bfloat16
  init_method:
    std: 0.025
  make_vocab_size_divisible_by: 1
  model_config:
    bos_token_id: 1
    eos_token_id: 2
    hidden_act: silu
    hidden_size: 960
    initializer_range: 0.02
    intermediate_size: 2560
    is_llama_config: true
    max_position_embeddings: 4096
    num_attention_heads: 15
    num_hidden_layers: 32
    num_key_value_heads: 5
    pad_token_id: null
    pretraining_tp: 1
    rms_norm_eps: 1.0e-05
    rope_scaling: null
    tie_word_embeddings: true
    use_cache: true
    vocab_size: 65536
    rope_theta: 500000 
optimizer:
  accumulate_grad_in_fp32: true
  clip_grad: 1.0
  learning_rate_scheduler:
    learning_rate: 0.003
    lr_decay_starting_step: 1501  # Start decay after warmup
    lr_decay_steps: 6000  # Decay over the remaining 80% of training
    lr_decay_style: cosine
    lr_warmup_steps: 1500  # 20% warmup
    lr_warmup_style: linear
    min_decay_lr: 0.0001
  optimizer_factory:
    adam_beta1: 0.9
    adam_beta2: 0.95
    adam_eps: 1.0e-08
    name: adamW
    torch_adam_is_fused: true
  weight_decay: 0.01
  zero_stage: 1
parallelism:
  dp: 32
  expert_parallel_size: 1
  pp: 1
  pp_engine: 1f1b
  tp: 1
  tp_linear_async_communication: false
  tp_mode: ALL_REDUCE
profiler: null
tokenizer:
  tokenizer_max_length: null
  tokenizer_name_or_path: /lustre/fswork/projects/rech/fmr/uft12cr/tokenizer/pleias_300m_65k_tokenizer
  tokenizer_revision: null
tokens:
  batch_accumulation_per_replica: 2
  limit_test_batches: 0
  limit_val_batches: 0
  micro_batch_size: 8
  sequence_length: 4096
  train_steps: 7500
  val_check_interval: -1```
@eliebak
Copy link
Contributor

eliebak commented Sep 25, 2024

Hey thanks for opening the issue, can you add the error message that you get and the log ?

@Pclanglais
Copy link
Author

Pclanglais commented Sep 25, 2024

Here they are:

  • start_run.out: 300 initial steps interrupted.
  • restart_run.out: 100 more steps with the wrong lr.

(switched to .txt due to github constraints)

start_run.txt

restart_run.txt

lr_scheduler.txt

@pchizhov
Copy link

Hi! About the "0 steps remaining" in this issue: here

if metadata.last_train_step > stage.start_training_step:
# NOTE: if the last_train_step is larger than the start_training_step of the current stage,
# it means that the training has already passed this stage
# so there is no remaining steps
return 0
there seems to be a bug. It returns that 0 steps are remaining when the current training step is larger than the first step of the stage. However, the stage can be not finished yet: in our case, the total number of steps in the stage is 100,000 and we are trying to restart from step 42501.

@eliebak
Copy link
Contributor

eliebak commented Sep 26, 2024

cc @zzhhjjj maybe if you can take a look at this (i screenshot the part that show two different lr despite having the same lr_schedule in the config and resuming from ckpt)
Screenshot 2024-09-26 at 15 46 06
Screenshot 2024-09-26 at 15 45 58

@zzhhjjj
Copy link
Collaborator

zzhhjjj commented Sep 26, 2024

I think you are correct. I'll take a look. I remember seeing the same issue before. A temporary bypass would be to modify the metafile by hand.

@Lauler
Copy link

Lauler commented Nov 1, 2024

Running into the same issue. After resuming training with Nanoset, LR is wrong value and stays constant.

@Lauler
Copy link

Lauler commented Nov 1, 2024

Added some extra logging to see whether the LR scheduler is initialized correctly:

[0434:0]:11/01/2024 11:05:18 [INFO|DP=0|PP=0|TP=0|lrdn0434]: Learning rate scheduler state: {'base_lrs': [0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003], 'last_epoch': 12000, 'verbose': False, '_step_count': 12001, '_get_lr_called_within_step': False, '_last_lr': [3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05, 3.11549437145356e-05], 'lr_lambdas': [{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}]}
[0434:0]:11/01/2024 11:05:18 [INFO|DP=0|PP=0|TP=0|lrdn0434]: iteration: 12001 / 17500 | consumed_tokens: 25.2G | elapsed_time_per_iteration_ms: 28.2K | tokens_per_sec: 74.5K | tokens_per_sec_per_gpu: 2.33K | global_batch_size: 1.02K | lm_loss: 2.3 | lr: 0.000289 | model_tflops_per_gpu: 26.7 | hardware_tflops_per_gpu: 26.7 | grad_norm: 0.15 | cuda_memory_allocated: 8.86G | cuda_max_memory_reserved: 34.1G | hd_total_memory_tb: 9.25G | hd_used_memory_tb: 8.21G | hd_free_memory_tb: 1.03G
[0434:0]:11/01/2024 11:05:26 [INFO|DP=0|PP=0|TP=0|lrdn0434]: Learning rate scheduler state: {'base_lrs': [0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003], 'last_epoch': 12001, 'verbose': False, '_step_count': 12002, '_get_lr_called_within_step': False, '_last_lr': [0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615, 0.00028892609384716615], 'lr_lambdas': [{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}]}
  • I'm resuming from a checkpoint saved at step 12000.
  • Training has only a single data stage.
  • Total train steps in config is set to 17500.

The LR scheduler is initialized correctly before the first self.lr_scheduler.step() is called. It has the correct _last_lr from the previous training run (3.11549437145356e-05).

However, this gets overwritten in the next iteration by a number that doesn't make sense (0.0002889). I'm having a hard time identifying where this would be happening in the code. I'll try and step through it line by line some other day and take a closer look.

Could you be explicit about what your temporary bypass would look like @zzhhjjj ? What needs to be edited?

Regarding the other bug with number of remaining steps calculation: I don't think it has any effect as that information is not used anywhere by nanoset as far as I can see.

@Pclanglais
Copy link
Author

Pclanglais commented Nov 1, 2024 via email

@Lauler
Copy link

Lauler commented Nov 1, 2024

I have been using the latest commit 51ca40bc5e1b1f5dcb55eaeb0b6f86dda03f3979 of the repo when running these experiments.

@Lauler
Copy link

Lauler commented Nov 17, 2024

I have been debugging this issue and have found what I believe is causing the faulty learning rates when restarting training. It happens when the LR scheduler is being built due to what I think is a mistaken assumption in Nanotron.

The issue: initial_lr incorrectly initialized in lambda_lr when restarting

LambdaLR in Pytorch sets the learning rate of an optimizer's parameter groups to the initial learning rate times a multiplicative factor (specificed by a given function).

The function lr_lambda in Nanotron has an argument initial_lr which is described as "the learning rate of a parameter group".

def lr_lambda(current_step: int, initial_lr: float):
"""
current_step: current training step
initial_lr: the learning rate of a parameter group

However, it should instead be "the initial learning rate of a parameter group".

The initial_lr parameter is being used as a direct substitute for what most standard parametrization LR scheduler formulas refer to as η_max(the maximum learning rate during training). The issue however is that the optimizer's param_group["lr"](which gets passed as initial_lr) is only equal to the initial learning_rate when you start a training run from scratch. Whenever you resume a training run, the lr key in the optimizer's parameter groups corresponds to the learning rate from the iteration that checkpoint was saved, as opposed to the learning_rate specified in the yaml configs.

for param_group in optimizer.get_base_optimizer().param_groups:
lr_lambdas.append(get_lr_lambda_for_param_group(lr=param_group["lr"]))

Whenever a user restarts a training, the initial_lr is therefore initialized incorrectly.

To patch this and resume correctly from a checkpoint, we need to pass lr_scheduler_args.learning_rate or param_groups["initial_lr"] as an argument instead. Since param_groups["initial_lr"] does not exist when the optimizer is first initialized, we pass lr_scheduler_args.learning_rate:

    for param_group in optimizer.get_base_optimizer().param_groups:
        lr_lambdas.append(get_lr_lambda_for_param_group(lr=lr_scheduler_args.learning_rate))

This issue is quite subtle and hard to detect if you restart training from an early checkpoint. It however becomes more and more noticeable as the difference becomes greater between the current learning rate at a specific checkpoint and the initial learning rate.

I'd advise anyone that has done large training runs with lots of resumptions from checkpoints to doublecheck their logs and see whether their learning rates actually corresponded to what they expected and specificed. @eliebak

(assuming you haven't been using µTransfer)

@eliebak
Copy link
Contributor

eliebak commented Nov 25, 2024

Hey, sorry just saw this issue, seems to be fix in #245 according to #243 ?

@Lauler
Copy link

Lauler commented Nov 27, 2024

I thought it was fixed, but discovered it's still not fixed. If you are training with standard parametrization as opposed to µ-parametrization (as most of are), I would apply my hacky fix while awaiting a proper fix.

@eliebak
Copy link
Contributor

eliebak commented Nov 28, 2024

Ok thanks a lot! cc @NouamaneTazi or @xrsrke (as it's working with spectral muP?) if you can take a look!

@eliebak eliebak added the bug Something isn't working label Nov 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants