Skip to content

Commit

Permalink
validate yaml for forbidden extra fields
Browse files Browse the repository at this point in the history
also small optimization to reuse test model instead of repushing if both train and predict are set.
  • Loading branch information
andreasjansson committed Sep 11, 2024
1 parent 94f0972 commit 726a1b3
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 9 deletions.
12 changes: 11 additions & 1 deletion cog_safe_push/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse

from pydantic import BaseModel, model_validator
from pydantic import BaseModel, ConfigDict, model_validator

from .exceptions import ArgumentError

Expand All @@ -11,6 +11,8 @@


class TestCase(BaseModel):
model_config = ConfigDict(extra="forbid")

inputs: dict[str, InputScalar]
exact_string: str | None = None
match_url: str | None = None
Expand All @@ -30,20 +32,26 @@ def check_mutually_exclusive(self):


class FuzzConfig(BaseModel):
model_config = ConfigDict(extra="forbid")

fixed_inputs: dict[str, InputScalar] = {}
disabled_inputs: list[str] = []
duration: int = DEFAULT_FUZZ_DURATION
iterations: int | None = None


class PredictConfig(BaseModel):
model_config = ConfigDict(extra="forbid")

compare_outputs: bool = True
predict_timeout: int = DEFAULT_PREDICT_TIMEOUT
test_cases: list[TestCase] = []
fuzz: FuzzConfig | None = None


class TrainConfig(BaseModel):
model_config = ConfigDict(extra="forbid")

destination: str | None = None
destination_hardware: str = "cpu"
train_timeout: int = DEFAULT_PREDICT_TIMEOUT
Expand All @@ -52,6 +60,8 @@ class TrainConfig(BaseModel):


class Config(BaseModel):
model_config = ConfigDict(extra="forbid")

model: str
test_model: str | None = None
test_hardware: str = "cpu"
Expand Down
29 changes: 21 additions & 8 deletions cog_safe_push/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ def run_config(config: Config, no_push: bool):
model_owner, model_name = parse_model(config.model)
test_model_owner, test_model_name = parse_model(config.test_model)

# small optimization
reuse_test_model = None

if config.train:
# Don't push twice in case train and predict are both defined
has_predict = config.predict is not None
Expand All @@ -191,7 +194,7 @@ def run_config(config: Config, no_push: bool):
fuzz = FuzzConfig(
fixed_inputs={}, disabled_inputs=[], duration=0, iterations=0
)
cog_safe_push(
reuse_test_model = cog_safe_push(
model_owner=model_owner,
model_name=model_name,
test_model_owner=test_model_owner,
Expand Down Expand Up @@ -232,6 +235,7 @@ def run_config(config: Config, no_push: bool):
fuzz_disabled_inputs=fuzz.disabled_inputs,
fuzz_seconds=fuzz.duration,
fuzz_iterations=fuzz.iterations,
reuse_test_model=reuse_test_model,
)


Expand All @@ -253,6 +257,7 @@ def cog_safe_push(
fuzz_disabled_inputs: list = [],
fuzz_seconds: int = 30,
fuzz_iterations: int | None = None,
reuse_test_model: Model | None = None,
):
if model_owner == test_model_owner and model_name == test_model_name:
raise ArgumentError("Can't use the same model as test model")
Expand All @@ -278,7 +283,12 @@ def cog_safe_push(
f"You need to create the model {model_owner}/{model_name} before running this script"
)

test_model = get_or_create_model(test_model_owner, test_model_name, test_hardware)
if reuse_test_model:
test_model = reuse_test_model
else:
test_model = get_or_create_model(
test_model_owner, test_model_name, test_hardware
)

if train:
train_destination = get_or_create_model(
Expand All @@ -287,12 +297,13 @@ def cog_safe_push(
else:
train_destination = None

log.info("Pushing test model")
pushed_version_id = cog.push(test_model)
test_model.reload()
assert (
test_model.versions.list()[0].id == pushed_version_id
), f"Pushed version ID {pushed_version_id} doesn't match latest version on {test_model_owner}/{test_model_name}: {test_model.versions.list()[0].id}"
if not reuse_test_model:
log.info("Pushing test model")
pushed_version_id = cog.push(test_model)
test_model.reload()
assert (
test_model.versions.list()[0].id == pushed_version_id
), f"Pushed version ID {pushed_version_id} doesn't match latest version on {test_model_owner}/{test_model_name}: {test_model.versions.list()[0].id}"

log.info("Linting test model schema")
schema.lint(test_model, train=train)
Expand Down Expand Up @@ -359,6 +370,8 @@ def cog_safe_push(
log.info("Pushing model...")
cog.push(model)

return test_model # for reuse


def parse_inputs(inputs_list: list[str]) -> dict[str, Any]:
inputs = {}
Expand Down
2 changes: 2 additions & 0 deletions cog_safe_push/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,8 @@ def predict(
version=model.versions.list()[0].id, input=inputs
)

log.vv(f"Prediction URL: https://replicate.com/p/{prediction.id}")

start_time = time.time()
while prediction.status not in ["succeeded", "failed", "canceled"]:
time.sleep(0.5)
Expand Down

0 comments on commit 726a1b3

Please sign in to comment.