From 2a1bcc56e358347bf3cff4d1e3548d3b1f80c269 Mon Sep 17 00:00:00 2001 From: Patrick Li Date: Fri, 20 Sep 2024 01:21:18 -0700 Subject: [PATCH] Add interleave config --- ultravox/data/dataset_config.py | 25 +- ultravox/data/datasets.py | 44 +- ultravox/data/datasets_test.py | 21 +- ultravox/training/config_base.py | 18 +- ultravox/training/configs/meta_config.yaml | 2 - ultravox/training/configs/release_config.yaml | 469 +++++++++--------- ultravox/training/train.py | 31 +- 7 files changed, 319 insertions(+), 291 deletions(-) diff --git a/ultravox/data/dataset_config.py b/ultravox/data/dataset_config.py index 38c2abd1..e1491a5a 100644 --- a/ultravox/data/dataset_config.py +++ b/ultravox/data/dataset_config.py @@ -1,23 +1,38 @@ import dataclasses +from enum import Enum from typing import List, Optional from pydantic import BaseModel +class StopStrategy(str, Enum): + FIRST_EXHAUSTED = "first_exhausted" + LAST_EXHAUSTED = "last_exhausted" + NEVER_STOP = "never_stop" + + class DataDictConfig(BaseModel): - # Path to the dataset, or huggingface dataset id - path: str - # Name of the dataset, or huggingface dataset config/subset + path: str # Name of the dataset, or huggingface dataset config/subset name: Optional[str] = None splits: List[str] = dataclasses.field(default_factory=list) num_samples: Optional[int] = None total_samples: int = 1 - multiplier: float = 1.0 streaming: bool = True user_template: str = "<|audio|>" assistant_template: str = "{{text}}" transcript_template: str = "{{text}}" - def __post_init__(self): + def post_init(self): if not self.splits: raise ValueError("At least one split must be provided") + + +class DatasetMultiplier(BaseModel): + dataset: DataDictConfig + multiplier: float + + +class InterleaveDataConfig(BaseModel): + # In InterleaveDataset, when to stop interleave: choose from last_exhausted (default), first_exhausted, or never_stop + stop_strategy: StopStrategy = StopStrategy.LAST_EXHAUSTED + datasets_with_multiplier: List[DatasetMultiplier] diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 105a574e..afdcee1e 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -8,7 +8,6 @@ import os import tempfile import warnings -from enum import Enum from typing import Any, Callable, Dict, List, Optional, Sequence import datasets @@ -296,17 +295,12 @@ def __init__(self, args: VoiceDatasetArgs) -> None: self._args = args self._session: Optional[requests.Session] = None self._rng = np.random.default_rng(self._args.shuffle_seed) - self._multiplier = 1.0 def _init_dataset(self, dataset: data.Dataset, estimated_length: int = 1) -> None: self._dataset = dataset # Only required when using epochs when training dataset. self._estimated_length = estimated_length - @property - def multiplier(self) -> float: - return self._multiplier - def _load_audio_dataset( self, path: str, @@ -1049,8 +1043,6 @@ def __init__( if self._args.shuffle: dataset = dataset.shuffle(seed=self._args.shuffle_seed) - self._multiplier = config.multiplier - self.user_template = config.user_template self.assistant_template = config.assistant_template self.transcript_template = config.transcript_template @@ -1089,7 +1081,9 @@ def _get_sample(self, row) -> VoiceSample: ) -def create_dataset(name: str, args: VoiceDatasetArgs) -> SizedIterableDataset: +def create_dataset( + name: str | dataset_config.DataDictConfig, args: VoiceDatasetArgs +) -> SizedIterableDataset: DATASET_MAP: Dict[str, Any] = { "anyinstruct": AnyInstructAnswerDataset, "anyinstruct_in": AnyInstructInputDataset, @@ -1115,21 +1109,16 @@ def create_dataset(name: str, args: VoiceDatasetArgs) -> SizedIterableDataset: return DATASET_MAP[name](args, *ext) -class StopStrategy(str, Enum): - FIRST_EXHAUSTED = "first_exhausted" - LAST_EXHAUSTED = "last_exhausted" - NEVER_STOP = "never_stop" - - class InterleaveDataset(SizedIterableDataset): """Interleaves multiple IterableDataset objects based on multiplier.""" def __init__( self, datasets: Sequence[SizedIterableDataset], - stop_strategy: StopStrategy = StopStrategy.LAST_EXHAUSTED, + stop_strategy: dataset_config.StopStrategy = dataset_config.StopStrategy.LAST_EXHAUSTED, seed: Optional[int] = 42, static: bool = False, + multipliers: Optional[List[float]] = None, ) -> None: """ Args: @@ -1143,10 +1132,15 @@ def __init__( self._static = static self._stop_strategy = stop_strategy + if multipliers is None: + self._multipliers = [1.0] * len(datasets) + else: + self._multipliers = multipliers + relative_frequencies = [ - int(getattr(ds, "multiplier", 1.0) * float(len(ds))) for ds in datasets + int(float(len(ds)) * multiple) + for ds, multiple in zip(datasets, self._multipliers) ] - total_frequency = sum(relative_frequencies) self._normalized_probs = [f / total_frequency for f in relative_frequencies] @@ -1173,9 +1167,13 @@ def __iter__(self): exhausted[iter_index] = True # Check if stopping condition is met - if self._stop_strategy == StopStrategy.FIRST_EXHAUSTED or ( - self._stop_strategy == StopStrategy.LAST_EXHAUSTED - and all(exhausted) + if ( + self._stop_strategy == dataset_config.StopStrategy.FIRST_EXHAUSTED + or ( + self._stop_strategy + == dataset_config.StopStrategy.LAST_EXHAUSTED + and all(exhausted) + ) ): break @@ -1186,8 +1184,8 @@ def __iter__(self): def __len__(self) -> int: # TODO: Implement the length method for different stop strategies return sum( - int(getattr(ds, "multiplier", 1.0) * float(len(ds))) - for ds in self._datasets + int(float(len(ds)) * multiple) + for ds, multiple in zip(datasets, self._multipliers) ) diff --git a/ultravox/data/datasets_test.py b/ultravox/data/datasets_test.py index 63b04d44..791a378a 100644 --- a/ultravox/data/datasets_test.py +++ b/ultravox/data/datasets_test.py @@ -15,15 +15,10 @@ class FakeSizedIterableDataset(datasets.SizedIterableDataset): """Fake version of datasets.SizedIterableDataset""" - def __init__(self, n, start=0, multiplier=1, estimated_length=1): + def __init__(self, n, start=0, estimated_length=1): self.data = range(start, start + n) - self._multiplier = multiplier self._estimated_length = estimated_length - @property - def multiplier(self) -> float: - return self._multiplier - def __iter__(self): for sample in self.data: yield sample @@ -98,7 +93,7 @@ def test_interleaved_first_exhausted(): ds3 = FakeSizedIterableDataset(3) s = datasets.InterleaveDataset( [ds1, ds2, ds3], - stop_strategy=datasets.StopStrategy.FIRST_EXHAUSTED, + stop_strategy=dataset_config.StopStrategy.FIRST_EXHAUSTED, static=True, ) # static=True disables random sampling of datasets, so the order is deterministic @@ -113,7 +108,7 @@ def test_interleaved_last_exhausted(): ds2 = FakeSizedIterableDataset(2, start=10) s = datasets.InterleaveDataset( [ds1, ds2], - stop_strategy=datasets.StopStrategy.LAST_EXHAUSTED, + stop_strategy=dataset_config.StopStrategy.LAST_EXHAUSTED, static=True, ) # static=True disables random sampling of datasets, so the order is deterministic @@ -126,7 +121,7 @@ def test_interleaved_never_stop(): ds2 = FakeSizedIterableDataset(2, start=10) s = datasets.InterleaveDataset( [ds1, ds2], - stop_strategy=datasets.StopStrategy.NEVER_STOP, + stop_strategy=dataset_config.StopStrategy.NEVER_STOP, static=True, ) # static=True disables random sampling of datasets, so the order is deterministic @@ -135,11 +130,9 @@ def test_interleaved_never_stop(): def test_interleaved_random(): - ds1 = FakeSizedIterableDataset(4, multiplier=10) - ds2 = FakeSizedIterableDataset(2, start=10, multiplier=1) - s = datasets.InterleaveDataset( - [ds1, ds2], - ) + ds1 = FakeSizedIterableDataset(4) + ds2 = FakeSizedIterableDataset(2, start=10) + s = datasets.InterleaveDataset([ds1, ds2], multipliers=[10, 1]) # stop_strategy=last_exhausted will stop interleaving when the last dataset is exhausted (attempted after exhaustion) assert list(s) == [ 0, diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index 232c0584..85de6f7b 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -3,19 +3,17 @@ import logging import os from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import List, Optional import simple_parsing import torch from ultravox.data import dataset_config -from ultravox.data import datasets from ultravox.model import ultravox_config @dataclasses.dataclass class TrainConfig: - data_sets: List[str] val_sets: List[str] # language model to use text_model: str @@ -29,13 +27,11 @@ class TrainConfig: # we first parse the data_dicts as a list of dictionaries. After parsing, # we convert these dictionaries to DataDictConfig objects using Pydantic # to enforce type constraints and validation, in the __post_init__ method. - data_dicts: Optional[List[Dict[str, Any]]] = None + interleave_datasets: dataset_config.InterleaveDataConfig do_train: bool = True do_eval: bool = True - # In InterleaveDataset, when to stop interleave: choose from last_exhausted (default), first_exhausted, or never_stop - stop_strategy: datasets.StopStrategy = datasets.StopStrategy.LAST_EXHAUSTED data_dir: Optional[str] = None mds: bool = False num_samples: Optional[int] = None @@ -91,16 +87,6 @@ class TrainConfig: loss_config: Optional[ultravox_config.LossConfig] = None def __post_init__(self): - if self.data_dicts: - self.data_dicts = [ - dataset_config.DataDictConfig(**data_dict) - for data_dict in self.data_dicts - ] - # For now, self.data_dicts is a hack to allow for the inclusion of new datasets using the - # GenericVoiceDataset class, without changing how existing datasets are specified in - # self.data_sets. In the future, all datasets will be updated to use the DataDictConfig class. - self.data_sets.extend(self.data_dicts) - assert self.data_type in ["bfloat16", "float16", "float32"] if self.device == "cuda" and not torch.cuda.is_available(): self.device = "mps" if torch.backends.mps.is_available() else "cpu" diff --git a/ultravox/training/configs/meta_config.yaml b/ultravox/training/configs/meta_config.yaml index 3fad7142..fe84296a 100644 --- a/ultravox/training/configs/meta_config.yaml +++ b/ultravox/training/configs/meta_config.yaml @@ -1,9 +1,7 @@ text_model: "meta-llama/Meta-Llama-3-8B-Instruct" audio_model: "facebook/wav2vec2-base-960h" -data_sets: [] # Can't use datasets with "multiplier" because it doesn't have a length attribute. val_sets: ["heysquad_human", "anyinstruct", "soda", "peoplespeech"] -stop_strategy: "LAST_EXHAUSTED" train_on_inputs: False shuffle_data: True diff --git a/ultravox/training/configs/release_config.yaml b/ultravox/training/configs/release_config.yaml index 4f8a9b1f..ab754a58 100644 --- a/ultravox/training/configs/release_config.yaml +++ b/ultravox/training/configs/release_config.yaml @@ -16,225 +16,250 @@ val_sets: ["anyinstruct", "soda", "peoplespeech"] batch_size: 24 max_steps: 14400 # x8x24 = 2,764,800 -data_sets: [] # Can't use datasets with "multiplier" because it doesn't have a length attribute. -data_dicts: -# continuation - - path: "fixie-ai/librispeech_asr" - name: "clean" - splits: - - "train.100" # 28_539 samples - - "train.360" # 104_014 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ text }}" - multiplier: 1 - total_samples: 132553 - - path: "fixie-ai/librispeech_asr" - name: "other" - splits: - - "train.500" # 148_688 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ text }}" - multiplier: 1 - total_samples: 148688 - - path: "fixie-ai/peoples_speech" - name: "clean" - splits: - - "train" # 1_501_271 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ text_proc.format_asr_text(text) }}" - multiplier: 1 - total_samples: 1501271 - - path: "fixie-ai/common_voice_17_0" - name: "en" - splits: - - "train" # 1_101_170 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ text_proc.format_asr_text(sentence) }}" - multiplier: 1 - total_samples: 1101170 - - path: "fixie-ai/common_voice_17_0" - name: "ar" - splits: - - "train" # 28_369 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 28369 - - path: "fixie-ai/common_voice_17_0" - name: "de" - splits: - - "train" # 589_100 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 589100 - - path: "fixie-ai/common_voice_17_0" - name: "es" - splits: - - "train" # 336_846 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 336846 - - path: "fixie-ai/common_voice_17_0" - name: "fr" - splits: - - "train" # 558_054 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 558054 - - path: "fixie-ai/common_voice_17_0" - name: "it" - splits: - - "train" # 169_771 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 169771 - - path: "fixie-ai/common_voice_17_0" - name: "ja" - splits: - - "train" # 10_039 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 10039 - - path: "fixie-ai/common_voice_17_0" - name: "pt" - splits: - - "train" # 21_968 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 21968 - - path: "fixie-ai/common_voice_17_0" - name: "ru" - splits: - - "train" # 26_377 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 26377 -# ASR task - - path: "fixie-ai/librispeech_asr" - name: "clean" - splits: - - "train.100" # 28_539 samples - - "train.360" # 104_014 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text }}" - transcript_template: "{{ text }}" - multiplier: 1 - total_samples: 132553 - - path: "fixie-ai/librispeech_asr" - name: "other" - splits: - - "train.500" # 148_688 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text }}" - transcript_template: "{{ text }}" - multiplier: 1 - total_samples: 148688 - - path: "fixie-ai/peoples_speech" - name: "clean" - splits: - - "train" # 1_501_271 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(text) }}" - transcript_template: "{{ text_proc.format_asr_text(text) }}" - multiplier: 1 - total_samples: 1501271 - - path: "fixie-ai/common_voice_17_0" - name: "en" - splits: - - "train" # 1_101_170 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ text_proc.format_asr_text(sentence) }}" - multiplier: 1 - total_samples: 1101170 - - path: "fixie-ai/common_voice_17_0" - name: "ar" - splits: - - "train" # 28_369 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 28369 - - path: "fixie-ai/common_voice_17_0" - name: "de" - splits: - - "train" # 589_100 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 589100 - - path: "fixie-ai/common_voice_17_0" - name: "es" - splits: - - "train" # 336_846 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 336846 - - path: "fixie-ai/common_voice_17_0" - name: "fr" - splits: - - "train" # 558_054 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 558054 - - path: "fixie-ai/common_voice_17_0" - name: "it" - splits: - - "train" # 169_771 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 169771 - - path: "fixie-ai/common_voice_17_0" - name: "ja" - splits: - - "train" # 10_039 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 10039 - - path: "fixie-ai/common_voice_17_0" - name: "pt" - splits: - - "train" # 21_968 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 21968 - - path: "fixie-ai/common_voice_17_0" - name: "ru" - splits: - - "train" # 26_377 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 26377 +interleave_datasets: + stop_strategy: "LAST_EXHAUSTED" + datasets_with_multiplier: + # continuation + - dataset: + path: "fixie-ai/librispeech_asr" + name: "clean" + splits: + - "train.100" # 28_539 samples + - "train.360" # 104_014 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ text }}" + total_samples: 132553 + multiplier: 1 + - dataset: + path: "fixie-ai/librispeech_asr" + name: "other" + splits: + - "train.500" # 148_688 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ text }}" + total_samples: 148688 + multiplier: 1 + - dataset: + path: "fixie-ai/peoples_speech" + name: "clean" + splits: + - "train" # 1_501_271 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ text_proc.format_asr_text(text) }}" + total_samples: 1501271 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "en" + splits: + - "train" # 1_101_170 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ text_proc.format_asr_text(sentence) }}" + total_samples: 1101170 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "ar" + splits: + - "train" # 28_369 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ sentence }}" + total_samples: 28369 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "de" + splits: + - "train" # 589_100 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ sentence }}" + total_samples: 589100 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "es" + splits: + - "train" # 336_846 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ sentence }}" + total_samples: 336846 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "fr" + splits: + - "train" # 558_054 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ sentence }}" + total_samples: 558054 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "it" + splits: + - "train" # 169_771 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ sentence }}" + total_samples: 169771 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "ja" + splits: + - "train" # 10_039 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ sentence }}" + total_samples: 10039 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "pt" + splits: + - "train" # 21_968 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ sentence }}" + total_samples: 21968 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "ru" + splits: + - "train" # 26_377 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ sentence }}" + total_samples: 2637 + multiplier: 1 + # ASR task + - dataset: + path: "fixie-ai/librispeech_asr" + name: "clean" + splits: + - "train.100" # 28_539 samples + - "train.360" # 104_014 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text }}" + transcript_template: "{{ text }}" + total_samples: 132553 + multiplier: 1 + - dataset: + path: "fixie-ai/librispeech_asr" + name: "other" + splits: + - "train.500" # 148_688 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text }}" + transcript_template: "{{ text }}" + total_samples: 148688 + multiplier: 1 + - dataset: + path: "fixie-ai/peoples_speech" + name: "clean" + splits: + - "train" # 1_501_271 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(text) }}" + transcript_template: "{{ text_proc.format_asr_text(text) }}" + total_samples: 1501271 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "en" + splits: + - "train" # 1_101_170 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ text_proc.format_asr_text(sentence) }}" + total_samples: 1101170 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "ar" + splits: + - "train" # 28_369 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ sentence }}" + total_samples: 28369 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "de" + splits: + - "train" # 589_100 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ sentence }}" + total_samples: 589100 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "es" + splits: + - "train" # 336_846 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ sentence }}" + total_samples: 336846 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "fr" + splits: + - "train" # 558_054 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ sentence }}" + total_samples: 558054 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "it" + splits: + - "train" # 169_771 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ sentence }}" + total_samples: 169771 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "ja" + splits: + - "train" # 10_039 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ sentence }}" + total_samples: 10039 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "pt" + splits: + - "train" # 21_968 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ sentence }}" + total_samples: 21968 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "ru" + splits: + - "train" # 26_377 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ sentence }}" + total_samples: 26377 + multiplier: 1 diff --git a/ultravox/training/train.py b/ultravox/training/train.py index a37cd248..fac8e530 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -21,6 +21,7 @@ import wandb.sdk from torch.utils import data +from ultravox.data import dataset_config from ultravox.data import datasets from ultravox.model import data_processing from ultravox.model import ultravox_config @@ -41,16 +42,30 @@ def fix_hyphens(arg: str): def prepare_dataset( train_args: config_base.TrainConfig, - dataset_names: List[str], + interleave_dataset: dataset_config.InterleaveDataConfig | List[str], data_args: datasets.VoiceDatasetArgs, processor: ultravox_processing.UltravoxProcessor, train_on_inputs: bool, - stop_strategy: datasets.StopStrategy, num_samples: Optional[int] = None, include_alt_fields: bool = False, # whether to generate tensors for text-only input (e.g., used for KD training) enforce_ds_len_epoch: bool = False, ) -> datasets.SizedIterableDataset: - data_sets = [datasets.create_dataset(ds, data_args) for ds in dataset_names] + if isinstance(interleave_dataset, dataset_config.InterleaveDataConfig): + data_sets = [ + datasets.create_dataset(ds.dataset, data_args) + for ds in interleave_dataset.datasets_with_multiplier + ] + multipliers = [ + ds.multiplier for ds in interleave_dataset.datasets_with_multiplier + ] + stop_strategy = interleave_dataset.stop_strategy + else: + data_sets = [ + datasets.create_dataset(ds, data_args) for ds in interleave_dataset + ] + stop_strategy = dataset_config.StopStrategy.LAST_EXHAUSTED + multipliers = [1.0] * len(data_sets) + # If we're using epochs to train, validate the dataset length is appropriate. using_epochs = train_args.max_steps == 0 if using_epochs and enforce_ds_len_epoch: @@ -58,10 +73,10 @@ def prepare_dataset( assert ( len(ds) > 1 ), f"Dataset {ds} has length {len(ds)} which is too short for epoch training" - interleave = datasets.InterleaveDataset( data_sets, stop_strategy=stop_strategy, + multipliers=multipliers, ) ds_with_proc = data_processing.UltravoxDataproc( interleave, @@ -209,9 +224,8 @@ def train(args: config_base.TrainConfig): ) train_dataset = prepare_dataset( train_args=args, - dataset_names=args.data_sets, + interleave_dataset=args.interleave_datasets, train_on_inputs=args.train_on_inputs, - stop_strategy=args.stop_strategy, processor=processor, num_samples=args.num_samples, data_args=datasets.VoiceDatasetArgs( @@ -241,9 +255,8 @@ def train(args: config_base.TrainConfig): val_datasets = { k: prepare_dataset( train_args=args, - dataset_names=val_sets[k], + interleave_dataset=val_sets[k], train_on_inputs=args.train_on_inputs, - stop_strategy=args.stop_strategy, processor=processor, num_samples=args.val_num_samples, data_args=val_ds_args_text if k.startswith("text_") else val_ds_args, @@ -252,7 +265,7 @@ def train(args: config_base.TrainConfig): for k in val_sets } logging.info( - f"Loaded {args.data_sets} data sets, sample limit: {args.num_samples} (val sample limit: {args.val_num_samples})" + f"Loaded {args.interleave_datasets} data sets, sample limit: {args.num_samples} (val sample limit: {args.val_num_samples})" ) else: # When using DDP with split_batches=True, the primary process will distribute the batches to the workers