From 59212c8b1c1de94cb2cf15eedafb4b639295a51f Mon Sep 17 00:00:00 2001 From: Michael Graeb Date: Tue, 10 Dec 2019 15:47:39 -0800 Subject: [PATCH] bootstrap-shutdown-cb: python edition (#114) --- aws-c-auth | 2 +- aws-c-common | 2 +- aws-c-http | 2 +- aws-c-io | 2 +- aws-c-mqtt | 2 +- awscrt/io.py | 11 +++++-- source/io.c | 74 ++++++++++++++++++++++++++++++++++++----------- test/__init__.py | 7 +++-- test/test_auth.py | 10 +++---- test/test_io.py | 9 ++++-- test/test_mqtt.py | 28 +++++++++++------- 11 files changed, 105 insertions(+), 44 deletions(-) diff --git a/aws-c-auth b/aws-c-auth index 0ecb31397..27bc11de7 160000 --- a/aws-c-auth +++ b/aws-c-auth @@ -1 +1 @@ -Subproject commit 0ecb3139761c6ad75cc91d81fc952a4ef11a5848 +Subproject commit 27bc11de749ee82fa8d4c895d427d7d39f87ac02 diff --git a/aws-c-common b/aws-c-common index 17fd80daa..8612ad8d3 160000 --- a/aws-c-common +++ b/aws-c-common @@ -1 +1 @@ -Subproject commit 17fd80daa9536cb877caa8eb64b8b67b38e13163 +Subproject commit 8612ad8d33fef6f8308576a9d1e6ff2776347595 diff --git a/aws-c-http b/aws-c-http index aa3f83c8c..32f49a57a 160000 --- a/aws-c-http +++ b/aws-c-http @@ -1 +1 @@ -Subproject commit aa3f83c8cbc3b0fe07daa038e07f755e59b5ea95 +Subproject commit 32f49a57a5055fb626c799abb28a538286100e42 diff --git a/aws-c-io b/aws-c-io index e331d52cf..7ed71e3c4 160000 --- a/aws-c-io +++ b/aws-c-io @@ -1 +1 @@ -Subproject commit e331d52cf01e7eecd7d1870b22991a0f26ddcbf6 +Subproject commit 7ed71e3c4e6db9ec5676e01d082a90f6eb2b3bb2 diff --git a/aws-c-mqtt b/aws-c-mqtt index 393c1dc4b..27b29893c 160000 --- a/aws-c-mqtt +++ b/aws-c-mqtt @@ -1 +1 @@ -Subproject commit 393c1dc4bd73d39c14033b84a3eda63d1e223358 +Subproject commit 27b29893c46d388c34eacff1e6f6870df4decd46 diff --git a/awscrt/io.py b/awscrt/io.py index ea2bf8756..958ad7939 100644 --- a/awscrt/io.py +++ b/awscrt/io.py @@ -16,6 +16,7 @@ from awscrt import NativeResource, isinstance_str from enum import IntEnum import io +import threading class LogLevel(IntEnum): @@ -75,7 +76,7 @@ def __init__(self, event_loop_group, max_hosts=16): class ClientBootstrap(NativeResource): - __slots__ = () + __slots__ = ('shutdown_event') def __init__(self, event_loop_group, host_resolver): assert isinstance(event_loop_group, EventLoopGroup) @@ -83,7 +84,13 @@ def __init__(self, event_loop_group, host_resolver): super(ClientBootstrap, self).__init__() - self._binding = _awscrt.client_bootstrap_new(event_loop_group, host_resolver) + shutdown_event = threading.Event() + + def on_shutdown(): + shutdown_event.set() + + self.shutdown_event = shutdown_event + self._binding = _awscrt.client_bootstrap_new(event_loop_group, host_resolver, on_shutdown) def _read_binary_file(filepath): diff --git a/source/io.c b/source/io.c index c418c45eb..1e5fc6015 100644 --- a/source/io.c +++ b/source/io.c @@ -249,16 +249,43 @@ struct client_bootstrap_binding { /* Dependencies that must outlive this */ PyObject *event_loop_group; PyObject *host_resolver; + PyObject *shutdown_complete; }; -static void s_client_bootstrap_destructor(PyObject *bootstrap_capsule) { +/* Fires after the native client bootstrap finishes shutting down. */ +static void s_client_bootstrap_on_shutdown_complete(void *user_data) { + struct client_bootstrap_binding *bootstrap = user_data; + PyObject *shutdown_complete = bootstrap->shutdown_complete; + + /*************** GIL ACQUIRE ***************/ + PyGILState_STATE state = PyGILState_Ensure(); + + Py_XDECREF(bootstrap->host_resolver); + Py_XDECREF(bootstrap->event_loop_group); + + aws_mem_release(aws_py_get_allocator(), bootstrap); + + if (shutdown_complete) { + PyObject *result = PyObject_CallFunction(shutdown_complete, "()"); + if (result) { + Py_DECREF(result); + } else { + PyErr_WriteUnraisable(PyErr_Occurred()); + } + Py_DECREF(shutdown_complete); + } + + PyGILState_Release(state); + /*************** GIL RELEASE ***************/ +} + +/* Fires when python capsule is GC'd. + * Note that bootstrap shutdown is async, we can't release dependencies until it completes */ +static void s_client_bootstrap_capsule_destructor(PyObject *bootstrap_capsule) { struct client_bootstrap_binding *bootstrap = PyCapsule_GetPointer(bootstrap_capsule, s_capsule_name_client_bootstrap); - assert(bootstrap); - Py_DECREF(bootstrap->host_resolver); - Py_DECREF(bootstrap->event_loop_group); + aws_client_bootstrap_release(bootstrap->native); - aws_mem_release(aws_py_get_allocator(), bootstrap); } PyObject *aws_py_client_bootstrap_new(PyObject *self, PyObject *args) { @@ -268,8 +295,9 @@ PyObject *aws_py_client_bootstrap_new(PyObject *self, PyObject *args) { PyObject *elg_py; PyObject *host_resolver_py; + PyObject *shutdown_complete_py; - if (!PyArg_ParseTuple(args, "OO", &elg_py, &host_resolver_py)) { + if (!PyArg_ParseTuple(args, "OOO", &elg_py, &host_resolver_py, &shutdown_complete_py)) { return NULL; } @@ -291,15 +319,22 @@ PyObject *aws_py_client_bootstrap_new(PyObject *self, PyObject *args) { /* From hereon, we need to clean up if errors occur */ - bootstrap->native = aws_client_bootstrap_new(allocator, elg, host_resolver, NULL); - if (!bootstrap->native) { - PyErr_SetAwsLastError(); - goto bootstrap_new_failed; + PyObject *capsule = + PyCapsule_New(bootstrap, s_capsule_name_client_bootstrap, s_client_bootstrap_capsule_destructor); + if (!capsule) { + goto error; } - PyObject *capsule = PyCapsule_New(bootstrap, s_capsule_name_client_bootstrap, s_client_bootstrap_destructor); - if (!capsule) { - goto capsule_new_failed; + struct aws_client_bootstrap_options bootstrap_options = { + .event_loop_group = elg, + .host_resolver = host_resolver, + .on_shutdown_complete = s_client_bootstrap_on_shutdown_complete, + .user_data = bootstrap, + }; + bootstrap->native = aws_client_bootstrap_new(allocator, &bootstrap_options); + if (!bootstrap->native) { + PyErr_SetAwsLastError(); + goto error; } /* From hereon, nothing will fail */ @@ -310,12 +345,17 @@ PyObject *aws_py_client_bootstrap_new(PyObject *self, PyObject *args) { bootstrap->host_resolver = host_resolver_py; Py_INCREF(host_resolver_py); + bootstrap->shutdown_complete = shutdown_complete_py; + Py_INCREF(bootstrap->shutdown_complete); + return capsule; -capsule_new_failed: - aws_client_bootstrap_release(bootstrap->native); -bootstrap_new_failed: - aws_mem_release(allocator, bootstrap); +error: + if (capsule) { + Py_DECREF(capsule); + } else { + aws_mem_release(allocator, bootstrap); + } return NULL; } diff --git a/test/__init__.py b/test/__init__.py index bb2a1622c..a1f0625b7 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -20,6 +20,8 @@ import types import unittest +TIMEOUT = 10.0 + class NativeResourceTest(unittest.TestCase): """ @@ -33,8 +35,9 @@ def tearDown(self): gc.collect() # Native resources might need a few more ticks to finish cleaning themselves up. - if NativeResource._living: - time.sleep(1) + wait_until = time.time() + TIMEOUT + while NativeResource._living and time.time() < wait_until: + time.sleep(0.1) # Print out debugging info on leaking resources if NativeResource._living: diff --git a/test/test_auth.py b/test/test_auth.py index 71a001b68..048e6164d 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -16,7 +16,7 @@ import awscrt.io import datetime import os -from test import NativeResourceTest +from test import NativeResourceTest, TIMEOUT EXAMPLE_ACCESS_KEY_ID = 'example_access_key_id' EXAMPLE_SECRET_ACCESS_KEY = 'example_secret_access_key' @@ -65,7 +65,7 @@ def test_static_provider(self): EXAMPLE_SESSION_TOKEN) future = provider.get_credentials() - credentials = future.result() + credentials = future.result(TIMEOUT) self.assertEqual(EXAMPLE_ACCESS_KEY_ID, credentials.access_key_id) self.assertEqual(EXAMPLE_SECRET_ACCESS_KEY, credentials.secret_access_key) @@ -80,7 +80,7 @@ def test_static_provider(self): # self.example_secret_access_key) # future = provider.get_credentials() - # credentials = future.result() + # credentials = future.result(TIMEOUT) # self.assertEqual(self.example_access_key_id, credentials.access_key_id) # self.assertEqual(self.example_secret_access_key, credentials.secret_access_key) @@ -96,7 +96,7 @@ def test_default_provider(self): provider = awscrt.auth.AwsCredentialsProvider.new_default_chain(bootstrap) future = provider.get_credentials() - credentials = future.result() + credentials = future.result(TIMEOUT) self.assertEqual('credentials_test_access_key_id', credentials.access_key_id) self.assertEqual('credentials_test_secret_access_key', credentials.secret_access_key) @@ -234,7 +234,7 @@ def test_signing_sigv4_headers(self): signing_future = awscrt.auth.aws_sign_request(http_request, signing_config) - signing_result = signing_future.result(10) + signing_result = signing_future.result(TIMEOUT) self.assertIs(http_request, signing_result) # should be same object diff --git a/test/test_io.py b/test/test_io.py index 4f88f6276..c89c6159a 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -13,7 +13,7 @@ from __future__ import absolute_import from awscrt.io import ClientBootstrap, ClientTlsContext, DefaultHostResolver, EventLoopGroup, TlsConnectionOptions, TlsContextOptions -from test import NativeResourceTest +from test import NativeResourceTest, TIMEOUT import unittest @@ -32,11 +32,16 @@ def test_init(self): class ClientBootstrapTest(NativeResourceTest): - def test_init(self): + def test_create_destroy(self): event_loop_group = EventLoopGroup() host_resolver = DefaultHostResolver(event_loop_group) bootstrap = ClientBootstrap(event_loop_group, host_resolver) + # ensure shutdown_event fires + bootstrap_shutdown_event = bootstrap.shutdown_event + del bootstrap + self.assertTrue(bootstrap_shutdown_event.wait(TIMEOUT)) + class ClientTlsContextTest(NativeResourceTest): def test_init_defaults(self): diff --git a/test/test_mqtt.py b/test/test_mqtt.py index 9ebc7986f..012321eac 100644 --- a/test/test_mqtt.py +++ b/test/test_mqtt.py @@ -23,6 +23,7 @@ import unittest import boto3 import botocore.exceptions +import shutil import tempfile import time import uuid @@ -192,20 +193,25 @@ def test_mtls_from_path(self): bootstrap = ClientBootstrap(elg, resolver) # test "from path" builder by writing secrets to tempfiles - with tempfile.NamedTemporaryFile() as cert_file: - with tempfile.NamedTemporaryFile() as key_file: + tmp_dirpath = tempfile.mkdtemp() + try: + cert_filepath = os.path.join(tmp_dirpath, 'cert') + with open(cert_filepath, 'wb') as cert_file: cert_file.write(config.cert) - cert_file.flush() + key_filepath = os.path.join(tmp_dirpath, 'key') + with open(key_filepath, 'wb') as key_file: key_file.write(config.key) - key_file.flush() - - connection = awsiot_mqtt_connection_builder.mtls_from_path( - cert_filepath=cert_file.name, - pri_key_filepath=key_file.name, - endpoint=config.endpoint, - client_id=create_client_id(), - client_bootstrap=bootstrap) + + connection = awsiot_mqtt_connection_builder.mtls_from_path( + cert_filepath=cert_filepath, + pri_key_filepath=key_filepath, + endpoint=config.endpoint, + client_id=create_client_id(), + client_bootstrap=bootstrap) + + finally: + shutil.rmtree(tmp_dirpath) self._test_connection(connection)