Skip to content

Commit

Permalink
on_message (any publish) + MQTT tests (#94)
Browse files Browse the repository at this point in the history
* Updated CRT libs

* Added boto3 as a depedency for testing

* Fixed self-referencing in MQTT connection

* Added support for null sub callbacks
  • Loading branch information
Justin Boswell authored Nov 19, 2019
1 parent a34bb20 commit c10c6d2
Show file tree
Hide file tree
Showing 11 changed files with 208 additions and 33 deletions.
32 changes: 17 additions & 15 deletions awscrt/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, bootstrap, tls_ctx=None):


class Connection(NativeResource):
__slots__ = ('client', '_on_connection_interrupted_cb', '_on_connection_resumed_cb')
__slots__ = ('client')

def __init__(self,
client,
Expand All @@ -85,23 +85,21 @@ def __init__(self,

super(Connection, self).__init__()
self.client = client
self._on_connection_interrupted_cb = on_connection_interrupted
self._on_connection_resumed_cb = on_connection_resumed

def _on_connection_interrupted(error_code):
if on_connection_interrupted:
on_connection_interrupted(self, error_code)

def _on_connection_resumed(error_code, session_present):
if on_connection_resumed:
on_connection_resumed(self, error_code, session_present)

self._binding = _awscrt.mqtt_client_connection_new(
client,
self._on_connection_interrupted,
self._on_connection_resumed,
_on_connection_interrupted,
_on_connection_resumed,
)

def _on_connection_interrupted(self, error_code):
if self._on_connection_interrupted_cb:
self._on_connection_interrupted_cb(self, error_code)

def _on_connection_resumed(self, error_code, session_present):
if self._on_connection_resumed_cb:
self._on_connection_resumed_cb(self, error_code, session_present)

def connect(self,
client_id,
host_name, port,
Expand Down Expand Up @@ -184,7 +182,7 @@ def on_disconnect():

return future

def subscribe(self, topic, qos, callback):
def subscribe(self, topic, qos, callback=None):
"""
callback: callback with signature (topic, message)
"""
Expand All @@ -207,14 +205,18 @@ def suback(packet_id, topic, qos, error_code):
))

try:
assert callable(callback)
assert callable(callback) or callback is None
assert isinstance(qos, QoS)
packet_id = _awscrt.mqtt_client_connection_subscribe(self._binding, topic, qos.value, callback, suback)
except Exception as e:
future.set_exception(e)

return future, packet_id

def on_message(self, callback):
assert callable(callback)
_awscrt.mqtt_client_connection_on_message(self._binding, callback)

def unsubscribe(self, topic):
future = Future()
packet_id = 0
Expand Down
7 changes: 6 additions & 1 deletion builder.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,27 @@
},
"post_build_steps": [
["echo", "------ Python 3.6 ------"],
["/opt/python/cp36-cp36m/bin/python", "-m", "pip", "install", "--upgrade", "--trusted-host", "pypi.org", "--trusted-host", "files.pythonhosted.org", "pip", "setuptools", "boto3", "autopep8"],
["/opt/python/cp36-cp36m/bin/python", "setup.py", "--verbose", "build_ext", "--include-dirs{openssl_include}", "--library-dirs{openssl_lib}", "install"],
["/opt/python/cp36-cp36m/bin/python", "-m", "unittest", "discover", "--verbose"],
["/opt/python/cp37-cp37m/bin/python", "aws-c-http/integration-testing/http_client_test.py", "/opt/python/cp36-cp36m/bin/python", "elasticurl.py"],
["echo", "------ Python 3.5 ------"],
["/opt/python/cp35-cp35m/bin/python", "-m", "pip", "install", "--upgrade", "--trusted-host", "pypi.org", "--trusted-host", "files.pythonhosted.org", "pip", "setuptools", "boto3", "autopep8"],
["/opt/python/cp35-cp35m/bin/python", "setup.py", "--verbose", "build_ext", "--include-dirs{openssl_include}", "--library-dirs{openssl_lib}", "install"],
["/opt/python/cp35-cp35m/bin/python", "-m", "unittest", "discover", "--verbose"],
["/opt/python/cp37-cp37m/bin/python", "aws-c-http/integration-testing/http_client_test.py", "/opt/python/cp35-cp35m/bin/python", "elasticurl.py"],
["echo", "------ Python 3.4 ------"],
["/opt/python/cp34-cp34m/bin/python", "-m", "pip", "install", "--upgrade", "--trusted-host", "pypi.org", "--trusted-host", "files.pythonhosted.org", "pip", "setuptools", "boto3", "autopep8"],
["/opt/python/cp34-cp34m/bin/python", "setup.py", "--verbose", "build_ext", "--include-dirs{openssl_include}", "--library-dirs{openssl_lib}", "install"],
["/opt/python/cp34-cp34m/bin/python", "-m", "unittest", "discover", "--verbose"],
["/opt/python/cp37-cp37m/bin/python", "aws-c-http/integration-testing/http_client_test.py", "/opt/python/cp34-cp34m/bin/python", "elasticurl.py"],
["echo", "------ Python 2.7 narrow-unicode ------"],
["/opt/python/cp27-cp27m/bin/python", "-m", "pip", "install", "--upgrade", "--trusted-host", "pypi.org", "--trusted-host", "files.pythonhosted.org", "pip", "setuptools", "boto3", "autopep8"],
["/opt/python/cp27-cp27m/bin/python", "setup.py", "--verbose", "build_ext", "--include-dirs{openssl_include}", "--library-dirs{openssl_lib}", "install"],
["/opt/python/cp27-cp27m/bin/python", "-m", "unittest", "discover", "--verbose"],
["/opt/python/cp37-cp37m/bin/python", "aws-c-http/integration-testing/http_client_test.py", "/opt/python/cp27-cp27m/bin/python", "elasticurl.py"],
["echo", "------ Python 2.7 wide-unicode ------"],
["/opt/python/cp27-cp27mu/bin/python", "-m", "pip", "install", "--upgrade", "--trusted-host", "pypi.org", "--trusted-host", "files.pythonhosted.org", "pip", "setuptools", "boto3", "autopep8"],
["/opt/python/cp27-cp27mu/bin/python", "setup.py", "--verbose", "build_ext", "--include-dirs{openssl_include}", "--library-dirs{openssl_lib}", "install"],
["/opt/python/cp27-cp27mu/bin/python", "-m", "unittest", "discover", "--verbose"],
["/opt/python/cp37-cp37m/bin/python", "aws-c-http/integration-testing/http_client_test.py", "/opt/python/cp27-cp27mu/bin/python", "elasticurl.py"]
Expand Down Expand Up @@ -83,9 +88,9 @@
["{python}", "setup.py", "--verbose", "build_ext", "--include-dirs{openssl_include}", "--library-dirs{openssl_lib}", "install"]
],
"test": [
["{python}", "-m", "pip", "install", "--upgrade", "--trusted-host", "pypi.org", "--trusted-host", "files.pythonhosted.org", "pip", "setuptools", "boto3", "autopep8"],
["{python}", "-m", "unittest", "discover", "--verbose"],
["{python}", "aws-c-http/integration-testing/http_client_test.py", "{python}", "elasticurl.py"],
["{python}", "-m", "pip", "install", "autopep8"],
["{python}", "-m", "autopep8", "--exit-code", "--diff", "--recursive", "awscrt", "test", "setup.py"]
]
}
2 changes: 1 addition & 1 deletion s2n
Submodule s2n updated from 4675f7 to f58bc0
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,4 +243,7 @@ def awscrt_ext():
ext_modules=[awscrt_ext()],
cmdclass={'build_ext': awscrt_build_ext},
test_suite='test',
tests_require=[
'boto3'
],
)
4 changes: 3 additions & 1 deletion source/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ PyObject *aws_py_init_logging(PyObject *self, PyObject *args) {
(void)self;

if (s_logger_init) {
Py_RETURN_NONE;
aws_logger_set(NULL);
aws_logger_clean_up(&s_logger);
}

s_logger_init = true;
Expand Down Expand Up @@ -264,6 +265,7 @@ static PyMethodDef s_module_methods[] = {
AWS_PY_METHOD_DEF(mqtt_client_connection_reconnect, METH_VARARGS),
AWS_PY_METHOD_DEF(mqtt_client_connection_publish, METH_VARARGS),
AWS_PY_METHOD_DEF(mqtt_client_connection_subscribe, METH_VARARGS),
AWS_PY_METHOD_DEF(mqtt_client_connection_on_message, METH_VARARGS),
AWS_PY_METHOD_DEF(mqtt_client_connection_resubscribe_existing_topics, METH_VARARGS),
AWS_PY_METHOD_DEF(mqtt_client_connection_unsubscribe, METH_VARARGS),
AWS_PY_METHOD_DEF(mqtt_client_connection_disconnect, METH_VARARGS),
Expand Down
54 changes: 42 additions & 12 deletions source/mqtt_client_connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ struct mqtt_connection_binding {
static void s_mqtt_python_connection_finish_destruction(struct mqtt_connection_binding *py_connection) {
aws_mqtt_client_connection_destroy(py_connection->native);

Py_DECREF(py_connection->on_connection_interrupted);
Py_DECREF(py_connection->on_connection_resumed);
Py_XDECREF(py_connection->on_connection_interrupted);
Py_XDECREF(py_connection->on_connection_resumed);
Py_DECREF(py_connection->client);

aws_mem_release(aws_py_get_allocator(), py_connection);
Expand Down Expand Up @@ -183,10 +183,8 @@ PyObject *aws_py_mqtt_client_connection_new(PyObject *self, PyObject *args) {

/* From hereon, nothing will fail */

py_connection->on_connection_interrupted = on_connection_interrupted;
Py_INCREF(py_connection->on_connection_interrupted);
py_connection->on_connection_resumed = on_connection_resumed;
Py_INCREF(py_connection->on_connection_resumed);
py_connection->on_connection_interrupted = PyWeakref_NewProxy(on_connection_interrupted, NULL);
py_connection->on_connection_resumed = PyWeakref_NewProxy(on_connection_resumed, NULL);
py_connection->client = client_py;
Py_INCREF(py_connection->client);

Expand Down Expand Up @@ -603,9 +601,12 @@ static void s_subscribe_callback(

(void)connection;

PyGILState_STATE state = PyGILState_Ensure();

PyObject *callback = user_data;
if (!callback) {
return;
}

PyGILState_STATE state = PyGILState_Ensure();

PyObject *result = PyObject_CallFunction(
callback,
Expand Down Expand Up @@ -643,6 +644,10 @@ static void s_suback_callback(
(void)connection;

PyObject *callback = userdata;
if (!callback) {
return;
}

PyGILState_STATE state = PyGILState_Ensure();

const char *topic_str = (const char *)topic->ptr;
Expand Down Expand Up @@ -678,8 +683,9 @@ PyObject *aws_py_mqtt_client_connection_subscribe(PyObject *self, PyObject *args
return NULL;
}

Py_INCREF(callback);
Py_INCREF(suback_callback);
Py_XINCREF(callback);
Py_XINCREF(suback_callback);

struct aws_byte_cursor topic_filter = aws_byte_cursor_from_array(topic, topic_len);
uint16_t msg_id = aws_mqtt_client_connection_subscribe(
py_connection->native,
Expand All @@ -692,14 +698,38 @@ PyObject *aws_py_mqtt_client_connection_subscribe(PyObject *self, PyObject *args
suback_callback);

if (msg_id == 0) {
Py_DECREF(callback);
Py_DECREF(suback_callback);
Py_XDECREF(callback);
Py_XDECREF(suback_callback);
return PyErr_AwsLastError();
}

return PyLong_FromUnsignedLong(msg_id);
}

PyObject *aws_py_mqtt_client_connection_on_message(PyObject *self, PyObject *args) {
(void)self;

PyObject *impl_capsule;
PyObject *callback;
if (!PyArg_ParseTuple(args, "OO", &impl_capsule, &callback)) {
return NULL;
}

struct mqtt_connection_binding *py_connection =
PyCapsule_GetPointer(impl_capsule, s_capsule_name_mqtt_client_connection);
if (!py_connection) {
return NULL;
}

callback = PyWeakref_NewProxy(callback, NULL);
if (aws_mqtt_client_connection_set_on_any_publish_handler(py_connection->native, s_subscribe_callback, callback)) {
Py_DECREF(callback);
return PyErr_AwsLastError();
}

Py_RETURN_NONE;
}

/*******************************************************************************
* Unsubscribe
******************************************************************************/
Expand Down
1 change: 1 addition & 0 deletions source/mqtt_client_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ PyObject *aws_py_mqtt_client_connection_connect(PyObject *self, PyObject *args);
PyObject *aws_py_mqtt_client_connection_reconnect(PyObject *self, PyObject *args);
PyObject *aws_py_mqtt_client_connection_publish(PyObject *self, PyObject *args);
PyObject *aws_py_mqtt_client_connection_subscribe(PyObject *self, PyObject *args);
PyObject *aws_py_mqtt_client_connection_on_message(PyObject *self, PyObject *args);
PyObject *aws_py_mqtt_client_connection_unsubscribe(PyObject *self, PyObject *args);
PyObject *aws_py_mqtt_client_connection_resubscribe_existing_topics(PyObject *self, PyObject *args);
PyObject *aws_py_mqtt_client_connection_disconnect(PyObject *self, PyObject *args);
Expand Down
132 changes: 132 additions & 0 deletions test/test_mqtt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright 2010-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://aws.amazon.com/apache2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from __future__ import absolute_import
from awscrt.io import ClientBootstrap, ClientTlsContext, DefaultHostResolver, EventLoopGroup, TlsConnectionOptions, TlsContextOptions, LogLevel, init_logging
from awscrt.mqtt import Client, Connection, QoS
from test import NativeResourceTest
from concurrent.futures import Future
import os
import unittest
import boto3
import time
import warnings


class MqttClientTest(NativeResourceTest):
def test_lifetime(self):
client = Client(ClientBootstrap(EventLoopGroup()))


class Config:
cache = None

def __init__(self, endpoint, cert, key):
try:
self.cert = cert
self.key = key
self.endpoint = endpoint
self.valid = True
except BaseException:
self.valid = False

@staticmethod
def get():
if Config.cache:
return Config.cache

# boto3 caches the HTTPS connection for the API calls, which appears to the unit test
# framework as a leak, so ignore it, that's not what we're testing here
warnings.simplefilter('ignore', ResourceWarning)

secrets = boto3.client('secretsmanager')
response = secrets.get_secret_value(SecretId='unit-test/endpoint')
endpoint = response['SecretString']
response = secrets.get_secret_value(SecretId='unit-test/certificate')
cert = bytes(response['SecretString'], 'utf8')
response = secrets.get_secret_value(SecretId='unit-test/privatekey')
key = bytes(response['SecretString'], 'utf8')
Config.cache = Config(endpoint, cert, key)
return Config.cache


class MqttConnectionTest(NativeResourceTest):
TEST_TOPIC = '/test/me/senpai'
TEST_MSG = 'NOTICE ME!'

def _test_connection(self):
try:
config = Config.get()
except Exception as ex:
return self.skipTest("No credentials")

try:
tls_opts = TlsContextOptions.create_client_with_mtls(config.cert, config.key)
tls = ClientTlsContext(tls_opts)
client = Client(ClientBootstrap(EventLoopGroup()), tls)
connection = Connection(client)
connection.connect('aws-crt-python-unit-test-'.format(time.gmtime()), config.endpoint, 8883).result()
return connection
except Exception as ex:
self.assertFalse(ex)

def test_connect_disconnect(self):
connection = self._test_connection()
connection.disconnect().result()

def test_pub_sub(self):
connection = self._test_connection()
disconnected = Future()

def on_disconnect(result):
disconnected.set_result(True)

def on_message(topic, payload):
self.assertEqual(self.TEST_TOPIC, topic)
self.assertEqual(self.TEST_MSG, str(payload, 'utf8'))
connection.unsubscribe(self.TEST_TOPIC)
connection.disconnect().add_done_callback(on_disconnect)

def do_publish(result):
connection.publish(self.TEST_TOPIC, bytes(self.TEST_MSG, 'utf8'), QoS.AT_LEAST_ONCE)

subscribed, packet_id = connection.subscribe(self.TEST_TOPIC, QoS.AT_LEAST_ONCE, on_message)
subscribed.add_done_callback(do_publish)

disconnected.result()

def test_on_message(self):
connection = self._test_connection()
disconnected = Future()

def on_disconnect(result):
disconnected.set_result(True)

def on_message(topic, payload):
self.assertEqual(self.TEST_TOPIC, topic)
self.assertEqual(self.TEST_MSG, str(payload, 'utf8'))
connection.unsubscribe(self.TEST_TOPIC)
connection.disconnect().add_done_callback(on_disconnect)

def do_publish(result):
connection.publish(self.TEST_TOPIC, bytes(self.TEST_MSG, 'utf8'), QoS.AT_LEAST_ONCE)

connection.on_message(on_message)
subscribed, packet_id = connection.subscribe(self.TEST_TOPIC, QoS.AT_LEAST_ONCE)
subscribed.add_done_callback(do_publish)

disconnected.result()


if __name__ == 'main':
unittest.main()

0 comments on commit c10c6d2

Please sign in to comment.