From 8214e08bd78652f5644bce7b2e6be22e955be771 Mon Sep 17 00:00:00 2001 From: Marcel van der Veldt Date: Fri, 1 Mar 2024 11:26:10 +0100 Subject: [PATCH] Refactor locking --- matter_server/server/device_controller.py | 153 ++++++++++------------ 1 file changed, 72 insertions(+), 81 deletions(-) diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index 1384e344..642b2586 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -12,7 +12,7 @@ import logging from random import randint import time -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable, TypeVar, cast +from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, cast from chip.ChipDeviceCtrl import DeviceProxyWrapper from chip.clusters import Attribute, Objects as Clusters @@ -502,15 +502,13 @@ async def interview_node(self, node_id: int) -> None: await self._resolve_node(node_id=node_id) LOGGER.info("Interviewing node: %s", node_id) read_response: Attribute.AsyncReadTransaction.ReadResponse = ( - await self._call_sdk( - self.chip_controller.Read( - nodeid=node_id, - attributes="*", - fabricFiltered=False, - ) + await self.chip_controller.Read( + nodeid=node_id, + attributes="*", + fabricFiltered=False, ) ) - except (ChipStackError, TimeoutError) as err: + except ChipStackError as err: raise NodeInterviewFailed(f"Failed to interview node {node_id}") from err is_new_node = node_id not in self._nodes @@ -565,15 +563,13 @@ async def send_device_command( cluster_cls: Cluster = ALL_CLUSTERS[cluster_id] command_cls = getattr(cluster_cls.Commands, command_name) command = dataclass_from_dict(command_cls, payload, allow_sdk_types=True) - return await self._call_sdk( - self.chip_controller.SendCommand( - nodeid=node_id, - endpoint=endpoint_id, - payload=command, - responseType=response_type, - timedRequestTimeoutMs=timed_request_timeout_ms, - interactionTimeoutMs=interaction_timeout_ms, - ) + return await self.chip_controller.SendCommand( + nodeid=node_id, + endpoint=endpoint_id, + payload=command, + responseType=response_type, + timedRequestTimeoutMs=timed_request_timeout_ms, + interactionTimeoutMs=interaction_timeout_ms, ) @api_command(APICommand.READ_ATTRIBUTE) @@ -586,38 +582,32 @@ async def read_attribute( if (node := self._nodes.get(node_id)) is None or not node.available: raise NodeNotReady(f"Node {node_id} is not (yet) available.") endpoint_id, cluster_id, attribute_id = parse_attribute_path(attribute_path) + device = await self._resolve_node(node_id) + # Read a list of attributes and/or events from a target node. + # This is basically a re-implementation of the chip controller's Read function + # but one that allows us to send/request custom attributes. - async def _do_read() -> Attribute.AsyncReadTransaction.ReadResponse: - """ - Read a list of attributes and/or events from a target node. - - This is basically a re-implementation of the chip controller's Read function - but one that allows us to send/request custom attributes. - """ - if TYPE_CHECKING: - assert self.server.loop - assert self.chip_controller - loop = self.server.loop - future = loop.create_future() - Attribute.Read( - future=future, - eventLoop=loop, - device=self.chip_controller.GetConnectedDeviceSync(node_id), - devCtrl=self.chip_controller, - attributes=[ - Attribute.AttributePath( - EndpointId=endpoint_id, - ClusterId=cluster_id, - AttributeId=attribute_id, - ) - ], - fabricFiltered=fabric_filtered, - ).raise_on_error() - return await future + if TYPE_CHECKING: + assert self.server.loop + assert self.chip_controller + loop = self.server.loop + future = loop.create_future() + Attribute.Read( + future=future, + eventLoop=loop, + device=device, + devCtrl=self.chip_controller, + attributes=[ + Attribute.AttributePath( + EndpointId=endpoint_id, + ClusterId=cluster_id, + AttributeId=attribute_id, + ) + ], + fabricFiltered=fabric_filtered, + ).raise_on_error() - result: Attribute.AsyncReadTransaction.ReadResponse = await self._call_sdk( - _do_read() - ) + result: Attribute.AsyncReadTransaction.ReadResponse = await future read_atributes = parse_attributes_from_read_result(result.tlvAttributes) # update cached info in node attributes self._nodes[node_id].attributes.update(read_atributes) @@ -647,11 +637,9 @@ async def write_attribute( value_type=attribute.attribute_type.Type, allow_sdk_types=True, ) - return await self._call_sdk( - self.chip_controller.WriteAttribute( - nodeid=node_id, - attributes=[(endpoint_id, attribute)], - ) + return await self.chip_controller.WriteAttribute( + nodeid=node_id, + attributes=[(endpoint_id, attribute)], ) @api_command(APICommand.REMOVE_NODE) @@ -737,7 +725,9 @@ async def subscribe_attribute( ) @api_command(APICommand.PING_NODE) - async def ping_node(self, node_id: int, attempts: int = 1) -> NodePingResult: + async def ping_node( + self, node_id: int, attempts: int = 1, allow_cached_ips: bool = True + ) -> NodePingResult: """Ping node on the currently known IP-adress(es).""" result: NodePingResult = {} node = self._nodes.get(node_id) @@ -771,7 +761,7 @@ async def _do_ping(ip_address: str) -> None: result[clean_ip] = await ping_ip(ip_address, timeout, attempts=attempts) ip_addresses = await self.get_node_ip_addresses( - node_id, prefer_cache=False, scoped=True + node_id, prefer_cache=False, scoped=True, allow_cache=allow_cached_ips ) tasks = [_do_ping(x) for x in ip_addresses] # TODO: replace this gather with a taskgroup once we bump our py version @@ -796,7 +786,11 @@ async def _do_ping(ip_address: str) -> None: @api_command(APICommand.GET_NODE_IP_ADRESSES) async def get_node_ip_addresses( - self, node_id: int, prefer_cache: bool = False, scoped: bool = False + self, + node_id: int, + prefer_cache: bool = False, + scoped: bool = False, + allow_cache: bool = True, ) -> list[str]: """Return the currently known (scoped) IP-adress(es).""" cached_info = self._last_known_ip_addresses.get(node_id, []) @@ -814,7 +808,7 @@ async def get_node_ip_addresses( info = AsyncServiceInfo(MDNS_TYPE_OPERATIONAL_NODE, mdns_name) if TYPE_CHECKING: assert self._aiozc is not None - if not await info.async_request(self._aiozc.zeroconf, 3000): + if not await info.async_request(self._aiozc.zeroconf, 3000) and allow_cache: node_logger.info( "Node could not be discovered on the network, returning cached IP's" ) @@ -1013,17 +1007,15 @@ def resubscription_succeeded( ) ) self._last_subscription_attempt[node_id] = 0 - sub: Attribute.SubscriptionTransaction = await self._call_sdk( - self.chip_controller.Read( - node_id, - attributes="*", - events=[("*", 1)], - returnClusterObject=False, - reportInterval=(interval_floor, interval_ceiling), - fabricFiltered=False, - keepSubscriptions=True, - autoResubscribe=True, - ) + sub: Attribute.SubscriptionTransaction = await self.chip_controller.Read( + node_id, + attributes="*", + events=[("*", 1)], + returnClusterObject=False, + reportInterval=(interval_floor, interval_ceiling), + fabricFiltered=False, + keepSubscriptions=True, + autoResubscribe=True, ) sub.SetAttributeUpdateCallback(attribute_updated_callback) @@ -1053,7 +1045,7 @@ def _get_next_node_id(self) -> int: async def _call_sdk( self, - target: Callable[..., _T] | Awaitable[_T], + target: Callable[..., _T], *args: Any, **kwargs: Any, ) -> _T: @@ -1061,16 +1053,13 @@ async def _call_sdk( if self.server.loop is None: raise RuntimeError("Server not started.") - # handle both awaitables and sync functions here - if not asyncio.iscoroutine(target): - target = self.server.loop.run_in_executor( + return cast( + _T, + await self.server.loop.run_in_executor( self._sdk_executor, - partial(target, *args, **kwargs), # type: ignore[arg-type] - ) - # we guard all calls to the sdk with a lock because we have no good way - # of knowing if all code in the python wrapper is thread safe. - async with self._sdk_lock: - return cast(_T, await target) + partial(target, *args, **kwargs), + ), + ) async def _setup_node(self, node_id: int) -> None: """Handle set-up of subscriptions and interview (if needed) for known/discovered node.""" @@ -1084,7 +1073,9 @@ async def _setup_node(self, node_id: int) -> None: # Ping the node to rule out stale mdns reports and to prevent that we # send an unreachable node to the sdk which is very slow with resolving it. # This will also precache the ip addresses of the node for later use. - ping_result = await self.ping_node(node_id, attempts=3) + ping_result = await self.ping_node( + node_id, attempts=3, allow_cached_ips=False + ) if not any(ping_result.values()): LOGGER.warning( "Skip set-up for node %s because it does not appear to be reachable...", @@ -1116,7 +1107,7 @@ async def _setup_node(self, node_id: int) -> None: # setup subscriptions for the node try: await self._subscribe_node(node_id) - except (NodeNotResolving, TimeoutError, ChipStackError) as err: + except (NodeNotResolving, ChipStackError) as err: LOGGER.warning( "Unable to subscribe to Node %s: %s", node_id, @@ -1150,7 +1141,7 @@ async def _resolve_node( allowPASE=False, timeoutMs=None, ) - except (ChipStackError, TimeoutError) as err: + except ChipStackError as err: if attempt >= retries: # when we're out of retries, raise NodeNotResolving raise NodeNotResolving(f"Unable to resolve Node {node_id}") from err @@ -1292,7 +1283,7 @@ async def _fallback_node_scanner(self) -> None: last_seen = self._node_last_seen.get(node_id, 0) if now - last_seen < FALLBACK_NODE_SCANNER_INTERVAL: continue - if await self.ping_node(node_id, attempts=3): + if await self.ping_node(node_id, attempts=3, allow_cached_ips=False): LOGGER.info("Node %s discovered using fallback ping", node_id) await self._setup_node(node_id)