Skip to content

Commit

Permalink
Updates
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn committed Dec 7, 2024
1 parent 35b1459 commit 4e2c037
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 20 deletions.
5 changes: 2 additions & 3 deletions base-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ cloud-files==4.29.0
fsspec==2024.9.0
hf-transfer<0.2.0
pyarrow>=15.0.0,<19.0.0
pynvml<12.0.0
rich>=13.0.0,<14
s3fs==2024.9.0
snowflake-connector-python[pandas]==3.12.3
torch==2.3.1+cu121
torchao==0.6.1+cu121
truefoundry==0.5.1rc3
unsloth[cu121-ampere-torch230] @ git+https://github.com/unslothai/unsloth.git@8558bc92b06f9128499484ef737fa71b966ffc23
truefoundry==0.5.1rc6
unsloth[cu121-ampere-torch230] @ git+https://github.com/unslothai/unsloth.git@9dc399a6b6625ee40835c5eab361426d3c5d4abb
12 changes: 6 additions & 6 deletions config-base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ early_stopping_patience: 10
eval_sample_packing: False
eval_steps: 0.1
eval_strategy: steps
fix_untrained_tokens: true
fix_untrained_tokens: True
gradient_accumulation_steps: 4
gradient_checkpointing: True
gradient_checkpointing_kwargs:
Expand Down Expand Up @@ -132,8 +132,8 @@ truefoundry_ml_run_name: auto # type: string
val_data_uri: null

## Liger
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
liger_rope: True
liger_rms_norm: True
liger_glu_activation: True
liger_layer_norm: True
liger_fused_linear_cross_entropy: True
2 changes: 1 addition & 1 deletion plugins/axolotl_truefoundry/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ description = "TrueFoundry plugin for Axolotl"
requires-python = ">=3.8.1,<4.0"
dependencies = [
"transformers>=4.0.0,<5",
"truefoundry>=0.4.4,<0.5.0",
"truefoundry==0.5.1rc6",
"pynvml>=11.0.0,<12",
"torch>=2.3.0,<2.4.0",
"pydantic>=2.0.0,<3",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
--extra-index-url https://download.pytorch.org/whl/cu121
-r base-requirements.txt
axolotl[deepspeed,flash-attn,mamba-ssm,optimizers,lion-pytorch,galore] @ git+https://github.com/truefoundry/axolotl@405ef674a46b85376eebfdd5d1d20680cf2a429d
axolotl[deepspeed,flash-attn,mamba-ssm,optimizers,lion-pytorch,galore] @ git+https://github.com/truefoundry/axolotl@bfcb37836b13712afae9d48dc4c6187b1eecb3d5
19 changes: 10 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,10 @@ def make_axolotl_config(config_base, kwargs, timestamp=None):
is_tf32_supported = is_ampere_or_newer and is_torch_tf32_available()
is_bf16_supported = is_ampere_or_newer and is_torch_bf16_gpu_available()

single_gpu = torch.cuda.device_count() == 1
using_deepspeed = cfg.deepspeed is not None
use_unsloth_lora = False # single_gpu and not using_deepspeed and cfg.adapter in {"qlora", "lora"}
use_unsloth = False
# single_gpu = torch.cuda.device_count() == 1
# using_deepspeed = cfg.deepspeed is not None
# use_unsloth_lora = single_gpu and not using_deepspeed and cfg.adapter in {"qlora", "lora"}

set_cfg_option_if_auto(cfg, "tf32", is_tf32_supported)
# TODO: Axolotl doesn't seem to do anything differently even though it says setting bfloat16/float16 will disable AMP
Expand All @@ -167,12 +168,12 @@ def make_axolotl_config(config_base, kwargs, timestamp=None):
set_cfg_option_if_auto(cfg, "load_in_4bit", cfg.adapter == "qlora")

# TODO (chiragjn): Add model arch condition
set_cfg_option_if_auto(cfg, "unsloth_cross_entropy_loss", single_gpu and not using_deepspeed)
set_cfg_option_if_auto(cfg, "unsloth_rms_norm", single_gpu and not using_deepspeed)
set_cfg_option_if_auto(cfg, "unsloth_rope", single_gpu and not using_deepspeed)
set_cfg_option_if_auto(cfg, "unsloth_lora_mlp", use_unsloth_lora)
set_cfg_option_if_auto(cfg, "unsloth_lora_qkv", use_unsloth_lora)
set_cfg_option_if_auto(cfg, "unsloth_lora_o", use_unsloth_lora)
set_cfg_option_if_auto(cfg, "unsloth_cross_entropy_loss", use_unsloth)
set_cfg_option_if_auto(cfg, "unsloth_rms_norm", use_unsloth)
set_cfg_option_if_auto(cfg, "unsloth_rope", use_unsloth)
set_cfg_option_if_auto(cfg, "unsloth_lora_mlp", use_unsloth)
set_cfg_option_if_auto(cfg, "unsloth_lora_qkv", use_unsloth)
set_cfg_option_if_auto(cfg, "unsloth_lora_o", use_unsloth)

set_cfg_option_if_auto(cfg, "flash_attention", is_ampere_or_newer)
set_cfg_option_if_auto(cfg, "flash_attn_cross_entropy", not cfg.unsloth_cross_entropy_loss)
Expand Down

0 comments on commit 4e2c037

Please sign in to comment.