diff --git a/awscrt/_test.py b/awscrt/_test.py index 06016f955..9c4eab0d6 100644 --- a/awscrt/_test.py +++ b/awscrt/_test.py @@ -10,6 +10,8 @@ import time import types +from awscrt.io import ClientBootstrap, DefaultHostResolver, EventLoopGroup + def native_memory_usage() -> int: """ @@ -80,6 +82,10 @@ def check_for_leaks(*, timeout_sec=10.0): if the test results will be made public as it may result in secrets being leaked. """ + ClientBootstrap.release_static_default() + EventLoopGroup.release_static_default() + DefaultHostResolver.release_static_default() + if os.getenv('AWS_CRT_MEMORY_TRACING') != '2': raise RuntimeError("environment variable AWS_CRT_MEMORY_TRACING=2 must be set for accurate leak checks") diff --git a/awscrt/auth.py b/awscrt/auth.py index a3f193045..c4262dc16 100644 --- a/awscrt/auth.py +++ b/awscrt/auth.py @@ -125,7 +125,7 @@ def __init__(self, binding): self._binding = binding @classmethod - def new_default_chain(cls, client_bootstrap): + def new_default_chain(cls, client_bootstrap=None): """ Create the default provider chain used by most AWS SDKs. @@ -137,12 +137,16 @@ def new_default_chain(cls, client_bootstrap): 4. (conditional, on by default) EC2 Instance Metadata Args: - client_bootstrap (ClientBootstrap): Client bootstrap to use when initiating socket connection. + client_bootstrap (Optional[ClientBootstrap]): Client bootstrap to use when initiating socket connection. + If not set, uses the default static ClientBootstrap instead. Returns: AwsCredentialsProvider: """ - assert isinstance(client_bootstrap, ClientBootstrap) + assert isinstance(client_bootstrap, ClientBootstrap) or client_bootstrap is None + + if client_bootstrap is None: + client_bootstrap = ClientBootstrap.get_or_create_static_default() binding = _awscrt.credentials_provider_new_chain_default(client_bootstrap) return cls(binding) @@ -170,7 +174,7 @@ def new_static(cls, access_key_id, secret_access_key, session_token=None): @classmethod def new_profile( cls, - client_bootstrap, + client_bootstrap=None, profile_name=None, config_filepath=None, credentials_filepath=None): @@ -179,7 +183,8 @@ def new_profile( loaded from the aws credentials file. Args: - client_bootstrap (ClientBootstrap): Client bootstrap to use when initiating socket connection. + client_bootstrap (Optional[ClientBootstrap]): Client bootstrap to use when initiating socket connection. + If not set, uses the static default ClientBootstrap instead. profile_name (Optional[str]): Name of profile to use. If not set, uses value from AWS_PROFILE environment variable. @@ -196,11 +201,14 @@ def new_profile( Returns: AwsCredentialsProvider: """ - assert isinstance(client_bootstrap, ClientBootstrap) + assert isinstance(client_bootstrap, ClientBootstrap) or client_bootstrap is None assert isinstance(profile_name, str) or profile_name is None assert isinstance(config_filepath, str) or config_filepath is None assert isinstance(credentials_filepath, str) or credentials_filepath is None + if client_bootstrap is None: + client_bootstrap = ClientBootstrap.get_or_create_static_default() + binding = _awscrt.credentials_provider_new_profile( client_bootstrap, profile_name, config_filepath, credentials_filepath) return cls(binding) diff --git a/awscrt/eventstream/rpc.py b/awscrt/eventstream/rpc.py index 29974b694..6c80268ac 100644 --- a/awscrt/eventstream/rpc.py +++ b/awscrt/eventstream/rpc.py @@ -258,7 +258,7 @@ def connect( handler: ClientConnectionHandler, host_name: str, port: int, - bootstrap: ClientBootstrap, + bootstrap: ClientBootstrap = None, socket_options: Optional[SocketOptions] = None, tls_connection_options: Optional[TlsConnectionOptions] = None) -> Future: """Asynchronously establish a new ClientConnection. @@ -271,6 +271,7 @@ def connect( port: Connect to port. bootstrap: Client bootstrap to use when initiating socket connection. + If None is provided, the default singleton is used. socket_options: Optional socket options. If None is provided, then default options are used. @@ -297,6 +298,9 @@ def connect( # Connection is not made available to user until setup callback fires connection = cls(host_name, port, handler) + if not bootstrap: + bootstrap = ClientBootstrap.get_or_create_static_default() + # connection._binding is set within the following call */ _awscrt.event_stream_rpc_client_connection_connect( host_name, diff --git a/awscrt/http.py b/awscrt/http.py index 5ffe9a6f3..258cdaaa8 100644 --- a/awscrt/http.py +++ b/awscrt/http.py @@ -11,7 +11,7 @@ from concurrent.futures import Future from awscrt import NativeResource import awscrt.exceptions -from awscrt.io import ClientBootstrap, EventLoopGroup, DefaultHostResolver, InputStream, TlsConnectionOptions, SocketOptions +from awscrt.io import ClientBootstrap, InputStream, TlsConnectionOptions, SocketOptions from enum import IntEnum @@ -82,7 +82,7 @@ class HttpClientConnection(HttpConnectionBase): def new(cls, host_name, port, - bootstrap, + bootstrap=None, socket_options=None, tls_connection_options=None, proxy_options=None): @@ -94,7 +94,8 @@ def new(cls, port (int): Connect to port. - bootstrap (ClientBootstrap): Client bootstrap to use when initiating socket connection. + bootstrap (Optional [ClientBootstrap]): Client bootstrap to use when initiating socket connection. + If None is provided, the default singleton is used. socket_options (Optional[SocketOptions]): Optional socket options. If None is provided, then default options are used. @@ -124,9 +125,7 @@ def new(cls, socket_options = SocketOptions() if not bootstrap: - event_loop_group = EventLoopGroup(1) - host_resolver = DefaultHostResolver(event_loop_group) - bootstrap = ClientBootstrap(event_loop_group, host_resolver) + bootstrap = ClientBootstrap.get_or_create_static_default() connection = cls() connection._host_name = host_name diff --git a/awscrt/io.py b/awscrt/io.py index 616e1ada9..c3e1f1916 100644 --- a/awscrt/io.py +++ b/awscrt/io.py @@ -60,6 +60,8 @@ class EventLoopGroup(NativeResource): EventLoopGroup object is destroyed. """ + _static_event_loop_group = None + _static_event_loop_group_lock = threading.Lock() __slots__ = ('shutdown_event') def __init__(self, num_threads=None, cpu_group=None): @@ -83,6 +85,18 @@ def on_shutdown(): self.shutdown_event = shutdown_event self._binding = _awscrt.event_loop_group_new(num_threads, is_pinned, cpu_group, on_shutdown) + @staticmethod + def get_or_create_static_default(): + with EventLoopGroup._static_event_loop_group_lock: + if EventLoopGroup._static_event_loop_group is None: + EventLoopGroup._static_event_loop_group = EventLoopGroup() + return EventLoopGroup._static_event_loop_group + + @staticmethod + def release_static_default(): + with EventLoopGroup._static_event_loop_group_lock: + EventLoopGroup._static_event_loop_group = None + class HostResolverBase(NativeResource): """DNS host resolver.""" @@ -96,6 +110,9 @@ class DefaultHostResolver(HostResolverBase): event_loop_group (EventLoopGroup): EventLoopGroup to use. max_hosts(int): Max host names to cache. """ + + _static_host_resolver = None + _static_host_resolver_lock = threading.Lock() __slots__ = () def __init__(self, event_loop_group, max_hosts=16): @@ -104,6 +121,19 @@ def __init__(self, event_loop_group, max_hosts=16): super().__init__() self._binding = _awscrt.host_resolver_new_default(max_hosts, event_loop_group) + @staticmethod + def get_or_create_static_default(): + with DefaultHostResolver._static_host_resolver_lock: + if DefaultHostResolver._static_host_resolver is None: + DefaultHostResolver._static_host_resolver = DefaultHostResolver( + EventLoopGroup.get_or_create_static_default()) + return DefaultHostResolver._static_host_resolver + + @staticmethod + def release_static_default(): + with DefaultHostResolver._static_host_resolver_lock: + DefaultHostResolver._static_host_resolver = None + class ClientBootstrap(NativeResource): """Handles creation and setup of client socket connections. @@ -117,6 +147,9 @@ class ClientBootstrap(NativeResource): internal resources finish shutting down. Shutdown begins when the ClientBootstrap object is destroyed. """ + + _static_client_bootstrap = None + _static_client_bootstrap_lock = threading.Lock() __slots__ = ('shutdown_event') def __init__(self, event_loop_group, host_resolver): @@ -133,6 +166,20 @@ def on_shutdown(): self.shutdown_event = shutdown_event self._binding = _awscrt.client_bootstrap_new(event_loop_group, host_resolver, on_shutdown) + @staticmethod + def get_or_create_static_default(): + with ClientBootstrap._static_client_bootstrap_lock: + if ClientBootstrap._static_client_bootstrap is None: + ClientBootstrap._static_client_bootstrap = ClientBootstrap( + EventLoopGroup.get_or_create_static_default(), + DefaultHostResolver.get_or_create_static_default()) + return ClientBootstrap._static_client_bootstrap + + @staticmethod + def release_static_default(): + with ClientBootstrap._static_client_bootstrap_lock: + ClientBootstrap._static_client_bootstrap = None + def _read_binary_file(filepath): with open(filepath, mode='rb') as fh: diff --git a/awscrt/mqtt.py b/awscrt/mqtt.py index 39c9b0b96..17b5a5e25 100644 --- a/awscrt/mqtt.py +++ b/awscrt/mqtt.py @@ -130,7 +130,8 @@ class Client(NativeResource): """MQTT client. Args: - bootstrap (ClientBootstrap): Client bootstrap to use when initiating new socket connections. + bootstrap (Optional [ClientBootstrap]): Client bootstrap to use when initiating new socket connections. + If None is provided, the default singleton is used. tls_ctx (Optional[ClientTlsContext]): TLS context for secure socket connections. If None is provided, then an unencrypted connection is used. @@ -138,12 +139,14 @@ class Client(NativeResource): __slots__ = ('tls_ctx') - def __init__(self, bootstrap, tls_ctx=None): - assert isinstance(bootstrap, ClientBootstrap) + def __init__(self, bootstrap=None, tls_ctx=None): + assert isinstance(bootstrap, ClientBootstrap) or bootstrap is None assert tls_ctx is None or isinstance(tls_ctx, ClientTlsContext) super().__init__() self.tls_ctx = tls_ctx + if not bootstrap: + bootstrap = ClientBootstrap.get_or_create_static_default() self._binding = _awscrt.mqtt_client_new(bootstrap, tls_ctx) @@ -435,6 +438,7 @@ def on_disconnect(): try: _awscrt.mqtt_client_connection_disconnect(self._binding, on_disconnect) + except Exception as e: future.set_exception(e) diff --git a/awscrt/s3.py b/awscrt/s3.py index e13459851..73c008ba4 100644 --- a/awscrt/s3.py +++ b/awscrt/s3.py @@ -54,7 +54,8 @@ class S3Client(NativeResource): """S3 client Keyword Args: - bootstrap (ClientBootstrap): Client bootstrap to use when initiating socket connection. + bootstrap (Optional [ClientBootstrap]): Client bootstrap to use when initiating socket connection. + If None is provided, the default singleton is used. region (str): Region that the S3 bucket lives in. @@ -94,7 +95,7 @@ def __init__( tls_connection_options=None, part_size=None, throughput_target_gbps=None): - assert isinstance(bootstrap, ClientBootstrap) + assert isinstance(bootstrap, ClientBootstrap) or bootstrap is None assert isinstance(region, str) assert isinstance(credential_provider, AwsCredentialsProvider) or credential_provider is None assert isinstance(tls_connection_options, TlsConnectionOptions) or tls_connection_options is None @@ -113,6 +114,9 @@ def on_shutdown(): shutdown_event.set() self._region = region self.shutdown_event = shutdown_event + + if not bootstrap: + bootstrap = ClientBootstrap.get_or_create_static_default() s3_client_core = _S3ClientCore(bootstrap, credential_provider, tls_connection_options) # C layer uses 0 to indicate defaults diff --git a/test/test_io.py b/test/test_io.py index a967c0d26..a5219a591 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -25,12 +25,35 @@ def test_shutdown_complete(self): del event_loop_group self.assertTrue(shutdown_event.wait(TIMEOUT)) + def test_init_defaults_singleton(self): + event_loop_group = EventLoopGroup.get_or_create_static_default() + + def test_init_defaults_singleton_is_singleton(self): + event_loop_group_one = EventLoopGroup.get_or_create_static_default() + event_loop_group_two = EventLoopGroup.get_or_create_static_default() + self.assertTrue(event_loop_group_one == event_loop_group_two) + + def test_shutdown_complete_singleton(self): + event_loop_group = EventLoopGroup.get_or_create_static_default() + shutdown_event = event_loop_group.shutdown_event + del event_loop_group + EventLoopGroup.release_static_default() + self.assertTrue(shutdown_event.wait(TIMEOUT)) + class DefaultHostResolverTest(NativeResourceTest): def test_init(self): event_loop_group = EventLoopGroup() host_resolver = DefaultHostResolver(event_loop_group) + def test_init_singleton(self): + host_resolver = DefaultHostResolver.get_or_create_static_default() + + def test_init_singleton_is_singleton(self): + host_resolver_one = DefaultHostResolver.get_or_create_static_default() + host_resolver_two = DefaultHostResolver.get_or_create_static_default() + self.assertTrue(host_resolver_one == host_resolver_two) + class ClientBootstrapTest(NativeResourceTest): def test_create_destroy(self): @@ -43,6 +66,20 @@ def test_create_destroy(self): del bootstrap self.assertTrue(bootstrap_shutdown_event.wait(TIMEOUT)) + def test_create_destroy_singleton(self): + bootstrap = ClientBootstrap.get_or_create_static_default() + + # ensure shutdown_event fires + bootstrap_shutdown_event = bootstrap.shutdown_event + del bootstrap + ClientBootstrap.release_static_default() + self.assertTrue(bootstrap_shutdown_event.wait(TIMEOUT)) + + def test_init_singleton_is_singleton(self): + client_one = ClientBootstrap.get_or_create_static_default() + client_two = ClientBootstrap.get_or_create_static_default() + self.assertTrue(client_one == client_two) + class ClientTlsContextTest(NativeResourceTest): def test_init_defaults(self): diff --git a/test/test_mqtt.py b/test/test_mqtt.py index 25a49d74c..566f90512 100644 --- a/test/test_mqtt.py +++ b/test/test_mqtt.py @@ -57,11 +57,8 @@ class MqttConnectionTest(NativeResourceTest): TEST_TOPIC = '/test/me/senpai' TEST_MSG = 'NOTICE ME!'.encode('utf8') - def _create_connection(self, auth_type=AuthType.CERT_AND_KEY): + def _create_connection(self, auth_type=AuthType.CERT_AND_KEY, use_static_singletons=False): config = Config(auth_type) - elg = EventLoopGroup() - resolver = DefaultHostResolver(elg) - bootstrap = ClientBootstrap(elg, resolver) if auth_type == AuthType.CERT_AND_KEY: tls_opts = TlsContextOptions.create_client_with_mtls_from_path(config.cert_path, config.key_path) @@ -89,7 +86,14 @@ def _create_connection(self, auth_type=AuthType.CERT_AND_KEY): # re-raise exception raise - client = Client(bootstrap, tls) + if use_static_singletons: + client = Client(tls_ctx=tls) + else: + elg = EventLoopGroup() + resolver = DefaultHostResolver(elg) + bootstrap = ClientBootstrap(elg, resolver) + client = Client(bootstrap, tls) + connection = Connection( client=client, client_id=create_client_id(), @@ -212,6 +216,16 @@ def on_sub_message(topic, payload): # disconnect connection.disconnect().result(TIMEOUT) + def test_connect_disconnect_with_default_singletons(self): + connection = self._create_connection(use_static_singletons=True) + connection.connect().result(TIMEOUT) + connection.disconnect().result(TIMEOUT) + + # free singletons + ClientBootstrap.release_static_default() + EventLoopGroup.release_static_default() + DefaultHostResolver.release_static_default() + if __name__ == 'main': unittest.main()