-
-
Notifications
You must be signed in to change notification settings - Fork 127
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
14582a3
commit 3b5ce7c
Showing
8 changed files
with
13,036 additions
and
241 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
from dataclasses import dataclass, field | ||
import transformers | ||
from transformers import ( | ||
AutoConfig, | ||
AutoModelForCausalLM, | ||
T5ForConditionalGeneration, | ||
AutoTokenizer, | ||
HfArgumentParser, | ||
Trainer, | ||
TrainingArguments, | ||
default_data_collator, | ||
DataCollatorWithPadding, | ||
DataCollatorForLanguageModeling, | ||
is_torch_tpu_available, | ||
set_seed, | ||
) | ||
from transformers.trainer_utils import get_last_checkpoint | ||
from typing import Optional | ||
from streaming.base.format.mds.encodings import Encoding, _encodings | ||
from streaming import LocalDataset | ||
import torch | ||
import logging | ||
import numpy as np | ||
from utils.copied_utils import ( | ||
compute_input_and_target_lengths, | ||
DataCollatorForT5MLM, | ||
tokenize_function, | ||
DataCollatorForNI, | ||
) | ||
|
||
|
||
class UInt16(Encoding): | ||
def encode(self, obj) -> bytes: | ||
return obj.tobytes() | ||
|
||
def decode(self, data: bytes): | ||
return np.frombuffer(data, np.uint16) | ||
|
||
|
||
_encodings['uint16'] = UInt16 | ||
|
||
|
||
@dataclass | ||
class ModelArguments: | ||
""" | ||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. | ||
""" | ||
|
||
model_name_or_path: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": ( | ||
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." | ||
) | ||
}, | ||
) | ||
|
||
|
||
@dataclass | ||
class DataTrainingArguments: | ||
train_file: Optional[str] = field( | ||
default=None, metadata={ | ||
"help": "The input training data file (a text file)."}) | ||
|
||
|
||
def main(): | ||
# See all possible arguments in src/transformers/training_args.py | ||
# or by passing the --help flag to this script. | ||
# We now keep distinct sets of args, for a cleaner separation of concerns. | ||
|
||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) | ||
model_args, data_args, training_args = parser.parse_args_into_dataclasses() | ||
|
||
# Setup logging | ||
logging.basicConfig( | ||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | ||
datefmt="%m/%d/%Y %H:%M:%S", | ||
handlers=[logging.StreamHandler(sys.stdout)], | ||
) | ||
|
||
if training_args.should_log: | ||
# The default of training_args.log_level is passive, so we set log level | ||
# at info here to have that default. | ||
transformers.utils.logging.set_verbosity_info() | ||
|
||
log_level = training_args.get_process_log_level() | ||
logger.setLevel(log_level) | ||
datasets.utils.logging.set_verbosity(log_level) | ||
transformers.utils.logging.set_verbosity(log_level) | ||
transformers.utils.logging.enable_default_handler() | ||
transformers.utils.logging.enable_explicit_format() | ||
|
||
last_checkpoint = None | ||
if os.path.isdir( | ||
training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: | ||
last_checkpoint = get_last_checkpoint(training_args.output_dir) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) | ||
config = AutoConfig.from_pretrained(model_args.model_name_or_path) | ||
model = T5ForConditionalGeneration.from_pretrained(model_args.model_name_or_path) | ||
|
||
before_mask_input_length, target_length = compute_input_and_target_lengths( | ||
inputs_length=512, | ||
noise_density=0.15, | ||
mean_noise_span_length=3.0, | ||
) | ||
|
||
data_collator = DataCollatorForT5MLM( | ||
tokenizer=tokenizer, | ||
noise_density=0.15, | ||
mean_noise_span_length=3.0, | ||
input_length=512, | ||
target_length=target_length, | ||
pad_token_id=config.pad_token_id, | ||
) | ||
dataset = DatasetFixed(data_args.train_file) | ||
|
||
trainer = Trainer( | ||
model=model, | ||
args=training_args, | ||
train_dataset=dataset, | ||
eval_dataset=None, | ||
tokenizer=tokenizer, | ||
data_collator=data_collator, | ||
compute_metrics=None, | ||
preprocess_logits_for_metrics=None, | ||
) | ||
|
||
checkpoint = None | ||
if training_args.resume_from_checkpoint is not None: | ||
checkpoint = training_args.resume_from_checkpoint | ||
elif last_checkpoint is not None: | ||
checkpoint = last_checkpoint | ||
train_result = trainer.train(resume_from_checkpoint=checkpoint) | ||
trainer.save_model() |
215 changes: 204 additions & 11 deletions
215
pretrained-model/nanoT5/prepare-tokenizer-base-model.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
210 changes: 141 additions & 69 deletions
210
pretrained-model/nanoT5/prepare-tokenizer-large-model.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.