Skip to content

Commit

Permalink
Support pushing with custom dockerfile
Browse files Browse the repository at this point in the history
  • Loading branch information
andreasjansson committed Oct 2, 2024
1 parent 568da71 commit 07faa68
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
7 changes: 5 additions & 2 deletions cog_safe_push/cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
from . import log


def push(model: Model) -> str:
def push(model: Model, dockerfile: str | None) -> str:
url = f"r8.im/{model.owner}/{model.name}"
log.info(f"Pushing to {url}")
cmd = ["cog", "push", url]
if dockerfile:
cmd += ["--dockerfile", dockerfile]
process = subprocess.Popen(
["cog", "push", url],
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True,
Expand Down
1 change: 1 addition & 0 deletions cog_safe_push/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class Config(BaseModel):
test_hardware: str = "cpu"
predict: PredictConfig | None = None
train: TrainConfig | None = None
dockerfile: str | None = None

def override(self, field: str, args: argparse.Namespace, arg: str):
if hasattr(args, arg) and getattr(args, arg) is not None:
Expand Down
8 changes: 5 additions & 3 deletions cog_safe_push/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def run_config(config: Config, no_push: bool):
fuzz_seconds=fuzz.duration,
fuzz_iterations=fuzz.iterations,
reuse_test_model=reuse_test_model,
dockerfile=config.dockerfile,
)


Expand All @@ -259,6 +260,7 @@ def cog_safe_push(
fuzz_seconds: int = 30,
fuzz_iterations: int | None = None,
reuse_test_model: Model | None = None,
dockerfile: str | None = None,
):
if model_owner == test_model_owner and model_name == test_model_name:
raise ArgumentError("Can't use the same model as test model")
Expand Down Expand Up @@ -300,7 +302,7 @@ def cog_safe_push(

if not reuse_test_model:
log.info("Pushing test model")
pushed_version_id = cog.push(test_model)
pushed_version_id = cog.push(test_model, dockerfile)
test_model.reload()
try:
assert (
Expand Down Expand Up @@ -387,7 +389,7 @@ def cog_safe_push(

if not no_push:
log.info("Pushing model...")
cog.push(model)
cog.push(model, dockerfile)

return test_model # for reuse

Expand Down Expand Up @@ -453,7 +455,7 @@ def get_model(owner, name) -> Model | None:


def parse_model(model_owner_name: str) -> tuple[str, str]:
pattern = r"^([a-z0-9_-]+)/([a-z0-9-]+)$"
pattern = r"^([a-z0-9_-]+)/([a-z0-9-.]+)$"
match = re.match(pattern, model_owner_name)
if not match:
raise ArgumentError(f"Invalid model URL format: {model_owner_name}")
Expand Down

0 comments on commit 07faa68

Please sign in to comment.