Skip to content

Commit

Permalink
Refactor locking
Browse files Browse the repository at this point in the history
  • Loading branch information
marcelveldt committed Mar 1, 2024
1 parent f8a9666 commit 8214e08
Showing 1 changed file with 72 additions and 81 deletions.
153 changes: 72 additions & 81 deletions matter_server/server/device_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import logging
from random import randint
import time
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable, TypeVar, cast
from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, cast

from chip.ChipDeviceCtrl import DeviceProxyWrapper
from chip.clusters import Attribute, Objects as Clusters
Expand Down Expand Up @@ -502,15 +502,13 @@ async def interview_node(self, node_id: int) -> None:
await self._resolve_node(node_id=node_id)
LOGGER.info("Interviewing node: %s", node_id)
read_response: Attribute.AsyncReadTransaction.ReadResponse = (
await self._call_sdk(
self.chip_controller.Read(
nodeid=node_id,
attributes="*",
fabricFiltered=False,
)
await self.chip_controller.Read(
nodeid=node_id,
attributes="*",
fabricFiltered=False,
)
)
except (ChipStackError, TimeoutError) as err:
except ChipStackError as err:
raise NodeInterviewFailed(f"Failed to interview node {node_id}") from err

is_new_node = node_id not in self._nodes
Expand Down Expand Up @@ -565,15 +563,13 @@ async def send_device_command(
cluster_cls: Cluster = ALL_CLUSTERS[cluster_id]
command_cls = getattr(cluster_cls.Commands, command_name)
command = dataclass_from_dict(command_cls, payload, allow_sdk_types=True)
return await self._call_sdk(
self.chip_controller.SendCommand(
nodeid=node_id,
endpoint=endpoint_id,
payload=command,
responseType=response_type,
timedRequestTimeoutMs=timed_request_timeout_ms,
interactionTimeoutMs=interaction_timeout_ms,
)
return await self.chip_controller.SendCommand(
nodeid=node_id,
endpoint=endpoint_id,
payload=command,
responseType=response_type,
timedRequestTimeoutMs=timed_request_timeout_ms,
interactionTimeoutMs=interaction_timeout_ms,
)

@api_command(APICommand.READ_ATTRIBUTE)
Expand All @@ -586,38 +582,32 @@ async def read_attribute(
if (node := self._nodes.get(node_id)) is None or not node.available:
raise NodeNotReady(f"Node {node_id} is not (yet) available.")
endpoint_id, cluster_id, attribute_id = parse_attribute_path(attribute_path)
device = await self._resolve_node(node_id)
# Read a list of attributes and/or events from a target node.
# This is basically a re-implementation of the chip controller's Read function
# but one that allows us to send/request custom attributes.

async def _do_read() -> Attribute.AsyncReadTransaction.ReadResponse:
"""
Read a list of attributes and/or events from a target node.
This is basically a re-implementation of the chip controller's Read function
but one that allows us to send/request custom attributes.
"""
if TYPE_CHECKING:
assert self.server.loop
assert self.chip_controller
loop = self.server.loop
future = loop.create_future()
Attribute.Read(
future=future,
eventLoop=loop,
device=self.chip_controller.GetConnectedDeviceSync(node_id),
devCtrl=self.chip_controller,
attributes=[
Attribute.AttributePath(
EndpointId=endpoint_id,
ClusterId=cluster_id,
AttributeId=attribute_id,
)
],
fabricFiltered=fabric_filtered,
).raise_on_error()
return await future
if TYPE_CHECKING:
assert self.server.loop
assert self.chip_controller
loop = self.server.loop
future = loop.create_future()
Attribute.Read(
future=future,
eventLoop=loop,
device=device,
devCtrl=self.chip_controller,
attributes=[
Attribute.AttributePath(
EndpointId=endpoint_id,
ClusterId=cluster_id,
AttributeId=attribute_id,
)
],
fabricFiltered=fabric_filtered,
).raise_on_error()

result: Attribute.AsyncReadTransaction.ReadResponse = await self._call_sdk(
_do_read()
)
result: Attribute.AsyncReadTransaction.ReadResponse = await future
read_atributes = parse_attributes_from_read_result(result.tlvAttributes)
# update cached info in node attributes
self._nodes[node_id].attributes.update(read_atributes)
Expand Down Expand Up @@ -647,11 +637,9 @@ async def write_attribute(
value_type=attribute.attribute_type.Type,
allow_sdk_types=True,
)
return await self._call_sdk(
self.chip_controller.WriteAttribute(
nodeid=node_id,
attributes=[(endpoint_id, attribute)],
)
return await self.chip_controller.WriteAttribute(
nodeid=node_id,
attributes=[(endpoint_id, attribute)],
)

@api_command(APICommand.REMOVE_NODE)
Expand Down Expand Up @@ -737,7 +725,9 @@ async def subscribe_attribute(
)

@api_command(APICommand.PING_NODE)
async def ping_node(self, node_id: int, attempts: int = 1) -> NodePingResult:
async def ping_node(
self, node_id: int, attempts: int = 1, allow_cached_ips: bool = True
) -> NodePingResult:
"""Ping node on the currently known IP-adress(es)."""
result: NodePingResult = {}
node = self._nodes.get(node_id)
Expand Down Expand Up @@ -771,7 +761,7 @@ async def _do_ping(ip_address: str) -> None:
result[clean_ip] = await ping_ip(ip_address, timeout, attempts=attempts)

ip_addresses = await self.get_node_ip_addresses(
node_id, prefer_cache=False, scoped=True
node_id, prefer_cache=False, scoped=True, allow_cache=allow_cached_ips
)
tasks = [_do_ping(x) for x in ip_addresses]
# TODO: replace this gather with a taskgroup once we bump our py version
Expand All @@ -796,7 +786,11 @@ async def _do_ping(ip_address: str) -> None:

@api_command(APICommand.GET_NODE_IP_ADRESSES)
async def get_node_ip_addresses(
self, node_id: int, prefer_cache: bool = False, scoped: bool = False
self,
node_id: int,
prefer_cache: bool = False,
scoped: bool = False,
allow_cache: bool = True,
) -> list[str]:
"""Return the currently known (scoped) IP-adress(es)."""
cached_info = self._last_known_ip_addresses.get(node_id, [])
Expand All @@ -814,7 +808,7 @@ async def get_node_ip_addresses(
info = AsyncServiceInfo(MDNS_TYPE_OPERATIONAL_NODE, mdns_name)
if TYPE_CHECKING:
assert self._aiozc is not None
if not await info.async_request(self._aiozc.zeroconf, 3000):
if not await info.async_request(self._aiozc.zeroconf, 3000) and allow_cache:
node_logger.info(
"Node could not be discovered on the network, returning cached IP's"
)
Expand Down Expand Up @@ -1013,17 +1007,15 @@ def resubscription_succeeded(
)
)
self._last_subscription_attempt[node_id] = 0
sub: Attribute.SubscriptionTransaction = await self._call_sdk(
self.chip_controller.Read(
node_id,
attributes="*",
events=[("*", 1)],
returnClusterObject=False,
reportInterval=(interval_floor, interval_ceiling),
fabricFiltered=False,
keepSubscriptions=True,
autoResubscribe=True,
)
sub: Attribute.SubscriptionTransaction = await self.chip_controller.Read(
node_id,
attributes="*",
events=[("*", 1)],
returnClusterObject=False,
reportInterval=(interval_floor, interval_ceiling),
fabricFiltered=False,
keepSubscriptions=True,
autoResubscribe=True,
)

sub.SetAttributeUpdateCallback(attribute_updated_callback)
Expand Down Expand Up @@ -1053,24 +1045,21 @@ def _get_next_node_id(self) -> int:

async def _call_sdk(
self,
target: Callable[..., _T] | Awaitable[_T],
target: Callable[..., _T],
*args: Any,
**kwargs: Any,
) -> _T:
"""Call function on the SDK in executor and return result."""
if self.server.loop is None:
raise RuntimeError("Server not started.")

# handle both awaitables and sync functions here
if not asyncio.iscoroutine(target):
target = self.server.loop.run_in_executor(
return cast(
_T,
await self.server.loop.run_in_executor(
self._sdk_executor,
partial(target, *args, **kwargs), # type: ignore[arg-type]
)
# we guard all calls to the sdk with a lock because we have no good way
# of knowing if all code in the python wrapper is thread safe.
async with self._sdk_lock:
return cast(_T, await target)
partial(target, *args, **kwargs),
),
)

async def _setup_node(self, node_id: int) -> None:
"""Handle set-up of subscriptions and interview (if needed) for known/discovered node."""
Expand All @@ -1084,7 +1073,9 @@ async def _setup_node(self, node_id: int) -> None:
# Ping the node to rule out stale mdns reports and to prevent that we
# send an unreachable node to the sdk which is very slow with resolving it.
# This will also precache the ip addresses of the node for later use.
ping_result = await self.ping_node(node_id, attempts=3)
ping_result = await self.ping_node(
node_id, attempts=3, allow_cached_ips=False
)
if not any(ping_result.values()):
LOGGER.warning(
"Skip set-up for node %s because it does not appear to be reachable...",
Expand Down Expand Up @@ -1116,7 +1107,7 @@ async def _setup_node(self, node_id: int) -> None:
# setup subscriptions for the node
try:
await self._subscribe_node(node_id)
except (NodeNotResolving, TimeoutError, ChipStackError) as err:
except (NodeNotResolving, ChipStackError) as err:
LOGGER.warning(
"Unable to subscribe to Node %s: %s",
node_id,
Expand Down Expand Up @@ -1150,7 +1141,7 @@ async def _resolve_node(
allowPASE=False,
timeoutMs=None,
)
except (ChipStackError, TimeoutError) as err:
except ChipStackError as err:
if attempt >= retries:
# when we're out of retries, raise NodeNotResolving
raise NodeNotResolving(f"Unable to resolve Node {node_id}") from err
Expand Down Expand Up @@ -1292,7 +1283,7 @@ async def _fallback_node_scanner(self) -> None:
last_seen = self._node_last_seen.get(node_id, 0)
if now - last_seen < FALLBACK_NODE_SCANNER_INTERVAL:
continue
if await self.ping_node(node_id, attempts=3):
if await self.ping_node(node_id, attempts=3, allow_cached_ips=False):
LOGGER.info("Node %s discovered using fallback ping", node_id)
await self._setup_node(node_id)

Expand Down

0 comments on commit 8214e08

Please sign in to comment.