Skip to content

Commit

Permalink
[ds_tool] Tools with Audio (#62)
Browse files Browse the repository at this point in the history
* allow ds_tool to TTS multiple audios with the same voice

* add resolve_voice to tts client providers

* rename provider -> implementation
  • Loading branch information
farzadab authored Aug 5, 2024
1 parent ecd58c4 commit f937487
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 34 deletions.
13 changes: 12 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ gradio-client = "~1.0.1"
gradio = "~3.40.1"
gpustat = "~1.1.1"
types-requests = "~2.26.0"
types-pyyaml = "^6.0.12.20240724"

[build-system]
requires = ["poetry-core"]
Expand Down
6 changes: 6 additions & 0 deletions scripts/dataset_creation/tools_audio.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash

# Given the tools dataset, we want to create the audio for all user messages
just tts -d fixie-ai/tools -u fixie-ai/tools-audio --private \
-c @ultravox/tools/ds_tool/user_messages.jinja -j -a user_message_audios \
-V random --num_workers 20 -i eleven --token $HF_WRITE_TOKEN
25 changes: 19 additions & 6 deletions ultravox/tools/ds_tool/caching.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import hashlib
import json
import os
from typing import Optional
from typing import List, Optional, Union, overload

import openai

Expand Down Expand Up @@ -37,14 +37,27 @@ def chat_completion(self, **kwargs) -> str:


class CachingTtsWrapper:
def __init__(self, client: tts.Client, provider: str):
def __init__(self, client: tts.Client, implementation: str):
super().__init__()
self._client = client
self._base_path = os.path.join(".cache/ds_tool/tts", provider)
self._base_path = os.path.join(".cache/ds_tool/tts", implementation)

def tts(self, text: str, voice: Optional[str] = None) -> bytes:
path = os.path.join(self._base_path, voice or "default")
text_hash = hashlib.sha256(text.encode()).hexdigest()
@overload
def tts(self, text: str, voice: Optional[str] = None) -> bytes: ...

@overload
def tts(self, text: List[str], voice: Optional[str] = None) -> List[bytes]: ...

def tts(
self, text: Union[str, List[str]], voice: Optional[str] = None
) -> Union[bytes, List[bytes]]:
text_hash = hashlib.sha256(str(text).encode()).hexdigest()
voice = self._client.resolve_voice(voice)

if isinstance(text, list):
return [self.tts(t, voice) for t in text]

path = os.path.join(self._base_path, voice)
os.makedirs(path, exist_ok=True)

cache_path = os.path.join(path, f"{text_hash}.wav")
Expand Down
49 changes: 32 additions & 17 deletions ultravox/tools/ds_tool/ds_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import jinja2
import openai
import simple_parsing
import yaml

from ultravox.data import text_proc
from ultravox.tools.ds_tool import caching
Expand All @@ -20,7 +21,8 @@
class TtsTask:
implementation: str = simple_parsing.field(default="azure", alias="-i")
# Column name containing the text to convert to audio. It can be a Jinja variable expression.
column_name: str = simple_parsing.field(default="question", alias="-c")
template: str = simple_parsing.field(alias="-T")
json_mode: bool = simple_parsing.field(default=False, alias="-j")
audio_column_name: Optional[str] = simple_parsing.field(default=None, alias="-a")
voice: Optional[str] = simple_parsing.field(default=None, alias="-V")
sample_rate: int = simple_parsing.field(default=16000, alias="-r")
Expand All @@ -32,35 +34,48 @@ def __post_init__(self):
self.audio_column_name = f"{self.column_name}_audio"
tts_client = caching.CachingTtsWrapper(
tts.create_client(self.implementation, self.sample_rate),
provider=self.implementation,
implementation=self.implementation,
)

if self.template.startswith("@"):
with open(self.template[1:], "r") as template_file:
self.template = template_file.read()

def map_split(
self, ds_split: datasets.Dataset, num_proc: int, writer_batch_size: int
) -> datasets.Dataset:
print(f'TTS mapping "{self.column_name}" to "{self.audio_column_name}"...')
return ds_split.map(
print(f'TTS mapping "{self.template}" to "{self.audio_column_name}"...')
ds_split = ds_split.map(
self._map_sample, num_proc=num_proc, writer_batch_size=writer_batch_size
).cast_column(
self.audio_column_name, datasets.Audio(sampling_rate=self.sample_rate)
)
column_type = datasets.Audio(sampling_rate=self.sample_rate)
if self.json_mode and isinstance(
ds_split.features[self.audio_column_name], datasets.Sequence
):
column_type = datasets.Sequence(column_type)
return ds_split.cast_column(self.audio_column_name, column_type)

def _map_sample(self, sample):
# using a Jinja template for some added flexibility
# The {{ var }} syntax is how Jinja denotes variables
# using a Jinja template for some added flexibility, template can include variables and functions
# e.g., {{ text }} or {{ text_proc.format_asr_text(text) }}
try:
text = jinja2.Template(
"{{" + self.column_name + "}}", undefined=jinja2.StrictUndefined
).render(**sample)
text_or_texts = jinja2.Template(
self.template, undefined=jinja2.StrictUndefined
).render(**sample, json_dump=json.dumps, text_proc=text_proc)
except jinja2.TemplateError as e:
print(f"Error rendering template: {e}")
print(f"column_name: {self.column_name}")
print(f"template: {self.template}")
print(f"sample keys: {list(sample.keys())}")
raise ValueError(
f"Template rendering failed. Make sure column_name exists in the sample."
) from e

sample[self.audio_column_name] = tts_client.tts(text, self.voice)
if self.json_mode:
text_or_texts = yaml.safe_load(text_or_texts)
assert isinstance(text_or_texts, list)
assert all(isinstance(turn, str) for turn in text_or_texts)

sample[self.audio_column_name] = tts_client.tts(text_or_texts, self.voice)
return sample


Expand Down Expand Up @@ -112,7 +127,7 @@ def _map_sample(self, sample):
) from e

if self.json_mode:
turns = json.loads(rendered)
turns = yaml.safe_load(rendered)
assert isinstance(turns, list)
assert all(isinstance(turn, dict) for turn in turns)
assert len(turns) > 0
Expand All @@ -130,9 +145,9 @@ def _map_sample(self, sample):


# This script is used to either generate audio samples from text using a TTS model, or to generate text samples using a text generation model.
# just ds_tool tts -d google/boolq -u fixie-ai/boolq-audio -c question -a audio --token $HF_WRITE_TOKEN
# just ds_tool textgen -d fixie-ai/boolq-audio -u fixie-ai/bar -c explanation -b https://api.fireworks.ai/inference/v1 -k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct
# just ds_tool textgen -d ylacombe/expresso -u fixie-ai/expresso -c continuation -T @expresso_template.txt
# just ds_tool tts -d google/boolq -u fixie-ai/boolq-audio -T {{question}} -a audio --token $HF_WRITE_TOKEN
# just ds_tool textgen -d fixie-ai/boolq-audio -u fixie-ai/bar -T {{explanation}} -b https://api.fireworks.ai/inference/v1 -k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct
# just ds_tool textgen -d ylacombe/expresso -u fixie-ai/expresso -T {{continuation}} -T @expresso_template.txt
# just ds_tool textgen --new_column_name continuation --dataset_name openslr/librispeech_asr --dataset_subset clean --dataset_split train.360 \
# --shuffle --upload_name fixie-ai/librispeech_asr --private --base_url https://api.fireworks.ai/inference/v1 \
# --api_key $FIREWORKS_API_KEY --token $HF_TOKEN --language_model accounts/fireworks/models/llama-v3-8b-instruct \
Expand Down
29 changes: 19 additions & 10 deletions ultravox/tools/ds_tool/tts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import io
import os
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
from xml.sax import saxutils

import numpy as np
Expand All @@ -23,7 +23,15 @@ def _make_ssml(voice: str, text: str):


class Client(abc.ABC):
DEFAULT_VOICE: str
ALL_VOICES: List[str]

def __init__(self, sample_rate: int = 16000):
if not hasattr(self, "DEFAULT_VOICE"):
raise ValueError("DEFAULT_VOICE must be defined in subclasses.")
if not hasattr(self, "ALL_VOICES"):
raise ValueError("ALL_VOICES must be defined in subclasses.")

self._session = requests.Session()
retries = requests.adapters.Retry(total=NUM_RETRIES)
self._session.mount(
Expand All @@ -48,6 +56,14 @@ def _handle_pcm_response(self, response: requests.Response) -> bytes:
sf.write(wav_bytes, pcm_array, self._sample_rate, format="wav")
return wav_bytes.getvalue()

def resolve_voice(self, voice: Optional[str]) -> str:
voice = voice or self.DEFAULT_VOICE
if voice == RANDOM_VOICE_KEY:
# Every process has same random seed, so we mix in the PID here for more variation.
i = np.random.randint(len(self.ALL_VOICES)) + os.getpid()
voice = self.ALL_VOICES[i % len(self.ALL_VOICES)]
return voice


class AzureTts(Client):
DEFAULT_VOICE = "en-US-JennyNeural"
Expand Down Expand Up @@ -80,10 +96,7 @@ class AzureTts(Client):
]

def tts(self, text: str, voice: Optional[str] = None):
voice = voice or self.DEFAULT_VOICE
if voice == RANDOM_VOICE_KEY:
voice = np.random.choice(self.ALL_VOICES)
assert voice
voice = self.resolve_voice(voice)
region = "westus"
api_key = os.environ.get("AZURE_TTS_API_KEY") or os.environ.get(
"AZURE_WESTUS_TTS_API_KEY"
Expand Down Expand Up @@ -134,11 +147,7 @@ class ElevenTts(Client):
]

def tts(self, text: str, voice: Optional[str] = None):
voice = voice or self.DEFAULT_VOICE
if voice == RANDOM_VOICE_KEY:
# Every process has same random seed, so we mix in the PID here for more variation.
i = np.random.randint(len(self.ALL_VOICES)) + os.getpid()
voice = self.ALL_VOICES[i % len(self.ALL_VOICES)]
voice = self.resolve_voice(voice)
url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice}/stream?output_format=pcm_16000"
headers = {"xi-api-key": os.environ["ELEVEN_API_KEY"]}
body = {
Expand Down
7 changes: 7 additions & 0 deletions ultravox/tools/ds_tool/user_messages.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[
{%- for message in messages[:-1] -%}
{%- if message['role'] == 'user' -%}
{{ json_dump(message['content']) }},
{%- endif -%}
{%- endfor -%}
]

0 comments on commit f937487

Please sign in to comment.