-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
257 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import asyncio | ||
from dataclasses import dataclass | ||
import logging | ||
from typing import Protocol | ||
|
||
import aio_pika | ||
|
||
from app import conf | ||
from consumer.dto import ConsumedMessage | ||
from consumer.dto import OutgoingMessage | ||
from storage.subscription_storage import SubscriptionStorage | ||
from pydantic import ValidationError | ||
import websockets | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ConsumerProtocol(Protocol): | ||
async def consume(self) -> None: | ||
pass | ||
|
||
|
||
@dataclass | ||
class Consumer: | ||
storage: SubscriptionStorage | ||
|
||
def __post_init__(self) -> None: | ||
settings = conf.get_app_settings() | ||
|
||
self.broker_url: str = str(settings.BROKER_URL) | ||
self.exchange: str = settings.BROKER_EXCHANGE | ||
self.queue: str = settings.BROKER_QUEUE | ||
self.routing_keys_consume_from: list[str] = settings.BROKER_ROUTING_KEYS_CONSUME_FROM | ||
|
||
async def consume(self, stop_signal: asyncio.Future) -> None: | ||
connection = await aio_pika.connect_robust(self.broker_url) | ||
|
||
async with connection: | ||
channel = await connection.channel() | ||
|
||
exchange = await channel.declare_exchange(self.exchange, type=aio_pika.ExchangeType.DIRECT) | ||
queue = await channel.declare_queue(name=self.queue, exclusive=True) | ||
|
||
for routing_key in self.routing_keys_consume_from: | ||
await queue.bind(exchange=exchange, routing_key=routing_key) | ||
|
||
await queue.consume(self.on_message) | ||
|
||
await stop_signal | ||
|
||
async def on_message(self, raw_message: aio_pika.abc.AbstractIncomingMessage) -> None: | ||
async with raw_message.process(): | ||
processed_messages = self.parse_message(raw_message) | ||
|
||
if processed_messages: | ||
self.broadcast_subscribers(self.storage, processed_messages) | ||
|
||
@staticmethod | ||
def parse_message(raw_message: aio_pika.abc.AbstractIncomingMessage) -> ConsumedMessage | None: | ||
try: | ||
return ConsumedMessage.model_validate_json(raw_message.body) | ||
except ValidationError as exc: | ||
logger.error("Consumed message not in expected format. Errors: %s", exc.errors()) | ||
return None | ||
|
||
@staticmethod | ||
def broadcast_subscribers(storage: SubscriptionStorage, message: ConsumedMessage) -> None: | ||
websockets_to_broadcast = storage.get_event_subscribers_websockets(message.event) | ||
|
||
if websockets_to_broadcast: | ||
outgoing_message = OutgoingMessage.model_construct(payload=message) | ||
websockets.broadcast(websockets=websockets_to_broadcast, message=outgoing_message.model_dump_json()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from pydantic import BaseModel | ||
from pydantic import ConfigDict | ||
from typing import Literal | ||
from app.types import Event | ||
|
||
|
||
class ConsumedMessage(BaseModel): | ||
model_config = ConfigDict(extra="allow") | ||
|
||
event: Event | ||
|
||
|
||
class OutgoingMessage(BaseModel): | ||
message_type: Literal["EventNotification"] = "EventNotification" | ||
payload: ConsumedMessage |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import pytest | ||
|
||
from consumer.consumer import Consumer | ||
from dataclasses import dataclass | ||
from contextlib import asynccontextmanager | ||
from typing import AsyncGenerator | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def _adjust_settings(settings): | ||
settings.BROKER_URL = "amqp://guest:guest@localhost/" | ||
settings.BROKER_EXCHANGE = "test-exchange" | ||
settings.BROKER_QUEUE = "test-queue" | ||
settings.BROKER_ROUTING_KEYS_CONSUME_FROM = [ | ||
"event-routing-key", | ||
"ya-event-routing-key", | ||
] | ||
|
||
|
||
@pytest.fixture | ||
def consumer(storage) -> Consumer: | ||
return Consumer(storage=storage) | ||
|
||
|
||
@dataclass | ||
class MockedIncomingMessage: | ||
"""The simplest Incoming message class that represent incoming amqp message. | ||
The safer choice is to use 'aio_pika.abc.AbstractIncomingMessage,' but the test setup will be significantly more challenging. | ||
""" | ||
|
||
body: bytes | ||
|
||
@asynccontextmanager | ||
async def process(self) -> AsyncGenerator: | ||
"""Do nothing, just for compatibility with aio_pika.abc.AbstractIncomingMessage.""" | ||
yield None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import asyncio | ||
import pytest | ||
|
||
|
||
@pytest.fixture | ||
def run_consume_task(consumer): | ||
def run(stop_signal: asyncio.Future): | ||
return asyncio.create_task(consumer.consume(stop_signal)) | ||
|
||
return run | ||
|
||
|
||
def test_consumer_attributes(consumer): | ||
assert consumer.broker_url == "amqp://guest:guest@localhost/" | ||
assert consumer.exchange == "test-exchange" | ||
assert consumer.queue == "test-queue" | ||
assert consumer.routing_keys_consume_from == [ | ||
"event-routing-key", | ||
"ya-event-routing-key", | ||
] | ||
|
||
|
||
async def test_consumer_correctly_stopped_on_stop_signal(run_consume_task): | ||
stop_signal = asyncio.get_running_loop().create_future() | ||
consumer_task = run_consume_task(stop_signal) | ||
|
||
stop_signal.set_result(None) | ||
|
||
await asyncio.sleep(0.1) # get enough time to stop the task | ||
assert consumer_task.done() is True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from contextlib import nullcontext as does_not_raise | ||
import json | ||
import pytest | ||
|
||
from consumer.tests.conftest import MockedIncomingMessage | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def mock_broadcast(mocker): | ||
return mocker.patch("websockets.broadcast") | ||
|
||
|
||
def python_to_bytes(data: dict) -> bytes: | ||
return json.dumps(data).encode() | ||
|
||
|
||
@pytest.fixture | ||
def broker_message_data(event): | ||
return { | ||
"event": event, | ||
"size": 3, | ||
"quantity": 2, | ||
} | ||
|
||
|
||
@pytest.fixture | ||
def ya_ws_subscribed(create_ws, ya_valid_token, ws_register_and_subscribe, event): | ||
return ws_register_and_subscribe(create_ws(), ya_valid_token, event) | ||
|
||
|
||
@pytest.fixture | ||
def ya_user_ws_subscribed(create_ws, ya_user_valid_token, ws_register_and_subscribe, event): | ||
return ws_register_and_subscribe(create_ws(), ya_user_valid_token, event) | ||
|
||
|
||
@pytest.fixture | ||
def consumed_message(broker_message_data): | ||
return MockedIncomingMessage(body=python_to_bytes(broker_message_data)) | ||
|
||
|
||
@pytest.fixture | ||
def on_message(consumer, consumed_message): | ||
return lambda message=consumed_message: consumer.on_message(message) | ||
|
||
|
||
async def test_broadcast_message_to_subscriber_websockets(on_message, ws_subscribed, mock_broadcast, mocker): | ||
await on_message() | ||
|
||
mock_broadcast.assert_called_once_with(websockets=[ws_subscribed], message=mocker.ANY) | ||
|
||
|
||
async def test_broadcast_message_to_all_subscribers_websockets(on_message, mock_broadcast, ws_subscribed, ya_ws_subscribed, ya_user_ws_subscribed, mocker): | ||
await on_message() | ||
|
||
mock_broadcast.assert_called_once() | ||
broadcasted_websockets = mock_broadcast.call_args.kwargs["websockets"] | ||
assert len(broadcasted_websockets) == 3 | ||
assert set(broadcasted_websockets) == {ws_subscribed, ya_ws_subscribed, ya_user_ws_subscribed} | ||
|
||
|
||
async def test_log_and_do_nothing_if_message_not_expected_format(on_message, ws_subscribed, mock_broadcast, consumed_message, caplog): | ||
consumed_message.body = b"invalid-json" | ||
|
||
with does_not_raise(): | ||
await on_message(consumed_message) | ||
|
||
assert "Consumed message not in expected format" in caplog.text | ||
mock_broadcast.assert_not_called() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters