Skip to content

Commit

Permalink
Fixes for logger getting cleaned up too early, also handles null term… (
Browse files Browse the repository at this point in the history
#51)

* Fixes for logger getting cleaned up too early, also handles null terminated buffers properly.
Updated submodules
  • Loading branch information
JonathanHenson authored May 10, 2019
1 parent b839e81 commit 0f1725d
Show file tree
Hide file tree
Showing 12 changed files with 80 additions and 100 deletions.
2 changes: 1 addition & 1 deletion aws-c-common
Submodule aws-c-common updated 217 files
2 changes: 1 addition & 1 deletion aws-c-mqtt
43 changes: 12 additions & 31 deletions awscrt/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,15 @@ class LogLevel(IntEnum):
Debug = 5
Trace = 6


class Logger(object):
def init_logging(log_level, file_name):
"""
initialize a logger. log_level is type LogLevel, and file_name is of type str.
To write to stdout, or stderr, simply pass 'stdout' or 'stderr' as strings. Otherwise, a file path is assumed.
"""
__slots__ = ('_internal_logger')

def __init__(self, log_level, file_name):
assert log_level is not None
assert file_name is not None
assert log_level is not None
assert file_name is not None

self._internal_logger = _aws_crt_python.aws_py_io_init_logging(log_level, file_name)
_aws_crt_python.aws_py_io_init_logging(log_level, file_name)


def is_alpn_available():
Expand Down Expand Up @@ -80,21 +76,6 @@ def __init__(self, elg, host_resolver=None):
self._internal_bootstrap = _aws_crt_python.aws_py_io_client_bootstrap_new(self.elg._internal_elg, host_resolver._internal_host_resolver)


#
def byte_buf_null_terminate(buf):
"""
force null termination at the end of buffer
:param buf: buffer to null terminate
:return: null terminated buffer
"""
if not buf.endswith(bytes([0])):
# I know this looks hacky. please don't change it
# because appending bytes([0]) does not work in python 2.7
# this works in both.
buf = buf + b'\0'
return buf


def byte_buf_from_file(filepath):
with open(filepath, mode='rb') as fh:
contents = fh.read()
Expand Down Expand Up @@ -169,7 +150,7 @@ def override_default_trust_store_from_path(self, ca_path, ca_file):
def override_default_trust_store(self, rootca_buffer):
assert isinstance(rootca_buffer, bytes)

self.ca_buffer = byte_buf_null_terminate(rootca_buffer)
self.ca_buffer = rootca_buffer

@staticmethod
def create_client_with_mtls_from_path(cert_path, pk_path):
Expand All @@ -188,8 +169,8 @@ def create_client_with_mtls(cert_buffer, key_buffer):
assert isinstance(key_buffer, bytes)

opt = TlsContextOptions()
opt.certificate_buffer = byte_buf_null_terminate(cert_buffer)
opt.private_key_buffer = byte_buf_null_terminate(key_buffer)
opt.certificate_buffer = cert_buffer
opt.private_key_buffer = key_buffer

opt.verify_peer = True
return opt
Expand All @@ -207,7 +188,7 @@ def create_client_with_mtls_pkcs12(pkcs12_path, pkcs12_password):
return opt

@staticmethod
def create_server_with_mtls_from_path(cert_path, pk_path):
def create_server_from_path(cert_path, pk_path):

assert isinstance(cert_path, str)
assert isinstance(pk_path, str)
Expand All @@ -218,18 +199,18 @@ def create_server_with_mtls_from_path(cert_path, pk_path):
return TlsContextOptions.create_server_with_mtls(cert_buffer, key_buffer)

@staticmethod
def create_server_with_mtls(cert_buffer, key_buffer):
def create_server(cert_buffer, key_buffer):
assert isinstance(cert_buffer, bytes)
assert isinstance(key_buffer, bytes)

opt = TlsContextOptions()
opt.certificate_buffer = byte_buf_null_terminate(cert_buffer)
opt.private_key_buffer = byte_buf_null_terminate(key_buffer)
opt.certificate_buffer = cert_buffer
opt.private_key_buffer = key_buffer
opt.verify_peer = False
return opt

@staticmethod
def create_server_with_mtls_pkcs12(pkcs12_path, pkcs12_password):
def create_server_pkcs12(pkcs12_path, pkcs12_password):

assert isinstance(pkcs12_path, str)
assert isinstance(pkcs12_password, str)
Expand Down
2 changes: 1 addition & 1 deletion awscrt/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,4 +203,4 @@ def puback(packet_id):
return future, packet_id

def ping(self):
_aws_crt_python.aws_py_mqtt_client_connection_ping(self._internal_connection)
_aws_crt_python.aws_py_mqtt_client_connection_ping(self._internal_connection)
3 changes: 1 addition & 2 deletions elasticurl.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def print_header_list(headers):
output = open(args.output, mode='wb')

# setup the logger if user request logging
logger = None

if args.verbose:
log_level = io.LogLevel.NoLogs
Expand All @@ -76,7 +75,7 @@ def print_header_list(headers):
if args.trace:
log_output = args.trace

logger = io.Logger(log_level, log_output)
io.init_logging(log_level, log_output)

# an event loop group is needed for IO operations. Unless you're a server or a client doing hundreds of connections
# you only want one of these.
Expand Down
4 changes: 4 additions & 0 deletions mqtt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import argparse
from awscrt import io, mqtt
from awscrt.io import LogLevel
import threading
import uuid

Expand All @@ -31,6 +32,8 @@
parser.add_argument('--key', help="File path to your private key, in PEM format")
parser.add_argument('--root-ca', help="File path to root certificate authority, in PEM format")

io.init_logging(LogLevel.Trace, 'stderr')

def on_connection_interrupted(error_code):
print("Connection has been interrupted with error code", error_code)

Expand Down Expand Up @@ -96,6 +99,7 @@ def on_receive_message(topic, message):
subscribe_results = subscribe_future.result(TIMEOUT)
assert(subscribe_results['packet_id'] == subscribe_packet_id)
assert(subscribe_results['topic'] == TOPIC)
print(subscribe_results)
assert(subscribe_results['qos'] == qos)

# Publish
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def get_from_env(key):

setuptools.setup(
name="awscrt",
version="v0.2.15",
version="v0.2.16",
author="Amazon Web Services, Inc",
author_email="[email protected]",
description="A common runtime for AWS Python projects",
Expand Down
58 changes: 1 addition & 57 deletions source/io.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#include <aws/io/channel_bootstrap.h>
#include <aws/io/event_loop.h>
#include <aws/io/logging.h>
#include <aws/io/tls_channel_handler.h>

#include <stdio.h>
Expand All @@ -27,7 +26,6 @@ static const char *s_capsule_name_elg = "aws_event_loop_group";
const char *s_capsule_name_host_resolver = "aws_host_resolver";
const char *s_capsule_name_tls_ctx = "aws_client_tls_ctx";
const char *s_capsule_name_tls_conn_options = "aws_tls_connection_options";
const char *s_capsule_name_logger = "aws_logger";

PyObject *aws_py_is_alpn_available(PyObject *self, PyObject *args) {

Expand All @@ -37,64 +35,9 @@ PyObject *aws_py_is_alpn_available(PyObject *self, PyObject *args) {
return PyBool_FromLong(aws_tls_is_alpn_available());
}

static void s_logger_destructor(PyObject *logger_capsule) {
struct aws_logger *logger = PyCapsule_GetPointer(logger_capsule, s_capsule_name_logger);
assert(logger);

struct aws_allocator *allocator = aws_crt_python_get_allocator();

aws_logger_clean_up(logger);
aws_mem_release(allocator, logger);
}

PyObject *aws_py_io_init_logging(PyObject *self, PyObject *args) {
(void)self;

struct aws_allocator *allocator = aws_crt_python_get_allocator();

int log_level = 0;
const char *file_path = NULL;
Py_ssize_t file_path_len = 0;

if (!PyArg_ParseTuple(args, "bs#", &log_level, &file_path, &file_path_len)) {
PyErr_SetNone(PyExc_ValueError);
return NULL;
}

struct aws_logger *logger = aws_mem_acquire(allocator, sizeof(struct aws_logger));

if (!logger) {
return PyErr_AwsLastError();
}

struct aws_logger_standard_options log_options = {
.level = log_level,
.file = NULL,
.filename = NULL,
};

Py_ssize_t stdout_len = (Py_ssize_t)strlen("stdout");

Py_ssize_t cmp_len = file_path_len > stdout_len ? stdout_len : file_path_len;

if (!memcmp("stdout", file_path, (size_t)cmp_len)) {
log_options.file = stdout;
} else if (!memcmp("stderr", file_path, (size_t)cmp_len)) {
log_options.file = stderr;
} else {
log_options.filename = file_path;
}

aws_logger_init_standard(logger, allocator, &log_options);
aws_logger_set(logger);

return PyCapsule_New(logger, s_capsule_name_logger, s_logger_destructor);
}

static void s_elg_destructor(PyObject *elg_capsule) {

assert(PyCapsule_CheckExact(elg_capsule));

struct aws_event_loop_group *elg = PyCapsule_GetPointer(elg_capsule, s_capsule_name_elg);
assert(elg);

Expand Down Expand Up @@ -271,6 +214,7 @@ PyObject *aws_py_io_client_tls_ctx_new(PyObject *self, PyObject *args) {
}
if (ca_buffer && ca_buffer_len > 0) {
struct aws_byte_cursor ca = aws_byte_cursor_from_array(ca_buffer, ca_buffer_len);

aws_tls_ctx_options_override_default_trust_store(&ctx_options, &ca);
}

Expand Down
57 changes: 55 additions & 2 deletions source/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,53 @@

#include <memoryobject.h>

static struct aws_logger s_logger;
static bool s_logger_init = false;

PyObject *aws_py_io_init_logging(PyObject *self, PyObject *args) {
(void)self;

if (s_logger_init) {
Py_RETURN_NONE;
}

s_logger_init = true;

struct aws_allocator *allocator = aws_crt_python_get_allocator();

int log_level = 0;
const char *file_path = NULL;
Py_ssize_t file_path_len = 0;

if (!PyArg_ParseTuple(args, "bs#", &log_level, &file_path, &file_path_len)) {
PyErr_SetNone(PyExc_ValueError);
return NULL;
}

struct aws_logger_standard_options log_options = {
.level = log_level,
.file = NULL,
.filename = NULL,
};

Py_ssize_t stdout_len = (Py_ssize_t)strlen("stdout");

Py_ssize_t cmp_len = file_path_len > stdout_len ? stdout_len : file_path_len;

if (!memcmp("stdout", file_path, (size_t)cmp_len)) {
log_options.file = stdout;
} else if (!memcmp("stderr", file_path, (size_t)cmp_len)) {
log_options.file = stderr;
} else {
log_options.filename = file_path;
}

aws_logger_init_standard(&s_logger, allocator, &log_options);
aws_logger_set(&s_logger);

Py_RETURN_NONE;
}

#if PY_MAJOR_VERSION == 3
# define INIT_FN PyInit__aws_crt_python
# define UNICODE_GET_BYTES_FN PyUnicode_DATA
Expand Down Expand Up @@ -168,7 +215,13 @@ static void s_module_free(void *userdata) {
(void)userdata;

aws_tls_clean_up_static_state();

if (s_logger_init) {
aws_logger_clean_up(&s_logger);
}
aws_mqtt_library_clean_up();
}

#endif /* PY_MAJOR_VERSION == 3 */

PyMODINIT_FUNC INIT_FN(void) {
Expand All @@ -193,12 +246,12 @@ PyMODINIT_FUNC INIT_FN(void) {

aws_load_error_strings();
aws_io_load_error_strings();
aws_mqtt_load_error_strings();

aws_io_load_log_subject_strings();
aws_tls_init_static_state(aws_crt_python_get_allocator());
aws_http_library_init(aws_crt_python_get_allocator());
aws_mqtt_library_init(aws_crt_python_get_allocator());

aws_tls_init_static_state(aws_crt_python_get_allocator());
if (!PyEval_ThreadsInitialized()) {
PyEval_InitThreads();
}
Expand Down
3 changes: 1 addition & 2 deletions source/mqtt_client_connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -621,8 +621,7 @@ static void s_suback_callback(

const char *topic_str = (const char *)topic->ptr;
Py_ssize_t topic_len = topic->len;

PyObject *result = PyObject_CallFunction(callback, "(Hs#L)", packet_id, topic_str, topic_len, qos);
PyObject *result = PyObject_CallFunction(callback, "(Hs#b)", packet_id, topic_str, topic_len, qos);
if (!result) {
PyErr_WriteUnraisable(PyErr_Occurred());
abort();
Expand Down

0 comments on commit 0f1725d

Please sign in to comment.