Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor subscription logic #335

Merged
merged 33 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7a29ba8
run initial interview and subscription logic in parallel
marcelveldt Jun 22, 2023
c68912d
bump schema version
marcelveldt Jun 22, 2023
fefba62
Only setup explicit subscriptions to devices
marcelveldt Jun 26, 2023
aceb711
use list instead of set
marcelveldt Jun 26, 2023
0903726
silence logging a bit
marcelveldt Jun 26, 2023
869b7ce
change default log handler
marcelveldt Jun 26, 2023
3b552a5
change save logic to prevent warning printed in logs
marcelveldt Jun 26, 2023
2b6de3d
add logic to detect endpoint changes on bridges
marcelveldt Jun 26, 2023
aff6b74
some renaming and cleanup
marcelveldt Jun 26, 2023
96e4ea2
Merge branch 'main' into change-subscriptions
marcelveldt Jun 26, 2023
20ee47c
silence logger some more
marcelveldt Jun 26, 2023
26eae82
Add locking to prevent concurrent access of same node
marcelveldt Jun 26, 2023
73302b7
change node logger
marcelveldt Jun 26, 2023
21ee14b
improve node lock
marcelveldt Jun 26, 2023
3dfb1fc
address some review feedback
marcelveldt Jun 26, 2023
cd6ff8e
attribute_paths as set
marcelveldt Jun 26, 2023
1c16fc6
some more review feedback
marcelveldt Jun 26, 2023
8f5ee31
some linting
marcelveldt Jun 26, 2023
2205e93
change wildcard
marcelveldt Jun 27, 2023
e030d67
Update matter_server/server/device_controller.py
marcelveldt Jun 27, 2023
3e9dc1c
some typos
marcelveldt Jun 27, 2023
87da481
comment
marcelveldt Jun 27, 2023
6818b52
prevent subscriptions of non existing nodes
marcelveldt Jun 27, 2023
2fda946
debounce resubscriptions
marcelveldt Jun 27, 2023
df713be
update attributes with current state from read request
marcelveldt Jun 27, 2023
f60a045
lint
marcelveldt Jun 27, 2023
455e3a0
some linting
marcelveldt Jun 27, 2023
eec2eb5
more lint
marcelveldt Jun 27, 2023
bb78873
morel inting
marcelveldt Jun 27, 2023
e0aa83d
Update matter_server/common/helpers/util.py
marcelveldt Jun 27, 2023
f11b5fe
Update matter_server/common/helpers/util.py
marcelveldt Jun 27, 2023
54f6ba7
Neighbor Discovery Protocol
marcelveldt Jun 27, 2023
6fd35c8
Merge branch 'change-subscriptions' of https://github.com/home-assist…
marcelveldt Jun 27, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ relies on the networking managed by your operating system.
Make sure your you run the container on the host network. The host network
interface needs to be in the same network as the Android/iPhone device
you are using for commissioning. Matter uses link-local multicast protocols
which do not work accross different LANs or VLANs.
which do not work across different LANs or VLANs.

The host network interface needs IPv6 support enabled.

Expand Down
42 changes: 36 additions & 6 deletions matter_server/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def subscribe(
Optionally filter by specific events or node attributes.
Returns:
function to unsubscribe.

NOTE: To receive attribute changed events,
you must also register the attributes to subscribe to
with the `subscribe_attributes` method.
"""
# for fast lookups we create a key based on the filters, allowing
# a "catch all" with a wildcard (*).
Expand Down Expand Up @@ -94,11 +98,11 @@ def unsubscribe() -> None:

return unsubscribe

async def get_nodes(self) -> list[MatterNode]:
def get_nodes(self) -> list[MatterNode]:
marcelveldt marked this conversation as resolved.
Show resolved Hide resolved
"""Return all Matter nodes."""
return list(self._nodes.values())

async def get_node(self, node_id: int) -> MatterNode:
def get_node(self, node_id: int) -> MatterNode:
"""Return Matter node by id."""
return self._nodes[node_id]

Expand Down Expand Up @@ -164,7 +168,7 @@ async def get_matter_fabrics(self, node_id: int) -> list[MatterFabricData]:
Returns a list of MatterFabricData objects.
"""

node = await self.get_node(node_id)
node = self.get_node(node_id)
fabrics: list[
Clusters.OperationalCredentials.Structs.FabricDescriptor
] = node.get_attribute_value(
Expand Down Expand Up @@ -229,6 +233,22 @@ async def remove_node(self, node_id: int) -> None:
"""Remove a Matter node/device from the fabric."""
await self.send_command(APICommand.REMOVE_NODE, node_id=node_id)

async def subscribe_attribute(
self, node_id: int, attribute_path: str | list[str]
) -> None:
"""
Subscribe to given AttributePath(s).

Either supply a single attribute path or a list of paths.
The given attribute path(s) will be added to the list of attributes that
are watched for the given node. This is persistent over restarts.
"""
await self.send_command(
APICommand.SUBSCRIBE_ATTRIBUTE,
node_id=node_id,
attribute_path=attribute_path,
)

async def send_command(
self,
command: str,
Expand Down Expand Up @@ -366,7 +386,7 @@ def _handle_incoming_message(self, msg: MessageType) -> None:

# handle EventMessage
if isinstance(msg, EventMessage):
self.logger.debug("Received event: %s", msg)
self.logger.debug("Received event: %s", msg.event)
self._handle_event_message(msg)
return

Expand All @@ -392,10 +412,20 @@ def _handle_event_message(self, msg: EventMessage) -> None:
node.update(node_data)
self._signal_event(event, data=node, node_id=node.node_id)
return
if msg.event == EventType.NODE_DELETED:
if msg.event == EventType.NODE_REMOVED:
node_id = msg.data
self._signal_event(EventType.NODE_REMOVED, data=node_id, node_id=node_id)
# cleanup node only after signalling subscribers
self._nodes.pop(node_id, None)
self._signal_event(EventType.NODE_DELETED, data=node_id, node_id=node_id)
return
if msg.event == EventType.ENDPOINT_REMOVED:
node_id = msg.data["node_id"]
self._signal_event(
EventType.ENDPOINT_REMOVED, data=msg.data, node_id=node_id
)
# cleanup endpoint only after signalling subscribers
if node := self._nodes.get(node_id):
node.endpoints.pop(msg.data["endpoint_id"], None)
return
if msg.event == EventType.ATTRIBUTE_UPDATED:
# data is tuple[node_id, attribute_path, new_value]
Expand Down
12 changes: 10 additions & 2 deletions matter_server/client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import asyncio
import logging
import os
import pprint
from typing import Any, Callable, Dict, Final, cast

Expand Down Expand Up @@ -33,6 +34,7 @@
from .models.node import MatterNode

LOGGER = logging.getLogger(f"{__package__}.connection")
VERBOSE_LOGGER = os.environ.get("MATTER_VERBOSE_LOGGING")
SUB_WILDCARD: Final = "*"


Expand Down Expand Up @@ -129,7 +131,10 @@ async def receive_message_or_raise(self) -> MessageType:
raise InvalidMessage("Received invalid JSON.") from err

if LOGGER.isEnabledFor(logging.DEBUG):
LOGGER.debug("Received message:\n%s\n", pprint.pformat(ws_msg))
if VERBOSE_LOGGER:
LOGGER.debug("Received message:\n%s\n", pprint.pformat(ws_msg))
else:
LOGGER.debug("Received message: %s ...", ws_msg.data[:50])

return msg

Expand All @@ -143,7 +148,10 @@ async def send_message(self, message: CommandMessage) -> None:
raise NotConnected

if LOGGER.isEnabledFor(logging.DEBUG):
LOGGER.debug("Publishing message:\n%s\n", pprint.pformat(message))
if VERBOSE_LOGGER:
LOGGER.debug("Publishing message:\n%s\n", pprint.pformat(message))
else:
LOGGER.debug("Publishing message: %s", message)

assert self._ws_client
assert isinstance(message, CommandMessage)
Expand Down
10 changes: 4 additions & 6 deletions matter_server/client/models/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ class MatterNode:
def __init__(self, node_data: MatterNodeData) -> None:
"""Initialize MatterNode from MatterNodeData."""
self.endpoints: dict[int, MatterEndpoint] = {}
self._is_bridge_device: bool = False
# composed devices reference to other endpoints through the partsList attribute
# create a mapping table
self._composed_endpoints: dict[int, int] = {}
Expand Down Expand Up @@ -251,7 +250,7 @@ def device_info(self) -> Clusters.BasicInformation:
@property
def is_bridge_device(self) -> bool:
"""Return if this Node is a Bridge/Aggregator device."""
return self._is_bridge_device
return self.node_data.is_bridge

def get_attribute_value(
self,
Expand Down Expand Up @@ -310,10 +309,6 @@ def update(self, node_data: MatterNodeData) -> None:
self.endpoints[endpoint_id] = MatterEndpoint(
endpoint_id=endpoint_id, attributes_data=attributes_data, node=self
)
# lookup if this is a bridge device
self._is_bridge_device = any(
Aggregator in x.device_types for x in self.endpoints.values()
)
# composed devices reference to other endpoints through the partsList attribute
# create a mapping table to quickly map this
for endpoint in self.endpoints.values():
Expand All @@ -339,6 +334,9 @@ def update(self, node_data: MatterNodeData) -> None:
def update_attribute(self, attribute_path: str, new_value: Any) -> None:
"""Handle Attribute value update."""
endpoint_id = int(attribute_path.split("/")[0])
if endpoint_id not in self.endpoints:
# race condition when a bridge is in the process of adding a new endpoint
return
self.endpoints[endpoint_id].set_attribute_value(attribute_path, new_value)

def __repr__(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion matter_server/common/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

# schema version is used to determine compatibility between server and client
# bump schema if we add new features and/or make other (breaking) changes
SCHEMA_VERSION = 3
SCHEMA_VERSION = 4
2 changes: 1 addition & 1 deletion matter_server/common/helpers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def parse_arguments(
if strict:
for key, value in args.items():
if key not in func_sig.parameters:
raise KeyError("Invalid parameter: '%s'" % key)
raise KeyError(f"Invalid parameter: '{key}'")
# parse arguments to correct type
for name, param in func_sig.parameters.items():
value = args.get(name)
Expand Down
5 changes: 1 addition & 4 deletions matter_server/common/helpers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from dataclasses import is_dataclass
from typing import Any

from chip.clusters.Attribute import ValueDecodeFailure
from chip.clusters.Types import Nullable
from chip.tlv import float32, uint
import orjson
Expand All @@ -22,8 +21,6 @@ def json_encoder_default(obj: Any) -> Any:
"""
if getattr(obj, "do_not_serialize", None):
return None
if isinstance(obj, ValueDecodeFailure):
marcelveldt marked this conversation as resolved.
Show resolved Hide resolved
return None
if isinstance(obj, (set, tuple)):
return list(obj)
if isinstance(obj, float32):
Expand All @@ -40,7 +37,7 @@ def json_encoder_default(obj: Any) -> Any:
return b64encode(obj).decode("utf-8")
if isinstance(obj, Exception):
return str(obj)
if type(obj) == type:
if type(obj) == type: # pylint: disable=unidiomatic-typecheck
return f"{obj.__module__}.{obj.__qualname__}"

raise TypeError
Expand Down
17 changes: 9 additions & 8 deletions matter_server/common/helpers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def parse_utc_timestamp(datetime_string: str) -> datetime:

def parse_value(name: str, value: Any, value_type: Any, default: Any = MISSING) -> Any:
"""Try to parse a value from raw (json) data and type annotations."""
# pylint: disable=too-many-return-statements,too-many-branches

if isinstance(value_type, str):
# this shouldn't happen, but just in case
Expand All @@ -105,14 +106,14 @@ def parse_value(name: str, value: Any, value_type: Any, default: Any = MISSING)
return dataclass_from_dict(value_type, value)
# get origin value type and inspect one-by-one
origin: Any = get_origin(value_type)
if origin in (list, tuple) and isinstance(value, list | tuple):
if origin in (list, tuple, set) and isinstance(value, (list, tuple, set)):
return origin(
parse_value(name, subvalue, get_args(value_type)[0])
for subvalue in value
if subvalue is not None
)
# handle dictionary where we should inspect all values
elif origin is dict:
if origin is dict:
subkey_type = get_args(value_type)[0]
subvalue_type = get_args(value_type)[1]
return {
Expand All @@ -122,7 +123,7 @@ def parse_value(name: str, value: Any, value_type: Any, default: Any = MISSING)
for subkey, subvalue in value.items()
}
# handle Union type
elif origin is Union or origin is UnionType:
if origin is Union or origin is UnionType:
# try all possible types
sub_value_types = get_args(value_type)
for sub_arg_type in sub_value_types:
Expand All @@ -143,9 +144,9 @@ def parse_value(name: str, value: Any, value_type: Any, default: Any = MISSING)
# raise exception, we have no idea how to handle this value
raise TypeError(err)
# failed to parse the (sub) value but None allowed, log only
logging.getLogger(__name__).warn(err)
logging.getLogger(__name__).warning(err)
return None
elif origin is type:
if origin is type:
return get_type_hints(value, globals(), locals())
# handle Any as value type (which is basically unprocessable)
if value_type is Any:
Expand All @@ -157,6 +158,7 @@ def parse_value(name: str, value: Any, value_type: Any, default: Any = MISSING)
try:
if issubclass(value_type, Enum):
# handle enums from the SDK that have a value that does not exist in the enum (sigh)
# pylint: disable=protected-access
if value not in value_type._value2member_map_:
# we do not want to crash so we return the raw value
return value
Expand Down Expand Up @@ -208,11 +210,10 @@ def dataclass_from_dict(cls: type[_T], dict_obj: dict, strict: bool = False) ->
If strict mode enabled, any additional keys in the provided dict will result in a KeyError.
"""
if strict:
extra_keys = dict_obj.keys() - set([f.name for f in fields(cls)])
extra_keys = dict_obj.keys() - {f.name for f in fields(cls)}
if extra_keys:
raise KeyError(
"Extra key(s) %s not allowed for %s"
% (",".join(extra_keys), (str(cls)))
f'Extra key(s) {",".join(extra_keys)} not allowed for {str(cls)}'
)
type_hints = get_type_hints(cls)
return cls(
Expand Down
12 changes: 11 additions & 1 deletion matter_server/common/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ class EventType(Enum):

NODE_ADDED = "node_added"
NODE_UPDATED = "node_updated"
NODE_DELETED = "node_deleted"
NODE_REMOVED = "node_removed"
NODE_EVENT = "node_event"
ATTRIBUTE_UPDATED = "attribute_updated"
SERVER_SHUTDOWN = "server_shutdown"
ENDPOINT_ADDED = "endpoint_added"
ENDPOINT_REMOVED = "endpoint_removed"


class APICommand(str, Enum):
Expand All @@ -38,6 +40,7 @@ class APICommand(str, Enum):
DEVICE_COMMAND = "device_command"
REMOVE_NODE = "remove_node"
GET_VENDOR_NAMES = "get_vendor_names"
SUBSCRIBE_ATTRIBUTE = "subscribe_attribute"


EventCallBackType = Callable[[EventType, Any], None]
Expand Down Expand Up @@ -66,8 +69,15 @@ class MatterNodeData:
last_interview: datetime
interview_version: int
available: bool = False
is_bridge: bool = False
# attributes are stored in form of AttributePath: ENDPOINT/CLUSTER_ID/ATTRIBUTE_ID
attributes: dict[str, Any] = field(default_factory=dict)
# all attribute subscriptions we need to persist for this node,
# a set of tuples in format (endpoint_id, cluster_id, attribute_id)
# where each value can also be a '*' for wildcard
attribute_subscriptions: set[tuple[int | str, int | str, int | str]] = field(
default_factory=set
)


@dataclass
Expand Down
2 changes: 2 additions & 0 deletions matter_server/server/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def main() -> None:
handlers = [logging.FileHandler(args.log_file)]
logging.basicConfig(handlers=handlers, level=args.log_level.upper())
coloredlogs.install(level=args.log_level.upper())
if not logging.getLogger().isEnabledFor(logging.DEBUG):
logging.getLogger("chip").setLevel(logging.WARNING)

# make sure storage path exists
if not os.path.isdir(args.storage_path):
Expand Down
Loading