-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_lm.sh
68 lines (60 loc) · 1.59 KB
/
train_lm.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#!/bin/bash
# data config
DATA_CKPT=wikimedia/wikipedia
SUB_DATA=20231101.ar
PROCESSED_DATA_PATH=data/$DATA_CKPT.csv
# tokenizer config
BASE_MODEL=openai-community/gpt2
MODEL_NAME=arabic-nano-gpt-v2
MODEL_PATH=models/$MODEL_NAME
MODEL_MAX_LENGTH=1024
VOCAB_SIZE=16384
# model config
EMBED_SIZE=384
NUM_ATT_HEAD=6
NUM_ATT_LAYERS=8
# training config
NUM_EPOCHS=8
BATCH_SIZE=32
ACCUM_STEPS=8
EVAL_STEPS=5000
LOG_STEPS=2000
LR=0.0001
WD=0.000001
WARMUP=0.01
# weights & biases config
PROJECT_NAME=Arabic-Nano-GPT
JOB_TYPE=LM-Modeling
RUN_NAME=Arabic-NanoGPT-LM-on-Wikipedia-Docs-23-V2
NOTES="LM Training on Arabic Data using Nano GPT2 Model Architecture"
TAGS=Modeling,Transformers,GPT2,Language-Modeling,Arabic-Wikipedia
python src/preprocess_data.py \
--data_ckpt=$DATA_CKPT \
--sub_data=$SUB_DATA \
--split_name=train \
--processed_data_file_path=$PROCESSED_DATA_PATH
python src/build_tokenizer.py \
--model_ckpt=$BASE_MODEL \
--data_ckpt=$DATA_CKPT \
--processed_data_file_path=$PROCESSED_DATA_PATH \
--model_max_length=$MODEL_MAX_LENGTH \
--vocab_size=$VOCAB_SIZE \
--model_name=$MODEL_NAME \
--target_model_path=$MODEL_PATH
python src/train_causal_lm.py \
--n_embd=$EMBED_SIZE \
--n_head=$NUM_ATT_HEAD \
--n_layer=$NUM_ATT_LAYERS \
--num_epochs=$NUM_EPOCHS \
--lr=$LR \
--wd=$WD \
--warmup=$WARMUP \
--batch_size=$BATCH_SIZE \
--accum_steps=$ACCUM_STEPS \
--eval_steps=$EVAL_STEPS \
--log_steps=$LOG_STEPS \
--torch_compile \
--model_name=$MODEL_NAME \
--run_name=$RUN_NAME \
--notes="$NOTES" \
--tags=$TAGS