Skip to content

Commit

Permalink
not uploading text_config when text_model_id is present (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
farzadab authored Sep 13, 2024
1 parent ae39709 commit a426890
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
13 changes: 11 additions & 2 deletions ultravox/model/ultravox_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def __init__(
audio_model_id: Optional[str] = None,
text_model_id: Optional[str] = None,
ignore_index: int = -100,
audio_token_index: int = 32000,
hidden_size: int = 4096,
stack_factor: int = 8,
norm_init: float = 0.4,
Expand All @@ -112,7 +111,6 @@ def __init__(

self.audio_model_id = audio_model_id
self.text_model_id = text_model_id
self.audio_token_index = audio_token_index

self.hidden_size = hidden_size
self.stack_factor = stack_factor
Expand Down Expand Up @@ -155,3 +153,14 @@ def __init__(
self.initializer_range = self.text_config.initializer_range

super().__init__(**kwargs)

def to_diff_dict(self) -> Dict[str, Any]:
diff_dict = super().to_diff_dict()

# remove text_config and audio_config if text_model_id and audio_model_id are present
if self.text_model_id is not None:
diff_dict.pop("text_config", None)
if self.audio_model_id is not None:
diff_dict.pop("audio_config", None)

return diff_dict
37 changes: 37 additions & 0 deletions ultravox/model/ultravox_config_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest
import transformers

from ultravox.model import ultravox_config


@pytest.mark.parametrize(
"model_id",
["fixie-ai/ultravox-v0_2", "fixie-ai/ultravox-v0_3", "fixie-ai/ultravox-v0_4"],
)
def test_can_load_release(model_id: str):
orig_config: transformers.PretrainedConfig = (
transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=True)
)
config_from_dict = ultravox_config.UltravoxConfig(**orig_config.to_dict())
config_from_diff_dict = ultravox_config.UltravoxConfig(**orig_config.to_diff_dict())

assert config_from_dict.to_dict() == orig_config.to_dict()
assert config_from_diff_dict.to_dict() == orig_config.to_dict()

assert config_from_dict.text_config.to_dict() == orig_config.text_config.to_dict()
assert config_from_dict.audio_config.to_dict() == orig_config.audio_config.to_dict()

config_reloaded = ultravox_config.UltravoxConfig(**config_from_dict.to_dict())
config_reloaded_diff = ultravox_config.UltravoxConfig(
**config_from_dict.to_diff_dict()
)
assert config_reloaded.to_dict() == orig_config.to_dict()
assert config_reloaded_diff.to_dict() == orig_config.to_dict()


def test_no_config_when_id_present():
config = ultravox_config.UltravoxConfig(audio_model_id="openai/whisper-small")
assert "audio_config" not in config.to_diff_dict()

config = ultravox_config.UltravoxConfig(text_model_id="microsoft/phi-2")
assert "text_config" not in config.to_diff_dict()

0 comments on commit a426890

Please sign in to comment.