Skip to content

Commit

Permalink
Support official models
Browse files Browse the repository at this point in the history
Official models don't have a list of versions
  • Loading branch information
andreasjansson committed Sep 19, 2024
1 parent 49dccb6 commit a8cdb64
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 9 deletions.
26 changes: 22 additions & 4 deletions cog_safe_push/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,14 +302,32 @@ def cog_safe_push(
log.info("Pushing test model")
pushed_version_id = cog.push(test_model)
test_model.reload()
assert (
test_model.versions.list()[0].id == pushed_version_id
), f"Pushed version ID {pushed_version_id} doesn't match latest version on {test_model_owner}/{test_model_name}: {test_model.versions.list()[0].id}"
try:
assert (
test_model.versions.list()[0].id == pushed_version_id
), f"Pushed version ID {pushed_version_id} doesn't match latest version on {test_model_owner}/{test_model_name}: {test_model.versions.list()[0].id}"
except ReplicateError as e:
if e.status == 404:
# Assume it's an official model
# If it's an official model, can't check that the version matches
pass
else:
raise

log.info("Linting test model schema")
schema.lint(test_model, train=train)

if model.latest_version:
model_has_versions = False
try:
model_has_versions = bool(model.versions.list())
except ReplicateError as e:
if e.status == 404:
# Assume it's an official model
model_has_versions = bool(model.latest_version)
else:
raise

if model_has_versions:
log.info("Checking schema backwards compatibility")
test_model_schemas = schema.get_schemas(test_model, train=train)
model_schemas = schema.get_schemas(model, train=train)
Expand Down
14 changes: 11 additions & 3 deletions cog_safe_push/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, List

import replicate
from replicate.exceptions import ReplicateError
from replicate.model import Model

from . import ai, log, schema
Expand Down Expand Up @@ -335,9 +336,16 @@ def predict(
destination=f"{train_destination.owner}/{train_destination.name}",
)
else:
prediction = replicate.predictions.create(
version=model.versions.list()[0].id, input=inputs
)
try:
prediction = replicate.predictions.create(
version=model.versions.list()[0].id, input=inputs
)
except ReplicateError as e:
if e.status == 404:
# Assume it's an official model
prediction = replicate.predictions.create(model=model, input=inputs)
else:
raise

log.vv(f"Prediction URL: https://replicate.com/p/{prediction.id}")

Expand Down
16 changes: 14 additions & 2 deletions cog_safe_push/schema.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from replicate.exceptions import ReplicateError
from replicate.model import Model

from .exceptions import IncompatibleSchemaError, SchemaLintError
Expand All @@ -7,7 +8,7 @@ def lint(model: Model, train: bool):
errors = []

input_name = "TrainingInput" if train else "Input"
schema = model.versions.list()[0].openapi_schema
schema = get_openapi_schema(model)
properties = schema["components"]["schemas"][input_name]["properties"]
for name, spec in properties.items():
description = spec.get("description")
Expand Down Expand Up @@ -104,8 +105,19 @@ def check_backwards_compatible(
)


def get_openapi_schema(model: Model) -> dict:
try:
return model.versions.list()[0].openapi_schema
except ReplicateError as e:
if e.status == 404:
# Assume it's an official model
assert model.latest_version
return model.latest_version.openapi_schema
raise


def get_schemas(model, train: bool):
schemas = model.versions.list()[0].openapi_schema["components"]["schemas"]
schemas = get_openapi_schema(model)["components"]["schemas"]
unnecessary_keys = [
"HTTPValidationError",
"PredictionRequest",
Expand Down

0 comments on commit a8cdb64

Please sign in to comment.