Skip to content

Commit

Permalink
text-only
Browse files Browse the repository at this point in the history
  • Loading branch information
juberti committed Oct 17, 2024
1 parent 76fec6e commit 8a9fb9b
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 9 deletions.
55 changes: 46 additions & 9 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from ultravox.data import text_proc

AUDIO_PLACEHOLDER = "<|audio|>"
SAMPLE_RATE = 16000

# TODO(juberti): set these in the environment so they don't need to be hard-coded here.
Expand All @@ -36,7 +37,6 @@

class DatasetSplit(str, enum.Enum):
TRAIN = "train"
TEST = "test"
VALIDATION = "validation"


Expand Down Expand Up @@ -81,6 +81,9 @@ def __post_init__(self):

@dataclasses.dataclass
class DatasetConfig(helpers.Serializable):
# Note that subclasses can override any of these fields, but they currently can't
# extend structured fields like splits or user_template_args.
# See _merge_configs below for the current implementation.
name: str
"""Name of the dataset."""
base: Optional[str] = None
Expand Down Expand Up @@ -110,7 +113,7 @@ def __post_init__(self):
"""Set defaults only if this is a root config, so that said defaults in a subclass don't act as overrides."""
DEFAULTS = {
"splits": [],
"user_template": "<|audio|>",
"user_template": AUDIO_PLACEHOLDER,
"user_template_args": {},
"assistant_template": "{{text}}",
"transcript_template": "{{text}}",
Expand Down Expand Up @@ -486,7 +489,6 @@ def _get_sample(self, row) -> Optional[VoiceSample]:
**row,
text_proc=text_proc,
dataset=self,
include_audio=self._args.include_audio,
**self._config.user_template_args,
)
assistant_content = jinja2.Template(
Expand All @@ -504,11 +506,31 @@ def _get_sample(self, row) -> Optional[VoiceSample]:
raise ValueError(
"Template rendering failed. Make sure all keys in the template exist in the sample."
) from e
if not self._args.include_audio:
user_content = user_content.replace(AUDIO_PLACEHOLDER, f'"{transcript}"')
messages = _get_messages(user_content, assistant_content)
audio = self._get_audio(row, self._config.audio_field)
return self._make_sample(messages, audio, audio_transcript=transcript)


class LibriSpeechDummyDataset(VoiceDataset):
def __init__(self, args: VoiceDatasetArgs) -> None:
super().__init__(args)
# This dataset doesn't support streaming.
dataset = self._load_hf_dataset(
"hf-internal-testing/librispeech_asr_dummy",
"clean",
split="validation",
streaming=False,
)
self._init_dataset(dataset, 73)

def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]:
text = text_proc.format_asr_text(row["text"])
return self._make_sample(
_get_messages(user_content, assistant_content),
self._get_audio(row, self._config.audio_field),
audio_transcript=transcript,
self._make_messages(f"Transcribe\n{AUDIO_PLACEHOLDER}", text),
self._get_audio(row, "audio"),
audio_transcript=text,
)


Expand Down Expand Up @@ -635,10 +657,10 @@ def __len__(self):


CONTINUATION_USER_TEMPLATE = (
"Continue the following text using less than 50 words:\n\n<|audio|>"
f"Continue the following text using less than 50 words:\n\n{AUDIO_PLACEHOLDER}"
)
CONTINUATION_ASSISTANT_TEMPLATE = "{{continuation}}"
TRANSCRIPTION_USER_TEMPLATE = "Transcribe\n<|audio|>"
TRANSCRIPTION_USER_TEMPLATE = f"Transcribe\n{AUDIO_PLACEHOLDER}"

BOOLQ_CONFIG = DatasetConfig(
name="boolq",
Expand All @@ -647,7 +669,7 @@ def __len__(self):
DatasetSplitConfig(name="train", num_samples=10000),
DatasetSplitConfig(name="validation", num_samples=1000),
],
user_template="{{passage}}\n\n{{'<|audio|>' if include_audio else question}}",
user_template="{{passage}}\n\n{AUDIO_PLACEHOLDER}",
assistant_template="{{'True' if answer else 'False'}}",
transcript_template="{{question}}",
)
Expand Down Expand Up @@ -770,6 +792,19 @@ def __len__(self):
],
)

# SODA_CONFIG = DatasetConfig(
# name="soda",
# path="fixie-ai/soda-audio",
# splits=[
# DatasetSplitConfig(name="train", num_samples=1_000_000),
# DatasetSplitConfig(name="validation", num_samples=10_000),
# ],
# # Need way to specify message history.
# audio_field="audio_second_last_turn",
# assistant_template="{{alt_last_turn}}",
# transcript_template="{{turns[-2]}}",
# )

VP_EN_CONFIG = DatasetConfig(
name="voxpopuli-en",
path="facebook/voxpopuli",
Expand Down Expand Up @@ -991,6 +1026,8 @@ def _merge_configs(configs: List[DatasetConfig]) -> DatasetConfig:


def create_dataset(name: str, args: VoiceDatasetArgs) -> SizedIterableDataset:
if name == "dummy":
return LibriSpeechDummyDataset(args)
assert name in DATASET_MAP, f"Unknown dataset: {name}"
# Make a list of configs from root->base.
configs: List[DatasetConfig] = []
Expand Down
18 changes: 18 additions & 0 deletions ultravox/data/datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,24 @@ def test_generic_dataset_custom_templates():
assert sample.audio_transcript == "0"


def test_generic_dataset_text_only():
config = datasets.DatasetConfig(
name="fake_dataset",
path="fake_path",
splits=[datasets.DatasetSplitConfig(name="fake", num_samples=5)],
user_template="Transcribe\n<|audio|>",
)
ds = FakeGenericDataset(5, config, datasets.VoiceDatasetArgs(include_audio=False))
assert len(ds) == 5
sample = next(iter(ds))
assert isinstance(sample, datasets.VoiceSample)
assert sample.messages == [
{"role": "user", "content": 'Transcribe\n"0"'},
{"role": "assistant", "content": "0"},
]
assert sample.audio is None


def test_generic_dataset_merge_configs():
base_config = datasets.DatasetConfig(
name="fake_base",
Expand Down

0 comments on commit 8a9fb9b

Please sign in to comment.