Skip to content

Commit

Permalink
Improvements and bugfixes for persistent storage of node data (#514)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcelveldt authored Jan 29, 2024
1 parent bfad6ee commit 65a4d60
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 40 deletions.
77 changes: 39 additions & 38 deletions matter_server/server/device_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from ..common.helpers.util import (
create_attribute_path_from_attribute,
dataclass_from_dict,
dataclass_to_dict,
parse_attribute_path,
parse_value,
)
Expand Down Expand Up @@ -101,7 +102,7 @@ def __init__(
self._attr_subscriptions: dict[int, list[Attribute.AttributePath]] = {}
self._resub_debounce_timer: dict[int, asyncio.TimerHandle] = {}
self._sub_retry_timer: dict[int, asyncio.TimerHandle] = {}
self._nodes: dict[int, MatterNodeData | None] = {}
self._nodes: dict[int, MatterNodeData] = {}
self._last_subscription_attempt: dict[int, int] = {}
self.wifi_credentials_set: bool = False
self.thread_credentials_set: bool = False
Expand All @@ -126,20 +127,22 @@ async def start(self) -> None:
"""Handle logic on controller start."""
# load nodes from persistent storage
nodes: dict[str, dict | None] = self.server.storage.get(DATA_KEY_NODES, {})
orphaned_nodes: set[str] = set()
for node_id_str, node_dict in nodes.items():
node_id = int(node_id_str)
if node_dict is None:
# ignore non-initialized (left-over) nodes
# from failed commissioning attempts
# Non-initialized (left-over) node from a failed commissioning attempt.
# NOTE: This code can be removed in a future version
# as this can no longer happen.
orphaned_nodes.add(node_id_str)
continue
if node_dict.get("interview_version") != SCHEMA_VERSION:
# invalidate node data if schema mismatch,
# the node will automatically be scheduled for re-interview
node = None
else:
node = dataclass_from_dict(MatterNodeData, node_dict)
# always mark node as unavailable at startup until subscriptions are ready
node.available = False
# Invalidate node attributes data if schema mismatch,
# the node will automatically be scheduled for re-interview.
node_dict["attributes"] = {}
node = dataclass_from_dict(MatterNodeData, node_dict)
# always mark node as unavailable at startup until subscriptions are ready
node.available = False
self._nodes[node_id] = node
# setup subscription and (re)interview as task in the background
# as we do not want it to block our startup
Expand All @@ -148,9 +151,12 @@ async def start(self) -> None:
# the first attempt to initialize so that we prioritize nodes
# that are probably available so they are back online as soon as
# possible and we're not stuck trying to initialize nodes that are offline
self._schedule_interview(node_id, 5)
self._schedule_interview(node_id, 30)
else:
asyncio.create_task(self._check_interview_and_subscription(node_id))
# cleanup orhpaned nodes from storage
for node_id_str in orphaned_nodes:
self.server.storage.remove(DATA_KEY_NODES, node_id_str)
LOGGER.info("Loaded %s nodes from stored configuration", len(self._nodes))

async def stop(self) -> None:
Expand Down Expand Up @@ -461,12 +467,7 @@ async def interview_node(self, node_id: int) -> None:

# save updated node data
self._nodes[node_id] = node
self.server.storage.set(
DATA_KEY_NODES,
subkey=str(node_id),
value=node,
force=not existing_info,
)
self._write_node_state(node_id, True)
if is_new_node:
# new node - first interview
self.server.signal_event(EventType.NODE_ADDED, node)
Expand Down Expand Up @@ -536,9 +537,8 @@ async def read_attribute(
result: Attribute.AsyncReadTransaction.ReadResponse = await future
read_atributes = parse_attributes_from_read_result(result.tlvAttributes)
# update cached info in node attributes
self._nodes[node_id].attributes.update( # type: ignore[union-attr]
read_atributes
)
self._nodes[node_id].attributes.update(read_atributes)
self._write_node_state(node_id)
if len(read_atributes) > 1:
return read_atributes
return read_atributes.get(attribute_path, None)
Expand Down Expand Up @@ -587,7 +587,6 @@ async def remove_node(self, node_id: int) -> None:
DATA_KEY_NODES,
subkey=str(node_id),
)
self.server.storage.save(immediate=True)
LOGGER.info("Node ID %s successfully removed from Matter server.", node_id)

self.server.signal_event(EventType.NODE_REMOVED, node_id)
Expand Down Expand Up @@ -654,11 +653,7 @@ async def subscribe_attribute(
if prev_subs == node.attribute_subscriptions:
return # nothing to do
# save updated node data
self.server.storage.set(
DATA_KEY_NODES,
subkey=str(node_id),
value=node,
)
self._write_node_state(node_id)

# (re)setup node subscription
# this could potentially be called multiple times within a short timeframe
Expand Down Expand Up @@ -755,7 +750,7 @@ async def _subscribe_node(self, node_id: int) -> None:

node_logger = LOGGER.getChild(f"[node {node_id}]")
node_lock = self._get_node_lock(node_id)
node = cast(MatterNodeData, self._nodes[node_id])
node = self._nodes[node_id]

# work out all (current) attribute subscriptions
attr_subscriptions: list[Attribute.AttributePath] = list(
Expand Down Expand Up @@ -825,6 +820,10 @@ def attribute_updated_callback(
attr_path = str(path.Path)
old_value = node.attributes.get(attr_path)

# return early if the value did not actually change at all
if old_value == new_value:
return

node_logger.debug(
"Attribute updated: %s - old value: %s - new value: %s",
path,
Expand Down Expand Up @@ -862,11 +861,7 @@ def attribute_updated_callback(
node.attributes[attr_path] = new_value

# schedule save to persistent storage
self.server.storage.set(
DATA_KEY_NODES,
subkey=str(node_id),
value=node,
)
self._write_node_state(node_id)

# This callback is running in the CHIP stack thread
loop.call_soon_threadsafe(
Expand Down Expand Up @@ -1108,7 +1103,7 @@ async def _resolve_node(

def _handle_endpoints_removed(self, node_id: int, endpoints: Iterable[int]) -> None:
"""Handle callback for when bridge endpoint(s) get deleted."""
node = cast(MatterNodeData, self._nodes[node_id])
node = self._nodes[node_id]
for endpoint_id in endpoints:
node.attributes = {
key: value
Expand All @@ -1120,11 +1115,7 @@ def _handle_endpoints_removed(self, node_id: int, endpoints: Iterable[int]) -> N
{"node_id": node_id, "endpoint_id": endpoint_id},
)
# schedule save to persistent storage
self.server.storage.set(
DATA_KEY_NODES,
subkey=str(node_id),
value=node,
)
self._write_node_state(node_id)

async def _handle_endpoints_added(
self, node_id: int, endpoints: Iterable[int]
Expand All @@ -1144,3 +1135,13 @@ def _get_node_lock(self, node_id: int) -> asyncio.Lock:
if node_id not in self._node_lock:
self._node_lock[node_id] = asyncio.Lock()
return self._node_lock[node_id]

def _write_node_state(self, node_id: int, force: bool = False) -> None:
"""Schedule the write of the current node state to persistent storage."""
node = self._nodes[node_id]
self.server.storage.set(
DATA_KEY_NODES,
value=dataclass_to_dict(node),
subkey=str(node_id),
force=force,
)
26 changes: 24 additions & 2 deletions scripts/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@
default="info",
help="Provide logging level. Example --log-level debug, default=info, possible=(critical, error, warning, info, debug)",
)
parser.add_argument(
"--primary-interface",
type=str,
default=None,
help="Primary network interface for link-local addresses (optional).",
)

args = parser.parse_args()

Expand All @@ -58,7 +64,11 @@

# Init server
server = MatterServer(
args.storage_path, DEFAULT_VENDOR_ID, DEFAULT_FABRIC_ID, int(args.port)
args.storage_path,
DEFAULT_VENDOR_ID,
DEFAULT_FABRIC_ID,
int(args.port),
args.primary_interface,
)

async def run_matter():
Expand All @@ -71,7 +81,19 @@ async def run_matter():
async with aiohttp.ClientSession() as session:
async with MatterClient(url, session) as client:
# start listening
await client.start_listening()
asyncio.create_task(client.start_listening())
# allow the client to initialize
await asyncio.sleep(10)
# dump full node info on random (available) node
for node in client.get_nodes():
if not node.available:
continue
print()
print(node)
res = await client.node_diagnostics(node.node_id)
print(res)
print()
break

async def handle_stop(loop: asyncio.AbstractEventLoop):
"""Handle server stop."""
Expand Down

0 comments on commit 65a4d60

Please sign in to comment.