diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index b3015a2..ba82c86 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -24,6 +24,7 @@ from ultravox.data import text_proc +AUDIO_PLACEHOLDER = "<|audio|>" SAMPLE_RATE = 16000 # TODO(juberti): set these in the environment so they don't need to be hard-coded here. @@ -36,7 +37,6 @@ class DatasetSplit(str, enum.Enum): TRAIN = "train" - TEST = "test" VALIDATION = "validation" @@ -81,6 +81,9 @@ def __post_init__(self): @dataclasses.dataclass class DatasetConfig(helpers.Serializable): + # Note that subclasses can override any of these fields, but they currently can't + # extend structured fields like splits or user_template_args. + # See _merge_configs below for the current implementation. name: str """Name of the dataset.""" base: Optional[str] = None @@ -110,7 +113,7 @@ def __post_init__(self): """Set defaults only if this is a root config, so that said defaults in a subclass don't act as overrides.""" DEFAULTS = { "splits": [], - "user_template": "<|audio|>", + "user_template": AUDIO_PLACEHOLDER, "user_template_args": {}, "assistant_template": "{{text}}", "transcript_template": "{{text}}", @@ -486,7 +489,6 @@ def _get_sample(self, row) -> Optional[VoiceSample]: **row, text_proc=text_proc, dataset=self, - include_audio=self._args.include_audio, **self._config.user_template_args, ) assistant_content = jinja2.Template( @@ -504,11 +506,31 @@ def _get_sample(self, row) -> Optional[VoiceSample]: raise ValueError( "Template rendering failed. Make sure all keys in the template exist in the sample." ) from e + if not self._args.include_audio: + user_content = user_content.replace(AUDIO_PLACEHOLDER, f'"{transcript}"') + messages = _get_messages(user_content, assistant_content) + audio = self._get_audio(row, self._config.audio_field) + return self._make_sample(messages, audio, audio_transcript=transcript) + + +class LibriSpeechDummyDataset(VoiceDataset): + def __init__(self, args: VoiceDatasetArgs) -> None: + super().__init__(args) + # This dataset doesn't support streaming. + dataset = self._load_hf_dataset( + "hf-internal-testing/librispeech_asr_dummy", + "clean", + split="validation", + streaming=False, + ) + self._init_dataset(dataset, 73) + def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: + text = text_proc.format_asr_text(row["text"]) return self._make_sample( - _get_messages(user_content, assistant_content), - self._get_audio(row, self._config.audio_field), - audio_transcript=transcript, + self._make_messages(f"Transcribe\n{AUDIO_PLACEHOLDER}", text), + self._get_audio(row, "audio"), + audio_transcript=text, ) @@ -635,10 +657,10 @@ def __len__(self): CONTINUATION_USER_TEMPLATE = ( - "Continue the following text using less than 50 words:\n\n<|audio|>" + f"Continue the following text using less than 50 words:\n\n{AUDIO_PLACEHOLDER}" ) CONTINUATION_ASSISTANT_TEMPLATE = "{{continuation}}" -TRANSCRIPTION_USER_TEMPLATE = "Transcribe\n<|audio|>" +TRANSCRIPTION_USER_TEMPLATE = f"Transcribe\n{AUDIO_PLACEHOLDER}" BOOLQ_CONFIG = DatasetConfig( name="boolq", @@ -647,7 +669,7 @@ def __len__(self): DatasetSplitConfig(name="train", num_samples=10000), DatasetSplitConfig(name="validation", num_samples=1000), ], - user_template="{{passage}}\n\n{{'<|audio|>' if include_audio else question}}", + user_template="{{passage}}\n\n{AUDIO_PLACEHOLDER}", assistant_template="{{'True' if answer else 'False'}}", transcript_template="{{question}}", ) @@ -770,6 +792,19 @@ def __len__(self): ], ) +# SODA_CONFIG = DatasetConfig( +# name="soda", +# path="fixie-ai/soda-audio", +# splits=[ +# DatasetSplitConfig(name="train", num_samples=1_000_000), +# DatasetSplitConfig(name="validation", num_samples=10_000), +# ], +# # Need way to specify message history. +# audio_field="audio_second_last_turn", +# assistant_template="{{alt_last_turn}}", +# transcript_template="{{turns[-2]}}", +# ) + VP_EN_CONFIG = DatasetConfig( name="voxpopuli-en", path="facebook/voxpopuli", @@ -991,6 +1026,8 @@ def _merge_configs(configs: List[DatasetConfig]) -> DatasetConfig: def create_dataset(name: str, args: VoiceDatasetArgs) -> SizedIterableDataset: + if name == "dummy": + return LibriSpeechDummyDataset(args) assert name in DATASET_MAP, f"Unknown dataset: {name}" # Make a list of configs from root->base. configs: List[DatasetConfig] = [] diff --git a/ultravox/data/datasets_test.py b/ultravox/data/datasets_test.py index 193453f..245aac3 100644 --- a/ultravox/data/datasets_test.py +++ b/ultravox/data/datasets_test.py @@ -308,6 +308,24 @@ def test_generic_dataset_custom_templates(): assert sample.audio_transcript == "0" +def test_generic_dataset_text_only(): + config = datasets.DatasetConfig( + name="fake_dataset", + path="fake_path", + splits=[datasets.DatasetSplitConfig(name="fake", num_samples=5)], + user_template="Transcribe\n<|audio|>", + ) + ds = FakeGenericDataset(5, config, datasets.VoiceDatasetArgs(include_audio=False)) + assert len(ds) == 5 + sample = next(iter(ds)) + assert isinstance(sample, datasets.VoiceSample) + assert sample.messages == [ + {"role": "user", "content": 'Transcribe\n"0"'}, + {"role": "assistant", "content": "0"}, + ] + assert sample.audio is None + + def test_generic_dataset_merge_configs(): base_config = datasets.DatasetConfig( name="fake_base",