Skip to content

Commit

Permalink
Updates
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn committed Dec 8, 2024
1 parent 4e2c037 commit 25d052d
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Dockerfile-notebook
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ RUN mkdir -p /packages && \
# Install axolotl_truefoundry plugin with our requirements overrides over axolotl
COPY --chown=jovyan:users plugins/axolotl_truefoundry /packages/axolotl_truefoundry
RUN cd /packages/axolotl_truefoundry/ && \
pip install --no-cache-dir -U -r /tmp/llm-finetune/notebook-requirements.txt -e .
pip install --no-cache-dir -e .

# Add source code for finetuning
COPY --chown=jovyan:users . /tmp_home/jovyan/llm-finetune/
Expand Down
3 changes: 3 additions & 0 deletions config-base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ use_tensorboard: True
cleanup_output_dir_on_start: False
dataset_type: chat # Can be completion | chat
drop_long_sequences: False
extra_hf_training_args:
average_tokens_across_devices: True
eval_on_start: True
logging_dir: ./tensorboard_logs
merge_adapters_post_train: True
save_model_on_interrupt: False
Expand Down
12 changes: 6 additions & 6 deletions plugins/axolotl_truefoundry/axolotl_truefoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from axolotl.integrations.base import BasePlugin
from axolotl.utils.callbacks import GPUStatsCallback
from axolotl.utils.distributed import is_main_process
from pydantic import BaseModel, ConfigDict, model_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator
from transformers import Trainer, TrainerCallback
from transformers.integrations import rewrite_logs
from transformers.integrations.integration_utils import TensorBoardCallback
Expand Down Expand Up @@ -92,7 +92,8 @@ def _add_perplexity(self, logs):
logs[perplexity_key] = perplexity

# noinspection PyMethodOverriding
def on_log(self, args, state, control, logs, model=None, **kwargs):
def on_log(self, args, state, control, logs=None, model=None, **kwargs):
logs = logs or {}
# TODO (chiragjn): Hack for now, needs to be moved to `compute_metrics`
# unfortunately compute metrics does not give us already computed metrics like eval_loss
if not state.is_world_process_zero:
Expand Down Expand Up @@ -121,7 +122,8 @@ def __init__(
logger.warning("checkpoint_artifact_name not passed. Checkpoints will not be logged to MLFoundry")

# noinspection PyMethodOverriding
def on_log(self, args, state, control, logs, model=None, **kwargs):
def on_log(self, args, state, control, logs=None, model=None, **kwargs):
logs = logs or {}
if not state.is_world_process_zero:
return

Expand Down Expand Up @@ -183,17 +185,15 @@ class LongSequenceStrategy(str, enum.Enum):

class TruefoundryMLPluginArgs(BaseModel):
model_config = ConfigDict(use_enum_values=True)

cleanup_output_dir_on_start: bool = False
logging_dir: str = "./tensorboard_logs"

dataset_type: DatasetType = DatasetType.chat
train_data_uri: Optional[str]
val_data_uri: Optional[str] = None
val_set_size: float = 0.1

long_sequences_strategy: LongSequenceStrategy = LongSequenceStrategy.error
merge_adapters_post_train: bool = True
extra_hf_training_args: Dict[str, Any] = Field(default_factory=dict)

truefoundry_ml_enable_reporting: bool = False
truefoundry_ml_repo: Optional[str] = None
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
--extra-index-url https://download.pytorch.org/whl/cu121
-r base-requirements.txt
axolotl[deepspeed,flash-attn,mamba-ssm,optimizers,lion-pytorch,galore] @ git+https://github.com/truefoundry/axolotl@bfcb37836b13712afae9d48dc4c6187b1eecb3d5
axolotl[deepspeed,flash-attn,mamba-ssm,optimizers,lion-pytorch,galore] @ git+https://github.com/truefoundry/axolotl@c7fc338e67c4313ec82fcca304733c9ececae5c0
15 changes: 8 additions & 7 deletions sample_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,30 +31,31 @@ accelerate launch \
train.py \
config-base.yaml \
--deepspeed ./deepspeed_configs/3_ds_z2_config.json \
--base_model HuggingFaceTB/SmolLM2-135M-Instruct \
--base_model Qwen/Qwen2.5-0.5B-Instruct \
--dataset_type chat \
--train_data_uri ./sample_data/chatalpaca-openai-1k.jsonl \
--train_data_uri ./sample_data/multiply-1k.jsonl \
--val_data_uri None \
--val_set_size 0.2 \
--sequence_len 4096 \
--eval_sample_packing False \
--sequence_len 2048 \
--max_steps 0 \
--micro_batch_size 4 \
--eval_batch_size 4 \
--num_epochs 10 \
--gradient_accumulation_steps 4 \
--gradient_checkpointing unsloth \
--learning_rate 0.0002 \
--learning_rate 0.0001 \
--output_dir ./outputs \
--train_on_inputs False \
--logging_steps 1 \
--save_strategy steps \
--save_steps 0.2 \
--evaluation_strategy steps \
--eval_strategy steps \
--eval_steps 0.2 \
--adapter qlora \
--lora_target_linear True \
--lora_r 16 \
--lora_alpha 64 \
--lora_r 64 \
--lora_alpha 128 \
--truefoundry_ml_enable_reporting $TRUEFOUNDRY_ML_ENABLE_REPORTING \
--truefoundry_ml_repo $TRUEFOUNDRY_ML_REPO \
--truefoundry_ml_run_name $TRUEFOUNDRY_ML_RUN_NAME \
Expand Down

0 comments on commit 25d052d

Please sign in to comment.