From 4b99f5855cc5c2d6c0e60723a8f5d52a01a321d9 Mon Sep 17 00:00:00 2001 From: Patrick Li Date: Mon, 9 Sep 2024 17:14:30 -0700 Subject: [PATCH 01/11] First --- ultravox/data/dataset_config.py | 4 ++++ ultravox/data/datasets.py | 24 +++++++++++++++--------- ultravox/data/datasets_test.py | 2 +- ultravox/training/train.py | 7 +++++-- 4 files changed, 25 insertions(+), 12 deletions(-) diff --git a/ultravox/data/dataset_config.py b/ultravox/data/dataset_config.py index ea1f0a60..b223ef78 100644 --- a/ultravox/data/dataset_config.py +++ b/ultravox/data/dataset_config.py @@ -12,6 +12,10 @@ class DataDictConfig(BaseModel): splits: List[str] = dataclasses.field(default_factory=list) num_samples: Optional[int] = None total_samples: int = 1 + # epochs mode: + # Weight is the number of copies of the dataset + # max steps mode: + # Weight of the dataset is used to calculate the proporition of the total samples that comes from this dataset weight: float = 1.0 streaming: bool = True user_template: str = "<|audio|>" diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 1ea40c95..4c8c10bd 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -279,7 +279,7 @@ class SizedIterableDataset(abc.ABC, data.IterableDataset): """ @abc.abstractmethod - def __len__(self): + def __len__(self) -> int: pass @@ -373,8 +373,8 @@ def __iter__(self): f"Mismatch between estimated length ({self._estimated_length}) and actual length ({actual_length}) for dataset of type {type(self._dataset)}. Make sure to update." ) - def __len__(self): - return self._estimated_length + def __len__(self) -> int: + return int(self._estimated_length * self._weight) @abc.abstractmethod def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: @@ -493,7 +493,7 @@ def __init__(self, estimated_length: int = 1) -> None: def __iter__(self): return iter([]) - def __len__(self): + def __len__(self) -> int: return self._estimated_length @@ -1129,6 +1129,7 @@ def __init__( stop_strategy: StopStrategy = StopStrategy.LAST_EXHAUSTED, seed: Optional[int] = 42, static: bool = False, + using_epochs: bool = False, ) -> None: """ Args: @@ -1142,8 +1143,12 @@ def __init__( self._static = static self._stop_strategy = stop_strategy + self._using_epochs = using_epochs + if not self._using_epochs: + weights = [int(getattr(ds, "weight", 1) * len(ds)) for ds in datasets] + else: + weights = [getattr(ds, "weight", 1) for ds in datasets] - weights = [getattr(ds, "weight", 1) for ds in datasets] total_weight = sum(weights) self._normalized_probs = [w / total_weight for w in weights] @@ -1180,9 +1185,10 @@ def __iter__(self): iters[iter_index] = iter(self._datasets[iter_index]) yield next(iters[iter_index]) - def __len__(self): + # Only used when using_epochs is True + def __len__(self) -> int: # TODO: Implement the length method for different stop strategies - return sum(len(ds) for ds in self._datasets) + return sum(int(getattr(ds, "weight", 1) * len(ds)) for ds in datasets) class Dataproc(SizedIterableDataset): @@ -1198,7 +1204,7 @@ def _process(self, sample: VoiceSample) -> Dict[str, Any]: def __iter__(self): return (self._process(sample) for sample in self._dataset) - def __len__(self): + def __len__(self) -> int: return len(self._dataset) @@ -1234,7 +1240,7 @@ def __iter__(self): break yield sample - def __len__(self): + def __len__(self) -> int: return ( self._num_samples if self._num_samples is not None diff --git a/ultravox/data/datasets_test.py b/ultravox/data/datasets_test.py index faacd3fa..78263479 100644 --- a/ultravox/data/datasets_test.py +++ b/ultravox/data/datasets_test.py @@ -15,7 +15,7 @@ class FakeSizedIterableDataset(datasets.SizedIterableDataset): """Fake version of datasets.SizedIterableDataset""" - def __init__(self, n, start=0, weight=1, estimated_length=0): + def __init__(self, n, start=0, weight=1, estimated_length=1): self.data = range(start, start + n) self._weight = weight self._estimated_length = estimated_length diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 05cea992..b21feb8b 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -51,13 +51,16 @@ def prepare_dataset( ) -> datasets.SizedIterableDataset: data_sets = [datasets.create_dataset(ds, data_args) for ds in dataset_names] # If we're using epochs to train, validate the dataset length is appropriate. - if train_args.max_steps == 0: + using_epochs = train_args.max_steps == 0 + if using_epochs: for ds in data_sets: assert ( len(ds) > 1 ), f"Dataset {ds} has length {len(ds)} which is too short for epoch training" - interleave = datasets.InterleaveDataset(data_sets, stop_strategy=stop_strategy) + interleave = datasets.InterleaveDataset( + data_sets, stop_strategy=stop_strategy, using_epochs=using_epochs + ) ds_with_proc = data_processing.UltravoxDataproc( interleave, processor=processor, From 7e83d325edb800c6818a0825c091aec5b180076b Mon Sep 17 00:00:00 2001 From: Patrick Li Date: Tue, 10 Sep 2024 14:44:18 -0700 Subject: [PATCH 02/11] Fix interleave dataset len --- ultravox/data/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 4c8c10bd..d6d042b5 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -1188,7 +1188,7 @@ def __iter__(self): # Only used when using_epochs is True def __len__(self) -> int: # TODO: Implement the length method for different stop strategies - return sum(int(getattr(ds, "weight", 1) * len(ds)) for ds in datasets) + return sum(int(getattr(ds, "weight", 1) * len(ds)) for ds in self._datasets) class Dataproc(SizedIterableDataset): From b83b8c979eb9bea5161a44d9c6dfb53362533b33 Mon Sep 17 00:00:00 2001 From: Patrick Li Date: Tue, 10 Sep 2024 15:45:00 -0700 Subject: [PATCH 03/11] Fixed length in voice datasets --- ultravox/data/datasets.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index d6d042b5..1d02159c 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -374,7 +374,7 @@ def __iter__(self): ) def __len__(self) -> int: - return int(self._estimated_length * self._weight) + return self._estimated_length @abc.abstractmethod def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: @@ -1049,16 +1049,17 @@ def __init__( if self._args.shuffle: dataset = dataset.shuffle(seed=self._args.shuffle_seed) - if config.num_samples: - dataset = Range(dataset, config.num_samples, config.total_samples) - self._weight = config.weight self.user_template = config.user_template self.assistant_template = config.assistant_template self.transcript_template = config.transcript_template - super()._init_dataset(dataset, config.total_samples) + if config.num_samples: + dataset = Range(dataset, config.num_samples, config.total_samples) + super()._init_dataset(dataset, len(dataset)) + else: + super()._init_dataset(dataset, config.total_samples) def _get_sample(self, row) -> VoiceSample: try: From 4c1ffe6a6ef8ddea5d58ac011c350844bb91243b Mon Sep 17 00:00:00 2001 From: Patrick Li Date: Tue, 10 Sep 2024 16:27:19 -0700 Subject: [PATCH 04/11] Don't assert length for val sets --- ultravox/data/datasets.py | 1 - ultravox/training/train.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 1d02159c..dcad0aca 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -1186,7 +1186,6 @@ def __iter__(self): iters[iter_index] = iter(self._datasets[iter_index]) yield next(iters[iter_index]) - # Only used when using_epochs is True def __len__(self) -> int: # TODO: Implement the length method for different stop strategies return sum(int(getattr(ds, "weight", 1) * len(ds)) for ds in self._datasets) diff --git a/ultravox/training/train.py b/ultravox/training/train.py index b21feb8b..10fca142 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -48,11 +48,12 @@ def prepare_dataset( stop_strategy: datasets.StopStrategy, num_samples: Optional[int] = None, include_alt_fields: bool = False, # whether to generate tensors for text-only input (e.g., used for KD training) + is_val_set: bool = False, ) -> datasets.SizedIterableDataset: data_sets = [datasets.create_dataset(ds, data_args) for ds in dataset_names] # If we're using epochs to train, validate the dataset length is appropriate. using_epochs = train_args.max_steps == 0 - if using_epochs: + if using_epochs and not is_val_set: for ds in data_sets: assert ( len(ds) > 1 @@ -245,6 +246,7 @@ def train(args: config_base.TrainConfig): num_samples=args.val_num_samples, data_args=val_ds_args_text if k.startswith("text_") else val_ds_args, include_alt_fields=model.loss_config.requires_alt_fields, + is_val_set=True, ) for k in val_sets } From 5f31d1c5a148df5a1b829119005340651426b181 Mon Sep 17 00:00:00 2001 From: Patrick Li Date: Tue, 10 Sep 2024 16:50:04 -0700 Subject: [PATCH 05/11] Remove gigapseech from meta_config --- ultravox/training/configs/meta_config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultravox/training/configs/meta_config.yaml b/ultravox/training/configs/meta_config.yaml index d3764d29..bb5262bb 100644 --- a/ultravox/training/configs/meta_config.yaml +++ b/ultravox/training/configs/meta_config.yaml @@ -1,7 +1,7 @@ text_model: "meta-llama/Meta-Llama-3-8B-Instruct" audio_model: "facebook/wav2vec2-base-960h" -data_sets: ["gigaspeech"] +data_sets: [] val_sets: ["heysquad_human", "anyinstruct", "soda", "peoplespeech"] stop_strategy: "LAST_EXHAUSTED" From 8697dcc0220f57e802da1618662232eb15448baa Mon Sep 17 00:00:00 2001 From: Patrick Li Date: Wed, 11 Sep 2024 15:28:47 -0700 Subject: [PATCH 06/11] Addressing comments --- ultravox/data/dataset_config.py | 6 ++---- ultravox/data/datasets.py | 18 +++++++++++++----- ultravox/training/configs/meta_config.yaml | 2 +- ultravox/training/train.py | 10 ++++++---- 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/ultravox/data/dataset_config.py b/ultravox/data/dataset_config.py index b223ef78..d1395cd9 100644 --- a/ultravox/data/dataset_config.py +++ b/ultravox/data/dataset_config.py @@ -12,11 +12,9 @@ class DataDictConfig(BaseModel): splits: List[str] = dataclasses.field(default_factory=list) num_samples: Optional[int] = None total_samples: int = 1 - # epochs mode: - # Weight is the number of copies of the dataset - # max steps mode: - # Weight of the dataset is used to calculate the proporition of the total samples that comes from this dataset weight: float = 1.0 + # This is used over weight when epoch mode is set + dataset_multiplier: float = 1.0 streaming: bool = True user_template: str = "<|audio|>" assistant_template: str = "{{text}}" diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index dcad0aca..83f8faa5 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -1130,7 +1130,7 @@ def __init__( stop_strategy: StopStrategy = StopStrategy.LAST_EXHAUSTED, seed: Optional[int] = 42, static: bool = False, - using_epochs: bool = False, + use_dataset_multiplier: bool = False, ) -> None: """ Args: @@ -1144,9 +1144,11 @@ def __init__( self._static = static self._stop_strategy = stop_strategy - self._using_epochs = using_epochs - if not self._using_epochs: - weights = [int(getattr(ds, "weight", 1) * len(ds)) for ds in datasets] + self._use_dataset_multiplier = use_dataset_multiplier + if self._use_dataset_multiplier: + weights = [ + int(getattr(ds, "dataset_multiplier", 1) * len(ds)) for ds in datasets + ] else: weights = [getattr(ds, "weight", 1) for ds in datasets] @@ -1188,7 +1190,13 @@ def __iter__(self): def __len__(self) -> int: # TODO: Implement the length method for different stop strategies - return sum(int(getattr(ds, "weight", 1) * len(ds)) for ds in self._datasets) + if self._use_dataset_multiplier: + return sum( + int(getattr(ds, "dataset_multiplier", 1) * len(ds)) + for ds in self._datasets + ) + else: + return sum(len(ds) for ds in self._datasets) class Dataproc(SizedIterableDataset): diff --git a/ultravox/training/configs/meta_config.yaml b/ultravox/training/configs/meta_config.yaml index bb5262bb..d3764d29 100644 --- a/ultravox/training/configs/meta_config.yaml +++ b/ultravox/training/configs/meta_config.yaml @@ -1,7 +1,7 @@ text_model: "meta-llama/Meta-Llama-3-8B-Instruct" audio_model: "facebook/wav2vec2-base-960h" -data_sets: [] +data_sets: ["gigaspeech"] val_sets: ["heysquad_human", "anyinstruct", "soda", "peoplespeech"] stop_strategy: "LAST_EXHAUSTED" diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 10fca142..aec78be2 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -48,19 +48,21 @@ def prepare_dataset( stop_strategy: datasets.StopStrategy, num_samples: Optional[int] = None, include_alt_fields: bool = False, # whether to generate tensors for text-only input (e.g., used for KD training) - is_val_set: bool = False, + enforce_ds_len_epoch: bool = False, ) -> datasets.SizedIterableDataset: data_sets = [datasets.create_dataset(ds, data_args) for ds in dataset_names] # If we're using epochs to train, validate the dataset length is appropriate. using_epochs = train_args.max_steps == 0 - if using_epochs and not is_val_set: + if using_epochs and enforce_ds_len_epoch: for ds in data_sets: assert ( len(ds) > 1 ), f"Dataset {ds} has length {len(ds)} which is too short for epoch training" interleave = datasets.InterleaveDataset( - data_sets, stop_strategy=stop_strategy, using_epochs=using_epochs + data_sets, + stop_strategy=stop_strategy, + use_dataset_multiplier=using_epochs, ) ds_with_proc = data_processing.UltravoxDataproc( interleave, @@ -223,6 +225,7 @@ def train(args: config_base.TrainConfig): mds_batch_size=args.batch_size, ), include_alt_fields=model.loss_config.requires_alt_fields, + enforce_ds_len_epoch=True, ) if is_master: val_ds_args = datasets.VoiceDatasetArgs( @@ -246,7 +249,6 @@ def train(args: config_base.TrainConfig): num_samples=args.val_num_samples, data_args=val_ds_args_text if k.startswith("text_") else val_ds_args, include_alt_fields=model.loss_config.requires_alt_fields, - is_val_set=True, ) for k in val_sets } From 9deaa7f39a1c3702db1cbc0bdfdf06ca9a4a10bd Mon Sep 17 00:00:00 2001 From: Patrick Li Date: Thu, 12 Sep 2024 09:43:52 -0700 Subject: [PATCH 07/11] Replace weight with multiplier --- ultravox/data/dataset_config.py | 4 +- ultravox/data/datasets.py | 36 ++++----- ultravox/data/datasets_test.py | 12 +-- ultravox/training/configs/meta_config.yaml | 2 +- ultravox/training/configs/release_config.yaml | 74 ++++++++++++------- ultravox/training/train.py | 1 - 6 files changed, 71 insertions(+), 58 deletions(-) diff --git a/ultravox/data/dataset_config.py b/ultravox/data/dataset_config.py index d1395cd9..38c2abd1 100644 --- a/ultravox/data/dataset_config.py +++ b/ultravox/data/dataset_config.py @@ -12,9 +12,7 @@ class DataDictConfig(BaseModel): splits: List[str] = dataclasses.field(default_factory=list) num_samples: Optional[int] = None total_samples: int = 1 - weight: float = 1.0 - # This is used over weight when epoch mode is set - dataset_multiplier: float = 1.0 + multiplier: float = 1.0 streaming: bool = True user_template: str = "<|audio|>" assistant_template: str = "{{text}}" diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 83f8faa5..105a574e 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -296,7 +296,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None: self._args = args self._session: Optional[requests.Session] = None self._rng = np.random.default_rng(self._args.shuffle_seed) - self._weight = 1.0 # the default weight for the dataset + self._multiplier = 1.0 def _init_dataset(self, dataset: data.Dataset, estimated_length: int = 1) -> None: self._dataset = dataset @@ -304,8 +304,8 @@ def _init_dataset(self, dataset: data.Dataset, estimated_length: int = 1) -> Non self._estimated_length = estimated_length @property - def weight(self) -> float: - return self._weight + def multiplier(self) -> float: + return self._multiplier def _load_audio_dataset( self, @@ -1049,7 +1049,7 @@ def __init__( if self._args.shuffle: dataset = dataset.shuffle(seed=self._args.shuffle_seed) - self._weight = config.weight + self._multiplier = config.multiplier self.user_template = config.user_template self.assistant_template = config.assistant_template @@ -1122,7 +1122,7 @@ class StopStrategy(str, Enum): class InterleaveDataset(SizedIterableDataset): - """Interleaves multiple IterableDataset objects based on normalized weights.""" + """Interleaves multiple IterableDataset objects based on multiplier.""" def __init__( self, @@ -1130,7 +1130,6 @@ def __init__( stop_strategy: StopStrategy = StopStrategy.LAST_EXHAUSTED, seed: Optional[int] = 42, static: bool = False, - use_dataset_multiplier: bool = False, ) -> None: """ Args: @@ -1144,16 +1143,12 @@ def __init__( self._static = static self._stop_strategy = stop_strategy - self._use_dataset_multiplier = use_dataset_multiplier - if self._use_dataset_multiplier: - weights = [ - int(getattr(ds, "dataset_multiplier", 1) * len(ds)) for ds in datasets - ] - else: - weights = [getattr(ds, "weight", 1) for ds in datasets] + relative_frequencies = [ + int(getattr(ds, "multiplier", 1.0) * float(len(ds))) for ds in datasets + ] - total_weight = sum(weights) - self._normalized_probs = [w / total_weight for w in weights] + total_frequency = sum(relative_frequencies) + self._normalized_probs = [f / total_frequency for f in relative_frequencies] def __iter__(self): # If no datasets are provided, return an empty iterator @@ -1190,13 +1185,10 @@ def __iter__(self): def __len__(self) -> int: # TODO: Implement the length method for different stop strategies - if self._use_dataset_multiplier: - return sum( - int(getattr(ds, "dataset_multiplier", 1) * len(ds)) - for ds in self._datasets - ) - else: - return sum(len(ds) for ds in self._datasets) + return sum( + int(getattr(ds, "multiplier", 1.0) * float(len(ds))) + for ds in self._datasets + ) class Dataproc(SizedIterableDataset): diff --git a/ultravox/data/datasets_test.py b/ultravox/data/datasets_test.py index 78263479..63b04d44 100644 --- a/ultravox/data/datasets_test.py +++ b/ultravox/data/datasets_test.py @@ -15,14 +15,14 @@ class FakeSizedIterableDataset(datasets.SizedIterableDataset): """Fake version of datasets.SizedIterableDataset""" - def __init__(self, n, start=0, weight=1, estimated_length=1): + def __init__(self, n, start=0, multiplier=1, estimated_length=1): self.data = range(start, start + n) - self._weight = weight + self._multiplier = multiplier self._estimated_length = estimated_length @property - def weight(self) -> float: - return self._weight + def multiplier(self) -> float: + return self._multiplier def __iter__(self): for sample in self.data: @@ -135,8 +135,8 @@ def test_interleaved_never_stop(): def test_interleaved_random(): - ds1 = FakeSizedIterableDataset(4, weight=10) - ds2 = FakeSizedIterableDataset(2, start=10, weight=1) + ds1 = FakeSizedIterableDataset(4, multiplier=10) + ds2 = FakeSizedIterableDataset(2, start=10, multiplier=1) s = datasets.InterleaveDataset( [ds1, ds2], ) diff --git a/ultravox/training/configs/meta_config.yaml b/ultravox/training/configs/meta_config.yaml index d3764d29..3fad7142 100644 --- a/ultravox/training/configs/meta_config.yaml +++ b/ultravox/training/configs/meta_config.yaml @@ -1,7 +1,7 @@ text_model: "meta-llama/Meta-Llama-3-8B-Instruct" audio_model: "facebook/wav2vec2-base-960h" -data_sets: ["gigaspeech"] +data_sets: [] # Can't use datasets with "multiplier" because it doesn't have a length attribute. val_sets: ["heysquad_human", "anyinstruct", "soda", "peoplespeech"] stop_strategy: "LAST_EXHAUSTED" diff --git a/ultravox/training/configs/release_config.yaml b/ultravox/training/configs/release_config.yaml index 36b7a5f6..4f8a9b1f 100644 --- a/ultravox/training/configs/release_config.yaml +++ b/ultravox/training/configs/release_config.yaml @@ -16,7 +16,7 @@ val_sets: ["anyinstruct", "soda", "peoplespeech"] batch_size: 24 max_steps: 14400 # x8x24 = 2,764,800 -data_sets: ["anyinstruct"] +data_sets: [] # Can't use datasets with "multiplier" because it doesn't have a length attribute. data_dicts: # continuation - path: "fixie-ai/librispeech_asr" @@ -27,7 +27,8 @@ data_dicts: user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ text }}" - weight: 1 + multiplier: 1 + total_samples: 132553 - path: "fixie-ai/librispeech_asr" name: "other" splits: @@ -35,7 +36,8 @@ data_dicts: user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ text }}" - weight: 1 + multiplier: 1 + total_samples: 148688 - path: "fixie-ai/peoples_speech" name: "clean" splits: @@ -43,7 +45,8 @@ data_dicts: user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ text_proc.format_asr_text(text) }}" - weight: 8 + multiplier: 1 + total_samples: 1501271 - path: "fixie-ai/common_voice_17_0" name: "en" splits: @@ -51,7 +54,8 @@ data_dicts: user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ text_proc.format_asr_text(sentence) }}" - weight: 8 + multiplier: 1 + total_samples: 1101170 - path: "fixie-ai/common_voice_17_0" name: "ar" splits: @@ -59,7 +63,8 @@ data_dicts: user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ sentence }}" - weight: 0.2 + multiplier: 1 + total_samples: 28369 - path: "fixie-ai/common_voice_17_0" name: "de" splits: @@ -67,7 +72,8 @@ data_dicts: user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ sentence }}" - weight: 4 + multiplier: 1 + total_samples: 589100 - path: "fixie-ai/common_voice_17_0" name: "es" splits: @@ -75,7 +81,8 @@ data_dicts: user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ sentence }}" - weight: 3 + multiplier: 1 + total_samples: 336846 - path: "fixie-ai/common_voice_17_0" name: "fr" splits: @@ -83,7 +90,8 @@ data_dicts: user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ sentence }}" - weight: 4 + multiplier: 1 + total_samples: 558054 - path: "fixie-ai/common_voice_17_0" name: "it" splits: @@ -91,7 +99,8 @@ data_dicts: user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ sentence }}" - weight: 1.2 + multiplier: 1 + total_samples: 169771 - path: "fixie-ai/common_voice_17_0" name: "ja" splits: @@ -99,7 +108,8 @@ data_dicts: user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ sentence }}" - weight: 0.1 + multiplier: 1 + total_samples: 10039 - path: "fixie-ai/common_voice_17_0" name: "pt" splits: @@ -107,7 +117,8 @@ data_dicts: user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ sentence }}" - weight: 0.2 + multiplier: 1 + total_samples: 21968 - path: "fixie-ai/common_voice_17_0" name: "ru" splits: @@ -115,7 +126,8 @@ data_dicts: user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ sentence }}" - weight: 0.2 + multiplier: 1 + total_samples: 26377 # ASR task - path: "fixie-ai/librispeech_asr" name: "clean" @@ -125,7 +137,8 @@ data_dicts: user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text }}" transcript_template: "{{ text }}" - weight: 0.1 + multiplier: 1 + total_samples: 132553 - path: "fixie-ai/librispeech_asr" name: "other" splits: @@ -133,7 +146,8 @@ data_dicts: user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text }}" transcript_template: "{{ text }}" - weight: 0.1 + multiplier: 1 + total_samples: 148688 - path: "fixie-ai/peoples_speech" name: "clean" splits: @@ -141,7 +155,8 @@ data_dicts: user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(text) }}" transcript_template: "{{ text_proc.format_asr_text(text) }}" - weight: 0.8 + multiplier: 1 + total_samples: 1501271 - path: "fixie-ai/common_voice_17_0" name: "en" splits: @@ -149,7 +164,8 @@ data_dicts: user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ text_proc.format_asr_text(sentence) }}" - weight: 0.8 + multiplier: 1 + total_samples: 1101170 - path: "fixie-ai/common_voice_17_0" name: "ar" splits: @@ -157,7 +173,8 @@ data_dicts: user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ sentence }}" - weight: 0.02 + multiplier: 1 + total_samples: 28369 - path: "fixie-ai/common_voice_17_0" name: "de" splits: @@ -165,7 +182,8 @@ data_dicts: user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ sentence }}" - weight: 0.4 + multiplier: 1 + total_samples: 589100 - path: "fixie-ai/common_voice_17_0" name: "es" splits: @@ -173,7 +191,8 @@ data_dicts: user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ sentence }}" - weight: 0.3 + multiplier: 1 + total_samples: 336846 - path: "fixie-ai/common_voice_17_0" name: "fr" splits: @@ -181,7 +200,8 @@ data_dicts: user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ sentence }}" - weight: 0.4 + multiplier: 1 + total_samples: 558054 - path: "fixie-ai/common_voice_17_0" name: "it" splits: @@ -189,7 +209,8 @@ data_dicts: user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ sentence }}" - weight: 0.12 + multiplier: 1 + total_samples: 169771 - path: "fixie-ai/common_voice_17_0" name: "ja" splits: @@ -197,7 +218,8 @@ data_dicts: user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ sentence }}" - weight: 0.01 + multiplier: 1 + total_samples: 10039 - path: "fixie-ai/common_voice_17_0" name: "pt" splits: @@ -205,7 +227,8 @@ data_dicts: user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ sentence }}" - weight: 0.02 + multiplier: 1 + total_samples: 21968 - path: "fixie-ai/common_voice_17_0" name: "ru" splits: @@ -213,4 +236,5 @@ data_dicts: user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ sentence }}" - weight: 0.02 + multiplier: 1 + total_samples: 26377 diff --git a/ultravox/training/train.py b/ultravox/training/train.py index aec78be2..a37cd248 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -62,7 +62,6 @@ def prepare_dataset( interleave = datasets.InterleaveDataset( data_sets, stop_strategy=stop_strategy, - use_dataset_multiplier=using_epochs, ) ds_with_proc = data_processing.UltravoxDataproc( interleave, From 2a1bcc56e358347bf3cff4d1e3548d3b1f80c269 Mon Sep 17 00:00:00 2001 From: Patrick Li Date: Fri, 20 Sep 2024 01:21:18 -0700 Subject: [PATCH 08/11] Add interleave config --- ultravox/data/dataset_config.py | 25 +- ultravox/data/datasets.py | 44 +- ultravox/data/datasets_test.py | 21 +- ultravox/training/config_base.py | 18 +- ultravox/training/configs/meta_config.yaml | 2 - ultravox/training/configs/release_config.yaml | 469 +++++++++--------- ultravox/training/train.py | 31 +- 7 files changed, 319 insertions(+), 291 deletions(-) diff --git a/ultravox/data/dataset_config.py b/ultravox/data/dataset_config.py index 38c2abd1..e1491a5a 100644 --- a/ultravox/data/dataset_config.py +++ b/ultravox/data/dataset_config.py @@ -1,23 +1,38 @@ import dataclasses +from enum import Enum from typing import List, Optional from pydantic import BaseModel +class StopStrategy(str, Enum): + FIRST_EXHAUSTED = "first_exhausted" + LAST_EXHAUSTED = "last_exhausted" + NEVER_STOP = "never_stop" + + class DataDictConfig(BaseModel): - # Path to the dataset, or huggingface dataset id - path: str - # Name of the dataset, or huggingface dataset config/subset + path: str # Name of the dataset, or huggingface dataset config/subset name: Optional[str] = None splits: List[str] = dataclasses.field(default_factory=list) num_samples: Optional[int] = None total_samples: int = 1 - multiplier: float = 1.0 streaming: bool = True user_template: str = "<|audio|>" assistant_template: str = "{{text}}" transcript_template: str = "{{text}}" - def __post_init__(self): + def post_init(self): if not self.splits: raise ValueError("At least one split must be provided") + + +class DatasetMultiplier(BaseModel): + dataset: DataDictConfig + multiplier: float + + +class InterleaveDataConfig(BaseModel): + # In InterleaveDataset, when to stop interleave: choose from last_exhausted (default), first_exhausted, or never_stop + stop_strategy: StopStrategy = StopStrategy.LAST_EXHAUSTED + datasets_with_multiplier: List[DatasetMultiplier] diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 105a574e..afdcee1e 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -8,7 +8,6 @@ import os import tempfile import warnings -from enum import Enum from typing import Any, Callable, Dict, List, Optional, Sequence import datasets @@ -296,17 +295,12 @@ def __init__(self, args: VoiceDatasetArgs) -> None: self._args = args self._session: Optional[requests.Session] = None self._rng = np.random.default_rng(self._args.shuffle_seed) - self._multiplier = 1.0 def _init_dataset(self, dataset: data.Dataset, estimated_length: int = 1) -> None: self._dataset = dataset # Only required when using epochs when training dataset. self._estimated_length = estimated_length - @property - def multiplier(self) -> float: - return self._multiplier - def _load_audio_dataset( self, path: str, @@ -1049,8 +1043,6 @@ def __init__( if self._args.shuffle: dataset = dataset.shuffle(seed=self._args.shuffle_seed) - self._multiplier = config.multiplier - self.user_template = config.user_template self.assistant_template = config.assistant_template self.transcript_template = config.transcript_template @@ -1089,7 +1081,9 @@ def _get_sample(self, row) -> VoiceSample: ) -def create_dataset(name: str, args: VoiceDatasetArgs) -> SizedIterableDataset: +def create_dataset( + name: str | dataset_config.DataDictConfig, args: VoiceDatasetArgs +) -> SizedIterableDataset: DATASET_MAP: Dict[str, Any] = { "anyinstruct": AnyInstructAnswerDataset, "anyinstruct_in": AnyInstructInputDataset, @@ -1115,21 +1109,16 @@ def create_dataset(name: str, args: VoiceDatasetArgs) -> SizedIterableDataset: return DATASET_MAP[name](args, *ext) -class StopStrategy(str, Enum): - FIRST_EXHAUSTED = "first_exhausted" - LAST_EXHAUSTED = "last_exhausted" - NEVER_STOP = "never_stop" - - class InterleaveDataset(SizedIterableDataset): """Interleaves multiple IterableDataset objects based on multiplier.""" def __init__( self, datasets: Sequence[SizedIterableDataset], - stop_strategy: StopStrategy = StopStrategy.LAST_EXHAUSTED, + stop_strategy: dataset_config.StopStrategy = dataset_config.StopStrategy.LAST_EXHAUSTED, seed: Optional[int] = 42, static: bool = False, + multipliers: Optional[List[float]] = None, ) -> None: """ Args: @@ -1143,10 +1132,15 @@ def __init__( self._static = static self._stop_strategy = stop_strategy + if multipliers is None: + self._multipliers = [1.0] * len(datasets) + else: + self._multipliers = multipliers + relative_frequencies = [ - int(getattr(ds, "multiplier", 1.0) * float(len(ds))) for ds in datasets + int(float(len(ds)) * multiple) + for ds, multiple in zip(datasets, self._multipliers) ] - total_frequency = sum(relative_frequencies) self._normalized_probs = [f / total_frequency for f in relative_frequencies] @@ -1173,9 +1167,13 @@ def __iter__(self): exhausted[iter_index] = True # Check if stopping condition is met - if self._stop_strategy == StopStrategy.FIRST_EXHAUSTED or ( - self._stop_strategy == StopStrategy.LAST_EXHAUSTED - and all(exhausted) + if ( + self._stop_strategy == dataset_config.StopStrategy.FIRST_EXHAUSTED + or ( + self._stop_strategy + == dataset_config.StopStrategy.LAST_EXHAUSTED + and all(exhausted) + ) ): break @@ -1186,8 +1184,8 @@ def __iter__(self): def __len__(self) -> int: # TODO: Implement the length method for different stop strategies return sum( - int(getattr(ds, "multiplier", 1.0) * float(len(ds))) - for ds in self._datasets + int(float(len(ds)) * multiple) + for ds, multiple in zip(datasets, self._multipliers) ) diff --git a/ultravox/data/datasets_test.py b/ultravox/data/datasets_test.py index 63b04d44..791a378a 100644 --- a/ultravox/data/datasets_test.py +++ b/ultravox/data/datasets_test.py @@ -15,15 +15,10 @@ class FakeSizedIterableDataset(datasets.SizedIterableDataset): """Fake version of datasets.SizedIterableDataset""" - def __init__(self, n, start=0, multiplier=1, estimated_length=1): + def __init__(self, n, start=0, estimated_length=1): self.data = range(start, start + n) - self._multiplier = multiplier self._estimated_length = estimated_length - @property - def multiplier(self) -> float: - return self._multiplier - def __iter__(self): for sample in self.data: yield sample @@ -98,7 +93,7 @@ def test_interleaved_first_exhausted(): ds3 = FakeSizedIterableDataset(3) s = datasets.InterleaveDataset( [ds1, ds2, ds3], - stop_strategy=datasets.StopStrategy.FIRST_EXHAUSTED, + stop_strategy=dataset_config.StopStrategy.FIRST_EXHAUSTED, static=True, ) # static=True disables random sampling of datasets, so the order is deterministic @@ -113,7 +108,7 @@ def test_interleaved_last_exhausted(): ds2 = FakeSizedIterableDataset(2, start=10) s = datasets.InterleaveDataset( [ds1, ds2], - stop_strategy=datasets.StopStrategy.LAST_EXHAUSTED, + stop_strategy=dataset_config.StopStrategy.LAST_EXHAUSTED, static=True, ) # static=True disables random sampling of datasets, so the order is deterministic @@ -126,7 +121,7 @@ def test_interleaved_never_stop(): ds2 = FakeSizedIterableDataset(2, start=10) s = datasets.InterleaveDataset( [ds1, ds2], - stop_strategy=datasets.StopStrategy.NEVER_STOP, + stop_strategy=dataset_config.StopStrategy.NEVER_STOP, static=True, ) # static=True disables random sampling of datasets, so the order is deterministic @@ -135,11 +130,9 @@ def test_interleaved_never_stop(): def test_interleaved_random(): - ds1 = FakeSizedIterableDataset(4, multiplier=10) - ds2 = FakeSizedIterableDataset(2, start=10, multiplier=1) - s = datasets.InterleaveDataset( - [ds1, ds2], - ) + ds1 = FakeSizedIterableDataset(4) + ds2 = FakeSizedIterableDataset(2, start=10) + s = datasets.InterleaveDataset([ds1, ds2], multipliers=[10, 1]) # stop_strategy=last_exhausted will stop interleaving when the last dataset is exhausted (attempted after exhaustion) assert list(s) == [ 0, diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index 232c0584..85de6f7b 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -3,19 +3,17 @@ import logging import os from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import List, Optional import simple_parsing import torch from ultravox.data import dataset_config -from ultravox.data import datasets from ultravox.model import ultravox_config @dataclasses.dataclass class TrainConfig: - data_sets: List[str] val_sets: List[str] # language model to use text_model: str @@ -29,13 +27,11 @@ class TrainConfig: # we first parse the data_dicts as a list of dictionaries. After parsing, # we convert these dictionaries to DataDictConfig objects using Pydantic # to enforce type constraints and validation, in the __post_init__ method. - data_dicts: Optional[List[Dict[str, Any]]] = None + interleave_datasets: dataset_config.InterleaveDataConfig do_train: bool = True do_eval: bool = True - # In InterleaveDataset, when to stop interleave: choose from last_exhausted (default), first_exhausted, or never_stop - stop_strategy: datasets.StopStrategy = datasets.StopStrategy.LAST_EXHAUSTED data_dir: Optional[str] = None mds: bool = False num_samples: Optional[int] = None @@ -91,16 +87,6 @@ class TrainConfig: loss_config: Optional[ultravox_config.LossConfig] = None def __post_init__(self): - if self.data_dicts: - self.data_dicts = [ - dataset_config.DataDictConfig(**data_dict) - for data_dict in self.data_dicts - ] - # For now, self.data_dicts is a hack to allow for the inclusion of new datasets using the - # GenericVoiceDataset class, without changing how existing datasets are specified in - # self.data_sets. In the future, all datasets will be updated to use the DataDictConfig class. - self.data_sets.extend(self.data_dicts) - assert self.data_type in ["bfloat16", "float16", "float32"] if self.device == "cuda" and not torch.cuda.is_available(): self.device = "mps" if torch.backends.mps.is_available() else "cpu" diff --git a/ultravox/training/configs/meta_config.yaml b/ultravox/training/configs/meta_config.yaml index 3fad7142..fe84296a 100644 --- a/ultravox/training/configs/meta_config.yaml +++ b/ultravox/training/configs/meta_config.yaml @@ -1,9 +1,7 @@ text_model: "meta-llama/Meta-Llama-3-8B-Instruct" audio_model: "facebook/wav2vec2-base-960h" -data_sets: [] # Can't use datasets with "multiplier" because it doesn't have a length attribute. val_sets: ["heysquad_human", "anyinstruct", "soda", "peoplespeech"] -stop_strategy: "LAST_EXHAUSTED" train_on_inputs: False shuffle_data: True diff --git a/ultravox/training/configs/release_config.yaml b/ultravox/training/configs/release_config.yaml index 4f8a9b1f..ab754a58 100644 --- a/ultravox/training/configs/release_config.yaml +++ b/ultravox/training/configs/release_config.yaml @@ -16,225 +16,250 @@ val_sets: ["anyinstruct", "soda", "peoplespeech"] batch_size: 24 max_steps: 14400 # x8x24 = 2,764,800 -data_sets: [] # Can't use datasets with "multiplier" because it doesn't have a length attribute. -data_dicts: -# continuation - - path: "fixie-ai/librispeech_asr" - name: "clean" - splits: - - "train.100" # 28_539 samples - - "train.360" # 104_014 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ text }}" - multiplier: 1 - total_samples: 132553 - - path: "fixie-ai/librispeech_asr" - name: "other" - splits: - - "train.500" # 148_688 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ text }}" - multiplier: 1 - total_samples: 148688 - - path: "fixie-ai/peoples_speech" - name: "clean" - splits: - - "train" # 1_501_271 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ text_proc.format_asr_text(text) }}" - multiplier: 1 - total_samples: 1501271 - - path: "fixie-ai/common_voice_17_0" - name: "en" - splits: - - "train" # 1_101_170 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ text_proc.format_asr_text(sentence) }}" - multiplier: 1 - total_samples: 1101170 - - path: "fixie-ai/common_voice_17_0" - name: "ar" - splits: - - "train" # 28_369 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 28369 - - path: "fixie-ai/common_voice_17_0" - name: "de" - splits: - - "train" # 589_100 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 589100 - - path: "fixie-ai/common_voice_17_0" - name: "es" - splits: - - "train" # 336_846 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 336846 - - path: "fixie-ai/common_voice_17_0" - name: "fr" - splits: - - "train" # 558_054 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 558054 - - path: "fixie-ai/common_voice_17_0" - name: "it" - splits: - - "train" # 169_771 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 169771 - - path: "fixie-ai/common_voice_17_0" - name: "ja" - splits: - - "train" # 10_039 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 10039 - - path: "fixie-ai/common_voice_17_0" - name: "pt" - splits: - - "train" # 21_968 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 21968 - - path: "fixie-ai/common_voice_17_0" - name: "ru" - splits: - - "train" # 26_377 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 26377 -# ASR task - - path: "fixie-ai/librispeech_asr" - name: "clean" - splits: - - "train.100" # 28_539 samples - - "train.360" # 104_014 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text }}" - transcript_template: "{{ text }}" - multiplier: 1 - total_samples: 132553 - - path: "fixie-ai/librispeech_asr" - name: "other" - splits: - - "train.500" # 148_688 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text }}" - transcript_template: "{{ text }}" - multiplier: 1 - total_samples: 148688 - - path: "fixie-ai/peoples_speech" - name: "clean" - splits: - - "train" # 1_501_271 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(text) }}" - transcript_template: "{{ text_proc.format_asr_text(text) }}" - multiplier: 1 - total_samples: 1501271 - - path: "fixie-ai/common_voice_17_0" - name: "en" - splits: - - "train" # 1_101_170 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ text_proc.format_asr_text(sentence) }}" - multiplier: 1 - total_samples: 1101170 - - path: "fixie-ai/common_voice_17_0" - name: "ar" - splits: - - "train" # 28_369 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 28369 - - path: "fixie-ai/common_voice_17_0" - name: "de" - splits: - - "train" # 589_100 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 589100 - - path: "fixie-ai/common_voice_17_0" - name: "es" - splits: - - "train" # 336_846 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 336846 - - path: "fixie-ai/common_voice_17_0" - name: "fr" - splits: - - "train" # 558_054 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 558054 - - path: "fixie-ai/common_voice_17_0" - name: "it" - splits: - - "train" # 169_771 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 169771 - - path: "fixie-ai/common_voice_17_0" - name: "ja" - splits: - - "train" # 10_039 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 10039 - - path: "fixie-ai/common_voice_17_0" - name: "pt" - splits: - - "train" # 21_968 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 21968 - - path: "fixie-ai/common_voice_17_0" - name: "ru" - splits: - - "train" # 26_377 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" - multiplier: 1 - total_samples: 26377 +interleave_datasets: + stop_strategy: "LAST_EXHAUSTED" + datasets_with_multiplier: + # continuation + - dataset: + path: "fixie-ai/librispeech_asr" + name: "clean" + splits: + - "train.100" # 28_539 samples + - "train.360" # 104_014 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ text }}" + total_samples: 132553 + multiplier: 1 + - dataset: + path: "fixie-ai/librispeech_asr" + name: "other" + splits: + - "train.500" # 148_688 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ text }}" + total_samples: 148688 + multiplier: 1 + - dataset: + path: "fixie-ai/peoples_speech" + name: "clean" + splits: + - "train" # 1_501_271 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ text_proc.format_asr_text(text) }}" + total_samples: 1501271 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "en" + splits: + - "train" # 1_101_170 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ text_proc.format_asr_text(sentence) }}" + total_samples: 1101170 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "ar" + splits: + - "train" # 28_369 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ sentence }}" + total_samples: 28369 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "de" + splits: + - "train" # 589_100 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ sentence }}" + total_samples: 589100 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "es" + splits: + - "train" # 336_846 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ sentence }}" + total_samples: 336846 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "fr" + splits: + - "train" # 558_054 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ sentence }}" + total_samples: 558054 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "it" + splits: + - "train" # 169_771 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ sentence }}" + total_samples: 169771 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "ja" + splits: + - "train" # 10_039 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ sentence }}" + total_samples: 10039 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "pt" + splits: + - "train" # 21_968 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ sentence }}" + total_samples: 21968 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "ru" + splits: + - "train" # 26_377 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ sentence }}" + total_samples: 2637 + multiplier: 1 + # ASR task + - dataset: + path: "fixie-ai/librispeech_asr" + name: "clean" + splits: + - "train.100" # 28_539 samples + - "train.360" # 104_014 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text }}" + transcript_template: "{{ text }}" + total_samples: 132553 + multiplier: 1 + - dataset: + path: "fixie-ai/librispeech_asr" + name: "other" + splits: + - "train.500" # 148_688 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text }}" + transcript_template: "{{ text }}" + total_samples: 148688 + multiplier: 1 + - dataset: + path: "fixie-ai/peoples_speech" + name: "clean" + splits: + - "train" # 1_501_271 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(text) }}" + transcript_template: "{{ text_proc.format_asr_text(text) }}" + total_samples: 1501271 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "en" + splits: + - "train" # 1_101_170 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ text_proc.format_asr_text(sentence) }}" + total_samples: 1101170 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "ar" + splits: + - "train" # 28_369 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ sentence }}" + total_samples: 28369 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "de" + splits: + - "train" # 589_100 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ sentence }}" + total_samples: 589100 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "es" + splits: + - "train" # 336_846 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ sentence }}" + total_samples: 336846 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "fr" + splits: + - "train" # 558_054 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ sentence }}" + total_samples: 558054 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "it" + splits: + - "train" # 169_771 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ sentence }}" + total_samples: 169771 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "ja" + splits: + - "train" # 10_039 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ sentence }}" + total_samples: 10039 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "pt" + splits: + - "train" # 21_968 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ sentence }}" + total_samples: 21968 + multiplier: 1 + - dataset: + path: "fixie-ai/common_voice_17_0" + name: "ru" + splits: + - "train" # 26_377 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ sentence }}" + total_samples: 26377 + multiplier: 1 diff --git a/ultravox/training/train.py b/ultravox/training/train.py index a37cd248..fac8e530 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -21,6 +21,7 @@ import wandb.sdk from torch.utils import data +from ultravox.data import dataset_config from ultravox.data import datasets from ultravox.model import data_processing from ultravox.model import ultravox_config @@ -41,16 +42,30 @@ def fix_hyphens(arg: str): def prepare_dataset( train_args: config_base.TrainConfig, - dataset_names: List[str], + interleave_dataset: dataset_config.InterleaveDataConfig | List[str], data_args: datasets.VoiceDatasetArgs, processor: ultravox_processing.UltravoxProcessor, train_on_inputs: bool, - stop_strategy: datasets.StopStrategy, num_samples: Optional[int] = None, include_alt_fields: bool = False, # whether to generate tensors for text-only input (e.g., used for KD training) enforce_ds_len_epoch: bool = False, ) -> datasets.SizedIterableDataset: - data_sets = [datasets.create_dataset(ds, data_args) for ds in dataset_names] + if isinstance(interleave_dataset, dataset_config.InterleaveDataConfig): + data_sets = [ + datasets.create_dataset(ds.dataset, data_args) + for ds in interleave_dataset.datasets_with_multiplier + ] + multipliers = [ + ds.multiplier for ds in interleave_dataset.datasets_with_multiplier + ] + stop_strategy = interleave_dataset.stop_strategy + else: + data_sets = [ + datasets.create_dataset(ds, data_args) for ds in interleave_dataset + ] + stop_strategy = dataset_config.StopStrategy.LAST_EXHAUSTED + multipliers = [1.0] * len(data_sets) + # If we're using epochs to train, validate the dataset length is appropriate. using_epochs = train_args.max_steps == 0 if using_epochs and enforce_ds_len_epoch: @@ -58,10 +73,10 @@ def prepare_dataset( assert ( len(ds) > 1 ), f"Dataset {ds} has length {len(ds)} which is too short for epoch training" - interleave = datasets.InterleaveDataset( data_sets, stop_strategy=stop_strategy, + multipliers=multipliers, ) ds_with_proc = data_processing.UltravoxDataproc( interleave, @@ -209,9 +224,8 @@ def train(args: config_base.TrainConfig): ) train_dataset = prepare_dataset( train_args=args, - dataset_names=args.data_sets, + interleave_dataset=args.interleave_datasets, train_on_inputs=args.train_on_inputs, - stop_strategy=args.stop_strategy, processor=processor, num_samples=args.num_samples, data_args=datasets.VoiceDatasetArgs( @@ -241,9 +255,8 @@ def train(args: config_base.TrainConfig): val_datasets = { k: prepare_dataset( train_args=args, - dataset_names=val_sets[k], + interleave_dataset=val_sets[k], train_on_inputs=args.train_on_inputs, - stop_strategy=args.stop_strategy, processor=processor, num_samples=args.val_num_samples, data_args=val_ds_args_text if k.startswith("text_") else val_ds_args, @@ -252,7 +265,7 @@ def train(args: config_base.TrainConfig): for k in val_sets } logging.info( - f"Loaded {args.data_sets} data sets, sample limit: {args.num_samples} (val sample limit: {args.val_num_samples})" + f"Loaded {args.interleave_datasets} data sets, sample limit: {args.num_samples} (val sample limit: {args.val_num_samples})" ) else: # When using DDP with split_batches=True, the primary process will distribute the batches to the workers From c3cfab6b1652953a39d6ead8465f2a991fcf74c4 Mon Sep 17 00:00:00 2001 From: Patrick Li Date: Fri, 20 Sep 2024 11:06:16 -0700 Subject: [PATCH 09/11] Update --- ultravox/training/config_base.py | 3 +++ ultravox/training/configs/release_config.yaml | 3 +-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index 85de6f7b..57754fb0 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -87,6 +87,9 @@ class TrainConfig: loss_config: Optional[ultravox_config.LossConfig] = None def __post_init__(self): + self.interleave_datasets = dataset_config.InterleaveDataConfig( + **self.interleave_datasets + ) assert self.data_type in ["bfloat16", "float16", "float32"] if self.device == "cuda" and not torch.cuda.is_available(): self.device = "mps" if torch.backends.mps.is_available() else "cpu" diff --git a/ultravox/training/configs/release_config.yaml b/ultravox/training/configs/release_config.yaml index ab754a58..67b0a17e 100644 --- a/ultravox/training/configs/release_config.yaml +++ b/ultravox/training/configs/release_config.yaml @@ -5,7 +5,6 @@ exp_name: "ultravox-v0_4" text_model: "meta-llama/Meta-Llama-3.1-8B-Instruct" audio_model: "openai/whisper-medium" - loss_config: # Choose from ["KL_Divergence", "CrossEntropy"], default is "KL_Divergence" loss_function: "KL_Divergence" @@ -17,7 +16,7 @@ batch_size: 24 max_steps: 14400 # x8x24 = 2,764,800 interleave_datasets: - stop_strategy: "LAST_EXHAUSTED" + stop_strategy: "last_exhausted" datasets_with_multiplier: # continuation - dataset: From d5ee8f32990b2f094871ca5c25d2e666b2553e7c Mon Sep 17 00:00:00 2001 From: Patrick Li Date: Fri, 20 Sep 2024 14:17:04 -0700 Subject: [PATCH 10/11] Revert "Update" This reverts commit c3cfab6b1652953a39d6ead8465f2a991fcf74c4. --- ultravox/training/config_base.py | 3 --- ultravox/training/configs/release_config.yaml | 3 ++- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index fb749527..754858cf 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -94,9 +94,6 @@ class TrainConfig: loss_config: Optional[ultravox_config.LossConfig] = None def __post_init__(self): - self.interleave_datasets = dataset_config.InterleaveDataConfig( - **self.interleave_datasets - ) assert self.data_type in ["bfloat16", "float16", "float32"] if self.device == "cuda" and not torch.cuda.is_available(): self.device = "mps" if torch.backends.mps.is_available() else "cpu" diff --git a/ultravox/training/configs/release_config.yaml b/ultravox/training/configs/release_config.yaml index 67b0a17e..ab754a58 100644 --- a/ultravox/training/configs/release_config.yaml +++ b/ultravox/training/configs/release_config.yaml @@ -5,6 +5,7 @@ exp_name: "ultravox-v0_4" text_model: "meta-llama/Meta-Llama-3.1-8B-Instruct" audio_model: "openai/whisper-medium" + loss_config: # Choose from ["KL_Divergence", "CrossEntropy"], default is "KL_Divergence" loss_function: "KL_Divergence" @@ -16,7 +17,7 @@ batch_size: 24 max_steps: 14400 # x8x24 = 2,764,800 interleave_datasets: - stop_strategy: "last_exhausted" + stop_strategy: "LAST_EXHAUSTED" datasets_with_multiplier: # continuation - dataset: From 5ab23303446619dac5c7b8f8fbffd9c9c0d4a981 Mon Sep 17 00:00:00 2001 From: Patrick Li Date: Fri, 20 Sep 2024 14:17:10 -0700 Subject: [PATCH 11/11] Revert "Add interleave config" This reverts commit 2a1bcc56e358347bf3cff4d1e3548d3b1f80c269. --- ultravox/data/dataset_config.py | 25 +- ultravox/data/datasets.py | 44 +- ultravox/data/datasets_test.py | 21 +- ultravox/training/config_base.py | 18 +- ultravox/training/configs/meta_config.yaml | 2 + ultravox/training/configs/release_config.yaml | 469 +++++++++--------- ultravox/training/train.py | 31 +- 7 files changed, 291 insertions(+), 319 deletions(-) diff --git a/ultravox/data/dataset_config.py b/ultravox/data/dataset_config.py index e1491a5a..38c2abd1 100644 --- a/ultravox/data/dataset_config.py +++ b/ultravox/data/dataset_config.py @@ -1,38 +1,23 @@ import dataclasses -from enum import Enum from typing import List, Optional from pydantic import BaseModel -class StopStrategy(str, Enum): - FIRST_EXHAUSTED = "first_exhausted" - LAST_EXHAUSTED = "last_exhausted" - NEVER_STOP = "never_stop" - - class DataDictConfig(BaseModel): - path: str # Name of the dataset, or huggingface dataset config/subset + # Path to the dataset, or huggingface dataset id + path: str + # Name of the dataset, or huggingface dataset config/subset name: Optional[str] = None splits: List[str] = dataclasses.field(default_factory=list) num_samples: Optional[int] = None total_samples: int = 1 + multiplier: float = 1.0 streaming: bool = True user_template: str = "<|audio|>" assistant_template: str = "{{text}}" transcript_template: str = "{{text}}" - def post_init(self): + def __post_init__(self): if not self.splits: raise ValueError("At least one split must be provided") - - -class DatasetMultiplier(BaseModel): - dataset: DataDictConfig - multiplier: float - - -class InterleaveDataConfig(BaseModel): - # In InterleaveDataset, when to stop interleave: choose from last_exhausted (default), first_exhausted, or never_stop - stop_strategy: StopStrategy = StopStrategy.LAST_EXHAUSTED - datasets_with_multiplier: List[DatasetMultiplier] diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index afdcee1e..105a574e 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -8,6 +8,7 @@ import os import tempfile import warnings +from enum import Enum from typing import Any, Callable, Dict, List, Optional, Sequence import datasets @@ -295,12 +296,17 @@ def __init__(self, args: VoiceDatasetArgs) -> None: self._args = args self._session: Optional[requests.Session] = None self._rng = np.random.default_rng(self._args.shuffle_seed) + self._multiplier = 1.0 def _init_dataset(self, dataset: data.Dataset, estimated_length: int = 1) -> None: self._dataset = dataset # Only required when using epochs when training dataset. self._estimated_length = estimated_length + @property + def multiplier(self) -> float: + return self._multiplier + def _load_audio_dataset( self, path: str, @@ -1043,6 +1049,8 @@ def __init__( if self._args.shuffle: dataset = dataset.shuffle(seed=self._args.shuffle_seed) + self._multiplier = config.multiplier + self.user_template = config.user_template self.assistant_template = config.assistant_template self.transcript_template = config.transcript_template @@ -1081,9 +1089,7 @@ def _get_sample(self, row) -> VoiceSample: ) -def create_dataset( - name: str | dataset_config.DataDictConfig, args: VoiceDatasetArgs -) -> SizedIterableDataset: +def create_dataset(name: str, args: VoiceDatasetArgs) -> SizedIterableDataset: DATASET_MAP: Dict[str, Any] = { "anyinstruct": AnyInstructAnswerDataset, "anyinstruct_in": AnyInstructInputDataset, @@ -1109,16 +1115,21 @@ def create_dataset( return DATASET_MAP[name](args, *ext) +class StopStrategy(str, Enum): + FIRST_EXHAUSTED = "first_exhausted" + LAST_EXHAUSTED = "last_exhausted" + NEVER_STOP = "never_stop" + + class InterleaveDataset(SizedIterableDataset): """Interleaves multiple IterableDataset objects based on multiplier.""" def __init__( self, datasets: Sequence[SizedIterableDataset], - stop_strategy: dataset_config.StopStrategy = dataset_config.StopStrategy.LAST_EXHAUSTED, + stop_strategy: StopStrategy = StopStrategy.LAST_EXHAUSTED, seed: Optional[int] = 42, static: bool = False, - multipliers: Optional[List[float]] = None, ) -> None: """ Args: @@ -1132,15 +1143,10 @@ def __init__( self._static = static self._stop_strategy = stop_strategy - if multipliers is None: - self._multipliers = [1.0] * len(datasets) - else: - self._multipliers = multipliers - relative_frequencies = [ - int(float(len(ds)) * multiple) - for ds, multiple in zip(datasets, self._multipliers) + int(getattr(ds, "multiplier", 1.0) * float(len(ds))) for ds in datasets ] + total_frequency = sum(relative_frequencies) self._normalized_probs = [f / total_frequency for f in relative_frequencies] @@ -1167,13 +1173,9 @@ def __iter__(self): exhausted[iter_index] = True # Check if stopping condition is met - if ( - self._stop_strategy == dataset_config.StopStrategy.FIRST_EXHAUSTED - or ( - self._stop_strategy - == dataset_config.StopStrategy.LAST_EXHAUSTED - and all(exhausted) - ) + if self._stop_strategy == StopStrategy.FIRST_EXHAUSTED or ( + self._stop_strategy == StopStrategy.LAST_EXHAUSTED + and all(exhausted) ): break @@ -1184,8 +1186,8 @@ def __iter__(self): def __len__(self) -> int: # TODO: Implement the length method for different stop strategies return sum( - int(float(len(ds)) * multiple) - for ds, multiple in zip(datasets, self._multipliers) + int(getattr(ds, "multiplier", 1.0) * float(len(ds))) + for ds in self._datasets ) diff --git a/ultravox/data/datasets_test.py b/ultravox/data/datasets_test.py index 791a378a..63b04d44 100644 --- a/ultravox/data/datasets_test.py +++ b/ultravox/data/datasets_test.py @@ -15,10 +15,15 @@ class FakeSizedIterableDataset(datasets.SizedIterableDataset): """Fake version of datasets.SizedIterableDataset""" - def __init__(self, n, start=0, estimated_length=1): + def __init__(self, n, start=0, multiplier=1, estimated_length=1): self.data = range(start, start + n) + self._multiplier = multiplier self._estimated_length = estimated_length + @property + def multiplier(self) -> float: + return self._multiplier + def __iter__(self): for sample in self.data: yield sample @@ -93,7 +98,7 @@ def test_interleaved_first_exhausted(): ds3 = FakeSizedIterableDataset(3) s = datasets.InterleaveDataset( [ds1, ds2, ds3], - stop_strategy=dataset_config.StopStrategy.FIRST_EXHAUSTED, + stop_strategy=datasets.StopStrategy.FIRST_EXHAUSTED, static=True, ) # static=True disables random sampling of datasets, so the order is deterministic @@ -108,7 +113,7 @@ def test_interleaved_last_exhausted(): ds2 = FakeSizedIterableDataset(2, start=10) s = datasets.InterleaveDataset( [ds1, ds2], - stop_strategy=dataset_config.StopStrategy.LAST_EXHAUSTED, + stop_strategy=datasets.StopStrategy.LAST_EXHAUSTED, static=True, ) # static=True disables random sampling of datasets, so the order is deterministic @@ -121,7 +126,7 @@ def test_interleaved_never_stop(): ds2 = FakeSizedIterableDataset(2, start=10) s = datasets.InterleaveDataset( [ds1, ds2], - stop_strategy=dataset_config.StopStrategy.NEVER_STOP, + stop_strategy=datasets.StopStrategy.NEVER_STOP, static=True, ) # static=True disables random sampling of datasets, so the order is deterministic @@ -130,9 +135,11 @@ def test_interleaved_never_stop(): def test_interleaved_random(): - ds1 = FakeSizedIterableDataset(4) - ds2 = FakeSizedIterableDataset(2, start=10) - s = datasets.InterleaveDataset([ds1, ds2], multipliers=[10, 1]) + ds1 = FakeSizedIterableDataset(4, multiplier=10) + ds2 = FakeSizedIterableDataset(2, start=10, multiplier=1) + s = datasets.InterleaveDataset( + [ds1, ds2], + ) # stop_strategy=last_exhausted will stop interleaving when the last dataset is exhausted (attempted after exhaustion) assert list(s) == [ 0, diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index 754858cf..cfe2dc47 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -5,17 +5,19 @@ import re import sys from pathlib import Path -from typing import List, Optional +from typing import Any, Dict, List, Optional import simple_parsing import torch from ultravox.data import dataset_config +from ultravox.data import datasets from ultravox.model import ultravox_config @dataclasses.dataclass class TrainConfig: + data_sets: List[str] val_sets: List[str] # language model to use text_model: str @@ -29,11 +31,13 @@ class TrainConfig: # we first parse the data_dicts as a list of dictionaries. After parsing, # we convert these dictionaries to DataDictConfig objects using Pydantic # to enforce type constraints and validation, in the __post_init__ method. - interleave_datasets: dataset_config.InterleaveDataConfig + data_dicts: Optional[List[Dict[str, Any]]] = None do_train: bool = True do_eval: bool = True + # In InterleaveDataset, when to stop interleave: choose from last_exhausted (default), first_exhausted, or never_stop + stop_strategy: datasets.StopStrategy = datasets.StopStrategy.LAST_EXHAUSTED data_dir: Optional[str] = None mds: bool = False num_samples: Optional[int] = None @@ -94,6 +98,16 @@ class TrainConfig: loss_config: Optional[ultravox_config.LossConfig] = None def __post_init__(self): + if self.data_dicts: + self.data_dicts = [ + dataset_config.DataDictConfig(**data_dict) + for data_dict in self.data_dicts + ] + # For now, self.data_dicts is a hack to allow for the inclusion of new datasets using the + # GenericVoiceDataset class, without changing how existing datasets are specified in + # self.data_sets. In the future, all datasets will be updated to use the DataDictConfig class. + self.data_sets.extend(self.data_dicts) + assert self.data_type in ["bfloat16", "float16", "float32"] if self.device == "cuda" and not torch.cuda.is_available(): self.device = "mps" if torch.backends.mps.is_available() else "cpu" diff --git a/ultravox/training/configs/meta_config.yaml b/ultravox/training/configs/meta_config.yaml index fe84296a..3fad7142 100644 --- a/ultravox/training/configs/meta_config.yaml +++ b/ultravox/training/configs/meta_config.yaml @@ -1,7 +1,9 @@ text_model: "meta-llama/Meta-Llama-3-8B-Instruct" audio_model: "facebook/wav2vec2-base-960h" +data_sets: [] # Can't use datasets with "multiplier" because it doesn't have a length attribute. val_sets: ["heysquad_human", "anyinstruct", "soda", "peoplespeech"] +stop_strategy: "LAST_EXHAUSTED" train_on_inputs: False shuffle_data: True diff --git a/ultravox/training/configs/release_config.yaml b/ultravox/training/configs/release_config.yaml index ab754a58..4f8a9b1f 100644 --- a/ultravox/training/configs/release_config.yaml +++ b/ultravox/training/configs/release_config.yaml @@ -16,250 +16,225 @@ val_sets: ["anyinstruct", "soda", "peoplespeech"] batch_size: 24 max_steps: 14400 # x8x24 = 2,764,800 -interleave_datasets: - stop_strategy: "LAST_EXHAUSTED" - datasets_with_multiplier: - # continuation - - dataset: - path: "fixie-ai/librispeech_asr" - name: "clean" - splits: - - "train.100" # 28_539 samples - - "train.360" # 104_014 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ text }}" - total_samples: 132553 - multiplier: 1 - - dataset: - path: "fixie-ai/librispeech_asr" - name: "other" - splits: - - "train.500" # 148_688 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ text }}" - total_samples: 148688 - multiplier: 1 - - dataset: - path: "fixie-ai/peoples_speech" - name: "clean" - splits: - - "train" # 1_501_271 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ text_proc.format_asr_text(text) }}" - total_samples: 1501271 - multiplier: 1 - - dataset: - path: "fixie-ai/common_voice_17_0" - name: "en" - splits: - - "train" # 1_101_170 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ text_proc.format_asr_text(sentence) }}" - total_samples: 1101170 - multiplier: 1 - - dataset: - path: "fixie-ai/common_voice_17_0" - name: "ar" - splits: - - "train" # 28_369 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" - total_samples: 28369 - multiplier: 1 - - dataset: - path: "fixie-ai/common_voice_17_0" - name: "de" - splits: - - "train" # 589_100 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" - total_samples: 589100 - multiplier: 1 - - dataset: - path: "fixie-ai/common_voice_17_0" - name: "es" - splits: - - "train" # 336_846 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" - total_samples: 336846 - multiplier: 1 - - dataset: - path: "fixie-ai/common_voice_17_0" - name: "fr" - splits: - - "train" # 558_054 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" - total_samples: 558054 - multiplier: 1 - - dataset: - path: "fixie-ai/common_voice_17_0" - name: "it" - splits: - - "train" # 169_771 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" - total_samples: 169771 - multiplier: 1 - - dataset: - path: "fixie-ai/common_voice_17_0" - name: "ja" - splits: - - "train" # 10_039 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" - total_samples: 10039 - multiplier: 1 - - dataset: - path: "fixie-ai/common_voice_17_0" - name: "pt" - splits: - - "train" # 21_968 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" - total_samples: 21968 - multiplier: 1 - - dataset: - path: "fixie-ai/common_voice_17_0" - name: "ru" - splits: - - "train" # 26_377 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" - total_samples: 2637 - multiplier: 1 - # ASR task - - dataset: - path: "fixie-ai/librispeech_asr" - name: "clean" - splits: - - "train.100" # 28_539 samples - - "train.360" # 104_014 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text }}" - transcript_template: "{{ text }}" - total_samples: 132553 - multiplier: 1 - - dataset: - path: "fixie-ai/librispeech_asr" - name: "other" - splits: - - "train.500" # 148_688 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text }}" - transcript_template: "{{ text }}" - total_samples: 148688 - multiplier: 1 - - dataset: - path: "fixie-ai/peoples_speech" - name: "clean" - splits: - - "train" # 1_501_271 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(text) }}" - transcript_template: "{{ text_proc.format_asr_text(text) }}" - total_samples: 1501271 - multiplier: 1 - - dataset: - path: "fixie-ai/common_voice_17_0" - name: "en" - splits: - - "train" # 1_101_170 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ text_proc.format_asr_text(sentence) }}" - total_samples: 1101170 - multiplier: 1 - - dataset: - path: "fixie-ai/common_voice_17_0" - name: "ar" - splits: - - "train" # 28_369 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" - total_samples: 28369 - multiplier: 1 - - dataset: - path: "fixie-ai/common_voice_17_0" - name: "de" - splits: - - "train" # 589_100 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" - total_samples: 589100 - multiplier: 1 - - dataset: - path: "fixie-ai/common_voice_17_0" - name: "es" - splits: - - "train" # 336_846 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" - total_samples: 336846 - multiplier: 1 - - dataset: - path: "fixie-ai/common_voice_17_0" - name: "fr" - splits: - - "train" # 558_054 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" - total_samples: 558054 - multiplier: 1 - - dataset: - path: "fixie-ai/common_voice_17_0" - name: "it" - splits: - - "train" # 169_771 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" - total_samples: 169771 - multiplier: 1 - - dataset: - path: "fixie-ai/common_voice_17_0" - name: "ja" - splits: - - "train" # 10_039 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" - total_samples: 10039 - multiplier: 1 - - dataset: - path: "fixie-ai/common_voice_17_0" - name: "pt" - splits: - - "train" # 21_968 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" - total_samples: 21968 - multiplier: 1 - - dataset: - path: "fixie-ai/common_voice_17_0" - name: "ru" - splits: - - "train" # 26_377 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" - total_samples: 26377 - multiplier: 1 +data_sets: [] # Can't use datasets with "multiplier" because it doesn't have a length attribute. +data_dicts: +# continuation + - path: "fixie-ai/librispeech_asr" + name: "clean" + splits: + - "train.100" # 28_539 samples + - "train.360" # 104_014 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ text }}" + multiplier: 1 + total_samples: 132553 + - path: "fixie-ai/librispeech_asr" + name: "other" + splits: + - "train.500" # 148_688 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ text }}" + multiplier: 1 + total_samples: 148688 + - path: "fixie-ai/peoples_speech" + name: "clean" + splits: + - "train" # 1_501_271 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ text_proc.format_asr_text(text) }}" + multiplier: 1 + total_samples: 1501271 + - path: "fixie-ai/common_voice_17_0" + name: "en" + splits: + - "train" # 1_101_170 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ text_proc.format_asr_text(sentence) }}" + multiplier: 1 + total_samples: 1101170 + - path: "fixie-ai/common_voice_17_0" + name: "ar" + splits: + - "train" # 28_369 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ sentence }}" + multiplier: 1 + total_samples: 28369 + - path: "fixie-ai/common_voice_17_0" + name: "de" + splits: + - "train" # 589_100 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ sentence }}" + multiplier: 1 + total_samples: 589100 + - path: "fixie-ai/common_voice_17_0" + name: "es" + splits: + - "train" # 336_846 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ sentence }}" + multiplier: 1 + total_samples: 336846 + - path: "fixie-ai/common_voice_17_0" + name: "fr" + splits: + - "train" # 558_054 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ sentence }}" + multiplier: 1 + total_samples: 558054 + - path: "fixie-ai/common_voice_17_0" + name: "it" + splits: + - "train" # 169_771 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ sentence }}" + multiplier: 1 + total_samples: 169771 + - path: "fixie-ai/common_voice_17_0" + name: "ja" + splits: + - "train" # 10_039 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ sentence }}" + multiplier: 1 + total_samples: 10039 + - path: "fixie-ai/common_voice_17_0" + name: "pt" + splits: + - "train" # 21_968 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ sentence }}" + multiplier: 1 + total_samples: 21968 + - path: "fixie-ai/common_voice_17_0" + name: "ru" + splits: + - "train" # 26_377 samples + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ sentence }}" + multiplier: 1 + total_samples: 26377 +# ASR task + - path: "fixie-ai/librispeech_asr" + name: "clean" + splits: + - "train.100" # 28_539 samples + - "train.360" # 104_014 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text }}" + transcript_template: "{{ text }}" + multiplier: 1 + total_samples: 132553 + - path: "fixie-ai/librispeech_asr" + name: "other" + splits: + - "train.500" # 148_688 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text }}" + transcript_template: "{{ text }}" + multiplier: 1 + total_samples: 148688 + - path: "fixie-ai/peoples_speech" + name: "clean" + splits: + - "train" # 1_501_271 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(text) }}" + transcript_template: "{{ text_proc.format_asr_text(text) }}" + multiplier: 1 + total_samples: 1501271 + - path: "fixie-ai/common_voice_17_0" + name: "en" + splits: + - "train" # 1_101_170 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ text_proc.format_asr_text(sentence) }}" + multiplier: 1 + total_samples: 1101170 + - path: "fixie-ai/common_voice_17_0" + name: "ar" + splits: + - "train" # 28_369 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ sentence }}" + multiplier: 1 + total_samples: 28369 + - path: "fixie-ai/common_voice_17_0" + name: "de" + splits: + - "train" # 589_100 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ sentence }}" + multiplier: 1 + total_samples: 589100 + - path: "fixie-ai/common_voice_17_0" + name: "es" + splits: + - "train" # 336_846 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ sentence }}" + multiplier: 1 + total_samples: 336846 + - path: "fixie-ai/common_voice_17_0" + name: "fr" + splits: + - "train" # 558_054 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ sentence }}" + multiplier: 1 + total_samples: 558054 + - path: "fixie-ai/common_voice_17_0" + name: "it" + splits: + - "train" # 169_771 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ sentence }}" + multiplier: 1 + total_samples: 169771 + - path: "fixie-ai/common_voice_17_0" + name: "ja" + splits: + - "train" # 10_039 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ sentence }}" + multiplier: 1 + total_samples: 10039 + - path: "fixie-ai/common_voice_17_0" + name: "pt" + splits: + - "train" # 21_968 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ sentence }}" + multiplier: 1 + total_samples: 21968 + - path: "fixie-ai/common_voice_17_0" + name: "ru" + splits: + - "train" # 26_377 samples + user_template: "{{ dataset._get_transcribe_prompt() }}" + assistant_template: "{{ text_proc.format_asr_text(sentence) }}" + transcript_template: "{{ sentence }}" + multiplier: 1 + total_samples: 26377 diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 91aaa95f..3657522b 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -20,7 +20,6 @@ import wandb.sdk from torch.utils import data -from ultravox.data import dataset_config from ultravox.data import datasets from ultravox.model import data_processing from ultravox.model import ultravox_config @@ -38,30 +37,16 @@ def prepare_dataset( train_args: config_base.TrainConfig, - interleave_dataset: dataset_config.InterleaveDataConfig | List[str], + dataset_names: List[str], data_args: datasets.VoiceDatasetArgs, processor: ultravox_processing.UltravoxProcessor, train_on_inputs: bool, + stop_strategy: datasets.StopStrategy, num_samples: Optional[int] = None, include_alt_fields: bool = False, # whether to generate tensors for text-only input (e.g., used for KD training) enforce_ds_len_epoch: bool = False, ) -> datasets.SizedIterableDataset: - if isinstance(interleave_dataset, dataset_config.InterleaveDataConfig): - data_sets = [ - datasets.create_dataset(ds.dataset, data_args) - for ds in interleave_dataset.datasets_with_multiplier - ] - multipliers = [ - ds.multiplier for ds in interleave_dataset.datasets_with_multiplier - ] - stop_strategy = interleave_dataset.stop_strategy - else: - data_sets = [ - datasets.create_dataset(ds, data_args) for ds in interleave_dataset - ] - stop_strategy = dataset_config.StopStrategy.LAST_EXHAUSTED - multipliers = [1.0] * len(data_sets) - + data_sets = [datasets.create_dataset(ds, data_args) for ds in dataset_names] # If we're using epochs to train, validate the dataset length is appropriate. using_epochs = train_args.max_steps == 0 if using_epochs and enforce_ds_len_epoch: @@ -69,10 +54,10 @@ def prepare_dataset( assert ( len(ds) > 1 ), f"Dataset {ds} has length {len(ds)} which is too short for epoch training" + interleave = datasets.InterleaveDataset( data_sets, stop_strategy=stop_strategy, - multipliers=multipliers, ) ds_with_proc = data_processing.UltravoxDataproc( interleave, @@ -230,8 +215,9 @@ def train(args: config_base.TrainConfig): ) train_dataset = prepare_dataset( train_args=args, - interleave_dataset=args.interleave_datasets, + dataset_names=args.data_sets, train_on_inputs=args.train_on_inputs, + stop_strategy=args.stop_strategy, processor=processor, num_samples=args.num_samples, data_args=datasets.VoiceDatasetArgs( @@ -261,8 +247,9 @@ def train(args: config_base.TrainConfig): val_datasets = { k: prepare_dataset( train_args=args, - interleave_dataset=val_sets[k], + dataset_names=val_sets[k], train_on_inputs=args.train_on_inputs, + stop_strategy=args.stop_strategy, processor=processor, num_samples=args.val_num_samples, data_args=val_ds_args_text if k.startswith("text_") else val_ds_args, @@ -271,7 +258,7 @@ def train(args: config_base.TrainConfig): for k in val_sets } logging.info( - f"Loaded {args.interleave_datasets} data sets, sample limit: {args.num_samples} (val sample limit: {args.val_num_samples})" + f"Loaded {args.data_sets} data sets, sample limit: {args.num_samples} (val sample limit: {args.val_num_samples})" ) else: # When using DDP with split_batches=True, the primary process will distribute the batches to the workers