Skip to content

Commit

Permalink
Defining block size in UltravoxConfig, and solving assertions (#157)
Browse files Browse the repository at this point in the history
* assertion mismatches

* Explicit definition of audio_latency_block_size

* Fix ultravox_config_test incompatibility between keys
  • Loading branch information
saeeddhqan authored Nov 25, 2024
1 parent 9fc2732 commit 6f8f255
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
4 changes: 4 additions & 0 deletions ultravox/model/ultravox_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class UltravoxConfig(transformers.PretrainedConfig):
The LoRA configuration for finetuning the text model.
audio_model_lora_config (`LoraConfigSimplified`, *optional*):
The LoRA configuration for finetuning the audio model.
audio_latency_block_size (`int`, *optional*, defaults to `None`):
The latency block size for simulating audio streaming.
Example:
Expand Down Expand Up @@ -105,6 +107,7 @@ def __init__(
projector_act: str = "swiglu",
text_model_lora_config: Optional[LoraConfigSimplified] = None,
audio_model_lora_config: Optional[LoraConfigSimplified] = None,
audio_latency_block_size: Optional[int] = None,
**kwargs,
):
self.ignore_index = ignore_index
Expand Down Expand Up @@ -147,6 +150,7 @@ def __init__(
if isinstance(audio_model_lora_config, dict)
else dataclasses.asdict(audio_model_lora_config or LoraConfigSimplified())
)
self.audio_latency_block_size = audio_latency_block_size

self.vocab_size = self.text_config.vocab_size

Expand Down
14 changes: 10 additions & 4 deletions ultravox/model/ultravox_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@ def test_can_load_release(model_id: str):
)
config_from_dict = ultravox_config.UltravoxConfig(**orig_config.to_dict())
config_from_diff_dict = ultravox_config.UltravoxConfig(**orig_config.to_diff_dict())
# To not inadvertently ignore other keys, we explicitly define keys we require to ignore.
keys_to_ignore = ("audio_latency_block_size",)
orig_values = {
**{k: None for k in keys_to_ignore},
**orig_config.to_dict(),
}

assert config_from_dict.to_dict() == orig_config.to_dict()
assert config_from_diff_dict.to_dict() == orig_config.to_dict()
assert config_from_dict.to_dict() == orig_values
assert config_from_diff_dict.to_dict() == orig_values

assert config_from_dict.text_config.to_dict() == orig_config.text_config.to_dict()
assert config_from_dict.audio_config.to_dict() == orig_config.audio_config.to_dict()
Expand All @@ -25,8 +31,8 @@ def test_can_load_release(model_id: str):
config_reloaded_diff = ultravox_config.UltravoxConfig(
**config_from_dict.to_diff_dict()
)
assert config_reloaded.to_dict() == orig_config.to_dict()
assert config_reloaded_diff.to_dict() == orig_config.to_dict()
assert config_reloaded.to_dict() == orig_values
assert config_reloaded_diff.to_dict() == orig_values


def test_no_config_when_id_present():
Expand Down
4 changes: 2 additions & 2 deletions ultravox/model/ultravox_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def _create_audio_tower(
config.audio_latency_block_size, dtype=config.torch_dtype
)
else:
assert config.audio_latency_block_size not in (
assert config.audio_latency_block_size in (
None,
0,
), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
Expand All @@ -305,7 +305,7 @@ def _create_audio_tower(
config.audio_latency_block_size, dtype=config.torch_dtype
)
else:
assert config.audio_latency_block_size not in (
assert config.audio_latency_block_size in (
None,
0,
), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
Expand Down

0 comments on commit 6f8f255

Please sign in to comment.