From 52fd08ad67f3e6773a12a46d7842ce80ec9698a9 Mon Sep 17 00:00:00 2001 From: Nikolay Kiryanov Date: Wed, 12 Jun 2024 02:19:12 +0300 Subject: [PATCH] Drop AsyncJWK to use static JWK public key --- env.example | 3 +- secrets/.gitkeep | 0 src/a12n/fixtures.py | 108 ++++++++++++ src/a12n/jwk_client.py | 157 ------------------ src/a12n/jwt_decode.py | 15 ++ src/a12n/tests/async_jwk_client/conftest.py | 73 -------- .../tests_async_jwk_common.py | 9 - .../tests_async_jwk_decode.py | 35 ---- .../tests_async_jwk_fetch_data.py | 57 ------- .../tests_async_jwk_get_jwk_set.py | 56 ------- ...ests_async_jwk_get_signing_key_from_jwt.py | 61 ------- src/a12n/tests/tests_jwt_decode.py | 21 +++ src/app/conf/settings.py | 6 +- src/conftest.py | 1 + src/handlers/messages_handler.py | 27 ++- .../tests/messages_handler/conftest.py | 18 +- .../tests_auth_message_handler.py | 27 ++- .../tests_message_handler_common.py | 25 ++- .../tests_subscirbe_message_handler.py | 8 +- .../tests_unsubscirbe_message_handler.py | 8 +- src/handlers/websockets_handler.py | 2 +- src/tests/functional/conftest.py | 15 +- src/tests/functional/tests_authentication.py | 1 - 23 files changed, 204 insertions(+), 529 deletions(-) create mode 100644 secrets/.gitkeep create mode 100644 src/a12n/fixtures.py delete mode 100644 src/a12n/jwk_client.py create mode 100644 src/a12n/jwt_decode.py delete mode 100644 src/a12n/tests/async_jwk_client/conftest.py delete mode 100644 src/a12n/tests/async_jwk_client/tests_async_jwk_common.py delete mode 100644 src/a12n/tests/async_jwk_client/tests_async_jwk_decode.py delete mode 100644 src/a12n/tests/async_jwk_client/tests_async_jwk_fetch_data.py delete mode 100644 src/a12n/tests/async_jwk_client/tests_async_jwk_get_jwk_set.py delete mode 100644 src/a12n/tests/async_jwk_client/tests_async_jwk_get_signing_key_from_jwt.py create mode 100644 src/a12n/tests/tests_jwt_decode.py diff --git a/env.example b/env.example index 7fe361b..4d37460 100644 --- a/env.example +++ b/env.example @@ -9,5 +9,4 @@ WEBSOCKETS_HOST=localhost WEBSOCKETS_PORT=6789 WEBSOCKETS_PATH=/v1/test-subscription-websocket -AUTH_JWKS_URL=https://auth-test.contoso.com/auth/realms/test-realm/protocol/openid-connect/certs -AUTH_SUPPORTED_SIGNING_ALGORITHMS=["RS256"] +JWT_PUBLIC_KEY=-----BEGIN PUBLIC KEY-----\nMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAqdnObC47NUAqOOCvOzOg\n4i0KxZaTe2m7rfQe9a+rWbtkJ2TNuakN7eRshvv2UVGP4uSKEKe356v4GwP/yGAi\n92XEGr0Y6ieypnhu1wi0wuK4Z62abRkvsEZdDwKQpcde1rvyuvt0YeDMh9dCi/3P\nBLhcOlgvAu+6M79iWlRTZxzFe3KVzQabU8CIfgG7MXokutHUxT2dsRNfX4VwxMsW\no9o0o1QqSPJ6OOx2DwLEKat5n1w5ysIYYvkgHs36B3nPnZYc2b66uEp9AP9JlRjc\npWuH8vn3/OsvxHMErhyn+h9/H+aXRRIk/JuokqpVbPlOY8l+5z+JG6zn9onWpjcM\njQ19NP8C/CTwvcB8O+s3qEHKECkggyywCOe7EQqrB0uMU7IQ1srH8ENspuY16UQV\nqQPBlYVQOywVW6+25z+ILNyPjEdzukn6Oyh9ChU+m08Tw9SsBAV2vnrkUCT1wJhC\nc2X/i2WqBM4lExJu18tau0X26iKdbRZwx50OKUVn9w8AbW/iglCExYkpDs3VKH80\nallIBcfAOXJ00X6jhWETT2T9U1c1KiTqNHMgBflvG17CbkKkyaFIILdYIfMkW2EA\nngpOxaBC1cLJTvXGuTx4R54wpQNyt8k/7P/r8UY+FxBlA/3Upb9LidmItvd7yvVr\nzUMftEW5MYWWWnm6ZUS3Q6sCAwEAAQ==\n-----END PUBLIC KEY-----\n diff --git a/secrets/.gitkeep b/secrets/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/src/a12n/fixtures.py b/src/a12n/fixtures.py new file mode 100644 index 0000000..a143689 --- /dev/null +++ b/src/a12n/fixtures.py @@ -0,0 +1,108 @@ +import pytest + +import jwt + +from app.types import UserId + + +@pytest.fixture +def jwt_private_key(): + return """-----BEGIN PRIVATE KEY----- +MIIJQQIBADANBgkqhkiG9w0BAQEFAASCCSswggknAgEAAoICAQCp2c5sLjs1QCo4 +4K87M6DiLQrFlpN7abut9B71r6tZu2QnZM25qQ3t5GyG+/ZRUY/i5IoQp7fnq/gb +A//IYCL3ZcQavRjqJ7KmeG7XCLTC4rhnrZptGS+wRl0PApClx17Wu/K6+3Rh4MyH +10KL/c8EuFw6WC8C77ozv2JaVFNnHMV7cpXNBptTwIh+AbsxeiS60dTFPZ2xE19f +hXDEyxaj2jSjVCpI8no47HYPAsQpq3mfXDnKwhhi+SAezfoHec+dlhzZvrq4Sn0A +/0mVGNyla4fy+ff86y/EcwSuHKf6H38f5pdFEiT8m6iSqlVs+U5jyX7nP4kbrOf2 +idamNwyNDX00/wL8JPC9wHw76zeoQcoQKSCDLLAI57sRCqsHS4xTshDWysfwQ2ym +5jXpRBWpA8GVhVA7LBVbr7bnP4gs3I+MR3O6Sfo7KH0KFT6bTxPD1KwEBXa+euRQ +JPXAmEJzZf+LZaoEziUTEm7Xy1q7RfbqIp1tFnDHnQ4pRWf3DwBtb+KCUITFiSkO +zdUofzRqWUgFx8A5cnTRfqOFYRNPZP1TVzUqJOo0cyAF+W8bXsJuQqTJoUggt1gh +8yRbYQCeCk7FoELVwslO9ca5PHhHnjClA3K3yT/s/+vxRj4XEGUD/dSlv0uJ2Yi2 +93vK9WvNQx+0RbkxhZZaebplRLdDqwIDAQABAoICAETiW7BOE58mFbmZjhexeZcg +81RtHAUaPY5wCjpT82dh811yqWiZolePo2AfQad7L6KyUzgr/Q7NFMNIHO1T5/pz +4FODy13zmaWgBDvbgQvkzSrnIlEKvOd9sfILdUR2lgT6lpe0sV+cvvZ8m7WQyuu8 +JVNYPkCvntGr1aSSvHx+E61cLFrJSiduVyzbYOLRCaJmxSb1NUYCeFSSFskJIHZ1 +YZG36apKBL2fUMYHtiy8KYgy7BFKJH/HT3qOyM9NXKEppyu8CZgCRa4o2tvICHxi +HvGw5R1C+M1wZD6Eyq9LJNB4QXM2x59XNce9owWeGmen6Xq5rs51kmHPRymD++dz +eorzF1YhYrJ72l9dVUWtaZAlxna99u4u3++WRGOczkibpROa9A3PvO1ipMOCjZOa +SeeZ787Pl9ttgZb2z0Fmx4S5eSpAyrsmopEzZO6FRp/mfzHcV62ztx7ol0b7esPD +a5sU/Vq/lVv65LkhzTo2ZWK1Z5pB9w7N4UYR7c/lFCeaEmEtJOjB2X6+2yWEpkWX +t5pn+oOrQ61qssxR1I/eU/4x/on7rjmtlR4xqaliLkfDrk4Dd6YKiYmpQV7IhtV2 +KCYdE239U+C2DDa5PA5OqyyUQWTS+oBKL1nB4Y+co55+W+0VBTc68Gd9Gv4Wf9Mj +5GJsw0Qk6OUr23WetpURAoIBAQDZ6q4jt6vCDtm7AT7W1Ms+m2CVrcA3wiq3+dHn +zlRCqxLyq0hPVbUFWmFVL4/FVdBWbLES9uyLZDoOcxZ5Ffa+AC5t80vGgqKpRNES +BAxkOlFHvjW4yFIk7TxGEgbbfxZLIEyHdnJU1gZeY7hucMlVyu77iXyP/biF3DaC +WjC40t2oOxJVFI2f52NS5TU01mHv8qLVhTAppv2N3BOPfC4G88qB13rIXFCEWcpN +C6A2Agk9rfPl6cKh8cfzJMOd/ZzCM+Jp8ohd6Z9Cd4XC3t8fD9RjDWmlKib3wHpz +ZyaUk1eFKur4+3oqdqHeqc6GoFb2dp62LIwFuVYXF8FWTj4jAoIBAQDHiLe2Qfga +6BhCHPZSL7qNqgOPkgpLW3SgGPK3czR0F0hWjKihw6ZvbTGG8tqHCCPh5Xb3D0Lm +8FgBvr/jFDbGJGICtP6Lva0dZL3pIqJQRPlB/mDl3iXMoQM/yrxITyLU5Pswrs1c +ydQiVDAlWGaeqbPBUP0AmtSPQMgSi4l+lROCjoIfiXk8RPe//1QZrmghi/8x8eCm +iTmAuOQXpICCjV4gdPv1xdPec11KPkr0SA6frUzR62sYHEDVnmixXE9WNaRG1h1J +RUzNn7HxreimBp1LltSa43SFrIAHe7QZ0IuoDm8b27pEztlqDJBhSsH1eq95KxuG +MOONp2TQNojZAoIBAB+Af4AGUzwQbYVNHsprpJ3+VC4PGhR1azuBT8jU2PVySaDv +BdsCJtMJR7zKzVvXlfCIceit7XARIxtno74JYMwCtrOKUk/2HpGdsyOJlkj+7TUT +2CxIOSfBa88tV/RvIMfneWizxL9i2TTX8Zd1koVmerm+HFWsdfpT5UVeyGBPi1+A +epv2BqsxBfi7zb8/ppTLXKLFSDsdOtZBFErPxs+WepXeko9YWQNo/4e3wIdOMAvM +k8+OxWYnz6HklKrIONsSKQ7r0q7Q0QcIxDtxgIu6/Bb9n2IS/+Mc3hbEuJ0N178W +fzVTFUwCLlBD9+kaULf8WeE3+13wdvOLqZVSZkUCggEAcsELPvOjt/3BbcxwUYYH +mU+s6pYH+5zmbujKNn04LofxX21X0mjOQIkhEcZ7rWseD93DVIVfaafSRXaprvRC +KCRmhb4IIt/8PspgekMj7FwuqiidG7ZuMMhtMPPs4v04QA5M9IujqfidWvzmD6RO +qHNa4RQt3XouQxDzv86mTbl41f4VkgOjSOk1PyOd/4MRejGkm9nK5JxJCOHMtFg0 +XGDnQG1nNssGdYoNnhRDUUhbuLOXWac2GVCubOzEszQuoJsLFn4vq6MCb8OnOCJX +iZyGPCHLtiSYMASsQSGAy9PnbciXWAM/ljEMUvRU2M+Ayyg64MnM85kMVbxuu1yR +yQKCAQB7AkFj7k2gldP7BqMgemTQ2HgnxX3+DrFkvscY5rNHoYYyQ9iCOBh4CcVN +HLpYkvN6omfmIGaeIlrFChk0xPoq2rC8Es0NO++p6Ufa/RlPe/qInmQsVe0dl0ou +44fxYKe6UunrLa/MAEy9H1Ev4GeiN4KmB14dqSvV2MKHv79W4J+knEqPcj9VGNGM +KwU5VgPsz/bDn9LJYWggdIsQUZB5OsHl54gTmVPXGjoYRMCDk9jn/BTkEZUkXkHs +tFb7jcKo9XUoe0PG6Bx634wrlLx8clCwHWFfPbji4NdSGm8UTMKDAXKtMVwQyPzk +8qn2Iv/fqfJHIYs6DvxWgLy8ZoWL +-----END PRIVATE KEY----- +""" + + +@pytest.fixture +def jwt_public_key(): + return """-----BEGIN PUBLIC KEY----- +MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAqdnObC47NUAqOOCvOzOg +4i0KxZaTe2m7rfQe9a+rWbtkJ2TNuakN7eRshvv2UVGP4uSKEKe356v4GwP/yGAi +92XEGr0Y6ieypnhu1wi0wuK4Z62abRkvsEZdDwKQpcde1rvyuvt0YeDMh9dCi/3P +BLhcOlgvAu+6M79iWlRTZxzFe3KVzQabU8CIfgG7MXokutHUxT2dsRNfX4VwxMsW +o9o0o1QqSPJ6OOx2DwLEKat5n1w5ysIYYvkgHs36B3nPnZYc2b66uEp9AP9JlRjc +pWuH8vn3/OsvxHMErhyn+h9/H+aXRRIk/JuokqpVbPlOY8l+5z+JG6zn9onWpjcM +jQ19NP8C/CTwvcB8O+s3qEHKECkggyywCOe7EQqrB0uMU7IQ1srH8ENspuY16UQV +qQPBlYVQOywVW6+25z+ILNyPjEdzukn6Oyh9ChU+m08Tw9SsBAV2vnrkUCT1wJhC +c2X/i2WqBM4lExJu18tau0X26iKdbRZwx50OKUVn9w8AbW/iglCExYkpDs3VKH80 +allIBcfAOXJ00X6jhWETT2T9U1c1KiTqNHMgBflvG17CbkKkyaFIILdYIfMkW2EA +ngpOxaBC1cLJTvXGuTx4R54wpQNyt8k/7P/r8UY+FxBlA/3Upb9LidmItvd7yvVr +zUMftEW5MYWWWnm6ZUS3Q6sCAwEAAQ== +-----END PUBLIC KEY----- +""" + + +@pytest.fixture +def set_jwt_public_key(settings, jwt_public_key): + settings.JWT_PUBLIC_KEY = jwt_public_key + return settings + + +@pytest.fixture +def create_jwt_for_user(jwt_private_key): + def create_jwt(user_id: UserId, timestamp_expired_at: int) -> str: + payload = { + "sub": user_id, + "exp": timestamp_expired_at, + } + + return jwt.encode(payload=payload, key=jwt_private_key, algorithm="RS256") + + return create_jwt + + +@pytest.fixture +def jwt_user_valid_token(create_jwt_for_user): + return create_jwt_for_user( + user_id="user", + timestamp_expired_at=4700000000, # year of expiration 2118 + ) diff --git a/src/a12n/jwk_client.py b/src/a12n/jwk_client.py deleted file mode 100644 index dbefd32..0000000 --- a/src/a12n/jwk_client.py +++ /dev/null @@ -1,157 +0,0 @@ -import asyncio -import json -import logging -from dataclasses import dataclass - -import httpx -import jwt -from jwt.api_jwk import PyJWK, PyJWKSet -from jwt.api_jwt import decode_complete as decode_token -from jwt.exceptions import PyJWKSetError -from jwt.jwk_set_cache import JWKSetCache - -from app.types import DecodedValidToken - -logger = logging.getLogger(__name__) - - -class AsyncJWKClientException(Exception): - pass - - -@dataclass -class AsyncJWKClient: - """Async JW Keys client. - - Inspired and partially copy-pasted from 'jwt.jwks_client.PyJWKClient'. - The purpose is the same but querying the JWKS endpoint is async. - """ - - jwks_url: str - supported_signing_algorithms: list[str] - cache_lifespan: int = 1 * 60 * 60 * 24 # 1 day - - def __post_init__(self) -> None: - self.jwk_set_cache = JWKSetCache(self.cache_lifespan) - - # Lock is used to synchronize coroutines to prevent multiple concurrent attempts to refresh cached jwk_set - self.fetch_data_lock = asyncio.Lock() - - async def fetch_data(self) -> PyJWKSet: - try: - async with self.fetch_data_lock: - async with httpx.AsyncClient() as client: - response = await client.get(url=self.jwks_url) - - response.raise_for_status() - jwk_set_data = response.json() - - if not isinstance(jwk_set_data, dict): - raise AsyncJWKClientException("Fetched data from JWKS endpoint is JSON but not an object") - - jwk_set = PyJWKSet.from_dict(jwk_set_data) - self.jwk_set_cache.put(jwk_set) - - logger.info( - "Signing keys fetched. Key ids: '%s'", - (" ,").join([key.key_id for key in jwk_set.keys if key.key_id]), - ) - - return jwk_set - - except httpx.HTTPError as exc: - raise AsyncJWKClientException(f"Fail to fetch data from JWKS endpoint: '{exc}'") from exc - except json.JSONDecodeError as exc: - raise AsyncJWKClientException(f"Fetched data from JWKS endpoint not a JSON: '{exc}'") from exc - except PyJWKSetError as exc: - raise AsyncJWKClientException(exc) from exc - - async def get_jwk_set(self, *, refresh: bool = False) -> PyJWKSet: - jwk_set: PyJWKSet | None = None - - while self.fetch_data_lock.locked(): - await asyncio.sleep(0) - - if not refresh: - jwk_set = self.jwk_set_cache.get() - - if jwk_set is None: - jwk_set = await self.fetch_data() - - return jwk_set - - async def get_signing_keys(self, *, refresh: bool = False) -> list[PyJWK]: - jwk_set = await self.get_jwk_set(refresh=refresh) - - signing_keys = [ - jwk_set_key - for jwk_set_key in jwk_set.keys - if jwk_set_key.public_key_use - in [ - "sig", - None, - ] - and jwk_set_key.key_id - ] - - if not signing_keys: - raise AsyncJWKClientException("The JWKS endpoint did not contain any signing keys") - - return signing_keys - - async def get_signing_key(self, kid: str) -> PyJWK: - signing_keys = await self.get_signing_keys(refresh=False) - signing_key = self.match_kid(signing_keys, kid) - - if not signing_key: - # If no matching signing key from the jwk set, refresh the jwk set and try again. - signing_keys = await self.get_signing_keys(refresh=True) - signing_key = self.match_kid(signing_keys, kid) - - if not signing_key: - raise AsyncJWKClientException(f"Unable to find a signing key that matches: '{kid}'") - - return signing_key - - async def get_signing_key_from_jwt(self, token: str) -> PyJWK: - unverified = decode_token(token, options={"verify_signature": False}) - header = unverified["header"] - return await self.get_signing_key(header.get("kid")) - - @staticmethod - def match_kid(signing_keys: list[PyJWK], kid: str) -> PyJWK | None: - signing_key = None - - for key in signing_keys: - if key.key_id == kid: - signing_key = key - break - - return signing_key - - async def decode(self, token: str, options: dict | None = None) -> DecodedValidToken: - decode_options = { - "verify_aud": False, - "verify_exp": True, - "verify_iat": True, - "require": ["exp", "iat", "sub"], - } - - if options is not None: - decode_options.update(options) - - verify_signature = decode_options.get("verify_signature", True) - - try: - signing_key = (await self.get_signing_key_from_jwt(token)).key if verify_signature else "" - - verified_payload = jwt.decode( - jwt=token, - algorithms=self.supported_signing_algorithms, - key=signing_key, - options=decode_options, - ) - except jwt.PyJWTError as exc: - raise AsyncJWKClientException(exc) from exc - - return DecodedValidToken(sub=verified_payload["sub"], exp=verified_payload["exp"]) diff --git a/src/a12n/jwt_decode.py b/src/a12n/jwt_decode.py new file mode 100644 index 0000000..e88212e --- /dev/null +++ b/src/a12n/jwt_decode.py @@ -0,0 +1,15 @@ +from typing import Any + +import jwt + +from app.conf import get_app_settings +from app.types import DecodedValidToken + + +def decode(jwt_token: str, **kwargs: Any) -> DecodedValidToken: + """Validate and decode a JWT token with public key. + + Adjust validation parameters to project requirements (like algorithms, required claims, etc). + """ + decoded = jwt.decode(jwt=jwt_token, key=get_app_settings().JWT_PUBLIC_KEY, algorithms=["RS256"], **kwargs) + return DecodedValidToken(**decoded) diff --git a/src/a12n/tests/async_jwk_client/conftest.py b/src/a12n/tests/async_jwk_client/conftest.py deleted file mode 100644 index 0046681..0000000 --- a/src/a12n/tests/async_jwk_client/conftest.py +++ /dev/null @@ -1,73 +0,0 @@ -import pytest - -from respx import MockRouter, Route - -from a12n.jwk_client import AsyncJWKClient - -JWKS_URL = "https://auth.test.com/auth/realms/test-realm/protocol/openid-connect/certs" - - -@pytest.fixture -def expired_token(): - return "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjNMcjhuTjh1R29wUElMZlFvUGpfRCJ9.eyJpc3MiOiJodHRwczovL2Rldi1wcm50ZG1vMTYzc2NsczR4LnVzLmF1dGgwLmNvbS8iLCJhdWQiOiIxV1NiR1hGUnlhS0NsTHJvWHZteTlXdndrZUtHb1JvayIsImlhdCI6MTY5ODUyODU1OSwiZXhwIjoxNjk4NTI4ODU5LCJzdWIiOiJhdXRoMHw2NTNjMzI2MGEzMDQ0OGM1OTRhNjllMTIiLCJzaWQiOiI5MEZ3WFNDSFUtd0N3QmY0Y1YyQ3NZTnpBMldieDNUcSIsIm5vbmNlIjoiZTIxZWVhNTljNGY1MDg0N2Q3YzFhOGUzZjQ0NjVjYTcifQ.FO_xoMA9RGI7uAVauv00-zdORgkvCwyWfeAPd7lmU_nKzGp5avPa2MN66S0fjLKOxb8tgzrfpXYLUhDl1nqUvtj1A54-PfNW0n0ctdn2zk_CCOxsAjKyImlIgq7Y4DIuil0wikj7FdoWkB-bCBrKs7JaOoWkSHws9uQxRyvZzBwPHExW0myHWvB3G0x8g23PfSv2oALbvXBp0OAniGwru2Br9e2iXCVyGAUMTCpQmjPDAyfeYXGxF9BhxuX3e-GL80oyngBQK0kTxw-2Xz8LDSC-MI2jTs1gUo9qdVrg_1fzQtvAW9LGaWg5L_CJe92ZH3l1fBPfSh7Gc6uBtwF-YA" # noqa: E501 - - -@pytest.fixture -def token(): - # The token won't expire in ~100 years (expiration date 2123-10-05, it's more than enough to rely on it in test) - return "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjNMcjhuTjh1R29wUElMZlFvUGpfRCJ9.eyJpc3MiOiJodHRwczovL2Rldi1wcm50ZG1vMTYzc2NsczR4LnVzLmF1dGgwLmNvbS8iLCJhdWQiOiIxV1NiR1hGUnlhS0NsTHJvWHZteTlXdndrZUtHb1JvayIsImlhdCI6MTY5ODUyODE3MCwiZXhwIjo0ODUyMTI4MTcwLCJzdWIiOiJhdXRoMHw2NTNjMzI2MGEzMDQ0OGM1OTRhNjllMTIiLCJzaWQiOiI5MEZ3WFNDSFUtd0N3QmY0Y1YyQ3NZTnpBMldieDNUcSIsIm5vbmNlIjoiMTVhNWI2M2Y3MzI5MDcwMmU3MGViZmJlMDc5ODgxYmIifQ.FQYBaTnjKJHcskRl1WsB4kKQmyvXRcG8RDWlB2woSbzukZx7SnWghC1qRhYeqOLBUBpe3Iu_EzxgF26YDZJ28bKKNgL4fVmYak3jOg2nRP2lulrkF8USmkqT9Vx85hlIEVCisYOS6DJE0bHJL5WbHjCmDjQ6RGRyVZ3s6UPFXIwe2CMC_egAdWrsLYrgA1mqozQhwLJN2zSuObkDffkpHbX9XXB225v3-ryY-Rr0rPh9AOfKtEeMUEmNG0gsGyIbi0DoPDjAxlxCDx7ULVSChIKhUv4DKICqrqzHyopA7oE8LlpDbPTshQsL6L4u1EwUT7maP9VTcEQUTnp3Cu5msw" # noqa: E501 - - -@pytest.fixture -def matching_kid_data(): - return { - "keys": [ - { - "kty": "RSA", - "use": "sig", - "n": "oIQkRCY4X-_ItMUPt65wVIGewOJfjMhlu6HG_rHik5-dTK0o6oyUne2Gevetn2Vrn8NSIaARobLZ8expuJBYDS121w_RloC6MCuzlc-j_nHj-BcBOCqGWPVwKX4un0HueD3aW3buqzYcmX_9LhdSE8ARyN0S9O6RbYWDCTKFhrRXtIP4wzP8vdPGXGurtGIiBbhVCK1LHG2lO5Gt8IIQ_DAcX6swnXCfbHwR1OXc9Do06o8c7ZsZdjMty5b4Fpv8rAKA-HTP_One4yhKtqCMYs3_gcTeQdHi-0w634VnpdzC_0f_MMzNIgvXC8VdJgkGpa6jLBp3mTqaFUdkAXFYlw", # noqa: E501 - "e": "AQAB", - "kid": "3Lr8nN8uGopPILfQoPj_D", - "x5t": "f93zLhSTsgVJiS9JA0x8sHkaLMg", - "x5c": [ - "MIIDHTCCAgWgAwIBAgIJA2x2yGZ3QbP7MA0GCSqGSIb3DQEBCwUAMCwxKjAoBgNVBAMTIWRldi1wcm50ZG1vMTYzc2NsczR4LnVzLmF1dGgwLmNvbTAeFw0yMzEwMjcyMTU0MDhaFw0zNzA3MDUyMTU0MDhaMCwxKjAoBgNVBAMTIWRldi1wcm50ZG1vMTYzc2NsczR4LnVzLmF1dGgwLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKCEJEQmOF/vyLTFD7eucFSBnsDiX4zIZbuhxv6x4pOfnUytKOqMlJ3thnr3rZ9la5/DUiGgEaGy2fHsabiQWA0tdtcP0ZaAujArs5XPo/5x4/gXATgqhlj1cCl+Lp9B7ng92lt27qs2HJl//S4XUhPAEcjdEvTukW2FgwkyhYa0V7SD+MMz/L3Txlxrq7RiIgW4VQitSxxtpTuRrfCCEPwwHF+rMJ1wn2x8EdTl3PQ6NOqPHO2bGXYzLcuW+Bab/KwCgPh0z/zp3uMoSragjGLN/4HE3kHR4vtMOt+FZ6Xcwv9H/zDMzSIL1wvFXSYJBqWuoywad5k6mhVHZAFxWJcCAwEAAaNCMEAwDwYDVR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQUs39B3IFR+w1DjzejWvr7ZnNU9DswDgYDVR0PAQH/BAQDAgKEMA0GCSqGSIb3DQEBCwUAA4IBAQBimrxvOHAssNkEU7r5aTiz0lvtlshUe4zN6r9pqA7P0m0OoxJMiEkGrtPblVIL8hNSRcSsD0AmlA/dfP8RR39BY44ac4KLh5WfCRsi5LXENqmPeNvFiVGKL3UBvtpp6KIc9mT1vFY9Jhdh8srF3AS9STFD9O1/qexevvgqkq4ZCng+kHuRP9C7eU3yQUQeZ9QYWloZuPaNe7DT3J6v7OW1gy41xjUpL0GisRcCqsVI+dzHDYi1MFfvmUwxcmtg8GXYexuR6FUkgocdRXQsDQ1qIhS9M54WVEEgC+fat25Kb/Ca59GO3okJ4suqMAXKCtlbVh3JUgBsCjBdbk0tYwp0" - ], - "alg": "RS256", - }, - ] - } - - -@pytest.fixture -def not_matching_kid_data(): - return { - "keys": [ - { - "kty": "RSA", - "use": "sig", - "n": "zB0xsH539lpLVejR6Hq1bHN3EzDt_0tJyr5JVHz3GSnNYAaZzkqL7HyLlhwttl7_bRyZJeZ8X6aasBxVK2JCDc9U-0KMJXmSoJs1oWYRo79DqdzCXK3ZYXcgkvI9OWF1qVx76vbZVwiRv5qUzpINdLnsX2CXChyd0LFkg14bYrSfdN-eMmG1PXtHZufeKG6HW17PFXS7OwesMQIfQ9kFfSvgFkJgkNM0o6NaeB-ZPDvzfKmmpBXjtGcze0A56NdQ7Z42DRDURROS82sPISrX-iAt93tZ1F0IW_U4niIYc6NFcWPPXpQpiVDDwdrz-L1H63mSUDSDFsWVcv2xWry6kQ", # noqa: E501 - "e": "AQAB", - "kid": "ICOpsXGmpNaDPiljjRjiE", - "x5t": "1GDK6kGV6HvZ1m_-VdSKIFNEtEU", - "x5c": [ - "MIIDHTCCAgWgAwIBAgIJYH5BBAgUHJCVMA0GCSqGSIb3DQEBCwUAMCwxKjAoBgNVBAMTIWRldi1wcm50ZG1vMTYzc2NsczR4LnVzLmF1dGgwLmNvbTAeFw0yMzEwMjcyMTU0MDlaFw0zNzA3MDUyMTU0MDlaMCwxKjAoBgNVBAMTIWRldi1wcm50ZG1vMTYzc2NsczR4LnVzLmF1dGgwLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMwdMbB+d/ZaS1Xo0eh6tWxzdxMw7f9LScq+SVR89xkpzWAGmc5Ki+x8i5YcLbZe/20cmSXmfF+mmrAcVStiQg3PVPtCjCV5kqCbNaFmEaO/Q6ncwlyt2WF3IJLyPTlhdalce+r22VcIkb+alM6SDXS57F9glwocndCxZINeG2K0n3TfnjJhtT17R2bn3ihuh1tezxV0uzsHrDECH0PZBX0r4BZCYJDTNKOjWngfmTw783yppqQV47RnM3tAOejXUO2eNg0Q1EUTkvNrDyEq1/ogLfd7WdRdCFv1OJ4iGHOjRXFjz16UKYlQw8Ha8/i9R+t5klA0gxbFlXL9sVq8upECAwEAAaNCMEAwDwYDVR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQUD06HTot2wVkYAi77ZUSLnBDXJx8wDgYDVR0PAQH/BAQDAgKEMA0GCSqGSIb3DQEBCwUAA4IBAQB+YY5dSdTqsO1ErV+ZusJ/+z+WZ/Kf+rBhX7pbPGdL00mbwyF5kKc1g9Nd2S6Uz+w5FrU7cv3ABkppQNK/07ipyad9EOEd1rWiVp9/f18VB4OUqSgxZyXrAuqEVTFTPL3wwBOG/cw3pYtF2DZ26Y5tIxic4T+Z+dmtxZm/7387XrGisUTngdQvs3X+3xvjou2Z+pCIP1+Qe14S+WM77ZMa62O/rajtdsvOXWGh68oKitzaE0/gKpGjP8mBkd5Taxl+MLXU+Ea/RvVZOtvtOomANyyEXRX2WBN90djFqdlTF7Lhb7X6OTvZGt9ZmgXDtVGgBkeJxJOgFEAyDpn3ErWx" - ], - "alg": "RS256", - }, - ] - } - - -@pytest.fixture -def mock_jwk_endpoint(respx_mock: MockRouter): - return respx_mock.get(JWKS_URL) - - -@pytest.fixture -def mock_success_response(mock_jwk_endpoint: Route, matching_kid_data): - return mock_jwk_endpoint.respond(json=matching_kid_data) - - -@pytest.fixture -def jwk_client(): - return AsyncJWKClient(jwks_url=JWKS_URL, supported_signing_algorithms=["RS256"]) diff --git a/src/a12n/tests/async_jwk_client/tests_async_jwk_common.py b/src/a12n/tests/async_jwk_client/tests_async_jwk_common.py deleted file mode 100644 index 29238a4..0000000 --- a/src/a12n/tests/async_jwk_client/tests_async_jwk_common.py +++ /dev/null @@ -1,9 +0,0 @@ -from jwt.jwk_set_cache import JWKSetCache - - -def test_create_jwk_client_with_empty_pyjwkset_cache(jwk_client): - jwk_cache = jwk_client.jwk_set_cache - - assert isinstance(jwk_cache, JWKSetCache) - assert jwk_cache.lifespan == 1 * 60 * 60 * 24 # one day by default - assert jwk_cache.get() is None, "JWK set cache should be empty on init" diff --git a/src/a12n/tests/async_jwk_client/tests_async_jwk_decode.py b/src/a12n/tests/async_jwk_client/tests_async_jwk_decode.py deleted file mode 100644 index dfe4afe..0000000 --- a/src/a12n/tests/async_jwk_client/tests_async_jwk_decode.py +++ /dev/null @@ -1,35 +0,0 @@ -import pytest -from contextlib import nullcontext as does_not_raise - -from a12n.jwk_client import AsyncJWKClientException - -pytestmark = [ - pytest.mark.usefixtures("mock_success_response"), -] - - -async def test_decode_token_and_return_valid_data(jwk_client, token): - with does_not_raise(): - decoded_valid_token = await jwk_client.decode(token) - - assert decoded_valid_token.sub == "auth0|653c3260a30448c594a69e12" - assert decoded_valid_token.exp == 4852128170 # 2123-10-05 00:22:50 GMT+03:00 - - -async def test_raise_if_token_is_not_valid(jwk_client, expired_token): - with pytest.raises(AsyncJWKClientException, match="Signature has expired"): - await jwk_client.decode(expired_token) - - -async def test_decode_token_without_validation_if_signature_verification_disable(jwk_client, expired_token): - with does_not_raise(): - expired_token = await jwk_client.decode(expired_token, options={"verify_signature": False, "verify_exp": False}) - - assert expired_token.sub == "auth0|653c3260a30448c594a69e12" - - -async def test_raise_if_token_signed_with_not_supported_algorithm(jwk_client, token): - jwk_client.supported_signing_algorithms = ["EdDSA"] - - with pytest.raises(AsyncJWKClientException, match="The specified alg value is not allowed"): - await jwk_client.decode(token) diff --git a/src/a12n/tests/async_jwk_client/tests_async_jwk_fetch_data.py b/src/a12n/tests/async_jwk_client/tests_async_jwk_fetch_data.py deleted file mode 100644 index f397cbb..0000000 --- a/src/a12n/tests/async_jwk_client/tests_async_jwk_fetch_data.py +++ /dev/null @@ -1,57 +0,0 @@ -import logging -import pytest - -from a12n.jwk_client import AsyncJWKClientException - - -async def test_fetch_data_send_request(jwk_client, respx_mock, mock_success_response): - await jwk_client.fetch_data() - - sent_request = respx_mock.calls.last.request - assert sent_request.url == "https://auth.test.com/auth/realms/test-realm/protocol/openid-connect/certs" - - -async def test_fetch_data_return_fetched_data_and_cache(jwk_client, mock_success_response): - jwk_set = await jwk_client.fetch_data() - - assert jwk_client.jwk_set_cache.get() == jwk_set - assert len(jwk_set.keys) == 1 - assert jwk_set.keys[0].key_id == "3Lr8nN8uGopPILfQoPj_D" - assert jwk_set.keys[0].public_key_use == "sig" - assert jwk_set.keys[0].key_type == "RSA" - - -async def test_log_fetched_keys_ids(jwk_client, mock_success_response, caplog): - caplog.set_level(logging.INFO) - - await jwk_client.fetch_data() - - assert "3Lr8nN8uGopPILfQoPj_D" in caplog.text - - -async def test_fetch_data_raise_on_http_error(jwk_client, mock_jwk_endpoint): - mock_jwk_endpoint.respond(status_code=404) - - with pytest.raises(AsyncJWKClientException, match="Fail to fetch data"): - await jwk_client.fetch_data() - - -async def test_fetch_data_raise_on_not_json_response(jwk_client, mock_jwk_endpoint): - mock_jwk_endpoint.respond(content=b"not-json-content") - - with pytest.raises(AsyncJWKClientException, match="not a JSON"): - await jwk_client.fetch_data() - - -async def test_fetch_data_raise_if_serialized_json_response_not_a_dict(jwk_client, mock_jwk_endpoint): - mock_jwk_endpoint.respond(json=[]) - - with pytest.raises(AsyncJWKClientException, match="is JSON but not an object"): - await jwk_client.fetch_data() - - -async def test_fetch_data_raise_if_serialized_json_not_jwkset(jwk_client, mock_jwk_endpoint): - mock_jwk_endpoint.respond(json={"keys": []}) # no keys - - with pytest.raises(AsyncJWKClientException, match="did not contain any keys"): - await jwk_client.fetch_data() diff --git a/src/a12n/tests/async_jwk_client/tests_async_jwk_get_jwk_set.py b/src/a12n/tests/async_jwk_client/tests_async_jwk_get_jwk_set.py deleted file mode 100644 index 675d185..0000000 --- a/src/a12n/tests/async_jwk_client/tests_async_jwk_get_jwk_set.py +++ /dev/null @@ -1,56 +0,0 @@ -import asyncio -import pytest - -import httpx - - -@pytest.fixture -async def amock_latency_success_response(mock_jwk_endpoint, matching_kid_data): - async def latency_response(request): - await asyncio.sleep(0.1) - return httpx.Response(status_code=200, json=matching_kid_data) - - return mock_jwk_endpoint.mock(side_effect=latency_response) - - -async def test_get_jwk_set_return_jwk_key_set(jwk_client, respx_mock, mock_success_response): - jwk_set = await jwk_client.get_jwk_set() - - assert len(jwk_set.keys) == 1 - assert respx_mock.calls.call_count == 1 - - -async def test_get_jwk_set_is_cached(jwk_client, respx_mock, mock_success_response): - await jwk_client.get_jwk_set() - - await jwk_client.get_jwk_set() - - assert respx_mock.calls.call_count == 1 - - -@pytest.mark.freeze_time("2023-01-01 00:00:00") -async def test_get_jwk_set_cache_invalidated_when_lifespan_expired(jwk_client, respx_mock, freezer, mock_success_response): - await jwk_client.get_jwk_set() - freezer.move_to("2023-01-02 00:00:01") # 1 day + 1 second passed - - await jwk_client.get_jwk_set() - - assert respx_mock.calls.call_count == 2 - - -@pytest.mark.usefixtures("amock_latency_success_response") -async def test_get_jwk_set_corotines_do_not_try_to_update_cache_simultaneously(jwk_client, respx_mock): - first_get_jwk_task = asyncio.create_task(jwk_client.get_jwk_set()) - second_get_jwk_task = asyncio.create_task(jwk_client.get_jwk_set()) - - await asyncio.gather(first_get_jwk_task, second_get_jwk_task) - - assert respx_mock.calls.call_count == 1 - - -async def test_get_jwk_set_could_be_forced_to_update_cache(jwk_client, respx_mock, mock_success_response): - await jwk_client.get_jwk_set() - - await jwk_client.get_jwk_set(refresh=True) - - assert respx_mock.calls.call_count == 2 diff --git a/src/a12n/tests/async_jwk_client/tests_async_jwk_get_signing_key_from_jwt.py b/src/a12n/tests/async_jwk_client/tests_async_jwk_get_signing_key_from_jwt.py deleted file mode 100644 index c6310ac..0000000 --- a/src/a12n/tests/async_jwk_client/tests_async_jwk_get_signing_key_from_jwt.py +++ /dev/null @@ -1,61 +0,0 @@ -import pytest - -import httpx -import jwt -from jwt.api_jwk import PyJWK -from respx import Route - -from a12n.jwk_client import AsyncJWKClientException - - -@pytest.fixture -def mock_jwk_endpoint_first_call_wrong_kid_second_call_correct_kid(mock_jwk_endpoint: Route, matching_kid_data, not_matching_kid_data): - return mock_jwk_endpoint.mock( - side_effect=[ - httpx.Response(status_code=200, json=not_matching_kid_data), - httpx.Response(status_code=200, json=matching_kid_data), - ] - ) - - -async def test_get_signing_key_from_jwt(jwk_client, token, mock_success_response): - signing_key = await jwk_client.get_signing_key_from_jwt(token) - - assert isinstance(signing_key, PyJWK) - assert signing_key.key_id == "3Lr8nN8uGopPILfQoPj_D" - - -async def test_token_could_be_decoded_with_signing_key(jwk_client, token, mock_success_response): - signing_key = await jwk_client.get_signing_key_from_jwt(token) - - data = jwt.decode( - jwt=token, - key=signing_key.key, - algorithms=["RS256"], - options={"verify_aud": False}, - ) - - assert data == { - "iss": "https://dev-prntdmo163scls4x.us.auth0.com/", - "aud": "1WSbGXFRyaKClLroXvmy9WvwkeKGoRok", - "iat": 1698528170, - "exp": 4852128170, - "sub": "auth0|653c3260a30448c594a69e12", - "sid": "90FwXSCHU-wCwBf4cV2CsYNzA2Wbx3Tq", - "nonce": "15a5b63f73290702e70ebfbe079881bb", - } - - -async def test_raise_if_jwt_key_not_match_fetched_jwk_set(jwk_client, mock_jwk_endpoint, not_matching_kid_data, token): - mock_jwk_endpoint.respond(json=not_matching_kid_data) - - with pytest.raises(AsyncJWKClientException, match="Unable to find a signing key"): - await jwk_client.get_signing_key_from_jwt(token) - - -@pytest.mark.usefixtures("mock_jwk_endpoint_first_call_wrong_kid_second_call_correct_kid") -async def test_refresh_jwk_set_if_cached_jwk_set_not_match_jwt_kid(jwk_client, respx_mock, token): - signing_key = await jwk_client.get_signing_key_from_jwt(token) - - assert signing_key.key_id == "3Lr8nN8uGopPILfQoPj_D" - assert respx_mock.calls.call_count == 2 diff --git a/src/a12n/tests/tests_jwt_decode.py b/src/a12n/tests/tests_jwt_decode.py new file mode 100644 index 0000000..c78c6a1 --- /dev/null +++ b/src/a12n/tests/tests_jwt_decode.py @@ -0,0 +1,21 @@ +import pytest +from datetime import UTC, datetime + +import jwt + +from a12n import jwt_decode +from app.types import DecodedValidToken + + +def test_decode_valid_token(jwt_user_valid_token): + decoded = jwt_decode.decode(jwt_user_valid_token) + + assert isinstance(decoded, DecodedValidToken) + assert decoded.sub == "user" + assert decoded.exp == 4700000000 + + +@pytest.mark.freeze_time(datetime.fromtimestamp(4700000001, tz=UTC)) +def test_decode_expired_token(jwt_user_valid_token): + with pytest.raises(jwt.ExpiredSignatureError, match="Signature has expired"): + jwt_decode.decode(jwt_user_valid_token) diff --git a/src/app/conf/settings.py b/src/app/conf/settings.py index d4dcf4f..ae19d35 100644 --- a/src/app/conf/settings.py +++ b/src/app/conf/settings.py @@ -1,7 +1,7 @@ from functools import lru_cache from typing import Literal -from pydantic import AmqpDsn +from pydantic import AliasChoices, AmqpDsn, Field from pydantic_settings import BaseSettings, SettingsConfigDict @@ -14,8 +14,7 @@ class Settings(BaseSettings): WEBSOCKETS_PORT: int WEBSOCKETS_PATH: str - AUTH_JWKS_URL: str - AUTH_SUPPORTED_SIGNING_ALGORITHMS: list[str] + JWT_PUBLIC_KEY: str = Field(validation_alias=AliasChoices("jwt_public_key", "JWT_PUBLIC_KEY")) DEBUG: bool = False LOG_LEVEL: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "WARNING" @@ -27,6 +26,7 @@ class Settings(BaseSettings): env_file=".env", env_file_encoding="utf-8", case_sensitive=True, + secrets_dir="../secrets", ) diff --git a/src/conftest.py b/src/conftest.py index 042509a..45ee854 100644 --- a/src/conftest.py +++ b/src/conftest.py @@ -3,6 +3,7 @@ from app.conf import Settings, get_app_settings pytest_plugins = [ + "a12n.fixtures", "app.fixtures", "storage.fixtures", ] diff --git a/src/handlers/messages_handler.py b/src/handlers/messages_handler.py index 0b04c50..5794d89 100644 --- a/src/handlers/messages_handler.py +++ b/src/handlers/messages_handler.py @@ -1,18 +1,18 @@ -from collections.abc import Callable, Coroutine +from collections.abc import Callable from dataclasses import dataclass from typing import Any +from jwt.exceptions import InvalidTokenError from websockets import WebSocketServerProtocol -from a12n.jwk_client import AsyncJWKClient, AsyncJWKClientException -from app import conf +from a12n import jwt_decode from handlers.dto import AuthMessage, IncomingMessage, SubscribeMessage, SuccessResponseMessage, UnsubscribeMessage from handlers.exceptions import WebsocketMessageException from storage import SubscriptionStorage from storage.exceptions import StorageOperationException from storage.storage_updaters import StorageUserSubscriber, StorageUserUnsubscriber, StorageWebSocketRegister -AsyncMessageHandler = Callable[[WebSocketServerProtocol, Any], Coroutine[Any, Any, SuccessResponseMessage]] +MessageHandler = Callable[[WebSocketServerProtocol, Any], SuccessResponseMessage] @dataclass @@ -20,31 +20,28 @@ class WebSocketMessagesHandler: storage: SubscriptionStorage def __post_init__(self) -> None: - settings = conf.get_app_settings() - self.jwk_client = AsyncJWKClient(jwks_url=settings.AUTH_JWKS_URL, supported_signing_algorithms=settings.AUTH_SUPPORTED_SIGNING_ALGORITHMS) - - self.message_handlers: dict[str, AsyncMessageHandler] = { + self.message_handlers: dict[str, MessageHandler] = { "Authenticate": self.handle_auth_message, "Subscribe": self.handle_subscribe_message, "Unsubscribe": self.handle_unsubscribe_message, } - async def handle_message(self, websocket: WebSocketServerProtocol, message: IncomingMessage) -> SuccessResponseMessage: - return await self.message_handlers[message.message_type](websocket, message) + def handle_message(self, websocket: WebSocketServerProtocol, message: IncomingMessage) -> SuccessResponseMessage: + return self.message_handlers[message.message_type](websocket, message) - async def handle_auth_message(self, websocket: WebSocketServerProtocol, message: AuthMessage) -> SuccessResponseMessage: + def handle_auth_message(self, websocket: WebSocketServerProtocol, message: AuthMessage) -> SuccessResponseMessage: try: - validated_token = await self.jwk_client.decode(message.params.token.get_secret_value()) + validated_token = jwt_decode.decode(jwt_token=message.params.token.get_secret_value()) StorageWebSocketRegister(storage=self.storage, websocket=websocket, validated_token=validated_token)() - except (AsyncJWKClientException, StorageOperationException) as exc: + except (InvalidTokenError, StorageOperationException) as exc: raise WebsocketMessageException(str(exc), message) from exc return SuccessResponseMessage.model_construct(incoming_message=message) - async def handle_subscribe_message(self, websocket: WebSocketServerProtocol, message: SubscribeMessage) -> SuccessResponseMessage: + def handle_subscribe_message(self, websocket: WebSocketServerProtocol, message: SubscribeMessage) -> SuccessResponseMessage: StorageUserSubscriber(storage=self.storage, websocket=websocket, event=message.params.event)() return SuccessResponseMessage.model_construct(incoming_message=message) - async def handle_unsubscribe_message(self, websocket: WebSocketServerProtocol, message: UnsubscribeMessage) -> SuccessResponseMessage: + def handle_unsubscribe_message(self, websocket: WebSocketServerProtocol, message: UnsubscribeMessage) -> SuccessResponseMessage: StorageUserUnsubscriber(storage=self.storage, websocket=websocket, event=message.params.event)() return SuccessResponseMessage.model_construct(incoming_message=message) diff --git a/src/handlers/tests/messages_handler/conftest.py b/src/handlers/tests/messages_handler/conftest.py index f91052b..c33c2a6 100644 --- a/src/handlers/tests/messages_handler/conftest.py +++ b/src/handlers/tests/messages_handler/conftest.py @@ -3,17 +3,9 @@ from handlers.dto import AuthMessage, SubscribeMessage, UnsubscribeMessage from handlers.messages_handler import WebSocketMessagesHandler - -@pytest.fixture(autouse=True) -def settings(settings): - settings.AUTH_JWKS_URL = "https://auth.clowns.com/auth/realms/clowns-realm/protocol/openid-connect/certs" - settings.AUTH_SUPPORTED_SIGNING_ALGORITHMS = ["RS256"] - return settings - - -@pytest.fixture -def force_token_validation(mocker, valid_token): - return mocker.patch("a12n.jwk_client.AsyncJWKClient.decode", return_value=valid_token) +pytestmark = [ + pytest.mark.usefixtures("set_jwt_public_key"), +] @pytest.fixture @@ -22,8 +14,8 @@ def message_handler(storage): @pytest.fixture -def auth_message(): - return AuthMessage(message_id=23, message_type="Authenticate", params={"token": "some-valid-token-value"}) +def auth_message(jwt_user_valid_token): + return AuthMessage(message_id=23, message_type="Authenticate", params={"token": jwt_user_valid_token}) @pytest.fixture diff --git a/src/handlers/tests/messages_handler/tests_auth_message_handler.py b/src/handlers/tests/messages_handler/tests_auth_message_handler.py index 304b28d..c19a610 100644 --- a/src/handlers/tests/messages_handler/tests_auth_message_handler.py +++ b/src/handlers/tests/messages_handler/tests_auth_message_handler.py @@ -1,6 +1,6 @@ import pytest +from datetime import UTC, datetime -from a12n.jwk_client import AsyncJWKClientException from app.types import DecodedValidToken from handlers.dto import SuccessResponseMessage from handlers.exceptions import WebsocketMessageException @@ -8,7 +8,7 @@ from storage.storage_updaters import StorageWebSocketRegister pytestmark = [ - pytest.mark.usefixtures("force_token_validation"), + pytest.mark.usefixtures("set_jwt_public_key"), ] @@ -22,32 +22,32 @@ def auth_handler(message_handler: WebSocketMessagesHandler, ws): return lambda auth_message: message_handler.handle_auth_message(ws, auth_message) -async def test_auth_handler_response_on_correct_auth_message(auth_handler, auth_message): - auth_response = await auth_handler(auth_message) +def test_auth_handler_response_on_correct_auth_message(auth_handler, auth_message): + auth_response = auth_handler(auth_message) assert isinstance(auth_response, SuccessResponseMessage) assert auth_response.message_type == "SuccessResponse" assert auth_response.incoming_message == auth_message -async def test_auth_handler_register_websocket_in_storage(auth_handler, ws, auth_message, mocker, storage, valid_token): +def test_auth_handler_register_websocket_in_storage(auth_handler, ws, auth_message, mocker, storage): spy_websocket_register = mocker.spy(StorageWebSocketRegister, "__call__") - await auth_handler(auth_message) + auth_handler(auth_message) assert storage.is_websocket_registered(ws) is True spy_websocket_register.assert_called_once() called_service = spy_websocket_register.call_args.args[0] assert called_service.storage == storage assert called_service.websocket == ws - assert called_service.validated_token == valid_token + assert called_service.validated_token == DecodedValidToken(sub="user", exp=4700000000) -async def test_auth_handler_raise_if_user_send_token_for_different_user(auth_handler, auth_message, storage, ws, register_ws, ya_user_decoded_valid_token): +def test_auth_handler_raise_if_user_send_token_for_different_user(auth_handler, auth_message, storage, ws, register_ws, ya_user_decoded_valid_token): register_ws(ws, ya_user_decoded_valid_token) with pytest.raises(WebsocketMessageException) as exc_info: - await auth_handler(auth_message) # send valid user1 token while connection registered with ya_user + auth_handler(auth_message) # send valid user1 token while connection registered with ya_user raised_exception = exc_info.value assert raised_exception.errors == ["The user has different public id"] @@ -55,13 +55,12 @@ async def test_auth_handler_raise_if_user_send_token_for_different_user(auth_han assert storage.is_websocket_registered(ws) is True, "The existed connection should not be touched" -async def test_auth_handler_raise_if_user_try_to_auth_with_expired_token(auth_handler, ws, auth_message, force_token_validation, storage): - force_token_validation.side_effect = AsyncJWKClientException("The token is expired") - +@pytest.mark.freeze_time(datetime.fromtimestamp(4700000001, tz=UTC)) # one second after token expiration +def test_auth_handler_raise_if_user_try_to_auth_with_expired_token(auth_handler, ws, auth_message, storage): with pytest.raises(WebsocketMessageException) as exc_info: - await auth_handler(auth_message) + auth_handler(auth_message) raised_exception = exc_info.value - assert raised_exception.errors == ["The token is expired"] + assert raised_exception.errors == ["Signature has expired"] assert raised_exception.incoming_message == auth_message assert storage.is_websocket_registered(ws) is False, "The ws should not be added to registered websockets" diff --git a/src/handlers/tests/messages_handler/tests_message_handler_common.py b/src/handlers/tests/messages_handler/tests_message_handler_common.py index bc0ecd0..1c00fc7 100644 --- a/src/handlers/tests/messages_handler/tests_message_handler_common.py +++ b/src/handlers/tests/messages_handler/tests_message_handler_common.py @@ -8,31 +8,26 @@ def get_message_handler(storage): return lambda: WebSocketMessagesHandler(storage) -def test_message_handler_jwk_client_settings(message_handler): - assert message_handler.jwk_client.jwks_url == "https://auth.clowns.com/auth/realms/clowns-realm/protocol/openid-connect/certs" - assert message_handler.jwk_client.supported_signing_algorithms == ["RS256"] - - -@pytest.mark.usefixtures("force_token_validation") -async def test_message_handler_call_auth_handler_on_auth_message(get_message_handler, auth_message, mocker, ws): +@pytest.mark.usefixtures("set_jwt_public_key") +def test_message_handler_call_auth_handler_on_auth_message(get_message_handler, auth_message, mocker, ws): spy_auth_handler = mocker.spy(WebSocketMessagesHandler, "handle_auth_message") - await get_message_handler().handle_message(ws, auth_message) + get_message_handler().handle_message(ws, auth_message) - spy_auth_handler.assert_awaited_once() + spy_auth_handler.assert_called_once() -async def test_message_handler_call_subscribe_handler_on_subscribe_message(get_message_handler, subscribe_message, mocker, ws_registered): +def test_message_handler_call_subscribe_handler_on_subscribe_message(get_message_handler, subscribe_message, mocker, ws_registered): spy_subscribe_handler = mocker.spy(WebSocketMessagesHandler, "handle_subscribe_message") - await get_message_handler().handle_message(ws_registered, subscribe_message) + get_message_handler().handle_message(ws_registered, subscribe_message) - spy_subscribe_handler.assert_awaited_once() + spy_subscribe_handler.assert_called_once() -async def test_message_handler_call_unsubscribe_handler_on_unsubscribe_message(get_message_handler, unsubscribe_message, mocker, ws_subscribed): +def test_message_handler_call_unsubscribe_handler_on_unsubscribe_message(get_message_handler, unsubscribe_message, mocker, ws_subscribed): spy_unsubscribe_handler = mocker.spy(WebSocketMessagesHandler, "handle_unsubscribe_message") - await get_message_handler().handle_message(ws_subscribed, unsubscribe_message) + get_message_handler().handle_message(ws_subscribed, unsubscribe_message) - spy_unsubscribe_handler.assert_awaited_once() + spy_unsubscribe_handler.assert_called_once() diff --git a/src/handlers/tests/messages_handler/tests_subscirbe_message_handler.py b/src/handlers/tests/messages_handler/tests_subscirbe_message_handler.py index 3794220..0a800a3 100644 --- a/src/handlers/tests/messages_handler/tests_subscirbe_message_handler.py +++ b/src/handlers/tests/messages_handler/tests_subscirbe_message_handler.py @@ -9,17 +9,17 @@ def subscribe_handler(message_handler: WebSocketMessagesHandler, ws_registered): return lambda subscribe_message: message_handler.handle_subscribe_message(ws_registered, subscribe_message) -async def test_subscribe_handler_return_success_response(subscribe_handler, subscribe_message): - subscribe_response = await subscribe_handler(subscribe_message) +def test_subscribe_handler_return_success_response(subscribe_handler, subscribe_message): + subscribe_response = subscribe_handler(subscribe_message) assert subscribe_response.message_type == "SuccessResponse" assert subscribe_response.incoming_message == subscribe_message -async def test_subscribe_handler_call_storage_subscriber_under_the_hood(subscribe_handler, subscribe_message, mocker, storage, ws_registered): +def test_subscribe_handler_call_storage_subscriber_under_the_hood(subscribe_handler, subscribe_message, mocker, storage, ws_registered): spy_storage_subscriber = mocker.spy(StorageUserSubscriber, "__call__") - await subscribe_handler(subscribe_message) + subscribe_handler(subscribe_message) spy_storage_subscriber.assert_called_once() called_service = spy_storage_subscriber.call_args.args[0] diff --git a/src/handlers/tests/messages_handler/tests_unsubscirbe_message_handler.py b/src/handlers/tests/messages_handler/tests_unsubscirbe_message_handler.py index 546db33..e10ca2f 100644 --- a/src/handlers/tests/messages_handler/tests_unsubscirbe_message_handler.py +++ b/src/handlers/tests/messages_handler/tests_unsubscirbe_message_handler.py @@ -9,17 +9,17 @@ def unsubscribe_handler(message_handler: WebSocketMessagesHandler, ws_subscribed return lambda unsubscribe_message: message_handler.handle_unsubscribe_message(ws_subscribed, unsubscribe_message) -async def test_unsubscribe_handler_return_success_response(unsubscribe_handler, unsubscribe_message): - unsubscribe_response = await unsubscribe_handler(unsubscribe_message) +def test_unsubscribe_handler_return_success_response(unsubscribe_handler, unsubscribe_message): + unsubscribe_response = unsubscribe_handler(unsubscribe_message) assert unsubscribe_response.message_type == "SuccessResponse" assert unsubscribe_response.incoming_message == unsubscribe_message -async def test_unsubscribe_handler_call_storage_unsubscriber_under_the_hood(unsubscribe_handler, unsubscribe_message, mocker, storage, ws_subscribed): +def test_unsubscribe_handler_call_storage_unsubscriber_under_the_hood(unsubscribe_handler, unsubscribe_message, mocker, storage, ws_subscribed): spy_storage_unsubscriber = mocker.spy(StorageUserUnsubscriber, "__call__") - await unsubscribe_handler(unsubscribe_message) + unsubscribe_handler(unsubscribe_message) spy_storage_unsubscriber.assert_called_once() called_service = spy_storage_unsubscriber.call_args.args[0] diff --git a/src/handlers/websockets_handler.py b/src/handlers/websockets_handler.py index 379dcc3..b328aca 100644 --- a/src/handlers/websockets_handler.py +++ b/src/handlers/websockets_handler.py @@ -50,7 +50,7 @@ async def process_message(self, websocket: WebSocketServerProtocol, raw_message: return ErrorResponseMessage.model_construct(errors=exc.errors(include_url=False, include_context=False), incoming_message=None) try: - success_response = await self.messages_handler.handle_message(websocket, message) + success_response = self.messages_handler.handle_message(websocket, message) except WebsocketMessageException as exc: return exc.as_error_message() diff --git a/src/tests/functional/conftest.py b/src/tests/functional/conftest.py index 5424b2b..b39faab 100644 --- a/src/tests/functional/conftest.py +++ b/src/tests/functional/conftest.py @@ -9,16 +9,13 @@ from handlers import WebSocketsAccessGuardian, WebSocketsHandler -@pytest.fixture -def force_token_validation(mocker, valid_token): - return mocker.patch("a12n.jwk_client.AsyncJWKClient.decode", return_value=valid_token) - - @pytest.fixture(autouse=True) -def _adjust_settings(settings, unused_tcp_port): +def adjust_settings(settings, unused_tcp_port, jwt_public_key): settings.BROKER_QUEUE = None # force consumer to create a queue with a random name settings.WEBSOCKETS_HOST = "0.0.0.0" # noqa: S104 settings.WEBSOCKETS_PORT = unused_tcp_port + settings.JWT_PUBLIC_KEY = jwt_public_key + return settings @pytest.fixture @@ -98,12 +95,12 @@ async def send_and_recv(ws_client, message: str): @pytest.fixture -def auth_message_data(): +def auth_message_data(jwt_user_valid_token): return { "message_id": 777, "message_type": "Authenticate", "params": { - "token": "valid-token", + "token": jwt_user_valid_token, }, } @@ -114,5 +111,5 @@ def auth_message(auth_message_data): @pytest.fixture -async def ws_client_authenticated(auth_message, ws_client_send_and_recv, ws_client, force_token_validation): +async def ws_client_authenticated(auth_message, ws_client_send_and_recv, ws_client): return await ws_client_send_and_recv(ws_client, auth_message) diff --git a/src/tests/functional/tests_authentication.py b/src/tests/functional/tests_authentication.py index f0f3294..e0e9c76 100644 --- a/src/tests/functional/tests_authentication.py +++ b/src/tests/functional/tests_authentication.py @@ -2,7 +2,6 @@ pytestmark = [ pytest.mark.rabbitmq, - pytest.mark.usefixtures("force_token_validation"), ]