From e3020ff49fcf10438ce923ea53bc82a710684070 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 13 Dec 2022 14:44:15 -1000 Subject: [PATCH] Significantly reduce heartbeat overhead (#301) --- aioshelly/rpc_device/wsrpc.py | 72 ++++++++++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 2 deletions(-) diff --git a/aioshelly/rpc_device/wsrpc.py b/aioshelly/rpc_device/wsrpc.py index d422615c..4ef5f251 100644 --- a/aioshelly/rpc_device/wsrpc.py +++ b/aioshelly/rpc_device/wsrpc.py @@ -33,6 +33,8 @@ _LOGGER = logging.getLogger(__name__) +WS_HEATBEAT_HALF_INTERVAL = WS_HEARTBEAT / 2 + def _receive_json_or_raise(msg: WSMessage) -> dict[str, Any]: """Receive json or raise.""" @@ -147,6 +149,10 @@ def __init__(self, ip_address: str, on_notification: Callable) -> None: self._calls: dict[int, RPCCall] = {} self._call_id = 0 self._session = SessionData(f"aios-{id(self)}", None, None) + self._heartbeat_cb: asyncio.TimerHandle | None = None + self._pong_response_cb: asyncio.TimerHandle | None = None + self._loop = asyncio.get_running_loop() + self._last_time: float = 0 @property def _next_id(self) -> int: @@ -161,7 +167,7 @@ async def connect(self, aiohttp_session: aiohttp.ClientSession) -> None: _LOGGER.debug("Trying to connect to device at %s", self._ip_address) try: self._client = await aiohttp_session.ws_connect( - f"http://{self._ip_address}/rpc", heartbeat=WS_HEARTBEAT + f"http://{self._ip_address}/rpc", autoping=False ) except ( client_exceptions.WSServerHandshakeError, @@ -170,12 +176,64 @@ async def connect(self, aiohttp_session: aiohttp.ClientSession) -> None: raise DeviceConnectionError(err) from err self._rx_task = asyncio.create_task(self._rx_msgs()) - + self._schedule_heartbeat() _LOGGER.info("Connected to %s", self._ip_address) + def _cancel_heartbeat(self) -> None: + """Cancel heartbeat.""" + if self._heartbeat_cb is not None: + self._heartbeat_cb.cancel() + self._heartbeat_cb = None + + def _cancel_pong_response_cb(self) -> None: + """Cancel pong response callback.""" + if self._pong_response_cb is not None: + self._pong_response_cb.cancel() + self._pong_response_cb = None + + def _cancel_heatbeat_and_pong_response_cb(self) -> None: + """Cancel heartbeat and pong response callback.""" + self._cancel_heartbeat() + self._cancel_pong_response_cb() + + def _schedule_heartbeat(self) -> None: + """Schedule heartbeat.""" + self._cancel_heatbeat_and_pong_response_cb() + self._heartbeat_cb = self._loop.call_later( + WS_HEATBEAT_HALF_INTERVAL, self._maybe_send_heartbeat + ) + + def _schedule_pong_response_cb(self) -> None: + """Schedule pong response callback.""" + self._cancel_pong_response_cb() + self._pong_response_cb = self._loop.call_later( + WS_HEATBEAT_HALF_INTERVAL, self._pong_not_received + ) + + def _maybe_send_heartbeat(self) -> None: + """Send heartbeat if no messages have been received in the last WS_HEARTBEAT seconds.""" + if not self._client or self._client.closed: + return + if time.time() - self._last_time < WS_HEARTBEAT: + # No need to send heartbeat + # so schedule next heartbeat + self._schedule_heartbeat() + return + self._schedule_pong_response_cb() + asyncio.create_task(self._client.ping()) + + def _pong_not_received(self) -> None: + """Pong not received.""" + _LOGGER.error( + "%s: Pong not received, device is likely unresponsive; disconnecting", + self._ip_address, + ) + asyncio.create_task(self.disconnect()) + async def disconnect(self) -> None: """Disconnect all sessions.""" self._rx_task = None + self._cancel_heatbeat_and_pong_response_cb() if self._client is None: return @@ -239,17 +297,27 @@ async def _rx_msgs(self) -> None: while not self._client.closed: try: msg = await self._client.receive() + self._last_time = time.time() + if msg.type == WSMsgType.PONG: + self._schedule_heartbeat() + continue + if msg.type == WSMsgType.PING: + await self._client.pong(msg.data) + continue frame = _receive_json_or_raise(msg) _LOGGER.debug("recv(%s): %s", self._ip_address, frame) except InvalidMessage as err: _LOGGER.error("Invalid Message from host %s: %s", self._ip_address, err) except ConnectionClosed: break + except Exception: # pylint: disable=broad-except + _LOGGER.exception("Unexpected error while receiving message") if not self._client.closed: self.handle_frame(frame) _LOGGER.debug("Websocket client connection from %s closed", self._ip_address) + self._cancel_heatbeat_and_pong_response_cb() for call_item in self._calls.values(): if not call_item.resolve.done():