From cd6ff8e34117f6e35b1a1edf964a9427bb389fad Mon Sep 17 00:00:00 2001 From: Marcel van der Veldt Date: Mon, 26 Jun 2023 20:26:29 +0200 Subject: [PATCH] attribute_paths as set --- matter_server/common/helpers/util.py | 2 +- matter_server/common/models.py | 2 +- matter_server/server/device_controller.py | 11 +++++------ 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/matter_server/common/helpers/util.py b/matter_server/common/helpers/util.py index 39e24719..e87ed238 100644 --- a/matter_server/common/helpers/util.py +++ b/matter_server/common/helpers/util.py @@ -105,7 +105,7 @@ def parse_value(name: str, value: Any, value_type: Any, default: Any = MISSING) return dataclass_from_dict(value_type, value) # get origin value type and inspect one-by-one origin: Any = get_origin(value_type) - if origin in (list, tuple, set) and isinstance(value, list | tuple | set): + if origin in (list, tuple, set) and isinstance(value, (list, tuple, set)): return origin( parse_value(name, subvalue, get_args(value_type)[0]) for subvalue in value diff --git a/matter_server/common/models.py b/matter_server/common/models.py index 4d7e1baa..baf701c6 100644 --- a/matter_server/common/models.py +++ b/matter_server/common/models.py @@ -76,7 +76,7 @@ class MatterNodeData: # a set of tuples in format (endpoint_id, cluster_id, attribute_id) # where each value can also be a '*' for wildcard attribute_subscriptions: set[tuple[int | str, int | str, int | str]] = field( - default_factory=list + default_factory=set ) diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index e23cc897..550b2f59 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -32,6 +32,7 @@ from ..common.helpers.util import ( create_attribute_path, create_attribute_path_from_attribute, + parse_attribute_path, dataclass_from_dict, ) from ..common.models import APICommand, EventType, MatterNodeData @@ -438,11 +439,9 @@ async def subscribe_attribute( node = self._nodes[node_id] # work out added subscriptions - attribute_paths = ( - set(attribute_path) - if isinstance(attribute_path, list) - else {attribute_path} - ) + if not isinstance(attribute_path, list): + attribute_path = [attribute_path] + attribute_paths = {parse_attribute_path(x) for x in attribute_path} if not node.attribute_subscriptions.difference(attribute_paths): return # nothing to do node.attribute_subscriptions.update(attribute_paths) @@ -490,7 +489,7 @@ async def _subscribe_node(self, node_id: int) -> None: endpoint_id, cluster_id, attribute_id, - ) in set.union(DEFAULT_SUBSCRIBE_ATTRIBUTES, node.attribute_subscriptions): + ) in set.union(DEFAULT_SUBSCRIBE_ATTRIBUTES, *node.attribute_subscriptions): endpoint: int | None = None if endpoint_id == "*" else endpoint_id cluster: Type[Cluster] = ALL_CLUSTERS[cluster_id] attribute: Type[ClusterAttributeDescriptor] | None = (