diff --git a/env.example b/env.example index 637d36d..a132e8a 100644 --- a/env.example +++ b/env.example @@ -2,7 +2,8 @@ DEBUG=False BROKER_URL=amqp://guest:guest@localhost:5672/ BROKER_EXCHANGE=guest-exchange -BROKER_ROUTING_KEYS=["test-create-event-key", "test-notification-event-key"] +BROKER_QUEUE=websockets-notifications-queue +BROKER_ROUTING_KEYS_CONSUME_FROM=["test-event-boobs"] WEBSOCKETS_HOST=localhost WEBSOCKETS_PORT=6789 diff --git a/src/app/conf/settings.py b/src/app/conf/settings.py index dd6b7f6..8c43778 100644 --- a/src/app/conf/settings.py +++ b/src/app/conf/settings.py @@ -9,7 +9,8 @@ class Settings(BaseSettings): BROKER_URL: AmqpDsn BROKER_EXCHANGE: str - BROKER_ROUTING_KEYS: list[str] + BROKER_QUEUE: str + BROKER_ROUTING_KEYS_CONSUME_FROM: list[str] WEBSOCKETS_HOST: str WEBSOCKETS_PORT: int WEBSOCKETS_PATH: str diff --git a/src/app/fixtures.py b/src/app/fixtures.py index 4811005..581197b 100644 --- a/src/app/fixtures.py +++ b/src/app/fixtures.py @@ -4,10 +4,15 @@ @pytest.fixture -def ws(): - return MockedWebSocketServerProtocol() +def create_ws(): + return lambda: MockedWebSocketServerProtocol() @pytest.fixture -def ya_ws(): - return MockedWebSocketServerProtocol() +def ws(create_ws): + return create_ws() + + +@pytest.fixture +def ya_ws(create_ws): + return create_ws() diff --git a/src/conftest.py b/src/conftest.py index 0eab3d2..6e7047e 100644 --- a/src/conftest.py +++ b/src/conftest.py @@ -15,7 +15,8 @@ def settings(mocker): return_value=Settings( BROKER_URL="amqp://guest:guest@localhost/", BROKER_EXCHANGE="test-exchange", - BROKER_ROUTING_KEYS=["test-routing-key", "ya-test-routing-key"], + BROKER_QUEUE="test-queue", + BROKER_ROUTING_KEYS_CONSUME_FROM=["test-routing-key", "ya-test-routing-key"], WEBSOCKETS_HOST="localhost", WEBSOCKETS_PORT=50000, WEBSOCKETS_PATH="/v2/test-subscription-websocket", diff --git a/src/consumer/consumer.py b/src/consumer/consumer.py new file mode 100644 index 0000000..83f0730 --- /dev/null +++ b/src/consumer/consumer.py @@ -0,0 +1,73 @@ +import asyncio +from dataclasses import dataclass +import logging +from typing import Protocol + +import aio_pika + +from app import conf +from consumer.dto import ConsumedMessage +from consumer.dto import OutgoingMessage +from storage.subscription_storage import SubscriptionStorage +from pydantic import ValidationError +import websockets + + +logger = logging.getLogger(__name__) + + +class ConsumerProtocol(Protocol): + async def consume(self) -> None: + pass + + +@dataclass +class Consumer: + storage: SubscriptionStorage + + def __post_init__(self) -> None: + settings = conf.get_app_settings() + + self.broker_url: str = str(settings.BROKER_URL) + self.exchange: str = settings.BROKER_EXCHANGE + self.queue: str = settings.BROKER_QUEUE + self.routing_keys_consume_from: list[str] = settings.BROKER_ROUTING_KEYS_CONSUME_FROM + + async def consume(self, stop_signal: asyncio.Future) -> None: + connection = await aio_pika.connect_robust(self.broker_url) + + async with connection: + channel = await connection.channel() + + exchange = await channel.declare_exchange(self.exchange, type=aio_pika.ExchangeType.DIRECT) + queue = await channel.declare_queue(name=self.queue, exclusive=True) + + for routing_key in self.routing_keys_consume_from: + await queue.bind(exchange=exchange, routing_key=routing_key) + + await queue.consume(self.on_message) + + await stop_signal + + async def on_message(self, raw_message: aio_pika.abc.AbstractIncomingMessage) -> None: + async with raw_message.process(): + processed_messages = self.parse_message(raw_message) + + if processed_messages: + self.broadcast_subscribers(self.storage, processed_messages) + + @staticmethod + def parse_message(raw_message: aio_pika.abc.AbstractIncomingMessage) -> ConsumedMessage | None: + try: + return ConsumedMessage.model_validate_json(raw_message.body) + except ValidationError as exc: + logger.error("Consumed message not in expected format. Errors: %s", exc.errors()) + return None + + @staticmethod + def broadcast_subscribers(storage: SubscriptionStorage, message: ConsumedMessage) -> None: + websockets_to_broadcast = storage.get_event_subscribers_websockets(message.event) + + if websockets_to_broadcast: + outgoing_message = OutgoingMessage.model_construct(payload=message) + websockets.broadcast(websockets=websockets_to_broadcast, message=outgoing_message.model_dump_json()) diff --git a/src/consumer/dto.py b/src/consumer/dto.py new file mode 100644 index 0000000..fd1b87a --- /dev/null +++ b/src/consumer/dto.py @@ -0,0 +1,15 @@ +from pydantic import BaseModel +from pydantic import ConfigDict +from typing import Literal +from app.types import Event + + +class ConsumedMessage(BaseModel): + model_config = ConfigDict(extra="allow") + + event: Event + + +class OutgoingMessage(BaseModel): + message_type: Literal["EventNotification"] = "EventNotification" + payload: ConsumedMessage diff --git a/src/consumer/tests/conftest.py b/src/consumer/tests/conftest.py new file mode 100644 index 0000000..e447d98 --- /dev/null +++ b/src/consumer/tests/conftest.py @@ -0,0 +1,37 @@ +import pytest + +from consumer.consumer import Consumer +from dataclasses import dataclass +from contextlib import asynccontextmanager +from typing import AsyncGenerator + + +@pytest.fixture(autouse=True) +def _adjust_settings(settings): + settings.BROKER_URL = "amqp://guest:guest@localhost/" + settings.BROKER_EXCHANGE = "test-exchange" + settings.BROKER_QUEUE = "test-queue" + settings.BROKER_ROUTING_KEYS_CONSUME_FROM = [ + "event-routing-key", + "ya-event-routing-key", + ] + + +@pytest.fixture +def consumer(storage) -> Consumer: + return Consumer(storage=storage) + + +@dataclass +class MockedIncomingMessage: + """The simplest Incoming message class that represent incoming amqp message. + + The safer choice is to use 'aio_pika.abc.AbstractIncomingMessage,' but the test setup will be significantly more challenging. + """ + + body: bytes + + @asynccontextmanager + async def process(self) -> AsyncGenerator: + """Do nothing, just for compatibility with aio_pika.abc.AbstractIncomingMessage.""" + yield None diff --git a/src/consumer/tests/tests_consumer.py b/src/consumer/tests/tests_consumer.py new file mode 100644 index 0000000..f01db13 --- /dev/null +++ b/src/consumer/tests/tests_consumer.py @@ -0,0 +1,30 @@ +import asyncio +import pytest + + +@pytest.fixture +def run_consume_task(consumer): + def run(stop_signal: asyncio.Future): + return asyncio.create_task(consumer.consume(stop_signal)) + + return run + + +def test_consumer_attributes(consumer): + assert consumer.broker_url == "amqp://guest:guest@localhost/" + assert consumer.exchange == "test-exchange" + assert consumer.queue == "test-queue" + assert consumer.routing_keys_consume_from == [ + "event-routing-key", + "ya-event-routing-key", + ] + + +async def test_consumer_correctly_stopped_on_stop_signal(run_consume_task): + stop_signal = asyncio.get_running_loop().create_future() + consumer_task = run_consume_task(stop_signal) + + stop_signal.set_result(None) + + await asyncio.sleep(0.1) # get enough time to stop the task + assert consumer_task.done() is True diff --git a/src/consumer/tests/tests_consumer_on_message.py b/src/consumer/tests/tests_consumer_on_message.py new file mode 100644 index 0000000..74a527d --- /dev/null +++ b/src/consumer/tests/tests_consumer_on_message.py @@ -0,0 +1,68 @@ +from contextlib import nullcontext as does_not_raise +import json +import pytest + +from consumer.tests.conftest import MockedIncomingMessage + + +@pytest.fixture(autouse=True) +def mock_broadcast(mocker): + return mocker.patch("websockets.broadcast") + + +def python_to_bytes(data: dict) -> bytes: + return json.dumps(data).encode() + + +@pytest.fixture +def broker_message_data(event): + return { + "event": event, + "size": 3, + "quantity": 2, + } + + +@pytest.fixture +def ya_ws_subscribed(create_ws, ya_valid_token, ws_register_and_subscribe, event): + return ws_register_and_subscribe(create_ws(), ya_valid_token, event) + + +@pytest.fixture +def ya_user_ws_subscribed(create_ws, ya_user_valid_token, ws_register_and_subscribe, event): + return ws_register_and_subscribe(create_ws(), ya_user_valid_token, event) + + +@pytest.fixture +def consumed_message(broker_message_data): + return MockedIncomingMessage(body=python_to_bytes(broker_message_data)) + + +@pytest.fixture +def on_message(consumer, consumed_message): + return lambda message=consumed_message: consumer.on_message(message) + + +async def test_broadcast_message_to_subscriber_websockets(on_message, ws_subscribed, mock_broadcast, mocker): + await on_message() + + mock_broadcast.assert_called_once_with(websockets=[ws_subscribed], message=mocker.ANY) + + +async def test_broadcast_message_to_all_subscribers_websockets(on_message, mock_broadcast, ws_subscribed, ya_ws_subscribed, ya_user_ws_subscribed, mocker): + await on_message() + + mock_broadcast.assert_called_once() + broadcasted_websockets = mock_broadcast.call_args.kwargs["websockets"] + assert len(broadcasted_websockets) == 3 + assert set(broadcasted_websockets) == {ws_subscribed, ya_ws_subscribed, ya_user_ws_subscribed} + + +async def test_log_and_do_nothing_if_message_not_expected_format(on_message, ws_subscribed, mock_broadcast, consumed_message, caplog): + consumed_message.body = b"invalid-json" + + with does_not_raise(): + await on_message(consumed_message) + + assert "Consumed message not in expected format" in caplog.text + mock_broadcast.assert_not_called() diff --git a/src/entrypoint.py b/src/entrypoint.py index 29ba9a1..ec8d2cc 100644 --- a/src/entrypoint.py +++ b/src/entrypoint.py @@ -2,12 +2,15 @@ import signal import websockets +import logging from app import conf from handlers import WebSocketsHandler from handlers import WebSocketsAccessGuardian from storage.subscription_storage import SubscriptionStorage +logging.basicConfig(level=logging.INFO) + def create_stop_signal() -> asyncio.Future[None]: loop = asyncio.get_running_loop() diff --git a/src/storage/fixtures.py b/src/storage/fixtures.py index 80c8952..7742bd4 100644 --- a/src/storage/fixtures.py +++ b/src/storage/fixtures.py @@ -66,3 +66,13 @@ def subscribe(ws, event): def ws_subscribed(ws_registered, subscribe_ws, event): subscribe_ws(ws_registered, event) return ws_registered + + +@pytest.fixture +def ws_register_and_subscribe(register_ws, subscribe_ws): + def register_and_subscribe(ws, token, event): + register_ws(ws, token) + subscribe_ws(ws, event) + return ws + + return register_and_subscribe diff --git a/src/storage/subscription_storage.py b/src/storage/subscription_storage.py index 73f26b7..5d08477 100644 --- a/src/storage/subscription_storage.py +++ b/src/storage/subscription_storage.py @@ -26,22 +26,15 @@ def get_websocket_user_id(self, websocket: WebSocketServerProtocol) -> UserId | websocket_meta = self.registered_websockets.get(websocket) return websocket_meta.user_id if websocket_meta else None - def get_event_subscribers_user_ids(self, event: Event) -> set[UserId]: - return self.subscriptions.get(event) or set() + def get_event_subscribers_websockets(self, event: Event) -> list[WebSocketServerProtocol]: + subscribers_user_ids = self.subscriptions.get(event) or set() - def is_event_has_subscribers(self, event: Event) -> bool: - return event in self.subscriptions + user_websockets = [] - def get_users_websockets(self, user_ids: set[UserId]) -> list[WebSocketServerProtocol]: - users_websockets = [] + for user_id in subscribers_user_ids: + user_websockets.extend(self.user_connections[user_id].websockets) - for user_id in user_ids: - user_connection_meta = self.user_connections.get(user_id) - - if user_connection_meta: - users_websockets.extend(user_connection_meta.websockets) - - return users_websockets + return user_websockets def get_expired_websockets(self) -> list[WebSocketServerProtocol]: now_timestamp = time.time()