diff --git a/test/test_mqtt.py b/test/test_mqtt.py index 7a6acdf3b..c9473dfa1 100644 --- a/test/test_mqtt.py +++ b/test/test_mqtt.py @@ -45,11 +45,12 @@ def test_lifetime(self): class Config: cache = None - def __init__(self, endpoint, cert, key, region): + def __init__(self, endpoint, cert, key, region, cognito_creds): self.cert = cert self.key = key self.endpoint = endpoint self.region = region + self.cognito_creds = cognito_creds @staticmethod def get(): @@ -73,7 +74,16 @@ def get(): response = secrets.get_secret_value(SecretId='unit-test/privatekey') key = response['SecretString'].encode('utf8') region = secrets.meta.region_name - Config.cache = Config(endpoint, cert, key, region) + response = secrets.get_secret_value(SecretId='unit-test/cognitopool') + cognito_pool = response['SecretString'] + + cognito = boto3.client('cognito-identity') + response = cognito.get_id(IdentityPoolId=cognito_pool) + cognito_id = response['IdentityId'] + response = cognito.get_credentials_for_identity(IdentityId=cognito_id) + cognito_creds = response['Credentials'] + + Config.cache = Config(endpoint, cert, key, region, cognito_creds) except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as ex: raise unittest.SkipTest("No credentials") @@ -229,6 +239,24 @@ def test_websockets_default(self): client_bootstrap=bootstrap) self._test_connection(connection) + def test_websockets_sts(self): + """Websocket connection with X-Amz-Security-Token query param""" + config = Config.get() + elg = EventLoopGroup() + resolver = DefaultHostResolver(elg) + bootstrap = ClientBootstrap(elg, resolver) + cred_provider = AwsCredentialsProvider.new_static( + access_key_id=config.cognito_creds['AccessKeyId'], + secret_access_key=config.cognito_creds['SecretKey'], + session_token=config.cognito_creds['SessionToken']) + connection = awsiot_mqtt_connection_builder.websockets_with_default_aws_signing( + region=config.region, + credentials_provider=cred_provider, + endpoint=config.endpoint, + client_id=create_client_id(), + client_bootstrap=bootstrap) + self._test_connection(connection) + @unittest.skipIf(PROXY_HOST is None, 'requires "proxyhost" and "proxyport" env vars') def test_websockets_proxy(self): config = Config.get()