Skip to content

Commit

Permalink
attribute_paths as set
Browse files Browse the repository at this point in the history
  • Loading branch information
marcelveldt committed Jun 26, 2023
1 parent 3dfb1fc commit cd6ff8e
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
2 changes: 1 addition & 1 deletion matter_server/common/helpers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def parse_value(name: str, value: Any, value_type: Any, default: Any = MISSING)
return dataclass_from_dict(value_type, value)
# get origin value type and inspect one-by-one
origin: Any = get_origin(value_type)
if origin in (list, tuple, set) and isinstance(value, list | tuple | set):
if origin in (list, tuple, set) and isinstance(value, (list, tuple, set)):
return origin(
parse_value(name, subvalue, get_args(value_type)[0])
for subvalue in value
Expand Down
2 changes: 1 addition & 1 deletion matter_server/common/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class MatterNodeData:
# a set of tuples in format (endpoint_id, cluster_id, attribute_id)
# where each value can also be a '*' for wildcard
attribute_subscriptions: set[tuple[int | str, int | str, int | str]] = field(
default_factory=list
default_factory=set
)


Expand Down
11 changes: 5 additions & 6 deletions matter_server/server/device_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ..common.helpers.util import (
create_attribute_path,
create_attribute_path_from_attribute,
parse_attribute_path,
dataclass_from_dict,
)
from ..common.models import APICommand, EventType, MatterNodeData
Expand Down Expand Up @@ -438,11 +439,9 @@ async def subscribe_attribute(
node = self._nodes[node_id]

# work out added subscriptions
attribute_paths = (
set(attribute_path)
if isinstance(attribute_path, list)
else {attribute_path}
)
if not isinstance(attribute_path, list):
attribute_path = [attribute_path]
attribute_paths = {parse_attribute_path(x) for x in attribute_path}
if not node.attribute_subscriptions.difference(attribute_paths):
return # nothing to do
node.attribute_subscriptions.update(attribute_paths)
Expand Down Expand Up @@ -490,7 +489,7 @@ async def _subscribe_node(self, node_id: int) -> None:
endpoint_id,
cluster_id,
attribute_id,
) in set.union(DEFAULT_SUBSCRIBE_ATTRIBUTES, node.attribute_subscriptions):
) in set.union(DEFAULT_SUBSCRIBE_ATTRIBUTES, *node.attribute_subscriptions):
endpoint: int | None = None if endpoint_id == "*" else endpoint_id
cluster: Type[Cluster] = ALL_CLUSTERS[cluster_id]
attribute: Type[ClusterAttributeDescriptor] | None = (
Expand Down

0 comments on commit cd6ff8e

Please sign in to comment.