From 07faa68fe6290bc62661605e0c89410e9b9eaf60 Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Wed, 2 Oct 2024 15:27:52 -0700 Subject: [PATCH] Support pushing with custom dockerfile --- cog_safe_push/cog.py | 7 +++++-- cog_safe_push/config.py | 1 + cog_safe_push/main.py | 8 +++++--- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/cog_safe_push/cog.py b/cog_safe_push/cog.py index f77b4c8..144d053 100644 --- a/cog_safe_push/cog.py +++ b/cog_safe_push/cog.py @@ -6,11 +6,14 @@ from . import log -def push(model: Model) -> str: +def push(model: Model, dockerfile: str | None) -> str: url = f"r8.im/{model.owner}/{model.name}" log.info(f"Pushing to {url}") + cmd = ["cog", "push", url] + if dockerfile: + cmd += ["--dockerfile", dockerfile] process = subprocess.Popen( - ["cog", "push", url], + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True, diff --git a/cog_safe_push/config.py b/cog_safe_push/config.py index 902b753..129f484 100644 --- a/cog_safe_push/config.py +++ b/cog_safe_push/config.py @@ -67,6 +67,7 @@ class Config(BaseModel): test_hardware: str = "cpu" predict: PredictConfig | None = None train: TrainConfig | None = None + dockerfile: str | None = None def override(self, field: str, args: argparse.Namespace, arg: str): if hasattr(args, arg) and getattr(args, arg) is not None: diff --git a/cog_safe_push/main.py b/cog_safe_push/main.py index e3e0b13..3822457 100644 --- a/cog_safe_push/main.py +++ b/cog_safe_push/main.py @@ -237,6 +237,7 @@ def run_config(config: Config, no_push: bool): fuzz_seconds=fuzz.duration, fuzz_iterations=fuzz.iterations, reuse_test_model=reuse_test_model, + dockerfile=config.dockerfile, ) @@ -259,6 +260,7 @@ def cog_safe_push( fuzz_seconds: int = 30, fuzz_iterations: int | None = None, reuse_test_model: Model | None = None, + dockerfile: str | None = None, ): if model_owner == test_model_owner and model_name == test_model_name: raise ArgumentError("Can't use the same model as test model") @@ -300,7 +302,7 @@ def cog_safe_push( if not reuse_test_model: log.info("Pushing test model") - pushed_version_id = cog.push(test_model) + pushed_version_id = cog.push(test_model, dockerfile) test_model.reload() try: assert ( @@ -387,7 +389,7 @@ def cog_safe_push( if not no_push: log.info("Pushing model...") - cog.push(model) + cog.push(model, dockerfile) return test_model # for reuse @@ -453,7 +455,7 @@ def get_model(owner, name) -> Model | None: def parse_model(model_owner_name: str) -> tuple[str, str]: - pattern = r"^([a-z0-9_-]+)/([a-z0-9-]+)$" + pattern = r"^([a-z0-9_-]+)/([a-z0-9-.]+)$" match = re.match(pattern, model_owner_name) if not match: raise ArgumentError(f"Invalid model URL format: {model_owner_name}")