diff --git a/Justfile b/Justfile index bafe39db..269bdbba 100644 --- a/Justfile +++ b/Justfile @@ -3,6 +3,7 @@ export WANDB_LOG_MODEL:="checkpoint" export PROJECT_DIR:="ultravox" export MCLOUD_CLUSTER:="r7z22p1" export MCLOUD_INSTANCE:="oci.bm.gpu.b4.8" +export MFA_ENV_NAME:="aligner" default: format check test @@ -62,3 +63,33 @@ run *FLAGS: mcloud *FLAGS: poetry run mcli interactive {{FLAGS}} --cluster ${MCLOUD_CLUSTER} --instance ${MCLOUD_INSTANCE} --name `whoami` --command "bash -c \"$(cat setup.sh)\"" + +@check_conda: + if ! command -v conda &> /dev/null; then \ + echo "Conda is not installed."; \ + mkdir -p ~/miniconda3; \ + if [ "$(uname)" = "Darwin" ]; then \ + echo "Downloading MacOS Miniconda."; \ + curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh; \ + elif [ "$(uname)" = "Linux" ]; then \ + echo "Downloading Linux Miniconda."; \ + wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh \ + else \ + echo "Unknown operating system."; \ + fi; \ + bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3; \ + rm ~/miniconda3/miniconda.sh; \ + else \ + echo "Conda is installed."; \ + fi + +@install_mfa: check_conda + if conda env list | grep -q "$MFA_ENV_NAME"; then \ + echo "Environment '$MFA_ENV_NAME' already exists."; \ + else \ + echo "Creating environment '$MFA_ENV_NAME'."; \ + conda create --name "$MFA_ENV_NAME" python=3.8 -y; \ + conda create -n "$MFA_ENV_NAME" -c conda-forge montreal-forced-aligner; \ + conda run -n "$MFA_ENV_NAME" mfa model download acoustic english_mfa; \ + conda run -n "$MFA_ENV_NAME" mfa model download dictionary english_mfa; \ + fi \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 6ee9b992..7465e884 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5124,6 +5124,20 @@ docs = ["sphinx (>=1.7.1)"] redis = ["redis"] tests = ["pytest (>=5.4.1)", "pytest-cov (>=2.8.1)", "pytest-mypy (>=0.8.0)", "pytest-timeout (>=2.1.0)", "redis", "sphinx (>=6.0.0)", "types-redis"] +[[package]] +name = "praatio" +version = "6.2.0" +description = "A library for working with praat, textgrids, time aligned audio transcripts, and audio files." +optional = false +python-versions = ">3.6.0" +files = [ + {file = "praatio-6.2.0-py3-none-any.whl", hash = "sha256:6541018791a3f0b087a8168d1746a165937c3fff1f94c7a6883b3f470e0cf405"}, + {file = "praatio-6.2.0.tar.gz", hash = "sha256:7d2a7f8135a3e0691743ada0af84308b64e637f07038cea77d814b8aa2fa2e40"}, +] + +[package.dependencies] +typing-extensions = "*" + [[package]] name = "preshed" version = "3.0.9" @@ -8883,4 +8897,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.10,<4.0" -content-hash = "798d26eeecb0625e6e6b655f7209286319924de573c9cdd9a30416593a492cb5" +content-hash = "f1d462cee8239c355f81406ff7ee88e42fd85b955abec54aac629ef2cd4a4cce" diff --git a/pyproject.toml b/pyproject.toml index 21c6734c..d11c699c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ wandb = "~0.17.1" sacrebleu = "^2.4.2" tenacity = "^9.0.0" evals = {git = "https://github.com/fixie-ai/evals", rev = "0c66bf85df7a4b903ecb202b23c2a826b749fd71"} +praatio = "^6.2.0" [tool.poetry.group.dev.dependencies] black = "~24.4.2" diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index 750f62e4..50d7e88e 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -1,14 +1,22 @@ import dataclasses +import glob import json import math import os -from typing import Any, Dict, List, Optional, Tuple, Union +import subprocess +import tempfile +import traceback +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Tuple, Union import datasets import jinja2 +import librosa import openai import simple_parsing +import soundfile as sf import yaml +from praatio import textgrid from tenacity import retry from tenacity import stop_after_attempt from tenacity import wait_fixed @@ -21,6 +29,41 @@ tts_client: caching.CachingTtsWrapper chat_client: caching.CachingChatWrapper +MFA_ENV_NAME = "aligner" + + +def apply_jinja_template( + template: str, sample: Dict[str, Any], exclude_fields: Optional[Set[str]] = None +): + """ + Apply a Jinja template to a sample, rendering it into text. + Jinja template allows for added flexibility as template can include variables and functions. + + Args: + template: The Jinja template to apply. It can include variables, functions, and control structures. + Example: + {{ text }} + {{ text_proc.format_asr_text(text) }} + sample: The sample to apply the template to. + exclude_fields: Fields to exclude from the sample before rendering the template, to avoid loading large fields into memory. + """ + if exclude_fields: + # Filter out big fields like audio before the sample is passed into the jinja template + # otherwise it can lead to unnecessary memory usage. + sample = {k: sample[k] for k in sample.keys() if k not in exclude_fields} + + try: + return jinja2.Template(template, undefined=jinja2.StrictUndefined).render( + **sample, json_dump=json.dumps, text_proc=text_proc + ) + except jinja2.TemplateError as e: + print(f"Error rendering template: {e}") + print(f"template: {template}") + print(f"sample keys: {list(sample.keys())}, excluded keys: {exclude_fields}") + raise ValueError( + f"Template rendering failed. Make sure all keys in the template exist in the sample." + ) from e + @dataclasses.dataclass class TtsTask: @@ -55,7 +98,10 @@ def map_split( ) -> datasets.Dataset: print(f'TTS mapping "{self.template}" to "{self.audio_column_name}"...') ds_split = ds_split.map( - self._map_sample, num_proc=num_proc, writer_batch_size=writer_batch_size + self._map_sample, + num_proc=num_proc, + writer_batch_size=writer_batch_size, + fn_kwargs={"exclude_fields": exclude_fields}, ) column_type = datasets.Audio(sampling_rate=self.sample_rate) if self.json_mode and isinstance( @@ -64,20 +110,10 @@ def map_split( column_type = datasets.Sequence(column_type) return ds_split.cast_column(self.audio_column_name, column_type) - def _map_sample(self, sample): + def _map_sample(self, sample, exclude_fields: Set[str]): # using a Jinja template for some added flexibility, template can include variables and functions # e.g., {{ text }} or {{ text_proc.format_asr_text(text) }} - try: - text_or_texts = jinja2.Template( - self.template, undefined=jinja2.StrictUndefined - ).render(**sample, json_dump=json.dumps, text_proc=text_proc) - except jinja2.TemplateError as e: - print(f"Error rendering template: {e}") - print(f"template: {self.template}") - print(f"sample keys: {list(sample.keys())}") - raise ValueError( - f"Template rendering failed. Make sure column_name exists in the sample." - ) from e + text_or_texts = apply_jinja_template(self.template, sample, exclude_fields) if self.json_mode: text_or_texts = yaml.safe_load(text_or_texts) @@ -137,24 +173,11 @@ def _map_sample(self, sample, exclude_fields): # using a Jinja template for some added flexibility, template can include variables and functions # e.g., {{ text }} or {{ text_proc.format_asr_text(text) }} try: - # Filter out the audio before the sample is passed into the jinja template, or it will get loaded into memory. - filtered_sample = { - k: sample[k] for k in sample.keys() if k not in exclude_fields - } - rendered = jinja2.Template( - self.template, undefined=jinja2.StrictUndefined - ).render(**filtered_sample, json_dump=json.dumps, text_proc=text_proc) + rendered = apply_jinja_template(self.template, sample, exclude_fields) except text_proc.FormatASRError as e: print(f"Format ASR Error {e}. Filtering out sample.") sample[self.new_column_name] = None return sample - except jinja2.TemplateError as e: - print(f"Error rendering template: {e}") - print(f"template: {self.template}") - print(f"sample keys: {list(filtered_sample.keys())}") - raise ValueError( - f"Template rendering failed. Make sure all keys in the template exist in the sample." - ) from e if self.json_mode: turns = yaml.safe_load(rendered) @@ -175,6 +198,163 @@ def _map_sample(self, sample, exclude_fields): return sample +@dataclasses.dataclass +class TimestampGenerationTask: + """ + This task is used to generate timestamps for the text transcription. + It uses the Montreal Forced Aligner (MFA) to align the text with the audio. The result is a + list of timestamps for each word in the text transcription. The timestamps are stored in a new + column, in a list of dict format: + [ {"start": float in seconds, "end": float in seconds, "text": first word str}, ... ] + """ + + # Jinja template for the text transcription that needs to be aligned + template: str = simple_parsing.field(alias="-T") + # The accoustic model to use for MFA alignment. + # Make sure the dictionary and acoustic model are installed. See just install_mfa for an example (English). + # Model index: https://mfa-models.readthedocs.io/en/latest/acoustic/index.html + # For many languages there exists a {language}_mfa model that you can use, e.g. "english_mfa" + mfa_acoustic_model: str = simple_parsing.field(alias="-m") + # The dictionary to use for MFA alignment. Defaults to the same name as the acoustic model. + mfa_dictionary: Optional[str] = simple_parsing.field(default=None, alias="-d") + audio_column_name: str = simple_parsing.field(default="audio", alias="-a") + sample_rate: int = simple_parsing.field(default=16000, alias="-r") + # The column name to store the timestamps in + timestamp_column_name: str = simple_parsing.field(default="timestamps", alias="-ts") + aligned_ratio_check: float = simple_parsing.field(default=0.95, alias="-ar") + + def __post_init__(self): + if self.mfa_dictionary is None: + self.mfa_dictionary = self.mfa_acoustic_model + + try: + # Make sure the MFA environment is installed + subprocess.run(["conda", "run", "-n", MFA_ENV_NAME, "echo"], check=True) + except subprocess.CalledProcessError: + raise Exception( + "Please install the MFA environment using `just install_mfa` first." + ) + + if self.template.startswith("@"): + with open(self.template[1:], "r") as template_file: + self.template = template_file.read() + + def map_split( + self, + ds_split: datasets.Dataset, + num_proc: int, + writer_batch_size: int, + exclude_fields: List[str], + ) -> datasets.Dataset: + # 0. create a temp directory to store the audio and text files + # The files will be deleted when the with block ends or when an exception is raised + with tempfile.TemporaryDirectory() as temp_dir: + # 1. copy all audio-text pairs into the temp directory + ds_split.map( + self._store_sample_as_files, + num_proc=num_proc, + fn_kwargs={"exclude_fields": set(exclude_fields), "temp_dir": temp_dir}, + ) + + count_wavs = len(glob.glob(os.path.join(temp_dir, "*.wav"))) + assert count_wavs == len( + ds_split + ), "Not all samples were stored as files. The id is likely not unique." + + # 2. run the alignment + self._run_alignment(temp_dir, num_proc=num_proc) + + # 3. retrieve the timestamps + ds_mapped = ds_split.map( + self._retrieve_timestamps, + num_proc=num_proc, + writer_batch_size=writer_batch_size, + fn_kwargs={"temp_dir": temp_dir}, + ) + + # 4. filter out samples without timestamps (should be a small number) + ds_mapped = ds_mapped.filter( + lambda sample: sample[self.timestamp_column_name] is not None, + num_proc=num_proc, + writer_batch_size=writer_batch_size, + ) + + # 5. make sure most samples have timestamps + if len(ds_split) * self.aligned_ratio_check > len(ds_mapped): + raise Exception( + f"Found too many samples without timestamps: {len(ds_mapped)}/{len(ds_split)} aligned." + ) + + return ds_mapped + + def _retrieve_timestamps(self, sample, temp_dir: str): + # find the timestamps for the audio and populate the timestamps column + sample_id = self.get_id(sample) + text_path = os.path.join(temp_dir, f"{sample_id}.TextGrid") + if not os.path.exists(text_path): + sample[self.timestamp_column_name] = None + return sample + + tg = textgrid.openTextgrid(text_path, False) + timestamps = tg.getTier("words").entries + sample[self.timestamp_column_name] = [ + {"start": entry.start, "end": entry.end, "text": entry.label} + for entry in timestamps + ] + return sample + + @staticmethod + def get_id(sample): + for key in ["id", "segment_id"]: + if key in sample and isinstance(sample[key], str): + return str(sample[key]) + for key in ["file", "path", "audio_file"]: + if key in sample and isinstance(sample[key], str): + return Path(sample[key]).stem + raise ValueError("Could not find an ID in the sample") + + def _store_sample_as_files(self, sample, temp_dir: str, exclude_fields: Set[str]): + sample_id = self.get_id(sample) + audio_path = os.path.join(temp_dir, f"{sample_id}.wav") + with open(audio_path, "wb") as f: + audio = sample[self.audio_column_name] + if audio["sampling_rate"] != self.sample_rate: + audio["array"] = librosa.resample( + audio["array"], + orig_sr=audio["sampling_rate"], + target_sr=self.sample_rate, + ) + sf.write(f, audio["array"], 16000, format="WAV", subtype="PCM_16") + + text_path = os.path.join(temp_dir, f"{sample_id}.txt") + text = apply_jinja_template(self.template, sample, exclude_fields) + with open(text_path, "w") as f: + f.write(text) + + def _run_alignment(self, temp_dir: str, num_proc: int = 16) -> None: + subprocess.run( + [ + "conda", + "run", + "--no-capture-output", + "-n", + MFA_ENV_NAME, + "mfa", + "align", + "--clean", + "--single_speaker", + "--use_mp", + "-j", + str(num_proc), + temp_dir, + self.mfa_acoustic_model, + str(self.mfa_dictionary), + temp_dir, + ], + check=True, + ) + + # This script is used to either generate audio samples from text using a TTS model, or to generate text samples using a text generation model. # just ds_tool tts -d google/boolq -u fixie-ai/boolq-audio -T {{question}} -a audio --token $HF_WRITE_TOKEN # just ds_tool textgen -d fixie-ai/boolq-audio -u fixie-ai/bar -T {{explanation}} -b https://api.fireworks.ai/inference/v1 -k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct @@ -183,6 +363,8 @@ def _map_sample(self, sample, exclude_fields): # --shuffle --upload_name fixie-ai/librispeech_asr --private --base_url https://api.fireworks.ai/inference/v1 \ # --api_key $FIREWORKS_API_KEY --token $HF_TOKEN --language_model accounts/fireworks/models/llama-v3-8b-instruct \ # --template @ultravox/tools/ds_tool/continuation.jinja --max_tokens 64 --num_workers 30 --writer_batch_size 30 +# just ds_tool timestamp -d fixie-ai/common_voice_17_0 -S en --upload_name fixie-ai/cv_ts \ +# -m english_mfa -T "\"{{text_proc.format_asr_text(sentence)}}\"" @dataclasses.dataclass class DatasetToolArgs: # HF source dataset parameters @@ -218,10 +400,12 @@ class DatasetToolArgs: default_factory=lambda: ["audio"] ) - task: Union[TtsTask, TextGenerationTask] = simple_parsing.subgroups( - {"tts": TtsTask, "textgen": TextGenerationTask}, # type: ignore - default_factory=TtsTask, - positional=True, + task: Union[TtsTask, TextGenerationTask, TimestampGenerationTask] = ( + simple_parsing.subgroups( + {"tts": TtsTask, "textgen": TextGenerationTask, "timestamp": TimestampGenerationTask}, # type: ignore + default_factory=TtsTask, + positional=True, + ) ) def __post_init__(self): @@ -232,6 +416,11 @@ def __post_init__(self): if self.dataset_split and not self.upload_split: self.upload_split = self.dataset_split + if self.upload_name == self.dataset_name: + raise ValueError( + "Updating datasets in-place is not well-supported and hence frowned upon." + ) + class DatasetChunkProcessor: args: DatasetToolArgs @@ -301,6 +490,7 @@ def process_and_upload_split_rescursive( # then the huggingface README needs to be updated to have the # download_size, and dataset_size fields present under dataset_info (could be initalized to 0) print(f"Failed to upload chunk {ds_chunk_name}: {e}. Retrying later.") + print(traceback.format_exc()) if total_chunks == 1: print( f"Finished processing and uploading 0/1 chunks for range [{start_index}, {end_index})"