From 3a30ba8b4453433f5a12deadd30c8a878a189e5a Mon Sep 17 00:00:00 2001 From: Dengke Tang Date: Thu, 7 Sep 2023 15:06:46 -0700 Subject: [PATCH] Fix: Get object with checksum leak when retry happens (#346) Co-authored-by: Michael Graeb --- .github/workflows/ci.yml | 3 +- source/s3_meta_request.c | 6 +- tests/CMakeLists.txt | 1 + .../GetObject/get_object_checksum_retry.json | 12 + tests/mock_s3_server/mock_s3_server.py | 293 +++++++++++------- tests/s3_mock_server_tests.c | 44 +++ 6 files changed, 235 insertions(+), 124 deletions(-) create mode 100644 tests/mock_s3_server/GetObject/get_object_checksum_retry.json diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 01119bda9..338ba3ed0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,7 +6,7 @@ on: - 'main' env: - BUILDER_VERSION: v0.9.43 + BUILDER_VERSION: v0.9.48 BUILDER_SOURCE: releases BUILDER_HOST: https://d19elf31gohf1l.cloudfront.net PACKAGE_NAME: aws-c-s3 @@ -16,6 +16,7 @@ env: AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} AWS_DEFAULT_REGION: ${{ secrets.AWS_DEFAULT_REGION }} AWS_REGION: us-east-1 + CTEST_PARALLEL_LEVEL: 2 jobs: linux-compat: diff --git a/source/s3_meta_request.c b/source/s3_meta_request.c index 9b9761116..f2216c59f 100644 --- a/source/s3_meta_request.c +++ b/source/s3_meta_request.c @@ -982,12 +982,12 @@ static void s_get_response_part_finish_checksum_helper(struct aws_s3_connection request->validation_algorithm = request->request_level_running_response_sum->algorithm; aws_byte_buf_clean_up(&response_body_sum); aws_byte_buf_clean_up(&encoded_response_body_sum); - aws_checksum_destroy(request->request_level_running_response_sum); - aws_byte_buf_clean_up(&request->request_level_response_header_checksum); - request->request_level_running_response_sum = NULL; } else { request->did_validate = false; } + aws_checksum_destroy(request->request_level_running_response_sum); + aws_byte_buf_clean_up(&request->request_level_response_header_checksum); + request->request_level_running_response_sum = NULL; } static int s_s3_meta_request_incoming_headers( diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index ac06da68a..e4137e7a3 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -254,6 +254,7 @@ add_net_test_case(test_s3_list_bucket_valid) if(ENABLE_MOCK_SERVER_TESTS) add_net_test_case(multipart_upload_mock_server) add_net_test_case(multipart_upload_checksum_with_retry_mock_server) + add_net_test_case(multipart_download_checksum_with_retry_mock_server) add_net_test_case(async_internal_error_from_complete_multipart_mock_server) add_net_test_case(async_access_denied_from_complete_multipart_mock_server) add_net_test_case(get_object_modified_mock_server) diff --git a/tests/mock_s3_server/GetObject/get_object_checksum_retry.json b/tests/mock_s3_server/GetObject/get_object_checksum_retry.json new file mode 100644 index 000000000..1910e495a --- /dev/null +++ b/tests/mock_s3_server/GetObject/get_object_checksum_retry.json @@ -0,0 +1,12 @@ +{ + "status": 200, + "headers": { + "ETag": "b54357faf0632cce46e942fa68356b38", + "Date": "Thu, 12 Jan 2023 00:04:21 GMT", + "Last-Modified": "Tue, 10 Jan 2023 23:39:32 GMT", + "Accept-Ranges": "bytes", + "Content-Range": "bytes 0-65535/65536", + "Content-Type": "binary/octet-stream", + "x-amz-checksum-crc32": "q1875w==" + } +} diff --git a/tests/mock_s3_server/mock_s3_server.py b/tests/mock_s3_server/mock_s3_server.py index 9e8c6ccb3..73de24922 100644 --- a/tests/mock_s3_server/mock_s3_server.py +++ b/tests/mock_s3_server/mock_s3_server.py @@ -3,10 +3,12 @@ # # S3 Mock server logic starts from handle_mock_s3_request +from dataclasses import dataclass import json from itertools import count from urllib.parse import parse_qs, urlparse import os +from typing import Optional from enum import Enum import trio @@ -17,7 +19,13 @@ TIMEOUT = 120 # this must be higher than any response's "delay" setting VERBOSE = False + +# Flags to keep between requests SHOULD_THROTTLE = True +RETRY_REQUEST_COUNT = 0 + + +base_dir = os.path.dirname(os.path.realpath(__file__)) class S3Opts(Enum): @@ -29,7 +37,84 @@ class S3Opts(Enum): ListParts = 6 -base_dir = os.path.dirname(os.path.realpath(__file__)) +@dataclass +class Response: + status_code: int + delay: int + headers: any + data: str + chunked: bool + head_request: bool + + +@dataclass +class ResponseConfig: + path: str + disconnect_after_headers = False + generate_body_size: Optional[int] = None + json_path: str = None + throttle: bool = False + force_retry: bool = False + + def _resolve_file_path(self, wrapper, request_type): + global SHOULD_THROTTLE + if self.json_path is None: + response_file = os.path.join( + base_dir, request_type.name, f"{self.path[1:]}.json") + if os.path.exists(response_file) == False: + wrapper.info( + response_file, "not exist, using the default response") + response_file = os.path.join( + base_dir, request_type.name, f"default.json") + if "throttle" in response_file: + # We throttle the request half the time to make sure it succeeds after a retry + if SHOULD_THROTTLE is False: + wrapper.info("Skipping throttling") + response_file = os.path.join( + base_dir, request_type.name, f"default.json") + else: + wrapper.info("Throttling") + # Flip the flag + SHOULD_THROTTLE = not SHOULD_THROTTLE + self.json_path = response_file + + def resolve_response(self, wrapper, request_type, chunked=False, head_request=False): + self._resolve_file_path(wrapper, request_type) + wrapper.info("resolving response from json file: ", self.json_path, + ".\n generate_body_size: ", self.generate_body_size) + with open(self.json_path, 'r') as f: + data = json.load(f) + + # if response has delay, then sleep before sending it + delay = data.get('delay', 0) + status_code = data['status'] + if self.generate_body_size is not None: + # generate body with a specific size instead + body = "a" * self.generate_body_size + else: + body = "\n".join(data['body']) + + headers = wrapper.basic_headers() + content_length_set = False + for header in data['headers'].items(): + headers.append((header[0], str(header[1]))) + if header[0].lower() == "content-length": + content_length_set = True + + if chunked: + headers.append(('Transfer-Encoding', "chunked")) + else: + if self.force_retry: + # Use a long `content-length` header to trigger error when we try to send EOM. + # so that the server will close connection after we send the header. + headers.append(("Content-Length", str(123456))) + elif content_length_set is False: + headers.append(("Content-Length", str(len(body)))) + + response = Response(status_code=status_code, delay=delay, headers=headers, + data=body, chunked=chunked, head_request=head_request) + + return response class TrioHTTPWrapper: @@ -37,7 +122,6 @@ class TrioHTTPWrapper: def __init__(self, stream): self.stream = stream - self.should_throttle = SHOULD_THROTTLE self.conn = h11.Connection(h11.SERVER) # A unique id for this connection, to include in debugging output # (useful for understanding what's going on if there are multiple @@ -58,7 +142,7 @@ async def send(self, event): async def _read_from_peer(self): if self.conn.they_are_waiting_for_100_continue: self.info("Sending 100 Continue") - go_ahead = h11.InformationalResponse( + go_ahead = h11.InformationalResponseConfig( status_code=100, headers=self.basic_headers() ) await self.send(go_ahead) @@ -112,6 +196,38 @@ def info(self, *args): # Server main loop ################################################################ + +async def send_simple_response(wrapper, status_code, content_type, body): + wrapper.info("Sending", status_code, "response with", len(body), "bytes") + headers = wrapper.basic_headers() + headers.append(("Content-Type", content_type)) + headers.append(("Content-Length", str(len(body)))) + res = h11.Response(status_code=status_code, headers=headers) + await wrapper.send(res) + await wrapper.send(h11.Data(data=body)) + await wrapper.send(h11.EndOfMessage()) + + +async def maybe_send_error_response(wrapper, exc): + if wrapper.conn.our_state not in {h11.IDLE, h11.SEND_RESPONSE}: + wrapper.info("...but I can't, because our state is", + wrapper.conn.our_state) + return + try: + if isinstance(exc, h11.RemoteProtocolError): + status_code = exc.error_status_hint + elif isinstance(exc, trio.TooSlowError): + status_code = 408 # Request Timeout + else: + status_code = 500 + body = str(exc).encode("utf-8") + await send_simple_response( + wrapper, status_code, "text/plain; charset=utf-8", body + ) + except Exception as exc: + wrapper.info("error while sending error response:", exc) + + async def http_serve(stream): wrapper = TrioHTTPWrapper(stream) wrapper.info("Got new connection") @@ -153,104 +269,29 @@ async def http_serve(stream): ################################################################ # Helper function +async def send_response(wrapper, response): + if response.delay > 0: + assert response.delay < TIMEOUT + await trio.sleep(response.delay) -def parse_request_path(request_path): - parsed_path = urlparse(request_path) - parsed_query = parse_qs(parsed_path.query) - return parsed_path, parsed_query - - -async def send_simple_response(wrapper, status_code, content_type, body): - wrapper.info("Sending", status_code, "response with", len(body), "bytes") - headers = wrapper.basic_headers() - headers.append(("Content-Type", content_type)) - headers.append(("Content-Length", str(len(body)))) - res = h11.Response(status_code=status_code, headers=headers) - await wrapper.send(res) - await wrapper.send(h11.Data(data=body)) - await wrapper.send(h11.EndOfMessage()) - - -async def send_response_from_json(wrapper, response_json_path, chunked=False, generate_body=False, generate_body_size=0, head_request=False): - wrapper.info("sending response from json file: ", response_json_path, - ".\n generate_body: ", generate_body, "generate_body_size: ", generate_body_size) - with open(response_json_path, 'r') as f: - data = json.load(f) + wrapper.info("Sending", response.status_code, + "response with", len(response.data), "bytes") - # if response has delay, then sleep before sending it - delay = data.get('delay', 0) - if delay > 0: - assert delay < TIMEOUT - await trio.sleep(delay) + res = h11.Response(status_code=response.status_code, + headers=response.headers) - status_code = data['status'] - if generate_body: - # generate body with a specific size instead - body = "a" * generate_body_size - else: - body = "\n".join(data['body']) - wrapper.info("Sending", status_code, - "response with", len(body), "bytes") - - headers = wrapper.basic_headers() - for header in data['headers'].items(): - headers.append((header[0], header[1])) - - if chunked: - headers.append(('Transfer-Encoding', "chunked")) - res = h11.Response(status_code=status_code, headers=headers) - await wrapper.send(res) - await wrapper.send(h11.Data(data=b"%X\r\n%s\r\n" % (len(body), body.encode()))) - else: - headers.append(("Content-Length", str(len(body)))) - res = h11.Response(status_code=status_code, headers=headers) + try: await wrapper.send(res) - if head_request: - await wrapper.send(h11.EndOfMessage()) - return - await wrapper.send(h11.Data(data=body.encode())) - - await wrapper.send(h11.EndOfMessage()) + except Exception as e: + print(e) - -async def send_mock_s3_response(wrapper, request_type, path, generate_body=False, generate_body_size=0, head_request=False): - response_file = os.path.join( - base_dir, request_type.name, f"{path[1:]}.json") - if os.path.exists(response_file) == False: - wrapper.info(response_file, "not exist, using the default response") - response_file = os.path.join( - base_dir, request_type.name, f"default.json") - if "throttle" in response_file: - # We throttle the request half the time to make sure it succeeds after a retry - if wrapper.should_throttle is False: - wrapper.info("Skipping throttling") - response_file = os.path.join( - base_dir, request_type.name, f"default.json") + if not response.head_request: + if response.chunked: + await wrapper.send(h11.Data(data=b"%X\r\n%s\r\n" % (len(response.data), response.data.encode()))) else: - wrapper.info("Throttling") - # Flip the flag - wrapper.should_throttle = not wrapper.should_throttle - await send_response_from_json(wrapper, response_file, generate_body=generate_body, generate_body_size=generate_body_size) + await wrapper.send(h11.Data(data=response.data.encode())) - -async def maybe_send_error_response(wrapper, exc): - if wrapper.conn.our_state not in {h11.IDLE, h11.SEND_RESPONSE}: - wrapper.info("...but I can't, because our state is", - wrapper.conn.our_state) - return - try: - if isinstance(exc, h11.RemoteProtocolError): - status_code = exc.error_status_hint - elif isinstance(exc, trio.TooSlowError): - status_code = 408 # Request Timeout - else: - status_code = 500 - body = str(exc).encode("utf-8") - await send_simple_response( - wrapper, status_code, "text/plain; charset=utf-8", body - ) - except Exception as exc: - wrapper.info("error while sending error response:", exc) + await wrapper.send(h11.EndOfMessage()) def get_request_header_value(request, header_name): @@ -264,7 +305,7 @@ def handle_get_object_modified(start_range, end_range, request): data_length = end_range - start_range if start_range == 0: - return "/get_object_modified_first_part", data_length, True + return ResponseConfig("/get_object_modified_first_part", generate_body_size=data_length) else: # Check the request header to make sure "If-Match" is set etag = get_request_header_value(request, "if-match") @@ -275,11 +316,21 @@ def handle_get_object_modified(start_range, end_range, request): with open(response_file, 'r') as f: data = json.load(f) if data['headers']['ETag'] == etag: - return "/get_object_modified_success", data_length, False - return "/get_object_modified_failure", data_length, False + return ResponseConfig("/get_object_modified_success") + return ResponseConfig("/get_object_modified_failure") + +def handle_get_object(wrapper, request, parsed_path, head_request=False): + global RETRY_REQUEST_COUNT + response_config = ResponseConfig(parsed_path.path) + if parsed_path.path == "/get_object_checksum_retry" and not head_request: + RETRY_REQUEST_COUNT = RETRY_REQUEST_COUNT + 1 -def handle_get_object(request, parsed_path): + if RETRY_REQUEST_COUNT == 1: + wrapper.info("Force retry on the request") + response_config.force_retry = True + else: + RETRY_REQUEST_COUNT = 0 body_range_value = get_request_header_value(request, "range") @@ -296,30 +347,27 @@ def handle_get_object(request, parsed_path): if parsed_path.path == "/get_object_modified": return handle_get_object_modified(start_range, end_range, request) - elif parsed_path.path == "/get_object_invalid_response_missing_content_range": - return "/get_object_invalid_response_missing_content_range", data_length, False - elif parsed_path.path == "/get_object_invalid_response_missing_etags": - return "/get_object_invalid_response_missing_etags", data_length, False + elif parsed_path.path == "/get_object_invalid_response_missing_content_range" or parsed_path.path == "/get_object_invalid_response_missing_etags": + # Don't generate the body for those requests + return response_config - return parsed_path.path, data_length, True + response_config.generate_body_size = data_length + return response_config def handle_list_parts(parsed_path): if parsed_path.path == "/multiple_list_parts": if parsed_path.query.find("part-number-marker") != -1: - return "/multiple_list_parts_2" + return ResponseConfig("/multiple_list_parts_2") else: - return "/multiple_list_parts_1" - return parsed_path.path + return ResponseConfig("/multiple_list_parts_1") + return ResponseConfig(parsed_path.path) async def handle_mock_s3_request(wrapper, request): - parsed_path, parsed_query = parse_request_path( - request.target.decode("ascii")) - response_path = parsed_path.path - generate_body = False - generate_body_size = 0 + parsed_path = urlparse(request.target.decode("ascii")) method = request.method.decode("utf-8") + response_config = None if method == "POST": if parsed_path.query == "uploads": @@ -336,11 +384,11 @@ async def handle_mock_s3_request(wrapper, request): if parsed_path.query.find("uploadId") != -1: # GET /Key+?max-parts=MaxParts&part-number-marker=PartNumberMarker&uploadId=UploadId HTTP/1.1 -- List Parts request_type = S3Opts.ListParts - response_path = handle_list_parts(parsed_path) + response_config = handle_list_parts(parsed_path) else: request_type = S3Opts.GetObject - response_path, generate_body_size, generate_body = handle_get_object( - request, parsed_path) + response_config = handle_get_object( + wrapper, request, parsed_path, head_request=method == "HEAD") else: # TODO: support more type. wrapper.info("unsupported request:", request) @@ -352,9 +400,18 @@ async def handle_mock_s3_request(wrapper, request): break assert type(event) is h11.Data - await send_mock_s3_response( - wrapper, request_type, response_path, generate_body=generate_body, generate_body_size=generate_body_size, head_request=method == "HEAD") + if response_config is None: + response_config = ResponseConfig(parsed_path.path) + + response = response_config.resolve_response( + wrapper, request_type, head_request=method == "HEAD") + await send_response(wrapper, response) + + +################################################################ +# Run the server +################################################################ async def serve(port): print("listening on http://localhost:{}".format(port)) @@ -363,9 +420,5 @@ async def serve(port): except KeyboardInterrupt: print("KeyboardInterrupt - shutting down") - -################################################################ -# Run the server -################################################################ if __name__ == "__main__": trio.run(serve, 8080) diff --git a/tests/s3_mock_server_tests.c b/tests/s3_mock_server_tests.c index bda02a354..81c5b4a7a 100644 --- a/tests/s3_mock_server_tests.c +++ b/tests/s3_mock_server_tests.c @@ -194,6 +194,50 @@ TEST_CASE(multipart_upload_checksum_with_retry_mock_server) { return AWS_OP_SUCCESS; } +TEST_CASE(multipart_download_checksum_with_retry_mock_server) { + (void)ctx; + /** + * We had a memory leak after the header of the request received successfully, the request failed. + * We have allocated memory that never frees. + */ + struct aws_s3_tester tester; + ASSERT_SUCCESS(aws_s3_tester_init(allocator, &tester)); + struct aws_s3_tester_client_options client_options = { + .part_size = MB_TO_BYTES(5), + .tls_usage = AWS_S3_TLS_DISABLED, + }; + + struct aws_s3_client *client = NULL; + ASSERT_SUCCESS(aws_s3_tester_client_new(&tester, &client_options, &client)); + /* Mock server will response without fake checksum for the body */ + struct aws_byte_cursor object_path = aws_byte_cursor_from_c_str("/get_object_checksum_retry"); + + struct aws_s3_tester_meta_request_options get_options = { + .allocator = allocator, + .meta_request_type = AWS_S3_META_REQUEST_TYPE_GET_OBJECT, + .client = client, + .expected_validate_checksum_alg = AWS_SCA_CRC32, + .validate_get_response_checksum = true, + .get_options = + { + .object_path = object_path, + }, + .default_type_options = + { + .mode = AWS_S3_TESTER_DEFAULT_TYPE_MODE_GET, + }, + .mock_server = true, + .validate_type = AWS_S3_TESTER_VALIDATE_TYPE_EXPECT_FAILURE, + }; + + ASSERT_SUCCESS(aws_s3_tester_send_meta_request_with_options(&tester, &get_options, NULL)); + + aws_s3_client_release(client); + aws_s3_tester_clean_up(&tester); + + return AWS_OP_SUCCESS; +} + TEST_CASE(async_internal_error_from_complete_multipart_mock_server) { (void)ctx;