Skip to content

Commit

Permalink
Introduce a single lock for all sdk operations
Browse files Browse the repository at this point in the history
  • Loading branch information
marcelveldt committed Mar 1, 2024
1 parent c6893b7 commit ccfc3d1
Showing 1 changed file with 73 additions and 69 deletions.
142 changes: 73 additions & 69 deletions matter_server/server/device_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
import logging
from random import randint
import time
from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, cast
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable, TypeVar, cast

import async_timeout
from chip.ChipDeviceCtrl import DeviceProxyWrapper
from chip.clusters import Attribute, Objects as Clusters
from chip.clusters.Attribute import ValueDecodeFailure
from chip.clusters.ClusterObjects import ALL_ATTRIBUTES, ALL_CLUSTERS, Cluster
from chip.exceptions import ChipStackError
from chip.native import PyChipError
from zeroconf import IPVersion, ServiceStateChange, Zeroconf
from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf

Expand Down Expand Up @@ -113,7 +114,7 @@ def __init__(
self.wifi_credentials_set: bool = False
self.thread_credentials_set: bool = False
self.compressed_fabric_id: int | None = None
self._node_lock: dict[int, asyncio.Lock] = {}
self._sdk_lock = asyncio.Lock()
self._aiobrowser: AsyncServiceBrowser | None = None
self._aiozc: AsyncZeroconf | None = None
self._fallback_node_scanner_timer: asyncio.TimerHandle | None = None
Expand Down Expand Up @@ -248,13 +249,13 @@ async def commission_with_code(
attempts,
MAX_COMMISSION_RETRIES,
)
success = await self._call_sdk(
success: bool | PyChipError = await self._call_sdk(
self.chip_controller.CommissionWithCode,
setupPayload=code,
nodeid=node_id,
networkOnly=network_only,
)
if success:
if not isinstance(success, PyChipError) and success:
break
if not success and attempts >= MAX_COMMISSION_RETRIES:
raise NodeCommissionFailed(
Expand Down Expand Up @@ -502,16 +503,17 @@ async def interview_node(self, node_id: int) -> None:
try:
if not (node := self._nodes.get(node_id)) or not node.available:
await self._resolve_node(node_id=node_id)
async with self._get_node_lock(node_id):
LOGGER.info("Interviewing node: %s", node_id)
read_response: Attribute.AsyncReadTransaction.ReadResponse = (
await self.chip_controller.Read(
nodeid=node_id,
attributes="*",
fabricFiltered=False,
await self._call_sdk(
self.chip_controller.Read(
nodeid=node_id,
attributes="*",
fabricFiltered=False,
)
)
)
except ChipStackError as err:
except (ChipStackError, TimeoutError) as err:
raise NodeInterviewFailed(f"Failed to interview node {node_id}") from err

is_new_node = node_id not in self._nodes
Expand Down Expand Up @@ -566,15 +568,16 @@ async def send_device_command(
cluster_cls: Cluster = ALL_CLUSTERS[cluster_id]
command_cls = getattr(cluster_cls.Commands, command_name)
command = dataclass_from_dict(command_cls, payload, allow_sdk_types=True)
async with self._get_node_lock(node_id):
return await self.chip_controller.SendCommand(
return await self._call_sdk(
self.chip_controller.SendCommand(
nodeid=node_id,
endpoint=endpoint_id,
payload=command,
responseType=response_type,
timedRequestTimeoutMs=timed_request_timeout_ms,
interactionTimeoutMs=interaction_timeout_ms,
)
)

@api_command(APICommand.READ_ATTRIBUTE)
async def read_attribute(
Expand All @@ -586,14 +589,23 @@ async def read_attribute(
if (node := self._nodes.get(node_id)) is None or not node.available:
raise NodeNotReady(f"Node {node_id} is not (yet) available.")
endpoint_id, cluster_id, attribute_id = parse_attribute_path(attribute_path)
assert self.server.loop is not None
async with self._get_node_lock(node_id):
future = self.server.loop.create_future()
device = await self._resolve_node(node_id)

async def _do_read() -> Attribute.AsyncReadTransaction.ReadResponse:
"""
Read a list of attributes and/or events from a target node.
This is basically a re-implementation of the chip controller's Read function
but one that allows us to send/request custom attributes.
"""
if TYPE_CHECKING:
assert self.server.loop
assert self.chip_controller
loop = self.server.loop
future = loop.create_future()
Attribute.Read(
future=future,
eventLoop=self.server.loop,
device=device.deviceProxy,
eventLoop=loop,
device=self.chip_controller.GetConnectedDeviceSync(node_id),
devCtrl=self.chip_controller,
attributes=[
Attribute.AttributePath(
Expand All @@ -604,12 +616,16 @@ async def read_attribute(
],
fabricFiltered=fabric_filtered,
).raise_on_error()
result: Attribute.AsyncReadTransaction.ReadResponse = await future
read_atributes = parse_attributes_from_read_result(result.tlvAttributes)
# update cached info in node attributes
self._nodes[node_id].attributes.update(read_atributes)
self._write_node_state(node_id)
return read_atributes
return await future

result: Attribute.AsyncReadTransaction.ReadResponse = await self._call_sdk(
_do_read()
)
read_atributes = parse_attributes_from_read_result(result.tlvAttributes)
# update cached info in node attributes
self._nodes[node_id].attributes.update(read_atributes)
self._write_node_state(node_id)
return read_atributes

@api_command(APICommand.WRITE_ATTRIBUTE)
async def write_attribute(
Expand All @@ -634,11 +650,12 @@ async def write_attribute(
value_type=attribute.attribute_type.Type,
allow_sdk_types=True,
)
async with self._get_node_lock(node_id):
return await self.chip_controller.WriteAttribute(
return await self._call_sdk(
self.chip_controller.WriteAttribute(
nodeid=node_id,
attributes=[(endpoint_id, attribute)],
)
)

@api_command(APICommand.REMOVE_NODE)
async def remove_node(self, node_id: int) -> None:
Expand Down Expand Up @@ -677,12 +694,14 @@ async def remove_node(self, node_id: int) -> None:
return
result: Clusters.OperationalCredentials.Commands.NOCResponse | None = None
try:
result = await self.chip_controller.SendCommand(
nodeid=node_id,
endpoint=0,
payload=Clusters.OperationalCredentials.Commands.RemoveFabric(
fabricIndex=fabric_index,
),
result = await self._call_sdk(
self.chip_controller.SendCommand(
nodeid=node_id,
endpoint=0,
payload=Clusters.OperationalCredentials.Commands.RemoveFabric(
fabricIndex=fabric_index,
),
)
)
except ChipStackError as err:
LOGGER.warning(
Expand Down Expand Up @@ -997,30 +1016,18 @@ def resubscription_succeeded(
)
)
self._last_subscription_attempt[node_id] = 0
future = loop.create_future()
device = await self._resolve_node(node_id)
async with async_timeout.timeout(DEFAULT_CALL_TIMEOUT):
Attribute.Read(
future=future,
eventLoop=loop,
device=device.deviceProxy,
devCtrl=self.chip_controller,
attributes=[Attribute.AttributePath()], # wildcard
events=[
Attribute.EventPath(
EndpointId=None, Cluster=None, Event=None, Urgent=1
)
],
sub: Attribute.SubscriptionTransaction = await self._call_sdk(
self.chip_controller.Read(
node_id,
attributes="*",
events=[("*", 1)],
returnClusterObject=False,
subscriptionParameters=Attribute.SubscriptionParameters(
interval_floor, interval_ceiling
),
# Use fabricfiltered as False to detect changes made by other controllers
# and to be able to provide a list of all fabrics attached to the device
reportInterval=(interval_floor, interval_ceiling),
fabricFiltered=False,
keepSubscriptions=True,
autoResubscribe=True,
).raise_on_error()
sub: Attribute.SubscriptionTransaction = await future
)
)

sub.SetAttributeUpdateCallback(attribute_updated_callback)
sub.SetEventUpdateCallback(event_callback)
Expand Down Expand Up @@ -1049,7 +1056,7 @@ def _get_next_node_id(self) -> int:

async def _call_sdk(
self,
func: Callable[..., _T],
target: Callable[..., _T] | Awaitable[_T],
*args: Any,
call_timeout: int = DEFAULT_CALL_TIMEOUT,
**kwargs: Any,
Expand All @@ -1058,15 +1065,18 @@ async def _call_sdk(
if self.server.loop is None:
raise RuntimeError("Server not started.")

# prevent a single job in the executor blocking everything with a timeout.
async with async_timeout.timeout(call_timeout):
return cast(
_T,
await self.server.loop.run_in_executor(
self._sdk_executor,
partial(func, *args, **kwargs),
),
# handle both awaitables and sync functions here
if not asyncio.iscoroutine(target):
target = self.server.loop.run_in_executor(
self._sdk_executor,
partial(target, *args, **kwargs), # type: ignore[arg-type]
)
# we guard all calls to the sdk with a lock because we have no good way
# of knowing if all code in the python wrapper is thread safe.
# The additional timeout is a guard to prevent ourselves from deadlocking somehow.
async with async_timeout.timeout(call_timeout):
async with self._sdk_lock:
return cast(_T, await target)

async def _setup_node(self, node_id: int) -> None:
"""Handle set-up of subscriptions and interview (if needed) for known/discovered node."""
Expand Down Expand Up @@ -1112,7 +1122,7 @@ async def _setup_node(self, node_id: int) -> None:
# setup subscriptions for the node
try:
await self._subscribe_node(node_id)
except (NodeNotResolving, TimeoutError) as err:
except (NodeNotResolving, TimeoutError, ChipStackError) as err:
LOGGER.warning(
"Unable to subscribe to Node %s: %s",
node_id,
Expand Down Expand Up @@ -1253,12 +1263,6 @@ async def _on_mdns_commissionable_node_state(
await info.async_request(self._aiozc.zeroconf, 3000)
LOGGER.debug("Discovered commissionable Matter node using MDNS: %s", info)

def _get_node_lock(self, node_id: int) -> asyncio.Lock:
"""Return lock for given node."""
if node_id not in self._node_lock:
self._node_lock[node_id] = asyncio.Lock()
return self._node_lock[node_id]

def _write_node_state(self, node_id: int, force: bool = False) -> None:
"""Schedule the write of the current node state to persistent storage."""
node = self._nodes[node_id]
Expand Down

0 comments on commit ccfc3d1

Please sign in to comment.