From 2709614bab0460399f484c010fafa7593575f1c7 Mon Sep 17 00:00:00 2001 From: Justin Uberti Date: Tue, 11 Jun 2024 14:15:10 -0700 Subject: [PATCH] Remove mlflow dependency (#23) --- Justfile | 4 +-- mcloud.yaml | 4 +-- mypy.ini | 2 -- requirements.txt | 3 -- ultravox/model/ultravox_processing.py | 2 +- ultravox/training/configs/llama3_whisper.yaml | 2 +- ultravox/training/train.py | 33 ------------------- 7 files changed, 6 insertions(+), 44 deletions(-) diff --git a/Justfile b/Justfile index 5fdd9b0d..dae55daa 100644 --- a/Justfile +++ b/Justfile @@ -20,14 +20,14 @@ install: just python -m pip install types-requests format: - . ./activate ${VENV_NAME} && autoflake ${PROJECT_DIR} --remove-all-unused-imports --quiet --in-place -r --exclude third_party --exclude ultravox/model/gazelle + . ./activate ${VENV_NAME} && autoflake ${PROJECT_DIR} --remove-all-unused-imports --quiet --in-place -r --exclude third_party . ./activate ${VENV_NAME} && isort ${PROJECT_DIR} --force-single-line-imports . ./activate ${VENV_NAME} && black ${PROJECT_DIR} check: . ./activate ${VENV_NAME} && black ${PROJECT_DIR} --check . ./activate ${VENV_NAME} && isort ${PROJECT_DIR} --check --force-single-line-imports - . ./activate ${VENV_NAME} && autoflake ${PROJECT_DIR} --check --quiet --remove-all-unused-imports -r --exclude third_party --exclude ultravox/model/gazelle + . ./activate ${VENV_NAME} && autoflake ${PROJECT_DIR} --check --quiet --remove-all-unused-imports -r --exclude third_party . ./activate ${VENV_NAME} && mypy ${PROJECT_DIR} test *ARGS=".": diff --git a/mcloud.yaml b/mcloud.yaml index e02fa95c..8d4edb15 100644 --- a/mcloud.yaml +++ b/mcloud.yaml @@ -1,6 +1,6 @@ -# Gazelle POC training configuration +# Ultravox POC training configuration -name: gazelle-poc +name: ultravox image: mosaicml/composer:latest compute: gpus: 8 diff --git a/mypy.ini b/mypy.ini index 0183ffbd..ebcf395f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,5 +1,3 @@ [mypy] ignore_missing_imports = True -[mypy-ultravox/model/gazelle.*] -ignore_errors = True diff --git a/requirements.txt b/requirements.txt index 0583b15d..8ad5ada1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,6 +22,3 @@ jiwer # Monitoring tensorboardx wandb -neptune -mlflow - diff --git a/ultravox/model/ultravox_processing.py b/ultravox/model/ultravox_processing.py index 58808658..dbe1b453 100644 --- a/ultravox/model/ultravox_processing.py +++ b/ultravox/model/ultravox_processing.py @@ -212,7 +212,7 @@ def _process(self, sample: datasets.VoiceSample) -> Dict[str, Any]: sample.messages, tokenize=False ) - # Process audio and text using GazelleProcessor. + # Process audio and text using UltravoxProcessor. # Audio is expanded to be a [C x M] array, although C=1 for mono audio. audio = ( np.expand_dims(sample.audio, axis=0) if sample.audio is not None else None diff --git a/ultravox/training/configs/llama3_whisper.yaml b/ultravox/training/configs/llama3_whisper.yaml index 3a7193f0..48283745 100644 --- a/ultravox/training/configs/llama3_whisper.yaml +++ b/ultravox/training/configs/llama3_whisper.yaml @@ -1,4 +1,4 @@ -# SLM with gazelle & llama3 +# SLM with ultravox & llama3 exp_name: "llama3_whisper_s" # Make sure to accept the license agreement on huggingface hub diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 21991219..37fc79ac 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -8,7 +8,6 @@ from typing import List, Optional import datasets as hf_datasets -import mlflow import safetensors.torch import simple_parsing import torch @@ -20,7 +19,6 @@ from ultravox.data import datasets from ultravox.inference import infer -from ultravox.inference import ultravox_infer from ultravox.model import ultravox_config from ultravox.model import ultravox_model from ultravox.model import ultravox_processing @@ -33,17 +31,6 @@ OUTPUT_EXAMPLE = {"text": "Hello, world!"} -class GazelleMlflowWrapper(mlflow.pyfunc.PythonModel): - def predict(self, context, model_input): - sample = datasets.VoiceSample.from_prompt_and_buf( - model_input["text"], model_input["audio"] - ) - return self.inference.infer(sample) - - def load_context(self, context): - self.inference = ultravox_infer.UltravoxInference(context.artifacts["model_id"]) - - def fix_hyphens(arg: str): return re.sub(r"^--([^=]+)", lambda m: "--" + m.group(1).replace("-", "_"), arg) @@ -143,12 +130,6 @@ def main() -> None: dir="runs", ) - # Starting MLflow; we need to set the experiment name before training starts. - if "mlflow" in args.report_logs_to and is_master: - mlflow.set_tracking_uri("runs/mlruns") - db_exp_name = f"/Shared/{args.exp_name}" - mlflow.set_experiment(db_exp_name) - if args.model_load_dir: logging.info(f"Loading model state dict from {args.model_load_dir}") load_path = args.model_load_dir @@ -274,20 +255,6 @@ def main() -> None: ) trainer.train() trainer.save_model(args.output_dir) - if "mlflow" in args.report_logs_to and is_master: - signature = mlflow.models.signature.infer_signature( - INPUT_EXAMPLE, OUTPUT_EXAMPLE - ) - model_info = mlflow.pyfunc.log_model( - python_model=GazelleMlflowWrapper(), - artifact_path="model", - pip_requirements="requirements.txt", - registered_model_name="ultravox", - input_example=INPUT_EXAMPLE, - signature=signature, - ) - logging.info(f"Model logged to MLflow: {model_info.model_uri}") - t_end = datetime.now() logging.info(f"end time: {t_end}") logging.info(f"elapsed: {t_end - t_start}")