Skip to content

Commit

Permalink
Connection builder simplification (#334)
Browse files Browse the repository at this point in the history
These changes make it where it's not required to create and manage
a DefaultHostResolver, EventLoopGroup, or ClientBootstrap to make
a connection. Instead, default singletons have been added to handle
this use case.

Commit log
* First pass at builder simplification for Python V2 SDK
* Removed left over test code created when testing builder simplification, adjusted code to fit autopep8
* Additional changes to fix formatting in io.py after first pass at builder simplification. Also surpressed the warnings on lgtm about deleting the local variables, as it is needed to delete the global singletons correctly
* Fixed more Pep8 style issues with io.py after first pass at builder simplification
* Removed freeing singletons manually, as they are automatically freed when garbage collected
* Updated MQTT builder simplification after review
* Missed a super minor code format issue and didn't think to check. Fixed it
* Updated MQTT singleton test to free memory at the end and removed unused imports
* Make passing a ClientBootstrap optional in MqttClient, HttpClientConnection, and rpc.ClientConnect
* Fixed passing wrong parameter in test, removed unused import
* Code review changes: Adjusted lock use and cleaned up code
* Fixed minor mistakes in last commit
* Fixed another minor mistake in test
  • Loading branch information
TwistedTwigleg authored Mar 11, 2022
1 parent cc65145 commit fe9266a
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 23 deletions.
6 changes: 6 additions & 0 deletions awscrt/_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import time
import types

from awscrt.io import ClientBootstrap, DefaultHostResolver, EventLoopGroup


def native_memory_usage() -> int:
"""
Expand Down Expand Up @@ -80,6 +82,10 @@ def check_for_leaks(*, timeout_sec=10.0):
if the test results will be made public as it may result in secrets
being leaked.
"""
ClientBootstrap.release_static_default()
EventLoopGroup.release_static_default()
DefaultHostResolver.release_static_default()

if os.getenv('AWS_CRT_MEMORY_TRACING') != '2':
raise RuntimeError("environment variable AWS_CRT_MEMORY_TRACING=2 must be set for accurate leak checks")

Expand Down
20 changes: 14 additions & 6 deletions awscrt/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(self, binding):
self._binding = binding

@classmethod
def new_default_chain(cls, client_bootstrap):
def new_default_chain(cls, client_bootstrap=None):
"""
Create the default provider chain used by most AWS SDKs.
Expand All @@ -137,12 +137,16 @@ def new_default_chain(cls, client_bootstrap):
4. (conditional, on by default) EC2 Instance Metadata
Args:
client_bootstrap (ClientBootstrap): Client bootstrap to use when initiating socket connection.
client_bootstrap (Optional[ClientBootstrap]): Client bootstrap to use when initiating socket connection.
If not set, uses the default static ClientBootstrap instead.
Returns:
AwsCredentialsProvider:
"""
assert isinstance(client_bootstrap, ClientBootstrap)
assert isinstance(client_bootstrap, ClientBootstrap) or client_bootstrap is None

if client_bootstrap is None:
client_bootstrap = ClientBootstrap.get_or_create_static_default()

binding = _awscrt.credentials_provider_new_chain_default(client_bootstrap)
return cls(binding)
Expand Down Expand Up @@ -170,7 +174,7 @@ def new_static(cls, access_key_id, secret_access_key, session_token=None):
@classmethod
def new_profile(
cls,
client_bootstrap,
client_bootstrap=None,
profile_name=None,
config_filepath=None,
credentials_filepath=None):
Expand All @@ -179,7 +183,8 @@ def new_profile(
loaded from the aws credentials file.
Args:
client_bootstrap (ClientBootstrap): Client bootstrap to use when initiating socket connection.
client_bootstrap (Optional[ClientBootstrap]): Client bootstrap to use when initiating socket connection.
If not set, uses the static default ClientBootstrap instead.
profile_name (Optional[str]): Name of profile to use.
If not set, uses value from AWS_PROFILE environment variable.
Expand All @@ -196,11 +201,14 @@ def new_profile(
Returns:
AwsCredentialsProvider:
"""
assert isinstance(client_bootstrap, ClientBootstrap)
assert isinstance(client_bootstrap, ClientBootstrap) or client_bootstrap is None
assert isinstance(profile_name, str) or profile_name is None
assert isinstance(config_filepath, str) or config_filepath is None
assert isinstance(credentials_filepath, str) or credentials_filepath is None

if client_bootstrap is None:
client_bootstrap = ClientBootstrap.get_or_create_static_default()

binding = _awscrt.credentials_provider_new_profile(
client_bootstrap, profile_name, config_filepath, credentials_filepath)
return cls(binding)
Expand Down
6 changes: 5 additions & 1 deletion awscrt/eventstream/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def connect(
handler: ClientConnectionHandler,
host_name: str,
port: int,
bootstrap: ClientBootstrap,
bootstrap: ClientBootstrap = None,
socket_options: Optional[SocketOptions] = None,
tls_connection_options: Optional[TlsConnectionOptions] = None) -> Future:
"""Asynchronously establish a new ClientConnection.
Expand All @@ -271,6 +271,7 @@ def connect(
port: Connect to port.
bootstrap: Client bootstrap to use when initiating socket connection.
If None is provided, the default singleton is used.
socket_options: Optional socket options.
If None is provided, then default options are used.
Expand All @@ -297,6 +298,9 @@ def connect(
# Connection is not made available to user until setup callback fires
connection = cls(host_name, port, handler)

if not bootstrap:
bootstrap = ClientBootstrap.get_or_create_static_default()

# connection._binding is set within the following call */
_awscrt.event_stream_rpc_client_connection_connect(
host_name,
Expand Down
11 changes: 5 additions & 6 deletions awscrt/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from concurrent.futures import Future
from awscrt import NativeResource
import awscrt.exceptions
from awscrt.io import ClientBootstrap, EventLoopGroup, DefaultHostResolver, InputStream, TlsConnectionOptions, SocketOptions
from awscrt.io import ClientBootstrap, InputStream, TlsConnectionOptions, SocketOptions
from enum import IntEnum


Expand Down Expand Up @@ -82,7 +82,7 @@ class HttpClientConnection(HttpConnectionBase):
def new(cls,
host_name,
port,
bootstrap,
bootstrap=None,
socket_options=None,
tls_connection_options=None,
proxy_options=None):
Expand All @@ -94,7 +94,8 @@ def new(cls,
port (int): Connect to port.
bootstrap (ClientBootstrap): Client bootstrap to use when initiating socket connection.
bootstrap (Optional [ClientBootstrap]): Client bootstrap to use when initiating socket connection.
If None is provided, the default singleton is used.
socket_options (Optional[SocketOptions]): Optional socket options.
If None is provided, then default options are used.
Expand Down Expand Up @@ -124,9 +125,7 @@ def new(cls,
socket_options = SocketOptions()

if not bootstrap:
event_loop_group = EventLoopGroup(1)
host_resolver = DefaultHostResolver(event_loop_group)
bootstrap = ClientBootstrap(event_loop_group, host_resolver)
bootstrap = ClientBootstrap.get_or_create_static_default()

connection = cls()
connection._host_name = host_name
Expand Down
47 changes: 47 additions & 0 deletions awscrt/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class EventLoopGroup(NativeResource):
EventLoopGroup object is destroyed.
"""

_static_event_loop_group = None
_static_event_loop_group_lock = threading.Lock()
__slots__ = ('shutdown_event')

def __init__(self, num_threads=None, cpu_group=None):
Expand All @@ -83,6 +85,18 @@ def on_shutdown():
self.shutdown_event = shutdown_event
self._binding = _awscrt.event_loop_group_new(num_threads, is_pinned, cpu_group, on_shutdown)

@staticmethod
def get_or_create_static_default():
with EventLoopGroup._static_event_loop_group_lock:
if EventLoopGroup._static_event_loop_group is None:
EventLoopGroup._static_event_loop_group = EventLoopGroup()
return EventLoopGroup._static_event_loop_group

@staticmethod
def release_static_default():
with EventLoopGroup._static_event_loop_group_lock:
EventLoopGroup._static_event_loop_group = None


class HostResolverBase(NativeResource):
"""DNS host resolver."""
Expand All @@ -96,6 +110,9 @@ class DefaultHostResolver(HostResolverBase):
event_loop_group (EventLoopGroup): EventLoopGroup to use.
max_hosts(int): Max host names to cache.
"""

_static_host_resolver = None
_static_host_resolver_lock = threading.Lock()
__slots__ = ()

def __init__(self, event_loop_group, max_hosts=16):
Expand All @@ -104,6 +121,19 @@ def __init__(self, event_loop_group, max_hosts=16):
super().__init__()
self._binding = _awscrt.host_resolver_new_default(max_hosts, event_loop_group)

@staticmethod
def get_or_create_static_default():
with DefaultHostResolver._static_host_resolver_lock:
if DefaultHostResolver._static_host_resolver is None:
DefaultHostResolver._static_host_resolver = DefaultHostResolver(
EventLoopGroup.get_or_create_static_default())
return DefaultHostResolver._static_host_resolver

@staticmethod
def release_static_default():
with DefaultHostResolver._static_host_resolver_lock:
DefaultHostResolver._static_host_resolver = None


class ClientBootstrap(NativeResource):
"""Handles creation and setup of client socket connections.
Expand All @@ -117,6 +147,9 @@ class ClientBootstrap(NativeResource):
internal resources finish shutting down.
Shutdown begins when the ClientBootstrap object is destroyed.
"""

_static_client_bootstrap = None
_static_client_bootstrap_lock = threading.Lock()
__slots__ = ('shutdown_event')

def __init__(self, event_loop_group, host_resolver):
Expand All @@ -133,6 +166,20 @@ def on_shutdown():
self.shutdown_event = shutdown_event
self._binding = _awscrt.client_bootstrap_new(event_loop_group, host_resolver, on_shutdown)

@staticmethod
def get_or_create_static_default():
with ClientBootstrap._static_client_bootstrap_lock:
if ClientBootstrap._static_client_bootstrap is None:
ClientBootstrap._static_client_bootstrap = ClientBootstrap(
EventLoopGroup.get_or_create_static_default(),
DefaultHostResolver.get_or_create_static_default())
return ClientBootstrap._static_client_bootstrap

@staticmethod
def release_static_default():
with ClientBootstrap._static_client_bootstrap_lock:
ClientBootstrap._static_client_bootstrap = None


def _read_binary_file(filepath):
with open(filepath, mode='rb') as fh:
Expand Down
10 changes: 7 additions & 3 deletions awscrt/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,20 +130,23 @@ class Client(NativeResource):
"""MQTT client.
Args:
bootstrap (ClientBootstrap): Client bootstrap to use when initiating new socket connections.
bootstrap (Optional [ClientBootstrap]): Client bootstrap to use when initiating new socket connections.
If None is provided, the default singleton is used.
tls_ctx (Optional[ClientTlsContext]): TLS context for secure socket connections.
If None is provided, then an unencrypted connection is used.
"""

__slots__ = ('tls_ctx')

def __init__(self, bootstrap, tls_ctx=None):
assert isinstance(bootstrap, ClientBootstrap)
def __init__(self, bootstrap=None, tls_ctx=None):
assert isinstance(bootstrap, ClientBootstrap) or bootstrap is None
assert tls_ctx is None or isinstance(tls_ctx, ClientTlsContext)

super().__init__()
self.tls_ctx = tls_ctx
if not bootstrap:
bootstrap = ClientBootstrap.get_or_create_static_default()
self._binding = _awscrt.mqtt_client_new(bootstrap, tls_ctx)


Expand Down Expand Up @@ -435,6 +438,7 @@ def on_disconnect():

try:
_awscrt.mqtt_client_connection_disconnect(self._binding, on_disconnect)

except Exception as e:
future.set_exception(e)

Expand Down
8 changes: 6 additions & 2 deletions awscrt/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ class S3Client(NativeResource):
"""S3 client
Keyword Args:
bootstrap (ClientBootstrap): Client bootstrap to use when initiating socket connection.
bootstrap (Optional [ClientBootstrap]): Client bootstrap to use when initiating socket connection.
If None is provided, the default singleton is used.
region (str): Region that the S3 bucket lives in.
Expand Down Expand Up @@ -94,7 +95,7 @@ def __init__(
tls_connection_options=None,
part_size=None,
throughput_target_gbps=None):
assert isinstance(bootstrap, ClientBootstrap)
assert isinstance(bootstrap, ClientBootstrap) or bootstrap is None
assert isinstance(region, str)
assert isinstance(credential_provider, AwsCredentialsProvider) or credential_provider is None
assert isinstance(tls_connection_options, TlsConnectionOptions) or tls_connection_options is None
Expand All @@ -113,6 +114,9 @@ def on_shutdown():
shutdown_event.set()
self._region = region
self.shutdown_event = shutdown_event

if not bootstrap:
bootstrap = ClientBootstrap.get_or_create_static_default()
s3_client_core = _S3ClientCore(bootstrap, credential_provider, tls_connection_options)

# C layer uses 0 to indicate defaults
Expand Down
37 changes: 37 additions & 0 deletions test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,35 @@ def test_shutdown_complete(self):
del event_loop_group
self.assertTrue(shutdown_event.wait(TIMEOUT))

def test_init_defaults_singleton(self):
event_loop_group = EventLoopGroup.get_or_create_static_default()

def test_init_defaults_singleton_is_singleton(self):
event_loop_group_one = EventLoopGroup.get_or_create_static_default()
event_loop_group_two = EventLoopGroup.get_or_create_static_default()
self.assertTrue(event_loop_group_one == event_loop_group_two)

def test_shutdown_complete_singleton(self):
event_loop_group = EventLoopGroup.get_or_create_static_default()
shutdown_event = event_loop_group.shutdown_event
del event_loop_group
EventLoopGroup.release_static_default()
self.assertTrue(shutdown_event.wait(TIMEOUT))


class DefaultHostResolverTest(NativeResourceTest):
def test_init(self):
event_loop_group = EventLoopGroup()
host_resolver = DefaultHostResolver(event_loop_group)

def test_init_singleton(self):
host_resolver = DefaultHostResolver.get_or_create_static_default()

def test_init_singleton_is_singleton(self):
host_resolver_one = DefaultHostResolver.get_or_create_static_default()
host_resolver_two = DefaultHostResolver.get_or_create_static_default()
self.assertTrue(host_resolver_one == host_resolver_two)


class ClientBootstrapTest(NativeResourceTest):
def test_create_destroy(self):
Expand All @@ -43,6 +66,20 @@ def test_create_destroy(self):
del bootstrap
self.assertTrue(bootstrap_shutdown_event.wait(TIMEOUT))

def test_create_destroy_singleton(self):
bootstrap = ClientBootstrap.get_or_create_static_default()

# ensure shutdown_event fires
bootstrap_shutdown_event = bootstrap.shutdown_event
del bootstrap
ClientBootstrap.release_static_default()
self.assertTrue(bootstrap_shutdown_event.wait(TIMEOUT))

def test_init_singleton_is_singleton(self):
client_one = ClientBootstrap.get_or_create_static_default()
client_two = ClientBootstrap.get_or_create_static_default()
self.assertTrue(client_one == client_two)


class ClientTlsContextTest(NativeResourceTest):
def test_init_defaults(self):
Expand Down
Loading

0 comments on commit fe9266a

Please sign in to comment.