Skip to content

Commit

Permalink
Store Exception for on headers and onbody callback (#523)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael Graeb <[email protected]>
  • Loading branch information
waahm7 and graebm authored Nov 10, 2023
1 parent cc5af86 commit 185872e
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 16 deletions.
43 changes: 29 additions & 14 deletions awscrt/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,8 @@ def __init__(
on_done=None,
on_progress=None):

# Stores exception raised in on_headers or on_body callback so that we can rethrow it in the on_done callback
self._python_callback_exception = None
self._request = request
self._signing_config = signing_config
self._credential_provider = credential_provider
Expand All @@ -529,11 +531,21 @@ def __init__(

def _on_headers(self, status_code, headers):
if self._on_headers_cb:
self._on_headers_cb(status_code=status_code, headers=headers)
try:
self._on_headers_cb(status_code=status_code, headers=headers)
return True
except BaseException as e:
self._python_callback_exception = e
return False

def _on_body(self, chunk, offset):
if self._on_body_cb:
self._on_body_cb(chunk=chunk, offset=offset)
try:
self._on_body_cb(chunk=chunk, offset=offset)
return True
except BaseException as e:
self._python_callback_exception = e
return False

def _on_shutdown(self):
self._shutdown_event.set()
Expand All @@ -547,18 +559,21 @@ def _on_finish(self, error_code, status_code, error_headers, error_body):
if error_code:
error = awscrt.exceptions.from_code(error_code)

# If the failure was due to a response, make it into an S3ResponseError.
# When failure is due to a response, its headers are always included.
if isinstance(error, awscrt.exceptions.AwsCrtError) \
and status_code is not None \
and error_headers is not None:
error = S3ResponseError(
code=error.code,
name=error.name,
message=error.message,
status_code=status_code,
headers=error_headers,
body=error_body)
if isinstance(error, awscrt.exceptions.AwsCrtError):
if (error.name == "AWS_ERROR_CRT_CALLBACK_EXCEPTION"
and self._python_callback_exception is not None):
error = self._python_callback_exception
# If the failure was due to a response, make it into an S3ResponseError.
# When failure is due to a response, its headers are always included.
elif status_code is not None \
and error_headers is not None:
error = S3ResponseError(
code=error.code,
name=error.name,
message=error.message,
status_code=status_code,
headers=error_headers,
body=error_body)
self._finished_future.set_exception(error)
else:
self._finished_future.set_result(None)
Expand Down
8 changes: 6 additions & 2 deletions source/s3_meta_request.c
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,10 @@ static int s_s3_request_on_headers(
PyErr_WriteUnraisable(request_binding->py_core);
goto done;
}
/* If user's callback raises an exception, _S3RequestCore._on_headers
* stores it to throw later and returns False */
error = (result == Py_False);
Py_DECREF(result);
error = false;
done:
Py_XDECREF(header_list);
PyGILState_Release(state);
Expand Down Expand Up @@ -177,8 +179,10 @@ static int s_s3_request_on_body(
PyErr_WriteUnraisable(request_binding->py_core);
goto done;
}
/* If user's callback raises an exception, _S3RequestCore._on_body
* stores it to throw later and returns False */
error = (result == Py_False);
Py_DECREF(result);
error = false;
done:
PyGILState_Release(state);
/*************** GIL RELEASE ***************/
Expand Down
46 changes: 46 additions & 0 deletions test/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,52 @@ def test_multipart_upload_with_invalid_request(self):

put_body_stream.close()

def test_on_headers_callback_failure(self):
def _explode(**kwargs):
raise RuntimeError("Error in on_headers callback")

request = self._get_object_request(self.get_test_object_path)
s3_client = s3_client_new(False, self.region, 5 * MB)
s3_request = s3_client.make_request(
request=request,
type=S3RequestType.GET_OBJECT,
on_headers=_explode,
on_body=self._on_request_body,
)

finished_future = s3_request.finished_future
shutdown_event = s3_request.shutdown_event
s3_request = None
self.assertTrue(shutdown_event.wait(self.timeout))

e = finished_future.exception()
# check that data from on_done callback came through correctly
self.assertIsInstance(e, RuntimeError)
self.assertEqual(str(e), "Error in on_headers callback")

def test_on_body_callback_failure(self):
def _explode(**kwargs):
raise RuntimeError("Error in on_body callback")

request = self._get_object_request(self.get_test_object_path)
s3_client = s3_client_new(False, self.region, 5 * MB)
s3_request = s3_client.make_request(
request=request,
type=S3RequestType.GET_OBJECT,
on_headers=self._on_request_headers,
on_body=_explode,
)

finished_future = s3_request.finished_future
shutdown_event = s3_request.shutdown_event
s3_request = None
self.assertTrue(shutdown_event.wait(self.timeout))

e = finished_future.exception()
# check that data from on_done callback came through correctly
self.assertIsInstance(e, RuntimeError)
self.assertEqual(str(e), "Error in on_body callback")

def test_special_filepath_upload(self):
# remove the input file when request done
content_length = 10 * MB
Expand Down

0 comments on commit 185872e

Please sign in to comment.