Skip to content

Commit

Permalink
Config files and test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
andreasjansson committed Sep 11, 2024
1 parent b3bba40 commit 895e39c
Show file tree
Hide file tree
Showing 14 changed files with 1,377 additions and 602 deletions.
146 changes: 113 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,17 @@ Both the creation of model inputs and comparison of model outputs is handled by
### Full help text

```
usage: cog-safe-push [-h] [--test-hardware TEST_HARDWARE]
[--test-model TEST_MODEL] [--test-only]
[-i INPUTS] [-x DISABLED_INPUTS]
usage: cog-safe-push [-h] [--config CONFIG] [--help-config]
[--test-model TEST_MODEL] [--no-push]
[--test-hardware TEST_HARDWARE]
[--no-compare-outputs]
[--fuzz-seconds FUZZ_SECONDS]
[--no-fuzz-user-inputs] [-v]
model
[--predict-timeout PREDICT_TIMEOUT]
[--test-case TEST_CASES]
[--fuzz-fixed-inputs FUZZ_FIXED_INPUTS]
[--fuzz-disabled-inputs FUZZ_DISABLED_INPUTS]
[--fuzz-duration FUZZ_DURATION]
[--fuzz-iterations FUZZ_ITERATIONS] [-v]
[model]
Safely push a Cog model, with tests
Expand All @@ -46,41 +50,117 @@ positional arguments:
options:
-h, --help show this help message and exit
--test-hardware TEST_HARDWARE
Hardware to run the test model on. Only
used when creating the test model, if it
doesn't already exist.
--config CONFIG Path to the YAML config file. If --config is not
passed, ./cog-safe-push.yaml will be used, if it
exists. Any arguments you pass in will override
fields on the predict configuration stanza.
--help-config Print a default cog-safe-push.yaml config to
stdout.
--test-model TEST_MODEL
Replicate model to test on, in the format
<username>/<model-name>. If omitted,
<model>-test will be used. The test model
is created automatically if it doesn't
exist already
--test-only Only test the model, don't push it to
<model>
-i INPUTS, --input INPUTS
Input key-value pairs in the format
<key>=<value>. These will be used when
comparing outputs, as well as during
fuzzing (unless --no-fuzz-user-inputs is
specified)
-x DISABLED_INPUTS, --disable-input DISABLED_INPUTS
Don't pass values to these inputs when
comparing outputs or fuzzing
<model>-test will be used. The test model is
created automatically if it doesn't exist
already
--no-push Only test the model, don't push it to <model>
--test-hardware TEST_HARDWARE
Hardware to run the test model on. Only used
when creating the test model, if it doesn't
already exist.
--no-compare-outputs Don't make predictions to compare that
prediction outputs match
--fuzz-seconds FUZZ_SECONDS
Number of seconds to run fuzzing. Set to
0 for no fuzzing
--no-fuzz-user-inputs
Don't use -i/--input values when fuzzing,
instead use AI-generated values for every
input
prediction outputs match the current version
--predict-timeout PREDICT_TIMEOUT
Timeout (in seconds) for predictions. Default:
300
--test-case TEST_CASES
Inputs and expected output that will be used for
testing, you can provide multiple --test-case
options for multiple test cases. The first test
case will be used when comparing outputs to the
current version. Each --test-case is semicolon-
separated key-value pairs in the format
'<key1>=<value1>;<key2=value2>[<output-
checker>]'. <output-checker> can either be
'==<exact-string-or-url>' or '~=<ai-prompt>'. If
you use '==<exact-string-or-url>' then the
output of the model must match exactly the
string or url you specify. If you use '~=<ai-
prompt>' then the AI will verify your output
based on <ai-prompt>. If you omit <output-
checker>, it will just verify that the
prediction doesn't throw an error.
--fuzz-fixed-inputs FUZZ_FIXED_INPUTS
Inputs that should have fixed values during
fuzzing. All other non-disabled input values
will be generated by AI. If no test cases are
specified, these will also be used when
comparing outputs to the current version.
Semicolon-separated key-value pairs in the
format '<key1>=<value1>;<key2=value2>' (etc.)
--fuzz-disabled-inputs FUZZ_DISABLED_INPUTS
Don't pass values for these inputs during
fuzzing. Semicolon-separated keys in the format
'<key1>;<key2>' (etc.). If no test cases are
specified, these will also be disabled when
comparing outputs to the current version.
--fuzz-duration FUZZ_DURATION
Number of seconds to run fuzzing. Set to 0 for
no fuzzing. Default: 300
--fuzz-iterations FUZZ_ITERATIONS
Maximum number of iterations to run fuzzing.
Leave blank to run for the full --fuzz-seconds
-v, --verbose Increase verbosity level (max 3)
```

### Using a configuration file

You can use a configuration file instead of passing all arguments on the command line. If you create a file called `cog-safe-push.yaml` in your Cog directory, it will be used. Any command line arguments you pass will override the values in the config file.

```
$ cog-safe-push --help-config
model: <model>
predict:
compare_outputs: true
fuzz:
disabled_inputs: []
duration: 300
fixed_inputs: {}
predict_timeout: 300
test_cases:
- exact_string: <exact string match>
inputs:
<input1>: <value1>
- inputs:
<input2>: <value2>
match_url: <match output image against url>
- inputs:
<input3>: <value3>
match_prompt: <match output using AI prompt, e.g. 'an image of a cat'>
test_hardware: <hardware, e.g. cpu>
test_model: <test model, or empty to append '-test' to model>
train:
destination: <generated prediction model, e.g. andreasjansson/test-predict. leave
blank to append '-dest' to the test model>
destination_hardware: <hardware for the created prediction model, e.g. cpu>
fuzz:
disabled_inputs: []
duration: 300
fixed_inputs: {}
test_cases:
- exact_string: <exact string match>
inputs:
<input1>: <value1>
- inputs:
<input2>: <value2>
match_url: <match output image against url>
- inputs:
<input3>: <value3>
match_prompt: <match output using AI prompt, e.g. 'an image of a cat'>
train_timeout: 300
# values between < and > should be edited
```

## Nota bene

* If you can't figure out the right name to use for `--test-hardware`, create the test model manually (setting hardware in the UI), leave `--test-hardware` blank, and set `--test-model=<test-username>/<test-model-name>` instead
* This is alpha software. If you find a bug, please open an issue!
37 changes: 32 additions & 5 deletions cog_safe_push/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,28 @@
import json
import mimetypes
import os
import subprocess
from pathlib import Path
from typing import cast

import anthropic

from . import log
from .exceptions import AIError
from .exceptions import AIError, ArgumentError
from .retry import retry


@retry(3)
def boolean(prompt: str, files: list[Path] | None = None) -> bool:
def boolean(
prompt: str, files: list[Path] | None = None, include_file_metadata: bool = False
) -> bool:
system_prompt = "You only answer YES or NO, and absolutely nothing else. Your outputs will be used in a programmatic context so it's important that you only ever answer with either the string YES or the string NO."
output = call(system_prompt=system_prompt, prompt=prompt.strip(), files=files)
output = call(
system_prompt=system_prompt,
prompt=prompt.strip(),
files=files,
include_file_metadata=include_file_metadata,
)
if output == "YES":
return True
if output == "NO":
Expand All @@ -33,17 +41,29 @@ def json_object(prompt: str, files: list[Path] | None = None) -> dict:
raise AIError(f"Failed to parse output as JSON: {output}")


def call(system_prompt: str, prompt: str, files: list[Path] | None = None) -> str:
def call(
system_prompt: str,
prompt: str,
files: list[Path] | None = None,
include_file_metadata: bool = False,
) -> str:
api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
raise ValueError("ANTHROPIC_API_KEY is not defined")
raise ArgumentError("ANTHROPIC_API_KEY is not defined")

model = "claude-3-5-sonnet-20240620"
client = anthropic.Anthropic(api_key=api_key)

if files:
content = create_content_list(files)

if include_file_metadata:
prompt += "\n\nMetadata about attached files:\n"
for i, path in enumerate(files):
prompt += f"{i}) " + file_info(path) + "\n"

content.append({"type": "text", "text": prompt})

log.vvv(f"Claude prompt with {len(files)} files: {prompt}")
else:
content = prompt
Expand Down Expand Up @@ -91,3 +111,10 @@ def create_content_list(
)

return content


def file_info(p: Path) -> str:
result = subprocess.run(
["file", "-b", str(p)], capture_output=True, text=True, check=True
)
return result.stdout.strip()
85 changes: 85 additions & 0 deletions cog_safe_push/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import argparse

from pydantic import BaseModel, model_validator

from .exceptions import ArgumentError

DEFAULT_PREDICT_TIMEOUT = 300
DEFAULT_FUZZ_DURATION = 300

InputScalar = bool | int | float | str | list[int] | list[str] | list[float]


class TestCase(BaseModel):
inputs: dict[str, InputScalar]
exact_string: str | None = None
match_url: str | None = None
match_prompt: str | None = None

@model_validator(mode="after")
def check_mutually_exclusive(self):
set_fields = sum(
getattr(self, field) is not None
for field in ["exact_string", "match_url", "match_prompt"]
)
if set_fields > 1:
raise ArgumentError(
"At most one of 'exact_string', 'match_url', or 'match_prompt' must be set"
)
return self


class FuzzConfig(BaseModel):
fixed_inputs: dict[str, InputScalar] = {}
disabled_inputs: list[str] = []
duration: int = DEFAULT_FUZZ_DURATION
iterations: int | None = None


class PredictConfig(BaseModel):
compare_outputs: bool = True
predict_timeout: int = DEFAULT_PREDICT_TIMEOUT
test_cases: list[TestCase] = []
fuzz: FuzzConfig | None = None


class TrainConfig(BaseModel):
destination: str | None = None
destination_hardware: str = "cpu"
train_timeout: int = DEFAULT_PREDICT_TIMEOUT
test_cases: list[TestCase] = []
fuzz: FuzzConfig | None = None


class Config(BaseModel):
model: str
test_model: str | None = None
test_hardware: str = "cpu"
predict: PredictConfig | None = None
train: TrainConfig | None = None

def override(self, field: str, args: argparse.Namespace, arg: str):
if hasattr(args, arg) and getattr(args, arg) is not None:
setattr(self, field, getattr(args, arg))

def predict_override(self, field: str, args: argparse.Namespace, arg: str):
if not hasattr(args, arg):
return
if not self.predict:
raise ArgumentError(
f"--config is used but is missing a predict section and you are overriding predict {field} in the command line arguments."
)
setattr(self.predict, field, getattr(args, arg))

def predict_fuzz_override(self, field: str, args: argparse.Namespace, arg: str):
if not hasattr(args, arg):
return
if not self.predict:
raise ArgumentError(
f"--config is used but is missing a predict section and you are overriding fuzz {field} in the command line arguments."
)
if not self.predict.fuzz:
raise ArgumentError(
f"--config is used but is missing a predict.fuzz section and you are overriding fuzz {field} in the command line arguments."
)
setattr(self.predict.fuzz, field, getattr(args, arg))
26 changes: 19 additions & 7 deletions cog_safe_push/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,40 @@
class CodeLintError(Exception):
class CogSafePushError(Exception):
pass


class SchemaLintError(Exception):
class ArgumentError(CogSafePushError):
pass


class IncompatibleSchemaError(Exception):
class CodeLintError(CogSafePushError):
pass


class OutputsDontMatchError(Exception):
class SchemaLintError(CogSafePushError):
pass


class FuzzError(Exception):
class IncompatibleSchemaError(CogSafePushError):
pass


class PredictionTimeoutError(Exception):
class OutputsDontMatchError(CogSafePushError):
pass


class PredictionFailedError(Exception):
class FuzzError(CogSafePushError):
pass


class PredictionTimeoutError(CogSafePushError):
pass


class PredictionFailedError(CogSafePushError):
pass


class TestCaseFailedError(CogSafePushError):
pass


Expand Down
Loading

0 comments on commit 895e39c

Please sign in to comment.