-
Notifications
You must be signed in to change notification settings - Fork 6
/
train.py
76 lines (64 loc) · 3.21 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import transformers
from transformers import RobertaTokenizerFast,AutoTokenizer
from transformers import EncoderDecoderModel
from general_utils import *
import yaml
import os
with open('./config.yaml') as f:
configs = yaml.load(f, Loader=yaml.SafeLoader)
train_data_batch = get_data_batch(path='./data/train_tokenized/*', batch_size=configs['batch_size'])
val_data_batch = get_data_batch(path='./data/val_tokenized/*', batch_size=configs['batch_size'])
if configs['load_pretrained']:
os.makedirs(configs['output_dir']+'/pretrained/', exist_ok=True)
os.system('gsutil -m cp -r "{}/*" "{}"'.format(configs['gcp_pretrained_path'],configs['output_dir']+'/pretrained/'))
try:
roberta_shared = EncoderDecoderModel.from_pretrained(configs['output_dir']+'/pretrained/', tie_encoder_decoder=True)
except:
print('Warning: There is no pretrained model in the provided link. Initializing a new model weights.')
roberta_shared = EncoderDecoderModel.from_encoder_decoder_pretrained("vinai/phobert-base", "vinai/phobert-base", tie_encoder_decoder=True)
else:
roberta_shared = EncoderDecoderModel.from_encoder_decoder_pretrained("vinai/phobert-base", "vinai/phobert-base", tie_encoder_decoder=True)
# set special tokens
roberta_shared.config.decoder_start_token_id = tokenizer.bos_token_id
roberta_shared.config.eos_token_id = tokenizer.eos_token_id
# sensible parameters for beam search
# set decoding params
roberta_shared.config.max_length = configs['max_length']
roberta_shared.config.early_stopping = configs['early_stopping']
roberta_shared.config.no_repeat_ngram_size = configs['no_repeat_ngram_size']
roberta_shared.config.length_penalty = configs['length_penalty']
roberta_shared.config.num_beams = configs['num_beams']
roberta_shared.config.vocab_size = roberta_shared.config.encoder.vocab_size
# set training arguments - these params are not really tuned, feel free to change
training_args = Seq2SeqTrainingArguments(
output_dir= configs['output_dir'],
per_device_train_batch_size=configs['batch_size'],
per_device_eval_batch_size=configs['batch_size'],
predict_with_generate=configs['predict_with_generate'],
do_train=configs['do_train'],
do_eval=configs['do_eval'],
logging_steps=configs['logging_steps'],
save_steps=configs['save_steps'],
eval_steps=configs['eval_steps'],
warmup_steps=configs['warmup_steps'],
num_train_epochs=configs['num_train_epochs'],
overwrite_output_dir=configs['overwrite_output_dir'],
save_total_limit=configs['save_total_limit'],
fp16=configs['fp16'],
)
# instantiate trainer
trainer = Seq2SeqTrainer(
model=roberta_shared,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_data_batch,
eval_dataset=val_data_batch,
callbacks = [UploaderCallback(gcp=configs['gcp_path'],
output_dir=configs['output_dir'])]
)
trainer.train()
# if configs['saved_gcp']:
# print('Warning: Model traned done. Copying traning files to GCP.')
# os.system('gsutil -m cp -r "{}" "{}"'.format(configs['output_dir'],configs['gcp_path']))
# else:
# print('Warning: Model traned done. No saved folder on GCP.')