Skip to content

Commit

Permalink
Tighter ruff rules
Browse files Browse the repository at this point in the history
  • Loading branch information
andreasjansson committed Aug 27, 2024
1 parent 03e8395 commit 45edc4c
Show file tree
Hide file tree
Showing 23 changed files with 294 additions and 125 deletions.
10 changes: 3 additions & 7 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,9 @@ jobs:
run: |
pip install -r requirements-test.txt
- name: Run ruff
- name: Lint
run: |
ruff check --ignore=F403,F405
- name: Run black
run: |
black --check .
./script/lint
unit-test:
runs-on: ubuntu-latest
Expand All @@ -48,4 +44,4 @@ jobs:
- name: Run pytest
run: |
pytest test/
./script/unit-test
19 changes: 13 additions & 6 deletions cog_safe_push/ai.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import base64
import json
import mimetypes
from pathlib import Path
import os
import json
from pathlib import Path
from typing import cast

import anthropic

from . import log
from .exceptions import AIError
from .retry import retry
from . import log


@retry(3)
Expand Down Expand Up @@ -47,7 +49,9 @@ def call(system_prompt: str, prompt: str, files: list[Path] | None = None) -> st
content = prompt
log.vvv(f"Claude prompt: {prompt}")

messages = [{"role": "user", "content": content}]
messages: list[anthropic.types.MessageParam] = [
{"role": "user", "content": content}
]

response = client.messages.create(
model=model,
Expand All @@ -57,12 +61,15 @@ def call(system_prompt: str, prompt: str, files: list[Path] | None = None) -> st
stream=False,
temperature=0.9,
)
output = response.content[0].text
content = cast(anthropic.types.TextBlock, response.content[0])
output = content.text
log.vvv(f"Claude response: {output}")
return output


def create_content_list(files: list[Path]):
def create_content_list(
files: list[Path],
) -> list[anthropic.types.ImageBlockParam | anthropic.types.TextBlockParam]:
content = []
for path in files:
with path.open("rb") as f:
Expand Down
8 changes: 5 additions & 3 deletions cog_safe_push/cog.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import subprocess
import re
import replicate
import subprocess

from replicate.model import Model

from . import log


def push(model: replicate.model.Model) -> str:
def push(model: Model) -> str:
url = f"r8.im/{model.owner}/{model.name}"
log.info(f"Pushing to {url}")
process = subprocess.Popen(
Expand All @@ -16,6 +17,7 @@ def push(model: replicate.model.Model) -> str:
)

sha256_id = None
assert process.stdout
for line in process.stdout:
log.v(line.rstrip()) # Print output in real-time
if "latest: digest: sha256:" in line:
Expand Down
8 changes: 4 additions & 4 deletions cog_safe_push/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,23 @@ class SchemaLintError(Exception):
pass


class IncompatibleSchema(Exception):
class IncompatibleSchemaError(Exception):
pass


class OutputsDontMatch(Exception):
class OutputsDontMatchError(Exception):
pass


class FuzzError(Exception):
pass


class PredictionTimeout(Exception):
class PredictionTimeoutError(Exception):
pass


class PredictionFailed(Exception):
class PredictionFailedError(Exception):
pass


Expand Down
17 changes: 10 additions & 7 deletions cog_safe_push/lint.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from pathlib import Path
import subprocess
from pathlib import Path
from typing import Any

import yaml

from .exceptions import CodeLintError


def lint_predict():
with open("cog.yaml", "r") as f:
cog_config = yaml.safe_load(f)

cog_config = load_cog_config()
predict_config = cog_config.get("predict", "")
predict_filename = predict_config.split(":")[0]

Expand All @@ -19,9 +19,7 @@ def lint_predict():


def lint_train():
with open("cog.yaml", "r") as f:
cog_config = yaml.safe_load(f)

cog_config = load_cog_config()
train_config = cog_config.get("train", "")
train_filename = train_config.split(":")[0]

Expand All @@ -31,6 +29,11 @@ def lint_train():
lint_file(train_filename)


def load_cog_config() -> dict[str, Any]:
with Path("cog.yaml").open() as f:
return yaml.safe_load(f)


def lint_file(filename: str):
if not Path(filename).exists():
raise CodeLintError(f"{filename} doesn't exist")
Expand Down
19 changes: 10 additions & 9 deletions cog_safe_push/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections import defaultdict
import re
import argparse
import re
from collections import defaultdict

import replicate
from replicate.exceptions import ReplicateException
from replicate.exceptions import ReplicateError
from replicate.model import Model

from . import cog, lint, schema, predict, log
from . import cog, lint, log, predict, schema


def main():
Expand Down Expand Up @@ -244,8 +246,7 @@ def parse_inputs(inputs_list: list[str]) -> dict[str, list[predict.WeightedInput
except ValueError:
raise ValueError(f"Invalid input format: {input_str}")

inputs = make_weighted_inputs(input_values, input_weights)
return inputs
return make_weighted_inputs(input_values, input_weights)


def make_weighted_inputs(
Expand Down Expand Up @@ -310,7 +311,7 @@ def parse_input_weight_percent(value_str: str) -> tuple[str, float | None]:
return value_str, None


def get_or_create_model(model_owner, model_name, hardware) -> replicate.model.Model:
def get_or_create_model(model_owner, model_name, hardware) -> Model:
model = get_model(model_owner, model_name)

if not model:
Expand All @@ -329,10 +330,10 @@ def get_or_create_model(model_owner, model_name, hardware) -> replicate.model.Mo
return model


def get_model(owner, name) -> replicate.model.Model:
def get_model(owner, name) -> Model | None:
try:
model = replicate.models.get(f"{owner}/{name}")
except ReplicateException as e:
except ReplicateError as e:
if e.status == 404:
return None
raise
Expand Down
67 changes: 34 additions & 33 deletions cog_safe_push/predict.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
import json
import math
import random
import tempfile
import time
from contextlib import contextmanager
from dataclasses import dataclass
import math
from pathlib import Path
import json
import os
from urllib.parse import urlparse
from contextlib import contextmanager
import time
from typing import Any, Iterator
import tempfile
import requests
from urllib.parse import urlparse

import replicate
import requests
from PIL import Image
from replicate.model import Model

from . import ai, log, schema
from .exceptions import (
OutputsDontMatch,
PredictionTimeout,
PredictionFailed,
FuzzError,
AIError,
FuzzError,
OutputsDontMatchError,
PredictionFailedError,
PredictionTimeoutError,
)
from . import ai, schema, log


@dataclass
Expand All @@ -40,10 +41,10 @@ class WeightedInputValue:


def check_outputs_match(
test_model: replicate.model.Model,
model: replicate.model.Model,
test_model: Model,
model: Model,
train: bool,
train_destination: str | None,
train_destination: Model | None,
timeout_seconds: float,
inputs: dict[str, list[WeightedInputValue]],
disabled_inputs: list[str],
Expand Down Expand Up @@ -73,15 +74,15 @@ def check_outputs_match(
)
matches, error = outputs_match(test_output, output, is_deterministic)
if not matches:
raise OutputsDontMatch(
raise OutputsDontMatchError(
f"Outputs don't match:\n\ntest output:\n{test_output}\n\nmodel output:\n{output}\n\n{error}"
)


def fuzz_model(
model: replicate.model.Model,
model: Model,
train: bool,
train_destination: str | None,
train_destination: Model | None,
timeout_seconds: float,
max_iterations: int | None,
inputs: dict[str, list[WeightedInputValue]],
Expand Down Expand Up @@ -111,13 +112,13 @@ def fuzz_model(
inputs=predict_inputs,
timeout_seconds=predict_timeout,
)
except PredictionTimeout:
except PredictionTimeoutError:
if not successful_predictions:
log.warning(
f"No predictions succeeded in {timeout_seconds}, try increasing --fuzz-seconds"
)
return
except PredictionFailed as e:
except PredictionFailedError as e:
raise FuzzError(e)
if not output:
raise FuzzError("No output")
Expand Down Expand Up @@ -345,9 +346,9 @@ def make_predict_inputs(


def predict(
model: replicate.model.Model,
model: Model,
train: bool,
train_destination: str | None,
train_destination: Model | None,
inputs: dict,
timeout_seconds: float,
):
Expand All @@ -356,11 +357,12 @@ def predict(
)

if train:
assert train_destination
version_ref = f"{model.owner}/{model.name}:{model.versions.list()[0].id}"
prediction = replicate.trainings.create(
version=version_ref,
input=inputs,
destination=train_destination,
destination=f"{train_destination.owner}/{train_destination.name}",
)
else:
prediction = replicate.predictions.create(
Expand All @@ -371,11 +373,11 @@ def predict(
while prediction.status not in ["succeeded", "failed", "canceled"]:
time.sleep(0.5)
if time.time() - start_time > timeout_seconds:
raise PredictionTimeout()
raise PredictionTimeoutError()
prediction.reload()

if prediction.status == "failed":
raise PredictionFailed(prediction.error)
raise PredictionFailedError(prediction.error)

log.vv(f"Got output: {truncate(prediction.output)}")

Expand Down Expand Up @@ -413,7 +415,7 @@ def outputs_match(test_output, output, is_deterministic: bool) -> tuple[bool, st
if isinstance(output, dict):
if test_output.keys() != output.keys():
return False, "Dict keys don't match"
for key in output.keys():
for key in output:
matches, message = outputs_match(
test_output[key], output[key], is_deterministic
)
Expand Down Expand Up @@ -452,8 +454,7 @@ def strings_match(s1: str, s2: str, is_deterministic: bool) -> tuple[bool, str]:
)
if fuzzy_match:
return True, ""
else:
return False, "Strings aren't similar"
return False, "Strings aren't similar"


def urls_match(url1: str, url2: str, is_deterministic: bool) -> tuple[bool, str]:
Expand Down Expand Up @@ -490,8 +491,8 @@ def is_video(url: str) -> bool:


def extensions_match(url1: str, url2: str) -> bool:
_, ext1 = os.path.splitext(urlparse(url1).path)
_, ext2 = os.path.splitext(urlparse(url2).path)
ext1 = Path(urlparse(url1).path).suffix
ext2 = Path(urlparse(url2).path).suffix
return ext1.lower() == ext2.lower()


Expand Down Expand Up @@ -565,7 +566,7 @@ def videos_match(url1: str, url2: str, is_deterministic: bool) -> tuple[bool, st

@contextmanager
def download(url: str) -> Iterator[Path]:
suffix = os.path.splitext(url)[1]
suffix = Path(url).suffix
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
response = requests.get(url)
response.raise_for_status()
Expand All @@ -575,7 +576,7 @@ def download(url: str) -> Iterator[Path]:
try:
yield Path(tmp_file.name)
finally:
os.unlink(tmp_file.name)
tmp_file.unlink()


def truncate(s, max_length=500) -> str:
Expand Down
1 change: 1 addition & 0 deletions cog_safe_push/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def wrapper_retry(*args, **kwargs):
else:
log.warning(f"Giving up after {attempts} attempts")
raise
return None

return wrapper_retry

Expand Down
Loading

0 comments on commit 45edc4c

Please sign in to comment.