From beb3d64a657647ef4096dd8892077a1ae8c9506a Mon Sep 17 00:00:00 2001 From: Marcel van der Veldt Date: Tue, 27 Jun 2023 11:57:34 +0200 Subject: [PATCH] Refactor subscription logic (#335) Refactor of the subscription logic for both optimizations reasons and less traffic (prevent congestion). Run interviews and subscriptions in parallel tasks to speedup startup. Only subscribe to specific attributes Detect endpoint additions and removals on bridges Co-authored-by: Martin Hjelmare --- README.md | 4 +- matter_server/client/client.py | 42 +- matter_server/client/connection.py | 12 +- matter_server/client/models/node.py | 10 +- matter_server/common/const.py | 2 +- matter_server/common/helpers/api.py | 2 +- matter_server/common/helpers/json.py | 5 +- matter_server/common/helpers/util.py | 17 +- matter_server/common/models.py | 12 +- matter_server/server/__main__.py | 2 + matter_server/server/device_controller.py | 451 +++++++++++++++------- matter_server/server/stack.py | 4 +- matter_server/server/storage.py | 9 +- 13 files changed, 400 insertions(+), 172 deletions(-) diff --git a/README.md b/README.md index 12d93f30..3efe4f89 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ relies on the networking managed by your operating system. Make sure your you run the container on the host network. The host network interface needs to be in the same network as the Android/iPhone device you are using for commissioning. Matter uses link-local multicast protocols -which do not work accross different LANs or VLANs. +which do not work across different LANs or VLANs. The host network interface needs IPv6 support enabled. @@ -54,7 +54,7 @@ For communication through Thread border routers which are not running on the sam host as the Matter Controller server to work, IPv6 routing needs to be properly working. IPv6 routing is largely setup automatically through the IPv6 Neighbor Discovery Protocol, specifically the Route Information Options (RIO). However, -if IPv6 ND RIO's are processed, and processed correctly depends on the network +if IPv6 Neighbor Discovery RIO's are processed, and processed correctly depends on the network management software your system is using. There may be bugs and cavats in processing this Route Information Options. diff --git a/matter_server/client/client.py b/matter_server/client/client.py index 3ce953e4..bf9edf29 100644 --- a/matter_server/client/client.py +++ b/matter_server/client/client.py @@ -67,6 +67,10 @@ def subscribe( Optionally filter by specific events or node attributes. Returns: function to unsubscribe. + + NOTE: To receive attribute changed events, + you must also register the attributes to subscribe to + with the `subscribe_attributes` method. """ # for fast lookups we create a key based on the filters, allowing # a "catch all" with a wildcard (*). @@ -94,11 +98,11 @@ def unsubscribe() -> None: return unsubscribe - async def get_nodes(self) -> list[MatterNode]: + def get_nodes(self) -> list[MatterNode]: """Return all Matter nodes.""" return list(self._nodes.values()) - async def get_node(self, node_id: int) -> MatterNode: + def get_node(self, node_id: int) -> MatterNode: """Return Matter node by id.""" return self._nodes[node_id] @@ -164,7 +168,7 @@ async def get_matter_fabrics(self, node_id: int) -> list[MatterFabricData]: Returns a list of MatterFabricData objects. """ - node = await self.get_node(node_id) + node = self.get_node(node_id) fabrics: list[ Clusters.OperationalCredentials.Structs.FabricDescriptor ] = node.get_attribute_value( @@ -229,6 +233,22 @@ async def remove_node(self, node_id: int) -> None: """Remove a Matter node/device from the fabric.""" await self.send_command(APICommand.REMOVE_NODE, node_id=node_id) + async def subscribe_attribute( + self, node_id: int, attribute_path: str | list[str] + ) -> None: + """ + Subscribe to given AttributePath(s). + + Either supply a single attribute path or a list of paths. + The given attribute path(s) will be added to the list of attributes that + are watched for the given node. This is persistent over restarts. + """ + await self.send_command( + APICommand.SUBSCRIBE_ATTRIBUTE, + node_id=node_id, + attribute_path=attribute_path, + ) + async def send_command( self, command: str, @@ -366,7 +386,7 @@ def _handle_incoming_message(self, msg: MessageType) -> None: # handle EventMessage if isinstance(msg, EventMessage): - self.logger.debug("Received event: %s", msg) + self.logger.debug("Received event: %s", msg.event) self._handle_event_message(msg) return @@ -392,10 +412,20 @@ def _handle_event_message(self, msg: EventMessage) -> None: node.update(node_data) self._signal_event(event, data=node, node_id=node.node_id) return - if msg.event == EventType.NODE_DELETED: + if msg.event == EventType.NODE_REMOVED: node_id = msg.data + self._signal_event(EventType.NODE_REMOVED, data=node_id, node_id=node_id) + # cleanup node only after signalling subscribers self._nodes.pop(node_id, None) - self._signal_event(EventType.NODE_DELETED, data=node_id, node_id=node_id) + return + if msg.event == EventType.ENDPOINT_REMOVED: + node_id = msg.data["node_id"] + self._signal_event( + EventType.ENDPOINT_REMOVED, data=msg.data, node_id=node_id + ) + # cleanup endpoint only after signalling subscribers + if node := self._nodes.get(node_id): + node.endpoints.pop(msg.data["endpoint_id"], None) return if msg.event == EventType.ATTRIBUTE_UPDATED: # data is tuple[node_id, attribute_path, new_value] diff --git a/matter_server/client/connection.py b/matter_server/client/connection.py index 7b4dac73..52101da9 100644 --- a/matter_server/client/connection.py +++ b/matter_server/client/connection.py @@ -3,6 +3,7 @@ import asyncio import logging +import os import pprint from typing import Any, Callable, Dict, Final, cast @@ -33,6 +34,7 @@ from .models.node import MatterNode LOGGER = logging.getLogger(f"{__package__}.connection") +VERBOSE_LOGGER = os.environ.get("MATTER_VERBOSE_LOGGING") SUB_WILDCARD: Final = "*" @@ -129,7 +131,10 @@ async def receive_message_or_raise(self) -> MessageType: raise InvalidMessage("Received invalid JSON.") from err if LOGGER.isEnabledFor(logging.DEBUG): - LOGGER.debug("Received message:\n%s\n", pprint.pformat(ws_msg)) + if VERBOSE_LOGGER: + LOGGER.debug("Received message:\n%s\n", pprint.pformat(ws_msg)) + else: + LOGGER.debug("Received message: %s ...", ws_msg.data[:50]) return msg @@ -143,7 +148,10 @@ async def send_message(self, message: CommandMessage) -> None: raise NotConnected if LOGGER.isEnabledFor(logging.DEBUG): - LOGGER.debug("Publishing message:\n%s\n", pprint.pformat(message)) + if VERBOSE_LOGGER: + LOGGER.debug("Publishing message:\n%s\n", pprint.pformat(message)) + else: + LOGGER.debug("Publishing message: %s", message) assert self._ws_client assert isinstance(message, CommandMessage) diff --git a/matter_server/client/models/node.py b/matter_server/client/models/node.py index 38b939a7..fad3d6c8 100644 --- a/matter_server/client/models/node.py +++ b/matter_server/client/models/node.py @@ -216,7 +216,6 @@ class MatterNode: def __init__(self, node_data: MatterNodeData) -> None: """Initialize MatterNode from MatterNodeData.""" self.endpoints: dict[int, MatterEndpoint] = {} - self._is_bridge_device: bool = False # composed devices reference to other endpoints through the partsList attribute # create a mapping table self._composed_endpoints: dict[int, int] = {} @@ -251,7 +250,7 @@ def device_info(self) -> Clusters.BasicInformation: @property def is_bridge_device(self) -> bool: """Return if this Node is a Bridge/Aggregator device.""" - return self._is_bridge_device + return self.node_data.is_bridge def get_attribute_value( self, @@ -310,10 +309,6 @@ def update(self, node_data: MatterNodeData) -> None: self.endpoints[endpoint_id] = MatterEndpoint( endpoint_id=endpoint_id, attributes_data=attributes_data, node=self ) - # lookup if this is a bridge device - self._is_bridge_device = any( - Aggregator in x.device_types for x in self.endpoints.values() - ) # composed devices reference to other endpoints through the partsList attribute # create a mapping table to quickly map this for endpoint in self.endpoints.values(): @@ -339,6 +334,9 @@ def update(self, node_data: MatterNodeData) -> None: def update_attribute(self, attribute_path: str, new_value: Any) -> None: """Handle Attribute value update.""" endpoint_id = int(attribute_path.split("/")[0]) + if endpoint_id not in self.endpoints: + # race condition when a bridge is in the process of adding a new endpoint + return self.endpoints[endpoint_id].set_attribute_value(attribute_path, new_value) def __repr__(self) -> str: diff --git a/matter_server/common/const.py b/matter_server/common/const.py index 6042368c..9b85927f 100644 --- a/matter_server/common/const.py +++ b/matter_server/common/const.py @@ -2,4 +2,4 @@ # schema version is used to determine compatibility between server and client # bump schema if we add new features and/or make other (breaking) changes -SCHEMA_VERSION = 3 +SCHEMA_VERSION = 4 diff --git a/matter_server/common/helpers/api.py b/matter_server/common/helpers/api.py index e7d5cb44..c13bf8fa 100644 --- a/matter_server/common/helpers/api.py +++ b/matter_server/common/helpers/api.py @@ -56,7 +56,7 @@ def parse_arguments( if strict: for key, value in args.items(): if key not in func_sig.parameters: - raise KeyError("Invalid parameter: '%s'" % key) + raise KeyError(f"Invalid parameter: '{key}'") # parse arguments to correct type for name, param in func_sig.parameters.items(): value = args.get(name) diff --git a/matter_server/common/helpers/json.py b/matter_server/common/helpers/json.py index 9ce22f06..e79d71a5 100644 --- a/matter_server/common/helpers/json.py +++ b/matter_server/common/helpers/json.py @@ -4,7 +4,6 @@ from dataclasses import is_dataclass from typing import Any -from chip.clusters.Attribute import ValueDecodeFailure from chip.clusters.Types import Nullable from chip.tlv import float32, uint import orjson @@ -22,8 +21,6 @@ def json_encoder_default(obj: Any) -> Any: """ if getattr(obj, "do_not_serialize", None): return None - if isinstance(obj, ValueDecodeFailure): - return None if isinstance(obj, (set, tuple)): return list(obj) if isinstance(obj, float32): @@ -40,7 +37,7 @@ def json_encoder_default(obj: Any) -> Any: return b64encode(obj).decode("utf-8") if isinstance(obj, Exception): return str(obj) - if type(obj) == type: + if type(obj) == type: # pylint: disable=unidiomatic-typecheck return f"{obj.__module__}.{obj.__qualname__}" raise TypeError diff --git a/matter_server/common/helpers/util.py b/matter_server/common/helpers/util.py index b62da6a3..5b2e1fc5 100644 --- a/matter_server/common/helpers/util.py +++ b/matter_server/common/helpers/util.py @@ -79,6 +79,7 @@ def parse_utc_timestamp(datetime_string: str) -> datetime: def parse_value(name: str, value: Any, value_type: Any, default: Any = MISSING) -> Any: """Try to parse a value from raw (json) data and type annotations.""" + # pylint: disable=too-many-return-statements,too-many-branches if isinstance(value_type, str): # this shouldn't happen, but just in case @@ -105,14 +106,14 @@ def parse_value(name: str, value: Any, value_type: Any, default: Any = MISSING) return dataclass_from_dict(value_type, value) # get origin value type and inspect one-by-one origin: Any = get_origin(value_type) - if origin in (list, tuple) and isinstance(value, list | tuple): + if origin in (list, tuple, set) and isinstance(value, (list, tuple, set)): return origin( parse_value(name, subvalue, get_args(value_type)[0]) for subvalue in value if subvalue is not None ) # handle dictionary where we should inspect all values - elif origin is dict: + if origin is dict: subkey_type = get_args(value_type)[0] subvalue_type = get_args(value_type)[1] return { @@ -122,7 +123,7 @@ def parse_value(name: str, value: Any, value_type: Any, default: Any = MISSING) for subkey, subvalue in value.items() } # handle Union type - elif origin is Union or origin is UnionType: + if origin is Union or origin is UnionType: # try all possible types sub_value_types = get_args(value_type) for sub_arg_type in sub_value_types: @@ -143,9 +144,9 @@ def parse_value(name: str, value: Any, value_type: Any, default: Any = MISSING) # raise exception, we have no idea how to handle this value raise TypeError(err) # failed to parse the (sub) value but None allowed, log only - logging.getLogger(__name__).warn(err) + logging.getLogger(__name__).warning(err) return None - elif origin is type: + if origin is type: return get_type_hints(value, globals(), locals()) # handle Any as value type (which is basically unprocessable) if value_type is Any: @@ -157,6 +158,7 @@ def parse_value(name: str, value: Any, value_type: Any, default: Any = MISSING) try: if issubclass(value_type, Enum): # handle enums from the SDK that have a value that does not exist in the enum (sigh) + # pylint: disable=protected-access if value not in value_type._value2member_map_: # we do not want to crash so we return the raw value return value @@ -208,11 +210,10 @@ def dataclass_from_dict(cls: type[_T], dict_obj: dict, strict: bool = False) -> If strict mode enabled, any additional keys in the provided dict will result in a KeyError. """ if strict: - extra_keys = dict_obj.keys() - set([f.name for f in fields(cls)]) + extra_keys = dict_obj.keys() - {f.name for f in fields(cls)} if extra_keys: raise KeyError( - "Extra key(s) %s not allowed for %s" - % (",".join(extra_keys), (str(cls))) + f'Extra key(s) {",".join(extra_keys)} not allowed for {str(cls)}' ) type_hints = get_type_hints(cls) return cls( diff --git a/matter_server/common/models.py b/matter_server/common/models.py index a8b719eb..baf701c6 100644 --- a/matter_server/common/models.py +++ b/matter_server/common/models.py @@ -14,10 +14,12 @@ class EventType(Enum): NODE_ADDED = "node_added" NODE_UPDATED = "node_updated" - NODE_DELETED = "node_deleted" + NODE_REMOVED = "node_removed" NODE_EVENT = "node_event" ATTRIBUTE_UPDATED = "attribute_updated" SERVER_SHUTDOWN = "server_shutdown" + ENDPOINT_ADDED = "endpoint_added" + ENDPOINT_REMOVED = "endpoint_removed" class APICommand(str, Enum): @@ -38,6 +40,7 @@ class APICommand(str, Enum): DEVICE_COMMAND = "device_command" REMOVE_NODE = "remove_node" GET_VENDOR_NAMES = "get_vendor_names" + SUBSCRIBE_ATTRIBUTE = "subscribe_attribute" EventCallBackType = Callable[[EventType, Any], None] @@ -66,8 +69,15 @@ class MatterNodeData: last_interview: datetime interview_version: int available: bool = False + is_bridge: bool = False # attributes are stored in form of AttributePath: ENDPOINT/CLUSTER_ID/ATTRIBUTE_ID attributes: dict[str, Any] = field(default_factory=dict) + # all attribute subscriptions we need to persist for this node, + # a set of tuples in format (endpoint_id, cluster_id, attribute_id) + # where each value can also be a '*' for wildcard + attribute_subscriptions: set[tuple[int | str, int | str, int | str]] = field( + default_factory=set + ) @dataclass diff --git a/matter_server/server/__main__.py b/matter_server/server/__main__.py index bcd16847..8c2706ce 100644 --- a/matter_server/server/__main__.py +++ b/matter_server/server/__main__.py @@ -70,6 +70,8 @@ def main() -> None: handlers = [logging.FileHandler(args.log_file)] logging.basicConfig(handlers=handlers, level=args.log_level.upper()) coloredlogs.install(level=args.log_level.upper()) + if not logging.getLogger().isEnabledFor(logging.DEBUG): + logging.getLogger("chip").setLevel(logging.WARNING) # make sure storage path exists if not os.path.isdir(args.storage_path): diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index a0cc2509..9c40b64e 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -7,13 +7,17 @@ from datetime import datetime from functools import partial import logging -import time -from typing import TYPE_CHECKING, Any, Callable, Coroutine, Deque, Type, TypeVar, cast +from typing import TYPE_CHECKING, Any, Callable, Deque, Iterable, Type, TypeVar, cast from chip.ChipDeviceCtrl import CommissionableNode from chip.clusters import Attribute, Objects as Clusters from chip.clusters.Attribute import ValueDecodeFailure -from chip.clusters.ClusterObjects import ALL_CLUSTERS, Cluster +from chip.clusters.ClusterObjects import ( + ALL_ATTRIBUTES, + ALL_CLUSTERS, + Cluster, + ClusterAttributeDescriptor, +) from chip.exceptions import ChipStackError from ..common.const import SCHEMA_VERSION @@ -29,6 +33,7 @@ create_attribute_path, create_attribute_path_from_attribute, dataclass_from_dict, + parse_attribute_path, ) from ..common.models import APICommand, EventType, MatterNodeData from .const import PAA_ROOT_CERTS_DIR @@ -45,7 +50,15 @@ DATA_KEY_LAST_NODE_ID = "last_node_id" LOGGER = logging.getLogger(__name__) -INTERVIEW_TASK_LIMIT = 10 +INTERVIEW_TASK_LIMIT = 5 + +# a list of attributes we should always watch on all nodes +DEFAULT_SUBSCRIBE_ATTRIBUTES: set[tuple[int | str, int | str, int | str]] = { + ("*", 0x001D, 0x00000000), # all endpoints, descriptor cluster, deviceTypeList + ("*", 0x001D, 0x00000003), # all endpoints, descriptor cluster, partsList + (0, 0x0028, "*"), # endpoint 0, BasicInformation cluster, all attributes + ("*", 0x0039, "*"), # BridgedDeviceBasicInformation +} class MatterDeviceController: @@ -62,11 +75,16 @@ def __init__( # we keep the last events in memory so we can include them in the diagnostics dump self.event_history: Deque[Attribute.EventReadResult] = deque(maxlen=25) self._subscriptions: dict[int, Attribute.SubscriptionTransaction] = {} + self._attr_subscriptions: dict[int, list[tuple[Any, ...]] | str] = {} + self._resub_timer: dict[int, asyncio.TimerHandle] = {} self._nodes: dict[int, MatterNodeData | None] = {} self.wifi_credentials_set: bool = False self.thread_credentials_set: bool = False self.compressed_fabric_id: int | None = None - self._interview_task: asyncio.Task | None = None + self._interview_limit: asyncio.Semaphore = asyncio.Semaphore( + INTERVIEW_TASK_LIMIT + ) + self._node_lock: dict[int, asyncio.Lock] = {} async def initialize(self) -> None: """Async initialize of controller.""" @@ -97,11 +115,9 @@ async def start(self) -> None: # always mark node as unavailable at startup until subscriptions are ready node.available = False self._nodes[node_id] = node - # setup subscriptions and (re)interviews as task in the background - # as we do not want it to block our startup - self._interview_task = asyncio.create_task( - self._check_subscriptions_and_interviews() - ) + # setup subscription and (re)interview as task in the background + # as we do not want it to block our startup + asyncio.create_task(self._check_interview_and_subscription(node_id)) LOGGER.debug("Loaded %s nodes", len(self._nodes)) async def stop(self) -> None: @@ -160,7 +176,7 @@ async def commission_with_code(self, code: str) -> MatterNodeData: # full interview of the device await self.interview_node(node_id) # make sure we start a subscription for this newly added node - await self.subscribe_node(node_id) + await self._subscribe_node(node_id) # return full node object once we're complete return self.get_node(node_id) @@ -205,7 +221,7 @@ async def commission_on_network( # full interview of the device await self.interview_node(node_id) # make sure we start a subscription for this newly added node - await self.subscribe_node(node_id) + await self._subscribe_node(node_id) # return full node object once we're complete return self.get_node(node_id) @@ -245,8 +261,7 @@ async def open_commissioning_window( option: int = 1, discriminator: int | None = None, ) -> tuple[int, str]: - """ - Open a commissioning window to commission a device present on this controller to another. + """Open a commissioning window to commission a device present on this controller to another. Returns code to use as discriminator. """ @@ -288,11 +303,16 @@ async def interview_node(self, node_id: int) -> None: LOGGER.debug("Interviewing node: %s", node_id) try: await self._resolve_node(node_id=node_id) - read_response: Attribute.AsyncReadTransaction.ReadResponse = ( - await self.chip_controller.Read( - nodeid=node_id, attributes="*", events="*", fabricFiltered=False - ) - ) + async with self._interview_limit: + async with self._get_node_lock(node_id): + read_response: Attribute.AsyncReadTransaction.ReadResponse = ( + await self.chip_controller.Read( + nodeid=node_id, + attributes="*", + events="*", + fabricFiltered=False, + ) + ) except (ChipStackError, NodeNotResolving) as err: raise NodeInterviewFailed(f"Failed to interview node {node_id}") from err @@ -310,6 +330,12 @@ async def interview_node(self, node_id: int) -> None: ), ) + if existing_info: + node.attribute_subscriptions = existing_info.attribute_subscriptions + # work out if the node is a bridge device by looking at the devicetype of endpoint 1 + if attr_data := node.attributes.get("1/29/0"): + node.is_bridge = any(x.deviceType == 14 for x in attr_data) + # save updated node data self._nodes[node_id] = node self.server.storage.set( @@ -323,6 +349,7 @@ async def interview_node(self, node_id: int) -> None: self.server.signal_event(EventType.NODE_ADDED, node) else: # existing node, signal node updated event + # TODO: maybe only signal this event if attributes actually changed ? self.server.signal_event(EventType.NODE_UPDATED, node) LOGGER.debug("Interview of node %s completed", node_id) @@ -346,14 +373,15 @@ async def send_device_command( cluster_cls: Cluster = ALL_CLUSTERS[cluster_id] command_cls = getattr(cluster_cls.Commands, command_name) command = dataclass_from_dict(command_cls, payload) - return await self.chip_controller.SendCommand( - nodeid=node_id, - endpoint=endpoint_id, - payload=command, - responseType=response_type, - timedRequestTimeoutMs=timed_request_timeout_ms, - interactionTimeoutMs=interaction_timeout_ms, - ) + async with self._get_node_lock(node_id): + return await self.chip_controller.SendCommand( + nodeid=node_id, + endpoint=endpoint_id, + payload=command, + responseType=response_type, + timedRequestTimeoutMs=timed_request_timeout_ms, + interactionTimeoutMs=interaction_timeout_ms, + ) @api_command(APICommand.REMOVE_NODE) async def remove_node(self, node_id: int) -> None: @@ -380,7 +408,7 @@ async def remove_node(self, node_id: int) -> None: ) fabric_index = node.attributes[attribute_path] - self.server.signal_event(EventType.NODE_DELETED, node_id) + self.server.signal_event(EventType.NODE_REMOVED, node_id) await self.chip_controller.SendCommand( nodeid=node_id, @@ -390,11 +418,16 @@ async def remove_node(self, node_id: int) -> None: ), ) - async def subscribe_node(self, node_id: int) -> None: + @api_command(APICommand.SUBSCRIBE_ATTRIBUTE) + async def subscribe_attribute( + self, node_id: int, attribute_path: str | list[str] + ) -> None: """ - Subscribe to all node state changes/events for an individual node. + Subscribe to given AttributePath(s). - Note that by using the listen command at server level, you will receive all node events. + Either supply a single attribute path or a list of paths. + The given attribute path(s) will be added to the list of attributes that + are watched for the given node. This is persistent over restarts. """ if self.chip_controller is None: raise RuntimeError("Device Controller not initialized.") @@ -403,27 +436,123 @@ async def subscribe_node(self, node_id: int) -> None: raise NodeNotExists( f"Node {node_id} does not exist or has not been interviewed." ) - assert node_id not in self._subscriptions, "Already subscribed to node" - node_logger = LOGGER.getChild(str(node_id)) - node_logger.debug("Setting up subscriptions...") + node = self._nodes[node_id] + assert node is not None + + # work out added subscriptions + if not isinstance(attribute_path, list): + attribute_path = [attribute_path] + attribute_paths = {parse_attribute_path(x) for x in attribute_path} + prev_subs = set(node.attribute_subscriptions) + node.attribute_subscriptions.update(attribute_paths) + if prev_subs == node.attribute_subscriptions: + return # nothing to do + # save updated node data + self.server.storage.set( + DATA_KEY_NODES, + subkey=str(node_id), + value=node, + ) + + # (re)setup node subscription + # this could potentially be called multiple times within a short timeframe + # so debounce it a bit + def resubscribe() -> None: + self._resub_timer.pop(node_id, None) + asyncio.create_task(self._subscribe_node(node_id)) + + if existing_timer := self._resub_timer.pop(node_id, None): + existing_timer.cancel() + assert self.server.loop is not None + self._resub_timer[node_id] = self.server.loop.call_later(5, resubscribe) + + async def _subscribe_node(self, node_id: int) -> None: + """ + Subscribe to all node state changes/events for an individual node. + + Note that by using the listen command at server level, + you will receive all (subscribed) node events and attribute updates. + """ + # pylint: disable=too-many-locals,too-many-statements + if self.chip_controller is None: + raise RuntimeError("Device Controller not initialized.") + + if self._nodes.get(node_id) is None: + raise NodeNotExists( + f"Node {node_id} does not exist or has not been interviewed." + ) + + node_logger = LOGGER.getChild(f"[node {node_id}]") + node_lock = self._get_node_lock(node_id) node = cast(MatterNodeData, self._nodes[node_id]) - node.available = False await self._resolve_node(node_id=node_id) - node.available = True - # we follow the pattern of apple and google here and - # just do a wildcard subscription for all clusters and properties - # the client will handle filtering of the events. - # if it turns out in the future that this is too much traffic (I don't think so now) - # we can revisit this choice and do some selected subscriptions. - sub: Attribute.SubscriptionTransaction = await self.chip_controller.Read( - nodeid=node_id, - attributes="*", - events=[("*", 1)], - reportInterval=(0, 120), - fabricFiltered=False, - ) + # work out all (current) attribute subscriptions + attr_subscriptions: list[Any] = [] + for ( + endpoint_id, + cluster_id, + attribute_id, + ) in set.union(DEFAULT_SUBSCRIBE_ATTRIBUTES, node.attribute_subscriptions): + endpoint: int | None = None if endpoint_id == "*" else int(endpoint_id) + cluster: Type[Cluster] = ALL_CLUSTERS[cluster_id] + attribute: Type[ClusterAttributeDescriptor] | None = ( + None + if attribute_id == "*" + else ALL_ATTRIBUTES[cluster_id][attribute_id] + ) + if endpoint and attribute: + # Concrete path: specific endpoint, specific clusterattribute + attr_subscriptions.append((endpoint, attribute)) + elif endpoint and cluster: + # Specific endpoint, Wildcard attribute id (specific cluster) + attr_subscriptions.append((endpoint, cluster)) + elif attribute: + # Wildcard endpoint, specific attribute + attr_subscriptions.append(attribute) + elif cluster: + # Wildcard endpoint, specific cluster + attr_subscriptions.append(cluster) + + if len(attr_subscriptions) > 50: + # prevent memory overload on node and fallback to wildcard sub if too many + # individual subscriptions + attr_subscriptions = "*" # type: ignore[assignment] + + # check if we already have an subscription for this node, + # if so, we need to unsubscribe first because a device can only maintain + # a very limited amount of concurrent subscriptions. + if prev_sub := self._subscriptions.pop(node_id, None): + if self._attr_subscriptions.get(node_id) == attr_subscriptions: + # the current subscription already matches, no need to re-setup + node_logger.debug("Re-using existing subscription.") + return + async with node_lock: + node_logger.debug("Unsubscribing from existing subscription.") + await self._call_sdk(prev_sub.Shutdown) + + node_logger.debug("Setting up attributes and events subscription.") + self._attr_subscriptions[node_id] = attr_subscriptions + async with node_lock: + sub: Attribute.SubscriptionTransaction = await self.chip_controller.Read( + nodeid=node_id, + # In order to prevent network congestion due to wildcard subscriptions on all nodes, + # we keep a list of attributes we are explicitly interested in. + attributes=attr_subscriptions, + # simply subscribe to all (urgent and non urgent) device events + events=[("*", 1), ("*", 0)], + # Use a report interval of 0, 300 which means we want to receive state changes + # as soon as possible (the 0 as floor) but we want to receive a report + # at least once every 5 minutes (300 as ceiling). + # This is also used to detect the node is still alive. + # A resubscription will be initiated automatically by the sdk + # if there was no report within the interval. + reportInterval=(0, 300), + # Use fabricfiltered as False to detect changes made by other controllers + # and to be able to provide a list of all fabrics attached to the device + fabricFiltered=False, + ) def attribute_updated_callback( path: Attribute.TypedAttributePath, @@ -435,8 +564,36 @@ def attribute_updated_callback( # these are set by the SDK if parsing the value failed miserably if isinstance(new_value, ValueDecodeFailure): return - node_logger.debug("Attribute updated: %s - new value: %s", path, new_value) + attr_path = str(path.Path) + old_value = node.attributes.get(attr_path) + + node_logger.debug( + "Attribute updated: %s - old value: %s - new value: %s", + path, + old_value, + new_value, + ) + + # work out added/removed endpoints on bridges + if ( + node.is_bridge + and path.Path.EndpointId == 0 + and path.AttributeType == Clusters.Descriptor.Attributes.PartsList + ): + endpoints_removed = set(old_value or []) - set(new_value) + endpoints_added = set(new_value) - set(old_value or []) + if endpoints_removed: + self.server.loop.call_soon_threadsafe( + self._handle_endpoints_removed, node_id, endpoints_removed + ) + if endpoints_added: + self.server.loop.create_task( + self._handle_endpoints_added(node_id, endpoints_added) + ) + return + + # store updated value in node attributes node.attributes[attr_path] = new_value # schedule save to persistent storage @@ -460,7 +617,9 @@ def event_callback( ) -> None: # pylint: disable=unused-argument assert self.server.loop is not None - node_logger.debug("Received node event: %s", data) + node_logger.debug( + "Received node event: %s - transaction: %s", data, transaction + ) self.event_history.append(data) self.server.loop.call_soon_threadsafe( self.server.signal_event, EventType.NODE_EVENT, data @@ -478,7 +637,7 @@ def resubscription_attempted( nextResubscribeIntervalMsec: int, ) -> None: # pylint: disable=unused-argument, invalid-name - node_logger.debug( + node_logger.info( "Previous subscription failed with Error: %s, re-subscribing in %s ms...", terminationError, nextResubscribeIntervalMsec, @@ -492,7 +651,7 @@ def resubscription_succeeded( transaction: Attribute.SubscriptionTransaction, ) -> None: # pylint: disable=unused-argument, invalid-name - node_logger.debug("Re-Subscription succeeded") + node_logger.info("Re-Subscription succeeded") # mark node as available and signal consumers if not node.available: node.available = True @@ -504,10 +663,14 @@ def resubscription_succeeded( sub.SetResubscriptionAttemptedCallback(resubscription_attempted) sub.SetResubscriptionSucceededCallback(resubscription_succeeded) self._subscriptions[node_id] = sub + # if we reach this point, it means the node could be resolved # and the initial subscription succeeded, mark the node available. node.available = True - node_logger.debug("Subscription succeeded") + node_logger.info("Subscription succeeded") + # update attributes with current state from read request + current_atributes = self._parse_attributes_from_read_result(sub.GetAttributes()) + node.attributes.update(current_atributes) self.server.signal_event(EventType.NODE_UPDATED, node) def _get_next_node_id(self) -> int: @@ -529,88 +692,57 @@ async def _call_sdk(self, func: Callable[..., _T], *args: Any, **kwargs: Any) -> ), ) - async def _check_subscriptions_and_interviews(self) -> None: - """Run subscriptions (and interviews) for known nodes.""" - # Set default resubscribe interval to 1 hour - reschedule_interval = 3600 - start_time = time.time() - tasks: list[Coroutine[Any, Any, None]] = [] - task_limit: asyncio.Semaphore = asyncio.Semaphore(INTERVIEW_TASK_LIMIT) - - for node_id, node in self._nodes.items(): - # (re)interview node (only) if needed - if ( - node is None - or node.interview_version < SCHEMA_VERSION - or (datetime.utcnow() - node.last_interview).days > 30 - ): + async def _check_interview_and_subscription( + self, node_id: int, reschedule_interval: int = 300 + ) -> None: + """Handle interview (if needed) and subscription for known node.""" - async def _interview_node(node_id: int) -> None: - """Run interview for node.""" - try: - await self.interview_node(node_id) - except NodeInterviewFailed as err: - LOGGER.warning( - "Unable to interview Node %s, we will retry later in the background.", - node_id, - exc_info=err, - ) - raise err - - tasks.append(_interview_node(node_id)) - continue - - # setup subscriptions for the node - if node_id in self._subscriptions: - continue - - async def _subscribe_node(node_id: int) -> None: - """Subscribe to node events.""" - try: - await self.subscribe_node(node_id) - except NodeNotResolving as err: - LOGGER.warning( - "Unable to subscribe to Node %s, " - "we will retry later in the background.", - node_id, - exc_info=err, - ) - raise err + def reschedule() -> None: + """(Re)Schedule interview and/or initial subscription for a node.""" + assert self.server.loop is not None + self.server.loop.call_later( + reschedule_interval, + asyncio.create_task, + self._check_interview_and_subscription( + node_id, + # increase interval at each attempt with maximum of 1 hour + min(reschedule_interval + 300, 3600), + ), + ) - tasks.append(_subscribe_node(node_id)) + # (re)interview node (only) if needed + node_data = self._nodes.get(node_id) + if ( + node_data is None + # re-interview if the schema has changed + or node_data.interview_version < SCHEMA_VERSION + # re-interview every 30 days + or (datetime.utcnow() - node_data.last_interview).days > 30 + ): + try: + await self.interview_node(node_id) + except NodeInterviewFailed: + LOGGER.warning( + "Unable to interview Node %s, will retry later in the background.", + node_id, + ) + # reschedule self on error + reschedule() + return - async def _run_task(task: Coroutine[Any, Any, None]) -> None: - """Run coroutine and release semaphore.""" - async with task_limit: - await task + # setup subscriptions for the node + if node_id in self._subscriptions: + return - LOGGER.debug("Running %s tasks", len(tasks)) - # wait for all tasks to finish - results: list[Exception | None] = await asyncio.gather( - *(_run_task(task) for task in tasks), return_exceptions=True - ) - LOGGER.debug( - "Done running %s tasks in %s seconds", - len(results), - start_time - time.time(), - ) - # check if any of the tasks failed - for result in results: - if isinstance(result, Exception): - # if any of the tasks failed, reschedule in 5 minutes - reschedule_interval = 300 - break - - # reschedule self to run every hour - def _schedule() -> None: - """Schedule task.""" - self._interview_task = asyncio.create_task( - self._check_subscriptions_and_interviews() + try: + await self._subscribe_node(node_id) + except NodeNotResolving: + LOGGER.warning( + "Unable to subscribe to Node %s as it is unavailable, " + "will retry later in the background.", + node_id, ) - - LOGGER.debug("Rescheduling interviews in %s seconds", reschedule_interval) - loop = cast(asyncio.AbstractEventLoop, self.server.loop) - loop.call_later(reschedule_interval, _schedule) + reschedule() @staticmethod def _parse_attributes_from_read_result( @@ -658,6 +790,7 @@ async def _resolve_node( self, node_id: int, retries: int = 3, allow_pase: bool = False ) -> None: """Resolve a Node on the network.""" + node_lock = self._get_node_lock(node_id) if self.chip_controller is None: raise RuntimeError("Device Controller not initialized.") try: @@ -666,11 +799,13 @@ async def _resolve_node( LOGGER.debug( "Attempting to resolve node %s (with PASE connection)", node_id ) - await self._call_sdk( - self.chip_controller.GetConnectedDeviceSync, - nodeid=node_id, - allowPASE=True, - ) + async with node_lock: + await self._call_sdk( + self.chip_controller.GetConnectedDeviceSync, + nodeid=node_id, + allowPASE=True, + timeoutMs=30000, + ) return LOGGER.debug("Resolving node %s", node_id) await self._call_sdk(self.chip_controller.ResolveNode, nodeid=node_id) @@ -678,7 +813,47 @@ async def _resolve_node( if not retries: # when we're out of retries, raise NodeNotResolving raise NodeNotResolving(f"Unable to resolve Node {node_id}") from err - await self._resolve_node( - node_id=node_id, retries=retries - 1, allow_pase=retries - 1 == 0 - ) + async with node_lock: + await self._resolve_node( + node_id=node_id, retries=retries - 1, allow_pase=retries - 1 == 0 + ) await asyncio.sleep(2) + + def _handle_endpoints_removed(self, node_id: int, endpoints: Iterable[int]) -> None: + """Handle callback for when bridge endpoint(s) get deleted.""" + node = cast(MatterNodeData, self._nodes[node_id]) + for endpoint_id in endpoints: + node.attributes = { + key: value + for key, value in node.attributes.items() + if not key.startswith(f"{endpoint_id}/") + } + self.server.signal_event( + EventType.ENDPOINT_REMOVED, + {"node_id": node_id, "endpoint_id": endpoint_id}, + ) + # schedule save to persistent storage + self.server.storage.set( + DATA_KEY_NODES, + subkey=str(node_id), + value=node, + ) + + async def _handle_endpoints_added( + self, node_id: int, endpoints: Iterable[int] + ) -> None: + """Handle callback for when bridge endpoint(s) get added.""" + # we simply do a full interview of the node + await self.interview_node(node_id) + # signal event to consumers + for endpoint_id in endpoints: + self.server.signal_event( + EventType.ENDPOINT_ADDED, + {"node_id": node_id, "endpoint_id": endpoint_id}, + ) + + def _get_node_lock(self, node_id: int) -> asyncio.Lock: + """Return lock for given node.""" + if node_id not in self._node_lock: + self._node_lock[node_id] = asyncio.Lock() + return self._node_lock[node_id] diff --git a/matter_server/server/stack.py b/matter_server/server/stack.py index 99ca91de..7a356dd5 100644 --- a/matter_server/server/stack.py +++ b/matter_server/server/stack.py @@ -30,7 +30,9 @@ def __init__( chip.logging.RedirectToPythonLogging() self._chip_stack = ChipStack( - persistentStoragePath=storage_file, enableServerInteractions=False + persistentStoragePath=storage_file, + installDefaultLogHandler=False, + enableServerInteractions=False, ) # Initialize Certificate Authority Manager diff --git a/matter_server/server/storage.py b/matter_server/server/storage.py index 31f0c97a..f454819f 100644 --- a/matter_server/server/storage.py +++ b/matter_server/server/storage.py @@ -123,16 +123,21 @@ def save(self, immediate: bool = False) -> None: """Schedule save of data to disk.""" assert self.server.loop is not None + def _do_save() -> None: + assert self.server.loop is not None + self.server.loop.create_task(self.async_save()) + if self._timer_handle is not None: self._timer_handle.cancel() self._timer_handle = None if immediate: - self.server.loop.create_task(self.async_save()) + _do_save() + else: # schedule the save for later self._timer_handle = self.server.loop.call_later( - DEFAULT_SAVE_DELAY, self.server.loop.create_task, self.async_save() + DEFAULT_SAVE_DELAY, _do_save ) async def async_save(self) -> None: