Skip to content

Commit

Permalink
Websockets handler and functional test (#8)
Browse files Browse the repository at this point in the history
* Rename `WebSocketMessageHandler` to `WebSocketMessagesHandler`

* WebsocketsHandler and functional tests
  • Loading branch information
nkiryanov authored Jun 8, 2024
1 parent 32646f7 commit 3813945
Show file tree
Hide file tree
Showing 15 changed files with 393 additions and 31 deletions.
39 changes: 39 additions & 0 deletions src/entrypoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import asyncio
import signal

import websockets

from app import conf
from handlers.websockets_handler import WebSocketsHandler
from storage.subscription_storage import SubscriptionStorage


def create_stop_signal() -> asyncio.Future[None]:
loop = asyncio.get_running_loop()
stop_signal = loop.create_future()
loop.add_signal_handler(signal.SIGTERM, stop_signal.set_result, None)
return stop_signal


async def app_runner(settings: conf.Settings, websockets_handler: WebSocketsHandler) -> None:
async with websockets.serve(
ws_handler=websockets_handler.websockets_handler,
host=settings.WEBSOCKETS_HOST,
port=settings.WEBSOCKETS_PORT,
):
await asyncio.Future()


async def main() -> None:
settings = conf.get_app_settings()
storage = SubscriptionStorage()
websockets_handler = WebSocketsHandler(storage=storage)

await app_runner(
settings=settings,
websockets_handler=websockets_handler,
)


if __name__ == "__main__":
asyncio.run(main())
4 changes: 2 additions & 2 deletions src/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from handlers.message_handler import WebSocketMessageHandler
from handlers.websockets_handler import WebSocketsHandler

__all__ = [
"WebSocketMessageHandler",
"WebSocketsHandler",
]
9 changes: 5 additions & 4 deletions src/handlers/dto.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Literal

from pydantic import Field
from pydantic_core import ErrorDetails
from pydantic import SecretStr

from pydantic import BaseModel
from app.types import Event
Expand All @@ -10,13 +11,13 @@


class AuthMessageParams(BaseModel):
token: str
token: SecretStr


class AuthMessage(BaseModel):
message_id: messageId
message_type: Literal["Authenticate"]
params: AuthMessageParams = Field(exclude=True)
params: AuthMessageParams


class SubscribeParams(BaseModel):
Expand Down Expand Up @@ -45,5 +46,5 @@ class SuccessResponseMessage(BaseModel):

class ErrorResponseMessage(BaseModel):
message_type: Literal["ErrorResponse"] = "ErrorResponse"
error_detail: str
errors: list[ErrorDetails | str]
incoming_message: IncomingMessage | None # may be null if incoming message was not valid
6 changes: 3 additions & 3 deletions src/handlers/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
class WebsocketMessageException(Exception):
"""Raise if error occurred during message handling."""

def __init__(self, error_detail: str, incoming_message: IncomingMessage | None = None) -> None:
self.error_detail = error_detail
def __init__(self, error_detail: str, incoming_message: IncomingMessage) -> None:
self.errors = [error_detail]
self.incoming_message = incoming_message

def as_error_message(self) -> ErrorResponseMessage:
return ErrorResponseMessage.model_construct(error_detail=self.error_detail, incoming_message=self.incoming_message)
return ErrorResponseMessage.model_construct(errors=self.errors, incoming_message=self.incoming_message)
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


@dataclass
class WebSocketMessageHandler:
class WebSocketMessagesHandler:
storage: SubscriptionStorage

def __post_init__(self) -> None:
Expand All @@ -40,7 +40,7 @@ async def handle_message(self, websocket: WebSocketServerProtocol, message: Inco

async def handle_auth_message(self, websocket: WebSocketServerProtocol, message: AuthMessage) -> SuccessResponseMessage:
try:
validated_token = await self.jwk_client.decode(message.params.token)
validated_token = await self.jwk_client.decode(message.params.token.get_secret_value())
StorageWebSocketRegister(storage=self.storage, websocket=websocket, validated_token=validated_token)()
except (AsyncJWKClientException, StorageOperationException) as exc:
raise WebsocketMessageException(str(exc), message)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from handlers import WebSocketMessageHandler
from handlers.messages_handler import WebSocketMessagesHandler
from handlers.dto import AuthMessage, SubscribeMessage, UnsubscribeMessage


Expand All @@ -12,13 +12,13 @@ def settings(settings):


@pytest.fixture
def force_token_on_validation(mocker, valid_token):
def force_token_validation(mocker, valid_token):
return mocker.patch("a12n.jwk_client.AsyncJWKClient.decode", return_value=valid_token)


@pytest.fixture
def message_handler(storage):
return WebSocketMessageHandler(storage=storage)
return WebSocketMessagesHandler(storage=storage)


@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from app.types import DecodedValidToken
from handlers.dto import SuccessResponseMessage
from handlers.exceptions import WebsocketMessageException
from handlers import WebSocketMessageHandler
from handlers.messages_handler import WebSocketMessagesHandler
from storage.storage_updaters import StorageWebSocketRegister

pytestmark = [
pytest.mark.usefixtures("force_token_on_validation"),
pytest.mark.usefixtures("force_token_validation"),
]


Expand All @@ -18,7 +18,7 @@ def ya_user_decoded_valid_token():


@pytest.fixture
def auth_handler(message_handler: WebSocketMessageHandler, ws):
def auth_handler(message_handler: WebSocketMessagesHandler, ws):
return lambda auth_message: message_handler.handle_auth_message(ws, auth_message)


Expand Down Expand Up @@ -50,18 +50,18 @@ async def test_auth_handler_raise_if_user_send_token_for_different_user(auth_han
await auth_handler(auth_message) # send valid user1 token while connection registered with ya_user

raised_exception = exc_info.value
assert raised_exception.error_detail == "The user has different public id"
assert raised_exception.errors == ["The user has different public id"]
assert raised_exception.incoming_message == auth_message
assert storage.is_websocket_registered(ws) is True, "The existed connection should not be touched"


async def test_auth_handler_raise_if_user_try_to_auth_with_expired_token(auth_handler, ws, auth_message, force_token_on_validation, storage):
force_token_on_validation.side_effect = AsyncJWKClientException("The token is expired")
async def test_auth_handler_raise_if_user_try_to_auth_with_expired_token(auth_handler, ws, auth_message, force_token_validation, storage):
force_token_validation.side_effect = AsyncJWKClientException("The token is expired")

with pytest.raises(WebsocketMessageException) as exc_info:
await auth_handler(auth_message)

raised_exception = exc_info.value
assert raised_exception.error_detail == "The token is expired"
assert raised_exception.errors == ["The token is expired"]
assert raised_exception.incoming_message == auth_message
assert storage.is_websocket_registered(ws) is False, "The ws should not be added to registered websockets"
Original file line number Diff line number Diff line change
@@ -1,36 +1,36 @@
import pytest
from handlers.message_handler import WebSocketMessageHandler
from handlers.messages_handler import WebSocketMessagesHandler


@pytest.fixture
def get_message_handler(storage):
return lambda: WebSocketMessageHandler(storage)
return lambda: WebSocketMessagesHandler(storage)


def test_message_handler_jwk_client_settings(message_handler):
assert message_handler.jwk_client.jwks_url == "https://auth.clowns.com/auth/realms/clowns-realm/protocol/openid-connect/certs"
assert message_handler.jwk_client.supported_signing_algorithms == ["RS256"]


@pytest.mark.usefixtures("force_token_on_validation")
@pytest.mark.usefixtures("force_token_validation")
async def test_message_handler_call_auth_handler_on_auth_message(get_message_handler, auth_message, mocker, ws):
spy_auth_handler = mocker.spy(WebSocketMessageHandler, "handle_auth_message")
spy_auth_handler = mocker.spy(WebSocketMessagesHandler, "handle_auth_message")

await get_message_handler().handle_message(ws, auth_message)

spy_auth_handler.assert_awaited_once()


async def test_message_handler_call_subscribe_handler_on_subscribe_message(get_message_handler, subscribe_message, mocker, ws_registered):
spy_subscribe_handler = mocker.spy(WebSocketMessageHandler, "handle_subscribe_message")
spy_subscribe_handler = mocker.spy(WebSocketMessagesHandler, "handle_subscribe_message")

await get_message_handler().handle_message(ws_registered, subscribe_message)

spy_subscribe_handler.assert_awaited_once()


async def test_message_handler_call_unsubscribe_handler_on_unsubscribe_message(get_message_handler, unsubscribe_message, mocker, ws_subscribed):
spy_unsubscribe_handler = mocker.spy(WebSocketMessageHandler, "handle_unsubscribe_message")
spy_unsubscribe_handler = mocker.spy(WebSocketMessagesHandler, "handle_unsubscribe_message")

await get_message_handler().handle_message(ws_subscribed, unsubscribe_message)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pytest
from handlers.message_handler import WebSocketMessageHandler
from handlers.messages_handler import WebSocketMessagesHandler
from storage.storage_updaters import StorageUserSubscriber


@pytest.fixture
def subscribe_handler(message_handler: WebSocketMessageHandler, ws_registered):
def subscribe_handler(message_handler: WebSocketMessagesHandler, ws_registered):
return lambda subscribe_message: message_handler.handle_subscribe_message(ws_registered, subscribe_message)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pytest
from handlers.message_handler import WebSocketMessageHandler
from handlers.messages_handler import WebSocketMessagesHandler
from storage.storage_updaters import StorageUserUnsubscriber


@pytest.fixture
def unsubscribe_handler(message_handler: WebSocketMessageHandler, ws_subscribed):
def unsubscribe_handler(message_handler: WebSocketMessagesHandler, ws_subscribed):
return lambda unsubscribe_message: message_handler.handle_unsubscribe_message(ws_subscribed, unsubscribe_message)


Expand Down
72 changes: 72 additions & 0 deletions src/handlers/websockets_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from dataclasses import dataclass
import logging
from typing import Annotated

import pydantic
from pydantic import Field
from pydantic import TypeAdapter
from websockets import WebSocketServerProtocol
from websockets.exceptions import ConnectionClosedError

from app import conf
from handlers.dto import AuthMessage
from handlers.dto import ErrorResponseMessage
from handlers.dto import IncomingMessage
from handlers.dto import SuccessResponseMessage
from handlers.exceptions import WebsocketMessageException
from storage.storage_updaters import StorageWebSocketRemover
from storage.subscription_storage import SubscriptionStorage
from handlers.messages_handler import WebSocketMessagesHandler

logger = logging.getLogger(__name__)

IncomingMessageAdapter = TypeAdapter(Annotated[IncomingMessage, Field(discriminator="message_type")])
AuthMessageAdapter = TypeAdapter(Annotated[AuthMessage, Field(discriminator="message_type")])


@dataclass
class WebSocketsHandler:
storage: SubscriptionStorage

def __post_init__(self) -> None:
settings = conf.get_app_settings()
self.websockets_path = settings.WEBSOCKETS_PATH

self.messages_handler = WebSocketMessagesHandler(storage=self.storage)

async def websockets_handler(self, websocket: WebSocketServerProtocol) -> None:
if websocket.path != self.websockets_path:
return

try:
async for message in websocket:
response_message = await self.process_message(websocket=websocket, raw_message=message)
await websocket.send(response_message.model_dump_json(exclude_none=True))
except ConnectionClosedError:
logger.warning("Trying to send message to closed connection. Connection id: '%s'", websocket.id)
finally:
StorageWebSocketRemover(storage=self.storage, websocket=websocket)()

async def process_message(self, websocket: WebSocketServerProtocol, raw_message: str | bytes) -> SuccessResponseMessage | ErrorResponseMessage:
try:
message = self.parse_raw_message(websocket, raw_message)
except pydantic.ValidationError as exc:
return ErrorResponseMessage.model_construct(errors=exc.errors(include_url=False, include_context=False), incoming_message=None)

try:
success_response = await self.messages_handler.handle_message(websocket, message)
except WebsocketMessageException as exc:
return exc.as_error_message()

return success_response

def parse_raw_message(self, websocket: WebSocketServerProtocol, raw_message: str | bytes) -> IncomingMessage:
adapter = self.get_message_adapter(websocket)
return adapter.validate_json(raw_message)

def get_message_adapter(self, websocket: WebSocketServerProtocol) -> TypeAdapter:
"""Only registered websockets can send all messages. Unregistered websockets can only send Auth messages."""
if self.storage.is_websocket_registered(websocket):
return IncomingMessageAdapter

return AuthMessageAdapter
Loading

0 comments on commit 3813945

Please sign in to comment.