From f937487878769ffbe3426a8f9217b38e0d8513d7 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Mon, 5 Aug 2024 14:16:24 -0700 Subject: [PATCH] [ds_tool] Tools with Audio (#62) * allow ds_tool to TTS multiple audios with the same voice * add resolve_voice to tts client providers * rename provider -> implementation --- poetry.lock | 13 +++++- pyproject.toml | 1 + scripts/dataset_creation/tools_audio.sh | 6 +++ ultravox/tools/ds_tool/caching.py | 25 ++++++++--- ultravox/tools/ds_tool/ds_tool.py | 49 ++++++++++++++-------- ultravox/tools/ds_tool/tts.py | 29 ++++++++----- ultravox/tools/ds_tool/user_messages.jinja | 7 ++++ 7 files changed, 96 insertions(+), 34 deletions(-) create mode 100644 scripts/dataset_creation/tools_audio.sh create mode 100644 ultravox/tools/ds_tool/user_messages.jinja diff --git a/poetry.lock b/poetry.lock index a3cb3348..56edc005 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5902,6 +5902,17 @@ files = [ {file = "types_python_dateutil-2.9.0.20240316-py3-none-any.whl", hash = "sha256:6b8cb66d960771ce5ff974e9dd45e38facb81718cc1e208b10b1baccbfdbee3b"}, ] +[[package]] +name = "types-pyyaml" +version = "6.0.12.20240724" +description = "Typing stubs for PyYAML" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-PyYAML-6.0.12.20240724.tar.gz", hash = "sha256:cf7b31ae67e0c5b2919c703d2affc415485099d3fe6666a6912f040fd05cb67f"}, + {file = "types_PyYAML-6.0.12.20240724-py3-none-any.whl", hash = "sha256:e5becec598f3aa3a2ddf671de4a75fa1c6856fbf73b2840286c9d50fae2d5d48"}, +] + [[package]] name = "types-requests" version = "2.26.3" @@ -6685,4 +6696,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "cacd27da473431681d3550b5d13bf790a7b447287885d2b2e8cb5d5b25d7e95e" +content-hash = "d790cd25bda3d667cdafe935a78d7aaf82229194d40444fce67ba77f19c11469" diff --git a/pyproject.toml b/pyproject.toml index b335e3fc..449146dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ gradio-client = "~1.0.1" gradio = "~3.40.1" gpustat = "~1.1.1" types-requests = "~2.26.0" +types-pyyaml = "^6.0.12.20240724" [build-system] requires = ["poetry-core"] diff --git a/scripts/dataset_creation/tools_audio.sh b/scripts/dataset_creation/tools_audio.sh new file mode 100644 index 00000000..1a020957 --- /dev/null +++ b/scripts/dataset_creation/tools_audio.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +# Given the tools dataset, we want to create the audio for all user messages +just tts -d fixie-ai/tools -u fixie-ai/tools-audio --private \ + -c @ultravox/tools/ds_tool/user_messages.jinja -j -a user_message_audios \ + -V random --num_workers 20 -i eleven --token $HF_WRITE_TOKEN \ No newline at end of file diff --git a/ultravox/tools/ds_tool/caching.py b/ultravox/tools/ds_tool/caching.py index 9f76f4e2..30ad87d0 100644 --- a/ultravox/tools/ds_tool/caching.py +++ b/ultravox/tools/ds_tool/caching.py @@ -1,7 +1,7 @@ import hashlib import json import os -from typing import Optional +from typing import List, Optional, Union, overload import openai @@ -37,14 +37,27 @@ def chat_completion(self, **kwargs) -> str: class CachingTtsWrapper: - def __init__(self, client: tts.Client, provider: str): + def __init__(self, client: tts.Client, implementation: str): super().__init__() self._client = client - self._base_path = os.path.join(".cache/ds_tool/tts", provider) + self._base_path = os.path.join(".cache/ds_tool/tts", implementation) - def tts(self, text: str, voice: Optional[str] = None) -> bytes: - path = os.path.join(self._base_path, voice or "default") - text_hash = hashlib.sha256(text.encode()).hexdigest() + @overload + def tts(self, text: str, voice: Optional[str] = None) -> bytes: ... + + @overload + def tts(self, text: List[str], voice: Optional[str] = None) -> List[bytes]: ... + + def tts( + self, text: Union[str, List[str]], voice: Optional[str] = None + ) -> Union[bytes, List[bytes]]: + text_hash = hashlib.sha256(str(text).encode()).hexdigest() + voice = self._client.resolve_voice(voice) + + if isinstance(text, list): + return [self.tts(t, voice) for t in text] + + path = os.path.join(self._base_path, voice) os.makedirs(path, exist_ok=True) cache_path = os.path.join(path, f"{text_hash}.wav") diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index 71c4f1f3..9ad122bc 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -7,6 +7,7 @@ import jinja2 import openai import simple_parsing +import yaml from ultravox.data import text_proc from ultravox.tools.ds_tool import caching @@ -20,7 +21,8 @@ class TtsTask: implementation: str = simple_parsing.field(default="azure", alias="-i") # Column name containing the text to convert to audio. It can be a Jinja variable expression. - column_name: str = simple_parsing.field(default="question", alias="-c") + template: str = simple_parsing.field(alias="-T") + json_mode: bool = simple_parsing.field(default=False, alias="-j") audio_column_name: Optional[str] = simple_parsing.field(default=None, alias="-a") voice: Optional[str] = simple_parsing.field(default=None, alias="-V") sample_rate: int = simple_parsing.field(default=16000, alias="-r") @@ -32,35 +34,48 @@ def __post_init__(self): self.audio_column_name = f"{self.column_name}_audio" tts_client = caching.CachingTtsWrapper( tts.create_client(self.implementation, self.sample_rate), - provider=self.implementation, + implementation=self.implementation, ) + if self.template.startswith("@"): + with open(self.template[1:], "r") as template_file: + self.template = template_file.read() + def map_split( self, ds_split: datasets.Dataset, num_proc: int, writer_batch_size: int ) -> datasets.Dataset: - print(f'TTS mapping "{self.column_name}" to "{self.audio_column_name}"...') - return ds_split.map( + print(f'TTS mapping "{self.template}" to "{self.audio_column_name}"...') + ds_split = ds_split.map( self._map_sample, num_proc=num_proc, writer_batch_size=writer_batch_size - ).cast_column( - self.audio_column_name, datasets.Audio(sampling_rate=self.sample_rate) ) + column_type = datasets.Audio(sampling_rate=self.sample_rate) + if self.json_mode and isinstance( + ds_split.features[self.audio_column_name], datasets.Sequence + ): + column_type = datasets.Sequence(column_type) + return ds_split.cast_column(self.audio_column_name, column_type) def _map_sample(self, sample): - # using a Jinja template for some added flexibility - # The {{ var }} syntax is how Jinja denotes variables + # using a Jinja template for some added flexibility, template can include variables and functions + # e.g., {{ text }} or {{ text_proc.format_asr_text(text) }} try: - text = jinja2.Template( - "{{" + self.column_name + "}}", undefined=jinja2.StrictUndefined - ).render(**sample) + text_or_texts = jinja2.Template( + self.template, undefined=jinja2.StrictUndefined + ).render(**sample, json_dump=json.dumps, text_proc=text_proc) except jinja2.TemplateError as e: print(f"Error rendering template: {e}") - print(f"column_name: {self.column_name}") + print(f"template: {self.template}") print(f"sample keys: {list(sample.keys())}") raise ValueError( f"Template rendering failed. Make sure column_name exists in the sample." ) from e - sample[self.audio_column_name] = tts_client.tts(text, self.voice) + if self.json_mode: + text_or_texts = yaml.safe_load(text_or_texts) + assert isinstance(text_or_texts, list) + assert all(isinstance(turn, str) for turn in text_or_texts) + + sample[self.audio_column_name] = tts_client.tts(text_or_texts, self.voice) return sample @@ -112,7 +127,7 @@ def _map_sample(self, sample): ) from e if self.json_mode: - turns = json.loads(rendered) + turns = yaml.safe_load(rendered) assert isinstance(turns, list) assert all(isinstance(turn, dict) for turn in turns) assert len(turns) > 0 @@ -130,9 +145,9 @@ def _map_sample(self, sample): # This script is used to either generate audio samples from text using a TTS model, or to generate text samples using a text generation model. -# just ds_tool tts -d google/boolq -u fixie-ai/boolq-audio -c question -a audio --token $HF_WRITE_TOKEN -# just ds_tool textgen -d fixie-ai/boolq-audio -u fixie-ai/bar -c explanation -b https://api.fireworks.ai/inference/v1 -k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct -# just ds_tool textgen -d ylacombe/expresso -u fixie-ai/expresso -c continuation -T @expresso_template.txt +# just ds_tool tts -d google/boolq -u fixie-ai/boolq-audio -T {{question}} -a audio --token $HF_WRITE_TOKEN +# just ds_tool textgen -d fixie-ai/boolq-audio -u fixie-ai/bar -T {{explanation}} -b https://api.fireworks.ai/inference/v1 -k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct +# just ds_tool textgen -d ylacombe/expresso -u fixie-ai/expresso -T {{continuation}} -T @expresso_template.txt # just ds_tool textgen --new_column_name continuation --dataset_name openslr/librispeech_asr --dataset_subset clean --dataset_split train.360 \ # --shuffle --upload_name fixie-ai/librispeech_asr --private --base_url https://api.fireworks.ai/inference/v1 \ # --api_key $FIREWORKS_API_KEY --token $HF_TOKEN --language_model accounts/fireworks/models/llama-v3-8b-instruct \ diff --git a/ultravox/tools/ds_tool/tts.py b/ultravox/tools/ds_tool/tts.py index 78e11c1c..67f8f622 100644 --- a/ultravox/tools/ds_tool/tts.py +++ b/ultravox/tools/ds_tool/tts.py @@ -1,7 +1,7 @@ import abc import io import os -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from xml.sax import saxutils import numpy as np @@ -23,7 +23,15 @@ def _make_ssml(voice: str, text: str): class Client(abc.ABC): + DEFAULT_VOICE: str + ALL_VOICES: List[str] + def __init__(self, sample_rate: int = 16000): + if not hasattr(self, "DEFAULT_VOICE"): + raise ValueError("DEFAULT_VOICE must be defined in subclasses.") + if not hasattr(self, "ALL_VOICES"): + raise ValueError("ALL_VOICES must be defined in subclasses.") + self._session = requests.Session() retries = requests.adapters.Retry(total=NUM_RETRIES) self._session.mount( @@ -48,6 +56,14 @@ def _handle_pcm_response(self, response: requests.Response) -> bytes: sf.write(wav_bytes, pcm_array, self._sample_rate, format="wav") return wav_bytes.getvalue() + def resolve_voice(self, voice: Optional[str]) -> str: + voice = voice or self.DEFAULT_VOICE + if voice == RANDOM_VOICE_KEY: + # Every process has same random seed, so we mix in the PID here for more variation. + i = np.random.randint(len(self.ALL_VOICES)) + os.getpid() + voice = self.ALL_VOICES[i % len(self.ALL_VOICES)] + return voice + class AzureTts(Client): DEFAULT_VOICE = "en-US-JennyNeural" @@ -80,10 +96,7 @@ class AzureTts(Client): ] def tts(self, text: str, voice: Optional[str] = None): - voice = voice or self.DEFAULT_VOICE - if voice == RANDOM_VOICE_KEY: - voice = np.random.choice(self.ALL_VOICES) - assert voice + voice = self.resolve_voice(voice) region = "westus" api_key = os.environ.get("AZURE_TTS_API_KEY") or os.environ.get( "AZURE_WESTUS_TTS_API_KEY" @@ -134,11 +147,7 @@ class ElevenTts(Client): ] def tts(self, text: str, voice: Optional[str] = None): - voice = voice or self.DEFAULT_VOICE - if voice == RANDOM_VOICE_KEY: - # Every process has same random seed, so we mix in the PID here for more variation. - i = np.random.randint(len(self.ALL_VOICES)) + os.getpid() - voice = self.ALL_VOICES[i % len(self.ALL_VOICES)] + voice = self.resolve_voice(voice) url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice}/stream?output_format=pcm_16000" headers = {"xi-api-key": os.environ["ELEVEN_API_KEY"]} body = { diff --git a/ultravox/tools/ds_tool/user_messages.jinja b/ultravox/tools/ds_tool/user_messages.jinja new file mode 100644 index 00000000..b786ea72 --- /dev/null +++ b/ultravox/tools/ds_tool/user_messages.jinja @@ -0,0 +1,7 @@ +[ + {%- for message in messages[:-1] -%} + {%- if message['role'] == 'user' -%} + {{ json_dump(message['content']) }}, + {%- endif -%} + {%- endfor -%} +]