diff --git a/source/s3_meta_request.c b/source/s3_meta_request.c index e471c64e5..c57efde20 100644 --- a/source/s3_meta_request.c +++ b/source/s3_meta_request.c @@ -30,8 +30,6 @@ struct s3_meta_request_binding { **/ FILE *recv_file; - struct aws_http_message *copied_message; - /* Batch up the transferred size in one sec. */ uint64_t size_transferred; /* The time stamp when the progress reported */ @@ -47,9 +45,6 @@ static void s_destroy(struct s3_meta_request_binding *meta_request) { if (meta_request->recv_file) { fclose(meta_request->recv_file); } - if (meta_request->copied_message) { - aws_http_message_release(meta_request->copied_message); - } Py_XDECREF(meta_request->py_core); aws_mem_release(aws_py_get_allocator(), meta_request); } @@ -122,6 +117,7 @@ static int s_s3_request_on_headers( } } +/* To avoid reporting progress to python too often. We cache it up and only report to python after at least 1 sec. */ static int s_record_progress(struct s3_meta_request_binding *request_binding, uint64_t length, bool *report_progress) { if (aws_add_u64_checked(request_binding->size_transferred, length, &request_binding->size_transferred)) { /* Wow */ @@ -151,10 +147,6 @@ static int s_s3_request_on_body( (void)meta_request; struct s3_meta_request_binding *request_binding = user_data; - bool report_progress; - if (s_record_progress(request_binding, (uint64_t)body->len, &report_progress)) { - return AWS_OP_ERR; - } if (request_binding->recv_file) { /* The callback will be invoked with the right order, so we don't need to seek first. */ if (fwrite((void *)body->ptr, body->len, 1, request_binding->recv_file) < 1) { @@ -168,9 +160,7 @@ static int s_s3_request_on_body( aws_error_name(aws_last_error())); return AWS_OP_ERR; } - if (!report_progress) { - return AWS_OP_SUCCESS; - } + return AWS_OP_SUCCESS; } bool error = true; /*************** GIL ACQUIRE ***************/ @@ -179,32 +169,15 @@ static int s_s3_request_on_body( if (aws_py_gilstate_ensure(&state)) { return AWS_OP_ERR; /* Python has shut down. Nothing matters anymore, but don't crash */ } - if (!request_binding->recv_file) { - result = PyObject_CallMethod( - request_binding->py_core, - "_on_body", - "(y#K)", - (const char *)(body->ptr), - (Py_ssize_t)body->len, - range_start); - if (!result) { - PyErr_WriteUnraisable(request_binding->py_core); - goto done; - } - Py_DECREF(result); - } - if (report_progress) { - /* Hold the GIL before enterring here */ - result = - PyObject_CallMethod(request_binding->py_core, "_on_progress", "(K)", request_binding->size_transferred); - if (!result) { - PyErr_WriteUnraisable(request_binding->py_core); - } else { - Py_DECREF(result); - } - request_binding->size_transferred = 0; + result = PyObject_CallMethod( + request_binding->py_core, "_on_body", "(y#K)", (const char *)(body->ptr), (Py_ssize_t)body->len, range_start); + + if (!result) { + PyErr_WriteUnraisable(request_binding->py_core); + goto done; } + Py_DECREF(result); error = false; done: PyGILState_Release(state); @@ -252,8 +225,6 @@ static void s_s3_request_on_finish( PyObject *header_list = NULL; PyObject *result = NULL; - request_binding->copied_message = aws_http_message_release(request_binding->copied_message); - if (request_binding->size_transferred && (error_code == 0)) { /* report the remaining progress */ result = @@ -343,39 +314,21 @@ static void s_s3_request_on_shutdown(void *user_data) { /*************** GIL RELEASE ***************/ } -/* - * file-based python input stream for reporting the progress - */ -struct aws_input_py_stream_file_impl { - struct aws_input_stream base; - struct aws_input_stream *actual_stream; - struct s3_meta_request_binding *binding; -}; - -static int s_aws_input_stream_file_read(struct aws_input_stream *stream, struct aws_byte_buf *dest) { - struct aws_input_py_stream_file_impl *impl = AWS_CONTAINER_OF(stream, struct aws_input_py_stream_file_impl, base); - size_t pre_len = dest->len; - - if (aws_input_stream_read(impl->actual_stream, dest)) { - return AWS_OP_ERR; - } +static void s_s3_request_on_progress( + struct aws_s3_meta_request *meta_request, + const struct aws_s3_meta_request_progress *progress, + void *user_data) { - size_t actually_read = 0; - if (aws_sub_size_checked(dest->len, pre_len, &actually_read)) { - return AWS_OP_ERR; - } + struct s3_meta_request_binding *request_binding = user_data; - bool report_progress; - struct s3_meta_request_binding *request_binding = impl->binding; - if (s_record_progress(request_binding, (uint64_t)actually_read, &report_progress)) { - return AWS_OP_ERR; - } + bool report_progress = false; + s_record_progress(request_binding, progress->bytes_transferred, &report_progress); if (report_progress) { /*************** GIL ACQUIRE ***************/ PyGILState_STATE state; if (aws_py_gilstate_ensure(&state)) { - return AWS_OP_ERR; /* Python has shut down. Nothing matters anymore, but don't crash */ + return; /* Python has shut down. Nothing matters anymore, but don't crash */ } PyObject *result = PyObject_CallMethod(request_binding->py_core, "_on_progress", "(K)", request_binding->size_transferred); @@ -385,113 +338,7 @@ static int s_aws_input_stream_file_read(struct aws_input_stream *stream, struct request_binding->size_transferred = 0; PyGILState_Release(state); /*************** GIL RELEASE ***************/ - if (!result) { - return aws_py_raise_error(); - } - } - return AWS_OP_SUCCESS; -} -static int s_aws_input_stream_file_seek( - struct aws_input_stream *stream, - int64_t offset, - enum aws_stream_seek_basis basis) { - struct aws_input_py_stream_file_impl *impl = AWS_CONTAINER_OF(stream, struct aws_input_py_stream_file_impl, base); - return aws_input_stream_seek(impl->actual_stream, offset, basis); -} - -static int s_aws_input_stream_file_get_status(struct aws_input_stream *stream, struct aws_stream_status *status) { - struct aws_input_py_stream_file_impl *impl = AWS_CONTAINER_OF(stream, struct aws_input_py_stream_file_impl, base); - return aws_input_stream_get_status(impl->actual_stream, status); -} - -static int s_aws_input_stream_file_get_length(struct aws_input_stream *stream, int64_t *length) { - struct aws_input_py_stream_file_impl *impl = AWS_CONTAINER_OF(stream, struct aws_input_py_stream_file_impl, base); - return aws_input_stream_get_length(impl->actual_stream, length); -} - -static void s_aws_input_stream_file_destroy(struct aws_input_py_stream_file_impl *impl) { - struct aws_allocator *allocator = aws_py_get_allocator(); - aws_input_stream_release(impl->actual_stream); - aws_mem_release(allocator, impl); -} - -static struct aws_input_stream_vtable s_aws_input_stream_file_vtable = { - .seek = s_aws_input_stream_file_seek, - .read = s_aws_input_stream_file_read, - .get_status = s_aws_input_stream_file_get_status, - .get_length = s_aws_input_stream_file_get_length, -}; - -static struct aws_input_stream *s_input_stream_new_from_file( - struct aws_allocator *allocator, - const char *file_name, - struct s3_meta_request_binding *request_binding) { - struct aws_input_py_stream_file_impl *impl = - aws_mem_calloc(allocator, 1, sizeof(struct aws_input_py_stream_file_impl)); - - impl->base.vtable = &s_aws_input_stream_file_vtable; - aws_ref_count_init(&impl->base.ref_count, impl, (aws_simple_completion_callback *)s_aws_input_stream_file_destroy); - - impl->actual_stream = aws_input_stream_new_from_file(allocator, file_name); - if (!impl->actual_stream) { - aws_mem_release(allocator, impl); - return NULL; - } - impl->binding = request_binding; - - return &impl->base; -} - -/* Copy an existing HTTP message without body. */ -struct aws_http_message *s_copy_http_message(struct aws_allocator *allocator, struct aws_http_message *base_message) { - AWS_PRECONDITION(allocator); - AWS_PRECONDITION(base_message); - - struct aws_http_message *message = aws_http_message_new_request(allocator); - - if (message == NULL) { - return NULL; } - - struct aws_byte_cursor request_method; - if (aws_http_message_get_request_method(base_message, &request_method)) { - goto error_clean_up; - } - - if (aws_http_message_set_request_method(message, request_method)) { - goto error_clean_up; - } - - struct aws_byte_cursor request_path; - if (aws_http_message_get_request_path(base_message, &request_path)) { - goto error_clean_up; - } - - if (aws_http_message_set_request_path(message, request_path)) { - goto error_clean_up; - } - - size_t num_headers = aws_http_message_get_header_count(base_message); - for (size_t header_index = 0; header_index < num_headers; ++header_index) { - struct aws_http_header header; - if (aws_http_message_get_header(base_message, &header, header_index)) { - goto error_clean_up; - } - if (aws_http_message_add_header(message, header)) { - goto error_clean_up; - } - } - - return message; - -error_clean_up: - - if (message != NULL) { - aws_http_message_release(message); - message = NULL; - } - - return NULL; } PyObject *aws_py_s3_client_make_meta_request(PyObject *self, PyObject *args) { @@ -579,37 +426,24 @@ PyObject *aws_py_s3_client_make_meta_request(PyObject *self, PyObject *args) { Py_INCREF(meta_request->py_core); if (recv_filepath) { - meta_request->recv_file = aws_fopen(recv_filepath, "wb+"); + meta_request->recv_file = aws_fopen(recv_filepath, "wb"); if (!meta_request->recv_file) { aws_translate_and_raise_io_error(errno); PyErr_SetAwsLastError(); goto error; } } - if (send_filepath) { - if (type == AWS_S3_META_REQUEST_TYPE_PUT_OBJECT) { - /* Copy the http request from python object and replace the old pointer with new pointer */ - meta_request->copied_message = s_copy_http_message(allocator, http_request); - struct aws_input_stream *input_body = s_input_stream_new_from_file(allocator, send_filepath, meta_request); - if (!input_body) { - PyErr_SetAwsLastError(); - goto error; - } - /* rewrite the input stream of the original request */ - aws_http_message_set_body_stream(meta_request->copied_message, input_body); - /* Input body is owned by copied message */ - aws_input_stream_release(input_body); - } - } struct aws_s3_meta_request_options s3_meta_request_opt = { .type = type, - .message = meta_request->copied_message ? meta_request->copied_message : http_request, + .message = http_request, .signing_config = signing_config, + .send_filepath = aws_byte_cursor_from_c_str(send_filepath), .headers_callback = s_s3_request_on_headers, .body_callback = s_s3_request_on_body, .finish_callback = s_s3_request_on_finish, .shutdown_callback = s_s3_request_on_shutdown, + .progress_callback = s_s3_request_on_progress, .user_data = meta_request, }; diff --git a/test/test_s3.py b/test/test_s3.py index 0d1a0c0ef..736a1ecf3 100644 --- a/test/test_s3.py +++ b/test/test_s3.py @@ -1,6 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0. +from io import BytesIO import unittest import os import tempfile @@ -147,8 +148,6 @@ def setUp(self): self.data_len = 0 self.progress_invoked = 0 - self.put_body_stream = None - self.files = FileCreator() self.temp_put_obj_file_path = self.files.create_file_with_size("temp_put_obj_10mb", 10 * MB) super().setUp() @@ -165,16 +164,15 @@ def _get_object_request(self, object_path): request = HttpRequest("GET", object_path, headers) return request - def _put_object_request(self, file_name, path=None): + def _put_object_request(self, input_stream, content_len, path=None, unknown_content_length=False): # if send file path is set, the body_stream of http request will be ignored (using file handler from C instead) - self.put_body_stream = open(file_name, "r+b") - file_stats = os.stat(file_name) - self.data_len = file_stats.st_size headers = HttpHeaders([("host", self._build_endpoint_string(self.region, self.bucket_name)), - ("Content-Type", "text/plain"), ("Content-Length", str(self.data_len))]) + ("Content-Type", "text/plain")]) + if unknown_content_length is False: + headers.add("Content-Length", str(content_len)) if path is None: path = self.put_test_object_path - request = HttpRequest("PUT", path, headers, self.put_body_stream) + request = HttpRequest("PUT", path, headers, input_stream) return request def _on_request_headers(self, status_code, headers, **kargs): @@ -187,12 +185,12 @@ def _on_request_body(self, chunk, offset, **kargs): def _on_progress(self, progress): self.transferred_len += progress - def _validate_successful_get_response(self, put_object): + def _validate_successful_response(self, is_put_object): self.assertEqual(self.response_status_code, 200, "status code is not 200") headers = HttpHeaders(self.response_headers) self.assertIsNone(headers.get("Content-Range")) body_length = headers.get("Content-Length") - if not put_object: + if not is_put_object: self.assertIsNotNone(body_length, "Content-Length is missing from headers") if body_length: self.assertEqual( @@ -200,12 +198,21 @@ def _validate_successful_get_response(self, put_object): self.received_body_len, "Received body length does not match the Content-Length header") - def _test_s3_put_get_object(self, request, request_type, exception_name=None): + def _test_s3_put_get_object( + self, + request, + request_type, + exception_name=None, + send_filepath=None, + recv_filepath=None): + s3_client = s3_client_new(False, self.region, 5 * MB) s3_request = s3_client.make_request( request=request, type=request_type, on_headers=self._on_request_headers, + send_filepath=send_filepath, + recv_filepath=recv_filepath, on_body=self._on_request_body) finished_future = s3_request.finished_future try: @@ -213,7 +220,7 @@ def _test_s3_put_get_object(self, request, request_type, exception_name=None): except Exception as e: self.assertEqual(e.name, exception_name) else: - self._validate_successful_get_response(request_type is S3RequestType.PUT_OBJECT) + self._validate_successful_response(request_type is S3RequestType.PUT_OBJECT) shutdown_event = s3_request.shutdown_event s3_request = None @@ -224,9 +231,25 @@ def test_get_object(self): self._test_s3_put_get_object(request, S3RequestType.GET_OBJECT) def test_put_object(self): - request = self._put_object_request(self.temp_put_obj_file_path) + put_body_stream = open(self.temp_put_obj_file_path, "rb") + content_length = os.stat(self.temp_put_obj_file_path).st_size + request = self._put_object_request(put_body_stream, content_length) + self._test_s3_put_get_object(request, S3RequestType.PUT_OBJECT) + put_body_stream.close() + + def test_put_object_unknown_content_length(self): + put_body_stream = open(self.temp_put_obj_file_path, "rb") + content_length = os.stat(self.temp_put_obj_file_path).st_size + request = self._put_object_request(put_body_stream, content_length, unknown_content_length=True) self._test_s3_put_get_object(request, S3RequestType.PUT_OBJECT) - self.put_body_stream.close() + put_body_stream.close() + + def test_put_object_unknown_content_length_single_part(self): + data_bytes = "test crt python single part upload".encode(encoding='utf-8') + put_body_stream = BytesIO(data_bytes) + request = self._put_object_request(put_body_stream, len(data_bytes), unknown_content_length=True) + self._test_s3_put_get_object(request, S3RequestType.PUT_OBJECT) + put_body_stream.close() def test_put_object_multiple_times(self): s3_client = s3_client_new(False, self.region, 5 * MB) @@ -234,8 +257,8 @@ def test_put_object_multiple_times(self): for i in range(3): tempfile = self.files.create_file_with_size("temp_file_{}".format(str(i)), 10 * MB) path = "/put_object_test_py_10MB_{}.txt".format(str(i)) - request = self._put_object_request(tempfile, path) - self.put_body_stream.close() + content_length = os.stat(tempfile).st_size + request = self._put_object_request(None, content_length, path=path) s3_request = s3_client.make_request( request=request, type=S3RequestType.PUT_OBJECT, @@ -255,9 +278,8 @@ def test_put_object_multiple_times(self): client_shutdown_event = s3_client.shutdown_event del s3_client self.assertTrue(client_shutdown_event.wait(self.timeout)) - self.put_body_stream.close() - def test_get_object_file_object(self): + def test_get_object_filepath(self): request = self._get_object_request(self.get_test_object_path) request_type = S3RequestType.GET_OBJECT s3_client = s3_client_new(False, self.region, 5 * MB) @@ -295,33 +317,21 @@ def test_get_object_file_object(self): # TODO verify the content of written file os.remove(file.name) - def test_put_object_file_object(self): - request = self._put_object_request(self.temp_put_obj_file_path) - request_type = S3RequestType.PUT_OBJECT - # close the stream, to test if the C FILE pointer as the input stream working well. - self.put_body_stream.close() - s3_client = s3_client_new(False, self.region, 5 * MB) - s3_request = s3_client.make_request( - request=request, - type=request_type, - send_filepath=self.temp_put_obj_file_path, - on_headers=self._on_request_headers, - on_progress=self._on_progress) - finished_future = s3_request.finished_future - finished_future.result(self.timeout) + def test_put_object_filepath(self): + content_length = os.stat(self.temp_put_obj_file_path).st_size + request = self._put_object_request(None, content_length) + self._test_s3_put_get_object(request, S3RequestType.PUT_OBJECT, send_filepath=self.temp_put_obj_file_path) - # check result - self.assertEqual( - self.data_len, - self.transferred_len, - "the transferred length reported does not match body we sent") - self._validate_successful_get_response(request_type is S3RequestType.PUT_OBJECT) + def test_put_object_filepath_unknown_content_length(self): + content_length = os.stat(self.temp_put_obj_file_path).st_size + request = self._put_object_request(None, content_length, unknown_content_length=True) + self._test_s3_put_get_object(request, S3RequestType.PUT_OBJECT, send_filepath=self.temp_put_obj_file_path) - def test_put_object_file_object_move(self): + def test_put_object_filepath_move(self): # remove the input file when request done tempfile = self.files.create_file_with_size("temp_file", 10 * MB) - request = self._put_object_request(tempfile) - self.put_body_stream.close() + content_length = os.stat(tempfile).st_size + request = self._put_object_request(None, content_length) s3_client = s3_client_new(False, self.region, 5 * MB) request_type = S3RequestType.PUT_OBJECT done_future = Future() @@ -341,10 +351,10 @@ def on_done_remove_file(**kwargs): # check result self.assertEqual( - self.data_len, + content_length, self.transferred_len, "the transferred length reported does not match body we sent") - self._validate_successful_get_response(request_type is S3RequestType.PUT_OBJECT) + self._validate_successful_response(request_type is S3RequestType.PUT_OBJECT) def _on_progress_cancel_after_first_chunk(self, progress): self.transferred_len += progress @@ -442,17 +452,20 @@ def test_put_object_quick_cancel(self): return self._put_object_cancel_helper(False) def test_multipart_upload_with_invalid_request(self): - request = self._put_object_request(self.temp_put_obj_file_path) + put_body_stream = open(self.temp_put_obj_file_path, "r+b") + content_length = os.stat(self.temp_put_obj_file_path).st_size + request = self._put_object_request(put_body_stream, content_length) request.headers.set("Content-MD5", "something") self._test_s3_put_get_object(request, S3RequestType.PUT_OBJECT, "AWS_ERROR_S3_INVALID_RESPONSE_STATUS") - self.put_body_stream.close() + put_body_stream.close() def test_special_filepath_upload(self): # remove the input file when request done with open(self.special_path, 'wb') as file: file.write(b"a" * 10 * MB) - request = self._put_object_request(self.special_path) - self.put_body_stream.close() + + content_length = os.stat(self.special_path).st_size + request = self._put_object_request(None, content_length) s3_client = s3_client_new(False, self.region, 5 * MB) request_type = S3RequestType.PUT_OBJECT @@ -485,18 +498,19 @@ def test_special_filepath_upload(self): # check result self.assertEqual( - self.data_len, + content_length, self.transferred_len, "the transferred length reported does not match body we sent") - self._validate_successful_get_response(request_type is S3RequestType.PUT_OBJECT) + self._validate_successful_response(request_type is S3RequestType.PUT_OBJECT) os.remove(self.special_path) def test_non_ascii_filepath_upload(self): # remove the input file when request done with open(self.non_ascii_file_name, 'wb') as file: file.write(b"a" * 10 * MB) - request = self._put_object_request(self.non_ascii_file_name) - self.put_body_stream.close() + + content_length = os.stat(self.non_ascii_file_name).st_size + request = self._put_object_request(None, content_length) s3_client = s3_client_new(False, self.region, 5 * MB) request_type = S3RequestType.PUT_OBJECT @@ -511,10 +525,10 @@ def test_non_ascii_filepath_upload(self): # check result self.assertEqual( - self.data_len, + content_length, self.transferred_len, "the transferred length reported does not match body we sent") - self._validate_successful_get_response(request_type is S3RequestType.PUT_OBJECT) + self._validate_successful_response(request_type is S3RequestType.PUT_OBJECT) os.remove(self.non_ascii_file_name) def test_non_ascii_filepath_download(self):