From fe9266af9f33c48e12869ccf783d6ccf3b98afa9 Mon Sep 17 00:00:00 2001 From: TwistedTwigleg Date: Fri, 11 Mar 2022 09:25:39 -0500 Subject: [PATCH] Connection builder simplification (#334) These changes make it where it's not required to create and manage a DefaultHostResolver, EventLoopGroup, or ClientBootstrap to make a connection. Instead, default singletons have been added to handle this use case. Commit log * First pass at builder simplification for Python V2 SDK * Removed left over test code created when testing builder simplification, adjusted code to fit autopep8 * Additional changes to fix formatting in io.py after first pass at builder simplification. Also surpressed the warnings on lgtm about deleting the local variables, as it is needed to delete the global singletons correctly * Fixed more Pep8 style issues with io.py after first pass at builder simplification * Removed freeing singletons manually, as they are automatically freed when garbage collected * Updated MQTT builder simplification after review * Missed a super minor code format issue and didn't think to check. Fixed it * Updated MQTT singleton test to free memory at the end and removed unused imports * Make passing a ClientBootstrap optional in MqttClient, HttpClientConnection, and rpc.ClientConnect * Fixed passing wrong parameter in test, removed unused import * Code review changes: Adjusted lock use and cleaned up code * Fixed minor mistakes in last commit * Fixed another minor mistake in test --- awscrt/_test.py | 6 +++++ awscrt/auth.py | 20 ++++++++++++----- awscrt/eventstream/rpc.py | 6 ++++- awscrt/http.py | 11 +++++---- awscrt/io.py | 47 +++++++++++++++++++++++++++++++++++++++ awscrt/mqtt.py | 10 ++++++--- awscrt/s3.py | 8 +++++-- test/test_io.py | 37 ++++++++++++++++++++++++++++++ test/test_mqtt.py | 24 +++++++++++++++----- 9 files changed, 146 insertions(+), 23 deletions(-) diff --git a/awscrt/_test.py b/awscrt/_test.py index 06016f955..9c4eab0d6 100644 --- a/awscrt/_test.py +++ b/awscrt/_test.py @@ -10,6 +10,8 @@ import time import types +from awscrt.io import ClientBootstrap, DefaultHostResolver, EventLoopGroup + def native_memory_usage() -> int: """ @@ -80,6 +82,10 @@ def check_for_leaks(*, timeout_sec=10.0): if the test results will be made public as it may result in secrets being leaked. """ + ClientBootstrap.release_static_default() + EventLoopGroup.release_static_default() + DefaultHostResolver.release_static_default() + if os.getenv('AWS_CRT_MEMORY_TRACING') != '2': raise RuntimeError("environment variable AWS_CRT_MEMORY_TRACING=2 must be set for accurate leak checks") diff --git a/awscrt/auth.py b/awscrt/auth.py index a3f193045..c4262dc16 100644 --- a/awscrt/auth.py +++ b/awscrt/auth.py @@ -125,7 +125,7 @@ def __init__(self, binding): self._binding = binding @classmethod - def new_default_chain(cls, client_bootstrap): + def new_default_chain(cls, client_bootstrap=None): """ Create the default provider chain used by most AWS SDKs. @@ -137,12 +137,16 @@ def new_default_chain(cls, client_bootstrap): 4. (conditional, on by default) EC2 Instance Metadata Args: - client_bootstrap (ClientBootstrap): Client bootstrap to use when initiating socket connection. + client_bootstrap (Optional[ClientBootstrap]): Client bootstrap to use when initiating socket connection. + If not set, uses the default static ClientBootstrap instead. Returns: AwsCredentialsProvider: """ - assert isinstance(client_bootstrap, ClientBootstrap) + assert isinstance(client_bootstrap, ClientBootstrap) or client_bootstrap is None + + if client_bootstrap is None: + client_bootstrap = ClientBootstrap.get_or_create_static_default() binding = _awscrt.credentials_provider_new_chain_default(client_bootstrap) return cls(binding) @@ -170,7 +174,7 @@ def new_static(cls, access_key_id, secret_access_key, session_token=None): @classmethod def new_profile( cls, - client_bootstrap, + client_bootstrap=None, profile_name=None, config_filepath=None, credentials_filepath=None): @@ -179,7 +183,8 @@ def new_profile( loaded from the aws credentials file. Args: - client_bootstrap (ClientBootstrap): Client bootstrap to use when initiating socket connection. + client_bootstrap (Optional[ClientBootstrap]): Client bootstrap to use when initiating socket connection. + If not set, uses the static default ClientBootstrap instead. profile_name (Optional[str]): Name of profile to use. If not set, uses value from AWS_PROFILE environment variable. @@ -196,11 +201,14 @@ def new_profile( Returns: AwsCredentialsProvider: """ - assert isinstance(client_bootstrap, ClientBootstrap) + assert isinstance(client_bootstrap, ClientBootstrap) or client_bootstrap is None assert isinstance(profile_name, str) or profile_name is None assert isinstance(config_filepath, str) or config_filepath is None assert isinstance(credentials_filepath, str) or credentials_filepath is None + if client_bootstrap is None: + client_bootstrap = ClientBootstrap.get_or_create_static_default() + binding = _awscrt.credentials_provider_new_profile( client_bootstrap, profile_name, config_filepath, credentials_filepath) return cls(binding) diff --git a/awscrt/eventstream/rpc.py b/awscrt/eventstream/rpc.py index 29974b694..6c80268ac 100644 --- a/awscrt/eventstream/rpc.py +++ b/awscrt/eventstream/rpc.py @@ -258,7 +258,7 @@ def connect( handler: ClientConnectionHandler, host_name: str, port: int, - bootstrap: ClientBootstrap, + bootstrap: ClientBootstrap = None, socket_options: Optional[SocketOptions] = None, tls_connection_options: Optional[TlsConnectionOptions] = None) -> Future: """Asynchronously establish a new ClientConnection. @@ -271,6 +271,7 @@ def connect( port: Connect to port. bootstrap: Client bootstrap to use when initiating socket connection. + If None is provided, the default singleton is used. socket_options: Optional socket options. If None is provided, then default options are used. @@ -297,6 +298,9 @@ def connect( # Connection is not made available to user until setup callback fires connection = cls(host_name, port, handler) + if not bootstrap: + bootstrap = ClientBootstrap.get_or_create_static_default() + # connection._binding is set within the following call */ _awscrt.event_stream_rpc_client_connection_connect( host_name, diff --git a/awscrt/http.py b/awscrt/http.py index 5ffe9a6f3..258cdaaa8 100644 --- a/awscrt/http.py +++ b/awscrt/http.py @@ -11,7 +11,7 @@ from concurrent.futures import Future from awscrt import NativeResource import awscrt.exceptions -from awscrt.io import ClientBootstrap, EventLoopGroup, DefaultHostResolver, InputStream, TlsConnectionOptions, SocketOptions +from awscrt.io import ClientBootstrap, InputStream, TlsConnectionOptions, SocketOptions from enum import IntEnum @@ -82,7 +82,7 @@ class HttpClientConnection(HttpConnectionBase): def new(cls, host_name, port, - bootstrap, + bootstrap=None, socket_options=None, tls_connection_options=None, proxy_options=None): @@ -94,7 +94,8 @@ def new(cls, port (int): Connect to port. - bootstrap (ClientBootstrap): Client bootstrap to use when initiating socket connection. + bootstrap (Optional [ClientBootstrap]): Client bootstrap to use when initiating socket connection. + If None is provided, the default singleton is used. socket_options (Optional[SocketOptions]): Optional socket options. If None is provided, then default options are used. @@ -124,9 +125,7 @@ def new(cls, socket_options = SocketOptions() if not bootstrap: - event_loop_group = EventLoopGroup(1) - host_resolver = DefaultHostResolver(event_loop_group) - bootstrap = ClientBootstrap(event_loop_group, host_resolver) + bootstrap = ClientBootstrap.get_or_create_static_default() connection = cls() connection._host_name = host_name diff --git a/awscrt/io.py b/awscrt/io.py index 616e1ada9..c3e1f1916 100644 --- a/awscrt/io.py +++ b/awscrt/io.py @@ -60,6 +60,8 @@ class EventLoopGroup(NativeResource): EventLoopGroup object is destroyed. """ + _static_event_loop_group = None + _static_event_loop_group_lock = threading.Lock() __slots__ = ('shutdown_event') def __init__(self, num_threads=None, cpu_group=None): @@ -83,6 +85,18 @@ def on_shutdown(): self.shutdown_event = shutdown_event self._binding = _awscrt.event_loop_group_new(num_threads, is_pinned, cpu_group, on_shutdown) + @staticmethod + def get_or_create_static_default(): + with EventLoopGroup._static_event_loop_group_lock: + if EventLoopGroup._static_event_loop_group is None: + EventLoopGroup._static_event_loop_group = EventLoopGroup() + return EventLoopGroup._static_event_loop_group + + @staticmethod + def release_static_default(): + with EventLoopGroup._static_event_loop_group_lock: + EventLoopGroup._static_event_loop_group = None + class HostResolverBase(NativeResource): """DNS host resolver.""" @@ -96,6 +110,9 @@ class DefaultHostResolver(HostResolverBase): event_loop_group (EventLoopGroup): EventLoopGroup to use. max_hosts(int): Max host names to cache. """ + + _static_host_resolver = None + _static_host_resolver_lock = threading.Lock() __slots__ = () def __init__(self, event_loop_group, max_hosts=16): @@ -104,6 +121,19 @@ def __init__(self, event_loop_group, max_hosts=16): super().__init__() self._binding = _awscrt.host_resolver_new_default(max_hosts, event_loop_group) + @staticmethod + def get_or_create_static_default(): + with DefaultHostResolver._static_host_resolver_lock: + if DefaultHostResolver._static_host_resolver is None: + DefaultHostResolver._static_host_resolver = DefaultHostResolver( + EventLoopGroup.get_or_create_static_default()) + return DefaultHostResolver._static_host_resolver + + @staticmethod + def release_static_default(): + with DefaultHostResolver._static_host_resolver_lock: + DefaultHostResolver._static_host_resolver = None + class ClientBootstrap(NativeResource): """Handles creation and setup of client socket connections. @@ -117,6 +147,9 @@ class ClientBootstrap(NativeResource): internal resources finish shutting down. Shutdown begins when the ClientBootstrap object is destroyed. """ + + _static_client_bootstrap = None + _static_client_bootstrap_lock = threading.Lock() __slots__ = ('shutdown_event') def __init__(self, event_loop_group, host_resolver): @@ -133,6 +166,20 @@ def on_shutdown(): self.shutdown_event = shutdown_event self._binding = _awscrt.client_bootstrap_new(event_loop_group, host_resolver, on_shutdown) + @staticmethod + def get_or_create_static_default(): + with ClientBootstrap._static_client_bootstrap_lock: + if ClientBootstrap._static_client_bootstrap is None: + ClientBootstrap._static_client_bootstrap = ClientBootstrap( + EventLoopGroup.get_or_create_static_default(), + DefaultHostResolver.get_or_create_static_default()) + return ClientBootstrap._static_client_bootstrap + + @staticmethod + def release_static_default(): + with ClientBootstrap._static_client_bootstrap_lock: + ClientBootstrap._static_client_bootstrap = None + def _read_binary_file(filepath): with open(filepath, mode='rb') as fh: diff --git a/awscrt/mqtt.py b/awscrt/mqtt.py index 39c9b0b96..17b5a5e25 100644 --- a/awscrt/mqtt.py +++ b/awscrt/mqtt.py @@ -130,7 +130,8 @@ class Client(NativeResource): """MQTT client. Args: - bootstrap (ClientBootstrap): Client bootstrap to use when initiating new socket connections. + bootstrap (Optional [ClientBootstrap]): Client bootstrap to use when initiating new socket connections. + If None is provided, the default singleton is used. tls_ctx (Optional[ClientTlsContext]): TLS context for secure socket connections. If None is provided, then an unencrypted connection is used. @@ -138,12 +139,14 @@ class Client(NativeResource): __slots__ = ('tls_ctx') - def __init__(self, bootstrap, tls_ctx=None): - assert isinstance(bootstrap, ClientBootstrap) + def __init__(self, bootstrap=None, tls_ctx=None): + assert isinstance(bootstrap, ClientBootstrap) or bootstrap is None assert tls_ctx is None or isinstance(tls_ctx, ClientTlsContext) super().__init__() self.tls_ctx = tls_ctx + if not bootstrap: + bootstrap = ClientBootstrap.get_or_create_static_default() self._binding = _awscrt.mqtt_client_new(bootstrap, tls_ctx) @@ -435,6 +438,7 @@ def on_disconnect(): try: _awscrt.mqtt_client_connection_disconnect(self._binding, on_disconnect) + except Exception as e: future.set_exception(e) diff --git a/awscrt/s3.py b/awscrt/s3.py index e13459851..73c008ba4 100644 --- a/awscrt/s3.py +++ b/awscrt/s3.py @@ -54,7 +54,8 @@ class S3Client(NativeResource): """S3 client Keyword Args: - bootstrap (ClientBootstrap): Client bootstrap to use when initiating socket connection. + bootstrap (Optional [ClientBootstrap]): Client bootstrap to use when initiating socket connection. + If None is provided, the default singleton is used. region (str): Region that the S3 bucket lives in. @@ -94,7 +95,7 @@ def __init__( tls_connection_options=None, part_size=None, throughput_target_gbps=None): - assert isinstance(bootstrap, ClientBootstrap) + assert isinstance(bootstrap, ClientBootstrap) or bootstrap is None assert isinstance(region, str) assert isinstance(credential_provider, AwsCredentialsProvider) or credential_provider is None assert isinstance(tls_connection_options, TlsConnectionOptions) or tls_connection_options is None @@ -113,6 +114,9 @@ def on_shutdown(): shutdown_event.set() self._region = region self.shutdown_event = shutdown_event + + if not bootstrap: + bootstrap = ClientBootstrap.get_or_create_static_default() s3_client_core = _S3ClientCore(bootstrap, credential_provider, tls_connection_options) # C layer uses 0 to indicate defaults diff --git a/test/test_io.py b/test/test_io.py index a967c0d26..a5219a591 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -25,12 +25,35 @@ def test_shutdown_complete(self): del event_loop_group self.assertTrue(shutdown_event.wait(TIMEOUT)) + def test_init_defaults_singleton(self): + event_loop_group = EventLoopGroup.get_or_create_static_default() + + def test_init_defaults_singleton_is_singleton(self): + event_loop_group_one = EventLoopGroup.get_or_create_static_default() + event_loop_group_two = EventLoopGroup.get_or_create_static_default() + self.assertTrue(event_loop_group_one == event_loop_group_two) + + def test_shutdown_complete_singleton(self): + event_loop_group = EventLoopGroup.get_or_create_static_default() + shutdown_event = event_loop_group.shutdown_event + del event_loop_group + EventLoopGroup.release_static_default() + self.assertTrue(shutdown_event.wait(TIMEOUT)) + class DefaultHostResolverTest(NativeResourceTest): def test_init(self): event_loop_group = EventLoopGroup() host_resolver = DefaultHostResolver(event_loop_group) + def test_init_singleton(self): + host_resolver = DefaultHostResolver.get_or_create_static_default() + + def test_init_singleton_is_singleton(self): + host_resolver_one = DefaultHostResolver.get_or_create_static_default() + host_resolver_two = DefaultHostResolver.get_or_create_static_default() + self.assertTrue(host_resolver_one == host_resolver_two) + class ClientBootstrapTest(NativeResourceTest): def test_create_destroy(self): @@ -43,6 +66,20 @@ def test_create_destroy(self): del bootstrap self.assertTrue(bootstrap_shutdown_event.wait(TIMEOUT)) + def test_create_destroy_singleton(self): + bootstrap = ClientBootstrap.get_or_create_static_default() + + # ensure shutdown_event fires + bootstrap_shutdown_event = bootstrap.shutdown_event + del bootstrap + ClientBootstrap.release_static_default() + self.assertTrue(bootstrap_shutdown_event.wait(TIMEOUT)) + + def test_init_singleton_is_singleton(self): + client_one = ClientBootstrap.get_or_create_static_default() + client_two = ClientBootstrap.get_or_create_static_default() + self.assertTrue(client_one == client_two) + class ClientTlsContextTest(NativeResourceTest): def test_init_defaults(self): diff --git a/test/test_mqtt.py b/test/test_mqtt.py index 25a49d74c..566f90512 100644 --- a/test/test_mqtt.py +++ b/test/test_mqtt.py @@ -57,11 +57,8 @@ class MqttConnectionTest(NativeResourceTest): TEST_TOPIC = '/test/me/senpai' TEST_MSG = 'NOTICE ME!'.encode('utf8') - def _create_connection(self, auth_type=AuthType.CERT_AND_KEY): + def _create_connection(self, auth_type=AuthType.CERT_AND_KEY, use_static_singletons=False): config = Config(auth_type) - elg = EventLoopGroup() - resolver = DefaultHostResolver(elg) - bootstrap = ClientBootstrap(elg, resolver) if auth_type == AuthType.CERT_AND_KEY: tls_opts = TlsContextOptions.create_client_with_mtls_from_path(config.cert_path, config.key_path) @@ -89,7 +86,14 @@ def _create_connection(self, auth_type=AuthType.CERT_AND_KEY): # re-raise exception raise - client = Client(bootstrap, tls) + if use_static_singletons: + client = Client(tls_ctx=tls) + else: + elg = EventLoopGroup() + resolver = DefaultHostResolver(elg) + bootstrap = ClientBootstrap(elg, resolver) + client = Client(bootstrap, tls) + connection = Connection( client=client, client_id=create_client_id(), @@ -212,6 +216,16 @@ def on_sub_message(topic, payload): # disconnect connection.disconnect().result(TIMEOUT) + def test_connect_disconnect_with_default_singletons(self): + connection = self._create_connection(use_static_singletons=True) + connection.connect().result(TIMEOUT) + connection.disconnect().result(TIMEOUT) + + # free singletons + ClientBootstrap.release_static_default() + EventLoopGroup.release_static_default() + DefaultHostResolver.release_static_default() + if __name__ == 'main': unittest.main()