diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index 4b51754d..1591eae1 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -12,7 +12,7 @@ import logging from random import randint import time -from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, cast +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable, TypeVar, cast import async_timeout from chip.ChipDeviceCtrl import DeviceProxyWrapper @@ -20,6 +20,7 @@ from chip.clusters.Attribute import ValueDecodeFailure from chip.clusters.ClusterObjects import ALL_ATTRIBUTES, ALL_CLUSTERS, Cluster from chip.exceptions import ChipStackError +from chip.native import PyChipError from zeroconf import IPVersion, ServiceStateChange, Zeroconf from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf @@ -113,7 +114,7 @@ def __init__( self.wifi_credentials_set: bool = False self.thread_credentials_set: bool = False self.compressed_fabric_id: int | None = None - self._node_lock: dict[int, asyncio.Lock] = {} + self._sdk_lock = asyncio.Lock() self._aiobrowser: AsyncServiceBrowser | None = None self._aiozc: AsyncZeroconf | None = None self._fallback_node_scanner_timer: asyncio.TimerHandle | None = None @@ -248,13 +249,13 @@ async def commission_with_code( attempts, MAX_COMMISSION_RETRIES, ) - success = await self._call_sdk( + success: bool | PyChipError = await self._call_sdk( self.chip_controller.CommissionWithCode, setupPayload=code, nodeid=node_id, networkOnly=network_only, ) - if success: + if not isinstance(success, PyChipError) and success: break if not success and attempts >= MAX_COMMISSION_RETRIES: raise NodeCommissionFailed( @@ -502,16 +503,17 @@ async def interview_node(self, node_id: int) -> None: try: if not (node := self._nodes.get(node_id)) or not node.available: await self._resolve_node(node_id=node_id) - async with self._get_node_lock(node_id): LOGGER.info("Interviewing node: %s", node_id) read_response: Attribute.AsyncReadTransaction.ReadResponse = ( - await self.chip_controller.Read( - nodeid=node_id, - attributes="*", - fabricFiltered=False, + await self._call_sdk( + self.chip_controller.Read( + nodeid=node_id, + attributes="*", + fabricFiltered=False, + ) ) ) - except ChipStackError as err: + except (ChipStackError, TimeoutError) as err: raise NodeInterviewFailed(f"Failed to interview node {node_id}") from err is_new_node = node_id not in self._nodes @@ -566,8 +568,8 @@ async def send_device_command( cluster_cls: Cluster = ALL_CLUSTERS[cluster_id] command_cls = getattr(cluster_cls.Commands, command_name) command = dataclass_from_dict(command_cls, payload, allow_sdk_types=True) - async with self._get_node_lock(node_id): - return await self.chip_controller.SendCommand( + return await self._call_sdk( + self.chip_controller.SendCommand( nodeid=node_id, endpoint=endpoint_id, payload=command, @@ -575,6 +577,7 @@ async def send_device_command( timedRequestTimeoutMs=timed_request_timeout_ms, interactionTimeoutMs=interaction_timeout_ms, ) + ) @api_command(APICommand.READ_ATTRIBUTE) async def read_attribute( @@ -586,14 +589,23 @@ async def read_attribute( if (node := self._nodes.get(node_id)) is None or not node.available: raise NodeNotReady(f"Node {node_id} is not (yet) available.") endpoint_id, cluster_id, attribute_id = parse_attribute_path(attribute_path) - assert self.server.loop is not None - async with self._get_node_lock(node_id): - future = self.server.loop.create_future() - device = await self._resolve_node(node_id) + + async def _do_read() -> Attribute.AsyncReadTransaction.ReadResponse: + """ + Read a list of attributes and/or events from a target node. + + This is basically a re-implementation of the chip controller's Read function + but one that allows us to send/request custom attributes. + """ + if TYPE_CHECKING: + assert self.server.loop + assert self.chip_controller + loop = self.server.loop + future = loop.create_future() Attribute.Read( future=future, - eventLoop=self.server.loop, - device=device.deviceProxy, + eventLoop=loop, + device=self.chip_controller.GetConnectedDeviceSync(node_id), devCtrl=self.chip_controller, attributes=[ Attribute.AttributePath( @@ -604,12 +616,16 @@ async def read_attribute( ], fabricFiltered=fabric_filtered, ).raise_on_error() - result: Attribute.AsyncReadTransaction.ReadResponse = await future - read_atributes = parse_attributes_from_read_result(result.tlvAttributes) - # update cached info in node attributes - self._nodes[node_id].attributes.update(read_atributes) - self._write_node_state(node_id) - return read_atributes + return await future + + result: Attribute.AsyncReadTransaction.ReadResponse = await self._call_sdk( + _do_read() + ) + read_atributes = parse_attributes_from_read_result(result.tlvAttributes) + # update cached info in node attributes + self._nodes[node_id].attributes.update(read_atributes) + self._write_node_state(node_id) + return read_atributes @api_command(APICommand.WRITE_ATTRIBUTE) async def write_attribute( @@ -634,11 +650,12 @@ async def write_attribute( value_type=attribute.attribute_type.Type, allow_sdk_types=True, ) - async with self._get_node_lock(node_id): - return await self.chip_controller.WriteAttribute( + return await self._call_sdk( + self.chip_controller.WriteAttribute( nodeid=node_id, attributes=[(endpoint_id, attribute)], ) + ) @api_command(APICommand.REMOVE_NODE) async def remove_node(self, node_id: int) -> None: @@ -677,12 +694,14 @@ async def remove_node(self, node_id: int) -> None: return result: Clusters.OperationalCredentials.Commands.NOCResponse | None = None try: - result = await self.chip_controller.SendCommand( - nodeid=node_id, - endpoint=0, - payload=Clusters.OperationalCredentials.Commands.RemoveFabric( - fabricIndex=fabric_index, - ), + result = await self._call_sdk( + self.chip_controller.SendCommand( + nodeid=node_id, + endpoint=0, + payload=Clusters.OperationalCredentials.Commands.RemoveFabric( + fabricIndex=fabric_index, + ), + ) ) except ChipStackError as err: LOGGER.warning( @@ -997,30 +1016,18 @@ def resubscription_succeeded( ) ) self._last_subscription_attempt[node_id] = 0 - future = loop.create_future() - device = await self._resolve_node(node_id) - async with async_timeout.timeout(DEFAULT_CALL_TIMEOUT): - Attribute.Read( - future=future, - eventLoop=loop, - device=device.deviceProxy, - devCtrl=self.chip_controller, - attributes=[Attribute.AttributePath()], # wildcard - events=[ - Attribute.EventPath( - EndpointId=None, Cluster=None, Event=None, Urgent=1 - ) - ], + sub: Attribute.SubscriptionTransaction = await self._call_sdk( + self.chip_controller.Read( + node_id, + attributes="*", + events=[("*", 1)], returnClusterObject=False, - subscriptionParameters=Attribute.SubscriptionParameters( - interval_floor, interval_ceiling - ), - # Use fabricfiltered as False to detect changes made by other controllers - # and to be able to provide a list of all fabrics attached to the device + reportInterval=(interval_floor, interval_ceiling), fabricFiltered=False, + keepSubscriptions=True, autoResubscribe=True, - ).raise_on_error() - sub: Attribute.SubscriptionTransaction = await future + ) + ) sub.SetAttributeUpdateCallback(attribute_updated_callback) sub.SetEventUpdateCallback(event_callback) @@ -1049,7 +1056,7 @@ def _get_next_node_id(self) -> int: async def _call_sdk( self, - func: Callable[..., _T], + target: Callable[..., _T] | Awaitable[_T], *args: Any, call_timeout: int = DEFAULT_CALL_TIMEOUT, **kwargs: Any, @@ -1058,15 +1065,18 @@ async def _call_sdk( if self.server.loop is None: raise RuntimeError("Server not started.") - # prevent a single job in the executor blocking everything with a timeout. - async with async_timeout.timeout(call_timeout): - return cast( - _T, - await self.server.loop.run_in_executor( - self._sdk_executor, - partial(func, *args, **kwargs), - ), + # handle both awaitables and sync functions here + if not asyncio.iscoroutine(target): + target = self.server.loop.run_in_executor( + self._sdk_executor, + partial(target, *args, **kwargs), # type: ignore[arg-type] ) + # we guard all calls to the sdk with a lock because we have no good way + # of knowing if all code in the python wrapper is thread safe. + # The additional timeout is a guard to prevent ourselves from deadlocking somehow. + async with async_timeout.timeout(call_timeout): + async with self._sdk_lock: + return cast(_T, await target) async def _setup_node(self, node_id: int) -> None: """Handle set-up of subscriptions and interview (if needed) for known/discovered node.""" @@ -1112,7 +1122,7 @@ async def _setup_node(self, node_id: int) -> None: # setup subscriptions for the node try: await self._subscribe_node(node_id) - except (NodeNotResolving, TimeoutError) as err: + except (NodeNotResolving, TimeoutError, ChipStackError) as err: LOGGER.warning( "Unable to subscribe to Node %s: %s", node_id, @@ -1253,12 +1263,6 @@ async def _on_mdns_commissionable_node_state( await info.async_request(self._aiozc.zeroconf, 3000) LOGGER.debug("Discovered commissionable Matter node using MDNS: %s", info) - def _get_node_lock(self, node_id: int) -> asyncio.Lock: - """Return lock for given node.""" - if node_id not in self._node_lock: - self._node_lock[node_id] = asyncio.Lock() - return self._node_lock[node_id] - def _write_node_state(self, node_id: int, force: bool = False) -> None: """Schedule the write of the current node state to persistent storage.""" node = self._nodes[node_id]