Skip to content

Commit

Permalink
Improve match_prompt and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
andreasjansson committed Nov 29, 2024
1 parent 06dac7f commit 576ff18
Show file tree
Hide file tree
Showing 34 changed files with 123 additions and 13 deletions.
13 changes: 7 additions & 6 deletions cog_safe_push/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
def boolean(
prompt: str, files: list[Path] | None = None, include_file_metadata: bool = False
) -> bool:
system_prompt = "You only answer YES or NO, and absolutely nothing else. Your outputs will be used in a programmatic context so it's important that you only ever answer with either the string YES or the string NO."
system_prompt = "You only answer YES or NO, and absolutely nothing else. Your response will be used in a programmatic context so it's important that you only ever answer with either the string YES or the string NO."
#system_prompt = "You are a helpful assistant"
output = call(
system_prompt=system_prompt,
prompt=prompt.strip(),
Expand Down Expand Up @@ -51,16 +52,16 @@ def call(
if not api_key:
raise ArgumentError("ANTHROPIC_API_KEY is not defined")

model = "claude-3-5-sonnet-20240620"
model = "claude-3-5-sonnet-20241022"
client = anthropic.Anthropic(api_key=api_key)

if files:
content = create_content_list(files)

if include_file_metadata:
prompt += "\n\nMetadata about attached files:\n"
for i, path in enumerate(files):
prompt += f"{i}) " + file_info(path) + "\n"
prompt += "\n\nMetadata for the attached file(s):\n"
for path in files:
prompt += f"* " + file_info(path) + "\n"

content.append({"type": "text", "text": prompt})

Expand All @@ -79,7 +80,7 @@ def call(
system=system_prompt,
max_tokens=4096,
stream=False,
temperature=0.9,
temperature=1.0,
)
content = cast(anthropic.types.TextBlock, response.content[0])
output = content.text
Expand Down
32 changes: 26 additions & 6 deletions cog_safe_push/match_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,44 @@ def output_matches_prompt(output: Any, prompt: str) -> tuple[bool, str]:
urls = output if isinstance(output, list) else list(output.values())

with download_many(urls) as tmp_files:
claude_prompt = "Does this output match the following description?"
claude_prompt = """You are part of an automatic evaluation that compares media (text, audio, image, video, etc.) to captions. I want to know if the caption matches the text or file..
"""
if urls:
claude_prompt = "Does this output or the contents of the output file(s) match the following description? The contents of the output file(s) are attached."
claude_prompt += f"""Does this file(s) and the attached content of the file(s) match the description? Pay close attention to the metadata about the attached files which is included below, especially if the description mentions file type, image dimensions, or any other aspect that is described in the metadata. Do not infer file type or image dimensions from the image content, but from the attached metadata.
Description to evaluate: {prompt}
claude_prompt += f"""
Filename(s): {output}"""
else:
claude_prompt += f"""Do these outputs match the following description?
Output: {output}
Description: {prompt}"""
Description to evaluate: {prompt}"""

matches = ai.boolean(
claude_prompt,
files=tmp_files,
include_file_metadata=True,
)

if matches:
return True, ""
if matches:
return True, ""

# If it's not a match, do best of three to avoid flaky tests
multiple_matches = [matches]
for _ in range(2):
matches = ai.boolean(
claude_prompt,
files=tmp_files,
include_file_metadata=True,
)
multiple_matches.append(matches)

if sum(multiple_matches) >= 2:
return True, ""

return False, "AI determined that the output does not match the description"


Expand Down
2 changes: 1 addition & 1 deletion cog_safe_push/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def make_predict_inputs(
Generate a json payload for the {input_name} schema.
If inputs have format=uri, you should use one of the following media URLs:
If inputs have format=uri, you should use one of the following media URLs (pick an appropriate URL for the the input, e.g. one of the image examples below if the input expects an image):
Videos:
* https://storage.googleapis.com/cog-safe-push-public/harry-truman.webm
* https://storage.googleapis.com/cog-safe-push-public/mariner-launch.ogv
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added integration-test/assets/images/negative/horse.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added integration-test/assets/images/positive/car.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
89 changes: 89 additions & 0 deletions integration-test/test_output_matches_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import pytest
from pathlib import Path
from cog_safe_push.match_outputs import output_matches_prompt
from cog_safe_push import log


# log.set_verbosity(3)

positive_images = {
"https://replicate.delivery/xezq/DyBtXhblvL7MApRBqeqiYnkw1xS9WpEf3nA7GRIlYFkQL31TA/out-0.webp": [
"A bird",
"A red bird",
"A webp image of a bird",
"A webp image of a red bird",
],
"https://replicate.delivery/czjl/QFrZ9RF8VroFM5Ml9MKt3rm0vP8ZHTWaqfO1oT6bouj0m76JA/tmpn888w5a8.jpg": [
"A jpg image of a formula one car",
"a jpg image of a car",
"A jpg image",
"Formula 1 car",
"car",
],
"https://replicate.delivery/czjl/8C4OJCR6w7rQEFeernSerHH5e3xe2f9cYYsGTW8k5Eob57d9E/tmpjwitpu7f.png": [
"480x320px png image",
"480x320px image of a formula one car",
],
"https://replicate.delivery/czjl/41MrDvJli4ZCAxeYMhEcKvAHNNcPaWJTicjqp7GYNFza476JA/tmpzs4y7hto.png": [
"an anime illustration of a lake",
"an anime illustration",
"a lake",
],
}

negative_images = {
"https://replicate.delivery/xezq/DyBtXhblvL7MApRBqeqiYnkw1xS9WpEf3nA7GRIlYFkQL31TA/out-0.webp": [
"A cat",
"A blue bird",
"A png image of a bird",
"A webp image of a blue bird",
],
"https://replicate.delivery/czjl/QFrZ9RF8VroFM5Ml9MKt3rm0vP8ZHTWaqfO1oT6bouj0m76JA/tmpn888w5a8.jpg": [
"A jpg image of a tractor",
"a webp image of a road",
"A webp image",
"motorcycle",
],
"https://replicate.delivery/czjl/8C4OJCR6w7rQEFeernSerHH5e3xe2f9cYYsGTW8k5Eob57d9E/tmpjwitpu7f.png": [
"100x100px png image",
"100x100px image of a formula one car",
],
"https://replicate.delivery/czjl/41MrDvJli4ZCAxeYMhEcKvAHNNcPaWJTicjqp7GYNFza476JA/tmpzs4y7hto.png": [
"an anime illustration of a cat",
"a 3d render",
"a potato patch",
],
}


def get_captioned_images(
image_dict: dict[str, list[str]], iterations_per_image=3
) -> list[tuple[str, str]]:
ret = []
for url, captions in image_dict.items():
for _ in range(iterations_per_image):
for caption in captions:
ret.append((url, caption))
return ret


@pytest.mark.parametrize(
"file_url,prompt",
get_captioned_images(positive_images),
ids=lambda x: Path(x[0]).name if isinstance(x, tuple) else x,
)
def test_image_output_matches_prompt_positive(file_url: str, prompt: str):
"""Test that images in the positive directory match their prompts."""
matches, message = output_matches_prompt(file_url, prompt)
assert matches, f"Image should match prompt '{prompt}'. Error: {message}"


@pytest.mark.parametrize(
"file_url,prompt",
get_captioned_images(negative_images),
ids=lambda x: Path(x[0]).name if isinstance(x, tuple) else x,
)
def test_image_output_matches_prompt_negative(file_url: str, prompt: str):
"""Test that images in the negative directory don't match their prompts."""
matches, _ = output_matches_prompt(file_url, prompt)
assert not matches, f"Image should not match prompt '{prompt}'"

0 comments on commit 576ff18

Please sign in to comment.