diff --git a/awscrt/s3.py b/awscrt/s3.py index f0af97490..7509d14da 100644 --- a/awscrt/s3.py +++ b/awscrt/s3.py @@ -515,6 +515,8 @@ def __init__( on_done=None, on_progress=None): + # Stores exception raised in on_headers or on_body callback so that we can rethrow it in the on_done callback + self._python_callback_exception = None self._request = request self._signing_config = signing_config self._credential_provider = credential_provider @@ -529,11 +531,21 @@ def __init__( def _on_headers(self, status_code, headers): if self._on_headers_cb: - self._on_headers_cb(status_code=status_code, headers=headers) + try: + self._on_headers_cb(status_code=status_code, headers=headers) + return True + except BaseException as e: + self._python_callback_exception = e + return False def _on_body(self, chunk, offset): if self._on_body_cb: - self._on_body_cb(chunk=chunk, offset=offset) + try: + self._on_body_cb(chunk=chunk, offset=offset) + return True + except BaseException as e: + self._python_callback_exception = e + return False def _on_shutdown(self): self._shutdown_event.set() @@ -547,18 +559,21 @@ def _on_finish(self, error_code, status_code, error_headers, error_body): if error_code: error = awscrt.exceptions.from_code(error_code) - # If the failure was due to a response, make it into an S3ResponseError. - # When failure is due to a response, its headers are always included. - if isinstance(error, awscrt.exceptions.AwsCrtError) \ - and status_code is not None \ - and error_headers is not None: - error = S3ResponseError( - code=error.code, - name=error.name, - message=error.message, - status_code=status_code, - headers=error_headers, - body=error_body) + if isinstance(error, awscrt.exceptions.AwsCrtError): + if (error.name == "AWS_ERROR_CRT_CALLBACK_EXCEPTION" + and self._python_callback_exception is not None): + error = self._python_callback_exception + # If the failure was due to a response, make it into an S3ResponseError. + # When failure is due to a response, its headers are always included. + elif status_code is not None \ + and error_headers is not None: + error = S3ResponseError( + code=error.code, + name=error.name, + message=error.message, + status_code=status_code, + headers=error_headers, + body=error_body) self._finished_future.set_exception(error) else: self._finished_future.set_result(None) diff --git a/source/s3_meta_request.c b/source/s3_meta_request.c index aacbd4cfb..0907a075c 100644 --- a/source/s3_meta_request.c +++ b/source/s3_meta_request.c @@ -104,8 +104,10 @@ static int s_s3_request_on_headers( PyErr_WriteUnraisable(request_binding->py_core); goto done; } + /* If user's callback raises an exception, _S3RequestCore._on_headers + * stores it to throw later and returns False */ + error = (result == Py_False); Py_DECREF(result); - error = false; done: Py_XDECREF(header_list); PyGILState_Release(state); @@ -177,8 +179,10 @@ static int s_s3_request_on_body( PyErr_WriteUnraisable(request_binding->py_core); goto done; } + /* If user's callback raises an exception, _S3RequestCore._on_body + * stores it to throw later and returns False */ + error = (result == Py_False); Py_DECREF(result); - error = false; done: PyGILState_Release(state); /*************** GIL RELEASE ***************/ diff --git a/test/test_s3.py b/test/test_s3.py index c7e7b3e8c..e3cb391a6 100644 --- a/test/test_s3.py +++ b/test/test_s3.py @@ -617,6 +617,52 @@ def test_multipart_upload_with_invalid_request(self): put_body_stream.close() + def test_on_headers_callback_failure(self): + def _explode(**kwargs): + raise RuntimeError("Error in on_headers callback") + + request = self._get_object_request(self.get_test_object_path) + s3_client = s3_client_new(False, self.region, 5 * MB) + s3_request = s3_client.make_request( + request=request, + type=S3RequestType.GET_OBJECT, + on_headers=_explode, + on_body=self._on_request_body, + ) + + finished_future = s3_request.finished_future + shutdown_event = s3_request.shutdown_event + s3_request = None + self.assertTrue(shutdown_event.wait(self.timeout)) + + e = finished_future.exception() + # check that data from on_done callback came through correctly + self.assertIsInstance(e, RuntimeError) + self.assertEqual(str(e), "Error in on_headers callback") + + def test_on_body_callback_failure(self): + def _explode(**kwargs): + raise RuntimeError("Error in on_body callback") + + request = self._get_object_request(self.get_test_object_path) + s3_client = s3_client_new(False, self.region, 5 * MB) + s3_request = s3_client.make_request( + request=request, + type=S3RequestType.GET_OBJECT, + on_headers=self._on_request_headers, + on_body=_explode, + ) + + finished_future = s3_request.finished_future + shutdown_event = s3_request.shutdown_event + s3_request = None + self.assertTrue(shutdown_event.wait(self.timeout)) + + e = finished_future.exception() + # check that data from on_done callback came through correctly + self.assertIsInstance(e, RuntimeError) + self.assertEqual(str(e), "Error in on_body callback") + def test_special_filepath_upload(self): # remove the input file when request done content_length = 10 * MB