Skip to content

Commit

Permalink
Refactor subscription logic (#335)
Browse files Browse the repository at this point in the history
Refactor of the subscription logic for both optimizations reasons and less traffic (prevent congestion).

Run interviews and subscriptions in parallel tasks to speedup startup.
Only subscribe to specific attributes
Detect endpoint additions and removals on bridges


Co-authored-by: Martin Hjelmare <[email protected]>
  • Loading branch information
marcelveldt and MartinHjelmare committed Jun 27, 2023
1 parent 3921c04 commit beb3d64
Show file tree
Hide file tree
Showing 13 changed files with 400 additions and 172 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ relies on the networking managed by your operating system.
Make sure your you run the container on the host network. The host network
interface needs to be in the same network as the Android/iPhone device
you are using for commissioning. Matter uses link-local multicast protocols
which do not work accross different LANs or VLANs.
which do not work across different LANs or VLANs.

The host network interface needs IPv6 support enabled.

Expand All @@ -54,7 +54,7 @@ For communication through Thread border routers which are not running on the sam
host as the Matter Controller server to work, IPv6 routing needs to be properly
working. IPv6 routing is largely setup automatically through the IPv6 Neighbor
Discovery Protocol, specifically the Route Information Options (RIO). However,
if IPv6 ND RIO's are processed, and processed correctly depends on the network
if IPv6 Neighbor Discovery RIO's are processed, and processed correctly depends on the network
management software your system is using. There may be bugs and cavats in
processing this Route Information Options.

Expand Down
42 changes: 36 additions & 6 deletions matter_server/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def subscribe(
Optionally filter by specific events or node attributes.
Returns:
function to unsubscribe.
NOTE: To receive attribute changed events,
you must also register the attributes to subscribe to
with the `subscribe_attributes` method.
"""
# for fast lookups we create a key based on the filters, allowing
# a "catch all" with a wildcard (*).
Expand Down Expand Up @@ -94,11 +98,11 @@ def unsubscribe() -> None:

return unsubscribe

async def get_nodes(self) -> list[MatterNode]:
def get_nodes(self) -> list[MatterNode]:
"""Return all Matter nodes."""
return list(self._nodes.values())

async def get_node(self, node_id: int) -> MatterNode:
def get_node(self, node_id: int) -> MatterNode:
"""Return Matter node by id."""
return self._nodes[node_id]

Expand Down Expand Up @@ -164,7 +168,7 @@ async def get_matter_fabrics(self, node_id: int) -> list[MatterFabricData]:
Returns a list of MatterFabricData objects.
"""

node = await self.get_node(node_id)
node = self.get_node(node_id)
fabrics: list[
Clusters.OperationalCredentials.Structs.FabricDescriptor
] = node.get_attribute_value(
Expand Down Expand Up @@ -229,6 +233,22 @@ async def remove_node(self, node_id: int) -> None:
"""Remove a Matter node/device from the fabric."""
await self.send_command(APICommand.REMOVE_NODE, node_id=node_id)

async def subscribe_attribute(
self, node_id: int, attribute_path: str | list[str]
) -> None:
"""
Subscribe to given AttributePath(s).
Either supply a single attribute path or a list of paths.
The given attribute path(s) will be added to the list of attributes that
are watched for the given node. This is persistent over restarts.
"""
await self.send_command(
APICommand.SUBSCRIBE_ATTRIBUTE,
node_id=node_id,
attribute_path=attribute_path,
)

async def send_command(
self,
command: str,
Expand Down Expand Up @@ -366,7 +386,7 @@ def _handle_incoming_message(self, msg: MessageType) -> None:

# handle EventMessage
if isinstance(msg, EventMessage):
self.logger.debug("Received event: %s", msg)
self.logger.debug("Received event: %s", msg.event)
self._handle_event_message(msg)
return

Expand All @@ -392,10 +412,20 @@ def _handle_event_message(self, msg: EventMessage) -> None:
node.update(node_data)
self._signal_event(event, data=node, node_id=node.node_id)
return
if msg.event == EventType.NODE_DELETED:
if msg.event == EventType.NODE_REMOVED:
node_id = msg.data
self._signal_event(EventType.NODE_REMOVED, data=node_id, node_id=node_id)
# cleanup node only after signalling subscribers
self._nodes.pop(node_id, None)
self._signal_event(EventType.NODE_DELETED, data=node_id, node_id=node_id)
return
if msg.event == EventType.ENDPOINT_REMOVED:
node_id = msg.data["node_id"]
self._signal_event(
EventType.ENDPOINT_REMOVED, data=msg.data, node_id=node_id
)
# cleanup endpoint only after signalling subscribers
if node := self._nodes.get(node_id):
node.endpoints.pop(msg.data["endpoint_id"], None)
return
if msg.event == EventType.ATTRIBUTE_UPDATED:
# data is tuple[node_id, attribute_path, new_value]
Expand Down
12 changes: 10 additions & 2 deletions matter_server/client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import asyncio
import logging
import os
import pprint
from typing import Any, Callable, Dict, Final, cast

Expand Down Expand Up @@ -33,6 +34,7 @@
from .models.node import MatterNode

LOGGER = logging.getLogger(f"{__package__}.connection")
VERBOSE_LOGGER = os.environ.get("MATTER_VERBOSE_LOGGING")
SUB_WILDCARD: Final = "*"


Expand Down Expand Up @@ -129,7 +131,10 @@ async def receive_message_or_raise(self) -> MessageType:
raise InvalidMessage("Received invalid JSON.") from err

if LOGGER.isEnabledFor(logging.DEBUG):
LOGGER.debug("Received message:\n%s\n", pprint.pformat(ws_msg))
if VERBOSE_LOGGER:
LOGGER.debug("Received message:\n%s\n", pprint.pformat(ws_msg))
else:
LOGGER.debug("Received message: %s ...", ws_msg.data[:50])

return msg

Expand All @@ -143,7 +148,10 @@ async def send_message(self, message: CommandMessage) -> None:
raise NotConnected

if LOGGER.isEnabledFor(logging.DEBUG):
LOGGER.debug("Publishing message:\n%s\n", pprint.pformat(message))
if VERBOSE_LOGGER:
LOGGER.debug("Publishing message:\n%s\n", pprint.pformat(message))
else:
LOGGER.debug("Publishing message: %s", message)

assert self._ws_client
assert isinstance(message, CommandMessage)
Expand Down
10 changes: 4 additions & 6 deletions matter_server/client/models/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ class MatterNode:
def __init__(self, node_data: MatterNodeData) -> None:
"""Initialize MatterNode from MatterNodeData."""
self.endpoints: dict[int, MatterEndpoint] = {}
self._is_bridge_device: bool = False
# composed devices reference to other endpoints through the partsList attribute
# create a mapping table
self._composed_endpoints: dict[int, int] = {}
Expand Down Expand Up @@ -251,7 +250,7 @@ def device_info(self) -> Clusters.BasicInformation:
@property
def is_bridge_device(self) -> bool:
"""Return if this Node is a Bridge/Aggregator device."""
return self._is_bridge_device
return self.node_data.is_bridge

def get_attribute_value(
self,
Expand Down Expand Up @@ -310,10 +309,6 @@ def update(self, node_data: MatterNodeData) -> None:
self.endpoints[endpoint_id] = MatterEndpoint(
endpoint_id=endpoint_id, attributes_data=attributes_data, node=self
)
# lookup if this is a bridge device
self._is_bridge_device = any(
Aggregator in x.device_types for x in self.endpoints.values()
)
# composed devices reference to other endpoints through the partsList attribute
# create a mapping table to quickly map this
for endpoint in self.endpoints.values():
Expand All @@ -339,6 +334,9 @@ def update(self, node_data: MatterNodeData) -> None:
def update_attribute(self, attribute_path: str, new_value: Any) -> None:
"""Handle Attribute value update."""
endpoint_id = int(attribute_path.split("/")[0])
if endpoint_id not in self.endpoints:
# race condition when a bridge is in the process of adding a new endpoint
return
self.endpoints[endpoint_id].set_attribute_value(attribute_path, new_value)

def __repr__(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion matter_server/common/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

# schema version is used to determine compatibility between server and client
# bump schema if we add new features and/or make other (breaking) changes
SCHEMA_VERSION = 3
SCHEMA_VERSION = 4
2 changes: 1 addition & 1 deletion matter_server/common/helpers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def parse_arguments(
if strict:
for key, value in args.items():
if key not in func_sig.parameters:
raise KeyError("Invalid parameter: '%s'" % key)
raise KeyError(f"Invalid parameter: '{key}'")
# parse arguments to correct type
for name, param in func_sig.parameters.items():
value = args.get(name)
Expand Down
5 changes: 1 addition & 4 deletions matter_server/common/helpers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from dataclasses import is_dataclass
from typing import Any

from chip.clusters.Attribute import ValueDecodeFailure
from chip.clusters.Types import Nullable
from chip.tlv import float32, uint
import orjson
Expand All @@ -22,8 +21,6 @@ def json_encoder_default(obj: Any) -> Any:
"""
if getattr(obj, "do_not_serialize", None):
return None
if isinstance(obj, ValueDecodeFailure):
return None
if isinstance(obj, (set, tuple)):
return list(obj)
if isinstance(obj, float32):
Expand All @@ -40,7 +37,7 @@ def json_encoder_default(obj: Any) -> Any:
return b64encode(obj).decode("utf-8")
if isinstance(obj, Exception):
return str(obj)
if type(obj) == type:
if type(obj) == type: # pylint: disable=unidiomatic-typecheck
return f"{obj.__module__}.{obj.__qualname__}"

raise TypeError
Expand Down
17 changes: 9 additions & 8 deletions matter_server/common/helpers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def parse_utc_timestamp(datetime_string: str) -> datetime:

def parse_value(name: str, value: Any, value_type: Any, default: Any = MISSING) -> Any:
"""Try to parse a value from raw (json) data and type annotations."""
# pylint: disable=too-many-return-statements,too-many-branches

if isinstance(value_type, str):
# this shouldn't happen, but just in case
Expand All @@ -105,14 +106,14 @@ def parse_value(name: str, value: Any, value_type: Any, default: Any = MISSING)
return dataclass_from_dict(value_type, value)
# get origin value type and inspect one-by-one
origin: Any = get_origin(value_type)
if origin in (list, tuple) and isinstance(value, list | tuple):
if origin in (list, tuple, set) and isinstance(value, (list, tuple, set)):
return origin(
parse_value(name, subvalue, get_args(value_type)[0])
for subvalue in value
if subvalue is not None
)
# handle dictionary where we should inspect all values
elif origin is dict:
if origin is dict:
subkey_type = get_args(value_type)[0]
subvalue_type = get_args(value_type)[1]
return {
Expand All @@ -122,7 +123,7 @@ def parse_value(name: str, value: Any, value_type: Any, default: Any = MISSING)
for subkey, subvalue in value.items()
}
# handle Union type
elif origin is Union or origin is UnionType:
if origin is Union or origin is UnionType:
# try all possible types
sub_value_types = get_args(value_type)
for sub_arg_type in sub_value_types:
Expand All @@ -143,9 +144,9 @@ def parse_value(name: str, value: Any, value_type: Any, default: Any = MISSING)
# raise exception, we have no idea how to handle this value
raise TypeError(err)
# failed to parse the (sub) value but None allowed, log only
logging.getLogger(__name__).warn(err)
logging.getLogger(__name__).warning(err)
return None
elif origin is type:
if origin is type:
return get_type_hints(value, globals(), locals())
# handle Any as value type (which is basically unprocessable)
if value_type is Any:
Expand All @@ -157,6 +158,7 @@ def parse_value(name: str, value: Any, value_type: Any, default: Any = MISSING)
try:
if issubclass(value_type, Enum):
# handle enums from the SDK that have a value that does not exist in the enum (sigh)
# pylint: disable=protected-access
if value not in value_type._value2member_map_:
# we do not want to crash so we return the raw value
return value
Expand Down Expand Up @@ -208,11 +210,10 @@ def dataclass_from_dict(cls: type[_T], dict_obj: dict, strict: bool = False) ->
If strict mode enabled, any additional keys in the provided dict will result in a KeyError.
"""
if strict:
extra_keys = dict_obj.keys() - set([f.name for f in fields(cls)])
extra_keys = dict_obj.keys() - {f.name for f in fields(cls)}
if extra_keys:
raise KeyError(
"Extra key(s) %s not allowed for %s"
% (",".join(extra_keys), (str(cls)))
f'Extra key(s) {",".join(extra_keys)} not allowed for {str(cls)}'
)
type_hints = get_type_hints(cls)
return cls(
Expand Down
12 changes: 11 additions & 1 deletion matter_server/common/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ class EventType(Enum):

NODE_ADDED = "node_added"
NODE_UPDATED = "node_updated"
NODE_DELETED = "node_deleted"
NODE_REMOVED = "node_removed"
NODE_EVENT = "node_event"
ATTRIBUTE_UPDATED = "attribute_updated"
SERVER_SHUTDOWN = "server_shutdown"
ENDPOINT_ADDED = "endpoint_added"
ENDPOINT_REMOVED = "endpoint_removed"


class APICommand(str, Enum):
Expand All @@ -38,6 +40,7 @@ class APICommand(str, Enum):
DEVICE_COMMAND = "device_command"
REMOVE_NODE = "remove_node"
GET_VENDOR_NAMES = "get_vendor_names"
SUBSCRIBE_ATTRIBUTE = "subscribe_attribute"


EventCallBackType = Callable[[EventType, Any], None]
Expand Down Expand Up @@ -66,8 +69,15 @@ class MatterNodeData:
last_interview: datetime
interview_version: int
available: bool = False
is_bridge: bool = False
# attributes are stored in form of AttributePath: ENDPOINT/CLUSTER_ID/ATTRIBUTE_ID
attributes: dict[str, Any] = field(default_factory=dict)
# all attribute subscriptions we need to persist for this node,
# a set of tuples in format (endpoint_id, cluster_id, attribute_id)
# where each value can also be a '*' for wildcard
attribute_subscriptions: set[tuple[int | str, int | str, int | str]] = field(
default_factory=set
)


@dataclass
Expand Down
2 changes: 2 additions & 0 deletions matter_server/server/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def main() -> None:
handlers = [logging.FileHandler(args.log_file)]
logging.basicConfig(handlers=handlers, level=args.log_level.upper())
coloredlogs.install(level=args.log_level.upper())
if not logging.getLogger().isEnabledFor(logging.DEBUG):
logging.getLogger("chip").setLevel(logging.WARNING)

# make sure storage path exists
if not os.path.isdir(args.storage_path):
Expand Down
Loading

0 comments on commit beb3d64

Please sign in to comment.