Skip to content

Commit

Permalink
Retry subscription setup if necessary (#873)
Browse files Browse the repository at this point in the history
Co-authored-by: Marcel van der Veldt <[email protected]>
  • Loading branch information
agners and marcelveldt authored Sep 4, 2024
1 parent d3cb494 commit 68754b0
Showing 1 changed file with 65 additions and 50 deletions.
115 changes: 65 additions & 50 deletions matter_server/server/device_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@
NODE_RESUBSCRIBE_FORCE_TIMEOUT = 5
NODE_PING_TIMEOUT = 10
NODE_PING_TIMEOUT_BATTERY_POWERED = 60
NODE_MDNS_BACKOFF = 610 # must be higher than (highest) sub ceiling
FALLBACK_NODE_SCANNER_INTERVAL = 1800
NODE_MDNS_SUBSCRIPTION_RETRY_TIMEOUT = 30 * 60
FALLBACK_NODE_SCANNER_INTERVAL = 30 * 60
CUSTOM_ATTRIBUTES_POLLER_INTERVAL = 30

MDNS_TYPE_OPERATIONAL_NODE = "_matter._tcp.local."
Expand Down Expand Up @@ -144,9 +144,9 @@ def __init__(
self._fabric_id_hex: str | None = None
self._wifi_credentials_set: bool = False
self._thread_credentials_set: bool = False
self._nodes_in_setup: set[int] = set()
self._setup_node_tasks = dict[int, asyncio.Task]()
self._nodes_in_ota: set[int] = set()
self._node_last_seen: dict[int, float] = {}
self._node_last_seen_on_mdns: dict[int, float] = {}
self._nodes: dict[int, MatterNodeData] = {}
self._last_known_ip_addresses: dict[int, list[str]] = {}
self._last_subscription_attempt: dict[int, int] = {}
Expand Down Expand Up @@ -230,6 +230,9 @@ async def stop(self) -> None:
scan_task.cancel()
if self._aiozc:
await self._aiozc.async_close()
# Ensure any in-progress setup tasks are cancelled
for task in self._setup_node_tasks.values():
task.cancel()

# shutdown the sdk device controller
await self._chip_device_controller.shutdown()
Expand Down Expand Up @@ -340,7 +343,8 @@ async def commission_with_code(
break

# make sure we start a subscription for this newly added node
await self._setup_node(node_id)
if task := self._setup_node_create_task(node_id):
await task
LOGGER.info("Commissioning of Node ID %s completed.", node_id)
# return full node object once we're complete
return self.get_node(node_id)
Expand Down Expand Up @@ -418,7 +422,8 @@ async def commission_on_network(
else:
break
# make sure we start a subscription for this newly added node
await self._setup_node(node_id)
if task := self._setup_node_create_task(node_id):
await task
LOGGER.info("Commissioning of Node ID %s completed.", node_id)
# return full node object once we're complete
return self.get_node(node_id)
Expand Down Expand Up @@ -765,6 +770,9 @@ async def remove_node(self, node_id: int) -> None:

LOGGER.info("Removing Node ID %s.", node_id)

if task := self._setup_node_tasks.pop(node_id, None):
task.cancel()

# shutdown any existing subscriptions
await self._chip_device_controller.shutdown_subscription(node_id)
self._polled_attributes.pop(node_id, None)
Expand Down Expand Up @@ -1114,7 +1122,6 @@ def attribute_updated_callback_threadsafe(
path: Attribute.AttributePath,
transaction: Attribute.SubscriptionTransaction,
) -> None:
self._node_last_seen[node_id] = time.time()
new_value = transaction.GetTLVAttribute(path)
# failsafe: ignore ValueDecodeErrors
# these are set by the SDK if parsing the value failed miserably
Expand Down Expand Up @@ -1142,7 +1149,6 @@ def event_callback(
data,
transaction,
)
self._node_last_seen[node_id] = time.time()
node_event = MatterNodeEvent(
node_id=node_id,
endpoint_id=data.Header.EndpointId,
Expand Down Expand Up @@ -1203,7 +1209,6 @@ def resubscription_succeeded(
transaction: Attribute.SubscriptionTransaction,
) -> None:
# pylint: disable=unused-argument, invalid-name
self._node_last_seen[node_id] = time.time()
node_logger.info("Re-Subscription succeeded")
self._last_subscription_attempt[node_id] = 0
# mark node as available and signal consumers
Expand Down Expand Up @@ -1263,7 +1268,6 @@ def resubscription_succeeded(
report_interval_ceiling,
)

self._node_last_seen[node_id] = time.time()
self.server.signal_event(EventType.NODE_UPDATED, node)

def _get_next_node_id(self) -> int:
Expand All @@ -1272,16 +1276,12 @@ def _get_next_node_id(self) -> int:
self.server.storage.set(DATA_KEY_LAST_NODE_ID, next_node_id, force=True)
return next_node_id

async def _setup_node(self, node_id: int) -> None:
async def _setup_node_try_once(
self,
node_logger: logging.LoggerAdapter,
node_id: int,
) -> None:
"""Handle set-up of subscriptions and interview (if needed) for known/discovered node."""
if node_id not in self._nodes:
raise NodeNotExists(f"Node {node_id} does not exist.")
if node_id in self._nodes_in_setup:
# prevent duplicate setup actions
return
self._nodes_in_setup.add(node_id)

node_logger = self.get_node_logger(LOGGER, node_id)
node_data = self._nodes[node_id]
log_timers: dict[int, asyncio.TimerHandle] = {}
is_thread_node = (
Expand Down Expand Up @@ -1332,6 +1332,7 @@ async def log_node_long_setup(time_start: float) -> None:
log_timers[node_id] = self._loop.call_later(
15 * 60, lambda: asyncio.create_task(log_node_long_setup(time_start))
)

try:
node_logger.info("Setting-up node...")

Expand All @@ -1347,9 +1348,7 @@ async def log_node_long_setup(time_start: float) -> None:
# log full stack trace if verbose logging is enabled
exc_info=err if LOGGER.isEnabledFor(VERBOSE_LOG_LEVEL) else None,
)
# NOTE: the node will be picked up by mdns discovery automatically
# when it comes available again.
return
raise err

# (re)interview node (only) if needed
if (
Expand All @@ -1369,9 +1368,7 @@ async def log_node_long_setup(time_start: float) -> None:
if LOGGER.isEnabledFor(VERBOSE_LOG_LEVEL)
else None,
)
# NOTE: the node will be picked up by mdns discovery automatically
# when it comes available again.
return
raise err

# setup subscriptions for the node
try:
Expand All @@ -1383,21 +1380,53 @@ async def log_node_long_setup(time_start: float) -> None:
# log full stack trace if verbose logging is enabled
exc_info=err if LOGGER.isEnabledFor(VERBOSE_LOG_LEVEL) else None,
)
# NOTE: the node will be picked up by mdns discovery automatically
# when it becomes available again.
return
raise err

# check if this node has any custom clusters that need to be polled
if polled_attributes := check_polled_attributes(node_data):
self._polled_attributes[node_id] = polled_attributes
self._schedule_custom_attributes_poller()

finally:
log_timers[node_id].cancel()
self._nodes_in_setup.discard(node_id)
if is_thread_node:
self._thread_node_setup_throttle.release()

async def _setup_node(self, node_id: int) -> None:
if node_id not in self._nodes:
raise NodeNotExists(f"Node {node_id} does not exist.")

node_logger = self.get_node_logger(LOGGER, node_id)

while True:
try:
await self._setup_node_try_once(node_logger, node_id)
break
except (NodeNotResolving, NodeInterviewFailed, ChipStackError):
if (
time.time() - self._node_last_seen_on_mdns.get(node_id, 0)
> NODE_MDNS_SUBSCRIPTION_RETRY_TIMEOUT
):
# NOTE: assume the node will be picked up by mdns discovery later
# automatically when it becomes available again.
node_logger.warning(
"Node setup not completed after %s minutes, giving up.",
NODE_MDNS_SUBSCRIPTION_RETRY_TIMEOUT // 60,
)
break

node_logger.info("Retrying node setup in 60 seconds...")
await asyncio.sleep(60)

def _setup_node_create_task(self, node_id: int) -> asyncio.Task | None:
"""Create a task for setting up a node with retry."""
if node_id in self._setup_node_tasks:
node_logger = self.get_node_logger(LOGGER, node_id)
node_logger.debug("Setup task exists already for this Node")
return None
task = asyncio.create_task(self._setup_node(node_id))
self._setup_node_tasks[node_id] = task
return task

def _handle_endpoints_removed(self, node_id: int, endpoints: Iterable[int]) -> None:
"""Handle callback for when bridge endpoint(s) get deleted."""
node = self._nodes[node_id]
Expand Down Expand Up @@ -1479,24 +1508,17 @@ def _on_mdns_operational_node_state(
if not (node := self._nodes.get(node_id)):
return # this should not happen, but guard just in case

now = time.time()
last_seen = self._node_last_seen.get(node_id, 0)
self._node_last_seen[node_id] = now
self._node_last_seen_on_mdns[node_id] = time.time()

# we only treat UPDATE state changes as ADD if the node is marked as
# unavailable to ensure we catch a node being operational
if node.available and state_change == ServiceStateChange.Updated:
return

if node_id in self._nodes_in_setup:
# prevent duplicate setup actions
return

if not self._chip_device_controller.node_has_subscription(node_id):
node_logger.info("Discovered on mDNS")
elif (now - last_seen) > NODE_MDNS_BACKOFF:
# node came back online after being offline for a while or restarted
node_logger.info("Re-discovered on mDNS")
# Setup the node - this will setup the subscriptions etc.
self._setup_node_create_task(node_id)
elif state_change == ServiceStateChange.Added:
# Trigger node re-subscriptions when mDNS entry got added
node_logger.info("Activity on mDNS, trigger resubscribe")
Expand All @@ -1505,13 +1527,6 @@ def _on_mdns_operational_node_state(
node_id, "mDNS state change detected"
)
)
return
else:
# ignore all other cases
return

# setup the node - this will (re) setup the subscriptions etc.
asyncio.create_task(self._setup_node(node_id))

def _on_mdns_commissionable_node_state(
self, name: str, state_change: ServiceStateChange
Expand Down Expand Up @@ -1609,13 +1624,13 @@ async def _fallback_node_scanner(self) -> None:
if node.available:
continue
now = time.time()
last_seen = self._node_last_seen.get(node_id, 0)
last_seen = self._node_last_seen_on_mdns.get(node_id, 0)
if now - last_seen < FALLBACK_NODE_SCANNER_INTERVAL:
continue
if await self.ping_node(node_id, attempts=3):
LOGGER.info("Node %s discovered using fallback ping", node_id)
self._node_last_seen[node_id] = now
await self._setup_node(node_id)
if task := self._setup_node_create_task(node_id):
await task

# reschedule self to run at next interval
self._schedule_fallback_scanner()
Expand Down

0 comments on commit 68754b0

Please sign in to comment.