diff --git a/ultravox/data/dataset_config.py b/ultravox/data/dataset_config.py index ea1f0a60..38c2abd1 100644 --- a/ultravox/data/dataset_config.py +++ b/ultravox/data/dataset_config.py @@ -12,7 +12,7 @@ class DataDictConfig(BaseModel): splits: List[str] = dataclasses.field(default_factory=list) num_samples: Optional[int] = None total_samples: int = 1 - weight: float = 1.0 + multiplier: float = 1.0 streaming: bool = True user_template: str = "<|audio|>" assistant_template: str = "{{text}}" diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 1ea40c95..105a574e 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -279,7 +279,7 @@ class SizedIterableDataset(abc.ABC, data.IterableDataset): """ @abc.abstractmethod - def __len__(self): + def __len__(self) -> int: pass @@ -296,7 +296,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None: self._args = args self._session: Optional[requests.Session] = None self._rng = np.random.default_rng(self._args.shuffle_seed) - self._weight = 1.0 # the default weight for the dataset + self._multiplier = 1.0 def _init_dataset(self, dataset: data.Dataset, estimated_length: int = 1) -> None: self._dataset = dataset @@ -304,8 +304,8 @@ def _init_dataset(self, dataset: data.Dataset, estimated_length: int = 1) -> Non self._estimated_length = estimated_length @property - def weight(self) -> float: - return self._weight + def multiplier(self) -> float: + return self._multiplier def _load_audio_dataset( self, @@ -373,7 +373,7 @@ def __iter__(self): f"Mismatch between estimated length ({self._estimated_length}) and actual length ({actual_length}) for dataset of type {type(self._dataset)}. Make sure to update." ) - def __len__(self): + def __len__(self) -> int: return self._estimated_length @abc.abstractmethod @@ -493,7 +493,7 @@ def __init__(self, estimated_length: int = 1) -> None: def __iter__(self): return iter([]) - def __len__(self): + def __len__(self) -> int: return self._estimated_length @@ -1049,16 +1049,17 @@ def __init__( if self._args.shuffle: dataset = dataset.shuffle(seed=self._args.shuffle_seed) - if config.num_samples: - dataset = Range(dataset, config.num_samples, config.total_samples) - - self._weight = config.weight + self._multiplier = config.multiplier self.user_template = config.user_template self.assistant_template = config.assistant_template self.transcript_template = config.transcript_template - super()._init_dataset(dataset, config.total_samples) + if config.num_samples: + dataset = Range(dataset, config.num_samples, config.total_samples) + super()._init_dataset(dataset, len(dataset)) + else: + super()._init_dataset(dataset, config.total_samples) def _get_sample(self, row) -> VoiceSample: try: @@ -1121,7 +1122,7 @@ class StopStrategy(str, Enum): class InterleaveDataset(SizedIterableDataset): - """Interleaves multiple IterableDataset objects based on normalized weights.""" + """Interleaves multiple IterableDataset objects based on multiplier.""" def __init__( self, @@ -1142,10 +1143,12 @@ def __init__( self._static = static self._stop_strategy = stop_strategy + relative_frequencies = [ + int(getattr(ds, "multiplier", 1.0) * float(len(ds))) for ds in datasets + ] - weights = [getattr(ds, "weight", 1) for ds in datasets] - total_weight = sum(weights) - self._normalized_probs = [w / total_weight for w in weights] + total_frequency = sum(relative_frequencies) + self._normalized_probs = [f / total_frequency for f in relative_frequencies] def __iter__(self): # If no datasets are provided, return an empty iterator @@ -1180,9 +1183,12 @@ def __iter__(self): iters[iter_index] = iter(self._datasets[iter_index]) yield next(iters[iter_index]) - def __len__(self): + def __len__(self) -> int: # TODO: Implement the length method for different stop strategies - return sum(len(ds) for ds in self._datasets) + return sum( + int(getattr(ds, "multiplier", 1.0) * float(len(ds))) + for ds in self._datasets + ) class Dataproc(SizedIterableDataset): @@ -1198,7 +1204,7 @@ def _process(self, sample: VoiceSample) -> Dict[str, Any]: def __iter__(self): return (self._process(sample) for sample in self._dataset) - def __len__(self): + def __len__(self) -> int: return len(self._dataset) @@ -1234,7 +1240,7 @@ def __iter__(self): break yield sample - def __len__(self): + def __len__(self) -> int: return ( self._num_samples if self._num_samples is not None diff --git a/ultravox/data/datasets_test.py b/ultravox/data/datasets_test.py index faacd3fa..63b04d44 100644 --- a/ultravox/data/datasets_test.py +++ b/ultravox/data/datasets_test.py @@ -15,14 +15,14 @@ class FakeSizedIterableDataset(datasets.SizedIterableDataset): """Fake version of datasets.SizedIterableDataset""" - def __init__(self, n, start=0, weight=1, estimated_length=0): + def __init__(self, n, start=0, multiplier=1, estimated_length=1): self.data = range(start, start + n) - self._weight = weight + self._multiplier = multiplier self._estimated_length = estimated_length @property - def weight(self) -> float: - return self._weight + def multiplier(self) -> float: + return self._multiplier def __iter__(self): for sample in self.data: @@ -135,8 +135,8 @@ def test_interleaved_never_stop(): def test_interleaved_random(): - ds1 = FakeSizedIterableDataset(4, weight=10) - ds2 = FakeSizedIterableDataset(2, start=10, weight=1) + ds1 = FakeSizedIterableDataset(4, multiplier=10) + ds2 = FakeSizedIterableDataset(2, start=10, multiplier=1) s = datasets.InterleaveDataset( [ds1, ds2], ) diff --git a/ultravox/training/configs/meta_config.yaml b/ultravox/training/configs/meta_config.yaml index d3764d29..3fad7142 100644 --- a/ultravox/training/configs/meta_config.yaml +++ b/ultravox/training/configs/meta_config.yaml @@ -1,7 +1,7 @@ text_model: "meta-llama/Meta-Llama-3-8B-Instruct" audio_model: "facebook/wav2vec2-base-960h" -data_sets: ["gigaspeech"] +data_sets: [] # Can't use datasets with "multiplier" because it doesn't have a length attribute. val_sets: ["heysquad_human", "anyinstruct", "soda", "peoplespeech"] stop_strategy: "LAST_EXHAUSTED" diff --git a/ultravox/training/configs/release_config.yaml b/ultravox/training/configs/release_config.yaml index 36b7a5f6..4f8a9b1f 100644 --- a/ultravox/training/configs/release_config.yaml +++ b/ultravox/training/configs/release_config.yaml @@ -16,7 +16,7 @@ val_sets: ["anyinstruct", "soda", "peoplespeech"] batch_size: 24 max_steps: 14400 # x8x24 = 2,764,800 -data_sets: ["anyinstruct"] +data_sets: [] # Can't use datasets with "multiplier" because it doesn't have a length attribute. data_dicts: # continuation - path: "fixie-ai/librispeech_asr" @@ -27,7 +27,8 @@ data_dicts: user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ text }}" - weight: 1 + multiplier: 1 + total_samples: 132553 - path: "fixie-ai/librispeech_asr" name: "other" splits: @@ -35,7 +36,8 @@ data_dicts: user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ text }}" - weight: 1 + multiplier: 1 + total_samples: 148688 - path: "fixie-ai/peoples_speech" name: "clean" splits: @@ -43,7 +45,8 @@ data_dicts: user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ text_proc.format_asr_text(text) }}" - weight: 8 + multiplier: 1 + total_samples: 1501271 - path: "fixie-ai/common_voice_17_0" name: "en" splits: @@ -51,7 +54,8 @@ data_dicts: user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ text_proc.format_asr_text(sentence) }}" - weight: 8 + multiplier: 1 + total_samples: 1101170 - path: "fixie-ai/common_voice_17_0" name: "ar" splits: @@ -59,7 +63,8 @@ data_dicts: user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ sentence }}" - weight: 0.2 + multiplier: 1 + total_samples: 28369 - path: "fixie-ai/common_voice_17_0" name: "de" splits: @@ -67,7 +72,8 @@ data_dicts: user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ sentence }}" - weight: 4 + multiplier: 1 + total_samples: 589100 - path: "fixie-ai/common_voice_17_0" name: "es" splits: @@ -75,7 +81,8 @@ data_dicts: user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ sentence }}" - weight: 3 + multiplier: 1 + total_samples: 336846 - path: "fixie-ai/common_voice_17_0" name: "fr" splits: @@ -83,7 +90,8 @@ data_dicts: user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ sentence }}" - weight: 4 + multiplier: 1 + total_samples: 558054 - path: "fixie-ai/common_voice_17_0" name: "it" splits: @@ -91,7 +99,8 @@ data_dicts: user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ sentence }}" - weight: 1.2 + multiplier: 1 + total_samples: 169771 - path: "fixie-ai/common_voice_17_0" name: "ja" splits: @@ -99,7 +108,8 @@ data_dicts: user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ sentence }}" - weight: 0.1 + multiplier: 1 + total_samples: 10039 - path: "fixie-ai/common_voice_17_0" name: "pt" splits: @@ -107,7 +117,8 @@ data_dicts: user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ sentence }}" - weight: 0.2 + multiplier: 1 + total_samples: 21968 - path: "fixie-ai/common_voice_17_0" name: "ru" splits: @@ -115,7 +126,8 @@ data_dicts: user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ sentence }}" - weight: 0.2 + multiplier: 1 + total_samples: 26377 # ASR task - path: "fixie-ai/librispeech_asr" name: "clean" @@ -125,7 +137,8 @@ data_dicts: user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text }}" transcript_template: "{{ text }}" - weight: 0.1 + multiplier: 1 + total_samples: 132553 - path: "fixie-ai/librispeech_asr" name: "other" splits: @@ -133,7 +146,8 @@ data_dicts: user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text }}" transcript_template: "{{ text }}" - weight: 0.1 + multiplier: 1 + total_samples: 148688 - path: "fixie-ai/peoples_speech" name: "clean" splits: @@ -141,7 +155,8 @@ data_dicts: user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(text) }}" transcript_template: "{{ text_proc.format_asr_text(text) }}" - weight: 0.8 + multiplier: 1 + total_samples: 1501271 - path: "fixie-ai/common_voice_17_0" name: "en" splits: @@ -149,7 +164,8 @@ data_dicts: user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ text_proc.format_asr_text(sentence) }}" - weight: 0.8 + multiplier: 1 + total_samples: 1101170 - path: "fixie-ai/common_voice_17_0" name: "ar" splits: @@ -157,7 +173,8 @@ data_dicts: user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ sentence }}" - weight: 0.02 + multiplier: 1 + total_samples: 28369 - path: "fixie-ai/common_voice_17_0" name: "de" splits: @@ -165,7 +182,8 @@ data_dicts: user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ sentence }}" - weight: 0.4 + multiplier: 1 + total_samples: 589100 - path: "fixie-ai/common_voice_17_0" name: "es" splits: @@ -173,7 +191,8 @@ data_dicts: user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ sentence }}" - weight: 0.3 + multiplier: 1 + total_samples: 336846 - path: "fixie-ai/common_voice_17_0" name: "fr" splits: @@ -181,7 +200,8 @@ data_dicts: user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ sentence }}" - weight: 0.4 + multiplier: 1 + total_samples: 558054 - path: "fixie-ai/common_voice_17_0" name: "it" splits: @@ -189,7 +209,8 @@ data_dicts: user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ sentence }}" - weight: 0.12 + multiplier: 1 + total_samples: 169771 - path: "fixie-ai/common_voice_17_0" name: "ja" splits: @@ -197,7 +218,8 @@ data_dicts: user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ sentence }}" - weight: 0.01 + multiplier: 1 + total_samples: 10039 - path: "fixie-ai/common_voice_17_0" name: "pt" splits: @@ -205,7 +227,8 @@ data_dicts: user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ sentence }}" - weight: 0.02 + multiplier: 1 + total_samples: 21968 - path: "fixie-ai/common_voice_17_0" name: "ru" splits: @@ -213,4 +236,5 @@ data_dicts: user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ sentence }}" - weight: 0.02 + multiplier: 1 + total_samples: 26377 diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 5dde3f62..3657522b 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -44,16 +44,21 @@ def prepare_dataset( stop_strategy: datasets.StopStrategy, num_samples: Optional[int] = None, include_alt_fields: bool = False, # whether to generate tensors for text-only input (e.g., used for KD training) + enforce_ds_len_epoch: bool = False, ) -> datasets.SizedIterableDataset: data_sets = [datasets.create_dataset(ds, data_args) for ds in dataset_names] # If we're using epochs to train, validate the dataset length is appropriate. - if train_args.max_steps == 0: + using_epochs = train_args.max_steps == 0 + if using_epochs and enforce_ds_len_epoch: for ds in data_sets: assert ( len(ds) > 1 ), f"Dataset {ds} has length {len(ds)} which is too short for epoch training" - interleave = datasets.InterleaveDataset(data_sets, stop_strategy=stop_strategy) + interleave = datasets.InterleaveDataset( + data_sets, + stop_strategy=stop_strategy, + ) ds_with_proc = data_processing.UltravoxDataproc( interleave, processor=processor, @@ -225,6 +230,7 @@ def train(args: config_base.TrainConfig): mds_batch_size=args.batch_size, ), include_alt_fields=model.loss_config.requires_alt_fields, + enforce_ds_len_epoch=True, ) if is_master: val_ds_args = datasets.VoiceDatasetArgs(