Skip to content

Commit

Permalink
Fix a couple of bugs in the base64 file_encoding_strategy (#398)
Browse files Browse the repository at this point in the history
This commit adds tests for the `file_encoding_strategy` argument for
`replicate.run()` and fixes two bugs that surfaced:

1. `replicate.run()` would convert the file provided into base64 encoded
data but not a valid data URL. We now use the `base64_encode_file`
function used for outputs.
2. `replicate.async_run()` accepted but did not use the
`file_encoding_strategy` flag at all. This is fixed, though it is worth
noting that `base64_encode_file` is not optimized for async workflows
and will block. This might be okay as the file sizes expected for data
URL payloads should be very small.
  • Loading branch information
aron authored Nov 15, 2024
1 parent 4fdd78f commit 07c8fbb
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 3 deletions.
10 changes: 7 additions & 3 deletions replicate/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def encode_json(
return encode_json(file, client, file_encoding_strategy)
if isinstance(obj, io.IOBase):
if file_encoding_strategy == "base64":
return base64.b64encode(obj.read()).decode("utf-8")
return base64_encode_file(obj)
else:
return client.files.create(obj).urls["get"]
if HAS_NUMPY:
Expand Down Expand Up @@ -77,9 +77,13 @@ async def async_encode_json(
]
if isinstance(obj, Path):
with obj.open("rb") as file:
return encode_json(file, client, file_encoding_strategy)
return await async_encode_json(file, client, file_encoding_strategy)
if isinstance(obj, io.IOBase):
return (await client.files.async_create(obj)).urls["get"]
if file_encoding_strategy == "base64":
# TODO: This should ideally use an async based file reader path.
return base64_encode_file(obj)
else:
return (await client.files.async_create(obj)).urls["get"]
if HAS_NUMPY:
if isinstance(obj, np.integer): # type: ignore
return int(obj)
Expand Down
129 changes: 129 additions & 0 deletions tests/test_run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import asyncio
import io
import json
import sys
from email.message import EmailMessage
from email.parser import BytesParser
from email.policy import HTTP
from typing import AsyncIterator, Iterator, Optional, cast

import httpx
Expand Down Expand Up @@ -581,6 +586,130 @@ async def test_run_with_model_error(mock_replicate_api_token):
assert excinfo.value.prediction.status == "failed"


@pytest.mark.asyncio
@pytest.mark.parametrize("async_flag", [True, False])
async def test_run_with_file_input_files_api(async_flag, mock_replicate_api_token):
router = respx.Router(base_url="https://api.replicate.com/v1")
mock_predictions_create = router.route(method="POST", path="/predictions").mock(
return_value=httpx.Response(
201,
json=_prediction_with_status("processing"),
)
)
router.route(
method="GET",
path="/models/test/example/versions/v1",
).mock(
return_value=httpx.Response(
200,
json=_version_with_schema(),
)
)
mock_files_create = router.route(
method="POST",
path="/files",
).mock(
return_value=httpx.Response(
200,
json={
"id": "file1",
"name": "file.png",
"content_type": "image/png",
"size": 10,
"etag": "123",
"checksums": {},
"metadata": {},
"created_at": "",
"expires_at": "",
"urls": {"get": "https://api.replicate.com/files/file.txt"},
},
)
)
router.route(host="api.replicate.com").pass_through()

client = Client(
api_token="test-token", transport=httpx.MockTransport(router.handler)
)
if async_flag:
await client.async_run(
"test/example:v1",
input={"file": io.BytesIO(initial_bytes=b"hello world")},
)
else:
client.run(
"test/example:v1",
input={"file": io.BytesIO(initial_bytes=b"hello world")},
)

assert mock_predictions_create.called
prediction_payload = json.loads(mock_predictions_create.calls[0].request.content)
assert (
prediction_payload.get("input", {}).get("file")
== "https://api.replicate.com/files/file.txt"
)

# Validate the Files API request
req = mock_files_create.calls[0].request
body = req.content
content_type = req.headers["Content-Type"]

# Parse the multipart data
parser = BytesParser(EmailMessage, policy=HTTP)
headers = f"Content-Type: {content_type}\n\n".encode()
parsed_message_generator = parser.parsebytes(headers + body).walk()
next(parsed_message_generator) # wrapper
input_file = next(parsed_message_generator)
assert mock_files_create.called
assert input_file.get_content() == b"hello world"
assert input_file.get_content_type() == "application/octet-stream"


@pytest.mark.asyncio
@pytest.mark.parametrize("async_flag", [True, False])
async def test_run_with_file_input_data_url(async_flag, mock_replicate_api_token):
router = respx.Router(base_url="https://api.replicate.com/v1")
mock_predictions_create = router.route(method="POST", path="/predictions").mock(
return_value=httpx.Response(
201,
json=_prediction_with_status("processing"),
)
)
router.route(
method="GET",
path="/models/test/example/versions/v1",
).mock(
return_value=httpx.Response(
200,
json=_version_with_schema(),
)
)
router.route(host="api.replicate.com").pass_through()

client = Client(
api_token="test-token", transport=httpx.MockTransport(router.handler)
)

if async_flag:
await client.async_run(
"test/example:v1",
input={"file": io.BytesIO(initial_bytes=b"hello world")},
file_encoding_strategy="base64",
)
else:
client.run(
"test/example:v1",
input={"file": io.BytesIO(initial_bytes=b"hello world")},
file_encoding_strategy="base64",
)

assert mock_predictions_create.called
prediction_payload = json.loads(mock_predictions_create.calls[0].request.content)
assert (
prediction_payload.get("input", {}).get("file")
== "data:application/octet-stream;base64,aGVsbG8gd29ybGQ="
)


@pytest.mark.asyncio
async def test_run_with_file_output(mock_replicate_api_token):
router = respx.Router(base_url="https://api.replicate.com/v1")
Expand Down

0 comments on commit 07c8fbb

Please sign in to comment.