Skip to content

Commit

Permalink
Callback kwargs (#107)
Browse files Browse the repository at this point in the history
* prep for GA, all callbacks now take kwargs.
  • Loading branch information
JonathanHenson authored Nov 28, 2019
1 parent 8a35dc7 commit eb3f0b8
Show file tree
Hide file tree
Showing 16 changed files with 181 additions and 89 deletions.
2 changes: 1 addition & 1 deletion aws-c-auth
2 changes: 1 addition & 1 deletion aws-c-mqtt
50 changes: 40 additions & 10 deletions awscrt/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,21 @@ class AwsSigningAlgorithm(IntEnum):
SigV4QueryParam = 1


class AwsBodySigningConfigType(IntEnum):
"""Body Signing config
BodySigningOff: No attempts will be made to sign the payload, and no
x-amz-content-sha256 header will be added to the request.
BodySigningOn: The body will be signed and x-amz-content-sha256 will contain
the value of the signature
UnsignedPayload: The body will not be signed, but x-amz-content-sha256 will contain
the value UNSIGNED-PAYLOAD. This value is currently only used for Amazon S3.
"""
BodySigningOff = 0
BodySigningOn = 1
UnsignedPayload = 2


class AwsSigningConfig(NativeResource):
"""
Configuration for use in AWS-related signing.
Expand All @@ -179,10 +194,10 @@ class AwsSigningConfig(NativeResource):
It is good practice to use a new config for each signature, or the date might get too old.
Naive dates (lacking timezone info) are assumed to be in local time.
"""
__slots__ = ()
__slots__ = ('_priv_should_sign_cb')

_attributes = ('algorithm', 'credentials_provider', 'region', 'service', 'date', 'should_sign_param',
'use_double_uri_encode', 'should_normalize_uri_path', 'sign_body')
'use_double_uri_encode', 'should_normalize_uri_path', 'body_signing_type')

def __init__(self,
algorithm, # type: AwsSigningAlgorithm
Expand All @@ -193,7 +208,7 @@ def __init__(self,
should_sign_param=None, # type: Optional[Callable[[str], bool]]
use_double_uri_encode=False, # type: bool
should_normalize_uri_path=True, # type: bool
sign_body=True # type: bool
body_signing_type=AwsBodySigningConfigType.BodySigningOn # type: AwsBodySigningConfigType
):
# type: (...) -> None

Expand All @@ -203,6 +218,7 @@ def __init__(self,
assert isinstance_str(service)
assert isinstance(date, datetime.datetime) or date is None
assert callable(should_sign_param) or should_sign_param is None
assert isinstance(body_signing_type, AwsBodySigningConfigType)

super(AwsSigningConfig, self).__init__()

Expand All @@ -220,17 +236,25 @@ def __init__(self,
epoch = datetime.datetime(1970, 1, 1, tzinfo=_utc)
timestamp = (date - epoch).total_seconds()

self._priv_should_sign_cb = should_sign_param

if should_sign_param is not None:
def should_sign_param_wrapper(name):
return should_sign_param(name=name)
else:
should_sign_param_wrapper = None

self._binding = _awscrt.signing_config_new(
algorithm,
credentials_provider,
region,
service,
date,
timestamp,
should_sign_param,
should_sign_param_wrapper,
use_double_uri_encode,
should_normalize_uri_path,
sign_body)
body_signing_type)

def replace(self, **kwargs):
"""
Expand Down Expand Up @@ -279,7 +303,7 @@ def should_sign_param(self):
supplements it. In particular, a header will get signed if and only if it returns true to both
the internal check (skips x-amzn-trace-id, user-agent) and this function (if defined).
"""
return _awscrt.signing_config_get_should_sign_param(self._binding)
return self._priv_should_sign_cb

@property
def use_double_uri_encode(self):
Expand All @@ -296,12 +320,18 @@ def should_normalize_uri_path(self):
return _awscrt.signing_config_get_should_normalize_uri_path(self._binding)

@property
def sign_body(self):
def body_signing_type(self):
"""
If true adds the x-amz-content-sha256 header (with appropriate value) to the canonical request,
otherwise does nothing
BodySigningOff: No attempts will be made to sign the payload, and no
x-amz-content-sha256 header will be added to the request.
BodySigningOn: The body will be signed and x-amz-content-sha256 will contain
the value of the signature
UnsignedPayload: The body will not be signed, but x-amz-content-sha256 will contain
the value UNSIGNED-PAYLOAD. This value is currently only used for Amazon S3.
"""
return _awscrt.signing_config_get_sign_body(self._binding)
return AwsBodySigningConfigType(_awscrt.signing_config_get_body_signing_type(self._binding))


def aws_sign_request(http_request, signing_config):
Expand Down
5 changes: 3 additions & 2 deletions awscrt/awsiot_mqtt_connection_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,9 @@ def websockets_with_default_aws_signing(region, credentials_provider, websocket_
"""
_check_required_kwargs(**kwargs)

def _should_sign_param(name):
def _should_sign_param(**kwargs):
blacklist = ['x-amz-date', 'x-amz-security-token']
name = kwargs['name']
return not (name.lower() in blacklist)

def _sign_websocket_handshake_request(handshake_args):
Expand All @@ -250,7 +251,7 @@ def _sign_websocket_handshake_request(handshake_args):
region=region,
service='iotdevicegateway',
should_sign_param=_should_sign_param,
sign_body=False)
body_signing_type=awscrt.auth.AwsBodySigningConfigType.BodySigningOff)

signing_future = awscrt.auth.aws_sign_request(handshake_args.http_request, signing_config)
signing_future.add_done_callback(lambda x: handshake_args.set_done(x.exception()))
Expand Down
4 changes: 2 additions & 2 deletions awscrt/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def completion_future(self):

def _on_body(self, chunk):
if self._on_body_cb:
self._on_body_cb(self, chunk)
self._on_body_cb(http_stream=self, chunk=chunk)


class HttpClientStream(HttpStreamBase):
Expand All @@ -184,7 +184,7 @@ def _on_response(self, status_code, name_value_pairs):
self._response_status_code = status_code

if self._on_response_cb:
self._on_response_cb(self, status_code, name_value_pairs)
self._on_response_cb(http_stream=self, status_code=status_code, headers=name_value_pairs)

def _on_complete(self, error_code):
if error_code == 0:
Expand Down
60 changes: 46 additions & 14 deletions awscrt/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,21 @@ def __init__(self,
If an existing session is resumed, the server remembers previous subscriptions
and sends mesages (with QoS1 or higher) that were published while the client was offline.
on_connection_interrupted (function): Optional callback with signature:
(Connection, awscrt.exceptions.AwsCrtError) -> None
Invoked when the MQTT connection is lost.
on_connection_interrupted (function): Optional callback invoked whenever the MQTT connection is lost.
The MQTT client will automatically attempt to reconnect.
on_connection_resumed (function): Optional callback with signature:
(Connection, ConnectReturnCode, session_present: bool) -> None
Invoked when the MQTT connection is automatically resumed.
The function should take **kwargs and return nothing.
The kwargs contain:
'connection': This MQTT Connection
'error': awscrt.exceptions.AwsCrtError
on_connection_resumed (function): Optional callback invoked whenever the MQTT connection
is automatically resumed. Function should take **kwargs and return nothing.
The kwargs contain:
'connection': This MQTT Connection
'return_code': ConnectReturnCode received from the server.
'session_present': True if resuming existing session. False if new session.
Note that the server has forgotten all previous subscriptions if this is False.
Subscriptions can be re-established via resubscribe_existing_topics().
reconnect_min_timeout_secs (int): Minimum time to wait between reconnect attempts.
Wait starts at min and doubles with each attempt until max is reached.
Expand Down Expand Up @@ -195,11 +202,14 @@ def __init__(self,

def _on_connection_interrupted(self, error_code):
if self._on_connection_interrupted_cb:
self._on_connection_interrupted_cb(self, awscrt.exceptions.from_code(error_code))
self._on_connection_interrupted_cb(connection=self, error=awscrt.exceptions.from_code(error_code))

def _on_connection_resumed(self, return_code, session_present):
if self._on_connection_resumed_cb:
self._on_connection_resumed_cb(self, ConnectReturnCode(return_code), session_present)
self._on_connection_resumed_cb(
connection=self,
error=connectionConnectReturnCode(return_code),
session_present=session_present)

def _ws_handshake_transform(self, http_request_binding, http_headers_binding, native_userdata):
if self._ws_handshake_transform_cb is None:
Expand All @@ -214,7 +224,7 @@ def _on_complete(f):
http_request = HttpRequest._from_bindings(http_request_binding, http_headers_binding)
transform_args = WebsocketHandshakeTransformArgs(self, http_request, future)
try:
self._ws_handshake_transform_cb(transform_args)
self._ws_handshake_transform_cb(transform_args=transform_args)
except Exception as e:
# Call set_done() if user failed to do so before uncaught exception was raised,
# there's a chance the callback wasn't callable and user has no idea we tried to hand them the baton.
Expand Down Expand Up @@ -289,12 +299,22 @@ def on_disconnect():

def subscribe(self, topic, qos, callback=None):
"""
callback: optional callback with signature (topic, message)
callback: Optional callback invoked when message received.
Function should take **kwargs and return nothing.
The kwargs contain:
'topic' (str): Topic receiving message.
'payload' (bytes): Payload of message.
"""

future = Future()
packet_id = 0

if callback:
def callback_wrapper(topic, payload):
callback(topic=topic, payload=payload)
else:
callback_wrapper = None

def suback(packet_id, topic, qos, error_code):
if error_code:
future.set_exception(awscrt.exceptions.from_code(error_code))
Expand All @@ -312,18 +332,30 @@ def suback(packet_id, topic, qos, error_code):
try:
assert callable(callback) or callback is None
assert isinstance(qos, QoS)
packet_id = _awscrt.mqtt_client_connection_subscribe(self._binding, topic, qos.value, callback, suback)
packet_id = _awscrt.mqtt_client_connection_subscribe(
self._binding, topic, qos.value, callback_wrapper, suback)
except Exception as e:
future.set_exception(e)

return future, packet_id

def on_message(self, callback):
"""
callback: callback with signature (topic, message), or None to disable.
callback: Callback invoked when message received, or None to disable.
Function should take **kwargs and return nothing.
The kwargs contain:
'topic' (str): Topic receiving message.
'payload' (bytes): Payload of message.
"""
assert callable(callback) or callback is None
_awscrt.mqtt_client_connection_on_message(self._binding, callback)

if callback:
def callback_wrapper(topic, payload):
callback(topic=topic, payload=payload)
else:
callback_wrapper = None

_awscrt.mqtt_client_connection_on_message(self._binding, callback_wrapper)

def unsubscribe(self, topic):
future = Future()
Expand Down
64 changes: 53 additions & 11 deletions elasticurl.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,63 @@ def print_header_list(headers):


parser = argparse.ArgumentParser()
parser.add_argument('url', help='URL to make request to. HTTPS is assumed unless port 80 is specified or HTTP is specified in the scheme.')
parser.add_argument(
'url',
help='URL to make request to. HTTPS is assumed unless port 80 is specified or HTTP is specified in the scheme.')
parser.add_argument('--cacert', required=False, help='FILE: path to a CA certificate file.')
parser.add_argument('--capath', required=False, help='PATH: path to a directory containing CA files.')
parser.add_argument('--cert', required=False, help='FILE: path to a PEM encoded certificate to use with mTLS')
parser.add_argument('--key', required=False, help='FILE: Path to a PEM encoded private key that matches cert.')
parser.add_argument('--connect_timeout', required=False, type=int, help='INT: time in milliseconds to wait for a connection.', default=3000)
parser.add_argument('-H', '--header', required=False, help='STRING: line to send as a header in format "name:value". May be specified multiple times.', action='append')
parser.add_argument(
'--connect_timeout',
required=False,
type=int,
help='INT: time in milliseconds to wait for a connection.',
default=3000)
parser.add_argument(
'-H',
'--header',
required=False,
help='STRING: line to send as a header in format "name:value". May be specified multiple times.',
action='append')
parser.add_argument('-d', '--data', required=False, help='STRING: Data to POST or PUT.')
parser.add_argument('--data_file', required=False, help='FILE: File to read from file and POST or PUT')
parser.add_argument('-M', '--method', required=False, help='STRING: Http Method verb to use for the request', default='GET')
parser.add_argument(
'-M',
'--method',
required=False,
help='STRING: Http Method verb to use for the request',
default='GET')
parser.add_argument('-G', '--get', required=False, help='uses GET for the verb', action='store_true')
parser.add_argument('-P', '--post', required=False, help='uses POST for the verb', action='store_true')
parser.add_argument('-I', '--head', required=False, help='uses HEAD for the verb', action='store_true')
parser.add_argument('-i', '--include', required=False, help='Includes headers in output', action='store_true', default=False)
parser.add_argument('-k', '--insecure', required=False, help='Turns off x.509 validation', action='store_true', default=False)
parser.add_argument(
'-i',
'--include',
required=False,
help='Includes headers in output',
action='store_true',
default=False)
parser.add_argument(
'-k',
'--insecure',
required=False,
help='Turns off x.509 validation',
action='store_true',
default=False)
parser.add_argument('-o', '--output', required=False, help='FILE: dumps content-body to FILE instead of stdout.')
parser.add_argument('-t', '--trace', required=False, help='FILE: dumps logs to FILE instead of stderr.')
parser.add_argument('-p', '--alpn', required=False, help='STRING: protocol for ALPN. May be specified multiple times.', action='append')
parser.add_argument('-v', '--verbose', required=False, help='ERROR|INFO|DEBUG|TRACE: log level to configure. Default is none.')
parser.add_argument(
'-p',
'--alpn',
required=False,
help='STRING: protocol for ALPN. May be specified multiple times.',
action='append')
parser.add_argument(
'-v',
'--verbose',
required=False,
help='ERROR|INFO|DEBUG|TRACE: log level to configure. Default is none.')

args = parser.parse_args()

Expand Down Expand Up @@ -85,7 +123,7 @@ def print_header_list(headers):

host_resolver = io.DefaultHostResolver(event_loop_group)

# client bootstrap knows how to connect all the pieces. In this case it also has the default dns resolver
# client bootstrap knows how to connect all the pieces. In this case it also has the default dns resolver
# baked in.
client_bootstrap = io.ClientBootstrap(event_loop_group, host_resolver)

Expand Down Expand Up @@ -126,13 +164,15 @@ def print_header_list(headers):
tls_connection_options.set_alpn_list(args.alpn)

# invoked up on the connection closing


def on_connection_shutdown(shutdown_future):
print('connection close with error: {}'.format(shutdown_future.exception()))


# invoked by the http request call as the response body is received in chunks
def on_incoming_body(http_stream, body_data):
output.write(body_data)
def on_incoming_body(http_stream, chunk):
output.write(chunk)


data_len = 0
Expand Down Expand Up @@ -190,6 +230,8 @@ def on_incoming_body(http_stream, body_data):
request.headers.add(name.strip(), value.strip())

# invoked as soon as the response headers are received


def response_received_cb(http_stream, status_code, headers):
if args.include:
print('Response Code: {}'.format(status_code))
Expand Down
6 changes: 3 additions & 3 deletions mqtt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def on_resubscribe_complete(resubscribe_future):

receive_results = {}
receive_event = threading.Event()
def on_receive_message(topic, message):
receive_results.update(locals())
def on_receive_message(**kwargs):
receive_results.update(kwargs)
receive_event.set()

# Run
Expand Down Expand Up @@ -133,7 +133,7 @@ def on_receive_message(topic, message):
print("Waiting to receive messsage")
assert(receive_event.wait(TIMEOUT))
assert(receive_results['topic'] == TOPIC)
assert(receive_results['message'].decode() == MESSAGE)
assert(receive_results['payload'].decode() == MESSAGE)

# Unsubscribe
print("Unsubscribing from topic")
Expand Down
Loading

0 comments on commit eb3f0b8

Please sign in to comment.