diff --git a/cog_safe_push/config.py b/cog_safe_push/config.py index 66c9565..902b753 100644 --- a/cog_safe_push/config.py +++ b/cog_safe_push/config.py @@ -1,6 +1,6 @@ import argparse -from pydantic import BaseModel, model_validator +from pydantic import BaseModel, ConfigDict, model_validator from .exceptions import ArgumentError @@ -11,6 +11,8 @@ class TestCase(BaseModel): + model_config = ConfigDict(extra="forbid") + inputs: dict[str, InputScalar] exact_string: str | None = None match_url: str | None = None @@ -30,6 +32,8 @@ def check_mutually_exclusive(self): class FuzzConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + fixed_inputs: dict[str, InputScalar] = {} disabled_inputs: list[str] = [] duration: int = DEFAULT_FUZZ_DURATION @@ -37,6 +41,8 @@ class FuzzConfig(BaseModel): class PredictConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + compare_outputs: bool = True predict_timeout: int = DEFAULT_PREDICT_TIMEOUT test_cases: list[TestCase] = [] @@ -44,6 +50,8 @@ class PredictConfig(BaseModel): class TrainConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + destination: str | None = None destination_hardware: str = "cpu" train_timeout: int = DEFAULT_PREDICT_TIMEOUT @@ -52,6 +60,8 @@ class TrainConfig(BaseModel): class Config(BaseModel): + model_config = ConfigDict(extra="forbid") + model: str test_model: str | None = None test_hardware: str = "cpu" diff --git a/cog_safe_push/main.py b/cog_safe_push/main.py index afbd4f6..15f9fb0 100644 --- a/cog_safe_push/main.py +++ b/cog_safe_push/main.py @@ -177,6 +177,9 @@ def run_config(config: Config, no_push: bool): model_owner, model_name = parse_model(config.model) test_model_owner, test_model_name = parse_model(config.test_model) + # small optimization + reuse_test_model = None + if config.train: # Don't push twice in case train and predict are both defined has_predict = config.predict is not None @@ -191,7 +194,7 @@ def run_config(config: Config, no_push: bool): fuzz = FuzzConfig( fixed_inputs={}, disabled_inputs=[], duration=0, iterations=0 ) - cog_safe_push( + reuse_test_model = cog_safe_push( model_owner=model_owner, model_name=model_name, test_model_owner=test_model_owner, @@ -232,6 +235,7 @@ def run_config(config: Config, no_push: bool): fuzz_disabled_inputs=fuzz.disabled_inputs, fuzz_seconds=fuzz.duration, fuzz_iterations=fuzz.iterations, + reuse_test_model=reuse_test_model, ) @@ -253,6 +257,7 @@ def cog_safe_push( fuzz_disabled_inputs: list = [], fuzz_seconds: int = 30, fuzz_iterations: int | None = None, + reuse_test_model: Model | None = None, ): if model_owner == test_model_owner and model_name == test_model_name: raise ArgumentError("Can't use the same model as test model") @@ -278,7 +283,12 @@ def cog_safe_push( f"You need to create the model {model_owner}/{model_name} before running this script" ) - test_model = get_or_create_model(test_model_owner, test_model_name, test_hardware) + if reuse_test_model: + test_model = reuse_test_model + else: + test_model = get_or_create_model( + test_model_owner, test_model_name, test_hardware + ) if train: train_destination = get_or_create_model( @@ -287,12 +297,13 @@ def cog_safe_push( else: train_destination = None - log.info("Pushing test model") - pushed_version_id = cog.push(test_model) - test_model.reload() - assert ( - test_model.versions.list()[0].id == pushed_version_id - ), f"Pushed version ID {pushed_version_id} doesn't match latest version on {test_model_owner}/{test_model_name}: {test_model.versions.list()[0].id}" + if not reuse_test_model: + log.info("Pushing test model") + pushed_version_id = cog.push(test_model) + test_model.reload() + assert ( + test_model.versions.list()[0].id == pushed_version_id + ), f"Pushed version ID {pushed_version_id} doesn't match latest version on {test_model_owner}/{test_model_name}: {test_model.versions.list()[0].id}" log.info("Linting test model schema") schema.lint(test_model, train=train) @@ -359,6 +370,8 @@ def cog_safe_push( log.info("Pushing model...") cog.push(model) + return test_model # for reuse + def parse_inputs(inputs_list: list[str]) -> dict[str, Any]: inputs = {} diff --git a/cog_safe_push/predict.py b/cog_safe_push/predict.py index 3822a54..2c29295 100644 --- a/cog_safe_push/predict.py +++ b/cog_safe_push/predict.py @@ -339,6 +339,8 @@ def predict( version=model.versions.list()[0].id, input=inputs ) + log.vv(f"Prediction URL: https://replicate.com/p/{prediction.id}") + start_time = time.time() while prediction.status not in ["succeeded", "failed", "canceled"]: time.sleep(0.5)