Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

out of memory for continuing pretraining llama3-8B #161

Open
ckzbullbullet opened this issue May 3, 2024 · 5 comments
Open

out of memory for continuing pretraining llama3-8B #161

ckzbullbullet opened this issue May 3, 2024 · 5 comments

Comments

@ckzbullbullet
Copy link

I am trying to use the framework to continue pretraining llama3-8B. I have converted the HF checkpoint into nanotron format and the generated tokens seem reasonable.

I use the following setting to train the model but I got the OOO with 8 GPUs. I tried HF accelerate with deepspeed zero-1 and flash-attention and it worked well previously:

model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
std: 0.025
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 128000
eos_token_id: 128001
hidden_act: silu
hidden_size: 4096
initializer_range: 0.02
intermediate_size: 14336
is_llama_config: true
max_position_embeddings: 8192
num_attention_heads: 32
num_hidden_layers: 32
num_key_value_heads: 8
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
tie_word_embeddings: false
use_cache: true
vocab_size: 128256
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 8.0e-07
lr_decay_starting_step: null
lr_decay_steps: 100
lr_decay_style: cosine
lr_warmup_steps: 100
lr_warmup_style: linear
min_decay_lr: 1.0e-05
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.999
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.1
zero_stage: 1
parallelism:
dp: 8
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
tp: 1
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 1
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 1
sequence_length: 8192
train_steps: 200
val_check_interval: -1

@zzhhjjj
Copy link
Collaborator

zzhhjjj commented May 6, 2024

Hi,

Thanks for your question. Here is my solution:

general:
  benchmark_csv_path: null
  consumed_train_samples: null
  ignore_sanity_checks: true
  project: test
  run: llama
  seed: 42
  step: 1
lighteval: null
logging:
  iteration_step_info_interval: 1
  log_level: info
  log_level_replica: info
model:
  ddp_bucket_cap_mb: 25
  dtype: bfloat16
  init_method:
    std: 0.2
  make_vocab_size_divisible_by: 1
  model_config:
    bos_token_id: 128000
    eos_token_id: 128001
    hidden_act: silu
    hidden_size: 4096
    initializer_range: 0.02
    intermediate_size: 14336
    is_llama_config: true
    max_position_embeddings: 8192
    num_attention_heads: 32
    num_hidden_layers: 32
    num_key_value_heads: 8
    pad_token_id: null
    pretraining_tp: 1
    rms_norm_eps: 1.0e-05
    rope_scaling: null
    tie_word_embeddings: false
    use_cache: true
    vocab_size: 128256
optimizer:
  accumulate_grad_in_fp32: true
  clip_grad: 1.0
  learning_rate_scheduler:
    learning_rate: 0.0003
    lr_decay_starting_step: null
    lr_decay_steps: 198
    lr_decay_style: cosine
    lr_warmup_steps: 2
    lr_warmup_style: linear
    min_decay_lr: 1.0e-05
  optimizer_factory:
    adam_beta1: 0.9
    adam_beta2: 0.95
    adam_eps: 1.0e-08
    name: adamW
    torch_adam_is_fused: true
  weight_decay: 0.01
  zero_stage: 0
parallelism:
  dp: 2
  expert_parallel_size: 1
  pp: 1
  pp_engine: 1f1b
  tp: 4
  tp_linear_async_communication: true
  tp_mode: REDUCE_SCATTER
profiler: null
tokenizer:
  tokenizer_max_length: null
  tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B
  tokenizer_revision: null
tokens:
  batch_accumulation_per_replica: 32
  limit_test_batches: 0
  limit_val_batches: 0
  micro_batch_size: 1
  sequence_length: 8192
  train_steps: 30
  val_check_interval: -1

You can also try tp=8, dp=1.

Then should set theta = 500000.0 for RotaryEmbedding and interleaved = False.

Please let me know if you have any other questions

@ckzbullbullet
Copy link
Author

ckzbullbullet commented May 6, 2024

Thanks for the reply.
I have implemented loading pretrained weights from llama-8b and it seems work.
Actually, my question is that when I use nanotron with pure DP (dp=8,pp=1,tp=1), I got OOO even if I use micro-batch-size=1. However, I did not get OOO when I use accelerate + deepspeed ZERO1 with the similar settings.
What would be the possible reason that nanotron uses more memory for pure DP? Also, if I want to implement that, where I should look into?

@zzhhjjj
Copy link
Collaborator

zzhhjjj commented May 6, 2024

If my understanding is correct, ZERO1 is not pure DP, that's why you got the error.

Some links:
https://huggingface.co/transformers/v4.12.5/parallelism.html
https://blog.csdn.net/baoyan2015/article/details/136820078

@huggingface huggingface deleted a comment from obrienwrite May 6, 2024
@ckzbullbullet
Copy link
Author

Sorry for the confusion.
I was comparing nanotron with the setting "dp=8,pp=1,tp=1,zero=1" with deepspeed zero1 without pp&tp.
I actually found that in deepspeed, I set the gradient_checkpointing=True to save memory during training but I did not find any notes about gradient_checkpoint in nanotron.

In the doc, I found the "activation checkpointing" with a decorator "@checkpoint_method(attr_name="do_checkpoint")". I applied it on LLaMAModel and I got the following error:
"File "/home/kezhen/Documents/nanotron/src/nanotron/parallel/pipeline_parallel/engine.py", line 56, in forward
[rank3]: assert output["loss"].requires_grad"

I have the following question:

  1. Does nanotron support gradient_checkpointing right now?
  2. Is the activation checkpointing similar with gradient checkpointing?
  3. What is the correct way to use it?

@huggingface huggingface deleted a comment from obrienwrite May 7, 2024
@zzhhjjj
Copy link
Collaborator

zzhhjjj commented May 12, 2024

Hello,

1,2 Yes, @checkpoint_method is the same as gradient checkpointing.
3 I didn't have enough time to check the code, but you can set use_reentrant=False to fix the error. Your observation is correct, even then there is still an OOO error, I will check the reason later when I have time.
https://pytorch.org/docs/stable/checkpoint.html

Thanks a lot for your questions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants