Skip to content

Commit

Permalink
auto-format python code (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
graebm authored Sep 20, 2019
1 parent 84fba92 commit 0d500b1
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 36 deletions.
1 change: 1 addition & 0 deletions awscrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

__all__ = ['io', 'mqtt', 'crypto', 'http']


class NativeResource(object):
"""
Base for classes that bind to a native type.
Expand Down
4 changes: 2 additions & 2 deletions awscrt/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def md5_new():
def update(self, to_hash):
_awscrt.hash_update(self._hash, to_hash)

def digest(self, truncate_to = 0):
def digest(self, truncate_to=0):
return _awscrt.hash_digest(self._hash, truncate_to)


Expand All @@ -60,5 +60,5 @@ def sha256_hmac_new(secret_key):
def update(self, to_hmac):
_awscrt.hmac_update(self._hmac, to_hmac)

def digest(self, truncate_to = 0):
def digest(self, truncate_to=0):
return _awscrt.hmac_digest(self._hmac, truncate_to)
8 changes: 2 additions & 6 deletions awscrt/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def on_connection_setup(binding, error_code):

return future


@property
def host_name(self):
return self._host_name
Expand All @@ -118,7 +117,6 @@ def host_name(self):
def port(self):
return self._port


def request(self, request, on_response=None, on_body=None):
return HttpClientStream(self, request, on_response, on_body)

Expand Down Expand Up @@ -148,7 +146,7 @@ def _on_complete(self, error_code):
if error_code == 0:
self._complete_future.set_result(None)
else:
self._complete_future.set_exception(Exception(error_code)) # TODO: Actual exceptions for error_codes
self._complete_future.set_exception(Exception(error_code)) # TODO: Actual exceptions for error_codes


class HttpClientStream(HttpStreamBase):
Expand All @@ -167,12 +165,10 @@ def __init__(self, connection, request, on_response=None, on_body=None):

_awscrt.http_client_stream_new(self, connection, request)


@property
def response_status_code(self):
return self._response_status_code


def _on_response(self, status_code, name_value_pairs):
self._response_status_code = status_code

Expand Down Expand Up @@ -201,6 +197,7 @@ def headers(self):
def body_stream(self):
return self._body_stream


class HttpRequest(HttpMessageBase):
"""
Definition for an outgoing HTTP request.
Expand Down Expand Up @@ -302,7 +299,6 @@ def remove_value(self, name, value):
return
raise ValueError("HttpHeaders.remove_value(name,value): value not found")


def clear(self):
"""
Clear all headers
Expand Down
6 changes: 6 additions & 0 deletions awscrt/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class LogLevel(IntEnum):
Debug = 5
Trace = 6


def init_logging(log_level, file_name):
"""
initialize a logger. log_level is type LogLevel, and file_name is of type str.
Expand Down Expand Up @@ -56,9 +57,11 @@ def __init__(self, num_threads=0):
super(EventLoopGroup, self).__init__()
self._binding = _awscrt.event_loop_group_new(num_threads)


class HostResolverBase(NativeResource):
__slots__ = ()


class DefaultHostResolver(HostResolverBase):
__slots__ = ()

Expand All @@ -68,6 +71,7 @@ def __init__(self, event_loop_group, max_hosts=16):
super(DefaultHostResolver, self).__init__()
self._binding = _awscrt.host_resolver_new_default(max_hosts, event_loop_group)


class ClientBootstrap(NativeResource):
__slots__ = ()

Expand All @@ -82,6 +86,7 @@ def __init__(self, event_loop_group, host_resolver=None):

self._binding = _awscrt.client_bootstrap_new(event_loop_group, host_resolver)


def _read_binary_file(filepath):
with open(filepath, mode='rb') as fh:
contents = fh.read()
Expand Down Expand Up @@ -226,6 +231,7 @@ def create_server_pkcs12(pkcs12_filepath, pkcs12_password):
opt.verify_peer = False
return opt


def _alpn_list_to_str(alpn_list):
"""
Transform ['h2', 'http/1.1'] -> "h2;http/1.1"
Expand Down
37 changes: 21 additions & 16 deletions awscrt/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
from awscrt import NativeResource
from awscrt.io import ClientBootstrap, ClientTlsContext


class QoS(IntEnum):
"""Quality of Service"""
AT_MOST_ONCE = 0
AT_LEAST_ONCE = 1
EXACTLY_ONCE = 2


class ConnectReturnCode(IntEnum):
ACCEPTED = 0
UNACCEPTABLE_PROTOCOL_VERSION = 1
Expand All @@ -31,6 +33,7 @@ class ConnectReturnCode(IntEnum):
BAD_USERNAME_OR_PASSWORD = 4
NOT_AUTHORIZED = 5


class Will(object):
__slots__ = ('topic', 'qos', 'payload', 'retain')

Expand All @@ -40,26 +43,28 @@ def __init__(self, topic, qos, payload, retain):
self.payload = payload
self.retain = retain


class Client(NativeResource):
__slots__ = ('tls_ctx')

def __init__(self, bootstrap, tls_ctx = None):
def __init__(self, bootstrap, tls_ctx=None):
assert isinstance(bootstrap, ClientBootstrap)
assert tls_ctx is None or isinstance(tls_ctx, ClientTlsContext)

super(Client, self).__init__()
self.tls_ctx = tls_ctx
self._binding = _awscrt.mqtt_client_new(bootstrap, tls_ctx)


class Connection(NativeResource):
__slots__ = ('client')

def __init__(self,
client,
on_connection_interrupted=None,
on_connection_resumed=None,
reconnect_min_timeout_sec=5.0,
reconnect_max_timeout_sec=60.0):
client,
on_connection_interrupted=None,
on_connection_resumed=None,
reconnect_min_timeout_sec=5.0,
reconnect_max_timeout_sec=60.0):
"""
on_connection_interrupted: optional callback, with signature (error_code)
on_connection_resumed: optional callback, with signature (error_code, session_present)
Expand All @@ -76,17 +81,17 @@ def __init__(self,
client,
on_connection_interrupted,
on_connection_resumed,
)
)

def connect(self,
client_id,
host_name, port,
use_websocket=False,
clean_session=True, keep_alive=0,
ping_timeout=0,
will=None,
username=None, password=None,
connect_timeout_sec=5.0):
client_id,
host_name, port,
use_websocket=False,
clean_session=True, keep_alive=0,
ping_timeout=0,
will=None,
username=None, password=None,
connect_timeout_sec=5.0):

future = Future()

Expand All @@ -112,7 +117,7 @@ def on_connect(error_code, return_code, session_present):
username,
password,
on_connect,
)
)

except Exception as e:
future.set_exception(e)
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
[pep8]
max-line-length = 120
aggressive = 2
38 changes: 26 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from distutils.ccompiler import get_default_compiler
import setuptools
import os
import subprocess
Expand All @@ -6,29 +7,34 @@
from os import path
import sys


def is_64bit():
if sys.maxsize > 2**32:
return True

return False


def is_32bit():
return is_64bit() == False

def is_arm ():

def is_arm():
return platform.machine().startswith('arm')


def determine_cross_compile_string():
host_arch = platform.machine()
if (host_arch == 'AMD64' or host_arch == 'x86_64') and is_32bit() and sys.platform != 'win32':
return '-DCMAKE_C_FLAGS=-m32'
return ''


def determine_generator_string():
if sys.platform == 'win32':
vs_version = None
prog_x86_path = os.getenv('PROGRAMFILES(x86)')
if vs_version == None:
if vs_version is None:
if os.path.exists(prog_x86_path + '\\Microsoft Visual Studio\\2019'):
vs_version = '16.0'
print('found installed version of Visual Studio 2019')
Expand All @@ -40,7 +46,12 @@ def determine_generator_string():
print('found installed version of Visual Studio 2015')
else:
print('Making an attempt at calling vswhere')
vswhere_args = ['%ProgramFiles(x86)%\\Microsoft Visual Studio\\Installer\\vswhere.exe', '-legacy', '-latest', '-property', 'installationVersion']
vswhere_args = [
'%ProgramFiles(x86)%\\Microsoft Visual Studio\\Installer\\vswhere.exe',
'-legacy',
'-latest',
'-property',
'installationVersion']
vswhere_output = None

try:
Expand All @@ -49,7 +60,7 @@ def determine_generator_string():
print('No version of MSVC compiler could be found!')
exit(1)

if vswhere_output != None:
if vswhere_output is not None:
for out in vswhere_output.split():
vs_version = out.decode('utf-8')
else:
Expand All @@ -69,7 +80,7 @@ def determine_generator_string():
vs_version_gen_str = trimmed_out.split('[')[0].strip(' *')
break

if vs_version_gen_str == None:
if vs_version_gen_str is None:
print('CMake does not recognize an installed version of visual studio on your system.')
exit(1)

Expand All @@ -82,6 +93,7 @@ def determine_generator_string():
return vs_version_gen_str
return ''


generator_string = determine_generator_string()
cross_compile_string = determine_cross_compile_string()
current_dir = os.path.dirname(os.path.realpath(__file__))
Expand All @@ -99,6 +111,7 @@ def determine_generator_string():
if os.path.exists(os.path.join(dep_install_path, 'lib64')):
lib_dir = 'lib64'


def build_dependency(lib_name, extra_cmake_args=[]):
lib_source_dir = os.path.join(current_dir, lib_name)
global lib_dir
Expand Down Expand Up @@ -138,6 +151,7 @@ def build_dependency(lib_name, extra_cmake_args=[]):
os.chdir(build_dir)
return ret_code


if sys.platform != 'darwin' and sys.platform != 'win32':
build_dependency('s2n', ['-DUSE_S2N_PQ_CRYPTO=OFF'])
build_dependency('aws-c-common')
Expand All @@ -149,17 +163,18 @@ def build_dependency(lib_name, extra_cmake_args=[]):

os.chdir(current_dir)

from distutils.ccompiler import get_default_compiler
compiler_type = get_default_compiler()

aws_c_libs = ['aws-c-mqtt', 'aws-c-http', 'aws-c-io', 'aws-c-compression', 'aws-c-cal', 'aws-c-common']


def get_from_env(key):
try:
return os.environ[key]
except:
except BaseException:
return ""


# fetch the CFLAGS/LDFLAGS from env
cflags = get_from_env('CFLAGS').split()
ldflags = get_from_env('LDFLAGS').split()
Expand All @@ -170,18 +185,18 @@ def get_from_env(key):
extra_objects = []

if compiler_type == 'msvc':
#if this is old python, we need to statically link in the VS2015 CRT, the invoking script
# if this is old python, we need to statically link in the VS2015 CRT, the invoking script
# already overrode the compiler environment variables so that a decent compiler is used
# and this is C so it shouldn't really matter.
# actually, I couldn't get this to work, leave it here commented out for future brave souls
#if sys.version_info[0] == 2 or (sys.version_info[0] == 3 and sys.version_info[1] <= 4):
# if sys.version_info[0] == 2 or (sys.version_info[0] == 3 and sys.version_info[1] <= 4):
# cflags += ['/MT']
pass
else:
cflags += ['-O3', '-Wextra', '-Werror', '-Wno-strict-aliasing', '-std=gnu99']

if sys.platform == 'win32':
#the windows apis being used under the hood. Since we're static linking we have to follow the entire chain down
# the windows apis being used under the hood. Since we're static linking we have to follow the entire chain down
libraries += ['Secur32', 'Crypt32', 'Advapi32', 'BCrypt', 'Kernel32', 'Ws2_32', 'Shlwapi']
elif sys.platform == 'darwin':
ldflags += ['-framework Security']
Expand Down Expand Up @@ -239,6 +254,5 @@ def get_from_env(key):
'enum34 ; python_version<"3.4"',
'futures ; python_version<"3.2"',
],
ext_modules = [_awscrt],
ext_modules=[_awscrt],
)

1 change: 1 addition & 0 deletions test/test_http_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,5 +118,6 @@ def test_clear(self):
h.clear()
self.assertEqual([], [pair for pair in h])


if __name__ == '__main__':
unittest.main()

0 comments on commit 0d500b1

Please sign in to comment.