Skip to content

Commit

Permalink
Merge pull request #5 from replicate/yaml-validation
Browse files Browse the repository at this point in the history
validate yaml for forbidden extra fields
  • Loading branch information
andreasjansson authored Sep 11, 2024
2 parents 94f0972 + 726a1b3 commit 6fb8710
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 9 deletions.
12 changes: 11 additions & 1 deletion cog_safe_push/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse

from pydantic import BaseModel, model_validator
from pydantic import BaseModel, ConfigDict, model_validator

from .exceptions import ArgumentError

Expand All @@ -11,6 +11,8 @@


class TestCase(BaseModel):
model_config = ConfigDict(extra="forbid")

inputs: dict[str, InputScalar]
exact_string: str | None = None
match_url: str | None = None
Expand All @@ -30,20 +32,26 @@ def check_mutually_exclusive(self):


class FuzzConfig(BaseModel):
model_config = ConfigDict(extra="forbid")

fixed_inputs: dict[str, InputScalar] = {}
disabled_inputs: list[str] = []
duration: int = DEFAULT_FUZZ_DURATION
iterations: int | None = None


class PredictConfig(BaseModel):
model_config = ConfigDict(extra="forbid")

compare_outputs: bool = True
predict_timeout: int = DEFAULT_PREDICT_TIMEOUT
test_cases: list[TestCase] = []
fuzz: FuzzConfig | None = None


class TrainConfig(BaseModel):
model_config = ConfigDict(extra="forbid")

destination: str | None = None
destination_hardware: str = "cpu"
train_timeout: int = DEFAULT_PREDICT_TIMEOUT
Expand All @@ -52,6 +60,8 @@ class TrainConfig(BaseModel):


class Config(BaseModel):
model_config = ConfigDict(extra="forbid")

model: str
test_model: str | None = None
test_hardware: str = "cpu"
Expand Down
29 changes: 21 additions & 8 deletions cog_safe_push/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ def run_config(config: Config, no_push: bool):
model_owner, model_name = parse_model(config.model)
test_model_owner, test_model_name = parse_model(config.test_model)

# small optimization
reuse_test_model = None

if config.train:
# Don't push twice in case train and predict are both defined
has_predict = config.predict is not None
Expand All @@ -191,7 +194,7 @@ def run_config(config: Config, no_push: bool):
fuzz = FuzzConfig(
fixed_inputs={}, disabled_inputs=[], duration=0, iterations=0
)
cog_safe_push(
reuse_test_model = cog_safe_push(
model_owner=model_owner,
model_name=model_name,
test_model_owner=test_model_owner,
Expand Down Expand Up @@ -232,6 +235,7 @@ def run_config(config: Config, no_push: bool):
fuzz_disabled_inputs=fuzz.disabled_inputs,
fuzz_seconds=fuzz.duration,
fuzz_iterations=fuzz.iterations,
reuse_test_model=reuse_test_model,
)


Expand All @@ -253,6 +257,7 @@ def cog_safe_push(
fuzz_disabled_inputs: list = [],
fuzz_seconds: int = 30,
fuzz_iterations: int | None = None,
reuse_test_model: Model | None = None,
):
if model_owner == test_model_owner and model_name == test_model_name:
raise ArgumentError("Can't use the same model as test model")
Expand All @@ -278,7 +283,12 @@ def cog_safe_push(
f"You need to create the model {model_owner}/{model_name} before running this script"
)

test_model = get_or_create_model(test_model_owner, test_model_name, test_hardware)
if reuse_test_model:
test_model = reuse_test_model
else:
test_model = get_or_create_model(
test_model_owner, test_model_name, test_hardware
)

if train:
train_destination = get_or_create_model(
Expand All @@ -287,12 +297,13 @@ def cog_safe_push(
else:
train_destination = None

log.info("Pushing test model")
pushed_version_id = cog.push(test_model)
test_model.reload()
assert (
test_model.versions.list()[0].id == pushed_version_id
), f"Pushed version ID {pushed_version_id} doesn't match latest version on {test_model_owner}/{test_model_name}: {test_model.versions.list()[0].id}"
if not reuse_test_model:
log.info("Pushing test model")
pushed_version_id = cog.push(test_model)
test_model.reload()
assert (
test_model.versions.list()[0].id == pushed_version_id
), f"Pushed version ID {pushed_version_id} doesn't match latest version on {test_model_owner}/{test_model_name}: {test_model.versions.list()[0].id}"

log.info("Linting test model schema")
schema.lint(test_model, train=train)
Expand Down Expand Up @@ -359,6 +370,8 @@ def cog_safe_push(
log.info("Pushing model...")
cog.push(model)

return test_model # for reuse


def parse_inputs(inputs_list: list[str]) -> dict[str, Any]:
inputs = {}
Expand Down
2 changes: 2 additions & 0 deletions cog_safe_push/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,8 @@ def predict(
version=model.versions.list()[0].id, input=inputs
)

log.vv(f"Prediction URL: https://replicate.com/p/{prediction.id}")

start_time = time.time()
while prediction.status not in ["succeeded", "failed", "canceled"]:
time.sleep(0.5)
Expand Down

0 comments on commit 6fb8710

Please sign in to comment.