diff --git a/awscrt/websocket.py b/awscrt/websocket.py index b8c1d9cdf..f93c8ea11 100644 --- a/awscrt/websocket.py +++ b/awscrt/websocket.py @@ -3,9 +3,17 @@ Use the :func:`connect()` to establish a :class:`WebSocket` client connection. +Note from the developer: This is a very low-level API, which forces the +user to deal with things like data fragmentation. +A higher-level API could easily be built on top of this. + +.. _authoring-callbacks: + +Authoring Callbacks +------------------- All network operations in `awscrt.websocket` are asynchronous. Callbacks are always invoked on the WebSocket's networking thread. -You MUST NOT perform blocking operations from any callback, or you will cause a deadlock. +You MUST NOT perform blocking network operations from any callback, or you will cause a deadlock. For example: do not send a frame, and then wait for that frame to complete, within a callback. The WebSocket cannot do work until your callback returns, so the thread will be stuck. You can send the frame from within the callback, @@ -17,9 +25,66 @@ All functions and methods in `awscrt.websocket` are thread-safe. They can be called from any mix of threads. -Note from the developer: This is a very low-level API, which forces the -user to deal with things like data fragmentation. -A higher-level API could easily be built on top of this. +.. _flow-control-reading: + +Flow Control (reading) +---------------------- +By default, the WebSocket will read from the network as fast as it can hand you the data. +You must prevent the WebSocket from reading data faster than you can process it, +or memory usage could balloon until your application explodes. + +There are two ways to manage this. + +First, and simplest, is to process incoming data synchronously within the +`on_incoming_frame` callbacks. Since callbacks are invoked on the WebSocket's +networking thread, the WebSocket cannot read more data until the callback returns. +Therefore, processing the data in a synchronous manner +(i.e. writing to disk, printing to screen, etc) will naturally +affect `TCP flow control `_, +and prevent data from arriving too fast. However, you MUST NOT perform a blocking +network operation from within the callback or you risk deadlock (see :ref:`authoring-callbacks`). + +The second, more complex, way requires you to manage the size of the read window. +Do this if you are processing the data asynchronously +(i.e. sending the data along on another network connection). +Create the WebSocket with `manage_read_window` set true, +and set `initial_read_window` to the number of bytes you are ready to receive right away. +Whenever the read window reaches 0, you will stop receiving anything. +The read window shrinks as you receive the payload from "data" frames (TEXT, BINARY, CONTINUATION). +Call :meth:`WebSocket.increment_read_window()` to increase the window again keep frames flowing in. +You only need to worry about the payload from "data" frames. +The WebSocket automatically increments its window to account for any +other incoming bytes, including other parts of a frame (opcode, payload-length, etc) +and the payload of other frame types (PING, PONG, CLOSE). +You'll probably want to do it like this: +Pick the max amount of memory to buffer, and set this as the `initial_read_window`. +When data arrives, the window has shrunk by that amount. +Send this data along on the other network connection. +When that data is done sending, call `increment_read_window()` +by the amount you just finished sending. +If you don't want to receive any data at first, set the `initial_read_window` to 0, +and `increment_read_window()` when you're ready. +Maintaining a larger window is better for overall throughput. + +.. _flow-control-writing: + +Flow Control (writing) +---------------------- +You must also ensure that you do not continually send frames faster than the other +side can read them, or memory usage could balloon until your application explodes. + +The simplest approach is to only send 1 frame at a time. +Use the :meth:`WebSocket.send_frame()` `on_complete` callback to know when the send is complete. +Then you can try and send another. + +A more complex, but higher throughput, way is to let multiple frames be in flight +but have a cap. If the number of frames in flight, or bytes in flight, reaches +your cap then wait until some frames complete before trying to send more. + +.. _api: + +API +--- """ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. @@ -84,6 +149,17 @@ class Opcode(IntEnum): See `RFC 6455 section 5.5.3 `_. """ + def is_data_frame(self): + """True if this is a "data frame" opcode. + + TEXT, BINARY, and CONTINUATION are "data frames". The rest are "control" frames. + + If the WebSocket was created with `manage_read_window`, + then the read window shrinks as "data frames" are received. + See :ref:`flow-control-reading` for a thorough explanation. + """ + return self.value in (Opcode.TEXT, Opcode.BINARY, Opcode.CONTINUATION) + MAX_PAYLOAD_LENGTH = 0x7FFFFFFFFFFFFFFF """The maximum frame payload length allowed by RFC 6455""" @@ -164,6 +240,17 @@ class IncomingFrame: See `RFC 6455 section 5.4 - Fragmentation `_""" + def is_data_frame(self): + """True if this is a "data frame". + + TEXT, BINARY, and CONTINUATION are "data frames". The rest are "control frames". + + If the WebSocket was created with `manage_read_window`, + then the read window shrinks as "data frames" are received. + See :ref:`flow-control-reading` for a thorough explanation. + """ + return self.opcode.is_data_frame() + @dataclass class OnIncomingFrameBeginData: @@ -186,6 +273,11 @@ class OnIncomingFramePayloadData: Once all `frame.payload_length` bytes have been received (or the network connection is lost), the `on_incoming_frame_complete` callback will be invoked. + + If the WebSocket was created with `manage_read_window`, + and this is a "data frame" (TEXT, BINARY, CONTINUATION), + then the read window shrinks by `len(data)`. + See :ref:`flow-control-reading` for a thorough explanation. """ frame: IncomingFrame @@ -261,8 +353,8 @@ def send_frame( If you are not an expert, stick to sending :attr:`Opcode.TEXT` or :attr:`Opcode.BINARY` frames, and don't touch the FIN bit. - If you want to limit the amount of unsent data buffered in memory, - wait until one frame completes before sending another. + See :ref:`flow-control-writing` to learn about limiting the amount of + unsent data buffered in memory. Args: opcode: :class:`Opcode` for this frame. @@ -286,7 +378,7 @@ def send_frame( or even guarantee that the data has left the machine yet, but it's on track to get there). - Read the :mod:`page notes` before authoring any callbacks. + Be sure to read about :ref:`authoring-callbacks`. """ def _on_complete(error_code): cbdata = OnSendFrameCompleteData() @@ -309,6 +401,21 @@ def _on_complete(error_code): fin, _on_complete) + def increment_read_window(self, size: int): + """Manually increment the read window by this many bytes, to continue receiving frames. + + See :ref:`flow-control-reading` for a thorough explanation. + If the WebSocket was created without `manage_read_window`, this function does nothing. + This function may be called from any thread. + + Args: + size: in bytes + """ + if size < 0: + raise ValueError("Increment size cannot be negative") + + _awscrt.websocket_increment_read_window(self._binding, size) + class _WebSocketCore(NativeResource): # Private class that handles wrangling callback data from C -> Python. @@ -431,13 +538,13 @@ def connect( socket_options: Optional[SocketOptions] = None, tls_connection_options: Optional[TlsConnectionOptions] = None, proxy_options: Optional[HttpProxyOptions] = None, + manage_read_window: bool = False, + initial_read_window: Optional[int] = None, on_connection_setup: Callable[[OnConnectionSetupData], None], on_connection_shutdown: Optional[Callable[[OnConnectionShutdownData], None]] = None, on_incoming_frame_begin: Optional[Callable[[OnIncomingFrameBeginData], None]] = None, on_incoming_frame_payload: Optional[Callable[[OnIncomingFramePayloadData], None]] = None, on_incoming_frame_complete: Optional[Callable[[OnIncomingFrameCompleteData], None]] = None, - enable_read_backpressure: bool = False, - initial_read_window: Optional[int] = None, ): """Asynchronously establish a client WebSocket connection. @@ -459,7 +566,7 @@ def connect( done with a healthy WebSocket, to ensure that it shuts down and cleans up. It is very easy to accidentally keep a reference around without realizing it. - Read the :mod:`page notes` before authoring your callbacks. + Be sure to read about :ref:`authoring-callbacks`. Args: host: Hostname to connect to. @@ -491,6 +598,17 @@ def connect( proxy_options: HTTP Proxy options. If not specified, no proxy is used. + manage_read_window: Set true to manually manage the flow-control read window. + If false (the default), data arrives as fast as possible. + See :ref:`flow-control-reading` for a thorough explanation. + + initial_read_window: The initial size of the read window, in bytes. + This must be set if `manage_read_window` is true, + otherwise it is ignored. + See :ref:`flow-control-reading` for a thorough explanation. + An initial size of 0 will prevent any frames from arriving + until :meth:`WebSocket.increment_read_window()` is called. + on_connection_setup: Callback invoked when the connect completes. Takes a single :class:`OnConnectionSetupData` argument. @@ -526,6 +644,10 @@ def connect( on_incoming_frame_payload: Optional callback, invoked 0+ times as payload data arrives. Takes a single :class:`OnIncomingFramePayloadData` argument. + If `manage_read_window` is on, and this is a "data frame", + then the read window shrinks accordingly. + See :ref:`flow-control-reading` for a thorough explanation. + If this callback raises an exception, the connection will shut down. on_incoming_frame_complete: Optional callback, invoked when the WebSocket @@ -538,12 +660,11 @@ def connect( If this callback raises an exception, the connection will shut down. """ - # TODO: document backpressure - if enable_read_backpressure: + if manage_read_window: if initial_read_window is None: - raise ValueError("'initial_read_window' must be set if 'enable_read_backpressure' is enabled") + raise ValueError("'initial_read_window' must be set if 'manage_read_window' is enabled") else: - initial_read_window = 0x7FFFFFFF # TODO: fix how this works in C + initial_read_window = 0 # value is ignored anyway if initial_read_window < 0: raise ValueError("'initial_read_window' cannot be negative") @@ -572,7 +693,7 @@ def connect( socket_options, tls_connection_options, proxy_options, - enable_read_backpressure, + manage_read_window, initial_read_window, core) diff --git a/crt/aws-c-event-stream b/crt/aws-c-event-stream index bfbf25451..e812dd4df 160000 --- a/crt/aws-c-event-stream +++ b/crt/aws-c-event-stream @@ -1 +1 @@ -Subproject commit bfbf254517513e3a5dd90bb0ee417cd98eb89850 +Subproject commit e812dd4df0fcc350ad7b5b7babe82cfe5664f4a4 diff --git a/crt/aws-c-http b/crt/aws-c-http index 5400050bf..4e82c1e50 160000 --- a/crt/aws-c-http +++ b/crt/aws-c-http @@ -1 +1 @@ -Subproject commit 5400050bff7a730f81be3f238a702bea73570e00 +Subproject commit 4e82c1e5022d3dd4d6eda4b8fa9cdba6e6def050 diff --git a/crt/aws-c-io b/crt/aws-c-io index 453c48b02..6c19e25f5 160000 --- a/crt/aws-c-io +++ b/crt/aws-c-io @@ -1 +1 @@ -Subproject commit 453c48b02a0886407d4b5c376b5a39fa60c30f7e +Subproject commit 6c19e25f55fa060d4e42010756967b979361dc66 diff --git a/source/websocket.c b/source/websocket.c index d2b35a047..56eee7702 100644 --- a/source/websocket.c +++ b/source/websocket.c @@ -62,7 +62,7 @@ PyObject *aws_py_websocket_client_connect(PyObject *self, PyObject *args) { PyObject *socket_options_py; /* O */ PyObject *tls_options_py; /* O */ PyObject *proxy_options_py; /* O */ - int enable_read_backpressure; /* p - boolean predicate */ + int manage_read_window; /* p - boolean predicate */ Py_ssize_t initial_read_window; /* n */ PyObject *websocket_core_py; /* O */ @@ -77,7 +77,7 @@ PyObject *aws_py_websocket_client_connect(PyObject *self, PyObject *args) { &socket_options_py, &tls_options_py, &proxy_options_py, - &enable_read_backpressure, + &manage_read_window, &initial_read_window, &websocket_core_py)) { return NULL; @@ -142,7 +142,7 @@ PyObject *aws_py_websocket_client_connect(PyObject *self, PyObject *args) { .on_incoming_frame_begin = s_websocket_on_incoming_frame_begin, .on_incoming_frame_payload = s_websocket_on_incoming_frame_payload, .on_incoming_frame_complete = s_websocket_on_incoming_frame_complete, - .manual_window_management = enable_read_backpressure != 0, + .manual_window_management = manage_read_window != 0, }; if (aws_websocket_client_connect(&options) != AWS_OP_SUCCESS) { PyErr_SetAwsLastError(); @@ -516,10 +516,23 @@ PyObject *aws_py_websocket_send_frame(PyObject *self, PyObject *args) { } PyObject *aws_py_websocket_increment_read_window(PyObject *self, PyObject *args) { - /* TODO implement */ (void)self; - (void)args; - return NULL; + + PyObject *binding_py; /* O */ + Py_ssize_t size; /* n */ + + if (!PyArg_ParseTuple(args, "On", &binding_py, &size)) { + return NULL; + } + + struct aws_websocket *websocket = PyCapsule_GetPointer(binding_py, s_websocket_capsule_name); + if (!websocket) { + return NULL; + } + + /* already checked that size was non-negative out in python */ + aws_websocket_increment_read_window(websocket, (size_t)size); + Py_RETURN_NONE; } PyObject *aws_py_websocket_create_handshake_request(PyObject *self, PyObject *args) { diff --git a/test/test_mqtt5.py b/test/test_mqtt5.py index db77ec792..117592944 100644 --- a/test/test_mqtt5.py +++ b/test/test_mqtt5.py @@ -539,7 +539,7 @@ def test_connect_with_invalid_port(self): def test_connect_with_invalid_port_for_websocket_connection(self): client_options = mqtt5.ClientOptions("will be set by _create_client", 1883) client, callbacks = self._test_connect_fail( - auth_type=AuthType.WS_BAD_PORT, client_options=client_options, expected_error_code=46) + auth_type=AuthType.WS_BAD_PORT, client_options=client_options) client.stop() callbacks.future_stopped.result(TIMEOUT) diff --git a/test/test_websocket.py b/test/test_websocket.py index 4f1541419..fd09949da 100644 --- a/test/test_websocket.py +++ b/test/test_websocket.py @@ -10,11 +10,12 @@ from io import StringIO import logging from os import urandom -from queue import Queue +from queue import Empty, Queue +import secrets import socket from test import NativeResourceTest import threading -from time import sleep +from time import sleep, time from typing import Optional # using a 3rdparty websocket library for the server @@ -46,7 +47,7 @@ def __init__(self): self.incoming_frame_payload = bytearray() self.exception = None - def connect_sync(self, host, port): + def connect_sync(self, host, port, **connect_kwargs): connect(host=host, port=port, handshake_request=create_handshake_request(host=host), @@ -54,7 +55,8 @@ def connect_sync(self, host, port): on_connection_shutdown=self._on_connection_shutdown, on_incoming_frame_begin=self._on_incoming_frame_begin, on_incoming_frame_payload=self._on_incoming_frame_payload, - on_incoming_frame_complete=self._on_incoming_frame_complete) + on_incoming_frame_complete=self._on_incoming_frame_complete, + **connect_kwargs) # wait for on_connection_setup to fire setup_data = self.setup_future.result(TIMEOUT) assert setup_data.exception is None @@ -128,6 +130,8 @@ def __enter__(self): # don't return until the server signals that it's started up and is listening for connections assert self._server_started_event.wait(TIMEOUT) + return self + def __exit__(self, exc_type, exc_value, exc_tb): # main thread is exiting the `with` block: tell the server to stop... @@ -158,6 +162,7 @@ async def _run_asyncio_server(self): async def _run_connection(self, server_connection: websockets_server_3rdparty.WebSocketServerProtocol): # this coroutine runs once for each connection to the server # when this coroutine exits, the connection gets shut down + self._current_connection = server_connection try: # await each message... async for msg in server_connection: @@ -170,6 +175,12 @@ async def _run_connection(self, server_connection: websockets_server_3rdparty.We # even if the connection ends cleanly, so just swallow it pass + finally: + self._current_connection = None + + def send_async(self, msg): + asyncio.run_coroutine_threadsafe(self._current_connection.send(msg), self._server_loop) + class TestClient(NativeResourceTest): def setUp(self): @@ -504,3 +515,55 @@ def bad_incoming_frame_callback(data): # wait for the frame to echo back, firing the bad callback, # which raises an exception, which should result in the WebSocket closing shutdown_future.result(TIMEOUT) + + def test_manage_read_window(self): + # test that users can manage how much data is read by managing the read window + with WebSocketServer(self.host, self.port) as server: + handler = ClientHandler() + handler.connect_sync(self.host, self.port, manage_read_window=True, initial_read_window=1000) + + # client's read window is 1000-bytes + # have the server send 10 messages with 100-byte payloads + # they should all get through + + for i in range(10): + msg = secrets.token_bytes(100) # random msg for server to send + server.send_async(msg) + recv: RecvFrame = handler.complete_frames.get(timeout=TIMEOUT) + self.assertEqual(recv.payload, msg, "did not receive expected payload") + + # client window is now 0 + # have server send a 1000 byte message, NONE of its payload should arrive + + msg = secrets.token_bytes(1000) # random msg for server to send + server.send_async(msg) + with self.assertRaises(Empty): + handler.complete_frames.get(timeout=1.0) + self.assertEqual(len(handler.incoming_frame_payload), 0, "No payload should arrive while window is 0") + + # now increment client's window to 500 + # half (500/1000) the bytes should flow in + + handler.websocket.increment_read_window(500) + max_wait_until = time() + TIMEOUT + while len(handler.incoming_frame_payload) < 500: + sleep(0.001) + self.assertLess(time(), max_wait_until, "timed out waiting for all bytes") + sleep(1.0) # sleep a moment to be sure we don't receive MORE than 500 bytes + self.assertEqual(len(handler.incoming_frame_payload), 500, "received more bytes than expected") + + # client's window is 0 again, 500 bytes are still waiting to flow in + # increment the window to let the rest in + # let's do it by calling increment a bunch of times in a row, just to be different + + handler.websocket.increment_read_window(100) + handler.websocket.increment_read_window(100) + handler.websocket.increment_read_window(100) + handler.websocket.increment_read_window(100) + handler.websocket.increment_read_window(100) + + recv: RecvFrame = handler.complete_frames.get(timeout=TIMEOUT) + self.assertEqual(recv.payload, msg, "did not receive expected payload") + + # done! + handler.close_sync()