Skip to content

Commit

Permalink
Simplest possible consumer (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
nkiryanov authored Jun 9, 2024
1 parent f14f2e5 commit df755ed
Show file tree
Hide file tree
Showing 12 changed files with 257 additions and 20 deletions.
3 changes: 2 additions & 1 deletion env.example
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ DEBUG=False

BROKER_URL=amqp://guest:guest@localhost:5672/
BROKER_EXCHANGE=guest-exchange
BROKER_ROUTING_KEYS=["test-create-event-key", "test-notification-event-key"]
BROKER_QUEUE=websockets-notifications-queue
BROKER_ROUTING_KEYS_CONSUME_FROM=["test-event-boobs"]

WEBSOCKETS_HOST=localhost
WEBSOCKETS_PORT=6789
Expand Down
3 changes: 2 additions & 1 deletion src/app/conf/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
class Settings(BaseSettings):
BROKER_URL: AmqpDsn
BROKER_EXCHANGE: str
BROKER_ROUTING_KEYS: list[str]
BROKER_QUEUE: str
BROKER_ROUTING_KEYS_CONSUME_FROM: list[str]
WEBSOCKETS_HOST: str
WEBSOCKETS_PORT: int
WEBSOCKETS_PATH: str
Expand Down
13 changes: 9 additions & 4 deletions src/app/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@


@pytest.fixture
def ws():
return MockedWebSocketServerProtocol()
def create_ws():
return lambda: MockedWebSocketServerProtocol()


@pytest.fixture
def ya_ws():
return MockedWebSocketServerProtocol()
def ws(create_ws):
return create_ws()


@pytest.fixture
def ya_ws(create_ws):
return create_ws()
3 changes: 2 additions & 1 deletion src/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def settings(mocker):
return_value=Settings(
BROKER_URL="amqp://guest:guest@localhost/",
BROKER_EXCHANGE="test-exchange",
BROKER_ROUTING_KEYS=["test-routing-key", "ya-test-routing-key"],
BROKER_QUEUE="test-queue",
BROKER_ROUTING_KEYS_CONSUME_FROM=["test-routing-key", "ya-test-routing-key"],
WEBSOCKETS_HOST="localhost",
WEBSOCKETS_PORT=50000,
WEBSOCKETS_PATH="/v2/test-subscription-websocket",
Expand Down
73 changes: 73 additions & 0 deletions src/consumer/consumer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import asyncio
from dataclasses import dataclass
import logging
from typing import Protocol

import aio_pika

from app import conf
from consumer.dto import ConsumedMessage
from consumer.dto import OutgoingMessage
from storage.subscription_storage import SubscriptionStorage
from pydantic import ValidationError
import websockets


logger = logging.getLogger(__name__)


class ConsumerProtocol(Protocol):
async def consume(self) -> None:
pass


@dataclass
class Consumer:
storage: SubscriptionStorage

def __post_init__(self) -> None:
settings = conf.get_app_settings()

self.broker_url: str = str(settings.BROKER_URL)
self.exchange: str = settings.BROKER_EXCHANGE
self.queue: str = settings.BROKER_QUEUE
self.routing_keys_consume_from: list[str] = settings.BROKER_ROUTING_KEYS_CONSUME_FROM

async def consume(self, stop_signal: asyncio.Future) -> None:
connection = await aio_pika.connect_robust(self.broker_url)

async with connection:
channel = await connection.channel()

exchange = await channel.declare_exchange(self.exchange, type=aio_pika.ExchangeType.DIRECT)
queue = await channel.declare_queue(name=self.queue, exclusive=True)

for routing_key in self.routing_keys_consume_from:
await queue.bind(exchange=exchange, routing_key=routing_key)

await queue.consume(self.on_message)

await stop_signal

async def on_message(self, raw_message: aio_pika.abc.AbstractIncomingMessage) -> None:
async with raw_message.process():
processed_messages = self.parse_message(raw_message)

if processed_messages:
self.broadcast_subscribers(self.storage, processed_messages)

@staticmethod
def parse_message(raw_message: aio_pika.abc.AbstractIncomingMessage) -> ConsumedMessage | None:
try:
return ConsumedMessage.model_validate_json(raw_message.body)
except ValidationError as exc:
logger.error("Consumed message not in expected format. Errors: %s", exc.errors())
return None

@staticmethod
def broadcast_subscribers(storage: SubscriptionStorage, message: ConsumedMessage) -> None:
websockets_to_broadcast = storage.get_event_subscribers_websockets(message.event)

if websockets_to_broadcast:
outgoing_message = OutgoingMessage.model_construct(payload=message)
websockets.broadcast(websockets=websockets_to_broadcast, message=outgoing_message.model_dump_json())
15 changes: 15 additions & 0 deletions src/consumer/dto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from pydantic import BaseModel
from pydantic import ConfigDict
from typing import Literal
from app.types import Event


class ConsumedMessage(BaseModel):
model_config = ConfigDict(extra="allow")

event: Event


class OutgoingMessage(BaseModel):
message_type: Literal["EventNotification"] = "EventNotification"
payload: ConsumedMessage
37 changes: 37 additions & 0 deletions src/consumer/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest

from consumer.consumer import Consumer
from dataclasses import dataclass
from contextlib import asynccontextmanager
from typing import AsyncGenerator


@pytest.fixture(autouse=True)
def _adjust_settings(settings):
settings.BROKER_URL = "amqp://guest:guest@localhost/"
settings.BROKER_EXCHANGE = "test-exchange"
settings.BROKER_QUEUE = "test-queue"
settings.BROKER_ROUTING_KEYS_CONSUME_FROM = [
"event-routing-key",
"ya-event-routing-key",
]


@pytest.fixture
def consumer(storage) -> Consumer:
return Consumer(storage=storage)


@dataclass
class MockedIncomingMessage:
"""The simplest Incoming message class that represent incoming amqp message.
The safer choice is to use 'aio_pika.abc.AbstractIncomingMessage,' but the test setup will be significantly more challenging.
"""

body: bytes

@asynccontextmanager
async def process(self) -> AsyncGenerator:
"""Do nothing, just for compatibility with aio_pika.abc.AbstractIncomingMessage."""
yield None
30 changes: 30 additions & 0 deletions src/consumer/tests/tests_consumer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import asyncio
import pytest


@pytest.fixture
def run_consume_task(consumer):
def run(stop_signal: asyncio.Future):
return asyncio.create_task(consumer.consume(stop_signal))

return run


def test_consumer_attributes(consumer):
assert consumer.broker_url == "amqp://guest:guest@localhost/"
assert consumer.exchange == "test-exchange"
assert consumer.queue == "test-queue"
assert consumer.routing_keys_consume_from == [
"event-routing-key",
"ya-event-routing-key",
]


async def test_consumer_correctly_stopped_on_stop_signal(run_consume_task):
stop_signal = asyncio.get_running_loop().create_future()
consumer_task = run_consume_task(stop_signal)

stop_signal.set_result(None)

await asyncio.sleep(0.1) # get enough time to stop the task
assert consumer_task.done() is True
68 changes: 68 additions & 0 deletions src/consumer/tests/tests_consumer_on_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from contextlib import nullcontext as does_not_raise
import json
import pytest

from consumer.tests.conftest import MockedIncomingMessage


@pytest.fixture(autouse=True)
def mock_broadcast(mocker):
return mocker.patch("websockets.broadcast")


def python_to_bytes(data: dict) -> bytes:
return json.dumps(data).encode()


@pytest.fixture
def broker_message_data(event):
return {
"event": event,
"size": 3,
"quantity": 2,
}


@pytest.fixture
def ya_ws_subscribed(create_ws, ya_valid_token, ws_register_and_subscribe, event):
return ws_register_and_subscribe(create_ws(), ya_valid_token, event)


@pytest.fixture
def ya_user_ws_subscribed(create_ws, ya_user_valid_token, ws_register_and_subscribe, event):
return ws_register_and_subscribe(create_ws(), ya_user_valid_token, event)


@pytest.fixture
def consumed_message(broker_message_data):
return MockedIncomingMessage(body=python_to_bytes(broker_message_data))


@pytest.fixture
def on_message(consumer, consumed_message):
return lambda message=consumed_message: consumer.on_message(message)


async def test_broadcast_message_to_subscriber_websockets(on_message, ws_subscribed, mock_broadcast, mocker):
await on_message()

mock_broadcast.assert_called_once_with(websockets=[ws_subscribed], message=mocker.ANY)


async def test_broadcast_message_to_all_subscribers_websockets(on_message, mock_broadcast, ws_subscribed, ya_ws_subscribed, ya_user_ws_subscribed, mocker):
await on_message()

mock_broadcast.assert_called_once()
broadcasted_websockets = mock_broadcast.call_args.kwargs["websockets"]
assert len(broadcasted_websockets) == 3
assert set(broadcasted_websockets) == {ws_subscribed, ya_ws_subscribed, ya_user_ws_subscribed}


async def test_log_and_do_nothing_if_message_not_expected_format(on_message, ws_subscribed, mock_broadcast, consumed_message, caplog):
consumed_message.body = b"invalid-json"

with does_not_raise():
await on_message(consumed_message)

assert "Consumed message not in expected format" in caplog.text
mock_broadcast.assert_not_called()
3 changes: 3 additions & 0 deletions src/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
import signal

import websockets
import logging

from app import conf
from handlers import WebSocketsHandler
from handlers import WebSocketsAccessGuardian
from storage.subscription_storage import SubscriptionStorage

logging.basicConfig(level=logging.INFO)


def create_stop_signal() -> asyncio.Future[None]:
loop = asyncio.get_running_loop()
Expand Down
10 changes: 10 additions & 0 deletions src/storage/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,13 @@ def subscribe(ws, event):
def ws_subscribed(ws_registered, subscribe_ws, event):
subscribe_ws(ws_registered, event)
return ws_registered


@pytest.fixture
def ws_register_and_subscribe(register_ws, subscribe_ws):
def register_and_subscribe(ws, token, event):
register_ws(ws, token)
subscribe_ws(ws, event)
return ws

return register_and_subscribe
19 changes: 6 additions & 13 deletions src/storage/subscription_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,15 @@ def get_websocket_user_id(self, websocket: WebSocketServerProtocol) -> UserId |
websocket_meta = self.registered_websockets.get(websocket)
return websocket_meta.user_id if websocket_meta else None

def get_event_subscribers_user_ids(self, event: Event) -> set[UserId]:
return self.subscriptions.get(event) or set()
def get_event_subscribers_websockets(self, event: Event) -> list[WebSocketServerProtocol]:
subscribers_user_ids = self.subscriptions.get(event) or set()

def is_event_has_subscribers(self, event: Event) -> bool:
return event in self.subscriptions
user_websockets = []

def get_users_websockets(self, user_ids: set[UserId]) -> list[WebSocketServerProtocol]:
users_websockets = []
for user_id in subscribers_user_ids:
user_websockets.extend(self.user_connections[user_id].websockets)

for user_id in user_ids:
user_connection_meta = self.user_connections.get(user_id)

if user_connection_meta:
users_websockets.extend(user_connection_meta.websockets)

return users_websockets
return user_websockets

def get_expired_websockets(self) -> list[WebSocketServerProtocol]:
now_timestamp = time.time()
Expand Down

0 comments on commit df755ed

Please sign in to comment.