Skip to content

Commit

Permalink
WebSocketsAccessGuardian and test fot it
Browse files Browse the repository at this point in the history
  • Loading branch information
nkiryanov committed Jun 8, 2024
1 parent 8a38955 commit 54ec107
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 18 deletions.
13 changes: 8 additions & 5 deletions src/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import websockets

from app import conf
from handlers.websockets_handler import WebSocketsHandler
from handlers import WebSocketsHandler
from handlers import WebSocketsAccessGuardian
from storage.subscription_storage import SubscriptionStorage


Expand All @@ -15,25 +16,27 @@ def create_stop_signal() -> asyncio.Future[None]:
return stop_signal


async def app_runner(settings: conf.Settings, websockets_handler: WebSocketsHandler) -> None:
stop_signal = create_stop_signal()

async def app_runner(settings: conf.Settings, websockets_handler: WebSocketsHandler, access_guardian: WebSocketsAccessGuardian) -> None:
async with websockets.serve(
ws_handler=websockets_handler.websockets_handler,
host=settings.WEBSOCKETS_HOST,
port=settings.WEBSOCKETS_PORT,
):
await stop_signal
await access_guardian.run()


async def main() -> None:
settings = conf.get_app_settings()
stop_signal = create_stop_signal()

storage = SubscriptionStorage()
websockets_handler = WebSocketsHandler(storage=storage)
access_guardian = WebSocketsAccessGuardian(storage=storage, check_interval=60.0, stop_signal=stop_signal)

await app_runner(
settings=settings,
websockets_handler=websockets_handler,
access_guardian=access_guardian,
)


Expand Down
10 changes: 7 additions & 3 deletions src/handlers/tests/tests_websockets_access_guardian.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@ def get_guardian_enough_time_to_do_its_job():


@pytest.fixture
def guardian(storage):
return WebSocketsAccessGuardian(storage=storage, check_interval=0.1)
async def guardian(storage):
stop_signal = asyncio.Future() # it's better to use loop.create_future() but ok for tests

yield WebSocketsAccessGuardian(storage=storage, check_interval=0.1, stop_signal=stop_signal)

stop_signal.set_result(None)


@pytest.fixture(autouse=True)
def guardian_as_task(guardian, event_loop):
return event_loop.create_task(guardian.run_validate_access())
return event_loop.create_task(guardian.run())


@pytest.fixture(autouse=True)
Expand Down
11 changes: 4 additions & 7 deletions src/handlers/websockets_access_guardian.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from dataclasses import dataclass
from dataclasses import dataclass, field
import logging

import websockets
Expand All @@ -15,13 +15,10 @@
class WebSocketsAccessGuardian:
storage: SubscriptionStorage
check_interval: float = 60.0 # in seconds
stop_signal: asyncio.Future[None] | None = None
stop_signal: asyncio.Future[None] = field(default_factory=asyncio.Future) # default feature will run forever

def is_guardian_stopped(self) -> bool:
return self.stop_signal.done() if self.stop_signal else False

async def run_validate_access(self) -> None:
while not self.is_guardian_stopped():
async def run(self) -> None:
while True:
await asyncio.sleep(self.check_interval)

self.monitor_and_manage_access()
Expand Down
20 changes: 17 additions & 3 deletions src/tests/functional/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from websockets import client

from entrypoint import app_runner
from handlers.websockets_handler import WebSocketsHandler
from handlers import WebSocketsHandler
from handlers import WebSocketsAccessGuardian


@pytest.fixture
Expand All @@ -24,18 +25,31 @@ def websockets_handler(storage):
return WebSocketsHandler(storage=storage)


@pytest.fixture
async def stop_signal():
return asyncio.Future()


@pytest.fixture
async def access_guardian(storage, stop_signal):
return WebSocketsAccessGuardian(storage=storage, check_interval=0.5, stop_signal=stop_signal)


@pytest.fixture(autouse=True)
async def serve_app_runner(settings, websockets_handler):
async def serve_app_runner(settings, websockets_handler, access_guardian, stop_signal):
serve_task = asyncio.get_running_loop().create_task(
app_runner(
settings=settings,
websockets_handler=websockets_handler,
access_guardian=access_guardian,
),
)

await asyncio.sleep(0.1) # give enough time to start the server
assert serve_task.done() is False # be sure server is running
return serve_task
yield serve_task

stop_signal.set_result(None)


@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import asyncio
import pytest

pytestmark = [
pytest.mark.slow,
]


@pytest.fixture
def set_storage_connections_expired(storage):
def set_expired():
for _, websocket_meta in storage.registered_websockets.items():
websocket_meta.expiration_timestamp = 1000 # far in the past

return set_expired


async def test_expired_connections_removed_from_active_connections(ws_client_authenticated, ws_client_recv_decoded, set_storage_connections_expired):
set_storage_connections_expired()
await asyncio.sleep(1.1) # give enough time to validator to do its job

received = await ws_client_recv_decoded(ws_client_authenticated)

assert len(received) == 2
assert received["message_type"] == "ErrorResponse"
assert received["errors"] == ["Token expired, user subscriptions disabled or removed"]

0 comments on commit 54ec107

Please sign in to comment.