From 61cecbe8c9620fcf544309d322e726ea48b64ae5 Mon Sep 17 00:00:00 2001 From: Colden Cullen Date: Wed, 16 Jan 2019 12:35:12 -0800 Subject: [PATCH] Callbacks 4 Dayz (#29) - Compatibility with recent aws-c-mqtt API changes (interruption callbacks, etc) - Functions that had completion-callbacks now return Futures instead. - Use Python-3-style enums. - Update aws-c-common, aws-c-io, aws-c-mqtt, s2n dependencies - Scrubbed over mqtt_client_connection.c, being more careful about refcounts and error-handling. --- .gitignore | 1 + aws-c-common | 2 +- aws-c-io | 2 +- aws-c-mqtt | 2 +- aws_crt/io.py | 28 +- aws_crt/mqtt.py | 196 +++++++++--- codebuild/common-macos.sh | 1 - codebuild/linux-clang3-x64.yml | 3 +- codebuild/linux-clang6-x64.yml | 4 +- codebuild/linux-gcc-4x-x64.yml | 3 +- codebuild/linux-gcc-5x-x64.yml | 3 +- codebuild/linux-gcc-6x-x64.yml | 3 +- codebuild/linux-gcc-7x-x64.yml | 3 +- s2n | 2 +- setup.py | 13 +- source/module.c | 18 ++ source/module.h | 3 + source/mqtt_client_connection.c | 533 ++++++++++++++++++++++---------- source/mqtt_client_connection.h | 2 + test.py | 72 ++--- 20 files changed, 611 insertions(+), 283 deletions(-) diff --git a/.gitignore b/.gitignore index bcf6c5795..fbd1bd7a4 100644 --- a/.gitignore +++ b/.gitignore @@ -72,6 +72,7 @@ __pycache__/ # Distribution / packaging .Python build/ +deps_build/ develop-eggs/ dist/ downloads/ diff --git a/aws-c-common b/aws-c-common index a54b7f6e8..67262a9a4 160000 --- a/aws-c-common +++ b/aws-c-common @@ -1 +1 @@ -Subproject commit a54b7f6e8fdf539c252ace4b3fadef41eec00e3d +Subproject commit 67262a9a458585400187417d49c39c737507f671 diff --git a/aws-c-io b/aws-c-io index 532241470..3da314e54 160000 --- a/aws-c-io +++ b/aws-c-io @@ -1 +1 @@ -Subproject commit 532241470ed4a9f9bdede13b756c438ad4727ed7 +Subproject commit 3da314e5444614f784a6f6e206999f7255b41078 diff --git a/aws-c-mqtt b/aws-c-mqtt index 1b9d13a20..c12c5974f 160000 --- a/aws-c-mqtt +++ b/aws-c-mqtt @@ -1 +1 @@ -Subproject commit 1b9d13a205b2e3147afd8b9e365de4e5e04c2106 +Subproject commit c12c5974f6224c7b3031973eff00182242d0e622 diff --git a/aws_crt/io.py b/aws_crt/io.py index a551bfc40..0f3291611 100644 --- a/aws_crt/io.py +++ b/aws_crt/io.py @@ -12,18 +12,19 @@ # permissions and limitations under the License. import _aws_crt_python +from enum import IntEnum def is_alpn_available(): return _aws_crt_python.aws_py_is_alpn_available() class EventLoopGroup(object): - __slots__ = ['_internal_elg'] + __slots__ = ('_internal_elg') def __init__(self, num_threads): self._internal_elg = _aws_crt_python.aws_py_io_event_loop_group_new(num_threads) class ClientBootstrap(object): - __slots__ = ['elg', '_internal_bootstrap'] + __slots__ = ('elg', '_internal_bootstrap') def __init__(self, elg): assert isinstance(elg, EventLoopGroup) @@ -31,24 +32,23 @@ def __init__(self, elg): self.elg = elg self._internal_bootstrap = _aws_crt_python.aws_py_io_client_bootstrap_new(self.elg._internal_elg) -TlsVersion = type('TlsVersion', (), dict( - SSLv3 = 0, - TLSV1 = 1, - TLSV1_1 = 2, - TLSV1_2 = 3, - TLSV1_3 = 4, - Default = 128, -)) +class TlsVersion(IntEnum): + SSLv3 = 0 + TLSv1 = 1 + TLSv1_1 = 2 + TLSv1_2 = 3 + TLSv1_3 = 4 + DEFAULT = 128 class TlsContextOptions(object): - __slots__ = ['min_tls_ver', 'ca_file', 'ca_path', 'alpn_list', 'certificate_path', 'private_key_path', 'pkcs12_path', 'pkcs12_password', 'verify_peer'] + __slots__ = ('min_tls_ver', 'ca_file', 'ca_path', 'alpn_list', 'certificate_path', 'private_key_path', 'pkcs12_path', 'pkcs12_password', 'verify_peer') def __init__(self): for slot in self.__slots__: setattr(self, slot, None) - self.min_tls_ver = TlsVersion.Default + self.min_tls_ver = TlsVersion.DEFAULT def override_default_trust_store(self, ca_path, ca_file): @@ -107,14 +107,14 @@ def create_server_with_mtls_pkcs12(clazz, pkcs12_path, pkcs12_password): return opt class ClientTlsContext(object): - __slots__ = ['options', '_internal_tls_ctx'] + __slots__ = ('options', '_internal_tls_ctx') def __init__(self, options): assert isinstance(options, TlsContextOptions) self.options = options self._internal_tls_ctx = _aws_crt_python.aws_py_io_client_tls_ctx_new( - options.min_tls_ver, + options.min_tls_ver.value, options.ca_file, options.ca_path, options.alpn_list, diff --git a/aws_crt/mqtt.py b/aws_crt/mqtt.py index 692e08d18..9a03ce857 100644 --- a/aws_crt/mqtt.py +++ b/aws_crt/mqtt.py @@ -12,21 +12,26 @@ # permissions and limitations under the License. import _aws_crt_python +from concurrent.futures import Future +from enum import IntEnum from aws_crt.io import ClientBootstrap, ClientTlsContext -def _default_on_connect(return_code, session_present): - pass -def _default_on_disconnect(return_code): - return False +class QoS(IntEnum): + """Quality of Service""" + AT_MOST_ONCE = 0 + AT_LEAST_ONCE = 1 + EXACTLY_ONCE = 2 -QoS = type('QoS', (), dict( - AtMostOnce = 0, - AtLeastOnce = 1, - ExactlyOnce = 2, -)) +class ConnectReturnCode(IntEnum): + ACCEPTED = 0 + UNACCEPTABLE_PROTOCOL_VERSION = 1 + IDENTIFIER_REJECTED = 2 + SERVER_UNAVAILABLE = 3 + BAD_USERNAME_OR_PASSWORD = 4 + NOT_AUTHORIZED = 5 class Will(object): - __slots__ = ['topic', 'qos', 'payload', 'retain'] + __slots__ = ('topic', 'qos', 'payload', 'retain') def __init__(self, topic, qos, payload, retain): self.topic = topic @@ -35,7 +40,7 @@ def __init__(self, topic, qos, payload, retain): self.retain = retain class Client(object): - __slots__ = ['_internal_client', 'bootstrap', 'tls_ctx'] + __slots__ = ('_internal_client', 'bootstrap', 'tls_ctx') def __init__(self, bootstrap, tls_ctx = None): assert isinstance(bootstrap, ClientBootstrap) @@ -46,51 +51,154 @@ def __init__(self, bootstrap, tls_ctx = None): self._internal_client = _aws_crt_python.aws_py_mqtt_client_new(self.bootstrap._internal_bootstrap) class Connection(object): - __slots__ = ['_internal_connection'] + __slots__ = ('_internal_connection', 'client') + + def __init__(self, + client, + on_connection_interrupted=None, + on_connection_resumed=None, + reconnect_min_timeout_sec=5.0, + reconnect_max_timeout_sec=60.0): + """ + on_connection_interrupted: optional callback, with signature (error_code) + on_connection_resumed: optional callback, with signature (error_code, session_present) + """ - def __init__(self, client, client_id, + assert isinstance(client, Client) + self.client = client + + self._internal_connection = _aws_crt_python.aws_py_mqtt_client_connection_new( + client._internal_client, + on_connection_interrupted, + on_connection_resumed, + ) + + def connect(self, + client_id, host_name, port, - on_connect=_default_on_connect, - on_disconnect=_default_on_disconnect, - use_websocket=False, alpn=None, + use_websocket=False, clean_session=True, keep_alive=0, will=None, - username=None, password=None): + username=None, password=None, + connect_timeout_sec=5.0): + + future = Future() + + def on_connect(error_code, return_code, session_present): + if error_code == 0 and return_code == 0: + future.set_result(dict(session_present=session_present)) + else: + future.set_exception(Exception("Error during connect.")) + + try: + assert will is None or isinstance(will, Will) + assert use_websocket == False + + tls_ctx_cap = None + if self.client.tls_ctx: + tls_ctx_cap = self.client.tls_ctx._internal_tls_ctx + + _aws_crt_python.aws_py_mqtt_client_connection_connect( + self._internal_connection, + client_id, + host_name, + port, + tls_ctx_cap, + keep_alive, + will, + username, + password, + on_connect, + ) + + except Exception as e: + future.set_exception(e) + + return future + + def reconnect(self): + future = Future() + + def on_connect(error_code, return_code, session_present): + if error_code == 0 and return_code == 0: + future.set_result(dict(session_present=session_present)) + else: + future.set_exception(Exception("Error during reconnect")) + + try: + _aws_crt_python.aws_py_mqtt_client_connection_reconnect(self._internal_connection, on_connect) + except Exception as e: + future.set_exception(e) + + return future - assert isinstance(client, Client) - assert will is None or isinstance(will, Will) + def disconnect(self): - assert use_websocket == False + future = Future() - tls_ctx_cap = None - if client.tls_ctx: - tls_ctx_cap = client.tls_ctx._internal_tls_ctx + def on_disconnect(): + future.set_result(dict()) - self._internal_connection = _aws_crt_python.aws_py_mqtt_client_connection_new( - client._internal_client, - tls_ctx_cap, - host_name, - port, - client_id, - keep_alive, - on_connect, - on_disconnect, - will, - username, - password, - ) + try: + _aws_crt_python.aws_py_mqtt_client_connection_disconnect(self._internal_connection, on_disconnect) + except Exception as e: + future.set_exception(e) - def disconnect(self): - _aws_crt_python.aws_py_mqtt_client_connection_disconnect(self._internal_connection) + return future + + def subscribe(self, topic, qos, callback): + """ + callback: callback with signature (topic, message) + """ + future = Future() + packet_id = 0 + + def suback(packet_id, topic, qos): + future.set_result(dict( + packet_id=packet_id, + topic=topic, + qos=QoS(qos), + )) + + try: + packet_id = _aws_crt_python.aws_py_mqtt_client_connection_subscribe(self._internal_connection, topic, qos.value, callback, suback) + except Exception as e: + future.set_exception(e) + + return future, packet_id + + def unsubscribe(self, topic): + future = Future() + packet_id = 0 + + def unsuback(packet_id): + future.set_result(dict( + packet_id=packet_id + )) + + try: + packet_id = _aws_crt_python.aws_py_mqtt_client_connection_unsubscribe(self._internal_connection, topic, unsuback) + + except Exception as e: + future.set_exception(e) + + return future, packet_id + + def publish(self, topic, payload, qos, retain=False): + future = Future() + packet_id = 0 - def subscribe(self, topic, qos, callback, suback_callback=None): - return _aws_crt_python.aws_py_mqtt_client_connection_subscribe(self._internal_connection, topic, qos, callback, suback_callback) + def puback(packet_id): + future.set_result(dict( + packet_id=packet_id + )) - def unsubscribe(self, topic, unsuback_callback=None): - return _aws_crt_python.aws_py_mqtt_client_connection_unsubscribe(self._internal_connection, topic, unsuback_callback) + try: + packet_id = _aws_crt_python.aws_py_mqtt_client_connection_publish(self._internal_connection, topic, payload, qos.value, retain, puback) + except Exception as e: + future.set_exception(e) - def publish(self, topic, payload, qos, retain=False, puback_callback=None): - return _aws_crt_python.aws_py_mqtt_client_connection_publish(self._internal_connection, topic, payload, qos, retain, puback_callback) + return future, packet_id def ping(self): _aws_crt_python.aws_py_mqtt_client_connection_ping(self._internal_connection) \ No newline at end of file diff --git a/codebuild/common-macos.sh b/codebuild/common-macos.sh index 895c0087f..3a8d1241c 100755 --- a/codebuild/common-macos.sh +++ b/codebuild/common-macos.sh @@ -27,7 +27,6 @@ function install_from_brew { install_from_brew openssl install_from_brew gdbm -install_from_brew sqlite install_from_brew python git submodule update --init --recursive diff --git a/codebuild/linux-clang3-x64.yml b/codebuild/linux-clang3-x64.yml index 4b190a696..9f4802317 100644 --- a/codebuild/linux-clang3-x64.yml +++ b/codebuild/linux-clang3-x64.yml @@ -5,7 +5,8 @@ phases: commands: - sudo apt-get update -y - sudo apt-get update - - sudo apt-get install clang-3.9 cmake3 cppcheck clang-tidy-3.9 python3 python3-dev python3-setuptools ninja-build libssl-dev -y + - sudo apt-get install clang-3.9 cmake3 cppcheck clang-tidy-3.9 python3 python3-dev python3-pip ninja-build libssl-dev -y + - pip3 install --upgrade setuptools pre_build: commands: - export CC=clang-3.9 diff --git a/codebuild/linux-clang6-x64.yml b/codebuild/linux-clang6-x64.yml index 6c62ac917..dff85dfc1 100644 --- a/codebuild/linux-clang6-x64.yml +++ b/codebuild/linux-clang6-x64.yml @@ -7,8 +7,8 @@ phases: - sudo add-apt-repository ppa:ubuntu-toolchain-r/test - sudo apt-add-repository "deb http://apt.llvm.org/trusty/ llvm-toolchain-trusty-6.0 main" - sudo apt-get update -y - - sudo apt-get install clang-6.0 cmake3 cppcheck clang-tidy-6.0 clang-format-6.0 python3 python3-dev python3-setuptools ninja-build libssl-dev -y -f --force-yes - + - sudo apt-get install clang-6.0 cmake3 cppcheck clang-tidy-6.0 clang-format-6.0 python3 python3-dev python3-pip ninja-build libssl-dev -y -f --force-yes + - pip3 install --upgrade setuptools pre_build: commands: - export CC=clang-6.0 diff --git a/codebuild/linux-gcc-4x-x64.yml b/codebuild/linux-gcc-4x-x64.yml index 3eac1d009..d5ef3783a 100644 --- a/codebuild/linux-gcc-4x-x64.yml +++ b/codebuild/linux-gcc-4x-x64.yml @@ -4,7 +4,8 @@ phases: install: commands: - sudo apt-get update -y - - sudo apt-get install gcc cmake3 cppcheck python3 python3-dev python3-setuptools ninja-build libssl-dev -y + - sudo apt-get install gcc cmake3 cppcheck python3 python3-dev python3-pip ninja-build libssl-dev -y + - pip3 install --upgrade setuptools pre_build: commands: - export CC=gcc diff --git a/codebuild/linux-gcc-5x-x64.yml b/codebuild/linux-gcc-5x-x64.yml index 7b92c41ac..67df88e40 100644 --- a/codebuild/linux-gcc-5x-x64.yml +++ b/codebuild/linux-gcc-5x-x64.yml @@ -5,7 +5,8 @@ phases: commands: - sudo add-apt-repository ppa:ubuntu-toolchain-r/test - sudo apt-get update -y - - sudo apt-get install gcc-5 cmake3 cppcheck python3 python3-dev python3-setuptools ninja-build libssl-dev -y + - sudo apt-get install gcc-5 cmake3 cppcheck python3 python3-dev python3-pip ninja-build libssl-dev -y + - pip3 install --upgrade setuptools pre_build: commands: - export CC=gcc-5 diff --git a/codebuild/linux-gcc-6x-x64.yml b/codebuild/linux-gcc-6x-x64.yml index 73060d36f..1ec86524b 100644 --- a/codebuild/linux-gcc-6x-x64.yml +++ b/codebuild/linux-gcc-6x-x64.yml @@ -5,7 +5,8 @@ phases: commands: - sudo add-apt-repository ppa:ubuntu-toolchain-r/test - sudo apt-get update -y - - sudo apt-get install gcc-6 cmake3 cppcheck python3 python3-dev python3-setuptools ninja-build libssl-dev -y -f --force-yes + - sudo apt-get install gcc-6 cmake3 cppcheck python3 python3-dev python3-pip ninja-build libssl-dev -y -f --force-yes + - pip3 install --upgrade setuptools pre_build: commands: - export CC=gcc-6 diff --git a/codebuild/linux-gcc-7x-x64.yml b/codebuild/linux-gcc-7x-x64.yml index 6c303a7d8..307845276 100644 --- a/codebuild/linux-gcc-7x-x64.yml +++ b/codebuild/linux-gcc-7x-x64.yml @@ -5,7 +5,8 @@ phases: commands: - sudo add-apt-repository ppa:ubuntu-toolchain-r/test - sudo apt-get update -y - - sudo apt-get install gcc-7 cmake3 cppcheck python3 python3-dev python3-setuptools ninja-build libssl-dev -y + - sudo apt-get install gcc-7 cmake3 cppcheck python3 python3-dev python3-pip ninja-build libssl-dev -y + - pip3 install --upgrade setuptools pre_build: commands: - export CC=gcc-7 diff --git a/s2n b/s2n index 2b418ce17..383586162 160000 --- a/s2n +++ b/s2n @@ -1 +1 @@ -Subproject commit 2b418ce17c4a2eb9a4da8aaa272538d707907f44 +Subproject commit 383586162b3ee60bbd75105fcfe583b14bf60d46 diff --git a/setup.py b/setup.py index 705cfec0e..1a51b2ef9 100644 --- a/setup.py +++ b/setup.py @@ -109,7 +109,14 @@ def build_dependency(lib_name, pass_dversion_libs=True): os.mkdir(lib_build_dir) os.chdir(lib_build_dir) - cmake_args = ['cmake', generator_string, cross_compile_string, '-DCMAKE_INSTALL_PREFIX={}'.format(dep_install_path), '-DBUILD_SHARED_LIBS=OFF'] + cmake_args = [ + 'cmake', + generator_string, + cross_compile_string, + '-DCMAKE_PREFIX_PATH={}'.format(dep_install_path), + '-DCMAKE_INSTALL_PREFIX={}'.format(dep_install_path), + '-DBUILD_SHARED_LIBS=OFF', + ] if pass_dversion_libs: cmake_args.append('-DVERSION_LIBS=OFF') cmake_args.append(lib_source_dir) @@ -212,6 +219,10 @@ def build_dependency(lib_name, pass_dversion_libs=True): "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", ], + install_requires=[ + 'enum34 ; python_version<"3.4"', + 'futures ; python_version<"3.2"', + ], ext_modules = [_aws_crt_python], ) diff --git a/source/module.c b/source/module.c index 2acd3ae11..b0d7a58f1 100644 --- a/source/module.c +++ b/source/module.c @@ -43,6 +43,22 @@ struct aws_byte_cursor aws_byte_cursor_from_pystring(PyObject *str) { return aws_byte_cursor_from_array(NULL, 0); } +int PyIntEnum_Check(PyObject *int_enum_obj) { +#if PY_MAJOR_VERSION == 2 + return PyInt_Check(int_enum_obj); +#else + return PyLong_Check(int_enum_obj); +#endif +} + +long PyIntEnum_AsLong(PyObject *int_enum_obj) { +#if PY_MAJOR_VERSION == 2 + return PyInt_AsLong(int_enum_obj); +#else + return PyLong_AsLong(int_enum_obj); +#endif +} + void PyErr_SetAwsLastError(void) { PyErr_AwsLastError(); } @@ -77,6 +93,8 @@ static PyMethodDef s_module_methods[] = { /* MQTT Client Connection */ {"aws_py_mqtt_client_connection_new", aws_py_mqtt_client_connection_new, METH_VARARGS, NULL}, + {"aws_py_mqtt_client_connection_connect", aws_py_mqtt_client_connection_connect, METH_VARARGS, NULL}, + {"aws_py_mqtt_client_connection_reconnect", aws_py_mqtt_client_connection_reconnect, METH_VARARGS, NULL}, {"aws_py_mqtt_client_connection_publish", aws_py_mqtt_client_connection_publish, METH_VARARGS, NULL}, {"aws_py_mqtt_client_connection_subscribe", aws_py_mqtt_client_connection_subscribe, METH_VARARGS, NULL}, {"aws_py_mqtt_client_connection_unsubscribe", aws_py_mqtt_client_connection_unsubscribe, METH_VARARGS, NULL}, diff --git a/source/module.h b/source/module.h index 33c0305ef..b9ce3b51e 100644 --- a/source/module.h +++ b/source/module.h @@ -32,6 +32,9 @@ #define PyBool_FromAwsResult(result) PyBool_FromLong((result) == AWS_OP_SUCCESS) #define PyString_FromAwsByteCursor(cursor) PyString_FromStringAndSize((const char *)(cursor)->ptr, (cursor)->len) +int PyIntEnum_Check(PyObject *int_enum_obj); +long PyIntEnum_AsLong(PyObject *int_enum_obj); + struct aws_byte_cursor aws_byte_cursor_from_pystring(PyObject *str); /* Set current thread's error indicator based on aws_last_error() */ diff --git a/source/mqtt_client_connection.c b/source/mqtt_client_connection.c index 090214627..66d31ee16 100644 --- a/source/mqtt_client_connection.c +++ b/source/mqtt_client_connection.c @@ -47,169 +47,261 @@ struct mqtt_python_connection { struct aws_mqtt_client_connection *connection; PyObject *on_connect; - PyObject *on_disconnect; + PyObject *on_connection_interrupted; + PyObject *on_connection_resumed; }; -static void s_mqtt_python_connection_destructor(PyObject *connection_capsule) { +static void s_mqtt_python_connection_destructor_on_disconnect( + struct aws_mqtt_client_connection *connection, + void *userdata) { struct aws_allocator *allocator = aws_crt_python_get_allocator(); + struct mqtt_python_connection *py_connection = userdata; + + Py_CLEAR(py_connection->on_connection_interrupted); + Py_CLEAR(py_connection->on_connection_resumed); + + aws_mqtt_client_connection_destroy(connection); + aws_mem_release(allocator, py_connection); +} - assert(PyCapsule_CheckExact(connection_capsule)); +static void s_mqtt_python_connection_destructor(PyObject *connection_capsule) { struct mqtt_python_connection *py_connection = PyCapsule_GetPointer(connection_capsule, s_capsule_name_mqtt_client_connection); assert(py_connection); - Py_XDECREF(py_connection->on_connect); - Py_XDECREF(py_connection->on_disconnect); + if (aws_mqtt_client_connection_disconnect( + py_connection->connection, s_mqtt_python_connection_destructor_on_disconnect, py_connection)) { - aws_mqtt_client_connection_disconnect(py_connection->connection); - - aws_mem_release(allocator, py_connection); + /* If this returns an error, we should immediately destroy the connection */ + s_mqtt_python_connection_destructor_on_disconnect(py_connection->connection, py_connection); + } } -static void s_on_connect_failed(struct aws_mqtt_client_connection *connection, int error_code, void *user_data) { +static void s_on_connection_interrupted(struct aws_mqtt_client_connection *connection, int error_code, void *userdata) { (void)connection; - struct mqtt_python_connection *py_connection = user_data; + struct mqtt_python_connection *py_connection = userdata; - if (py_connection->on_disconnect) { - PyGILState_STATE state = PyGILState_Ensure(); - - PyObject *result = PyObject_CallFunction(py_connection->on_disconnect, "(I)", error_code); - Py_XDECREF(result); + PyGILState_STATE state = PyGILState_Ensure(); - PyGILState_Release(state); + PyObject *result = PyObject_CallFunction(py_connection->on_connection_interrupted, "(I)", error_code); + if (result) { + Py_DECREF(result); + } else { + PyErr_WriteUnraisable(PyErr_Occurred()); } + + PyGILState_Release(state); } -static void s_on_connect( +static void s_on_connection_resumed( struct aws_mqtt_client_connection *connection, enum aws_mqtt_connect_return_code return_code, bool session_present, - void *user_data) { + void *userdata) { (void)connection; - struct mqtt_python_connection *py_connection = user_data; + struct mqtt_python_connection *py_connection = userdata; - if (py_connection->on_connect) { - PyGILState_STATE state = PyGILState_Ensure(); + PyGILState_STATE state = PyGILState_Ensure(); - PyObject *result = - PyObject_CallFunction(py_connection->on_connect, "(IN)", return_code, PyBool_FromLong(session_present)); - Py_XDECREF(result); + PyObject *result = PyObject_CallFunction( + py_connection->on_connection_resumed, "(IN)", return_code, PyBool_FromLong(session_present)); + if (result) { + Py_DECREF(result); + } else { + PyErr_WriteUnraisable(PyErr_Occurred()); + } - PyGILState_Release(state); + PyGILState_Release(state); +} + +PyObject *aws_py_mqtt_client_connection_new(PyObject *self, PyObject *args) { + (void)self; + + struct aws_allocator *allocator = aws_crt_python_get_allocator(); + + /* If anything goes wrong in this function: goto error */ + struct mqtt_python_connection *py_connection = NULL; + + PyObject *client_capsule = NULL; + PyObject *on_connection_interrupted = NULL; + PyObject *on_connection_resumed = NULL; + + if (!PyArg_ParseTuple(args, "OOO", &client_capsule, &on_connection_interrupted, &on_connection_resumed)) { + goto error; } + + py_connection = aws_mem_acquire(allocator, sizeof(struct mqtt_python_connection)); + if (!py_connection) { + PyErr_SetAwsLastError(); + goto error; + } + AWS_ZERO_STRUCT(*py_connection); + + py_connection->py_client = PyCapsule_GetPointer(client_capsule, s_capsule_name_mqtt_client); + if (!py_connection->py_client) { + goto error; + } + + if (on_connection_interrupted != Py_None) { + if (!PyCallable_Check(on_connection_interrupted)) { + PyErr_SetString(PyExc_TypeError, "on_connection_interrupted is invalid"); + goto error; + } + + Py_INCREF(on_connection_interrupted); + py_connection->on_connection_interrupted = on_connection_interrupted; + } + + if (on_connection_resumed != Py_None) { + if (!PyCallable_Check(on_connection_resumed)) { + PyErr_SetString(PyExc_TypeError, "on_connection_resumed is invalid"); + goto error; + } + + Py_INCREF(on_connection_resumed); + py_connection->on_connection_resumed = on_connection_resumed; + } + + py_connection->connection = aws_mqtt_client_connection_new(&py_connection->py_client->native_client); + if (!py_connection->connection) { + PyErr_SetAwsLastError(); + goto error; + } + + if (py_connection->on_connection_interrupted || py_connection->on_connection_resumed) { + if (aws_mqtt_client_connection_set_connection_interruption_handlers( + py_connection->connection, + py_connection->on_connection_interrupted ? s_on_connection_interrupted : NULL, + py_connection, + py_connection->on_connection_resumed ? s_on_connection_resumed : NULL, + py_connection)) { + + PyErr_SetAwsLastError(); + goto error; + } + } + + PyObject *impl_capsule = + PyCapsule_New(py_connection, s_capsule_name_mqtt_client_connection, s_mqtt_python_connection_destructor); + if (!impl_capsule) { + goto error; + } + + return impl_capsule; + +error: + if (py_connection) { + if (py_connection->connection) { + aws_mqtt_client_connection_destroy(py_connection->connection); + } + + Py_CLEAR(py_connection->on_connection_interrupted); + Py_CLEAR(py_connection->on_connection_resumed); + aws_mem_release(allocator, py_connection); + } + + return NULL; } -static bool s_on_disconnect(struct aws_mqtt_client_connection *connection, int error_code, void *user_data) { +/******************************************************************************* + * Connect + ******************************************************************************/ + +static void s_on_connect( + struct aws_mqtt_client_connection *connection, + int error_code, + enum aws_mqtt_connect_return_code return_code, + bool session_present, + void *user_data) { (void)connection; struct mqtt_python_connection *py_connection = user_data; - bool should_reconnect = true; - - if (py_connection->on_disconnect) { + if (py_connection->on_connect) { PyGILState_STATE state = PyGILState_Ensure(); - PyObject *result = PyObject_CallFunction(py_connection->on_disconnect, "(I)", error_code); + PyObject *callback = py_connection->on_connect; + py_connection->on_connect = NULL; + + PyObject *result = + PyObject_CallFunction(callback, "(IIN)", error_code, return_code, PyBool_FromLong(session_present)); if (result) { - if (result != Py_None) { - should_reconnect = PyObject_IsTrue(result); - } Py_DECREF(result); + } else { + PyErr_WriteUnraisable(PyErr_Occurred()); } + Py_XDECREF(callback); + PyGILState_Release(state); } - - return should_reconnect; } -PyObject *aws_py_mqtt_client_connection_new(PyObject *self, PyObject *args) { - (void)self; +PyObject *aws_py_mqtt_client_connection_connect(PyObject *self, PyObject *args) { - struct aws_allocator *allocator = aws_crt_python_get_allocator(); + (void)self; - /* If anything goes wrong in this function: goto error */ - struct mqtt_python_connection *py_connection = NULL; struct aws_tls_ctx *tls_ctx = NULL; - PyObject *client_capsule = NULL; - PyObject *tls_ctx_capsule = NULL; + PyObject *impl_capsule = NULL; + const char *client_id = NULL; + Py_ssize_t client_id_len = 0; const char *server_name = NULL; Py_ssize_t server_name_len = 0; uint16_t port_number = 0; - const char *client_id = NULL; - Py_ssize_t client_id_len = 0; + PyObject *tls_ctx_capsule = NULL; uint16_t keep_alive_time = 0; - PyObject *on_connect = NULL; - PyObject *on_disconnect = NULL; PyObject *will = NULL; const char *username = NULL; Py_ssize_t username_len = 0; const char *password = NULL; Py_ssize_t password_len = 0; + PyObject *on_connect = NULL; if (!PyArg_ParseTuple( args, - "OOs#Hs#HOOOz#z#", - &client_capsule, - &tls_ctx_capsule, + "Os#s#HOHOz#z#O", + &impl_capsule, + &client_id, + &client_id_len, &server_name, &server_name_len, &port_number, - &client_id, - &client_id_len, + &tls_ctx_capsule, &keep_alive_time, - &on_connect, - &on_disconnect, &will, &username, &username_len, &password, - &password_len)) { - goto error; + &password_len, + &on_connect)) { + return NULL; } - py_connection = aws_mem_acquire(allocator, sizeof(struct mqtt_python_connection)); + struct mqtt_python_connection *py_connection = + PyCapsule_GetPointer(impl_capsule, s_capsule_name_mqtt_client_connection); if (!py_connection) { - PyErr_SetAwsLastError(); - goto error; + return NULL; } - AWS_ZERO_STRUCT(*py_connection); - if (!client_capsule || !PyCapsule_CheckExact(client_capsule)) { - PyErr_SetNone(PyExc_ValueError); - goto error; - } - py_connection->py_client = PyCapsule_GetPointer(client_capsule, s_capsule_name_mqtt_client); - if (!py_connection->py_client) { - goto error; + if (py_connection->on_connect) { + PyErr_SetString(PyExc_RuntimeError, "Connection already in progress"); + return NULL; } - if (tls_ctx_capsule != Py_None && PyCapsule_CheckExact(tls_ctx_capsule)) { + if (tls_ctx_capsule != Py_None) { tls_ctx = PyCapsule_GetPointer(tls_ctx_capsule, s_capsule_name_tls_ctx); if (!tls_ctx) { - goto error; + return NULL; } - } - - if (on_connect && PyCallable_Check(on_connect)) { - Py_INCREF(on_connect); - py_connection->on_connect = on_connect; - } - if (on_disconnect && PyCallable_Check(on_disconnect)) { - Py_INCREF(on_disconnect); - py_connection->on_disconnect = on_disconnect; - } - - if (tls_ctx) { aws_tls_connection_options_init_from_ctx(&py_connection->tls_options, tls_ctx); aws_tls_connection_options_set_server_name(&py_connection->tls_options, server_name); } @@ -219,46 +311,48 @@ PyObject *aws_py_mqtt_client_connection_new(PyObject *self, PyObject *args) { py_connection->socket_options.type = AWS_SOCKET_STREAM; py_connection->socket_options.domain = AWS_SOCKET_IPV4; - struct aws_mqtt_client_connection_callbacks callbacks; - AWS_ZERO_STRUCT(callbacks); - callbacks.on_connection_failed = s_on_connect_failed; - callbacks.on_connack = s_on_connect; - callbacks.on_disconnect = s_on_disconnect; - callbacks.user_data = py_connection; - struct aws_byte_cursor server_name_cur = aws_byte_cursor_from_array(server_name, server_name_len); - py_connection->connection = aws_mqtt_client_connection_new( - &py_connection->py_client->native_client, - callbacks, - &server_name_cur, - port_number, - &py_connection->socket_options, - tls_ctx ? &py_connection->tls_options : NULL); - - if (!py_connection->connection) { - PyErr_SetAwsLastError(); - goto error; - } - - if (will && will != Py_None) { + if (will != Py_None) { + struct aws_byte_cursor topic; + AWS_ZERO_STRUCT(topic); PyObject *py_topic = PyObject_GetAttrString(will, "topic"); - assert(py_topic); - struct aws_byte_cursor topic = aws_byte_cursor_from_pystring(py_topic); + if (py_topic) { + topic = aws_byte_cursor_from_pystring(py_topic); + } + if (!topic.ptr) { + PyErr_SetString(PyExc_TypeError, "Will.topic is invalid"); + return NULL; + } PyObject *py_qos = PyObject_GetAttrString(will, "qos"); - assert(py_qos && PyLong_Check(py_qos)); - enum aws_mqtt_qos qos = (enum aws_mqtt_qos)PyLong_AsUnsignedLong(py_qos); + if (!py_qos || !PyIntEnum_Check(py_qos)) { + PyErr_SetString(PyExc_TypeError, "Will.qos is invalid"); + return NULL; + } + enum aws_mqtt_qos qos = (enum aws_mqtt_qos)PyIntEnum_AsLong(py_qos); + struct aws_byte_cursor payload; + AWS_ZERO_STRUCT(payload); PyObject *py_payload = PyObject_GetAttrString(will, "payload"); - assert(py_payload); - struct aws_byte_cursor payload = aws_byte_cursor_from_pystring(py_payload); + if (py_payload) { + payload = aws_byte_cursor_from_pystring(py_payload); + } + if (!payload.ptr) { + PyErr_SetString(PyExc_TypeError, "Will.payload is invalid"); + return NULL; + } PyObject *py_retain = PyObject_GetAttrString(will, "retain"); - assert(py_retain && PyBool_Check(py_retain)); + if (!PyBool_Check(py_retain)) { + PyErr_SetString(PyExc_TypeError, "Will.retain is invalid"); + return NULL; + } bool retain = py_retain == Py_True; - aws_mqtt_client_connection_set_will(py_connection->connection, &topic, qos, retain, &payload); + if (aws_mqtt_client_connection_set_will(py_connection->connection, &topic, qos, retain, &payload)) { + return PyErr_AwsLastError(); + } } if (username) { @@ -272,27 +366,82 @@ PyObject *aws_py_mqtt_client_connection_new(PyObject *self, PyObject *args) { password_cur_ptr = &password_cur; } - aws_mqtt_client_connection_set_login(py_connection->connection, &username_cur, password_cur_ptr); + if (aws_mqtt_client_connection_set_login(py_connection->connection, &username_cur, password_cur_ptr)) { + return PyErr_AwsLastError(); + } + } + + if (on_connect != Py_None) { + if (!PyCallable_Check(on_connect)) { + PyErr_SetString(PyExc_TypeError, "on_connect is invalid"); + return NULL; + } + Py_INCREF(on_connect); + py_connection->on_connect = on_connect; } struct aws_byte_cursor client_id_cur = aws_byte_cursor_from_array(client_id, client_id_len); - if (aws_mqtt_client_connection_connect(py_connection->connection, &client_id_cur, true, keep_alive_time)) { - PyErr_SetAwsLastError(); - goto error; + if (aws_mqtt_client_connection_connect( + py_connection->connection, + &server_name_cur, + port_number, + &py_connection->socket_options, + tls_ctx ? &py_connection->tls_options : NULL, + &client_id_cur, + true, + keep_alive_time, + s_on_connect, + py_connection)) { + Py_CLEAR(py_connection->on_connect); + return PyErr_AwsLastError(); } - return PyCapsule_New(py_connection, s_capsule_name_mqtt_client_connection, s_mqtt_python_connection_destructor); + Py_RETURN_NONE; +} -error: - if (py_connection) { - if (py_connection->connection) { - aws_mem_release(allocator, py_connection->connection); /* TODO: need aws_mqtt_client_connection_destroy() */ +/******************************************************************************* + * Reconnect + ******************************************************************************/ + +PyObject *aws_py_mqtt_client_connection_reconnect(PyObject *self, PyObject *args) { + + (void)self; + + PyObject *impl_capsule = NULL; + PyObject *on_connect = NULL; + + if (!PyArg_ParseTuple(args, "OO", &impl_capsule, &on_connect)) { + return NULL; + } + + struct mqtt_python_connection *py_connection = + PyCapsule_GetPointer(impl_capsule, s_capsule_name_mqtt_client_connection); + if (!py_connection) { + return NULL; + } + + if (py_connection->on_connect) { + PyErr_SetString(PyExc_RuntimeError, "Connection already in progress"); + return NULL; + } + + if (on_connect != Py_None) { + if (!PyCallable_Check(on_connect)) { + PyErr_SetString(PyExc_TypeError, "on_connect is invalid"); + return NULL; } - aws_mem_release(allocator, py_connection); + Py_INCREF(on_connect); + py_connection->on_connect = on_connect; } - return NULL; + if (aws_mqtt_client_connection_reconnect(py_connection->connection, s_on_connect, py_connection)) { + Py_CLEAR(py_connection->on_connect); + PyErr_SetAwsLastError(); + return NULL; + } + + Py_RETURN_NONE; } /******************************************************************************* @@ -319,7 +468,13 @@ static void s_publish_complete( PyGILState_STATE state = PyGILState_Ensure(); if (metadata->callback) { - PyObject_CallFunction(metadata->callback, "(H)", packet_id); + PyObject *result = PyObject_CallFunction(metadata->callback, "(H)", packet_id); + if (result) { + Py_DECREF(result); + } else { + PyErr_WriteUnraisable(PyErr_Occurred()); + } + Py_DECREF(metadata->callback); } @@ -348,20 +503,22 @@ PyObject *aws_py_mqtt_client_connection_publish(PyObject *self, PyObject *args) return NULL; } - if (!impl_capsule || !PyCapsule_CheckExact(impl_capsule)) { - PyErr_SetNone(PyExc_TypeError); - return NULL; - } - struct mqtt_python_connection *connection = PyCapsule_GetPointer(impl_capsule, s_capsule_name_mqtt_client_connection); + if (!connection) { + return NULL; + } - if (qos_val > 3) { + if (qos_val >= AWS_MQTT_QOS_EXACTLY_ONCE) { PyErr_SetNone(PyExc_ValueError); return NULL; } - if (puback_callback && PyCallable_Check(puback_callback)) { - Py_INCREF(puback_callback); + + if (puback_callback != Py_None) { + if (!PyCallable_Check(puback_callback)) { + PyErr_SetString(PyExc_TypeError, "puback callback is invalid"); + return NULL; + } } else { puback_callback = NULL; } @@ -386,10 +543,14 @@ PyObject *aws_py_mqtt_client_connection_publish(PyObject *self, PyObject *args) enum aws_mqtt_qos qos = (enum aws_mqtt_qos)qos_val; + Py_XINCREF(metadata->callback); + uint16_t msg_id = aws_mqtt_client_connection_publish( connection->connection, &topic_cursor, qos, retain == Py_True, &payload_cursor, s_publish_complete, metadata); if (msg_id == 0) { + Py_CLEAR(metadata->callback); + aws_mem_release(aws_crt_python_get_allocator(), metadata); return PyErr_AwsLastError(); } @@ -459,9 +620,11 @@ static void s_suback_callback( PyErr_WriteUnraisable(PyErr_Occurred()); abort(); } else { - Py_DECREF(callback); + Py_DECREF(result); } + Py_DECREF(callback); + PyGILState_Release(state); } } @@ -480,30 +643,33 @@ PyObject *aws_py_mqtt_client_connection_subscribe(PyObject *self, PyObject *args return NULL; } - if (!impl_capsule || !PyCapsule_CheckExact(impl_capsule)) { - PyErr_SetNone(PyExc_TypeError); + struct mqtt_python_connection *py_connection = + PyCapsule_GetPointer(impl_capsule, s_capsule_name_mqtt_client_connection); + if (!py_connection) { return NULL; } - if (!callback || !PyCallable_Check(callback)) { - PyErr_SetNone(PyExc_ValueError); + if (qos_val >= AWS_MQTT_QOS_EXACTLY_ONCE) { + PyErr_SetString(PyExc_ValueError, "qos is invalid"); return NULL; } - Py_INCREF(callback); - if (suback_callback && PyCallable_Check(suback_callback)) { - Py_INCREF(suback_callback); - } else { - suback_callback = NULL; + if (!PyCallable_Check(callback)) { + PyErr_SetString(PyExc_TypeError, "callback is invalid"); + return NULL; } - struct mqtt_python_connection *py_connection = - PyCapsule_GetPointer(impl_capsule, s_capsule_name_mqtt_client_connection); - - if (qos_val > 3) { - PyErr_SetNone(PyExc_ValueError); + if (suback_callback != Py_None) { + if (!PyCallable_Check(suback_callback)) { + PyErr_SetString(PyExc_TypeError, "suback_callback is invalid"); + return NULL; + } + } else { + suback_callback = NULL; } + Py_INCREF(callback); + Py_XINCREF(suback_callback); struct aws_byte_cursor topic_filter = aws_byte_cursor_from_array(topic, topic_len); uint16_t msg_id = aws_mqtt_client_connection_subscribe( py_connection->connection, @@ -516,6 +682,8 @@ PyObject *aws_py_mqtt_client_connection_subscribe(PyObject *self, PyObject *args suback_callback); if (msg_id == 0) { + Py_CLEAR(callback); + Py_CLEAR(suback_callback); return PyErr_AwsLastError(); } @@ -540,7 +708,13 @@ static void s_unsuback_callback( PyGILState_STATE state = PyGILState_Ensure(); - PyObject_CallFunction(callback, "(H)", packet_id); + PyObject *result = PyObject_CallFunction(callback, "(H)", packet_id); + if (result) { + Py_DECREF(result); + } else { + PyErr_WriteUnraisable(PyErr_Occurred()); + } + Py_DECREF(callback); PyGILState_Release(state); @@ -559,25 +733,28 @@ PyObject *aws_py_mqtt_client_connection_unsubscribe(PyObject *self, PyObject *ar return NULL; } - if (!impl_capsule || !PyCapsule_CheckExact(impl_capsule)) { - PyErr_SetNone(PyExc_TypeError); + struct mqtt_python_connection *connection = + PyCapsule_GetPointer(impl_capsule, s_capsule_name_mqtt_client_connection); + if (!connection) { return NULL; } - if (unsuback_callback && PyCallable_Check(unsuback_callback)) { - Py_INCREF(unsuback_callback); + if (unsuback_callback != Py_None) { + if (!PyCallable_Check(unsuback_callback)) { + PyErr_SetString(PyExc_TypeError, "unsuback callback is invalid"); + return NULL; + } } else { unsuback_callback = NULL; } - struct mqtt_python_connection *connection = - PyCapsule_GetPointer(impl_capsule, s_capsule_name_mqtt_client_connection); - struct aws_byte_cursor filter = aws_byte_cursor_from_array(topic, topic_len); + Py_XINCREF(unsuback_callback); uint16_t msg_id = aws_mqtt_client_connection_unsubscribe(connection->connection, &filter, s_unsuback_callback, unsuback_callback); if (msg_id == 0) { + Py_CLEAR(unsuback_callback); return PyErr_AwsLastError(); } @@ -597,13 +774,11 @@ PyObject *aws_py_mqtt_client_connection_ping(PyObject *self, PyObject *args) { return NULL; } - if (!impl_capsule || !PyCapsule_CheckExact(impl_capsule)) { - PyErr_SetNone(PyExc_TypeError); - return NULL; - } - struct mqtt_python_connection *connection = PyCapsule_GetPointer(impl_capsule, s_capsule_name_mqtt_client_connection); + if (!connection) { + return NULL; + } int err = aws_mqtt_client_connection_ping(connection->connection); if (err) { @@ -617,25 +792,57 @@ PyObject *aws_py_mqtt_client_connection_ping(PyObject *self, PyObject *args) { * Disconnect ******************************************************************************/ +static void s_on_disconnect(struct aws_mqtt_client_connection *connection, void *user_data) { + + (void)connection; + + PyObject *on_disconnect = user_data; + + if (on_disconnect) { + PyGILState_STATE state = PyGILState_Ensure(); + + PyObject *result = PyObject_CallFunction(on_disconnect, "()"); + if (result) { + Py_DECREF(result); + } else { + PyErr_WriteUnraisable(PyErr_Occurred()); + } + + Py_DECREF(on_disconnect); + + PyGILState_Release(state); + } +} + PyObject *aws_py_mqtt_client_connection_disconnect(PyObject *self, PyObject *args) { (void)self; PyObject *impl_capsule = NULL; + PyObject *on_disconnect = NULL; - if (!PyArg_ParseTuple(args, "O", &impl_capsule)) { + if (!PyArg_ParseTuple(args, "OO", &impl_capsule, &on_disconnect)) { return NULL; } - if (!impl_capsule || !PyCapsule_CheckExact(impl_capsule)) { - PyErr_SetNone(PyExc_TypeError); + struct mqtt_python_connection *connection = + PyCapsule_GetPointer(impl_capsule, s_capsule_name_mqtt_client_connection); + if (!connection) { return NULL; } - struct mqtt_python_connection *connection = - PyCapsule_GetPointer(impl_capsule, s_capsule_name_mqtt_client_connection); + if (on_disconnect != Py_None) { + if (!PyCallable_Check(on_disconnect)) { + PyErr_SetString(PyExc_TypeError, "on_disconnect is invalid"); + return NULL; + } + Py_INCREF(on_disconnect); + } else { + on_disconnect = NULL; + } - int err = aws_mqtt_client_connection_disconnect(connection->connection); + int err = aws_mqtt_client_connection_disconnect(connection->connection, s_on_disconnect, on_disconnect); if (err) { + Py_CLEAR(on_disconnect); return PyErr_AwsLastError(); } diff --git a/source/mqtt_client_connection.h b/source/mqtt_client_connection.h index 1a0295fec..806d1d4a6 100644 --- a/source/mqtt_client_connection.h +++ b/source/mqtt_client_connection.h @@ -25,6 +25,8 @@ extern const char *s_capsule_name_mqtt_client_connection; PyObject *aws_py_mqtt_client_connection_new(PyObject *self, PyObject *args); +PyObject *aws_py_mqtt_client_connection_connect(PyObject *self, PyObject *args); +PyObject *aws_py_mqtt_client_connection_reconnect(PyObject *self, PyObject *args); PyObject *aws_py_mqtt_client_connection_publish(PyObject *self, PyObject *args); PyObject *aws_py_mqtt_client_connection_subscribe(PyObject *self, PyObject *args); PyObject *aws_py_mqtt_client_connection_unsubscribe(PyObject *self, PyObject *args); diff --git a/test.py b/test.py index 46dc808b3..07e76150e 100644 --- a/test.py +++ b/test.py @@ -11,6 +11,8 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. +from __future__ import print_function + import argparse from aws_crt import io, mqtt import threading @@ -24,23 +26,16 @@ parser = argparse.ArgumentParser() parser.add_argument('--endpoint', required=True, help="Connect to this endpoint (aka host-name)") -parser.add_argument('--port', help="Override default connection port") +parser.add_argument('--port', type=int, help="Override default connection port") parser.add_argument('--cert', help="File path to your client certificate, in PEM format") parser.add_argument('--key', help="File path to your private key, in PEM format") parser.add_argument('--root-ca', help="File path to root certificate authority, in PEM format") -connect_results = {} -connect_event = threading.Event() -def on_connect(return_code, session_present): - connect_results.update(locals()) - connect_event.set() +def on_connection_interrupted(error_code): + print("Connection has been interrupted with error code", error_code) -disconnect_results = {} -disconnect_event = threading.Event() -def on_disconnect(return_code): - disconnect_results.update(locals()) - disconnect_event.set() - return False +def on_connection_resumed(return_code, session_present): + print("Connection has been resumed with return code", return_code, "and session present:", session_present) receive_results = {} receive_event = threading.Event() @@ -48,29 +43,12 @@ def on_receive_message(topic, message): receive_results.update(locals()) receive_event.set() -subscribe_results = {} -subscribe_event = threading.Event() -def on_subscribe(packet_id, topic, qos): - subscribe_results.update(locals()) - subscribe_event.set() - -unsubscribe_results = {} -unsubscribe_event = threading.Event() -def on_unsubscribe(packet_id): - unsubscribe_results.update(locals()) - unsubscribe_event.set() - -publish_results = {} -publish_event = threading.Event() -def on_publish(packet_id): - publish_results.update(locals()) - publish_event.set() - # Run args = parser.parse_args() event_loop_group = io.EventLoopGroup(1) client_bootstrap = io.ClientBootstrap(event_loop_group) +tls_options = None if args.cert or args.key or args.root_ca: if args.cert: assert(args.key) @@ -97,36 +75,34 @@ def on_publish(packet_id): print("Connecting to {}:{} with client-id:{}".format(args.endpoint, port, CLIENT_ID)) mqtt_connection = mqtt.Connection( client=mqtt_client, + on_connection_interrupted=on_connection_interrupted, + on_connection_resumed=on_connection_resumed) + +connect_results = mqtt_connection.connect( client_id=CLIENT_ID, host_name=args.endpoint, - port=port, - on_connect=on_connect, - on_disconnect=on_disconnect) -assert(connect_event.wait(TIMEOUT)) -assert(connect_results['return_code'] == 0) + port=port).result(TIMEOUT) assert(connect_results['session_present'] == False) # Subscribe print("Subscribing to:", TOPIC) -qos = mqtt.QoS.AtLeastOnce -subscribe_packet_id = mqtt_connection.subscribe( +qos = mqtt.QoS.AT_LEAST_ONCE +subscribe_future, subscribe_packet_id = mqtt_connection.subscribe( topic=TOPIC, qos=qos, - callback=on_receive_message, - suback_callback=on_subscribe) -assert(subscribe_event.wait(TIMEOUT)) + callback=on_receive_message) +subscribe_results = subscribe_future.result(TIMEOUT) assert(subscribe_results['packet_id'] == subscribe_packet_id) assert(subscribe_results['topic'] == TOPIC) assert(subscribe_results['qos'] == qos) # Publish print("Publishing to '{}': {}".format(TOPIC, MESSAGE)) -publish_packet_id = mqtt_connection.publish( +publish_future, publish_packet_id = mqtt_connection.publish( topic=TOPIC, payload=MESSAGE, - qos=mqtt.QoS.AtLeastOnce, - puback_callback=on_publish) -assert(publish_event.wait(TIMEOUT)) + qos=mqtt.QoS.AT_LEAST_ONCE) +publish_results = publish_future.result(TIMEOUT) assert(publish_results['packet_id'] == publish_packet_id) # Receive Message @@ -137,15 +113,13 @@ def on_publish(packet_id): # Unsubscribe print("Unsubscribing from topic") -unsubscribe_packet_id = mqtt_connection.unsubscribe(TOPIC, on_unsubscribe) -assert(unsubscribe_event.wait(TIMEOUT)) +unsubscribe_future, unsubscribe_packet_id = mqtt_connection.unsubscribe(TOPIC) +unsubscribe_results = unsubscribe_future.result(TIMEOUT) assert(unsubscribe_results['packet_id'] == unsubscribe_packet_id) # Disconnect print("Disconnecting") -mqtt_connection.disconnect() -assert(disconnect_event.wait(TIMEOUT)) -assert(disconnect_results['return_code'] == 0) +mqtt_connection.disconnect().result(TIMEOUT) # Done print("Test Success")