Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replacing weight with multiplier #105

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion ultravox/data/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class DataDictConfig(BaseModel):
splits: List[str] = dataclasses.field(default_factory=list)
num_samples: Optional[int] = None
total_samples: int = 1
weight: float = 1.0
multiplier: float = 1.0
streaming: bool = True
user_template: str = "<|audio|>"
assistant_template: str = "{{text}}"
Expand Down
44 changes: 25 additions & 19 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ class SizedIterableDataset(abc.ABC, data.IterableDataset):
"""

@abc.abstractmethod
def __len__(self):
def __len__(self) -> int:
pass


Expand All @@ -296,16 +296,16 @@ def __init__(self, args: VoiceDatasetArgs) -> None:
self._args = args
self._session: Optional[requests.Session] = None
self._rng = np.random.default_rng(self._args.shuffle_seed)
self._weight = 1.0 # the default weight for the dataset
self._multiplier = 1.0

def _init_dataset(self, dataset: data.Dataset, estimated_length: int = 1) -> None:
self._dataset = dataset
# Only required when using epochs when training dataset.
self._estimated_length = estimated_length

@property
def weight(self) -> float:
return self._weight
def multiplier(self) -> float:
return self._multiplier

def _load_audio_dataset(
self,
Expand Down Expand Up @@ -373,7 +373,7 @@ def __iter__(self):
f"Mismatch between estimated length ({self._estimated_length}) and actual length ({actual_length}) for dataset of type {type(self._dataset)}. Make sure to update."
)

def __len__(self):
def __len__(self) -> int:
return self._estimated_length

@abc.abstractmethod
Expand Down Expand Up @@ -493,7 +493,7 @@ def __init__(self, estimated_length: int = 1) -> None:
def __iter__(self):
return iter([])

def __len__(self):
def __len__(self) -> int:
return self._estimated_length


Expand Down Expand Up @@ -1049,16 +1049,17 @@ def __init__(
if self._args.shuffle:
dataset = dataset.shuffle(seed=self._args.shuffle_seed)

if config.num_samples:
dataset = Range(dataset, config.num_samples, config.total_samples)

self._weight = config.weight
self._multiplier = config.multiplier

self.user_template = config.user_template
self.assistant_template = config.assistant_template
self.transcript_template = config.transcript_template

super()._init_dataset(dataset, config.total_samples)
if config.num_samples:
dataset = Range(dataset, config.num_samples, config.total_samples)
super()._init_dataset(dataset, len(dataset))
else:
super()._init_dataset(dataset, config.total_samples)

def _get_sample(self, row) -> VoiceSample:
try:
Expand Down Expand Up @@ -1121,7 +1122,7 @@ class StopStrategy(str, Enum):


class InterleaveDataset(SizedIterableDataset):
"""Interleaves multiple IterableDataset objects based on normalized weights."""
"""Interleaves multiple IterableDataset objects based on multiplier."""

def __init__(
self,
Expand All @@ -1142,10 +1143,12 @@ def __init__(
self._static = static

self._stop_strategy = stop_strategy
relative_frequencies = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think life might be easier if we didn't try to read the multiplier value out of each dataset class, and instead read it from a config passed to the interleave class, similar to how this is done with probabilities in HF's interleave_datasets: https://huggingface.co/docs/datasets/en/package_reference/main_classes#datasets.interleave_datasets

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we wouldn't have to worry about the interaction with len, etc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we'd need to calculate the probabilities by hand in the config? I think we'd need len either way because that's how hf determines the # of steps to take in epoch mode.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the weights are being set by hand already, right? It just seems a bit strange to have to read the weight/multiplier property from the dataset class, especially when the dataset class doesn't use it internally.

int(getattr(ds, "multiplier", 1.0) * float(len(ds))) for ds in datasets
]

weights = [getattr(ds, "weight", 1) for ds in datasets]
total_weight = sum(weights)
self._normalized_probs = [w / total_weight for w in weights]
total_frequency = sum(relative_frequencies)
self._normalized_probs = [f / total_frequency for f in relative_frequencies]

def __iter__(self):
# If no datasets are provided, return an empty iterator
Expand Down Expand Up @@ -1180,9 +1183,12 @@ def __iter__(self):
iters[iter_index] = iter(self._datasets[iter_index])
yield next(iters[iter_index])

def __len__(self):
def __len__(self) -> int:
# TODO: Implement the length method for different stop strategies
return sum(len(ds) for ds in self._datasets)
return sum(
int(getattr(ds, "multiplier", 1.0) * float(len(ds)))
for ds in self._datasets
)


class Dataproc(SizedIterableDataset):
Expand All @@ -1198,7 +1204,7 @@ def _process(self, sample: VoiceSample) -> Dict[str, Any]:
def __iter__(self):
return (self._process(sample) for sample in self._dataset)

def __len__(self):
def __len__(self) -> int:
return len(self._dataset)


Expand Down Expand Up @@ -1234,7 +1240,7 @@ def __iter__(self):
break
yield sample

def __len__(self):
def __len__(self) -> int:
return (
self._num_samples
if self._num_samples is not None
Expand Down
12 changes: 6 additions & 6 deletions ultravox/data/datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
class FakeSizedIterableDataset(datasets.SizedIterableDataset):
"""Fake version of datasets.SizedIterableDataset"""

def __init__(self, n, start=0, weight=1, estimated_length=0):
def __init__(self, n, start=0, multiplier=1, estimated_length=1):
self.data = range(start, start + n)
self._weight = weight
self._multiplier = multiplier
self._estimated_length = estimated_length

@property
def weight(self) -> float:
return self._weight
def multiplier(self) -> float:
return self._multiplier

def __iter__(self):
for sample in self.data:
Expand Down Expand Up @@ -135,8 +135,8 @@ def test_interleaved_never_stop():


def test_interleaved_random():
ds1 = FakeSizedIterableDataset(4, weight=10)
ds2 = FakeSizedIterableDataset(2, start=10, weight=1)
ds1 = FakeSizedIterableDataset(4, multiplier=10)
ds2 = FakeSizedIterableDataset(2, start=10, multiplier=1)
s = datasets.InterleaveDataset(
[ds1, ds2],
)
Expand Down
2 changes: 1 addition & 1 deletion ultravox/training/configs/meta_config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
text_model: "meta-llama/Meta-Llama-3-8B-Instruct"
audio_model: "facebook/wav2vec2-base-960h"

data_sets: ["gigaspeech"]
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
data_sets: [] # Can't use datasets with "multiplier" because it doesn't have a length attribute.
val_sets: ["heysquad_human", "anyinstruct", "soda", "peoplespeech"]
stop_strategy: "LAST_EXHAUSTED"

Expand Down
Loading
Loading