Skip to content

Commit

Permalink
Remove mlflow dependency (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
juberti authored Jun 11, 2024
1 parent 680a730 commit 2709614
Show file tree
Hide file tree
Showing 7 changed files with 6 additions and 44 deletions.
4 changes: 2 additions & 2 deletions Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ install:
just python -m pip install types-requests
format:
. ./activate ${VENV_NAME} && autoflake ${PROJECT_DIR} --remove-all-unused-imports --quiet --in-place -r --exclude third_party --exclude ultravox/model/gazelle
. ./activate ${VENV_NAME} && autoflake ${PROJECT_DIR} --remove-all-unused-imports --quiet --in-place -r --exclude third_party
. ./activate ${VENV_NAME} && isort ${PROJECT_DIR} --force-single-line-imports
. ./activate ${VENV_NAME} && black ${PROJECT_DIR}
check:
. ./activate ${VENV_NAME} && black ${PROJECT_DIR} --check
. ./activate ${VENV_NAME} && isort ${PROJECT_DIR} --check --force-single-line-imports
. ./activate ${VENV_NAME} && autoflake ${PROJECT_DIR} --check --quiet --remove-all-unused-imports -r --exclude third_party --exclude ultravox/model/gazelle
. ./activate ${VENV_NAME} && autoflake ${PROJECT_DIR} --check --quiet --remove-all-unused-imports -r --exclude third_party
. ./activate ${VENV_NAME} && mypy ${PROJECT_DIR}
test *ARGS=".":
Expand Down
4 changes: 2 additions & 2 deletions mcloud.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Gazelle POC training configuration
# Ultravox POC training configuration

name: gazelle-poc
name: ultravox
image: mosaicml/composer:latest
compute:
gpus: 8
Expand Down
2 changes: 0 additions & 2 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
[mypy]
ignore_missing_imports = True

[mypy-ultravox/model/gazelle.*]
ignore_errors = True
3 changes: 0 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,3 @@ jiwer
# Monitoring
tensorboardx
wandb
neptune
mlflow

2 changes: 1 addition & 1 deletion ultravox/model/ultravox_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _process(self, sample: datasets.VoiceSample) -> Dict[str, Any]:
sample.messages, tokenize=False
)

# Process audio and text using GazelleProcessor.
# Process audio and text using UltravoxProcessor.
# Audio is expanded to be a [C x M] array, although C=1 for mono audio.
audio = (
np.expand_dims(sample.audio, axis=0) if sample.audio is not None else None
Expand Down
2 changes: 1 addition & 1 deletion ultravox/training/configs/llama3_whisper.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SLM with gazelle & llama3
# SLM with ultravox & llama3
exp_name: "llama3_whisper_s"

# Make sure to accept the license agreement on huggingface hub
Expand Down
33 changes: 0 additions & 33 deletions ultravox/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import List, Optional

import datasets as hf_datasets
import mlflow
import safetensors.torch
import simple_parsing
import torch
Expand All @@ -20,7 +19,6 @@

from ultravox.data import datasets
from ultravox.inference import infer
from ultravox.inference import ultravox_infer
from ultravox.model import ultravox_config
from ultravox.model import ultravox_model
from ultravox.model import ultravox_processing
Expand All @@ -33,17 +31,6 @@
OUTPUT_EXAMPLE = {"text": "Hello, world!"}


class GazelleMlflowWrapper(mlflow.pyfunc.PythonModel):
def predict(self, context, model_input):
sample = datasets.VoiceSample.from_prompt_and_buf(
model_input["text"], model_input["audio"]
)
return self.inference.infer(sample)

def load_context(self, context):
self.inference = ultravox_infer.UltravoxInference(context.artifacts["model_id"])


def fix_hyphens(arg: str):
return re.sub(r"^--([^=]+)", lambda m: "--" + m.group(1).replace("-", "_"), arg)

Expand Down Expand Up @@ -143,12 +130,6 @@ def main() -> None:
dir="runs",
)

# Starting MLflow; we need to set the experiment name before training starts.
if "mlflow" in args.report_logs_to and is_master:
mlflow.set_tracking_uri("runs/mlruns")
db_exp_name = f"/Shared/{args.exp_name}"
mlflow.set_experiment(db_exp_name)

if args.model_load_dir:
logging.info(f"Loading model state dict from {args.model_load_dir}")
load_path = args.model_load_dir
Expand Down Expand Up @@ -274,20 +255,6 @@ def main() -> None:
)
trainer.train()
trainer.save_model(args.output_dir)
if "mlflow" in args.report_logs_to and is_master:
signature = mlflow.models.signature.infer_signature(
INPUT_EXAMPLE, OUTPUT_EXAMPLE
)
model_info = mlflow.pyfunc.log_model(
python_model=GazelleMlflowWrapper(),
artifact_path="model",
pip_requirements="requirements.txt",
registered_model_name="ultravox",
input_example=INPUT_EXAMPLE,
signature=signature,
)
logging.info(f"Model logged to MLflow: {model_info.model_uri}")

t_end = datetime.now()
logging.info(f"end time: {t_end}")
logging.info(f"elapsed: {t_end - t_start}")
Expand Down

0 comments on commit 2709614

Please sign in to comment.