diff --git a/gradio/f5-tts/app.py b/gradio/f5-tts/app.py index ff006b2..82e8eb5 100644 --- a/gradio/f5-tts/app.py +++ b/gradio/f5-tts/app.py @@ -49,13 +49,15 @@ overlap_length = None speech_enhancement_hop_length = None +MODEL_NAME = os.environ.get('MODEL_NAME', 'mesolitica/Malaysian-F5-TTS') +VOCODER_NAME = os.environ.get('VOCODER_NAME', 'charactr/vocos-mel-24khz') maxlen_text = int(os.environ.get('MAXLEN_TEXT', '1000')) maxlen = 20000 maxlen_str = f'{maxlen // 1000} seconds' examples = [] for f in glob('*.mp3'): - examples.append([f, '', 'Model Text to Speech TTS ini dibangunkan seratus peratus oleh Mesolitica, syarikat pemula di Malaysia yang membangunkan juga Malaysia Large Language Model mallam.', False, True, 0.15, 1.0]) + examples.append([f, '', 'Model Text to Speech TTS ini dibangunkan seratus peratus oleh Mesolitica, syarikat pemula di Malaysia yang membangunkan juga Malaysia Large Language Model mallam.', False, 0.15, 1.0]) def load_speech_enhancement(): global speech_enhancement, hp, speech_enhancement_sr, chunk_length, overlap_length, speech_enhancement_hop_length @@ -85,8 +87,8 @@ def load_asr_pipe(): def load_tts(): global model, vocoder gr.Info('Loading TTS model.') - model = load_f5_tts('mesolitica/Malaysian-F5-TTS', device = device, dtype = torch.float16) - vocoder = load_vocoder('mesolitica/malaysian-vocos-mel-24khz', device = device) + model = load_f5_tts(MODEL_NAME, device = device, dtype = torch.float16) + vocoder = load_vocoder(VOCODER_NAME, device = device) convert_char_to_pinyin(['helo']) @@ -163,7 +165,6 @@ def basic_tts( ref_text_input, gen_text_input, reference_enhancement, - output_enhancement, cross_fade_duration_slider, speed_slider, ): @@ -308,13 +309,10 @@ def basic_tts( final_wave = new_wave y = final_wave - if output_enhancement: - y = speech_enhancement_func(torch.tensor(y), sr, resample_back = False) - sr = speech_enhancement_sr - y = y.numpy() + e_y = speech_enhancement_func(torch.tensor(y), sr, resample_back = False) + e_y = e_y.numpy() - audio = (sr, y) - return [audio, ref_text_input] + return [(sr, y), (speech_enhancement_sr, e_y), ref_text_input] with gr.Blocks(theme=theme) as demo: gr.Markdown( @@ -348,11 +346,6 @@ def basic_tts( info="Apply Speech Enhancement to reduce noise for reference audio, this will also increase generation time.", value=False, ) - output_enhancement = gr.Checkbox( - label="Output Enhancement", - info="Apply Speech Enhancement to reduce noise for generated audio, this will also increase generation time.", - value=True, - ) speed_slider = gr.Slider( label="Speed", minimum=0.3, @@ -370,6 +363,7 @@ def basic_tts( info="Set the duration of the cross-fade between audio clips.", ) audio_output = gr.Audio(label="Synthesized Audio", show_download_button = True) + enhanced_audio_output = gr.Audio(label="Enhanced Synthesized Audio", show_download_button = True) generate_btn = gr.Button("Synthesize", variant="primary") generate_btn.click( @@ -379,11 +373,10 @@ def basic_tts( ref_text_input, gen_text_input, reference_enhancement, - output_enhancement, cross_fade_duration_slider, speed_slider, ], - outputs=[audio_output, ref_text_input], + outputs=[audio_output, enhanced_audio_output, ref_text_input], ) examples = gr.Examples( examples=examples, @@ -392,7 +385,6 @@ def basic_tts( ref_text_input, gen_text_input, reference_enhancement, - output_enhancement, cross_fade_duration_slider, speed_slider, ], diff --git a/gradio/f5-tts/kj.mp3 b/gradio/f5-tts/kj.mp3 new file mode 100644 index 0000000..9921bb2 Binary files /dev/null and b/gradio/f5-tts/kj.mp3 differ diff --git a/gradio/f5-tts/p-ramlee.mp3 b/gradio/f5-tts/p-ramlee.mp3 index 45fffbf..f32666e 100644 Binary files a/gradio/f5-tts/p-ramlee.mp3 and b/gradio/f5-tts/p-ramlee.mp3 differ diff --git a/session/f5-tts/README.md b/session/f5-tts/README.md index ee63692..a364dc1 100644 --- a/session/f5-tts/README.md +++ b/session/f5-tts/README.md @@ -1,6 +1,6 @@ # F5-TTS -## how to +## how to Speech Enhancement 1. Download dataset, @@ -11,19 +11,15 @@ tar -xf 7z2301-linux-x64.tar.xz pip3 install huggingface-hub wandb python3 -c " from huggingface_hub import snapshot_download -snapshot_download(repo_id='mesolitica/Malaysian-Emilia-annotated', repo_type='dataset', allow_patterns = 'filtered-24k_processed_24k.z*', local_dir = './') +snapshot_download(repo_id='mesolitica/Malaysian-Emilia', repo_type='dataset', allow_patterns = 'filtered-24k_processed.z*', local_dir = './') " python3 -c " from huggingface_hub import snapshot_download -snapshot_download(repo_id='mesolitica/Malaysian-Emilia-annotated', repo_type='dataset', allow_patterns = 'malaysian-podcast_processed_24k.z*', local_dir = './') +snapshot_download(repo_id='mesolitica/Malaysian-Emilia', repo_type='dataset', allow_patterns = 'malaysian-podcast-processed.z*', local_dir = './') " python3 -c " from huggingface_hub import snapshot_download -snapshot_download(repo_id='mesolitica/Malaysian-Emilia-annotated', repo_type='dataset', allow_patterns = 'sg-podcast_processed_24k.zip', local_dir = './') -" -python3 -c " -from huggingface_hub import snapshot_download -snapshot_download(repo_id='mesolitica/Malaysian-Emilia-annotated', repo_type='dataset', allow_patterns = 'parlimen-24k-chunk_processed_24k.z*', local_dir = './') +snapshot_download(repo_id='mesolitica/Malaysian-Emilia', repo_type='dataset', allow_patterns = 'sg-podcast_processed.zip', local_dir = './') " python3 -c " from huggingface_hub import snapshot_download @@ -33,12 +29,11 @@ python3 -c " from huggingface_hub import snapshot_download snapshot_download(repo_id='mesolitica/Malaysian-Emilia', repo_type='dataset', allow_patterns = 'parlimen-24k-chunk_processed.z*', local_dir = './') " -/workspace/7zz x filtered-24k_processed_24k.zip -y -mmt40 -/workspace/7zz x malaysian-podcast_processed_24k.zip -y -mmt40 -/workspace/7zz x sg-podcast_processed_24k.zip -y -mmt40 -/workspace/7zz x parlimen-24k-chunk_processed_24k.zip -y -mmt40 -/workspace/7zz x malaysian-cartoon.zip -y -mmt40 +/workspace/7zz x filtered-24k_processed.zip -y -mmt40 +/workspace/7zz x malaysian-podcast-processed.zip -y -mmt40 +/workspace/7zz x sg-podcast_processed.zip -y -mmt40 /workspace/7zz x parlimen-24k-chunk_processed.zip -y -mmt40 +/workspace/7zz x malaysian-cartoon.zip -y -mmt40 ``` 2. Install libraries, @@ -47,7 +42,7 @@ snapshot_download(repo_id='mesolitica/Malaysian-Emilia', repo_type='dataset', al git clone https://github.com/mesolitica/F5-TTS cd F5-TTS pip3 install -e . -pip3 install torchdiffeq x-transformers jieba pypinyin ema_pytorch accelerate==1.1.1 +pip3 install torchdiffeq x-transformers jieba pypinyin ema_pytorch accelerate==1.1.1 torch==2.5.1 torchaudio==2.5.1 python3 -c " from huggingface_hub import snapshot_download snapshot_download(repo_id='mesolitica/Malaysian-Voice-Conversion', repo_type='dataset', allow_patterns = 'data/Emilia_Malaysian_pinyin/*', local_dir = './') diff --git a/session/f5-tts/default_config.yaml b/session/f5-tts/default_config.yaml new file mode 100644 index 0000000..63b81a7 --- /dev/null +++ b/session/f5-tts/default_config.yaml @@ -0,0 +1,17 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +enable_cpu_affinity: false +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: 'no' +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/session/f5-tts/train.py b/session/f5-tts/train.py index 0ae3f23..55f4399 100644 --- a/session/f5-tts/train.py +++ b/session/f5-tts/train.py @@ -34,7 +34,7 @@ epochs = 100 # use linear decay, thus epochs control the slope num_warmup_updates = 2000 # warmup steps save_per_updates = 50000 # save checkpoint per steps -last_per_steps = 500 # save last checkpoint per steps +last_per_steps = 2000 # save last checkpoint per steps # model params if exp_name == "F5TTS_Base": @@ -84,7 +84,7 @@ def main(): max_samples=max_samples, grad_accumulation_steps=grad_accumulation_steps, max_grad_norm=max_grad_norm, - wandb_project="CFM-TTS", + wandb_project="CFM-TTS-original", wandb_run_name=exp_name, wandb_resume_id=wandb_resume_id, last_per_steps=last_per_steps, diff --git a/session/smollm2-speech-semantics/README.md b/session/smollm2-speech-semantics/README.md new file mode 100644 index 0000000..1b26b11 --- /dev/null +++ b/session/smollm2-speech-semantics/README.md @@ -0,0 +1,19 @@ +# Finetune SmolLM2 for speech semantic tokens + +## how to + +1. Clone the dataset, + +```bash +python3 -c " +from huggingface_hub import snapshot_download +snapshot_download(repo_id='mesolitica/smollm2-speech-semantic-multipack-2048', repo_type='dataset', local_dir = './smollm2-speech-semantic-multipack-2048') +" +``` + +2. Finetune, + +```bash +smollm2-135m-speech.sh +smollm2-360m-speech.sh +``` \ No newline at end of file diff --git a/session/smollm2-speech-semantics/run-instruction-speech-multipack.py b/session/smollm2-speech-semantics/run-instruction-speech-multipack.py new file mode 100644 index 0000000..41ed794 --- /dev/null +++ b/session/smollm2-speech-semantics/run-instruction-speech-multipack.py @@ -0,0 +1,450 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset. + +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=text-generation +""" +# You can also adapt this script on your own causal language modeling +# task. Pointers for this are left as comments. + +import logging +import math +import os +import sys +import warnings +from dataclasses import dataclass, field +from itertools import chain +from typing import Optional + +import datasets +import evaluate +import torch + +torch._dynamo.config.optimize_ddp=False + +from datasets import load_dataset + +import transformers +import random +from transformers import ( + CONFIG_MAPPING, + MODEL_FOR_CAUSAL_LM_MAPPING, + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + HfArgumentParser, + Trainer, + TrainingArguments, + default_data_collator, + DataCollatorWithPadding, + DataCollatorForLanguageModeling, + is_torch_tpu_available, + set_seed, +) +from transformers.testing_utils import CaptureLogger +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import check_min_version, send_example_telemetry +from transformers.utils.versions import require_version +from transformers import AddedToken +import streaming +import json +import numpy as np +from streaming import LocalDataset +from streaming.base.format.mds.encodings import Encoding, _encodings +from peft import LoraConfig, get_peft_model + +require_version( + "datasets>=1.8.0", + "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") + +logger = logging.getLogger(__name__) + + +MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. + """ + + rank: int = field( + default=256, + metadata={ + "help": "lora rank" + }, + ) + + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." + ) + }, + ) + model_type: Optional[str] = field( + default=None, + metadata={ + "help": "If training from scratch, pass a model type from the list: " + + ", ".join(MODEL_TYPES)}, + ) + config_overrides: Optional[str] = field( + default=None, metadata={ + "help": ( + "Override some existing default config settings when a model is trained from scratch. Example: " + "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index")}, ) + config_name: Optional[str] = field( + default=None, metadata={ + "help": "Pretrained config name or path if not the same as model_name"}) + tokenizer_name: Optional[str] = field( + default=None, metadata={ + "help": "Pretrained tokenizer name or path if not the same as model_name"}) + cache_dir: Optional[str] = field( + default=None, metadata={ + "help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, ) + use_fast_tokenizer: bool = field( + default=True, metadata={ + "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, ) + model_revision: str = field( + default="main", metadata={ + "help": "The specific model version to use (can be a branch name, tag name or commit id)."}, ) + token: str = field( + default=None, + metadata={ + "help": ( + "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " + "generated when running `huggingface-cli login` (stored in `~/.huggingface`)." + ) + }, + ) + use_auth_token: bool = field( + default=None, + metadata={ + "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`." + }, + ) + trust_remote_code: bool = field( + default=False, metadata={ + "help": ( + "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option" + "should only be set to `True` for repositories you trust and in which you have read the code, as it will" + "execute code present on the Hub on your local machine.")}, ) + torch_dtype: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " + "dtype will be automatically derived from the model's weights."), + "choices": [ + "auto", + "bfloat16", + "float16", + "float32"], + }, + ) + low_cpu_mem_usage: bool = field( + default=False, + metadata={ + "help": ( + "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded." + "set True will benefit LLM loading time and RAM consumption." + ) + }, + ) + + def __post_init__(self): + if self.config_overrides is not None and ( + self.config_name is not None or self.model_name_or_path is not None): + raise ValueError( + "--config_overrides can't be used in combination with --config_name or --model_name_or_path" + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, metadata={ + "help": "The name of the dataset to use (via the datasets library)."}) + dataset_config_name: Optional[str] = field( + default=None, metadata={ + "help": "The configuration name of the dataset to use (via the datasets library)."}) + train_file: Optional[str] = field( + default=None, metadata={ + "help": "The input training data file (a text file)."}) + validation_file: Optional[str] = field( + default=None, metadata={ + "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, ) + max_train_samples: Optional[int] = field( + default=None, metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set.")}, ) + max_eval_samples: Optional[int] = field( + default=None, metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set.")}, ) + streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"}) + block_size: Optional[int] = field( + default=None, + metadata={ + "help": ( + "Optional input sequence length after tokenization. " + "The training dataset will be truncated in block of this size for training. " + "Default to the model max input length for single sentence inputs (take into account special tokens)." + ) + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + validation_split_percentage: Optional[int] = field( + default=5, + metadata={ + "help": "The percentage of the train set used as validation set in case there's no validation split" + }, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + keep_linebreaks: bool = field( + default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} + ) + +def print_trainable_parameters(model): + """ + Prints the number of trainable parameters in the model. + """ + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + all_param += param.numel() + if param.requires_grad: + trainable_params += param.numel() + print( + f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" + ) + +def block_diagonal_concat_inverted(*masks, dtype=torch.bfloat16): + total_size = sum(mask.size(0) for mask in masks) + combined_mask = torch.zeros(total_size, total_size, dtype=dtype) + + current_pos = 0 + + for mask in masks: + size = mask.size(0) + combined_mask[current_pos:current_pos + size, current_pos:current_pos + size] = mask + current_pos += size + + min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min + inverted_mask = torch.where(combined_mask == 1, torch.tensor(0, dtype=dtype), min_value) + return inverted_mask.unsqueeze(0) + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file( + json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if model_args.use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v4.34.", + FutureWarning) + if model_args.token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`.") + model_args.token = model_args.use_auth_token + + # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The + # information sent is the one passed as arguments along with your Python/PyTorch versions. + send_example_telemetry("run_clm", model_args, data_args) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + if training_args.should_log: + # The default of training_args.log_level is passive, so we set log level + # at info here to have that default. + transformers.utils.logging.set_verbosity_info() + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + print(model_args) + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}") + logger.info(f"Training/evaluation parameters {training_args}") + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir( + training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + + # Set seed before initializing model. + set_seed(training_args.seed) + + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + + config_kwargs = { + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "token": model_args.token, + "trust_remote_code": model_args.trust_remote_code, + } + + if model_args.config_name: + config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) + elif model_args.model_name_or_path: + config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) + else: + config = CONFIG_MAPPING[model_args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + if model_args.config_overrides is not None: + logger.info(f"Overriding config: {model_args.config_overrides}") + config.update_from_string(model_args.config_overrides) + logger.info(f"New config: {config}") + + tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) + new = ['<|speech_start|>', '<|speech_end|>', '<|text_start|>', '<|text_end|>'] + new = [AddedToken(t) for t in new] + tokenizer.add_tokens(new) + speech_tokens = [AddedToken(f"<|{i}|>") for i in range(1024)] + tokenizer.add_tokens(speech_tokens) + + class UInt32(Encoding): + def encode(self, obj) -> bytes: + return obj.tobytes() + + def decode(self, data: bytes): + return np.frombuffer(data, np.uint32) + + _encodings['uint32'] = UInt32 + + class DatasetFixed(torch.utils.data.Dataset): + def __init__(self, local): + self.dataset = LocalDataset(local=local) + + def __getitem__(self, idx): + data = self.dataset[idx] + data['labels'] = data["input_ids"].copy() + masking = data.pop('attention_mask') + + data.pop('token_type_ids', None) + for k in data.keys(): + data[k] = data[k].astype(np.int64) + + masks = [] + for m in masking: + masks.append(torch.tril(torch.ones(m, m))) + attention_mask = block_diagonal_concat_inverted(*masks) + data['attention_mask'] = attention_mask + + return data + + def __len__(self): + return len(self.dataset) + + dataset = DatasetFixed(data_args.train_file) + print(len(dataset), dataset[0]['attention_mask'].shape) + + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) + ) + + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + torch_dtype=torch_dtype, + low_cpu_mem_usage=model_args.low_cpu_mem_usage, + attn_implementation = 'sdpa', + ) + model.resize_token_embeddings(len(tokenizer), mean_resizing=False) + print(model) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=dataset, + eval_dataset=None, + tokenizer=tokenizer, + data_collator=default_data_collator, + compute_metrics=None, + preprocess_logits_for_metrics=None, + ) + + # Training + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_model() + trainer.save_state() + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/session/smollm2-speech-semantics/smollm2-135m-speech.sh b/session/smollm2-speech-semantics/smollm2-135m-speech.sh new file mode 100644 index 0000000..1aab6a3 --- /dev/null +++ b/session/smollm2-speech-semantics/smollm2-135m-speech.sh @@ -0,0 +1,23 @@ +WANDB_PROJECT="run-instruction-speech-multipack-HuggingFaceTB-SmolLM2-135M" \ +TORCH_DISTRIBUTED_DEBUG="info" \ +torchrun --nproc_per_node 2 \ +-m run-instruction-speech-multipack \ +--model_name_or_path HuggingFaceTB/SmolLM2-135M \ +--per_device_train_batch_size 20 \ +--gradient_accumulation_steps 2 \ +--output_dir run-instruction-speech-multipack-HuggingFaceTB-SmolLM2-135M \ +--bf16 --do_train --do_eval false --num_train_epochs 5 \ +--train_file /home/husein/ssd4/continue-training/smollm2-speech-semantic-multipack-2048 \ +--logging_steps 1 \ +--learning_rate 5e-4 \ +--warmup_steps 200 \ +--block_size 24576 \ +--save_steps 1000 \ +--save_total_limit 3 \ +--gradient_checkpointing true \ +--torch_dtype bfloat16 \ +--ddp_find_unused_parameters false \ +--dataloader_num_workers 3 \ +--dataloader_prefetch_factor 4 \ +--torch_compile \ +--torch_compile_backend inductor \ No newline at end of file diff --git a/session/smollm2-speech-semantics/smollm2-135m-tts-combine-annotated.sh b/session/smollm2-speech-semantics/smollm2-135m-tts-combine-annotated.sh new file mode 100644 index 0000000..a9d4645 --- /dev/null +++ b/session/smollm2-speech-semantics/smollm2-135m-tts-combine-annotated.sh @@ -0,0 +1,23 @@ +WANDB_PROJECT="tts-combine-annotated-multipack-HuggingFaceTB-SmolLM2-135M" \ +TORCH_DISTRIBUTED_DEBUG="info" \ +torchrun --nproc_per_node 2 \ +-m tts-multipack \ +--model_name_or_path mesolitica/SmolLM2-135M-firefly-vqgan \ +--per_device_train_batch_size 20 \ +--gradient_accumulation_steps 2 \ +--output_dir tts-combine-annotated-multipack-HuggingFaceTB-SmolLM2-135M \ +--bf16 --do_train --do_eval false --num_train_epochs 5 \ +--train_file smollm2-speech-semantic-multipack-2048 \ +--logging_steps 1 \ +--learning_rate 1e-4 \ +--warmup_steps 200 \ +--block_size 24576 \ +--save_steps 1000 \ +--save_total_limit 3 \ +--gradient_checkpointing true \ +--torch_dtype bfloat16 \ +--ddp_find_unused_parameters false \ +--dataloader_num_workers 3 \ +--dataloader_prefetch_factor 4 \ +--torch_compile \ +--torch_compile_backend inductor \ No newline at end of file diff --git a/session/smollm2-speech-semantics/smollm2-360m-speech.sh b/session/smollm2-speech-semantics/smollm2-360m-speech.sh new file mode 100644 index 0000000..531c590 --- /dev/null +++ b/session/smollm2-speech-semantics/smollm2-360m-speech.sh @@ -0,0 +1,23 @@ +WANDB_PROJECT="run-instruction-speech-multipack-HuggingFaceTB-SmolLM2-360M" \ +TORCH_DISTRIBUTED_DEBUG="info" \ +torchrun --nproc_per_node 4 \ +-m run-instruction-speech-multipack \ +--model_name_or_path HuggingFaceTB/SmolLM2-360M \ +--per_device_train_batch_size 16 \ +--gradient_accumulation_steps 2 \ +--output_dir run-instruction-speech-multipack-HuggingFaceTB-SmolLM2-360M \ +--bf16 --do_train --do_eval false --num_train_epochs 5 \ +--train_file smollm2-speech-semantic-multipack-2048 \ +--logging_steps 1 \ +--learning_rate 5e-4 \ +--warmup_steps 200 \ +--block_size 24576 \ +--save_steps 200 \ +--save_total_limit 3 \ +--gradient_checkpointing true \ +--torch_dtype bfloat16 \ +--ddp_find_unused_parameters false \ +--dataloader_num_workers 3 \ +--dataloader_prefetch_factor 4 \ +--torch_compile \ +--torch_compile_backend inductor \ No newline at end of file diff --git a/session/smollm2-speech-semantics/smollm2-speech-semantic-multipack-2048.ipynb b/session/smollm2-speech-semantics/smollm2-speech-semantic-multipack-2048.ipynb new file mode 100644 index 0000000..25acdcf --- /dev/null +++ b/session/smollm2-speech-semantics/smollm2-speech-semantic-multipack-2048.ipynb @@ -0,0 +1,2089 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "96210ac2", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "1024" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from transformers import AutoTokenizer, AutoConfig\n", + "from transformers import AddedToken\n", + "import os\n", + "import numpy as np\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained('HuggingFaceTB/SmolLM2-135M')\n", + "new = ['<|speech_start|>', '<|speech_end|>', '<|text_start|>', '<|text_end|>']\n", + "new = [AddedToken(t) for t in new]\n", + "tokenizer.add_tokens(new)\n", + "speech_tokens = [AddedToken(f\"<|{i}|>\") for i in range(1024)]\n", + "tokenizer.add_tokens(speech_tokens)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "078c35c9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2438225" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "df = pd.read_parquet('/home/husein/ssd3/verify-text.parquet').to_dict(orient = 'records')\n", + "len(df)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "98d77100", + "metadata": {}, + "outputs": [], + "source": [ + "t = df[0]['transcription']\n", + "splitted = df[0]['audio'].split('/')\n", + "new_f = '/'.join([splitted[0] + '_vqgan'] + splitted[1:]).replace('.mp3', '.npy')\n", + "new_f = os.path.join('/home/husein/ssd3', new_f)\n", + "speech_t = np.load(new_f)\n", + "speech_t = ''.join([f'<|{t}|>' for t in speech_t.tolist()])\n", + "tts = f'<|text_start|>{t}<|text_end|><|speech_start|>{speech_t}<|speech_end|>'\n", + "stt = f'<|speech_start|>{speech_t}<|speech_end|><|text_start|>{t}<|text_end|>'" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "19e849df", + "metadata": {}, + "outputs": [], + "source": [ + "from streaming import MDSWriter\n", + "from streaming.base.format.mds.encodings import Encoding, _encodings\n", + "from streaming import LocalDataset\n", + "import streaming\n", + "import numpy as np\n", + "from tqdm import tqdm\n", + "from glob import glob\n", + "import os\n", + "import json\n", + "\n", + "class UInt32(Encoding):\n", + " def encode(self, obj) -> bytes:\n", + " return obj.tobytes()\n", + "\n", + " def decode(self, data: bytes):\n", + " return np.frombuffer(data, np.uint32)\n", + "\n", + "_encodings['uint32'] = UInt32\n", + "\n", + "columns = {\n", + " 'input_ids': 'uint32',\n", + " 'position_ids': 'uint32',\n", + " 'attention_mask': 'uint32',\n", + "}\n", + "hashes = 'sha1', 'xxh64'" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0c6ee450", + "metadata": {}, + "outputs": [], + "source": [ + "import gc\n", + "\n", + "def collator(batch, batch_position_ids):\n", + " input_ids = []\n", + " position_ids = []\n", + " masks = []\n", + " for i in range(len(batch)):\n", + " l = len(batch[i])\n", + " input_ids.extend(batch[i])\n", + " position_ids.extend(batch_position_ids[i])\n", + " masks.append(l)\n", + " \n", + " return {\n", + " 'input_ids': np.array(input_ids).astype(np.uint32),\n", + " 'position_ids': np.array(position_ids).astype(np.uint32),\n", + " 'attention_mask': np.array(masks).astype(np.uint32),\n", + " }\n", + "\n", + "def slice_and_balance(nested_list, size):\n", + " first = []\n", + " balance = []\n", + " current_size = 0\n", + "\n", + " for sublist in nested_list:\n", + " if current_size < size:\n", + " remaining_space = size - current_size\n", + " if len(sublist) <= remaining_space:\n", + " first.append(sublist)\n", + " current_size += len(sublist)\n", + " else:\n", + " first.append(sublist[:remaining_space])\n", + " balance.append(sublist[remaining_space:])\n", + " current_size = size\n", + " else:\n", + " balance.append(sublist)\n", + " \n", + " return first, balance" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "14ba02fc", + "metadata": {}, + "outputs": [], + "source": [ + "!rm -rf tokenized-2048\n", + "!mkdir tokenized-2048" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "9a3ef60d", + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "\n", + "def loop(files, block_size = 2048):\n", + " rows, index = files\n", + " out_root = f'tokenized-2048/tokenized-{index}'\n", + " os.system(f'rm -rf {out_root}')\n", + " count = 0\n", + " temp = []\n", + " position_ids = []\n", + " last_block, last_position_block = None, None\n", + " with MDSWriter(out=out_root, columns=columns, compression=None, hashes=hashes) as out:\n", + " for row in tqdm(rows):\n", + " \n", + " t = row['transcription']\n", + " splitted = row['audio'].split('/')\n", + " new_f = '/'.join([splitted[0] + '_vqgan'] + splitted[1:]).replace('.mp3', '.npy')\n", + " new_f = os.path.join('/home/husein/ssd3', new_f)\n", + " if not os.path.exists(new_f):\n", + " continue\n", + " \n", + " speech_t = np.load(new_f)\n", + " speech_t = ''.join([f'<|{t}|>' for t in speech_t.tolist()])\n", + " tts = f'<|text_start|>{t}<|text_end|><|speech_start|>{speech_t}<|speech_end|>'\n", + " \n", + " outputs = tokenizer(tts, add_special_tokens = False)\n", + " temp.append(outputs['input_ids'])\n", + " position_ids.append(range(len(outputs['input_ids'])))\n", + " count += len(outputs['input_ids'])\n", + " \n", + " while count >= block_size:\n", + " block, temp = slice_and_balance(temp, block_size)\n", + " block_position, position_ids = slice_and_balance(position_ids, block_size)\n", + " count = count - block_size\n", + " o = collator(block, block_position)\n", + " last_block = block\n", + " last_position_block = block_position\n", + " out.write(o)\n", + " \n", + " block, _ = slice_and_balance(last_block, block_size - count)\n", + " block_position, _ = slice_and_balance(last_position_block, block_size - count)\n", + "\n", + " block.extend(temp)\n", + " block_position.extend(position_ids)\n", + "\n", + " o = collator(block, block_position)\n", + " if len(o['input_ids']) == block_size:\n", + " out.write(o)\n", + " return o" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "a16d2dee", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████| 1000/1000 [00:01<00:00, 644.54it/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "{'input_ids': array([50107, 49691, 49166, ..., 49863, 49689, 49153], dtype=uint32),\n", + " 'position_ids': array([1782, 1783, 1784, ..., 852, 853, 854], dtype=uint32),\n", + " 'attention_mask': array([ 145, 1452, 47, 404], dtype=uint32)}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loop((df[:1000], 0))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "7a7fb01d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "100%|██████████████████████████████████| 100000/100000 [08:30<00:00, 195.93it/s]\n", + "100%|██████████████████████████████████| 100000/100000 [08:29<00:00, 196.26it/s]\n", + "100%|██████████████████████████████████| 100000/100000 [08:27<00:00, 197.17it/s]\n", + "100%|██████████████████████████████████| 100000/100000 [08:33<00:00, 194.91it/s]\n", + "100%|██████████████████████████████████| 100000/100000 [08:32<00:00, 195.17it/s]\n", + "100%|██████████████████████████████████| 100000/100000 [08:39<00:00, 192.43it/s]\n", + "100%|██████████████████████████████████| 100000/100000 [08:41<00:00, 191.70it/s]\n", + "100%|██████████████████████████████████| 100000/100000 [08:40<00:00, 192.02it/s]\n", + "100%|██████████████████████████████████| 100000/100000 [08:44<00:00, 190.67it/s]\n", + "100%|██████████████████████████████████| 100000/100000 [09:10<00:00, 181.66it/s]\n", + "100%|██████████████████████████████████| 100000/100000 [08:44<00:00, 190.60it/s]\n", + "100%|██████████████████████████████████| 100000/100000 [08:43<00:00, 191.12it/s]\n", + "100%|██████████████████████████████████| 100000/100000 [08:58<00:00, 185.74it/s]\n", + "100%|██████████████████████████████████| 100000/100000 [08:57<00:00, 186.19it/s]\n", + "100%|██████████████████████████████████| 100000/100000 [08:57<00:00, 186.09it/s]\n", + "100%|██████████████████████████████████| 100000/100000 [09:18<00:00, 179.15it/s]\n", + "100%|██████████████████████████████████| 100000/100000 [09:47<00:00, 170.25it/s]\n", + "100%|██████████████████████████████████| 100000/100000 [09:53<00:00, 168.46it/s]\n", + "100%|██████████████████████████████████| 100000/100000 [09:49<00:00, 169.53it/s]\n", + "100%|██████████████████████████████████| 100000/100000 [09:48<00:00, 170.03it/s]\n", + "100%|████████████████████████████████████| 38225/38225 [02:03<00:00, 310.43it/s]\n", + "100%|██████████████████████████████████| 100000/100000 [05:15<00:00, 316.74it/s]\n", + "100%|██████████████████████████████████| 100000/100000 [05:23<00:00, 309.53it/s]\n", + "100%|██████████████████████████████████| 100000/100000 [05:11<00:00, 321.06it/s]\n", + "100%|██████████████████████████████████| 100000/100000 [05:15<00:00, 316.57it/s]\n" + ] + } + ], + "source": [ + "from multiprocess import Pool\n", + "import mp\n", + "\n", + "chunks = mp.chunks(df, 100000)\n", + "pool = Pool(10)\n", + "pooled = pool.map(loop, chunks)\n", + "pool.close()\n", + "pool.join()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "fc1ab01e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['tokenized-2048/tokenized-0',\n", + " 'tokenized-2048/tokenized-1',\n", + " 'tokenized-2048/tokenized-2',\n", + " 'tokenized-2048/tokenized-3',\n", + " 'tokenized-2048/tokenized-4',\n", + " 'tokenized-2048/tokenized-5',\n", + " 'tokenized-2048/tokenized-6',\n", + " 'tokenized-2048/tokenized-7',\n", + " 'tokenized-2048/tokenized-8',\n", + " 'tokenized-2048/tokenized-9',\n", + " 'tokenized-2048/tokenized-10',\n", + " 'tokenized-2048/tokenized-11',\n", + " 'tokenized-2048/tokenized-12',\n", + " 'tokenized-2048/tokenized-13',\n", + " 'tokenized-2048/tokenized-14',\n", + " 'tokenized-2048/tokenized-15',\n", + " 'tokenized-2048/tokenized-16',\n", + " 'tokenized-2048/tokenized-17',\n", + " 'tokenized-2048/tokenized-18',\n", + " 'tokenized-2048/tokenized-19',\n", + " 'tokenized-2048/tokenized-20',\n", + " 'tokenized-2048/tokenized-21',\n", + " 'tokenized-2048/tokenized-22',\n", + " 'tokenized-2048/tokenized-23',\n", + " 'tokenized-2048/tokenized-24']" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "folders = sorted(glob('tokenized-2048/tokenized-*'), key = lambda x: int(x.split('-')[-1]))\n", + "folders" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "82a70251", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = LocalDataset(folders[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "bf0732f4", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" + ] + } + ], + "source": [ + "!rm -rf smollm2-speech-semantic-multipack-2048" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "81e3e590", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████| 90078/90078 [00:04<00:00, 21220.53it/s]\n", + "100%|██████████████████████████████████| 91347/91347 [00:04<00:00, 22195.75it/s]\n", + "100%|██████████████████████████████████| 89677/89677 [00:04<00:00, 21200.47it/s]\n", + "100%|██████████████████████████████████| 91201/91201 [00:07<00:00, 11595.93it/s]\n", + "100%|██████████████████████████████████| 90920/90920 [00:04<00:00, 19928.08it/s]\n", + "100%|██████████████████████████████████| 90876/90876 [00:04<00:00, 19986.43it/s]\n", + "100%|██████████████████████████████████| 91056/91056 [00:05<00:00, 16781.18it/s]\n", + "100%|███████████████████████████████████| 90230/90230 [00:11<00:00, 7993.23it/s]\n", + "100%|██████████████████████████████████| 90289/90289 [00:04<00:00, 20804.48it/s]\n", + "100%|██████████████████████████████████| 90723/90723 [00:04<00:00, 20546.19it/s]\n", + "100%|██████████████████████████████████| 91015/91015 [00:07<00:00, 11646.02it/s]\n", + "100%|███████████████████████████████████| 91343/91343 [00:15<00:00, 5742.13it/s]\n", + "100%|██████████████████████████████████| 90520/90520 [00:08<00:00, 10631.77it/s]\n", + "100%|██████████████████████████████████| 91116/91116 [00:04<00:00, 20448.20it/s]\n", + "100%|███████████████████████████████████| 91678/91678 [00:10<00:00, 8426.00it/s]\n", + "100%|███████████████████████████████████| 90756/90756 [00:14<00:00, 6101.54it/s]\n", + "100%|██████████████████████████████████| 85933/85933 [00:04<00:00, 20449.16it/s]\n", + "100%|██████████████████████████████████| 85723/85723 [00:04<00:00, 19082.55it/s]\n", + "100%|███████████████████████████████████| 85783/85783 [00:16<00:00, 5179.58it/s]\n", + "100%|███████████████████████████████████| 85144/85144 [00:10<00:00, 8451.29it/s]\n", + "100%|██████████████████████████████████| 85795/85795 [00:04<00:00, 20305.28it/s]\n", + "100%|██████████████████████████████████| 86378/86378 [00:05<00:00, 15215.02it/s]\n", + "100%|███████████████████████████████████| 86291/86291 [00:10<00:00, 8541.34it/s]\n", + "100%|███████████████████████████████████| 85956/85956 [00:16<00:00, 5270.23it/s]\n", + "100%|██████████████████████████████████| 32631/32631 [00:01<00:00, 20870.85it/s]\n" + ] + } + ], + "source": [ + "with MDSWriter(\n", + " out='smollm2-speech-semantic-multipack-2048', columns=columns, compression=None, hashes=hashes) as out:\n", + " for f in folders:\n", + " try:\n", + " dataset = LocalDataset(local=f)\n", + " for i in tqdm(range(len(dataset))):\n", + " out.write(dataset[i])\n", + " except Exception as e:\n", + " print(e)\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "72e2a21f", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = LocalDataset('smollm2-speech-semantic-multipack-2048')" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "2d3a66aa", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4.449196032" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(len(dataset) * 2048) / 1e9" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "5ee8783a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'attention_mask': array([1364, 684], dtype=uint32),\n", + " 'input_ids': array([49154, 51, 4075, ..., 49385, 49840, 50075], dtype=uint32),\n", + " 'position_ids': array([ 0, 1, 2, ..., 681, 682, 683], dtype=uint32)}" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "beef0216", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'attention_mask': array([2048], dtype=uint32),\n", + " 'input_ids': array([49440, 49427, 49595, ..., 49697, 49837, 49491], dtype=uint32),\n", + " 'position_ids': array([ 684, 685, 686, ..., 2729, 2730, 2731], dtype=uint32)}" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "3f5cc713", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'attention_mask': array([ 544, 1504], dtype=uint32),\n", + " 'input_ids': array([49579, 49576, 49509, ..., 49647, 49995, 49401], dtype=uint32),\n", + " 'position_ids': array([4780, 4781, 4782, ..., 1501, 1502, 1503], dtype=uint32)}" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset[3]" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "cb1438d8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RepoUrl('https://huggingface.co/datasets/mesolitica/smollm2-speech-semantic-multipack-2048', endpoint='https://huggingface.co', repo_type='dataset', repo_id='mesolitica/smollm2-speech-semantic-multipack-2048')" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from huggingface_hub import create_repo, delete_repo\n", + "\n", + "try:\n", + " delete_repo(repo_id=\"mesolitica/smollm2-speech-semantic-multipack-2048\", repo_type=\"dataset\")\n", + "except:\n", + " pass\n", + "create_repo(\"mesolitica/smollm2-speech-semantic-multipack-2048\", repo_type=\"dataset\", private = True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2796a548", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "dedb145454d546e7b35835fc6aa89d8f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "shard.00002.mds: 0%| | 0.00/67.1M [00:00' for t in speech_t.tolist()])\n", + "tts = f'<|text_start|>{speaker}: {t}<|text_end|><|speech_start|>{speech_t}<|speech_end|>'" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0d5b1c02", + "metadata": {}, + "outputs": [], + "source": [ + "tokens = tts.split('<|speech_start|>')[1].split('<|speech_end|>')[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c2d0eb8f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(8, 153)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import re\n", + "\n", + "numbers = [int(t) for t in re.findall(r'<\\|(\\d+)\\|>', tokens)]\n", + "np.array(numbers).reshape((-1, 8)).T.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "d374b439", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# model = load_vqgan(device = 'cuda')\n", + "# i = torch.tensor(np.array(numbers).reshape((-1, 8)).T[None])\n", + "# y_, _ = model.decode(i.cuda(), torch.tensor([i.shape[-1]]).cuda())\n", + "# ipd.Audio(y_.detach().cpu().numpy()[0, 0], rate = model.spec_transform.sample_rate)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "20384f6b", + "metadata": {}, + "outputs": [], + "source": [ + "import gc\n", + "\n", + "def collator(batch, batch_position_ids):\n", + " input_ids = []\n", + " position_ids = []\n", + " masks = []\n", + " for i in range(len(batch)):\n", + " l = len(batch[i])\n", + " input_ids.extend(batch[i])\n", + " position_ids.extend(batch_position_ids[i])\n", + " masks.append(l)\n", + " \n", + " return {\n", + " 'input_ids': np.array(input_ids).astype(np.uint32),\n", + " 'position_ids': np.array(position_ids).astype(np.uint32),\n", + " 'attention_mask': np.array(masks).astype(np.uint32),\n", + " }\n", + "\n", + "def slice_and_balance(nested_list, size):\n", + " first = []\n", + " balance = []\n", + " current_size = 0\n", + "\n", + " for sublist in nested_list:\n", + " if current_size < size:\n", + " remaining_space = size - current_size\n", + " if len(sublist) <= remaining_space:\n", + " first.append(sublist)\n", + " current_size += len(sublist)\n", + " else:\n", + " first.append(sublist[:remaining_space])\n", + " balance.append(sublist[remaining_space:])\n", + " current_size = size\n", + " else:\n", + " balance.append(sublist)\n", + " \n", + " return first, balance" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "76757131", + "metadata": {}, + "outputs": [], + "source": [ + "!rm -rf tokenized-2048\n", + "!mkdir tokenized-2048" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "cfcb4fc4", + "metadata": {}, + "outputs": [], + "source": [ + "from streaming import MDSWriter\n", + "from streaming.base.format.mds.encodings import Encoding, _encodings\n", + "from streaming import LocalDataset\n", + "import streaming\n", + "import numpy as np\n", + "from tqdm import tqdm\n", + "from glob import glob\n", + "import os\n", + "import json\n", + "\n", + "class UInt32(Encoding):\n", + " def encode(self, obj) -> bytes:\n", + " return obj.tobytes()\n", + "\n", + " def decode(self, data: bytes):\n", + " return np.frombuffer(data, np.uint32)\n", + "\n", + "_encodings['uint32'] = UInt32\n", + "\n", + "columns = {\n", + " 'input_ids': 'uint32',\n", + " 'position_ids': 'uint32',\n", + " 'attention_mask': 'uint32',\n", + "}\n", + "hashes = 'sha1', 'xxh64'" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "a4ba4a94", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "8192" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.model_max_length" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "3268e801", + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "\n", + "def loop(files, block_size = 2048):\n", + " rows, index = files\n", + " out_root = f'tokenized-2048/tokenized-{index}'\n", + " os.system(f'rm -rf {out_root}')\n", + " count = 0\n", + " temp = []\n", + " position_ids = []\n", + " last_block, last_position_block = None, None\n", + " with MDSWriter(out=out_root, columns=columns, compression=None, hashes=hashes) as out:\n", + " for row in tqdm(rows):\n", + " \n", + " speaker = row['speaker']\n", + " t = row['transcription']\n", + " splitted = row['audio_filename'].split('/')\n", + " new_f = '/'.join([splitted[0] + '_vqgan'] + splitted[1:]).replace('.mp3', '.npy')\n", + " speech_t = np.load(new_f)\n", + " speech_t = ''.join([f'<|{t}|>' for t in speech_t.tolist()])\n", + " tts = f'<|text_start|>{speaker}: {t}<|text_end|><|speech_start|>{speech_t}<|speech_end|>'\n", + " \n", + " outputs = tokenizer(tts, add_special_tokens = False)\n", + " if len(outputs['input_ids']) >= tokenizer.model_max_length:\n", + " continue\n", + " temp.append(outputs['input_ids'])\n", + " position_ids.append(range(len(outputs['input_ids'])))\n", + " count += len(outputs['input_ids'])\n", + " \n", + " while count >= block_size:\n", + " block, temp = slice_and_balance(temp, block_size)\n", + " block_position, position_ids = slice_and_balance(position_ids, block_size)\n", + " count = count - block_size\n", + " o = collator(block, block_position)\n", + " last_block = block\n", + " last_position_block = block_position\n", + " out.write(o)\n", + " \n", + " block, _ = slice_and_balance(last_block, block_size - count)\n", + " block_position, _ = slice_and_balance(last_position_block, block_size - count)\n", + "\n", + " block.extend(temp)\n", + " block_position.extend(position_ids)\n", + "\n", + " o = collator(block, block_position)\n", + " if len(o['input_ids']) == block_size:\n", + " out.write(o)\n", + " return o" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "329ea472", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████| 1000/1000 [00:01<00:00, 993.29it/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "{'input_ids': array([49595, 49955, 49696, ..., 49666, 49890, 49153], dtype=uint32),\n", + " 'position_ids': array([ 314, 315, 316, ..., 1812, 1813, 1814], dtype=uint32),\n", + " 'attention_mask': array([ 138, 95, 1815], dtype=uint32)}" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loop((df[:1000], 0))" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "f4f25fce", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "586" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset = LocalDataset('tokenized-2048/tokenized-0')\n", + "len(dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "12df9abf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'attention_mask': array([1274, 774], dtype=uint32),\n", + " 'input_ids': array([49154, 11062, 1483, ..., 49651, 49408, 49282], dtype=uint32),\n", + " 'position_ids': array([ 0, 1, 2, ..., 771, 772, 773], dtype=uint32)}" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "d5349fa7", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'<|text_start|>Osman: Sedangkan dalam bahasa Perancis , frira hanya bererti menggoreng di dalam minyak goreng yang banyak hingga terendam .<|text_end|><|speech_start|><|361|><|704|><|26|><|639|><|759|><|587|><|669|><|533|><|530|><|752|><|18|><|479|><|599|><|348|><|708|><|535|><|768|><|712|><|227|><|639|><|679|><|348|><|302|><|327|><|529|><|478|><|495|><|479|><|989|><|739|><|268|><|646|><|328|><|15|><|770|><|545|><|733|><|178|><|846|><|534|><|522|><|7|><|785|><|738|><|453|><|539|><|219|><|508|><|351|><|59|><|465|><|386|><|455|><|448|><|354|><|447|><|755|><|694|><|663|><|788|><|674|><|540|><|590|><|805|><|264|><|65|><|544|><|312|><|427|><|215|><|159|><|447|><|351|><|62|><|308|><|388|><|346|><|226|><|62|><|286|><|948|><|680|><|622|><|478|><|345|><|16|><|909|><|447|><|936|><|902|><|751|><|852|><|315|><|823|><|470|><|965|><|503|><|269|><|810|><|512|><|789|><|29|><|518|><|560|><|751|><|21|><|107|><|548|><|580|><|467|><|77|><|760|><|949|><|530|><|629|><|916|><|104|><|264|><|751|><|247|><|785|><|421|><|339|><|464|><|237|><|470|><|538|><|646|><|142|><|101|><|458|><|116|><|52|><|284|><|91|><|447|><|567|><|413|><|449|><|25|><|281|><|535|><|190|><|887|><|654|><|541|><|19|><|542|><|91|><|321|><|19|><|880|><|860|><|971|><|623|><|302|><|138|><|400|><|529|><|720|><|690|><|324|><|496|><|554|><|169|><|573|><|97|><|568|><|420|><|135|><|451|><|108|><|496|><|933|><|313|><|760|><|383|><|15|><|332|><|65|><|731|><|549|><|347|><|731|><|89|><|47|><|795|><|722|><|589|><|89|><|619|><|982|><|120|><|207|><|779|><|948|><|334|><|371|><|710|><|702|><|743|><|59|><|547|><|354|><|535|><|683|><|305|><|520|><|322|><|536|><|28|><|279|><|549|><|881|><|555|><|562|><|661|><|231|><|294|><|619|><|225|><|628|><|643|><|283|><|256|><|513|><|498|><|417|><|931|><|495|><|275|><|760|><|464|><|269|><|59|><|384|><|377|><|510|><|274|><|560|><|546|><|59|><|52|><|580|><|347|><|64|><|64|><|360|><|345|><|980|><|911|><|559|><|370|><|120|><|304|><|764|><|360|><|87|><|346|><|587|><|703|><|379|><|678|><|535|><|723|><|284|><|149|><|489|><|431|><|554|><|707|><|640|><|785|><|680|><|295|><|563|><|535|><|498|><|180|><|925|><|540|><|427|><|295|><|324|><|614|><|465|><|88|><|607|><|337|><|103|><|531|><|356|><|493|><|392|><|211|><|575|><|528|><|247|><|786|><|555|><|342|><|188|><|437|><|743|><|345|><|207|><|387|><|509|><|125|><|390|><|251|><|733|><|761|><|98|><|97|><|95|><|550|><|68|><|499|><|366|><|770|><|792|><|227|><|639|><|719|><|540|><|710|><|486|><|577|><|752|><|18|><|639|><|719|><|539|><|469|><|334|><|728|><|752|><|26|><|639|><|511|><|347|><|509|><|495|><|769|><|504|><|227|><|639|><|719|><|548|><|502|><|527|><|585|><|517|><|455|><|759|><|861|><|484|><|118|><|924|><|336|><|31|><|382|><|737|><|461|><|525|><|643|><|533|><|414|><|230|><|332|><|482|><|25|><|571|><|444|><|204|><|111|><|109|><|228|><|497|><|454|><|968|><|637|><|247|><|225|><|961|><|620|><|569|><|29|><|304|><|396|><|287|><|744|><|817|><|334|><|409|><|622|><|257|><|749|><|607|><|666|><|101|><|357|><|825|><|437|><|810|><|259|><|245|><|272|><|224|><|578|><|274|><|828|><|254|><|581|><|447|><|301|><|209|><|513|><|227|><|627|><|223|><|95|><|687|><|578|><|17|><|581|><|61|><|289|><|11|><|102|><|847|><|722|><|590|><|333|><|375|><|547|><|209|><|337|><|366|><|563|><|752|><|58|><|703|><|358|><|147|><|950|><|734|><|336|><|303|><|438|><|259|><|146|><|903|><|507|><|725|><|687|><|95|><|491|><|457|><|472|><|651|><|382|><|486|><|503|><|102|><|266|><|528|><|659|><|68|><|309|><|447|><|460|><|346|><|93|><|740|><|227|><|352|><|369|><|206|><|667|><|973|><|510|><|337|><|614|><|488|><|344|><|607|><|775|><|12|><|493|><|450|><|548|><|445|><|380|><|647|><|485|><|281|><|212|><|431|><|546|><|736|><|746|><|568|><|493|><|22|><|129|><|425|><|741|><|842|><|489|><|368|><|228|><|102|><|303|><|259|><|637|><|483|><|307|><|680|><|69|><|69|><|132|><|114|><|542|><|657|><|304|><|368|><|306|><|500|><|549|><|62|><|162|><|337|><|304|><|970|><|722|><|96|><|18|><|431|><|519|><|330|><|749|><|734|><|575|><|85|><|83|><|345|><|52|><|144|><|260|><|283|><|950|><|854|><|55|><|489|><|173|><|88|><|348|><|921|><|818|><|370|><|230|><|676|><|266|><|936|><|658|><|560|><|850|><|659|><|510|><|265|><|386|><|283|><|705|><|520|><|408|><|138|><|21|><|708|><|107|><|448|><|467|><|760|><|665|><|936|><|414|><|510|><|563|><|40|><|721|><|882|><|450|><|46|><|502|><|459|><|409|><|899|><|245|><|42|><|244|><|820|><|421|><|513|><|365|><|522|><|58|><|168|><|374|><|813|><|100|><|50|><|572|><|685|><|305|><|721|><|503|><|94|><|150|><|409|><|479|><|652|><|550|><|523|><|359|><|502|><|621|><|568|><|214|><|25|><|110|><|447|><|652|><|492|><|254|><|579|><|25|><|216|><|170|><|407|><|943|><|698|><|742|><|238|><|8|><|288|><|720|><|887|><|711|><|737|><|468|><|271|><|378|><|56|><|651|><|721|><|456|><|918|><|671|><|955|><|546|><|697|><|627|><|963|><|457|><|541|><|12|><|378|><|180|><|192|><|498|><|760|><|674|><|737|><|109|><|167|><|323|><|305|><|378|><|529|><|771|><|46|><|339|><|417|><|493|><|667|><|536|><|120|><|59|><|61|><|258|><|555|><|739|><|711|><|353|><|560|><|646|><|295|><|510|><|473|><|556|><|483|><|139|><|721|><|452|><|92|><|11|><|115|><|378|><|661|><|73|><|760|><|834|><|739|><|222|><|917|><|154|><|488|><|899|><|760|><|630|><|751|><|270|><|505|><|492|><|704|><|913|><|760|><|703|><|567|><|221|><|579|><|132|><|336|><|373|><|322|><|791|><|573|><|623|><|230|><|370|><|400|><|332|><|842|><|942|><|703|><|142|><|768|><|542|><|58|><|515|><|560|><|737|><|212|><|18|><|283|><|336|><|62|><|392|><|520|><|745|><|249|><|311|><|302|><|66|><|403|><|91|><|685|><|379|><|984|><|469|><|679|><|747|><|136|><|740|><|845|><|289|><|412|><|53|><|284|><|448|><|373|><|203|><|49|><|320|><|888|><|461|><|473|><|414|><|333|><|290|><|520|><|241|><|628|><|499|><|306|><|972|><|803|><|267|><|360|><|479|><|143|><|311|><|410|><|477|><|692|><|516|><|681|><|351|><|535|><|460|><|368|><|351|><|58|><|70|><|325|><|661|><|653|><|214|><|572|><|10|><|256|><|364|><|207|><|934|><|699|><|782|><|430|><|8|><|480|><|482|><|647|><|742|><|694|><|423|><|68|><|179|><|41|><|427|><|921|><|898|><|693|><|99|><|768|><|788|><|50|><|154|><|360|><|729|><|217|><|28|><|100|><|137|><|291|><|105|><|681|><|896|><|760|><|430|><|662|><|137|><|80|><|570|><|683|><|449|><|731|><|911|><|716|><|874|><|740|><|310|><|883|><|640|><|457|><|289|><|97|><|416|><|534|><|178|><|921|><|264|><|141|><|68|><|377|><|379|><|467|><|274|><|760|><|468|><|139|><|52|><|741|><|348|><|152|><|297|><|360|><|663|><|935|><|631|><|300|><|748|><|480|><|469|><|523|><|743|><|631|><|291|><|770|><|754|><|53|><|399|><|287|><|773|><|285|><|99|><|60|><|376|><|45|><|306|><|724|><|513|><|59|><|270|><|463|><|249|><|260|><|98|><|444|><|329|><|739|><|95|><|575|><|572|><|498|><|346|><|164|><|329|><|727|><|343|><|542|><|126|><|716|><|492|><|684|><|724|><|326|><|54|><|417|><|147|><|456|><|297|><|760|><|863|><|983|><|219|><|509|><|147|><|248|><|388|><|721|><|543|><|748|><|614|><|438|><|178|><|208|><|692|><|882|><|660|><|703|><|743|><|562|><|499|><|490|><|611|><|522|><|143|><|271|><|490|><|489|><|783|><|307|><|559|><|526|><|532|><|339|><|499|><|95|><|308|><|539|><|107|><|733|><|507|><|223|><|53|><|738|><|427|><|636|><|654|><|244|><|277|><|109|><|292|><|347|><|420|><|523|><|105|><|645|><|672|><|313|><|332|><|544|><|229|><|460|><|146|><|487|><|537|><|137|><|493|><|507|><|273|><|96|><|501|><|207|><|742|><|771|><|702|><|420|><|105|><|80|><|441|><|967|><|738|><|222|><|388|><|504|><|550|><|629|><|746|><|321|><|568|><|289|><|338|><|336|><|488|><|295|><|195|><|685|><|533|><|132|><|100|><|548|><|656|><|645|><|129|><|886|><|543|><|570|><|109|><|503|><|201|><|681|><|203|><|445|><|751|><|738|><|503|><|455|><|51|><|600|><|843|><|884|><|334|><|697|><|446|><|71|><|58|><|696|><|442|><|683|><|170|><|521|><|49|><|95|><|548|><|538|><|850|><|165|><|769|><|105|><|28|><|639|><|719|><|300|><|462|><|335|><|730|><|752|><|18|><|679|><|599|><|301|><|469|><|735|><|736|><|952|><|226|><|839|><|519|><|396|><|501|><|534|><|736|><|704|><|18|><|639|><|719|><|347|><|511|><|727|><|537|><|712|><|219|><|639|><|559|><|346|><|461|><|527|><|776|><|752|><|18|><|639|><|519|><|347|><|510|><|527|><|777|><|712|><|219|><|639|><|559|><|346|><|502|><|527|><|736|><|712|><|18|><|639|><|519|><|387|><|502|><|727|><|728|><|752|><|18|><|639|><|559|><|339|><|509|><|527|><|736|><|712|><|226|><|639|><|519|><|347|><|501|><|526|><|568|><|552|><|18|><|639|><|719|><|387|><|502|><|527|><|737|><|712|><|18|><|639|><|559|><|346|><|469|><|527|><|736|><|712|><|218|><|639|><|519|><|347|><|501|><|527|><|777|><|712|><|218|><|639|><|519|><|387|><|501|><|486|><|769|><|512|><|18|><|639|><|719|><|298|><|309|><|327|><|768|><|712|><|227|><|639|><|679|><|548|><|302|><|326|><|562|><|512|><|60|><|639|><|717|><|547|><|542|><|686|><|speech_end|><|text_start|>Osman: dikenali dengan timangan tomcat , ialah satu genus kumbang kecil dalam famili Staphylinidae ( \" kumbang rayau \" ) .<|text_end|><|speech_start|><|561|><|704|><|226|><|639|><|711|><|587|><|670|><|533|><|731|><|712|><|12|><|639|><|359|><|348|><|708|><|735|><|971|><|473|><|51|><|471|><|519|><|501|><|660|><|175|><|362|><|433|><|52|><|431|><|476|><|699|><|307|><|524|><|585|><|789|><|911|><|553|><|769|><|170|><|314|><|881|><|428|><|500|><|418|><|729|><|30|><|664|><|345|><|7|><|538|><|722|><|333|><|167|><|573|><|489|><|570|><|934|><|452|><|302|><|20|><|569|><|419|><|460|><|615|><|204|><|345|><|264|><|251|><|338|><|54|><|112|><|193|><|247|><|858|><|340|><|958|><|537|><|635|><|281|><|105|><|247|><|417|><|318|><|89|><|546|><|434|><|302|><|584|><|205|><|257|><|258|><|307|><|395|><|418|><|420|><|61|><|287|><|903|><|695|><|270|><|298|><|269|><|256|><|711|><|723|><|743|><|942|><|94|><|731|><|348|><|96|><|750|><|760|><|548|><|529|><|340|><|87|><|122|><|89|><|386|><|814|><|975|><|101|><|459|><|528|><|183|><|266|><|508|><|486|><|663|><|647|><|423|><|494|><|509|><|8|><|261|><|641|><|710|><|703|><|297|><|569|><|549|><|10|><|153|><|360|><|508|><|100|><|11|><|549|><|467|><|288|><|105|><|520|><|562|><|336|><|250|><|479|><|918|><|129|><|299|><|365|><|244|><|63|><|331|><|562|><|13|><|512|><|285|><|283|><|902|><|901|><|94|><|497|><|133|><|568|><|393|><|968|><|611|><|580|><|437|><|637|><|146|><|528|><|505|><|520|><|851|><|830|><|498|><|753|><|176|><|685|><|577|><|361|><|742|><|23|><|99|><|139|><|106|><|411|><|104|><|760|><|679|><|342|><|238|><|469|><|269|><|96|><|530|><|481|><|910|><|295|><|301|><|576|><|549|><|19|><|153|><|360|><|743|><|23|><|139|><|337|><|459|><|419|><|72|><|320|><|676|><|492|><|238|><|645|><|349|><|56|><|394|><|440|><|729|><|776|><|659|><|439|><|509|><|256|><|490|><|565|><|522|><|792|><|27|><|639|><|518|><|180|><|918|><|726|><|465|><|790|><|703|><|556|><|707|><|573|><|317|><|884|><|965|><|244|><|499|><|464|><|131|><|723|><|368|><|568|><|445|><|582|><|228|><|499|><|178|><|530|><|50|><|360|><|871|><|702|><|421|><|717|><|114|><|448|><|571|><|722|><|553|><|380|><|783|><|591|><|539|><|48|><|65|><|764|><|80|><|662|><|186|><|537|><|551|><|780|><|461|><|734|><|575|><|207|><|580|><|624|><|740|><|709|><|748|><|284|><|332|><|531|><|188|><|491|><|570|><|215|><|593|><|968|><|795|><|291|><|54|><|329|><|425|><|222|><|58|><|647|><|350|><|299|><|101|><|303|><|252|><|240|><|99|><|324|><|522|><|792|><|309|><|639|><|479|><|529|><|516|><|524|><|322|><|592|><|27|><|639|><|399|><|309|><|717|><|565|><|570|><|496|><|19|><|279|><|798|><|348|><|668|><|533|><|736|><|712|><|10|><|639|><|479|><|388|><|503|><|535|><|722|><|713|><|12|><|439|><|911|><|300|><|269|><|535|><|762|><|744|><|11|><|271|><|519|><|492|><|475|><|334|><|461|><|342|><|215|><|540|><|626|><|744|><|349|><|840|><|622|><|303|><|91|><|529|><|68|><|584|><|339|><|360|><|551|><|463|><|689|><|497|><|983|><|90|><|350|><|525|><|302|><|110|><|499|><|157|><|52|><|523|><|21|><|693|><|167|><|222|><|649|><|27|><|561|><|142|><|117|><|932|><|340|><|103|><|340|><|307|><|376|><|26|><|50|><|922|><|248|><|7|><|986|><|756|><|301|><|139|><|830|><|550|><|412|><|47|><|572|><|627|><|412|><|580|><|502|><|410|><|335|><|62|><|290|><|116|><|741|><|605|><|153|><|720|><|539|><|346|><|330|><|103|><|339|><|489|><|146|><|930|><|866|><|822|><|359|><|692|><|715|><|668|><|662|><|885|><|543|><|428|><|21|><|106|><|132|><|321|><|465|><|768|><|98|><|448|><|69|><|351|><|708|><|256|><|795|><|689|><|449|><|495|><|309|><|537|><|307|><|674|><|411|><|323|><|250|><|339|><|452|><|89|><|783|><|546|><|553|><|360|><|829|><|174|><|230|><|557|><|228|><|496|><|355|><|440|><|551|><|103|><|212|><|169|><|558|><|100|><|109|><|644|><|381|><|63|><|492|><|115|><|572|><|0|><|211|><|766|><|128|><|206|><|786|><|956|><|350|><|377|><|470|><|670|><|593|><|406|><|588|><|542|><|131|><|580|><|489|><|254|><|921|><|532|><|733|><|671|><|275|><|367|><|775|><|483|><|925|><|284|><|31|><|634|><|299|><|530|><|298|><|320|><|609|><|540|><|29|><|717|><|579|><|96|><|146|><|720|><|740|><|684|><|230|><|518|><|194|><|120|><|145|><|920|><|883|><|894|><|740|><|672|><|505|><|742|><|539|><|360|><|454|><|223|><|219|><|139|><|137|><|733|><|97|><|760|><|717|><|94|><|20|><|518|><|316|><|96|><|298|><|520|><|362|><|745|><|451|><|471|><|719|><|240|><|259|><|565|><|461|><|109|><|84|><|362|><|220|><|538|><|454|><|243|><|386|><|482|><|220|><|381|><|366|><|457|><|153|><|369|><|464|><|823|><|311|><|310|><|180|><|757|><|901|><|927|><|768|><|247|><|539|><|530|><|311|><|307|><|350|><|525|><|983|><|486|><|102|><|456|><|583|><|505|><|746|><|642|><|559|><|230|><|811|><|307|><|372|><|471|><|189|><|521|><|527|><|446|><|667|><|36|><|370|><|534|><|134|><|724|><|567|><|653|><|469|><|37|><|365|><|921|><|252|><|571|><|529|><|697|><|68|><|311|><|267|><|568|><|155|><|489|><|426|><|413|><|750|><|424|><|668|><|650|><|301|><|722|><|248|><|9|><|330|><|76|><|627|><|663|><|370|><|930|><|640|><|139|><|470|><|274|><|276|><|484|><|139|><|720|><|720|><|264|><|250|><|107|><|48|><|502|><|137|><|921|><|696|><|331|><|23|><|917|><|106|><|80|><|298|><|441|><|713|><|862|><|591|><|659|><|700|><|602|><|619|><|567|><|537|><|31|><|740|><|742|><|89|><|495|><|252|><|126|>'" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.decode(dataset[0]['input_ids'])" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "4a04f32d", + "metadata": {}, + "outputs": [], + "source": [ + "# !wget https://gist.githubusercontent.com/huseinzol05/98974ae8c6c7a65d4bc0af9f5003786a/raw/2e06e71ef7349a57bc58cc9913ae6bae1f9f8447/mp.py" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "ba844846", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + " 4%|█▎ | 702/20000 [00:00<00:19, 1002.66it/s]TOKENIZERS_PARALLELISM=(true | false)\n", + " 5%|█▋ | 918/20000 [00:01<00:32, 580.05it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (9392 > 8192). Running this sequence through the model will result in indexing errors\n", + " 22%|████████▎ | 4478/20000 [00:07<00:35, 434.44it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (11680 > 8192). Running this sequence through the model will result in indexing errors\n", + " 37%|█████████████▊ | 7493/20000 [00:13<00:31, 398.08it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (8237 > 8192). Running this sequence through the model will result in indexing errors\n", + " 27%|██████████ | 5445/20000 [00:12<00:37, 393.23it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (8402 > 8192). Running this sequence through the model will result in indexing errors\n", + " 67%|████████████████████████▏ | 13459/20000 [00:38<00:19, 333.22it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (8371 > 8192). Running this sequence through the model will result in indexing errors\n", + "100%|████████████████████████████████████| 20000/20000 [00:44<00:00, 450.32it/s]\n", + " 76%|███████████████████████████▎ | 15197/20000 [00:41<00:13, 359.76it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (12335 > 8192). Running this sequence through the model will result in indexing errors\n", + "100%|████████████████████████████████████| 20000/20000 [00:44<00:00, 446.79it/s]\n", + "100%|████████████████████████████████████| 20000/20000 [00:53<00:00, 372.76it/s]\n", + "100%|████████████████████████████████████| 20000/20000 [00:52<00:00, 379.04it/s]\n", + "100%|████████████████████████████████████| 20000/20000 [00:55<00:00, 359.49it/s]\n", + "100%|████████████████████████████████████| 20000/20000 [01:00<00:00, 330.25it/s]\n", + "100%|████████████████████████████████████| 20000/20000 [00:58<00:00, 339.56it/s]\n", + "100%|████████████████████████████████████| 20000/20000 [01:02<00:00, 320.46it/s]\n", + " 1%|▌ | 265/20000 [00:00<00:55, 355.81it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (8913 > 8192). Running this sequence through the model will result in indexing errors\n", + "100%|████████████████████████████████████| 20000/20000 [01:10<00:00, 284.68it/s]\n", + "100%|████████████████████████████████████| 20000/20000 [01:09<00:00, 286.66it/s]\n", + "100%|████████████████████████████████████████| 298/298 [00:00<00:00, 376.82it/s]\n", + "100%|████████████████████████████████████| 20000/20000 [00:42<00:00, 476.07it/s]\n", + "100%|████████████████████████████████████| 20000/20000 [00:39<00:00, 507.85it/s]\n", + "100%|████████████████████████████████████| 20000/20000 [00:50<00:00, 396.09it/s]\n", + "100%|████████████████████████████████████| 20000/20000 [00:38<00:00, 513.05it/s]\n", + "100%|████████████████████████████████████| 20000/20000 [00:34<00:00, 574.28it/s]\n", + "100%|████████████████████████████████████| 20000/20000 [00:38<00:00, 518.74it/s]\n", + "100%|████████████████████████████████████| 20000/20000 [00:35<00:00, 563.49it/s]\n", + "100%|████████████████████████████████████| 20000/20000 [00:37<00:00, 536.86it/s]\n" + ] + } + ], + "source": [ + "from multiprocess import Pool\n", + "import mp\n", + "\n", + "chunks = mp.chunks(df, 20000)\n", + "pool = Pool(10)\n", + "pooled = pool.map(loop, chunks)\n", + "pool.close()\n", + "pool.join()" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "3f390f43", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['tokenized-2048/tokenized-0',\n", + " 'tokenized-2048/tokenized-1',\n", + " 'tokenized-2048/tokenized-2',\n", + " 'tokenized-2048/tokenized-3',\n", + " 'tokenized-2048/tokenized-4',\n", + " 'tokenized-2048/tokenized-5',\n", + " 'tokenized-2048/tokenized-6',\n", + " 'tokenized-2048/tokenized-7',\n", + " 'tokenized-2048/tokenized-8',\n", + " 'tokenized-2048/tokenized-9',\n", + " 'tokenized-2048/tokenized-10',\n", + " 'tokenized-2048/tokenized-11',\n", + " 'tokenized-2048/tokenized-12',\n", + " 'tokenized-2048/tokenized-13',\n", + " 'tokenized-2048/tokenized-14',\n", + " 'tokenized-2048/tokenized-15',\n", + " 'tokenized-2048/tokenized-16',\n", + " 'tokenized-2048/tokenized-17',\n", + " 'tokenized-2048/tokenized-18']" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "folders = sorted(glob('tokenized-2048/tokenized-*'), key = lambda x: int(x.split('-')[-1]))\n", + "folders" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "11f09fd6", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████| 11692/11692 [00:00<00:00, 28964.85it/s]\n", + "100%|██████████████████████████████████| 11617/11617 [00:00<00:00, 23850.61it/s]\n", + "100%|██████████████████████████████████| 12904/12904 [00:00<00:00, 25673.80it/s]\n", + "100%|██████████████████████████████████| 13956/13956 [00:00<00:00, 22783.10it/s]\n", + "100%|██████████████████████████████████| 13866/13866 [00:00<00:00, 26494.07it/s]\n", + "100%|██████████████████████████████████| 12705/12705 [00:00<00:00, 25609.02it/s]\n", + "100%|██████████████████████████████████| 12289/12289 [00:00<00:00, 25136.33it/s]\n", + "100%|██████████████████████████████████| 12951/12951 [00:00<00:00, 26063.83it/s]\n", + "100%|██████████████████████████████████| 14811/14811 [00:00<00:00, 22708.50it/s]\n", + "100%|██████████████████████████████████| 14884/14884 [00:00<00:00, 22629.24it/s]\n", + "100%|██████████████████████████████████| 13082/13082 [00:00<00:00, 25237.33it/s]\n", + "100%|██████████████████████████████████| 10739/10739 [00:00<00:00, 22355.06it/s]\n", + "100%|██████████████████████████████████| 10992/10992 [00:00<00:00, 28660.18it/s]\n", + "100%|██████████████████████████████████| 11082/11082 [00:00<00:00, 22466.82it/s]\n", + "100%|██████████████████████████████████| 11404/11404 [00:00<00:00, 22900.30it/s]\n", + "100%|██████████████████████████████████| 10577/10577 [00:00<00:00, 28104.88it/s]\n", + "100%|██████████████████████████████████| 11486/11486 [00:00<00:00, 23251.71it/s]\n", + "100%|██████████████████████████████████| 12661/12661 [00:00<00:00, 24608.13it/s]\n", + "100%|██████████████████████████████████████| 183/183 [00:00<00:00, 53236.07it/s]\n" + ] + } + ], + "source": [ + "with MDSWriter(\n", + " out='smollm2-speech-semantic-multipack-2048', columns=columns, compression=None, hashes=hashes) as out:\n", + " for f in folders:\n", + " try:\n", + " dataset = LocalDataset(local=f)\n", + " for i in tqdm(range(len(dataset))):\n", + " out.write(dataset[i])\n", + " except Exception as e:\n", + " print(e)\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "3f55b311", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = LocalDataset('smollm2-speech-semantic-multipack-2048')" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "51d453a6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.458508288" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(len(dataset) * 2048) / 1e9" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3.10", + "language": "python", + "name": "python3.10" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.15" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/session/smollm2-speech-semantics/tts-multipack.py b/session/smollm2-speech-semantics/tts-multipack.py new file mode 100644 index 0000000..62a60d2 --- /dev/null +++ b/session/smollm2-speech-semantics/tts-multipack.py @@ -0,0 +1,443 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset. + +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=text-generation +""" +# You can also adapt this script on your own causal language modeling +# task. Pointers for this are left as comments. + +import logging +import math +import os +import sys +import warnings +from dataclasses import dataclass, field +from itertools import chain +from typing import Optional + +import datasets +import evaluate +import torch + +torch._dynamo.config.optimize_ddp=False + +from datasets import load_dataset + +import transformers +import random +from transformers import ( + CONFIG_MAPPING, + MODEL_FOR_CAUSAL_LM_MAPPING, + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + HfArgumentParser, + Trainer, + TrainingArguments, + default_data_collator, + DataCollatorWithPadding, + DataCollatorForLanguageModeling, + is_torch_tpu_available, + set_seed, +) +from transformers.testing_utils import CaptureLogger +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import check_min_version, send_example_telemetry +from transformers.utils.versions import require_version +from transformers import AddedToken +import streaming +import json +import numpy as np +from streaming import LocalDataset +from streaming.base.format.mds.encodings import Encoding, _encodings +from peft import LoraConfig, get_peft_model + +require_version( + "datasets>=1.8.0", + "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") + +logger = logging.getLogger(__name__) + + +MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. + """ + + rank: int = field( + default=256, + metadata={ + "help": "lora rank" + }, + ) + + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." + ) + }, + ) + model_type: Optional[str] = field( + default=None, + metadata={ + "help": "If training from scratch, pass a model type from the list: " + + ", ".join(MODEL_TYPES)}, + ) + config_overrides: Optional[str] = field( + default=None, metadata={ + "help": ( + "Override some existing default config settings when a model is trained from scratch. Example: " + "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index")}, ) + config_name: Optional[str] = field( + default=None, metadata={ + "help": "Pretrained config name or path if not the same as model_name"}) + tokenizer_name: Optional[str] = field( + default=None, metadata={ + "help": "Pretrained tokenizer name or path if not the same as model_name"}) + cache_dir: Optional[str] = field( + default=None, metadata={ + "help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, ) + use_fast_tokenizer: bool = field( + default=True, metadata={ + "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, ) + model_revision: str = field( + default="main", metadata={ + "help": "The specific model version to use (can be a branch name, tag name or commit id)."}, ) + token: str = field( + default=None, + metadata={ + "help": ( + "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " + "generated when running `huggingface-cli login` (stored in `~/.huggingface`)." + ) + }, + ) + use_auth_token: bool = field( + default=None, + metadata={ + "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`." + }, + ) + trust_remote_code: bool = field( + default=False, metadata={ + "help": ( + "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option" + "should only be set to `True` for repositories you trust and in which you have read the code, as it will" + "execute code present on the Hub on your local machine.")}, ) + torch_dtype: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " + "dtype will be automatically derived from the model's weights."), + "choices": [ + "auto", + "bfloat16", + "float16", + "float32"], + }, + ) + low_cpu_mem_usage: bool = field( + default=False, + metadata={ + "help": ( + "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded." + "set True will benefit LLM loading time and RAM consumption." + ) + }, + ) + + def __post_init__(self): + if self.config_overrides is not None and ( + self.config_name is not None or self.model_name_or_path is not None): + raise ValueError( + "--config_overrides can't be used in combination with --config_name or --model_name_or_path" + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, metadata={ + "help": "The name of the dataset to use (via the datasets library)."}) + dataset_config_name: Optional[str] = field( + default=None, metadata={ + "help": "The configuration name of the dataset to use (via the datasets library)."}) + train_file: Optional[str] = field( + default=None, metadata={ + "help": "The input training data file (a text file)."}) + validation_file: Optional[str] = field( + default=None, metadata={ + "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, ) + max_train_samples: Optional[int] = field( + default=None, metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set.")}, ) + max_eval_samples: Optional[int] = field( + default=None, metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set.")}, ) + streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"}) + block_size: Optional[int] = field( + default=None, + metadata={ + "help": ( + "Optional input sequence length after tokenization. " + "The training dataset will be truncated in block of this size for training. " + "Default to the model max input length for single sentence inputs (take into account special tokens)." + ) + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + validation_split_percentage: Optional[int] = field( + default=5, + metadata={ + "help": "The percentage of the train set used as validation set in case there's no validation split" + }, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + keep_linebreaks: bool = field( + default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} + ) + +def print_trainable_parameters(model): + """ + Prints the number of trainable parameters in the model. + """ + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + all_param += param.numel() + if param.requires_grad: + trainable_params += param.numel() + print( + f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" + ) + +def block_diagonal_concat_inverted(*masks, dtype=torch.bfloat16): + total_size = sum(mask.size(0) for mask in masks) + combined_mask = torch.zeros(total_size, total_size, dtype=dtype) + + current_pos = 0 + + for mask in masks: + size = mask.size(0) + combined_mask[current_pos:current_pos + size, current_pos:current_pos + size] = mask + current_pos += size + + min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min + inverted_mask = torch.where(combined_mask == 1, torch.tensor(0, dtype=dtype), min_value) + return inverted_mask.unsqueeze(0) + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file( + json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if model_args.use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v4.34.", + FutureWarning) + if model_args.token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`.") + model_args.token = model_args.use_auth_token + + # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The + # information sent is the one passed as arguments along with your Python/PyTorch versions. + send_example_telemetry("run_clm", model_args, data_args) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + if training_args.should_log: + # The default of training_args.log_level is passive, so we set log level + # at info here to have that default. + transformers.utils.logging.set_verbosity_info() + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + print(model_args) + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}") + logger.info(f"Training/evaluation parameters {training_args}") + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir( + training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + + # Set seed before initializing model. + set_seed(training_args.seed) + + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + + config_kwargs = { + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "token": model_args.token, + "trust_remote_code": model_args.trust_remote_code, + } + + if model_args.config_name: + config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) + elif model_args.model_name_or_path: + config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) + else: + config = CONFIG_MAPPING[model_args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + if model_args.config_overrides is not None: + logger.info(f"Overriding config: {model_args.config_overrides}") + config.update_from_string(model_args.config_overrides) + logger.info(f"New config: {config}") + + tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) + + class UInt32(Encoding): + def encode(self, obj) -> bytes: + return obj.tobytes() + + def decode(self, data: bytes): + return np.frombuffer(data, np.uint32) + + _encodings['uint32'] = UInt32 + + class DatasetFixed(torch.utils.data.Dataset): + def __init__(self, local): + self.dataset = LocalDataset(local=local) + + def __getitem__(self, idx): + data = self.dataset[idx] + data['labels'] = data["input_ids"].copy() + masking = data.pop('attention_mask') + + data.pop('token_type_ids', None) + for k in data.keys(): + data[k] = data[k].astype(np.int64) + + masks = [] + for m in masking: + masks.append(torch.tril(torch.ones(m, m))) + attention_mask = block_diagonal_concat_inverted(*masks) + data['attention_mask'] = attention_mask + + return data + + def __len__(self): + return len(self.dataset) + + dataset = DatasetFixed(data_args.train_file) + print(len(dataset), dataset[0]['attention_mask'].shape) + + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) + ) + + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + torch_dtype=torch_dtype, + low_cpu_mem_usage=model_args.low_cpu_mem_usage, + attn_implementation = 'sdpa', + ) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=dataset, + eval_dataset=None, + tokenizer=tokenizer, + data_collator=default_data_collator, + compute_metrics=None, + preprocess_logits_for_metrics=None, + ) + + # Training + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_model() + trainer.save_state() + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main()