-
Notifications
You must be signed in to change notification settings - Fork 116
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
299 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import pytest | ||
import transformers | ||
|
||
from ultravox.model import ultravox_config | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"model_id", | ||
["fixie-ai/ultravox-v0_2", "fixie-ai/ultravox-v0_3", "fixie-ai/ultravox-v0_4"], | ||
) | ||
def test_can_load_release(model_id: str): | ||
orig_config: transformers.PretrainedConfig = ( | ||
transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=True) | ||
) | ||
config_from_dict = ultravox_config.UltravoxConfig(**orig_config.to_dict()) | ||
config_from_diff_dict = ultravox_config.UltravoxConfig(**orig_config.to_diff_dict()) | ||
|
||
assert config_from_dict.to_dict() == orig_config.to_dict() | ||
assert config_from_diff_dict.to_dict() == orig_config.to_dict() | ||
|
||
assert config_from_dict.text_config.to_dict() == orig_config.text_config.to_dict() | ||
assert config_from_dict.audio_config.to_dict() == orig_config.audio_config.to_dict() | ||
|
||
config_reloaded = ultravox_config.UltravoxConfig(**config_from_dict.to_dict()) | ||
config_reloaded_diff = ultravox_config.UltravoxConfig( | ||
**config_from_dict.to_diff_dict() | ||
) | ||
assert config_reloaded.to_dict() == orig_config.to_dict() | ||
assert config_reloaded_diff.to_dict() == orig_config.to_dict() | ||
|
||
|
||
def test_no_config_when_id_present(): | ||
config = ultravox_config.UltravoxConfig(audio_model_id="openai/whisper-small") | ||
assert "audio_config" not in config.to_diff_dict() | ||
|
||
config = ultravox_config.UltravoxConfig(text_model_id="microsoft/phi-2") | ||
assert "text_config" not in config.to_diff_dict() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Training configuration for Ultravox with Meta-Llama-3.1-70B-Instruct and Whisper-Medium | ||
# This configuration is used alongside release_config.yaml | ||
|
||
# TODO: make sure to increase max_duration in mcloud.yaml to 18 hours instead of 6 | ||
|
||
exp_name: "ultravox-v0_4-llama3.1-70B-whisper_m" | ||
|
||
text_model: "meta-llama/Meta-Llama-3.1-70B-Instruct" | ||
audio_model: "openai/whisper-medium" | ||
|
||
batch_size: 5 | ||
# We increase the number of steps by 2x, but with a lower batch_size, we still won't be training on as many samples as the 8B model | ||
# We would revisit this later on when | ||
max_steps: 28800 # x8x5 = 1,152,000 | ||
|
||
do_eval: False # evaluation doesn't support FSDP yet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from datetime import datetime | ||
from typing import List, Optional | ||
|
||
import huggingface_hub | ||
import transformers | ||
|
||
from ultravox.model import wandb_utils | ||
from ultravox.training import config_base | ||
|
||
ALLOW_PATTERNS = ["*.safetensors", "*.json", "*.txt"] | ||
|
||
|
||
def main(override_sys_args: Optional[List[str]] = None): | ||
start = datetime.now() | ||
print("Downloading weights ...") | ||
|
||
args = config_base.get_train_args(override_sys_args) | ||
|
||
download_weights([args.text_model, args.audio_model], args.model_load_dir) | ||
|
||
end = datetime.now() | ||
print(f"Weights downloaded in {end - start} seconds") | ||
|
||
|
||
def download_weights(model_ids: List[str], model_load_dir: Optional[str] = None): | ||
for model_id in model_ids: | ||
try: | ||
# Download all model files that match ALLOW_PATTERNS | ||
# This is faster than .from_pretrained due to parallel downloads | ||
huggingface_hub.snapshot_download( | ||
repo_id=model_id, allow_patterns=ALLOW_PATTERNS | ||
) | ||
except huggingface_hub.utils.GatedRepoError as e: | ||
raise e | ||
except huggingface_hub.utils.RepositoryNotFoundError as e: | ||
# We assume that the model is local if it's not found on HF Hub. | ||
# The `.from_pretrained` call will verify the local case. | ||
print( | ||
f"Model {model_id} not found on HF Hub. Skipping download. Error: {e}" | ||
) | ||
|
||
# A backstop to make sure the model is fully downloaded. Scenarios to consider: | ||
# - ALLOW_PATTERNS is not enough to download all files needed | ||
# - The model is local, this will verify that everything is in order | ||
# Using `device_map="meta"` to avoid loading the weights into memory or device | ||
transformers.AutoModel.from_pretrained(model_id, device_map="meta") | ||
|
||
if model_load_dir and wandb_utils.is_wandb_url(model_load_dir): | ||
wandb_utils.download_model_from_wandb(model_load_dir) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.