Skip to content

Commit

Permalink
Merge pull request #12 from replicate/tasks
Browse files Browse the repository at this point in the history
Tasks
  • Loading branch information
andreasjansson authored Dec 1, 2024
2 parents 576ff18 + 453c283 commit bdd264e
Show file tree
Hide file tree
Showing 24 changed files with 816 additions and 616 deletions.
54 changes: 54 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,57 @@ jobs:
- name: Run pytest
run: |
./script/unit-test
integration-test:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'

- name: Install dependencies
run: |
pip install -r requirements-test.txt
pip install .
- name: Run pytest
env:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
run: |
./script/integration-test
end-to-end-test:
runs-on: ubuntu-latest-4-cores

steps:
- uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'

- name: Install dependencies
run: |
pip install -r requirements-test.txt
pip install .
- name: Install Cog
run: |
sudo curl -o /usr/local/bin/cog -L "https://github.com/replicate/cog/releases/latest/download/cog_$(uname -s)_$(uname -m)"
sudo chmod +x /usr/local/bin/cog
- name: cog login
run: |
echo ${{ secrets.COG_TOKEN }} | cog login --token-stdin
- name: Run pytest
env:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }}
run: |
./script/end-to-end-test
104 changes: 65 additions & 39 deletions cog_safe_push/ai.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import functools
import json
import mimetypes
import os
Expand All @@ -10,16 +11,36 @@

from . import log
from .exceptions import AIError, ArgumentError
from .retry import retry


@retry(3)
def boolean(
def async_retry(attempts=3):
def decorator_retry(func):
@functools.wraps(func)
async def wrapper_retry(*args, **kwargs):
for attempt in range(1, attempts + 1):
try:
return await func(*args, **kwargs)
except Exception as e:
log.warning(f"Exception occurred: {e}")
if attempt < attempts:
log.warning(f"Retrying attempt {attempt}/{attempts}")
else:
log.warning(f"Giving up after {attempts} attempts")
raise
return None

return wrapper_retry

return decorator_retry


@async_retry(3)
async def boolean(
prompt: str, files: list[Path] | None = None, include_file_metadata: bool = False
) -> bool:
system_prompt = "You only answer YES or NO, and absolutely nothing else. Your response will be used in a programmatic context so it's important that you only ever answer with either the string YES or the string NO."
#system_prompt = "You are a helpful assistant"
output = call(
# system_prompt = "You are a helpful assistant"
output = await call(
system_prompt=system_prompt,
prompt=prompt.strip(),
files=files,
Expand All @@ -32,17 +53,17 @@ def boolean(
raise AIError(f"Failed to parse output as YES/NO: {output}")


@retry(3)
def json_object(prompt: str, files: list[Path] | None = None) -> dict:
@async_retry(3)
async def json_object(prompt: str, files: list[Path] | None = None) -> dict:
system_prompt = "You always respond with valid JSON, and nothing else (no backticks, etc.). Your outputs will be used in a programmatic context."
output = call(system_prompt=system_prompt, prompt=prompt.strip(), files=files)
output = await call(system_prompt=system_prompt, prompt=prompt.strip(), files=files)
try:
return json.loads(output)
except json.JSONDecodeError:
raise AIError(f"Failed to parse output as JSON: {output}")


def call(
async def call(
system_prompt: str,
prompt: str,
files: list[Path] | None = None,
Expand All @@ -53,36 +74,41 @@ def call(
raise ArgumentError("ANTHROPIC_API_KEY is not defined")

model = "claude-3-5-sonnet-20241022"
client = anthropic.Anthropic(api_key=api_key)

if files:
content = create_content_list(files)

if include_file_metadata:
prompt += "\n\nMetadata for the attached file(s):\n"
for path in files:
prompt += f"* " + file_info(path) + "\n"

content.append({"type": "text", "text": prompt})

log.vvv(f"Claude prompt with {len(files)} files: {prompt}")
else:
content = prompt
log.vvv(f"Claude prompt: {prompt}")

messages: list[anthropic.types.MessageParam] = [
{"role": "user", "content": content}
]

response = client.messages.create(
model=model,
messages=messages,
system=system_prompt,
max_tokens=4096,
stream=False,
temperature=1.0,
)
content = cast(anthropic.types.TextBlock, response.content[0])
client = anthropic.AsyncAnthropic(api_key=api_key)

try:
if files:
content = create_content_list(files)

if include_file_metadata:
prompt += "\n\nMetadata for the attached file(s):\n"
for path in files:
prompt += "* " + file_info(path) + "\n"

content.append({"type": "text", "text": prompt})

log.vvv(f"Claude prompt with {len(files)} files: {prompt}")
else:
content = prompt
log.vvv(f"Claude prompt: {prompt}")

messages: list[anthropic.types.MessageParam] = [
{"role": "user", "content": content}
]

response = await client.messages.create(
model=model,
messages=messages,
system=system_prompt,
max_tokens=4096,
stream=False,
temperature=1.0,
)
content = cast(anthropic.types.TextBlock, response.content[0])

finally:
await client.close()

output = content.text
log.vvv(f"Claude response: {output}")
return output
Expand Down
4 changes: 2 additions & 2 deletions cog_safe_push/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ class FuzzConfig(BaseModel):

fixed_inputs: dict[str, InputScalar] = {}
disabled_inputs: list[str] = []
duration: int = DEFAULT_FUZZ_DURATION
iterations: int | None = None
iterations: int = 10


class PredictConfig(BaseModel):
Expand Down Expand Up @@ -68,6 +67,7 @@ class Config(BaseModel):
predict: PredictConfig | None = None
train: TrainConfig | None = None
dockerfile: str | None = None
parallel: int = 4

def override(self, field: str, args: argparse.Namespace, arg: str):
if hasattr(args, arg) and getattr(args, arg) is not None:
Expand Down
Loading

0 comments on commit bdd264e

Please sign in to comment.