From 89240380496636dbd73de19013721be359c07367 Mon Sep 17 00:00:00 2001 From: Michael Graeb Date: Mon, 28 Oct 2019 16:30:25 -0700 Subject: [PATCH] aws_credentials_provider binding (#90) Only exposing "default chain" and "static" providers for now --- .gitmodules | 3 + aws-c-auth | 1 + aws-c-common | 2 +- awscrt/__init__.py | 8 +- awscrt/auth.py | 111 ++++++++++++ setup.py | 1 + source/auth.h | 30 ++++ source/auth_credentials.c | 288 ++++++++++++++++++++++++++++++++ source/module.c | 14 +- test/resources/credentials_test | 3 + test/test_credentials.py | 86 ++++++++++ 11 files changed, 543 insertions(+), 4 deletions(-) create mode 160000 aws-c-auth create mode 100644 awscrt/auth.py create mode 100644 source/auth.h create mode 100644 source/auth_credentials.c create mode 100644 test/resources/credentials_test create mode 100644 test/test_credentials.py diff --git a/.gitmodules b/.gitmodules index f672310ab..ef1f7bdfc 100644 --- a/.gitmodules +++ b/.gitmodules @@ -19,3 +19,6 @@ [submodule "aws-c-compression"] path = aws-c-compression url = git://github.com/awslabs/aws-c-compression +[submodule "aws-c-auth"] + path = aws-c-auth + url = git://github.com/awslabs/aws-c-auth diff --git a/aws-c-auth b/aws-c-auth new file mode 160000 index 000000000..fecefc800 --- /dev/null +++ b/aws-c-auth @@ -0,0 +1 @@ +Subproject commit fecefc8009760218e29f7bf53a8efa0193988f8c diff --git a/aws-c-common b/aws-c-common index 97081187e..e3e7ccd35 160000 --- a/aws-c-common +++ b/aws-c-common @@ -1 +1 @@ -Subproject commit 97081187e3c8744fc3baa86c71d1ece921e4bd83 +Subproject commit e3e7ccd35a85f9cd38c67cb1988251f1543b6632 diff --git a/awscrt/__init__.py b/awscrt/__init__.py index 59ba1dbdb..183a541e6 100644 --- a/awscrt/__init__.py +++ b/awscrt/__init__.py @@ -14,7 +14,13 @@ from sys import version_info from weakref import WeakSet -__all__ = ['io', 'mqtt', 'crypto', 'http'] +__all__ = [ + 'auth', + 'crypto', + 'http', + 'io', + 'mqtt', +] class NativeResource(object): diff --git a/awscrt/auth.py b/awscrt/auth.py new file mode 100644 index 000000000..b9f61d714 --- /dev/null +++ b/awscrt/auth.py @@ -0,0 +1,111 @@ +# Copyright 2010-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://aws.amazon.com/apache2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from __future__ import absolute_import +import _awscrt +from awscrt import isinstance_str, NativeResource +from awscrt.io import ClientBootstrap +from concurrent.futures import Future + + +class Credentials(object): + """ + Credentials are the public/private data needed to sign an authenticated AWS request. + """ + + __slots__ = ('access_key_id', 'secret_access_key', 'session_token') + + def __init__(self, access_key_id, secret_access_key, session_token=None): + assert isinstance_str(access_key_id) + assert isinstance_str(secret_access_key) + assert isinstance_str(session_token) or session_token is None + + self.access_key_id = access_key_id + self.secret_access_key = secret_access_key + self.session_token = session_token + + +class CredentialsProviderBase(NativeResource): + """ + Base class for providers that source the Credentials needed to sign an authenticated AWS request. + """ + + def get_credentials(self): + """ + Asynchronously fetch Credentials. + + Returns a Future which will contain Credentials (or an exception) + when the call completes. The call may complete on a different thread. + """ + future = Future() + + def _on_complete(error_code, access_key_id, secret_access_key, session_token): + try: + if error_code: + future.set_exception(Exception(error_code)) # TODO: Actual exceptions for error_codes + else: + credentials = Credentials(access_key_id, secret_access_key, session_token) + future.set_result(credentials) + + except Exception as e: + future.set_exception(e) + + try: + _awscrt.credentials_provider_get_credentials(self._binding, _on_complete) + except Exception as e: + future.set_result(e) + + return future + + def close(self): + """ + Signal a provider (and all linked providers) to cancel pending queries and + stop accepting new ones. Useful to hasten shutdown time if you know the provider + is going away. + """ + _awscrt.credentials_provider_shutdown(self._binding) + + +class DefaultCredentialsProviderChain(CredentialsProviderBase): + """ + Providers source the Credentials needed to sign an authenticated AWS request. + This is the default provider chain used by most AWS SDKs. + + Generally: + + (1) Environment + (2) Profile + (3) (conditional, off by default) ECS + (4) (conditional, on by default) EC2 Instance Metadata + """ + + def __init__(self, client_bootstrap): + assert isinstance(client_bootstrap, ClientBootstrap) + + super(DefaultCredentialsProviderChain, self).__init__() + self._binding = _awscrt.credentials_provider_new_chain_default(client_bootstrap) + + +class StaticCredentialsProvider(CredentialsProviderBase): + """ + Providers source the Credentials needed to sign an authenticated AWS request. + This is a simple provider that just returns a fixed set of credentials + """ + + def __init__(self, access_key_id, secret_access_key, session_token=None): + assert isinstance_str(access_key_id) + assert isinstance_str(secret_access_key) + assert isinstance_str(session_token) or session_token is None + + super(StaticCredentialsProvider, self).__init__() + self._binding = _awscrt.credentials_provider_new_static(access_key_id, secret_access_key, session_token) diff --git a/setup.py b/setup.py index c6208fbcf..4fd94efd3 100644 --- a/setup.py +++ b/setup.py @@ -99,6 +99,7 @@ def __init__(self, name, extra_cmake_args=[]): AWS_LIBS.append(AwsLib('aws-c-cal')) AWS_LIBS.append(AwsLib('aws-c-compression')) AWS_LIBS.append(AwsLib('aws-c-http')) +AWS_LIBS.append(AwsLib('aws-c-auth')) AWS_LIBS.append(AwsLib('aws-c-mqtt')) diff --git a/source/auth.h b/source/auth.h new file mode 100644 index 000000000..28d262341 --- /dev/null +++ b/source/auth.h @@ -0,0 +1,30 @@ +#ifndef AWS_CRT_PYTHON_AUTH_H +#define AWS_CRT_PYTHON_AUTH_H +/* + * Copyright 2010-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +#include "module.h" + +PyObject *aws_py_credentials_provider_get_credentials(PyObject *self, PyObject *args); +PyObject *aws_py_credentials_provider_shutdown(PyObject *self, PyObject *args); + +PyObject *aws_py_credentials_provider_new_chain_default(PyObject *self, PyObject *args); +PyObject *aws_py_credentials_provider_new_static(PyObject *self, PyObject *args); + +/* Given a python object, return a pointer to its underlying native type. + * If NULL is returned, a python error has been set */ + +struct aws_credentials_provider *aws_py_get_credentials_provider(PyObject *credentials_provider); + +#endif // AWS_CRT_PYTHON_AUTH_H diff --git a/source/auth_credentials.c b/source/auth_credentials.c new file mode 100644 index 000000000..401ebb028 --- /dev/null +++ b/source/auth_credentials.c @@ -0,0 +1,288 @@ +/* + * Copyright 2010-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +#include "auth.h" + +#include "io.h" + +#include +#include + +static const char *s_capsule_name_credentials_provider = "aws_credentials_provider"; + +/** + * Binds a Python CredentialsProvider to a native aws_credentials_provider. + */ +struct credentials_provider_binding { + struct aws_credentials_provider *native; + + /* Dependencies that must outlive this. + * Note that different types of providers have different dependencies */ + PyObject *bootstrap; +}; + +/* Runs when the GC destroys the capsule containing the binding */ +static void s_credentials_provider_capsule_destructor(PyObject *capsule) { + struct credentials_provider_binding *provider = PyCapsule_GetPointer(capsule, s_capsule_name_credentials_provider); + + /* Note that destructor might run due to setup failing, and some/all members might still be NULL. */ + + if (provider->native) { + aws_credentials_provider_release(provider->native); + } + + Py_XDECREF(provider->bootstrap); + aws_mem_release(aws_py_get_allocator(), provider); +} + +struct aws_credentials_provider *aws_py_get_credentials_provider(PyObject *credentials_provider) { + struct aws_credentials_provider *native = NULL; + + PyObject *capsule = PyObject_GetAttrString(credentials_provider, "_binding"); + if (capsule) { + struct credentials_provider_binding *binding = + PyCapsule_GetPointer(capsule, s_capsule_name_credentials_provider); + if (binding) { + native = binding->native; + AWS_FATAL_ASSERT(native); + } + Py_DECREF(capsule); + } + + return native; +} + +int s_aws_string_to_cstr_and_ssize(const struct aws_string *source, const char **out_cstr, Py_ssize_t *out_ssize) { + *out_cstr = NULL; + *out_ssize = 0; + if (source) { + if (source->len > PY_SSIZE_T_MAX) { + return aws_raise_error(AWS_ERROR_OVERFLOW_DETECTED); + } + *out_cstr = aws_string_c_str(source); + *out_ssize = source->len; + } + return AWS_OP_SUCCESS; +} + +/* Completion callback for get_credentials() */ +static void s_on_get_credentials_complete(struct aws_credentials *credentials, void *user_data) { + PyObject *on_complete_cb = user_data; + + /* NOTE: This callback doesn't currently supply an error_code, but it should. */ + int error_code = AWS_ERROR_UNKNOWN; + + /* Note that we don't actually bind the native aws_credentials to the python Credentials class, + * we simply copy the contents back and forth. */ + const char *access_key_id = NULL; + Py_ssize_t access_key_id_len = 0; + const char *secret_access_key = NULL; + Py_ssize_t secret_access_key_len = 0; + const char *session_token = NULL; + Py_ssize_t session_token_len = 0; + + if (credentials) { + error_code = AWS_ERROR_SUCCESS; + + if (s_aws_string_to_cstr_and_ssize(credentials->access_key_id, &access_key_id, &access_key_id_len)) { + error_code = aws_last_error(); + } + if (s_aws_string_to_cstr_and_ssize( + credentials->secret_access_key, &secret_access_key, &secret_access_key_len)) { + error_code = aws_last_error(); + } + if (s_aws_string_to_cstr_and_ssize(credentials->session_token, &session_token, &session_token_len)) { + error_code = aws_last_error(); + } + } + + /*************** GIL ACQUIRE ***************/ + PyGILState_STATE state = PyGILState_Ensure(); + + PyObject *result = PyObject_CallFunction( + on_complete_cb, + "(is#s#s#)", + error_code, + access_key_id, + access_key_id_len, + secret_access_key, + secret_access_key_len, + session_token, + session_token_len); + if (result) { + Py_DECREF(result); + } else { + PyErr_WriteUnraisable(PyErr_Occurred()); + } + + Py_DECREF(on_complete_cb); + + PyGILState_Release(state); + /*************** GIL RELEASE ***************/ +} + +PyObject *aws_py_credentials_provider_get_credentials(PyObject *self, PyObject *args) { + (void)self; + PyObject *capsule; + PyObject *on_complete_cb; + if (!PyArg_ParseTuple(args, "OO", &capsule, &on_complete_cb)) { + return NULL; + } + + struct credentials_provider_binding *provider = PyCapsule_GetPointer(capsule, s_capsule_name_credentials_provider); + if (!provider) { + return NULL; + } + + AWS_FATAL_ASSERT(on_complete_cb != Py_None); + + Py_INCREF(on_complete_cb); + if (aws_credentials_provider_get_credentials(provider->native, s_on_get_credentials_complete, on_complete_cb)) { + Py_DECREF(on_complete_cb); + return PyErr_AwsLastError(); + } + + Py_RETURN_NONE; +} + +PyObject *aws_py_credentials_provider_shutdown(PyObject *self, PyObject *args) { + (void)self; + PyObject *capsule; + if (!PyArg_ParseTuple(args, "O", &capsule)) { + return NULL; + } + + struct credentials_provider_binding *provider = PyCapsule_GetPointer(capsule, s_capsule_name_credentials_provider); + if (!provider) { + return NULL; + } + + aws_credentials_provider_shutdown(provider->native); + Py_RETURN_NONE; +} + +/* Create binding and capsule. + * Helper function for every aws_py_credentials_provider_new_XYZ() function */ +PyObject *s_new_credentials_provider_binding_and_capsule(struct credentials_provider_binding **out_binding) { + *out_binding = NULL; + + struct credentials_provider_binding *binding = + aws_mem_calloc(aws_py_get_allocator(), 1, sizeof(struct credentials_provider_binding)); + if (!binding) { + return PyErr_AwsLastError(); + } + + PyObject *capsule = + PyCapsule_New(binding, s_capsule_name_credentials_provider, s_credentials_provider_capsule_destructor); + if (!capsule) { + aws_mem_release(aws_py_get_allocator(), binding); + return NULL; + } + + *out_binding = binding; + return capsule; +} + +PyObject *aws_py_credentials_provider_new_chain_default(PyObject *self, PyObject *args) { + (void)self; + + PyObject *bootstrap_py; + if (!PyArg_ParseTuple(args, "O", &bootstrap_py)) { + return NULL; + } + + struct aws_client_bootstrap *bootstrap = aws_py_get_client_bootstrap(bootstrap_py); + if (!bootstrap) { + return NULL; + } + + struct credentials_provider_binding *binding; + PyObject *capsule = s_new_credentials_provider_binding_and_capsule(&binding); + if (!capsule) { + return NULL; + } + + /* From hereon, we need to clean up if errors occur. + * Fortunately, the capsule destructor will clean up anything stored inside the binding */ + + binding->bootstrap = bootstrap_py; + Py_INCREF(binding->bootstrap); + + struct aws_credentials_provider_chain_default_options options = { + .bootstrap = bootstrap, + }; + + binding->native = aws_credentials_provider_new_chain_default(aws_py_get_allocator(), &options); + if (!binding->native) { + PyErr_SetAwsLastError(); + goto error; + } + + return capsule; + +error: + Py_DECREF(capsule); + return NULL; +} + +PyObject *aws_py_credentials_provider_new_static(PyObject *self, PyObject *args) { + (void)self; + + struct aws_allocator *allocator = aws_py_get_allocator(); + + const char *access_key_id; + Py_ssize_t access_key_id_len; + const char *secret_access_key; + Py_ssize_t secret_access_key_len; + const char *session_token; /* optional */ + Py_ssize_t session_token_len; + + if (!PyArg_ParseTuple( + args, + "s#s#z#", + &access_key_id, + &access_key_id_len, + &secret_access_key, + &secret_access_key_len, + &session_token, + &session_token_len)) { + return NULL; + } + + struct credentials_provider_binding *binding; + PyObject *capsule = s_new_credentials_provider_binding_and_capsule(&binding); + if (!capsule) { + return NULL; + } + + /* From hereon, we need to clean up if errors occur. + * Fortunately, the capsule destructor will clean up anything stored inside the binding */ + + binding->native = aws_credentials_provider_new_static( + allocator, + aws_byte_cursor_from_array(access_key_id, access_key_id_len), + aws_byte_cursor_from_array(secret_access_key, secret_access_key_len), + aws_byte_cursor_from_array(session_token, session_token_len)); + + if (!binding->native) { + PyErr_SetAwsLastError(); + goto error; + } + + return capsule; +error: + Py_DECREF(capsule); + return NULL; +} diff --git a/source/module.c b/source/module.c index 187696d17..4623ec2b5 100644 --- a/source/module.c +++ b/source/module.c @@ -13,18 +13,20 @@ * permissions and limitations under the License. */ #include "module.h" + +#include "auth.h" #include "crypto.h" #include "http.h" #include "io.h" #include "mqtt_client.h" #include "mqtt_client_connection.h" +#include #include +#include #include #include #include - -#include #include #include @@ -277,6 +279,12 @@ static PyMethodDef s_module_methods[] = { AWS_PY_METHOD_DEF(http_client_stream_new, METH_VARARGS), AWS_PY_METHOD_DEF(http_request_new, METH_VARARGS), + /* Auth */ + AWS_PY_METHOD_DEF(credentials_provider_get_credentials, METH_VARARGS), + AWS_PY_METHOD_DEF(credentials_provider_shutdown, METH_VARARGS), + AWS_PY_METHOD_DEF(credentials_provider_new_chain_default, METH_VARARGS), + AWS_PY_METHOD_DEF(credentials_provider_new_static, METH_VARARGS), + {NULL, NULL, 0, NULL}, }; @@ -297,6 +305,7 @@ static void s_module_free(void *userdata) { aws_logger_clean_up(&s_logger); } aws_mqtt_library_clean_up(); + aws_auth_library_clean_up(); aws_http_library_clean_up(); } @@ -323,6 +332,7 @@ PyMODINIT_FUNC INIT_FN(void) { #endif /* PY_MAJOR_VERSION */ aws_http_library_init(aws_py_get_allocator()); + aws_auth_library_init(aws_py_get_allocator()); aws_mqtt_library_init(aws_py_get_allocator()); if (!PyEval_ThreadsInitialized()) { diff --git a/test/resources/credentials_test b/test/resources/credentials_test new file mode 100644 index 000000000..f5745e2f8 --- /dev/null +++ b/test/resources/credentials_test @@ -0,0 +1,3 @@ +[default] +aws_access_key_id = credentials_test_access_key_id +aws_secret_access_key = credentials_test_secret_access_key diff --git a/test/test_credentials.py b/test/test_credentials.py new file mode 100644 index 000000000..5fde96b26 --- /dev/null +++ b/test/test_credentials.py @@ -0,0 +1,86 @@ +# Copyright 2010-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://aws.amazon.com/apache2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from __future__ import absolute_import +from awscrt.auth import DefaultCredentialsProviderChain, StaticCredentialsProvider +from awscrt.io import ClientBootstrap, EventLoopGroup +import os +from test import NativeResourceTest + + +class ScopedEnvironmentVariable(object): + """ + Set environment variable for lifetime of this object. + """ + + def __init__(self, key, value): + self.key = key + self.prev_value = os.environ.get(key) + os.environ[key] = value + + def __del__(self): + if self.prev_value is None: + del os.environ[self.key] + else: + os.environ[self.key] = self.prev_value + + +class TestProvider(NativeResourceTest): + example_access_key_id = 'example_access_key_id' + example_secret_access_key = 'example_secret_access_key' + example_session_token = 'example_session_token' + + def test_static_provider(self): + provider = StaticCredentialsProvider( + self.example_access_key_id, + self.example_secret_access_key, + self.example_session_token) + + future = provider.get_credentials() + credentials = future.result() + + self.assertEqual(self.example_access_key_id, credentials.access_key_id) + self.assertEqual(self.example_secret_access_key, credentials.secret_access_key) + self.assertEqual(self.example_session_token, credentials.session_token) + + # TODO: test currently broken because None session_token comes back as empty string do to inconsistent use of + # aws_byte_cursor by value/pointer in aws-c-auth APIs. + # + # def test_static_provider_no_session_token(self): + # provider = StaticCredentialsProvider( + # self.example_access_key_id, + # self.example_secret_access_key) + + # future = provider.get_credentials() + # credentials = future.result() + + # self.assertEqual(self.example_access_key_id, credentials.access_key_id) + # self.assertEqual(self.example_secret_access_key, credentials.secret_access_key) + # self.assertIsNone(credentials.session_token) + + def test_default_provider(self): + # Use environment variable to force specific credentials file + scoped_env = ScopedEnvironmentVariable('AWS_SHARED_CREDENTIALS_FILE', 'test/resources/credentials_test') + + event_loop_group = EventLoopGroup() + bootstrap = ClientBootstrap(event_loop_group) + provider = DefaultCredentialsProviderChain(bootstrap) + + future = provider.get_credentials() + credentials = future.result() + + self.assertEqual('credentials_test_access_key_id', credentials.access_key_id) + self.assertEqual('credentials_test_secret_access_key', credentials.secret_access_key) + self.assertIsNone(credentials.session_token) + + del scoped_env