Skip to content

Commit

Permalink
Add interleave config
Browse files Browse the repository at this point in the history
  • Loading branch information
liPatrick committed Sep 20, 2024
1 parent 9deaa7f commit 2a1bcc5
Show file tree
Hide file tree
Showing 7 changed files with 319 additions and 291 deletions.
25 changes: 20 additions & 5 deletions ultravox/data/dataset_config.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,38 @@
import dataclasses
from enum import Enum
from typing import List, Optional

from pydantic import BaseModel


class StopStrategy(str, Enum):
FIRST_EXHAUSTED = "first_exhausted"
LAST_EXHAUSTED = "last_exhausted"
NEVER_STOP = "never_stop"


class DataDictConfig(BaseModel):
# Path to the dataset, or huggingface dataset id
path: str
# Name of the dataset, or huggingface dataset config/subset
path: str # Name of the dataset, or huggingface dataset config/subset
name: Optional[str] = None
splits: List[str] = dataclasses.field(default_factory=list)
num_samples: Optional[int] = None
total_samples: int = 1
multiplier: float = 1.0
streaming: bool = True
user_template: str = "<|audio|>"
assistant_template: str = "{{text}}"
transcript_template: str = "{{text}}"

def __post_init__(self):
def post_init(self):
if not self.splits:
raise ValueError("At least one split must be provided")


class DatasetMultiplier(BaseModel):
dataset: DataDictConfig
multiplier: float


class InterleaveDataConfig(BaseModel):
# In InterleaveDataset, when to stop interleave: choose from last_exhausted (default), first_exhausted, or never_stop
stop_strategy: StopStrategy = StopStrategy.LAST_EXHAUSTED
datasets_with_multiplier: List[DatasetMultiplier]
44 changes: 21 additions & 23 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import os
import tempfile
import warnings
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence

import datasets
Expand Down Expand Up @@ -296,17 +295,12 @@ def __init__(self, args: VoiceDatasetArgs) -> None:
self._args = args
self._session: Optional[requests.Session] = None
self._rng = np.random.default_rng(self._args.shuffle_seed)
self._multiplier = 1.0

def _init_dataset(self, dataset: data.Dataset, estimated_length: int = 1) -> None:
self._dataset = dataset
# Only required when using epochs when training dataset.
self._estimated_length = estimated_length

@property
def multiplier(self) -> float:
return self._multiplier

def _load_audio_dataset(
self,
path: str,
Expand Down Expand Up @@ -1049,8 +1043,6 @@ def __init__(
if self._args.shuffle:
dataset = dataset.shuffle(seed=self._args.shuffle_seed)

self._multiplier = config.multiplier

self.user_template = config.user_template
self.assistant_template = config.assistant_template
self.transcript_template = config.transcript_template
Expand Down Expand Up @@ -1089,7 +1081,9 @@ def _get_sample(self, row) -> VoiceSample:
)


def create_dataset(name: str, args: VoiceDatasetArgs) -> SizedIterableDataset:
def create_dataset(
name: str | dataset_config.DataDictConfig, args: VoiceDatasetArgs
) -> SizedIterableDataset:
DATASET_MAP: Dict[str, Any] = {
"anyinstruct": AnyInstructAnswerDataset,
"anyinstruct_in": AnyInstructInputDataset,
Expand All @@ -1115,21 +1109,16 @@ def create_dataset(name: str, args: VoiceDatasetArgs) -> SizedIterableDataset:
return DATASET_MAP[name](args, *ext)


class StopStrategy(str, Enum):
FIRST_EXHAUSTED = "first_exhausted"
LAST_EXHAUSTED = "last_exhausted"
NEVER_STOP = "never_stop"


class InterleaveDataset(SizedIterableDataset):
"""Interleaves multiple IterableDataset objects based on multiplier."""

def __init__(
self,
datasets: Sequence[SizedIterableDataset],
stop_strategy: StopStrategy = StopStrategy.LAST_EXHAUSTED,
stop_strategy: dataset_config.StopStrategy = dataset_config.StopStrategy.LAST_EXHAUSTED,
seed: Optional[int] = 42,
static: bool = False,
multipliers: Optional[List[float]] = None,
) -> None:
"""
Args:
Expand All @@ -1143,10 +1132,15 @@ def __init__(
self._static = static

self._stop_strategy = stop_strategy
if multipliers is None:
self._multipliers = [1.0] * len(datasets)
else:
self._multipliers = multipliers

relative_frequencies = [
int(getattr(ds, "multiplier", 1.0) * float(len(ds))) for ds in datasets
int(float(len(ds)) * multiple)
for ds, multiple in zip(datasets, self._multipliers)
]

total_frequency = sum(relative_frequencies)
self._normalized_probs = [f / total_frequency for f in relative_frequencies]

Expand All @@ -1173,9 +1167,13 @@ def __iter__(self):
exhausted[iter_index] = True

# Check if stopping condition is met
if self._stop_strategy == StopStrategy.FIRST_EXHAUSTED or (
self._stop_strategy == StopStrategy.LAST_EXHAUSTED
and all(exhausted)
if (
self._stop_strategy == dataset_config.StopStrategy.FIRST_EXHAUSTED
or (
self._stop_strategy
== dataset_config.StopStrategy.LAST_EXHAUSTED
and all(exhausted)
)
):
break

Expand All @@ -1186,8 +1184,8 @@ def __iter__(self):
def __len__(self) -> int:
# TODO: Implement the length method for different stop strategies
return sum(
int(getattr(ds, "multiplier", 1.0) * float(len(ds)))
for ds in self._datasets
int(float(len(ds)) * multiple)
for ds, multiple in zip(datasets, self._multipliers)
)


Expand Down
21 changes: 7 additions & 14 deletions ultravox/data/datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,10 @@
class FakeSizedIterableDataset(datasets.SizedIterableDataset):
"""Fake version of datasets.SizedIterableDataset"""

def __init__(self, n, start=0, multiplier=1, estimated_length=1):
def __init__(self, n, start=0, estimated_length=1):
self.data = range(start, start + n)
self._multiplier = multiplier
self._estimated_length = estimated_length

@property
def multiplier(self) -> float:
return self._multiplier

def __iter__(self):
for sample in self.data:
yield sample
Expand Down Expand Up @@ -98,7 +93,7 @@ def test_interleaved_first_exhausted():
ds3 = FakeSizedIterableDataset(3)
s = datasets.InterleaveDataset(
[ds1, ds2, ds3],
stop_strategy=datasets.StopStrategy.FIRST_EXHAUSTED,
stop_strategy=dataset_config.StopStrategy.FIRST_EXHAUSTED,
static=True,
)
# static=True disables random sampling of datasets, so the order is deterministic
Expand All @@ -113,7 +108,7 @@ def test_interleaved_last_exhausted():
ds2 = FakeSizedIterableDataset(2, start=10)
s = datasets.InterleaveDataset(
[ds1, ds2],
stop_strategy=datasets.StopStrategy.LAST_EXHAUSTED,
stop_strategy=dataset_config.StopStrategy.LAST_EXHAUSTED,
static=True,
)
# static=True disables random sampling of datasets, so the order is deterministic
Expand All @@ -126,7 +121,7 @@ def test_interleaved_never_stop():
ds2 = FakeSizedIterableDataset(2, start=10)
s = datasets.InterleaveDataset(
[ds1, ds2],
stop_strategy=datasets.StopStrategy.NEVER_STOP,
stop_strategy=dataset_config.StopStrategy.NEVER_STOP,
static=True,
)
# static=True disables random sampling of datasets, so the order is deterministic
Expand All @@ -135,11 +130,9 @@ def test_interleaved_never_stop():


def test_interleaved_random():
ds1 = FakeSizedIterableDataset(4, multiplier=10)
ds2 = FakeSizedIterableDataset(2, start=10, multiplier=1)
s = datasets.InterleaveDataset(
[ds1, ds2],
)
ds1 = FakeSizedIterableDataset(4)
ds2 = FakeSizedIterableDataset(2, start=10)
s = datasets.InterleaveDataset([ds1, ds2], multipliers=[10, 1])
# stop_strategy=last_exhausted will stop interleaving when the last dataset is exhausted (attempted after exhaustion)
assert list(s) == [
0,
Expand Down
18 changes: 2 additions & 16 deletions ultravox/training/config_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,17 @@
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import List, Optional

import simple_parsing
import torch

from ultravox.data import dataset_config
from ultravox.data import datasets
from ultravox.model import ultravox_config


@dataclasses.dataclass
class TrainConfig:
data_sets: List[str]
val_sets: List[str]
# language model to use
text_model: str
Expand All @@ -29,13 +27,11 @@ class TrainConfig:
# we first parse the data_dicts as a list of dictionaries. After parsing,
# we convert these dictionaries to DataDictConfig objects using Pydantic
# to enforce type constraints and validation, in the __post_init__ method.
data_dicts: Optional[List[Dict[str, Any]]] = None
interleave_datasets: dataset_config.InterleaveDataConfig

do_train: bool = True
do_eval: bool = True

# In InterleaveDataset, when to stop interleave: choose from last_exhausted (default), first_exhausted, or never_stop
stop_strategy: datasets.StopStrategy = datasets.StopStrategy.LAST_EXHAUSTED
data_dir: Optional[str] = None
mds: bool = False
num_samples: Optional[int] = None
Expand Down Expand Up @@ -91,16 +87,6 @@ class TrainConfig:
loss_config: Optional[ultravox_config.LossConfig] = None

def __post_init__(self):
if self.data_dicts:
self.data_dicts = [
dataset_config.DataDictConfig(**data_dict)
for data_dict in self.data_dicts
]
# For now, self.data_dicts is a hack to allow for the inclusion of new datasets using the
# GenericVoiceDataset class, without changing how existing datasets are specified in
# self.data_sets. In the future, all datasets will be updated to use the DataDictConfig class.
self.data_sets.extend(self.data_dicts)

assert self.data_type in ["bfloat16", "float16", "float32"]
if self.device == "cuda" and not torch.cuda.is_available():
self.device = "mps" if torch.backends.mps.is_available() else "cpu"
Expand Down
2 changes: 0 additions & 2 deletions ultravox/training/configs/meta_config.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
text_model: "meta-llama/Meta-Llama-3-8B-Instruct"
audio_model: "facebook/wav2vec2-base-960h"

data_sets: [] # Can't use datasets with "multiplier" because it doesn't have a length attribute.
val_sets: ["heysquad_human", "anyinstruct", "soda", "peoplespeech"]
stop_strategy: "LAST_EXHAUSTED"

train_on_inputs: False
shuffle_data: True
Expand Down
Loading

0 comments on commit 2a1bcc5

Please sign in to comment.