diff --git a/docs/source/conf.py b/docs/source/conf.py index 15c490c..a990ae5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -33,6 +33,6 @@ html_static_path = ['_static'] html_theme_options = { - 'page_width': '1200px', # Set this to your desired width - 'sidebar_width': '400px', # Adjust the sidebar width as well + 'page_width': '1300px', # Set this to your desired width + 'sidebar_width': '450px', # Adjust the sidebar width as well } diff --git a/picows/picows.pxd b/picows/picows.pxd index 266c9f4..60d6cc9 100644 --- a/picows/picows.pxd +++ b/picows/picows.pxd @@ -83,6 +83,8 @@ cdef class WSTransport: readonly bint is_client_side readonly bint is_secure + bint auto_ping_expect_pong + object _logger #: Logger bint _log_debug_enabled object _disconnected_future #: asyncio.Future @@ -95,6 +97,7 @@ cdef class WSTransport: cpdef send_pong(self, message=*) cpdef send_close(self, WSCloseCode close_code=*, close_message=*) cpdef disconnect(self) + cpdef notify_user_specific_pong_received(self) cdef inline _send_http_handshake(self, bytes ws_path, bytes host_port, bytes websocket_key_b64) cdef inline _send_http_handshake_response(self, bytes accept_val) @@ -109,6 +112,10 @@ cdef class WSListener: cpdef on_ws_connected(self, WSTransport transport) cpdef on_ws_frame(self, WSTransport transport, WSFrame frame) cpdef on_ws_disconnected(self, WSTransport transport) + + cpdef send_user_specific_ping(self, WSTransport transport) + cpdef is_user_specific_pong(self, WSFrame frame) + cpdef pause_writing(self) cpdef resume_writing(self) diff --git a/picows/picows.pyx b/picows/picows.pyx index d4b2741..893bd32 100644 --- a/picows/picows.pyx +++ b/picows/picows.pyx @@ -9,18 +9,18 @@ import struct import urllib.parse from ssl import SSLContext from typing import cast, Tuple, Optional, Callable + from multidict import CIMultiDict cimport cython - from cpython.bytes cimport PyBytes_GET_SIZE, PyBytes_AS_STRING, PyBytes_FromStringAndSize, PyBytes_CheckExact from cpython.bytearray cimport PyByteArray_AS_STRING, PyByteArray_GET_SIZE, PyByteArray_CheckExact from cpython.memoryview cimport PyMemoryView_FromMemory from cpython.mem cimport PyMem_Malloc, PyMem_Realloc, PyMem_Free from cpython.buffer cimport PyBUF_WRITE, PyBUF_READ, PyBUF_SIMPLE, PyObject_GetBuffer, PyBuffer_Release from cpython.unicode cimport PyUnicode_FromStringAndSize, PyUnicode_DecodeASCII -from libc cimport errno +from libc cimport errno from libc.string cimport memmove, memcpy, strerror from libc.stdlib cimport rand @@ -49,7 +49,8 @@ cdef extern from "picows_compat.h" nogil: cdef ssize_t PICOWS_SOCKET_ERROR int picows_get_errno() - ssize_t send(int sockfd, const void* buf, size_t len, int flags); + double picows_get_monotonic_time() + ssize_t send(int sockfd, const void* buf, size_t len, int flags) class WSError(RuntimeError): @@ -160,7 +161,6 @@ cdef class WSFrame: """ Received websocket frame. - Internally WSFrame just points to a chunk of memory in the receiving buffer without copying or owning memory.\n .. DANGER:: Only use WSFrame object during :any:`WSListener.on_ws_frame` callback. WSFrame objects are essentially just pointers to the underlying receiving buffer. After :any:`WSListener.on_ws_frame` has completed the buffer @@ -322,7 +322,7 @@ cdef class WSListener: WSFrame is essentially just a pointer to a chunk of memory in the receiving buffer. It does not own the memory. Do NOT cache or store WSFrame object for later processing because the data may be invalidated after :any:`WSListener.on_ws_frame` is complete. - Process the payload immediatelly or just copy it with one of `WSFrame.get_*` methods. + Process the payload immediately or just copy it with one of `WSFrame.get_*` methods. :param transport: :any:`WSTransport` object :param frame: :any:`WSFrame` object @@ -337,6 +337,42 @@ cdef class WSListener: """ pass + cpdef send_user_specific_ping(self, WSTransport transport): + """ + Called when the auto-ping logic wants to send a ping to a remote peer. + + User can override this method to send something else instead of + the standard PING frame. + + Default implementation: + + .. code:: python + + def send_user_specific_ping(self, transport: picows.WSTransport) + return transport.send_ping() + + :param transport: :any:`WSTransport` + """ + transport.send_ping() + + cpdef is_user_specific_pong(self, WSFrame frame): + """ + Called before :any:`WSListener.on_ws_frame` if auto ping is enabled and pong is expected. + + User can override this method to indicate that the received frame is a + valid response to a previously sent user specific ping message. + + The default implementation just do: + + .. code:: python + + def is_user_specific_pong(self, frame: picows.WSFrame) + return frame.msg_type == WSMsgType.PONG + + :return: Returns True if the frame is a response to a previously send ping. In such case the frame will be *consumed* by the protocol, i.e :any:`WSListener.on_ws_frame` will not be called for this frame. + """ + return frame.msg_type == WSMsgType.PONG + cpdef pause_writing(self): """ Called when the underlying transport’s buffer goes over the high watermark. @@ -355,6 +391,7 @@ cdef class WSTransport: self.underlying_transport = underlying_transport self.is_client_side = is_client_side self.is_secure = underlying_transport.get_extra_info('ssl_object') is not None + self.auto_ping_expect_pong = False self._logger = logger self._log_debug_enabled = self._logger.isEnabledFor(PICOWS_DEBUG_LL) self._disconnected_future = loop.create_future() @@ -532,6 +569,24 @@ cdef class WSTransport: """ await asyncio.shield(self._disconnected_future) + cpdef notify_user_specific_pong_received(self): + """ + Notify the auto-ping loop that a user-specific pong message + has been received. + + This method is useful when determining whether a frame contains a + user-specific pong is too expensive for is_user_specific_pong + (for example, it may require full JSON parsing). + In such cases, :any:`WSListener.is_user_specific_pong` should always + return `False`, and the logic in :any:`WSListener.on_ws_frame` should + call :any:`WSTransport.notify_user_specific_pong_received`. + + It is safe to call this method even if auto-ping is disabled or + the auto-ping loop doesn’t expect pong messages. + In such cases, the method simply does nothing. + """ + self.auto_ping_expect_pong = False + cdef _send_http_handshake(self, bytes ws_path, bytes host_port, bytes websocket_key_b64): initial_handshake = (b"GET %b HTTP/1.1\r\n" b"Host: %b\r\n" @@ -645,6 +700,12 @@ cdef class WSProtocol: bytes _websocket_key_b64 size_t _max_frame_size + bint _enable_auto_ping + double _auto_ping_idle_timeout + double _auto_ping_reply_timeout + object _auto_ping_loop_task + double _last_data_time + # The following are the parts of an unfinished frame # Once the frame is finished WSFrame is created and returned WSParserState _state @@ -662,7 +723,8 @@ cdef class WSProtocol: uint8_t _f_payload_length_flag def __init__(self, str host_port, str ws_path, bint is_client_side, ws_listener_factory, str logger_name, - bint disconnect_on_exception, websocket_handshake_timeout): + bint disconnect_on_exception, websocket_handshake_timeout, + enable_auto_ping, auto_ping_idle_timeout, auto_ping_reply_timeout): self.transport = None self.listener = None @@ -684,6 +746,16 @@ cdef class WSProtocol: self._websocket_key_b64 = b64encode(os.urandom(16)) self._max_frame_size = 1024 * 1024 + self._enable_auto_ping = enable_auto_ping + self._auto_ping_idle_timeout = auto_ping_idle_timeout + self._auto_ping_reply_timeout = auto_ping_reply_timeout + self._auto_ping_loop_task = None + self._last_data_time = 0 + + if self._enable_auto_ping: + assert self._auto_ping_reply_timeout <= self._auto_ping_idle_timeout, \ + "auto_ping_reply_timeout can't be bigger than auto_ping_idle_timeout" + self._state = WSParserState.WAIT_UPGRADE_RESPONSE self._buffer = MemoryBuffer() self._f_new_data_start_pos = 0 @@ -749,6 +821,9 @@ cdef class WSProtocol: if self._handshake_timeout_handle is not None: self._handshake_timeout_handle.cancel() + if self._auto_ping_loop_task is not None and not self._auto_ping_loop_task.done(): + self._auto_ping_loop_task.cancel() + self.transport._mark_disconnected() def eof_received(self) -> bool: @@ -827,6 +902,8 @@ cdef class WSProtocol: if not self._negotiate(): return + self._last_data_time = picows_get_monotonic_time() + cdef WSFrame frame = self._get_next_frame() if frame is None: return @@ -894,8 +971,51 @@ cdef class WSProtocol: self._handshake_timeout_handle = None self._handshake_complete_future.set_result(None) self._invoke_on_ws_connected() + self._last_data_time = picows_get_monotonic_time() + if self._enable_auto_ping: + self._auto_ping_loop_task = self._loop.create_task(self._auto_ping_loop()) return True + async def _auto_ping_loop(self): + cdef double now + cdef double prev_last_data_time + cdef double idle_delay + cdef object sleep = asyncio.sleep + try: + if self._log_debug_enabled: + self._logger.log(PICOWS_DEBUG_LL, "Auto-ping loop started with idle_timeout=%s, reply_timeout=%s", + self._auto_ping_idle_timeout, self._auto_ping_reply_timeout) + + while True: + now = picows_get_monotonic_time() + idle_delay = self._last_data_time + self._auto_ping_idle_timeout - now + prev_last_data_time = self._last_data_time + await sleep(idle_delay) + + if self._last_data_time > prev_last_data_time: + continue + + self.listener.send_user_specific_ping(self.transport) + + self.transport.auto_ping_expect_pong = True + await sleep(self._auto_ping_reply_timeout) + if self.transport.auto_ping_expect_pong: + # Pong hasn't arrived withing specified interval + self.transport.send_close(WSCloseCode.GOING_AWAY, f"peer has not replied to ping/heartbeat request within {self._auto_ping_reply_timeout} second(s)".encode()) + # Give a chance for the transport to send close message + # But don't wait for any tcp confirmation, use abort() + # because normal disconnect may hang until OS TCP/IP timeout + # for ACK is fired. + await sleep(0.01) + self.transport.underlying_transport.abort() + except asyncio.CancelledError: + if self._log_debug_enabled: + self._logger.log(PICOWS_DEBUG_LL, "Auto-ping loop cancelled") + except: + self._logger.exception("Auto-ping loop failed, disconnect websocket") + self.transport.send_close(WSCloseCode.INTERNAL_ERROR, b"an exception occurred in auto-ping loop") + self.transport.disconnect() + cdef inline tuple _try_read_upgrade_request(self): cdef bytes data = PyBytes_FromStringAndSize(self._buffer.data, self._f_new_data_start_pos) cdef list request = data.split(b"\r\n\r\n", 1) @@ -1153,6 +1273,13 @@ cdef class WSProtocol: cdef inline _invoke_on_ws_frame(self, WSFrame frame): try: + if self._enable_auto_ping and self.transport.auto_ping_expect_pong: + if self.listener.is_user_specific_pong(frame): + self.transport.auto_ping_expect_pong = False + if self._log_debug_enabled: + self._logger.log(PICOWS_DEBUG_LL, "Received pong for the previously sent ping, reset expect_pong flag") + return + self.listener.on_ws_frame(self.transport, frame) except Exception as e: if self._disconnect_on_exception: @@ -1193,6 +1320,9 @@ async def ws_connect(ws_listener_factory: Callable[[], WSListener], bint disconnect_on_exception: bool=True, websocket_handshake_timeout=5, logger_name: str="client", + enable_auto_ping: bool = False, + auto_ping_idle_timeout: float = 10, + auto_ping_reply_timeout: float = 10, **kwargs ) -> Tuple[WSTransport, WSListener]: """ @@ -1212,6 +1342,18 @@ async def ws_connect(ws_listener_factory: Callable[[], WSListener], is the time in seconds to wait for the websocket client to receive websocket handshake response before aborting the connection. :param logger_name: picows will use `picows.` logger to do all the logging. + :param enable_auto_ping: + Enable detection of a stale connection by periodically pinging remote peer. + + .. note:: + This does NOT enable automatic replies to incoming `ping` requests. + Library user is always supposed to explicitly implement replies + to incoming `ping` requests in `WSListener.on_ws_frame` + :param auto_ping_idle_timeout: + how long to wait before sending `ping` request when there is no + incoming data. + :param auto_ping_reply_timeout: + how long to wait for a `pong` reply before shutting down connection. :return: :any:`WSTransport` object and a user handler returned by `ws_listener_factory()` """ @@ -1234,7 +1376,8 @@ async def ws_connect(ws_listener_factory: Callable[[], WSListener], if url_parts.query: path_plus_query += "?" + url_parts.query ws_protocol_factory = lambda: WSProtocol(url_parts.netloc, path_plus_query, True, ws_listener_factory, - logger_name, disconnect_on_exception, websocket_handshake_timeout) + logger_name, disconnect_on_exception, websocket_handshake_timeout, + enable_auto_ping, auto_ping_idle_timeout, auto_ping_reply_timeout) cdef WSProtocol ws_protocol @@ -1253,6 +1396,9 @@ async def ws_create_server(ws_listener_factory: Callable[[WSUpgradeRequest], Opt bint disconnect_on_exception: bool=True, websocket_handshake_timeout=5, str logger_name: str="server", + enable_auto_ping: bool = False, + auto_ping_idle_timeout: float = 20, + auto_ping_reply_timeout: float = 20, **kwargs ) -> asyncio.Server: """ @@ -1291,10 +1437,23 @@ async def ws_create_server(ws_listener_factory: Callable[[WSUpgradeRequest], Opt is the time in seconds to wait for the websocket server to receive websocket handshake request before aborting the connection. :param logger_name: picows will use `picows.` logger to do all the logging. + :param enable_auto_ping: + Enable detection of a stale connection by periodically pinging remote peer. + + .. note:: + This does NOT enable automatic replies to incoming `ping` requests. + Library user is always supposed to explicitly implement replies + to incoming `ping` requests in `WSListener.on_ws_frame` + :param auto_ping_idle_timeout: + how long to wait before sending `ping` request when there is no + incoming data. + :param auto_ping_reply_timeout: + how long to wait for a `pong` reply before shutting down connection. :return: `asyncio.Server `_ object """ ws_protocol_factory = lambda: WSProtocol(None, None, False, ws_listener_factory, logger_name, - disconnect_on_exception, websocket_handshake_timeout) + disconnect_on_exception, websocket_handshake_timeout, + enable_auto_ping, auto_ping_idle_timeout, auto_ping_reply_timeout) return await asyncio.get_running_loop().create_server( ws_protocol_factory, diff --git a/picows/picows_compat.h b/picows/picows_compat.h index 4b856b8..88e253d 100644 --- a/picows/picows_compat.h +++ b/picows/picows_compat.h @@ -62,7 +62,7 @@ #include #define PICOWS_SOCKET_ERROR SOCKET_ERROR - inline int picows_convert_wsa_error_to_errno(int ec) + static inline int picows_convert_wsa_error_to_errno(int ec) { switch(ec) { @@ -93,16 +93,33 @@ } } - inline int picows_get_errno(void) + static inline int picows_get_errno(void) { return picows_convert_wsa_error_to_errno(WSAGetLastError()); } + + static inline double picows_get_monotonic_time(void) + { + LARGE_INTEGER frequency, counter; + QueryPerformanceFrequency(&frequency); + QueryPerformanceCounter(&counter); + return (double)counter.QuadPart / frequency.QuadPart; + } #else #include #include + #include + #define PICOWS_SOCKET_ERROR -1 #define PICOWS_EAGAIN EAGAIN #define PICOWS_EWOULDBLOCK EWOULDBLOCK - inline int picows_get_errno(void) { return errno; } + static inline int picows_get_errno(void) { return errno; } + + static inline double picows_get_monotonic_time(void) + { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (double)ts.tv_sec + (double)ts.tv_nsec * 1e-9; + } #endif diff --git a/tests/test_autoping.py b/tests/test_autoping.py new file mode 100644 index 0000000..b111cf6 --- /dev/null +++ b/tests/test_autoping.py @@ -0,0 +1,253 @@ +import asyncio +from idlelib.pyparse import trans + +import async_timeout +from aiohttp import WSMsgType + +import picows +from tests.utils import ServerAsyncContext, TIMEOUT, TextFrame, CloseFrame, \ + BinaryFrame, materialize_frame + + +class AccumulatingListener(picows.WSListener): + def __init__(self): + self.frames = [] + + def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): + self.frames.append(materialize_frame(frame)) + + +class AccumulatingServerListener(picows.WSListener): + def __init__(self, server_frames): + self.frames = server_frames + + def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): + self.frames.append(materialize_frame(frame)) + + + +async def test_ping_pong(): + server = await picows.ws_create_server(lambda _: picows.WSListener(), + "127.0.0.1", 0, + enable_auto_ping=True, + auto_ping_idle_timeout=0.1, + auto_ping_reply_timeout=0.1) + + class ClientListener(picows.WSListener): + def __init__(self): + self.ping_count = 0 + + def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): + if frame.msg_type == picows.WSMsgType.PING: + self.ping_count += 1 + transport.send_pong() + + if self.ping_count == 3: + transport.disconnect() + + async with ServerAsyncContext(server): + url = f"ws://127.0.0.1:{server.sockets[0].getsockname()[1]}" + (transport, listener) = await picows.ws_connect(ClientListener, url) + async with async_timeout.timeout(TIMEOUT): + await transport.wait_disconnected() + assert listener.ping_count == 3 + + +async def test_custom_ping_consume_pong(): + server_frames = [] + + class ServerClientListener(AccumulatingServerListener): + def send_user_specific_ping(self, transport: picows.WSTransport): + transport.send(WSMsgType.TEXT, b"ping") + + def is_user_specific_pong(self, frame: picows.WSFrame): + return frame.get_payload_as_memoryview() == b"pong" + + server = await picows.ws_create_server(lambda _: ServerClientListener(server_frames), + "127.0.0.1", 0, + enable_auto_ping=True, + auto_ping_idle_timeout=0.1, + auto_ping_reply_timeout=0.1) + + class ClientListener(picows.WSListener): + def __init__(self): + self.ping_count = 0 + + def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): + if frame.msg_type == picows.WSMsgType.TEXT and frame.get_payload_as_memoryview() == b"ping": + self.ping_count += 1 + transport.send(picows.WSMsgType.TEXT, b"pong") + + if self.ping_count == 3: + transport.disconnect() + + async with ServerAsyncContext(server): + url = f"ws://127.0.0.1:{server.sockets[0].getsockname()[1]}" + (transport, listener) = await picows.ws_connect(ClientListener, url) + async with async_timeout.timeout(TIMEOUT): + transport.send(WSMsgType.TEXT, b"hello") + await transport.wait_disconnected() + + assert listener.ping_count == 3 + assert len(server_frames) == 1 + assert server_frames[0].payload_as_ascii_text == "hello" + + +async def test_custom_ping_notify_pong(): + server_frames = [] + + class ServerClientListener(AccumulatingServerListener): + def send_user_specific_ping(self, transport: picows.WSTransport): + transport.send(WSMsgType.TEXT, b"ping") + + def is_user_specific_pong(self, frame: picows.WSFrame): + return False + + def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): + if frame.get_payload_as_memoryview() == b"pong": + transport.notify_user_specific_pong_received() + return + + super().on_ws_frame(transport, frame) + + server = await picows.ws_create_server(lambda _: ServerClientListener(server_frames), + "127.0.0.1", 0, + enable_auto_ping=True, + auto_ping_idle_timeout=0.1, + auto_ping_reply_timeout=0.1) + + class ClientListener(picows.WSListener): + def __init__(self): + self.ping_count = 0 + + def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): + if frame.msg_type == picows.WSMsgType.TEXT and frame.get_payload_as_memoryview() == b"ping": + self.ping_count += 1 + transport.send(picows.WSMsgType.TEXT, b"pong") + + if self.ping_count == 3: + transport.disconnect() + + async with ServerAsyncContext(server): + url = f"ws://127.0.0.1:{server.sockets[0].getsockname()[1]}" + (transport, listener) = await picows.ws_connect(ClientListener, url) + async with async_timeout.timeout(TIMEOUT): + transport.send(WSMsgType.TEXT, b"hello") + await transport.wait_disconnected() + + assert listener.ping_count == 3 + assert len(server_frames) == 1 + assert server_frames[0].payload_as_ascii_text == "hello" + + +async def test_no_pong_reply(): + server = await picows.ws_create_server(lambda _: picows.WSListener(), + "127.0.0.1", 0, + enable_auto_ping=True, + auto_ping_idle_timeout=0.1, + auto_ping_reply_timeout=0.1) + + async with ServerAsyncContext(server): + url = f"ws://127.0.0.1:{server.sockets[0].getsockname()[1]}" + (transport, listener) = await picows.ws_connect(AccumulatingListener, url) + async with async_timeout.timeout(TIMEOUT): + await transport.wait_disconnected() + + assert len(listener.frames) == 2 + assert listener.frames[0].msg_type == picows.WSMsgType.PING + assert listener.frames[1].msg_type == picows.WSMsgType.CLOSE + assert listener.frames[1].close_code == picows.WSCloseCode.GOING_AWAY + + +async def test_no_ping_when_data_is_present(): + server = await picows.ws_create_server(lambda _: picows.WSListener(), + "127.0.0.1", 0, + enable_auto_ping=True, + auto_ping_idle_timeout=0.1, + auto_ping_reply_timeout=0.1) + + async with ServerAsyncContext(server): + url = f"ws://127.0.0.1:{server.sockets[0].getsockname()[1]}" + (transport, listener) = await picows.ws_connect(AccumulatingListener, url) + async with async_timeout.timeout(TIMEOUT): + for i in range(5): + await asyncio.sleep(0.05) + transport.send(picows.WSMsgType.TEXT, b"hi") + + transport.disconnect() + await transport.wait_disconnected() + + assert len(listener.frames) == 0 + + +async def test_consume_pong_when_data_is_present(): + server = await picows.ws_create_server(lambda _: picows.WSListener(), + "127.0.0.1", 0, + enable_auto_ping=True, + auto_ping_idle_timeout=0.1, + auto_ping_reply_timeout=0.1) + + async with ServerAsyncContext(server): + url = f"ws://127.0.0.1:{server.sockets[0].getsockname()[1]}" + (transport, listener) = await picows.ws_connect(AccumulatingListener, url) + async with async_timeout.timeout(TIMEOUT): + for i in range(5): + await asyncio.sleep(0.05) + transport.send(picows.WSMsgType.TEXT, b"hi") + + transport.disconnect() + await transport.wait_disconnected() + + assert len(listener.frames) == 0 + + +async def test_send_user_specific_ping_exception(): + class ServerClientListener(picows.WSListener): + def send_user_specific_ping(self, transport: picows.WSTransport): + raise RuntimeError("failed") + + server = await picows.ws_create_server(lambda _: ServerClientListener(), + "127.0.0.1", 0, + enable_auto_ping=True, + auto_ping_idle_timeout=0.1, + auto_ping_reply_timeout=0.1) + + async with ServerAsyncContext(server): + url = f"ws://127.0.0.1:{server.sockets[0].getsockname()[1]}" + (transport, listener) = await picows.ws_connect(AccumulatingListener, url) + async with async_timeout.timeout(TIMEOUT): + await transport.wait_disconnected() + + assert len(listener.frames) == 1 + assert isinstance(listener.frames[0], CloseFrame) + assert listener.frames[0].close_code == picows.WSCloseCode.INTERNAL_ERROR + + +async def test_is_user_specific_pong_exception(): + class ServerClientListener(picows.WSListener): + def is_user_specific_pong(self, transport: picows.WSTransport): + raise RuntimeError("failed") + + server = await picows.ws_create_server(lambda _: ServerClientListener(), + "127.0.0.1", 0, + enable_auto_ping=True, + auto_ping_idle_timeout=0.1, + auto_ping_reply_timeout=0.1) + + class ClientListener(AccumulatingListener): + def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): + if frame.msg_type == picows.WSMsgType.PING: + transport.send_pong(frame.get_payload_as_bytes()) + + super().on_ws_frame(transport, frame) + + async with ServerAsyncContext(server): + url = f"ws://127.0.0.1:{server.sockets[0].getsockname()[1]}" + (transport, listener) = await picows.ws_connect(ClientListener, url) + async with async_timeout.timeout(TIMEOUT): + await transport.wait_disconnected() + + assert len(listener.frames) == 2 + assert listener.frames[0].msg_type == picows.WSMsgType.PING + assert listener.frames[1].msg_type == picows.WSMsgType.CLOSE + assert listener.frames[1].close_code == picows.WSCloseCode.INTERNAL_ERROR diff --git a/tests/test_basics.py b/tests/test_basics.py index 1612e54..987eb10 100644 --- a/tests/test_basics.py +++ b/tests/test_basics.py @@ -2,17 +2,14 @@ import base64 import os import sys -import pathlib -import ssl -import socket import picows import pytest import async_timeout - -TIMEOUT = 0.5 - +from tests.utils import create_client_ssl_context, create_server_ssl_context, \ + TextFrame, CloseFrame, BinaryFrame, ServerAsyncContext, TIMEOUT, \ + materialize_frame if os.name == 'nt': @pytest.fixture( @@ -43,68 +40,6 @@ def event_loop_policy(request): assert False, "unknown loop" -def create_server_ssl_context(): - ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - ssl_context.load_cert_chain(pathlib.Path(__file__).parent / "picows_test.crt", - pathlib.Path(__file__).parent / "picows_test.key") - ssl_context.check_hostname = False - ssl_context.hostname_checks_common_name = False - ssl_context.verify_mode = ssl.CERT_NONE - return ssl_context - - -def create_client_ssl_context(): - ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) - ssl_context.load_default_certs(ssl.Purpose.SERVER_AUTH) - ssl_context.check_hostname = False - ssl_context.hostname_checks_common_name = False - ssl_context.verify_mode = ssl.CERT_NONE - return ssl_context - - -class BinaryFrame: - def __init__(self, frame: picows.WSFrame): - self.msg_type = frame.msg_type - self.payload_as_bytes = frame.get_payload_as_bytes() - self.payload_as_bytes_from_mv = bytes(frame.get_payload_as_memoryview()) - self.fin = frame.fin - self.rsv1 = frame.rsv1 - - -class TextFrame: - def __init__(self, frame: picows.WSFrame): - self.msg_type = frame.msg_type - self.payload_as_ascii_text = frame.get_payload_as_ascii_text() - self.payload_as_utf8_text = frame.get_payload_as_utf8_text() - self.fin = frame.fin - self.rsv1 = frame.rsv1 - - -class CloseFrame: - def __init__(self, frame: picows.WSFrame): - self.msg_type = frame.msg_type - self.close_code = frame.get_close_code() - self.close_message = frame.get_close_message() - self.fin = frame.fin - self.rsv1 = frame.rsv1 - - -class ServerAsyncContext: - def __init__(self, server): - self.server = server - self.server_task = asyncio.create_task(server.serve_forever()) - - async def __aenter__(self): - return await self.server.__aenter__() - - async def __aexit__(self, *exc): - self.server_task.cancel() - await self.server.__aexit__(*exc) - with pytest.raises(asyncio.CancelledError): - async with async_timeout.timeout(TIMEOUT): - await self.server_task - - @pytest.fixture(params=["plain", "ssl"]) async def echo_server(request): class PicowsServerListener(picows.WSListener): @@ -142,12 +77,7 @@ def on_ws_connected(self, transport: picows.WSTransport): self.is_paused = False def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): - if frame.msg_type == picows.WSMsgType.TEXT: - self.msg_queue.put_nowait(TextFrame(frame)) - elif frame.msg_type == picows.WSMsgType.CLOSE: - self.msg_queue.put_nowait(CloseFrame(frame)) - else: - self.msg_queue.put_nowait(BinaryFrame(frame)) + self.msg_queue.put_nowait(materialize_frame(frame)) def pause_writing(self): self.is_paused = True diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..edc5616 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,80 @@ +import asyncio +import pathlib +import ssl + +import async_timeout +import pytest + +import picows + +TIMEOUT = 0.5 + +class BinaryFrame: + def __init__(self, frame: picows.WSFrame): + self.msg_type = frame.msg_type + self.payload_as_bytes = frame.get_payload_as_bytes() + self.payload_as_bytes_from_mv = bytes(frame.get_payload_as_memoryview()) + self.fin = frame.fin + self.rsv1 = frame.rsv1 + + +class TextFrame: + def __init__(self, frame: picows.WSFrame): + self.msg_type = frame.msg_type + self.payload_as_ascii_text = frame.get_payload_as_ascii_text() + self.payload_as_utf8_text = frame.get_payload_as_utf8_text() + self.fin = frame.fin + self.rsv1 = frame.rsv1 + + +class CloseFrame: + def __init__(self, frame: picows.WSFrame): + self.msg_type = frame.msg_type + self.close_code = frame.get_close_code() + self.close_message = frame.get_close_message() + self.fin = frame.fin + self.rsv1 = frame.rsv1 + + +def materialize_frame(frame: picows.WSFrame): + if frame.msg_type == picows.WSMsgType.TEXT: + return TextFrame(frame) + elif frame.msg_type == picows.WSMsgType.CLOSE: + return CloseFrame(frame) + else: + return BinaryFrame(frame) + + +class ServerAsyncContext: + def __init__(self, server): + self.server = server + self.server_task = asyncio.create_task(server.serve_forever()) + + async def __aenter__(self): + return await self.server.__aenter__() + + async def __aexit__(self, *exc): + self.server_task.cancel() + await self.server.__aexit__(*exc) + with pytest.raises(asyncio.CancelledError): + async with async_timeout.timeout(TIMEOUT): + await self.server_task + + +def create_server_ssl_context(): + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_context.load_cert_chain(pathlib.Path(__file__).parent / "picows_test.crt", + pathlib.Path(__file__).parent / "picows_test.key") + ssl_context.check_hostname = False + ssl_context.hostname_checks_common_name = False + ssl_context.verify_mode = ssl.CERT_NONE + return ssl_context + + +def create_client_ssl_context(): + ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + ssl_context.load_default_certs(ssl.Purpose.SERVER_AUTH) + ssl_context.check_hostname = False + ssl_context.hostname_checks_common_name = False + ssl_context.verify_mode = ssl.CERT_NONE + return ssl_context \ No newline at end of file