diff --git a/awscrt/http.py b/awscrt/http.py index 73a5976a8..75f2bedfa 100644 --- a/awscrt/http.py +++ b/awscrt/http.py @@ -15,7 +15,6 @@ import _awscrt from concurrent.futures import Future from collections import defaultdict -from enum import Enum from io import IOBase from awscrt import NativeResource, isinstance_str from awscrt.io import ClientBootstrap, EventLoopGroup, DefaultHostResolver, TlsConnectionOptions, SocketOptions @@ -122,12 +121,12 @@ def request(self, request, on_response=None, on_body=None): class HttpStreamBase(NativeResource): - __slots__ = ('_connection', '_complete_future', '_on_body_cb') + __slots__ = ('_connection', '_completion_future', '_on_body_cb') def __init__(self, connection, on_body=None): super(HttpStreamBase, self).__init__() self._connection = connection - self._complete_future = Future() + self._completion_future = Future() self._on_body_cb = on_body @property @@ -135,8 +134,8 @@ def connection(self): return self._connection @property - def complete_future(self): - return self._complete_future + def completion_future(self): + return self._completion_future def _on_body(self, chunk): if self._on_body_cb: @@ -144,9 +143,9 @@ def _on_body(self, chunk): def _on_complete(self, error_code): if error_code == 0: - self._complete_future.set_result(None) + self._completion_future.set_result(None) else: - self._complete_future.set_exception(Exception(error_code)) # TODO: Actual exceptions for error_codes + self._completion_future.set_exception(Exception(error_code)) # TODO: Actual exceptions for error_codes class HttpClientStream(HttpStreamBase): diff --git a/builder.json b/builder.json index dbebdbf7b..837d8811a 100644 --- a/builder.json +++ b/builder.json @@ -61,7 +61,7 @@ ["{python}", "setup.py", "--verbose", "build_ext", "--include-dirs{openssl_include}", "--library-dirs{openssl_lib}", "install"] ], "test": [ - ["{python}", "-m", "unittest", "discover", "--buffer", "--verbose"], + ["{python}", "-m", "unittest", "discover", "--verbose"], ["{python}", "aws-c-http/integration-testing/http_client_test.py", "{python}", "elasticurl.py"], ["{python}", "-m", "pip", "install", "autopep8"], ["{python}", "-m", "autopep8", "--exit-code", "--diff", "--recursive", "awscrt", "test", "setup.py"] diff --git a/continuous-delivery/pull-pypirc.py b/continuous-delivery/pull-pypirc.py index 558075c50..5caae5332 100644 --- a/continuous-delivery/pull-pypirc.py +++ b/continuous-delivery/pull-pypirc.py @@ -15,7 +15,6 @@ import base64 import os import argparse -from botocore.exceptions import ClientError def get_secret(stage): @@ -49,4 +48,4 @@ def get_secret(stage): parser.add_argument('stage', help='Stage to deploy the pypi package to (e.g. alpha, prod, etc...)', type=str) args = parser.parse_args() get_secret(args.stage) - + diff --git a/elasticurl.py b/elasticurl.py index 8de4d8b70..d7bd2504d 100644 --- a/elasticurl.py +++ b/elasticurl.py @@ -192,7 +192,7 @@ def response_received_cb(http_stream, status_code, headers): stream = connection.request(request, response_received_cb, on_incoming_body) # wait until the full response is finished -stream.complete_future.result() +stream.completion_future.result() stream = None connection = None diff --git a/source/http_connection.c b/source/http_connection.c index 7ae8370b3..0f4f1b375 100644 --- a/source/http_connection.c +++ b/source/http_connection.c @@ -218,7 +218,6 @@ PyObject *aws_py_http_client_connection_new(PyObject *self, PyObject *args) { if (!connection->tls_ctx) { goto error; } - Py_INCREF(connection->tls_ctx); } struct aws_socket_options socket_options; diff --git a/test/__init__.py b/test/__init__.py index 502e69bf2..ac47e99ef 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -12,6 +12,9 @@ # permissions and limitations under the License. from awscrt import NativeResource +import gc +import sys +import types import unittest @@ -24,4 +27,33 @@ def setUp(self): NativeResource._track_lifetime = True def tearDown(self): + gc.collect() + + # Print out debugging info on leaking resources + if NativeResource._living: + print('Leaking NativeResources:') + for i in NativeResource._living: + print('-', i) + + # getrefcount(i) returns 4+ here, but 2 of those are due to debugging. + # Don't show: + # - 1 for WeakSet iterator due to this for-loop. + # - 1 for getrefcount(i)'s reference. + # But do show: + # - 1 for item's self-reference. + # - the rest are what's causing this leak. + refcount = sys.getrefcount(i) - 2 + + # The act of iterating a WeakSet creates a reference. Don't show that. + referrers = gc.get_referrers(i) + for r in referrers: + if isinstance(r, types.FrameType) and '_weakrefset.py' in str(r): + referrers.remove(r) + break + + print(' sys.getrefcount():', refcount) + print(' gc.referrers():', len(referrers)) + for r in referrers: + print(' -', r) + self.assertEqual(0, len(NativeResource._living)) diff --git a/test/test_http_client.py b/test/test_http_client.py new file mode 100644 index 000000000..ee11e110d --- /dev/null +++ b/test/test_http_client.py @@ -0,0 +1,229 @@ +# Copyright 2010-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://aws.amazon.com/apache2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from __future__ import absolute_import +from awscrt.http import HttpClientConnection, HttpClientStream, HttpHeaders, HttpRequest +from awscrt.io import TlsContextOptions, ClientTlsContext, TlsConnectionOptions +from concurrent.futures import Future +import ssl +from test import NativeResourceTest +import threading + +# Use a built-in Python HTTP server to test the awscrt's HTTP client +try: + from http.server import HTTPServer, SimpleHTTPRequestHandler +except ImportError: + # Simple HTTP server lives in a different places in Python3 vs Python2: + # http.server.HTTPServer == SocketServer.TCPServer + # http.server.SimpleHTTPRequestHandler == SimpleHTTPServer.SimpleHTTPRequestHandler + from SimpleHTTPServer import SimpleHTTPRequestHandler + import SocketServer + HTTPServer = SocketServer.TCPServer + + +class Response(object): + """Holds contents of incoming response""" + + def __init__(self): + self.status_code = None + self.headers = None + self.body = bytearray() + + def on_response(self, stream, status_code, headers): + self.status_code = status_code + self.headers = HttpHeaders(headers) + + def on_body(self, stream, chunk): + self.body.extend(chunk) + + +class TestRequestHandler(SimpleHTTPRequestHandler): + """Request handler for test server""" + + # default was HTTP/1.0. + # specifying HTTP/1.1 keeps connection alive after handling 1 request + protocol_version = "HTTP/1.1" + + def do_PUT(self): + content_length = int(self.headers['Content-Length']) + # store put request on the server object + incoming_body_bytes = self.rfile.read(content_length) + self.server.put_requests[self.path] = incoming_body_bytes + self.send_response(200, 'OK') + self.end_headers() + + +class TestClient(NativeResourceTest): + hostname = 'localhost' + timeout = 10 # seconds + + def _start_server(self, secure): + self.server = HTTPServer((self.hostname, 0), TestRequestHandler) + if secure: + self.server.socket = ssl.wrap_socket(self.server.socket, + keyfile="test/resources/unittests.key", + certfile='test/resources/unittests.crt', + server_side=True) + self.port = self.server.server_address[1] + + # put requests are stored in this dict + self.server.put_requests = {} + + self.server_thread = threading.Thread(target=self.server.serve_forever, name='test_server') + self.server_thread.start() + + def _stop_server(self): + self.server.shutdown() + self.server.server_close() + self.server_thread.join() + + def _new_client_connection(self, secure): + if secure: + tls_ctx_opt = TlsContextOptions() + tls_ctx_opt.override_default_trust_store_from_path(None, 'test/resources/unittests.crt') + tls_ctx = ClientTlsContext(tls_ctx_opt) + tls_conn_opt = tls_ctx.new_connection_options() + tls_conn_opt.set_server_name(self.hostname) + else: + tls_conn_opt = None + + connection_future = HttpClientConnection.new(self.hostname, self.port, tls_connection_options=tls_conn_opt) + return connection_future.result(self.timeout) + + def _test_connect(self, secure): + self._start_server(secure) + connection = self._new_client_connection(secure) + + # register shutdown callback + shutdown_callback_results = [] + + def shutdown_callback(error_code): + shutdown_callback_results.append(error_code) + + connection.add_shutdown_callback(shutdown_callback) + + # close connection + shutdown_error_code_from_close_future = connection.close().result(self.timeout) + + # assert that error code was reported via close_future and shutdown callback + # error_code should be 0 (normal shutdown) + self.assertEqual(0, shutdown_error_code_from_close_future) + self.assertEqual(1, len(shutdown_callback_results)) + self.assertEqual(0, shutdown_callback_results[0]) + self.assertFalse(connection.is_open()) + + self._stop_server() + + def test_connect_http(self): + self._test_connect(secure=False) + + def test_connect_https(self): + self._test_connect(secure=True) + + # The connection should shut itself down cleanly when the GC collects the HttpClientConnection Python object. + def _test_connection_closes_on_zero_refcount(self, secure): + self._start_server(secure) + + connection = self._new_client_connection(secure) + + # Subscribing for the shutdown callback shouldn't affect the refcount of the HttpClientConnection. + close_future = Future() + + def on_close(error_code): + close_future.set_result(error_code) + + connection.add_shutdown_callback(on_close) + + # This should cause the GC to collect the HttpClientConnection + del connection + + close_code = close_future.result(self.timeout) + self.assertEqual(0, close_code) + self._stop_server() + + def test_connection_closes_on_zero_refcount_http(self): + self._test_connection_closes_on_zero_refcount(secure=False) + + def test_connection_closes_on_zero_refcount_https(self): + self._test_connection_closes_on_zero_refcount(secure=True) + + # GET request receives this very file from the server. Super meta. + def _test_get(self, secure): + self._start_server(secure) + connection = self._new_client_connection(secure) + + test_asset_path = 'test/test_http_client.py' + + request = HttpRequest('GET', '/' + test_asset_path) + response = Response() + stream = connection.request(request, response.on_response, response.on_body) + + # wait for stream to complete + stream.completion_future.result(self.timeout) + + self.assertEqual(200, response.status_code) + + with open(test_asset_path, 'rb') as test_asset: + test_asset_bytes = test_asset.read() + self.assertEqual(test_asset_bytes, response.body) + + self.assertEqual(0, connection.close().result(self.timeout)) + + self._stop_server() + + def test_get_http(self): + self._test_get(secure=False) + + def test_get_https(self): + self._test_get(secure=True) + + # PUT request sends this very file to the server. + def _test_put(self, secure): + self._start_server(secure) + connection = self._new_client_connection(secure) + test_asset_path = 'test/test_http_client.py' + with open(test_asset_path, 'rb') as outgoing_body_stream: + outgoing_body_bytes = outgoing_body_stream.read() + headers = HttpHeaders([ + ('Content-Length', str(len(outgoing_body_bytes))), + ]) + + # seek back to start of stream before trying to send it + outgoing_body_stream.seek(0) + + request = HttpRequest('PUT', '/' + test_asset_path, headers, outgoing_body_stream) + response = Response() + http_stream = connection.request(request, response.on_response, response.on_body) + + # wait for stream to complete + http_stream.completion_future.result(self.timeout) + + self.assertEqual(200, response.status_code) + + # compare what we sent against what the server received + server_received = self.server.put_requests.get('/' + test_asset_path) + self.assertIsNotNone(server_received) + self.assertEqual(server_received, outgoing_body_bytes) + + self.assertEqual(0, connection.close().result(self.timeout)) + self._stop_server() + + def test_put_http(self): + self._test_put(secure=False) + + def test_put_https(self): + self._test_put(secure=True) + + +if __name__ == '__main__': + unittest.main()