Skip to content

Commit

Permalink
update files
Browse files Browse the repository at this point in the history
  • Loading branch information
lintangsutawika committed Feb 26, 2024
1 parent 4d246a1 commit 84c5380
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 2 deletions.
113 changes: 113 additions & 0 deletions configs/coord_check.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
{
# parallelism settings
"pipe_parallel_size": 1,
"model_parallel_size": 1,

# model settings
"num_layers": 8,
"num_attention_heads": 8,
"seq_length": 128,
"max_position_embeddings": 128,
"pos_emb": "rotary",
"rotary_pct": 0.25,
"no_weight_tying": true,
"gpt_j_residual": true,
"output_layer_parallelism": "column",

# "attention_config": [[["flash"], 8]],

# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": true,
"bias_gelu_fusion": true,

# init methods
"init_method": "normal",
"output_layer_init_method": "scaled_normal",

# optimizer settings
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.006,
"betas": [0.9, 0.95],
"eps": 1.0e-8,
}
},
# "min_lr": 0.006,

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 1,
"allgather_partitions": true,
"allgather_bucket_size": 1260000000,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 1260000000,
"contiguous_gradients": true,
"cpu_offload": false
},

# batch / data settings
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 8,
"data_impl": "mmap",
"num_workers": 1,

# activation checkpointing
"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,

# regularization
"gradient_clipping": 1.0,
"weight_decay": 0.0,
"hidden_dropout": 0,
"attention_dropout": 0,

# precision settings
"precision": "fp32",
# "fp16": {
# "fp16": true,
# "enabled": true,
# "loss_scale": 0,
# "loss_scale_window": 1000,
# "initial_scale_power": 12,
# "hysteresis": 2,
# "min_loss_scale": 1,
# },

# misc. training settings
"train_iters": 300,
"lr_decay_iters": 300,
"distributed_backend": "nccl",
"lr_decay_style": "cosine",
"warmup": 0.01,
"checkpoint_factor": 300,
"eval_interval": 300,
"eval_iters": 10,

# logging
"log_interval": 10,
"steps_per_print": 10,
"wall_clock_breakdown": true,

"tokenizer_type": "HFTokenizer",
"vocab-file": "/weka/lintangsutawika/09-mup-neox/20B_tokenizer.json",

"coord_check": true,
"use_mup": true,
# sigma_base
"init_method_std": 0.08,
# "mup_embedding_multiplier": 5,
# "mup_output_multiplier": 1,
# "mup_width_multiplier": 1,
"mup_d_model_base": 128,
"hidden_size": 128,

"data-path": "/weka/lintangsutawika/09-mup-neox/data/enwik8/enwik8_text_document",

# "launcher": "slurm",
# "deepspeed_slurm": true,

}
8 changes: 6 additions & 2 deletions megatron/mup_substitute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Helper functions for performing coord check.
"""
import os
import gc
from copy import copy
from itertools import product

Expand Down Expand Up @@ -30,7 +31,7 @@ def _get_coord_data(
filter_module_by_name=None,
fix_data=True,
cuda=True,
nseeds=2,
nseeds=10,
output_fdict=None,
input_fdict=None,
param_fdict=None,
Expand Down Expand Up @@ -131,12 +132,15 @@ def output_logits_coord_check_hook(module, input, output):
df["output_logits_act_abs_mean"].append(output_logits_act_abs_mean)
df["width"].append(width)

import gc
del model, optimizer
gc.collect()
with torch.no_grad():
torch.cuda.empty_cache()

gc.collect()
with torch.no_grad():
torch.cuda.empty_cache()

return pd.DataFrame(df)


Expand Down

0 comments on commit 84c5380

Please sign in to comment.