diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index eaa1b11a..909e9056 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -158,7 +158,7 @@ def __init__( self._polled_attributes: dict[int, set[str]] = {} self._custom_attribute_poller_timer: asyncio.TimerHandle | None = None self._custom_attribute_poller_task: asyncio.Task | None = None - self._ota_provider = ExternalOtaProvider(ota_provider_dir) + self._ota_provider = ExternalOtaProvider(server.vendor_id, ota_provider_dir) async def initialize(self) -> None: """Initialize the device controller.""" @@ -913,12 +913,6 @@ async def update_node(self, node_id: int, software_version: int) -> dict | None: f"Software version {software_version} is not available for node {node_id}." ) - if self.chip_controller is None: - raise RuntimeError("Device Controller not initialized.") - - if not self._ota_provider: - raise UpdateError("No OTA provider found, updates not possible") - if self._ota_provider.is_busy(): raise UpdateError( "No OTA provider currently busy, updates currently not possible" @@ -929,7 +923,7 @@ async def update_node(self, node_id: int, software_version: int) -> dict | None: # Make sure any previous instances get stopped await self._ota_provider.start_update( - self, + self._chip_device_controller, node_id, ) diff --git a/matter_server/server/ota/provider.py b/matter_server/server/ota/provider.py index 87862dc4..16b50fdf 100644 --- a/matter_server/server/ota/provider.py +++ b/matter_server/server/ota/provider.py @@ -18,22 +18,26 @@ from aiohttp import ClientError, ClientSession from aiohttp.client_exceptions import InvalidURL from chip.clusters import Attribute, Objects as Clusters, Types +from chip.discovery import FilterType from chip.exceptions import ChipStackError from chip.interaction_model import Status -from matter_server.common.errors import NodeCommissionFailed, NodeNotExists, UpdateError +from matter_server.common.errors import UpdateError from matter_server.common.helpers.util import dataclass_from_dict -if TYPE_CHECKING: - from matter_server.server.device_controller import MatterDeviceController - if TYPE_CHECKING: from asyncio.subprocess import Process + from chip.native import PyChipError + + from matter_server.server.sdk import ChipDeviceControllerWrapper + LOGGER = logging.getLogger(__name__) DEFAULT_UPDATES_PATH: Final[Path] = Path("updates") +DEFAULT_OTA_PROVIDER_NODE_ID = 999900 + # From Matter SDK src/app/ota_image_tool.py CHECHKSUM_TYPES: Final[dict[int, str]] = { 1: "sha256", @@ -85,8 +89,9 @@ class ExternalOtaProvider: ENDPOINT_ID: Final[int] = 0 - def __init__(self, ota_provider_dir: Path) -> None: + def __init__(self, vendor_id: int, ota_provider_dir: Path) -> None: """Initialize the OTA provider.""" + self._vendor_id: int = vendor_id self._ota_provider_dir: Path = ota_provider_dir self._ota_provider_image_list_file: Path = ota_provider_dir / "updates.json" self._ota_provider_image_list: OtaProviderImageList | None = None @@ -158,40 +163,45 @@ def get_node_id(self) -> int | None: return self._get_ota_provider_image_list().otaProviderNodeId - async def _initialize(self, device_controller: MatterDeviceController) -> None: + async def _initialize( + self, chip_device_controller: ChipDeviceControllerWrapper + ) -> None: """Commissions the OTA Provider.""" - if device_controller.chip_controller is None: - raise RuntimeError("Device Controller not initialized.") - # The OTA Provider has not been commissioned yet, let's do it now. LOGGER.info("Commissioning the built-in OTA Provider App.") - try: - ota_provider_node = await device_controller.commission_on_network( - self.get_passcode(), - # TODO: Filtering by long discriminator seems broken - # filter_type=FilterType.LONG_DISCRIMINATOR, - # filter=ota_provider.get_descriminator(), + + res: PyChipError = await chip_device_controller.commission_on_network( + DEFAULT_OTA_PROVIDER_NODE_ID, + self.get_passcode(), + # TODO: Filtering by long discriminator seems broken + disc_filter_type=FilterType.LONG_DISCRIMINATOR, + disc_filter=self.get_descriminator(), + ) + if not res.is_success: + await self.stop() + raise UpdateError( + f"Failed to commission OTA Provider App: SDK Error {res.code}" ) - ota_provider_node_id = ota_provider_node.node_id - except NodeCommissionFailed: - LOGGER.error("Failed to commission OTA Provider App!") - return LOGGER.info( "OTA Provider App commissioned with node id %d.", - ota_provider_node_id, + DEFAULT_OTA_PROVIDER_NODE_ID, ) # Adjust ACL of OTA Requestor such that Node peer-to-peer communication # is allowed. try: - read_result = await device_controller.chip_controller.ReadAttribute( - ota_provider_node_id, [(0, Clusters.AccessControl.Attributes.Acl)] + read_result = cast( + Attribute.AsyncReadTransaction.ReadResponse, + await chip_device_controller.read_attribute( + DEFAULT_OTA_PROVIDER_NODE_ID, + [(0, Clusters.AccessControl.Attributes.Acl)], + ), ) acl_list = cast( list, - read_result[0][Clusters.AccessControl][ + read_result.attributes[0][Clusters.AccessControl][ Clusters.AccessControl.Attributes.Acl ], ) @@ -216,8 +226,8 @@ async def _initialize(self, device_controller: MatterDeviceController) -> None: # And write. This is persistent, so only need to be done after we commissioned # the OTA Provider App. write_result: Attribute.AttributeWriteResult = ( - await device_controller.chip_controller.WriteAttribute( - ota_provider_node_id, + await chip_device_controller.write_attribute( + DEFAULT_OTA_PROVIDER_NODE_ID, [(0, Clusters.AccessControl.Attributes.Acl(acl_list))], ) ) @@ -226,22 +236,17 @@ async def _initialize(self, device_controller: MatterDeviceController) -> None: "Failed writing adjusted OTA Provider App ACL: Status %s.", str(write_result[0].Status), ) - await device_controller.remove_node(ota_provider_node_id) + await self.stop() raise UpdateError("Error while setting up OTA Provider.") except ChipStackError as ex: logging.exception("Failed adjusting OTA Provider App ACL.", exc_info=ex) - await device_controller.remove_node(ota_provider_node_id) + await self.stop() raise UpdateError("Error while setting up OTA Provider.") from ex - self.set_node_id(ota_provider_node_id) - async def start_update( - self, device_controller: MatterDeviceController, node_id: int + self, chip_device_controller: ChipDeviceControllerWrapper, node_id: int ) -> None: - """Start the OTA Provider.""" - - if device_controller.chip_controller is None: - raise RuntimeError("Device Controller not initialized.") + """Start the OTA Provider and trigger the update.""" self._ota_target_node_id = node_id @@ -291,39 +296,29 @@ def _write_ota_provider_image_list_json( # Wait for OTA provider to be ready # TODO: Detect when OTA provider is ready - await asyncio.sleep(2) + await asyncio.sleep(3) # Handle if user deleted the OTA Provider node. ota_provider_node_id = self.get_node_id() - if ota_provider_node_id is not None: - try: - device_controller.get_node(ota_provider_node_id) - except NodeNotExists: - LOGGER.warning( - "OTA Provider node id %d not known by device controller! Resetting...", - ota_provider_node_id, - ) - await self._reset() - ota_provider_node_id = None # Commission and prepare OTA Provider if not initialized yet. # Use "ota_provider_node_id" to indicate if OTA Provider is setup or not. try: if ota_provider_node_id is None: LOGGER.info("Initializing OTA Provider") - await self._initialize(device_controller) + await self._initialize(chip_device_controller) finally: self._ota_target_node_id = None # Notify update node about the availability of the OTA Provider. It will query # the OTA provider and start the update. try: - await device_controller.chip_controller.SendCommand( - nodeid=node_id, - endpoint=0, - payload=Clusters.OtaSoftwareUpdateRequestor.Commands.AnnounceOTAProvider( - providerNodeID=ota_provider_node_id, - vendorID=device_controller.server.vendor_id, + await chip_device_controller.send_command( + node_id, + endpoint_id=0, + command=Clusters.OtaSoftwareUpdateRequestor.Commands.AnnounceOTAProvider( + providerNodeID=DEFAULT_OTA_PROVIDER_NODE_ID, + vendorID=self._vendor_id, announcementReason=Clusters.OtaSoftwareUpdateRequestor.Enums.AnnouncementReasonEnum.kUpdateAvailable, endpoint=ExternalOtaProvider.ENDPOINT_ID, ),