Skip to content

Commit

Permalink
Handlers: AuthMessage and it's handler
Browse files Browse the repository at this point in the history
  • Loading branch information
nkiryanov committed Jun 8, 2024
1 parent 49cdda9 commit 9ddf26c
Show file tree
Hide file tree
Showing 13 changed files with 229 additions and 13 deletions.
5 changes: 5 additions & 0 deletions src/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from handlers.websocket_auth_message_handler import WebsocketAuthMessageHandler

__all__ = [
"WebsocketAuthMessageHandler",
]
49 changes: 49 additions & 0 deletions src/handlers/dto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Literal

from pydantic import Field

from pydantic import BaseModel
from app.types import Event


messageId = int | str


class AuthMessageParams(BaseModel):
token: str


class AuthMessage(BaseModel):
message_id: messageId
message_type: Literal["Authenticate"]
params: AuthMessageParams = Field(exclude=True)


class SubscribeParams(BaseModel):
event: Event


class SubscribeMessage(BaseModel):
message_id: messageId
message_type: Literal["Subscribe"]
params: SubscribeParams


class UnsubscribeMessage(BaseModel):
message_id: messageId
message_type: Literal["Unsubscribe"]
params: SubscribeParams


IncomingMessage = AuthMessage | SubscribeMessage | UnsubscribeMessage


class SuccessResponseMessage(BaseModel):
status: Literal["Success"] = "Success"
incoming_message: IncomingMessage


class ErrorResponseMessage(BaseModel):
status: Literal["Error"] = "Error"
error_detail: str
incoming_message: IncomingMessage | None # may be null if incoming message was not valid
13 changes: 13 additions & 0 deletions src/handlers/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from handlers.dto import IncomingMessage
from handlers.dto import ErrorResponseMessage


class WebsocketMessageException(Exception):
"""Raise if error occurred during message handling."""

def __init__(self, error_detail: str, incoming_message: IncomingMessage | None = None) -> None:
self.error_detail = error_detail
self.incoming_message = incoming_message

def as_error_message(self) -> ErrorResponseMessage:
return ErrorResponseMessage.model_construct(error_detail=self.error_detail, incoming_message=self.incoming_message)
14 changes: 14 additions & 0 deletions src/handlers/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest

from app.testing import MockedWebSocketServerProtocol
from storage import SubscriptionStorage


@pytest.fixture
def storage():
return SubscriptionStorage()


@pytest.fixture
def ws():
return MockedWebSocketServerProtocol()
104 changes: 104 additions & 0 deletions src/handlers/tests/tests_auth_message_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import pytest

from a12n.jwk_client import AsyncJWKClientException
from app.types import DecodedValidToken
from handlers.dto import AuthMessage
from handlers.dto import SuccessResponseMessage
from handlers.exceptions import WebsocketMessageException
from handlers import WebsocketAuthMessageHandler
from storage.storage_updaters import StorageWebSocketRegister


@pytest.fixture(autouse=True)
def settings(settings):
settings.AUTH_JWKS_URL = "https://auth.clowns.com/auth/realms/clowns-realm/protocol/openid-connect/certs"
settings.AUTH_SUPPORTED_SIGNING_ALGORITHMS = ["RS256"]
return settings


@pytest.fixture
def decoded_valid_token():
return DecodedValidToken(sub="user1", exp="4852128170") # 2123 year


@pytest.fixture
def ya_user_decoded_valid_token():
return DecodedValidToken(sub="ya_user", exp="4852128170")


@pytest.fixture(autouse=True)
def set_token_validation(mocker, decoded_valid_token):
return mocker.patch("a12n.jwk_client.AsyncJWKClient.decode", return_value=decoded_valid_token)


@pytest.fixture
def auth_message():
return AuthMessage(message_id=23, message_type="Authenticate", params={"token": "some-valid-token-value"})


@pytest.fixture
def register_ws(storage):
return lambda ws, decoded_valid_token: StorageWebSocketRegister(storage, ws, decoded_valid_token)()


@pytest.fixture
def auth_message_handler(storage):
return WebsocketAuthMessageHandler(storage=storage)


@pytest.fixture
def handle(auth_message_handler, ws):
def get_handler(message):
return auth_message_handler.handle_message(ws, message)

return get_handler


def test_auth_handler_jwk_client_settings(auth_message_handler):
assert auth_message_handler.jwk_client.jwks_url == "https://auth.clowns.com/auth/realms/clowns-realm/protocol/openid-connect/certs"
assert auth_message_handler.jwk_client.supported_signing_algorithms == ["RS256"]


async def test_auth_handler_response_on_correct_authenticate(handle, auth_message):
auth_response = await handle(auth_message)

assert isinstance(auth_response, SuccessResponseMessage)
assert auth_response.status == "Success"
assert auth_response.incoming_message == auth_message


async def test_auth_handler_register_websocket_in_storage(handle, ws, auth_message, mocker, storage, decoded_valid_token):
spy_websocket_register = mocker.spy(StorageWebSocketRegister, "__call__")

await handle(auth_message)

assert storage.is_websocket_registered(ws) is True
spy_websocket_register.assert_called_once()
called_service = spy_websocket_register.call_args.args[0]
assert called_service.storage == storage
assert called_service.websocket == ws
assert called_service.validated_token == decoded_valid_token


async def test_auth_handler_raise_if_user_send_token_for_different_user(handle, auth_message, storage, ws, register_ws, ya_user_decoded_valid_token):
register_ws(ws, ya_user_decoded_valid_token)

with pytest.raises(WebsocketMessageException) as exc_info:
await handle(auth_message) # send valid user1 token while connection registered with ya_user

raised_exception = exc_info.value
assert raised_exception.error_detail == "The user has different public id"
assert raised_exception.incoming_message == auth_message
assert storage.is_websocket_registered(ws) is True, "The existed connection should not be touched"


async def test_auth_handler_raise_if_user_try_to_auth_with_expired_token(handle, ws, auth_message, set_token_validation, storage):
set_token_validation.side_effect = AsyncJWKClientException("The token is expired")

with pytest.raises(WebsocketMessageException) as exc_info:
await handle(auth_message)

raised_exception = exc_info.value
assert raised_exception.error_detail == "The token is expired"
assert raised_exception.incoming_message == auth_message
assert storage.is_websocket_registered(ws) is False, "The ws should not be added to registered websockets"
31 changes: 31 additions & 0 deletions src/handlers/websocket_auth_message_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from dataclasses import dataclass

from websockets import WebSocketServerProtocol

from a12n.jwk_client import AsyncJWKClient
from a12n.jwk_client import AsyncJWKClientException
from app import conf
from handlers.dto import AuthMessage
from handlers.dto import SuccessResponseMessage
from handlers.exceptions import WebsocketMessageException
from storage.exceptions import StorageOperationException
from storage.storage_updaters import StorageWebSocketRegister
from storage.subscription_storage import SubscriptionStorage


@dataclass
class WebsocketAuthMessageHandler:
storage: SubscriptionStorage

def __post_init__(self) -> None:
settings = conf.get_app_settings()
self.jwk_client = AsyncJWKClient(jwks_url=settings.AUTH_JWKS_URL, supported_signing_algorithms=settings.AUTH_SUPPORTED_SIGNING_ALGORITHMS)

async def handle_message(self, websocket: WebSocketServerProtocol, message: AuthMessage) -> SuccessResponseMessage:
try:
validated_token = await self.jwk_client.decode(message.params.token)
StorageWebSocketRegister(storage=self.storage, websocket=websocket, validated_token=validated_token)()
except (AsyncJWKClientException, StorageOperationException) as exc:
raise WebsocketMessageException(str(exc), message)

return SuccessResponseMessage.model_construct(incoming_message=message)
8 changes: 4 additions & 4 deletions src/storage/storage_updaters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from storage.storage_updaters.storage_connection_register import StorageConnectionRegister
from storage.storage_updaters.storage_connection_remover import StorageConnectionRemover
from storage.storage_updaters.storage_websocket_register import StorageWebSocketRegister
from storage.storage_updaters.storage_websocket_remover import StorageWebSocketRemover
from storage.storage_updaters.storage_user_subscriber import StorageUserSubscriber
from storage.storage_updaters.storage_user_unsubscriber import StorageUserUnsubscriber

__all__ = [
"StorageConnectionRegister",
"StorageWebSocketRegister",
"StorageUserSubscriber",
"StorageUserUnsubscriber",
"StorageConnectionRemover",
"StorageWebSocketRemover",
]
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


@dataclass
class StorageConnectionRegister(BaseService):
class StorageWebSocketRegister(BaseService):
"""Add or update websocket in storage
If websocket not registered: just register it.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


@dataclass
class StorageConnectionRemover(BaseService):
class StorageWebSocketRemover(BaseService):
""" "Remove connection from storage.
If websocket is not registered then nothing to do.
Expand Down
2 changes: 1 addition & 1 deletion src/storage/subscription_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_websocket_user_id(self, websocket: WebSocketServerProtocol) -> UserId |
def get_event_subscribers_user_ids(self, event: Event) -> set[UserId]:
return self.subscriptions.get(event) or set()

def is_event_active(self, event: Event) -> bool:
def is_event_has_subscribers(self, event: Event) -> bool:
return event in self.subscriptions

def get_users_websockets(self, user_ids: set[UserId]) -> list[WebSocketServerProtocol]:
Expand Down
4 changes: 2 additions & 2 deletions src/storage/tests/storage_updaters/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from app.types import DecodedValidToken
from app.testing import MockedWebSocketServerProtocol
from storage.storage_updaters.storage_connection_register import StorageConnectionRegister
from storage.storage_updaters.storage_websocket_register import StorageWebSocketRegister
from storage.storage_updaters.storage_user_subscriber import StorageUserSubscriber


Expand Down Expand Up @@ -34,7 +34,7 @@ def ya_ws():
@pytest.fixture
def register_ws(storage):
def register(ws, token):
StorageConnectionRegister(storage, ws, token)()
StorageWebSocketRegister(storage, ws, token)()

return register

Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import pytest

from storage.exceptions import StorageOperationException
from storage.storage_updaters import StorageConnectionRegister
from storage.storage_updaters import StorageWebSocketRegister
from storage.subscription_storage import SubscriptionStorage
from storage.types import ConnectedUserMeta
from storage.types import WebSocketMeta


@pytest.fixture
def register(storage: SubscriptionStorage):
return lambda ws, token: StorageConnectionRegister(
return lambda ws, token: StorageWebSocketRegister(
storage=storage,
websocket=ws,
validated_token=token,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from contextlib import nullcontext as does_not_raise
import pytest

from storage.storage_updaters import StorageConnectionRemover
from storage.storage_updaters import StorageWebSocketRemover


@pytest.fixture
Expand All @@ -12,7 +12,7 @@ def ya_user_ws_registered(ya_ws, ya_user_valid_token, register_ws):

@pytest.fixture
def remove(storage):
return lambda ws: StorageConnectionRemover(storage, ws)()
return lambda ws: StorageWebSocketRemover(storage, ws)()


def test_remove_websocket_and_user_subscriptions_from_storage(remove, ws_subscribed, storage):
Expand Down

0 comments on commit 9ddf26c

Please sign in to comment.