From 612b867ff5c8470e5dc18a12535743f22aa21b76 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Fri, 17 May 2024 16:36:53 +0200 Subject: [PATCH 01/39] Implement Update check using DCL software information Check if there is a software update available using DCL software update information. --- matter_server/common/models.py | 1 + matter_server/server/device_controller.py | 49 +++++++++++++++ matter_server/server/helpers/__init__.py | 3 + .../server/helpers/paa_certificates.py | 8 +-- matter_server/server/ota/dcl.py | 60 +++++++++++++++++++ 5 files changed, 117 insertions(+), 4 deletions(-) create mode 100644 matter_server/server/ota/dcl.py diff --git a/matter_server/common/models.py b/matter_server/common/models.py index efd4049d..2b295ace 100644 --- a/matter_server/common/models.py +++ b/matter_server/common/models.py @@ -47,6 +47,7 @@ class APICommand(str, Enum): PING_NODE = "ping_node" GET_NODE_IP_ADDRESSES = "get_node_ip_addresses" IMPORT_TEST_NODE = "import_test_node" + UPDATE_NODE = "update_node" EventCallBackType = Callable[[EventType, Any], None] diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index 730aec57..00431722 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -31,6 +31,7 @@ from matter_server.common.models import CommissionableNodeData, CommissioningParameters from matter_server.server.helpers.attributes import parse_attributes_from_read_result from matter_server.server.helpers.utils import ping_ip +from matter_server.server.ota.dcl import check_updates from matter_server.server.sdk import ChipDeviceControllerWrapper from ..common.errors import ( @@ -91,11 +92,22 @@ DESCRIPTOR_PARTS_LIST_ATTRIBUTE_PATH = create_attribute_path_from_attribute( 0, Clusters.Descriptor.Attributes.PartsList ) +BASIC_INFORMATION_VENDOR_ID_ATTRIBUTE_PATH = create_attribute_path_from_attribute( + 0, Clusters.BasicInformation.Attributes.VendorID +) +BASIC_INFORMATION_PRODUCT_ID_ATTRIBUTE_PATH = create_attribute_path_from_attribute( + 0, Clusters.BasicInformation.Attributes.ProductID +) BASIC_INFORMATION_SOFTWARE_VERSION_ATTRIBUTE_PATH = ( create_attribute_path_from_attribute( 0, Clusters.BasicInformation.Attributes.SoftwareVersion ) ) +BASIC_INFORMATION_SOFTWARE_VERSION_STRING_ATTRIBUTE_PATH = ( + create_attribute_path_from_attribute( + 0, Clusters.BasicInformation.Attributes.SoftwareVersionString + ) +) # pylint: disable=too-many-lines,too-many-instance-attributes,too-many-public-methods @@ -876,6 +888,43 @@ async def import_test_node(self, dump: str) -> None: self._nodes[node.node_id] = node self.server.signal_event(EventType.NODE_ADDED, node) + @api_command(APICommand.UPDATE_NODE) + async def update_node(self, node_id: int) -> dict | None: + """ + Check if there is an update for a particular node. + + Reads the current software version and checks the DCL if there is an update + available. If there is an update available, the command returns the version + information of the latest update available. + """ + + node_logger = LOGGER.getChild(f"node_{node_id}") + node = self._nodes[node_id] + + node_logger.debug("Check for updates.") + vid = cast(int, node.attributes.get(BASIC_INFORMATION_VENDOR_ID_ATTRIBUTE_PATH)) + pid = cast( + int, node.attributes.get(BASIC_INFORMATION_PRODUCT_ID_ATTRIBUTE_PATH) + ) + software_version = cast( + int, node.attributes.get(BASIC_INFORMATION_SOFTWARE_VERSION_ATTRIBUTE_PATH) + ) + software_version_string = node.attributes.get( + BASIC_INFORMATION_SOFTWARE_VERSION_STRING_ATTRIBUTE_PATH + ) + + update = await check_updates(node_id, vid, pid, software_version) + if update and "otaUrl" in update and len(update["otaUrl"]) > 0: + node_logger.info( + "New software update found: %s (current %s). Preparing updates...", + update["softwareVersionString"], + software_version_string, + ) + + # Add to OTA provider + + return update + async def _subscribe_node(self, node_id: int) -> None: """ Subscribe to all node state changes/events for an individual node. diff --git a/matter_server/server/helpers/__init__.py b/matter_server/server/helpers/__init__.py index 8cb651ea..b80f5f6a 100644 --- a/matter_server/server/helpers/__init__.py +++ b/matter_server/server/helpers/__init__.py @@ -1 +1,4 @@ """Helpers/utils for the Matter Server.""" + +DCL_PRODUCTION_URL = "https://on.dcl.csa-iot.org" +DCL_TEST_URL = "https://on.test-net.dcl.csa-iot.org" diff --git a/matter_server/server/helpers/paa_certificates.py b/matter_server/server/helpers/paa_certificates.py index 4a459e65..27fc5875 100644 --- a/matter_server/server/helpers/paa_certificates.py +++ b/matter_server/server/helpers/paa_certificates.py @@ -19,14 +19,14 @@ from cryptography.hazmat.primitives import serialization from cryptography.utils import CryptographyDeprecationWarning +from matter_server.server.helpers import DCL_PRODUCTION_URL, DCL_TEST_URL + # Git repo details OWNER = "project-chip" REPO = "connectedhomeip" PATH = "credentials/development/paa-root-certs" LOGGER = logging.getLogger(__name__) -PRODUCTION_URL = "https://on.dcl.csa-iot.org" -TEST_URL = "https://on.test-net.dcl.csa-iot.org" GIT_URL = f"https://raw.githubusercontent.com/{OWNER}/{REPO}/master/{PATH}" @@ -226,7 +226,7 @@ def _check_paa_root_dir( fetch_count = await fetch_dcl_certificates( paa_root_cert_dir=paa_root_cert_dir, base_name="dcld_production_", - base_url=PRODUCTION_URL, + base_url=DCL_PRODUCTION_URL, ) LOGGER.info("Fetched %s PAA root certificates from DCL.", fetch_count) total_fetch_count += fetch_count @@ -235,7 +235,7 @@ def _check_paa_root_dir( fetch_count = await fetch_dcl_certificates( paa_root_cert_dir=paa_root_cert_dir, base_name="dcld_test_", - base_url=TEST_URL, + base_url=DCL_TEST_URL, ) LOGGER.info("Fetched %s PAA root certificates from Test DCL.", fetch_count) total_fetch_count += fetch_count diff --git a/matter_server/server/ota/dcl.py b/matter_server/server/ota/dcl.py new file mode 100644 index 00000000..3ce0eace --- /dev/null +++ b/matter_server/server/ota/dcl.py @@ -0,0 +1,60 @@ +"""Handle OTA software version endpoints of the DCL.""" + +import logging +from typing import Any + +from aiohttp import ClientError, ClientSession + +from matter_server.server.helpers import DCL_PRODUCTION_URL + +LOGGER = logging.getLogger(__name__) + + +async def get_software_versions(node_id: int, vid: int, pid: int) -> Any: + """Check DCL if there are updates available for a particular node.""" + async with ClientSession(raise_for_status=True) as http_session: + # fetch the paa certificates list + async with http_session.get( + f"{DCL_PRODUCTION_URL}/dcl/model/versions/{vid}/{pid}" + ) as response: + return await response.json() + + +async def get_software_version( + node_id: int, vid: int, pid: int, software_version: int +) -> Any: + """Check DCL if there are updates available for a particular node.""" + async with ClientSession(raise_for_status=True) as http_session: + # fetch the paa certificates list + async with http_session.get( + f"{DCL_PRODUCTION_URL}/dcl/model/versions/{vid}/{pid}/{software_version}" + ) as response: + return await response.json() + + +async def check_updates( + node_id: int, vid: int, pid: int, current_software_version: int +) -> None | dict: + """Check if there is a newer software version available on the DCL.""" + try: + versions = await get_software_versions(node_id, vid, pid) + + software_versions: list[int] = versions["modelVersions"]["softwareVersions"] + latest_software_version = max(software_versions) + if latest_software_version <= current_software_version: + return None + + version: dict = await get_software_version( + node_id, vid, pid, latest_software_version + ) + if isinstance(version, dict) and "modelVersion" in version: + result: Any = version["modelVersion"] + if isinstance(result, dict): + return result + + logging.error("Unexpected DCL response.") + return None + + except (ClientError, TimeoutError) as err: + LOGGER.error("Fetching software version failed: error %s", err, exc_info=err) + return None From 17f9f547a346888bf42c681c4348054ceb130f05 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Fri, 17 May 2024 17:50:22 +0200 Subject: [PATCH 02/39] Initial implementation of OTA provider The OTA provider downloads the updates and prepares them so Matter devices can consume them. --- matter_server/server/device_controller.py | 3 + matter_server/server/ota/provider.py | 139 ++++++++++++++++++++++ 2 files changed, 142 insertions(+) create mode 100644 matter_server/server/ota/provider.py diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index 00431722..2e41aa04 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -32,6 +32,7 @@ from matter_server.server.helpers.attributes import parse_attributes_from_read_result from matter_server.server.helpers.utils import ping_ip from matter_server.server.ota.dcl import check_updates +from matter_server.server.ota.provider import ExternalOtaProvider from matter_server.server.sdk import ChipDeviceControllerWrapper from ..common.errors import ( @@ -149,6 +150,7 @@ def __init__( self._polled_attributes: dict[int, set[str]] = {} self._custom_attribute_poller_timer: asyncio.TimerHandle | None = None self._custom_attribute_poller_task: asyncio.Task | None = None + self._ota_provider = ExternalOtaProvider() async def initialize(self) -> None: """Initialize the device controller.""" @@ -922,6 +924,7 @@ async def update_node(self, node_id: int) -> dict | None: ) # Add to OTA provider + await self._ota_provider.download_update(update) return update diff --git a/matter_server/server/ota/provider.py b/matter_server/server/ota/provider.py new file mode 100644 index 00000000..be49a3d1 --- /dev/null +++ b/matter_server/server/ota/provider.py @@ -0,0 +1,139 @@ +"""Handling Matter OTA provider.""" + +import asyncio +from dataclasses import asdict, dataclass +import json +import logging +from pathlib import Path +from typing import Final +from urllib.parse import unquote, urlparse + +from aiohttp import ClientError, ClientSession + +from matter_server.common.helpers.util import dataclass_from_dict + +LOGGER = logging.getLogger(__name__) + +DEFAULT_UPDATES_PATH: Final[Path] = Path("updates") + + +@dataclass +class DeviceSoftwareVersionModel: # pylint: disable=C0103 + """Device Software Version Model for OTA Provider JSON descriptor file.""" + + vendorId: int + productId: int + softwareVersion: int + softwareVersionString: str + cDVersionNumber: int + softwareVersionValid: bool + minApplicableSoftwareVersion: int + maxApplicableSoftwareVersion: int + otaURL: str + + +@dataclass +class UpdateFile: # pylint: disable=C0103 + """Update File for OTA Provider JSON descriptor file.""" + + deviceSoftwareVersionModel: list[DeviceSoftwareVersionModel] + + +class ExternalOtaProvider: + """Class handling Matter OTA Provider. + + The OTA Provider class implements a Matter OTA (over-the-air) update provider + for devices. + """ + + def __init__(self) -> None: + """Initialize the OTA provider.""" + + def start(self) -> None: + """Start the OTA Provider.""" + + async def add_update(self, update_desc: dict, ota_file: Path) -> None: + """Add update to the OTA provider.""" + + update_json_path = DEFAULT_UPDATES_PATH / "updates.json" + + def _read_update_json(update_json_path: Path) -> None | UpdateFile: + if not update_json_path.exists(): + return None + + with open(update_json_path, "r") as json_file: + data = json.load(json_file) + return dataclass_from_dict(UpdateFile, data) + + loop = asyncio.get_running_loop() + update_file = await loop.run_in_executor( + None, _read_update_json, update_json_path + ) + + if not update_file: + update_file = UpdateFile(deviceSoftwareVersionModel=[]) + + # Convert to OTA Requestor descriptor file + update_file.deviceSoftwareVersionModel.append( + DeviceSoftwareVersionModel( + vendorId=update_desc["vid"], + productId=update_desc["pid"], + softwareVersion=update_desc["softwareVersion"], + softwareVersionString=update_desc["softwareVersionString"], + cDVersionNumber=update_desc["cdVersionNumber"], + softwareVersionValid=update_desc["softwareVersionValid"], + minApplicableSoftwareVersion=update_desc[ + "minApplicableSoftwareVersion" + ], + maxApplicableSoftwareVersion=update_desc[ + "maxApplicableSoftwareVersion" + ], + otaURL=str(ota_file), + ) + ) + + def _write_update_json(update_json_path: Path, update_file: UpdateFile) -> None: + update_file_dict = asdict(update_file) + with open(update_json_path, "w") as json_file: + json.dump(update_file_dict, json_file, indent=4) + + await loop.run_in_executor( + None, + _write_update_json, + update_json_path, + update_file, + ) + + async def download_update(self, update_desc: dict) -> None: + """Download update file from OTA Path and add it to the OTA provider.""" + + url = update_desc["otaUrl"] + parsed_url = urlparse(url) + file_name = unquote(Path(parsed_url.path).name) + + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, DEFAULT_UPDATES_PATH.mkdir) + + file_path = DEFAULT_UPDATES_PATH / file_name + + try: + async with ClientSession(raise_for_status=True) as session: + # fetch the paa certificates list + logging.debug("Download update from f{url}.") + async with session.get(url) as response: + with file_path.open("wb") as f: + while True: + chunk = await response.content.read(1024) + if not chunk: + break + f.write(chunk) + LOGGER.info( + "File '%s' downloaded to '%s'", file_name, DEFAULT_UPDATES_PATH + ) + + except (ClientError, TimeoutError) as err: + LOGGER.error( + "Fetching software version failed: error %s", err, exc_info=err + ) + + await self.add_update(update_desc, file_path) From 2bab7e3fc26889f8d302784c356ca5d19fa79ea6 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Fri, 17 May 2024 19:15:05 +0200 Subject: [PATCH 03/39] Implement update using OTA Provider app Use the OTA Provider example app to implement a OTA provider. The example app supports a JSON update descriptor file to manage update metadata. Tested with the OTA Requestor app. --- matter_server/server/device_controller.py | 50 ++++++++++--- matter_server/server/ota/provider.py | 85 +++++++++++++++++------ 2 files changed, 107 insertions(+), 28 deletions(-) diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index 2e41aa04..8b340dc3 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -221,6 +221,9 @@ async def stop(self) -> None: # shutdown the sdk device controller await self._chip_device_controller.shutdown() + # shutdown the OTA Provider + if self._ota_provider: + await self._ota_provider.stop() LOGGER.debug("Stopped.") @property @@ -903,6 +906,13 @@ async def update_node(self, node_id: int) -> dict | None: node_logger = LOGGER.getChild(f"node_{node_id}") node = self._nodes[node_id] + if self.chip_controller is None: + raise RuntimeError("Device Controller not initialized.") + + if not self._ota_provider: + LOGGER.warning("No OTA provider found, updates not possible.") + return None + node_logger.debug("Check for updates.") vid = cast(int, node.attributes.get(BASIC_INFORMATION_VENDOR_ID_ATTRIBUTE_PATH)) pid = cast( @@ -916,15 +926,39 @@ async def update_node(self, node_id: int) -> dict | None: ) update = await check_updates(node_id, vid, pid, software_version) - if update and "otaUrl" in update and len(update["otaUrl"]) > 0: - node_logger.info( - "New software update found: %s (current %s). Preparing updates...", - update["softwareVersionString"], - software_version_string, - ) + if not update: + node_logger.info("No new update found.") + return None + + if "otaUrl" not in update: + node_logger.warning("Update found, but no OTA URL provided.") + return None - # Add to OTA provider - await self._ota_provider.download_update(update) + node_logger.info( + "New software update found: %s (current %s). Preparing updates...", + update["softwareVersionString"], + software_version_string, + ) + + # Add to OTA provider + await self._ota_provider.download_update(update) + + self._ota_provider.start() + + # Wait for OTA provider to be ready + # TODO: Detect when OTA provider is ready + await asyncio.sleep(2) + + await self.chip_controller.SendCommand( + nodeid=node_id, + endpoint=0, + payload=Clusters.OtaSoftwareUpdateRequestor.Commands.AnnounceOTAProvider( + providerNodeID=32, + vendorID=0, # TODO: Use Server Vendor ID + announcementReason=Clusters.OtaSoftwareUpdateRequestor.Enums.AnnouncementReasonEnum.kUpdateAvailable, + endpoint=0, + ), + ) return update diff --git a/matter_server/server/ota/provider.py b/matter_server/server/ota/provider.py index be49a3d1..06257c28 100644 --- a/matter_server/server/ota/provider.py +++ b/matter_server/server/ota/provider.py @@ -2,16 +2,20 @@ import asyncio from dataclasses import asdict, dataclass +import functools import json import logging from pathlib import Path -from typing import Final +from typing import TYPE_CHECKING, Final from urllib.parse import unquote, urlparse from aiohttp import ClientError, ClientSession from matter_server.common.helpers.util import dataclass_from_dict +if TYPE_CHECKING: + from asyncio.subprocess import Process + LOGGER = logging.getLogger(__name__) DEFAULT_UPDATES_PATH: Final[Path] = Path("updates") @@ -48,10 +52,42 @@ class ExternalOtaProvider: def __init__(self) -> None: """Initialize the OTA provider.""" + self._ota_provider_proc: Process | None = None + self._ota_provider_task: asyncio.Task | None = None + + async def _start_ota_provider(self) -> None: + # TODO: Randomize discriminator + ota_provider_cmd = [ + "chip-ota-provider-app", + "--discriminator", + "22", + "--secured-device-port", + "5565", + "--KVS", + "/data/chip_kvs_provider", + "--otaImageList", + str(DEFAULT_UPDATES_PATH / "updates.json"), + ] + + LOGGER.info("Starting OTA Provider") + self._ota_provider_proc = await asyncio.create_subprocess_exec( + *ota_provider_cmd + ) def start(self) -> None: """Start the OTA Provider.""" + loop = asyncio.get_event_loop() + self._ota_provider_task = loop.create_task(self._start_ota_provider()) + + async def stop(self) -> None: + """Stop the OTA Provider.""" + if self._ota_provider_proc: + LOGGER.info("Terminating OTA Provider") + self._ota_provider_proc.terminate() + if self._ota_provider_task: + await self._ota_provider_task + async def add_update(self, update_desc: dict, ota_file: Path) -> None: """Add update to the OTA provider.""" @@ -73,24 +109,25 @@ def _read_update_json(update_json_path: Path) -> None | UpdateFile: if not update_file: update_file = UpdateFile(deviceSoftwareVersionModel=[]) + local_ota_url = str(ota_file) + for i, device_software in enumerate(update_file.deviceSoftwareVersionModel): + if device_software.otaURL == local_ota_url: + LOGGER.debug("Device software entry exists already, replacing!") + del update_file.deviceSoftwareVersionModel[i] + # Convert to OTA Requestor descriptor file - update_file.deviceSoftwareVersionModel.append( - DeviceSoftwareVersionModel( - vendorId=update_desc["vid"], - productId=update_desc["pid"], - softwareVersion=update_desc["softwareVersion"], - softwareVersionString=update_desc["softwareVersionString"], - cDVersionNumber=update_desc["cdVersionNumber"], - softwareVersionValid=update_desc["softwareVersionValid"], - minApplicableSoftwareVersion=update_desc[ - "minApplicableSoftwareVersion" - ], - maxApplicableSoftwareVersion=update_desc[ - "maxApplicableSoftwareVersion" - ], - otaURL=str(ota_file), - ) + new_device_software = DeviceSoftwareVersionModel( + vendorId=update_desc["vid"], + productId=update_desc["pid"], + softwareVersion=update_desc["softwareVersion"], + softwareVersionString=update_desc["softwareVersionString"], + cDVersionNumber=update_desc["cdVersionNumber"], + softwareVersionValid=update_desc["softwareVersionValid"], + minApplicableSoftwareVersion=update_desc["minApplicableSoftwareVersion"], + maxApplicableSoftwareVersion=update_desc["maxApplicableSoftwareVersion"], + otaURL=local_ota_url, ) + update_file.deviceSoftwareVersionModel.append(new_device_software) def _write_update_json(update_json_path: Path, update_file: UpdateFile) -> None: update_file_dict = asdict(update_file) @@ -112,9 +149,14 @@ async def download_update(self, update_desc: dict) -> None: file_name = unquote(Path(parsed_url.path).name) loop = asyncio.get_running_loop() - await loop.run_in_executor(None, DEFAULT_UPDATES_PATH.mkdir) + await loop.run_in_executor( + None, functools.partial(DEFAULT_UPDATES_PATH.mkdir, exists_ok=True) + ) file_path = DEFAULT_UPDATES_PATH / file_name + if await loop.run_in_executor(None, file_path.exists): + LOGGER.info("File '%s' exists already, skipping download.", file_name) + return try: async with ClientSession(raise_for_status=True) as session: @@ -123,10 +165,13 @@ async def download_update(self, update_desc: dict) -> None: async with session.get(url) as response: with file_path.open("wb") as f: while True: - chunk = await response.content.read(1024) + chunk = await response.content.read(4048) if not chunk: break - f.write(chunk) + await loop.run_in_executor(None, f.write, chunk) + + # TODO: Check against otaChecksum + LOGGER.info( "File '%s' downloaded to '%s'", file_name, DEFAULT_UPDATES_PATH ) From 9d7717f0dea3a0124d473b4ed54d0071013f9208 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Sat, 18 May 2024 01:40:25 +0200 Subject: [PATCH 04/39] Setup OTA Provider App automatically when necessary Start and commission OTA Provider App when necessary. Use random discriminator and passcode. Store the Node ID of the OTA Provider App once setup for fast re-use. --- matter_server/server/__main__.py | 7 + matter_server/server/const.py | 2 + matter_server/server/device_controller.py | 86 +++++++++++- matter_server/server/ota/provider.py | 153 ++++++++++++++++------ matter_server/server/server.py | 19 ++- 5 files changed, 221 insertions(+), 46 deletions(-) diff --git a/matter_server/server/__main__.py b/matter_server/server/__main__.py index df66597a..f12bdd44 100644 --- a/matter_server/server/__main__.py +++ b/matter_server/server/__main__.py @@ -116,6 +116,12 @@ nargs="+", help="List of node IDs to show logs from (applies only to server logs).", ) +parser.add_argument( + "--ota-provider-dir", + type=str, + default=None, + help="Directory where OTA Provider stores software updates and configuration.", +) args = parser.parse_args() @@ -216,6 +222,7 @@ def main() -> None: args.paa_root_cert_dir, args.enable_test_net_dcl, args.bluetooth_adapter, + args.ota_provider_dir, ) async def handle_stop(loop: asyncio.AbstractEventLoop) -> None: diff --git a/matter_server/server/const.py b/matter_server/server/const.py index 5afa4c1f..f411f68f 100644 --- a/matter_server/server/const.py +++ b/matter_server/server/const.py @@ -20,3 +20,5 @@ .parent.resolve() .joinpath("credentials/development/paa-root-certs") ) + +DEFAULT_OTA_PROVIDER_DIR: Final[pathlib.Path] = pathlib.Path().cwd().joinpath("updates") diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index 8b340dc3..871f8671 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -18,11 +18,12 @@ from typing import TYPE_CHECKING, Any, cast from chip.ChipDeviceCtrl import ChipDeviceController -from chip.clusters import Attribute, Objects as Clusters +from chip.clusters import Attribute, Objects as Clusters, Types from chip.clusters.Attribute import ValueDecodeFailure from chip.clusters.ClusterObjects import ALL_ATTRIBUTES, ALL_CLUSTERS, Cluster from chip.discovery import DiscoveryType from chip.exceptions import ChipStackError +from chip.interaction_model import Status from zeroconf import BadTypeInNameException, IPVersion, ServiceStateChange, Zeroconf from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf @@ -120,6 +121,7 @@ def __init__( self, server: MatterServer, paa_root_cert_dir: Path, + ota_provider_dir: Path, ): """Initialize the device controller.""" self.server = server @@ -150,7 +152,7 @@ def __init__( self._polled_attributes: dict[int, set[str]] = {} self._custom_attribute_poller_timer: asyncio.TimerHandle | None = None self._custom_attribute_poller_task: asyncio.Task | None = None - self._ota_provider = ExternalOtaProvider() + self._ota_provider = ExternalOtaProvider(ota_provider_dir) async def initialize(self) -> None: """Initialize the device controller.""" @@ -158,6 +160,7 @@ async def initialize(self) -> None: await self._chip_device_controller.get_compressed_fabric_id() ) self._fabric_id_hex = hex(self._compressed_fabric_id)[2:] + await self._ota_provider.initialize() async def start(self) -> None: """Handle logic on controller start.""" @@ -943,17 +946,94 @@ async def update_node(self, node_id: int) -> dict | None: # Add to OTA provider await self._ota_provider.download_update(update) + ota_provider_node_id = self._ota_provider.get_node_id() + if ota_provider_node_id not in self._nodes: + LOGGER.warning( + "OTA Provider node id %d no longer exists! Resetting...", + ota_provider_node_id, + ) + await self._ota_provider.reset() + ota_provider_node_id = None + + # Make sure any previous instances get stopped + await self._ota_provider.stop() self._ota_provider.start() # Wait for OTA provider to be ready # TODO: Detect when OTA provider is ready await asyncio.sleep(2) + if not ota_provider_node_id: + # The OTA Provider has not been commissioned yet, let's do it now. + LOGGER.info("Commissioning the built-in OTA Provider App.") + try: + ota_provider_node = await self.commission_on_network( + self._ota_provider.get_passcode(), + # TODO: Filtering by long discriminator seems broken + # filter_type=FilterType.LONG_DISCRIMINATOR, + # filter=self._ota_provider.get_descriminator(), + ) + ota_provider_node_id = ota_provider_node.node_id + except NodeCommissionFailed: + LOGGER.error("Failed to commission OTA Provider App!") + return None + LOGGER.info( + "OTA Provider App commissioned with node id %d.", + ota_provider_node_id, + ) + + # Adjust ACL of OTA Requestor such that Node peer-to-peer communication + # is allowed. + try: + read_result = await self.chip_controller.ReadAttribute( + ota_provider_node_id, [(0, Clusters.AccessControl.Attributes.Acl)] + ) + acl_list = cast( + list, + read_result[0][Clusters.AccessControl][ + Clusters.AccessControl.Attributes.Acl + ], + ) + + # Add new ACL entry... + acl_list.append( + Clusters.AccessControl.Structs.AccessControlEntryStruct( + fabricIndex=1, + privilege=3, + authMode=2, + subjects=Types.NullValue, + targets=[ + Clusters.AccessControl.Structs.AccessControlTargetStruct( + cluster=41, endpoint=0, deviceType=Types.NullValue + ) + ], + ) + ) + + # And write. This is persistent, so only need to be done after we commissioned + # the OTA Provider App. + write_result: Attribute.AttributeWriteResult = ( + await self.chip_controller.WriteAttribute( + ota_provider_node_id, + [(0, Clusters.AccessControl.Attributes.Acl(acl_list))], + ) + ) + if write_result[0].Status != Status.Success: + logging.error("Failed writing adjusted OTA Provider App ACL.") + await self.remove_node(ota_provider_node_id) + return None + except ChipStackError as ex: + logging.exception("Failed adjusting OTA Provider App ACL.", exc_info=ex) + await self.remove_node(ota_provider_node_id) + else: + self._ota_provider.set_node_id(ota_provider_node_id) + + # Notify node about the new update! await self.chip_controller.SendCommand( nodeid=node_id, endpoint=0, payload=Clusters.OtaSoftwareUpdateRequestor.Commands.AnnounceOTAProvider( - providerNodeID=32, + providerNodeID=ota_provider_node_id, vendorID=0, # TODO: Use Server Vendor ID announcementReason=Clusters.OtaSoftwareUpdateRequestor.Enums.AnnouncementReasonEnum.kUpdateAvailable, endpoint=0, diff --git a/matter_server/server/ota/provider.py b/matter_server/server/ota/provider.py index 06257c28..0e540aec 100644 --- a/matter_server/server/ota/provider.py +++ b/matter_server/server/ota/provider.py @@ -6,6 +6,7 @@ import json import logging from pathlib import Path +import secrets from typing import TYPE_CHECKING, Final from urllib.parse import unquote, urlparse @@ -37,9 +38,12 @@ class DeviceSoftwareVersionModel: # pylint: disable=C0103 @dataclass -class UpdateFile: # pylint: disable=C0103 +class OtaProviderImageList: # pylint: disable=C0103 """Update File for OTA Provider JSON descriptor file.""" + otaProviderDiscriminator: int + otaProviderPasscode: int + otaProviderNodeId: int | None deviceSoftwareVersionModel: list[DeviceSoftwareVersionModel] @@ -50,23 +54,103 @@ class ExternalOtaProvider: for devices. """ - def __init__(self) -> None: + def __init__(self, ota_provider_dir: Path) -> None: """Initialize the OTA provider.""" + self._ota_provider_dir: Path = ota_provider_dir + self._ota_provider_image_list_file: Path = ota_provider_dir / "updates.json" + self._ota_provider_image_list: OtaProviderImageList | None = None self._ota_provider_proc: Process | None = None self._ota_provider_task: asyncio.Task | None = None + async def initialize(self) -> None: + """Initialize OTA Provider.""" + + loop = asyncio.get_event_loop() + + # Take existence of image list file as indicator if we need to initialize the + # OTA Provider. + if not await loop.run_in_executor( + None, self._ota_provider_image_list_file.exists + ): + await loop.run_in_executor( + None, functools.partial(DEFAULT_UPDATES_PATH.mkdir, exist_ok=True) + ) + + # Initialize with random data. Node ID will get written once paired by + # device controller. + self._ota_provider_image_list = OtaProviderImageList( + otaProviderDiscriminator=secrets.randbelow(2**12), + otaProviderPasscode=secrets.randbelow(2**21), + otaProviderNodeId=None, + deviceSoftwareVersionModel=[], + ) + else: + + def _read_update_json( + update_json_path: Path, + ) -> None | OtaProviderImageList: + with open(update_json_path, "r") as json_file: + data = json.load(json_file) + return dataclass_from_dict(OtaProviderImageList, data) + + self._ota_provider_image_list = await loop.run_in_executor( + None, _read_update_json, self._ota_provider_image_list_file + ) + + def _get_ota_provider_image_list(self) -> OtaProviderImageList: + if self._ota_provider_image_list is None: + raise RuntimeError("OTA provider image list not initialized.") + return self._ota_provider_image_list + + def get_node_id(self) -> int | None: + """Get Node ID of the OTA Provider App.""" + + return self._get_ota_provider_image_list().otaProviderNodeId + + def get_descriminator(self) -> int: + """Return OTA Provider App discriminator.""" + + return self._get_ota_provider_image_list().otaProviderDiscriminator + + def get_passcode(self) -> int: + """Return OTA Provider App passcode.""" + + return self._get_ota_provider_image_list().otaProviderPasscode + + def set_node_id(self, node_id: int) -> None: + """Set Node ID of the OTA Provider App.""" + + self._get_ota_provider_image_list().otaProviderNodeId = node_id + async def _start_ota_provider(self) -> None: - # TODO: Randomize discriminator + def _write_ota_provider_image_list_json( + ota_provider_image_list_file: Path, + ota_provider_image_list: OtaProviderImageList, + ) -> None: + update_file_dict = asdict(ota_provider_image_list) + with open(ota_provider_image_list_file, "w") as json_file: + json.dump(update_file_dict, json_file, indent=4) + + loop = asyncio.get_running_loop() + await loop.run_in_executor( + None, + _write_ota_provider_image_list_json, + self._ota_provider_image_list_file, + self._get_ota_provider_image_list(), + ) + ota_provider_cmd = [ "chip-ota-provider-app", "--discriminator", - "22", + str(self._get_ota_provider_image_list().otaProviderDiscriminator), + "--passcode", + str(self._get_ota_provider_image_list().otaProviderPasscode), "--secured-device-port", "5565", "--KVS", - "/data/chip_kvs_provider", + str(self._ota_provider_dir / "chip_kvs_ota_provider"), "--otaImageList", - str(DEFAULT_UPDATES_PATH / "updates.json"), + str(self._ota_provider_image_list_file), ] LOGGER.info("Starting OTA Provider") @@ -80,40 +164,41 @@ def start(self) -> None: loop = asyncio.get_event_loop() self._ota_provider_task = loop.create_task(self._start_ota_provider()) + async def reset(self) -> None: + """Reset the OTA Provider App state.""" + + def _remove_update_data(ota_provider_dir: Path) -> None: + for path in ota_provider_dir.iterdir(): + if not path.is_dir(): + path.unlink() + + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, _remove_update_data, self._ota_provider_dir) + + await self.initialize() + async def stop(self) -> None: """Stop the OTA Provider.""" if self._ota_provider_proc: LOGGER.info("Terminating OTA Provider") - self._ota_provider_proc.terminate() + loop = asyncio.get_event_loop() + try: + await loop.run_in_executor(None, self._ota_provider_proc.terminate) + except ProcessLookupError as ex: + LOGGER.warning("Stopping OTA Provider failed with error:", exc_info=ex) if self._ota_provider_task: await self._ota_provider_task async def add_update(self, update_desc: dict, ota_file: Path) -> None: """Add update to the OTA provider.""" - update_json_path = DEFAULT_UPDATES_PATH / "updates.json" - - def _read_update_json(update_json_path: Path) -> None | UpdateFile: - if not update_json_path.exists(): - return None - - with open(update_json_path, "r") as json_file: - data = json.load(json_file) - return dataclass_from_dict(UpdateFile, data) - - loop = asyncio.get_running_loop() - update_file = await loop.run_in_executor( - None, _read_update_json, update_json_path - ) - - if not update_file: - update_file = UpdateFile(deviceSoftwareVersionModel=[]) - local_ota_url = str(ota_file) - for i, device_software in enumerate(update_file.deviceSoftwareVersionModel): + for i, device_software in enumerate( + self._get_ota_provider_image_list().deviceSoftwareVersionModel + ): if device_software.otaURL == local_ota_url: LOGGER.debug("Device software entry exists already, replacing!") - del update_file.deviceSoftwareVersionModel[i] + del self._get_ota_provider_image_list().deviceSoftwareVersionModel[i] # Convert to OTA Requestor descriptor file new_device_software = DeviceSoftwareVersionModel( @@ -127,18 +212,8 @@ def _read_update_json(update_json_path: Path) -> None | UpdateFile: maxApplicableSoftwareVersion=update_desc["maxApplicableSoftwareVersion"], otaURL=local_ota_url, ) - update_file.deviceSoftwareVersionModel.append(new_device_software) - - def _write_update_json(update_json_path: Path, update_file: UpdateFile) -> None: - update_file_dict = asdict(update_file) - with open(update_json_path, "w") as json_file: - json.dump(update_file_dict, json_file, indent=4) - - await loop.run_in_executor( - None, - _write_update_json, - update_json_path, - update_file, + self._get_ota_provider_image_list().deviceSoftwareVersionModel.append( + new_device_software ) async def download_update(self, update_desc: dict) -> None: diff --git a/matter_server/server/server.py b/matter_server/server/server.py index 63a71261..25d08da3 100644 --- a/matter_server/server/server.py +++ b/matter_server/server/server.py @@ -30,7 +30,11 @@ ServerInfoMessage, ) from ..server.client_handler import WebsocketClientHandler -from .const import DEFAULT_PAA_ROOT_CERTS_DIR, MIN_SCHEMA_VERSION +from .const import ( + DEFAULT_OTA_PROVIDER_DIR, + DEFAULT_PAA_ROOT_CERTS_DIR, + MIN_SCHEMA_VERSION, +) from .device_controller import MatterDeviceController from .stack import MatterStack from .storage import StorageController @@ -91,12 +95,12 @@ async def _handle_shutdown(app: web.Application) -> None: class MatterServer: """Serve Matter stack over WebSockets.""" - # pylint: disable=too-many-instance-attributes + # pylint: disable=too-many-instance-attributes,too-many-arguments _runner: web.AppRunner | None = None _http: MultiHostTCPSite | None = None - def __init__( + def __init__( # noqa: PLR0913 self, storage_path: str, vendor_id: int, @@ -107,6 +111,7 @@ def __init__( paa_root_cert_dir: Path | None = None, enable_test_net_dcl: bool = False, bluetooth_adapter_id: int | None = None, + ota_provider_dir: Path | None = None, ) -> None: """Initialize the Matter Server.""" self.storage_path = storage_path @@ -121,6 +126,10 @@ def __init__( self.paa_root_cert_dir = Path(paa_root_cert_dir).absolute() self.enable_test_net_dcl = enable_test_net_dcl self.bluetooth_enabled = bluetooth_adapter_id is not None + if ota_provider_dir is None: + self.ota_provider_dir = DEFAULT_OTA_PROVIDER_DIR + else: + self.ota_provider_dir = Path(ota_provider_dir).absolute() self.logger = logging.getLogger(__name__) self.app = web.Application() self.loop: asyncio.AbstractEventLoop | None = None @@ -165,7 +174,9 @@ async def start(self) -> None: # Initialize our (intermediate) device controller which keeps track # of Matter devices and their subscriptions. - self._device_controller = MatterDeviceController(self, self.paa_root_cert_dir) + self._device_controller = MatterDeviceController( + self, self.paa_root_cert_dir, self.ota_provider_dir + ) self._register_api_commands() await self._device_controller.initialize() From ee82e391023f09ee86ab7e0379f5db1bc1b1ac43 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Thu, 23 May 2024 14:31:53 +0200 Subject: [PATCH 05/39] Deploy chip-ota-provider-app in container --- Dockerfile | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/Dockerfile b/Dockerfile index 1f911762..dc3cee92 100644 --- a/Dockerfile +++ b/Dockerfile @@ -9,6 +9,7 @@ RUN \ set -x \ && apt-get update \ && apt-get install -y --no-install-recommends \ + curl \ libuv1 \ zlib1g \ libjson-c5 \ @@ -25,6 +26,21 @@ RUN \ ARG PYTHON_MATTER_SERVER +ENV chip_example_url "https://github.com/agners/matter-linux-example-apps/releases/download/v1.3.0.0" +ARG TARGETPLATFORM + +RUN \ + set -x \ + && echo "${TARGETPLATFORM}" \ + && if [ "${TARGETPLATFORM}" = "linux/amd64" ]; then \ + curl -Lo /usr/local/bin/chip-ota-provider-app "${chip_example_url}/chip-ota-provider-app-x86-64"; \ + elif [ "${TARGETPLATFORM}" = "linux/arm64" ]; then \ + curl -Lo /usr/local/bin/chip-ota-provider-app "${chip_example_url}/chip-ota-provider-app-aarch64"; \ + else \ + exit 1; \ + fi \ + && chmod +x /usr/local/bin/chip-ota-provider-app + # hadolint ignore=DL3013 RUN \ pip3 install --no-cache-dir "python-matter-server[server]==${PYTHON_MATTER_SERVER}" From 1cf634b62577baf865ef8f9a1ca3ce936b7aad02 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Fri, 24 May 2024 11:22:52 +0200 Subject: [PATCH 06/39] Check if DCL software updates are indeed applicable Verify that the DCL software update is indeed applicable to the currently software running on the device. Add test coverage as well. --- matter_server/server/device_controller.py | 4 +- matter_server/server/ota/dcl.py | 56 +++++++++++------ tests/server/ota/test_dcl.py | 75 +++++++++++++++++++++++ 3 files changed, 114 insertions(+), 21 deletions(-) create mode 100644 tests/server/ota/test_dcl.py diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index 871f8671..565ff2cc 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -32,7 +32,7 @@ from matter_server.common.models import CommissionableNodeData, CommissioningParameters from matter_server.server.helpers.attributes import parse_attributes_from_read_result from matter_server.server.helpers.utils import ping_ip -from matter_server.server.ota.dcl import check_updates +from matter_server.server.ota.dcl import check_for_update from matter_server.server.ota.provider import ExternalOtaProvider from matter_server.server.sdk import ChipDeviceControllerWrapper @@ -928,7 +928,7 @@ async def update_node(self, node_id: int) -> dict | None: BASIC_INFORMATION_SOFTWARE_VERSION_STRING_ATTRIBUTE_PATH ) - update = await check_updates(node_id, vid, pid, software_version) + update = await check_for_update(vid, pid, software_version) if not update: node_logger.info("No new update found.") return None diff --git a/matter_server/server/ota/dcl.py b/matter_server/server/ota/dcl.py index 3ce0eace..29411a8d 100644 --- a/matter_server/server/ota/dcl.py +++ b/matter_server/server/ota/dcl.py @@ -1,7 +1,7 @@ """Handle OTA software version endpoints of the DCL.""" import logging -from typing import Any +from typing import Any, cast from aiohttp import ClientError, ClientSession @@ -10,7 +10,7 @@ LOGGER = logging.getLogger(__name__) -async def get_software_versions(node_id: int, vid: int, pid: int) -> Any: +async def get_software_versions(vid: int, pid: int) -> Any: """Check DCL if there are updates available for a particular node.""" async with ClientSession(raise_for_status=True) as http_session: # fetch the paa certificates list @@ -20,9 +20,7 @@ async def get_software_versions(node_id: int, vid: int, pid: int) -> Any: return await response.json() -async def get_software_version( - node_id: int, vid: int, pid: int, software_version: int -) -> Any: +async def get_software_version(vid: int, pid: int, software_version: int) -> Any: """Check DCL if there are updates available for a particular node.""" async with ClientSession(raise_for_status=True) as http_session: # fetch the paa certificates list @@ -32,27 +30,47 @@ async def get_software_version( return await response.json() -async def check_updates( - node_id: int, vid: int, pid: int, current_software_version: int +async def check_for_update( + vid: int, pid: int, current_software_version: int ) -> None | dict: """Check if there is a newer software version available on the DCL.""" try: - versions = await get_software_versions(node_id, vid, pid) + versions = await get_software_versions(vid, pid) - software_versions: list[int] = versions["modelVersions"]["softwareVersions"] - latest_software_version = max(software_versions) - if latest_software_version <= current_software_version: + all_software_versions: list[int] = versions["modelVersions"]["softwareVersions"] + newer_software_versions = [ + version + for version in all_software_versions + if version > current_software_version + ] + + # Check if there is a newer software version available + if not newer_software_versions: + LOGGER.info("No newer software version available.") return None - version: dict = await get_software_version( - node_id, vid, pid, latest_software_version - ) - if isinstance(version, dict) and "modelVersion" in version: - result: Any = version["modelVersion"] - if isinstance(result, dict): - return result + # Check if latest firmware is applicable, and backtrack from there + for version in sorted(newer_software_versions, reverse=True): + version_res: dict = await get_software_version(vid, pid, version) + if not isinstance(version_res, dict): + raise TypeError("Unexpected DCL response.") + + if "modelVersion" not in version_res: + raise ValueError("Unexpected DCL response.") + + version_candidate: dict = cast(dict, version_res["modelVersion"]) + + # Check minApplicableSoftwareVersion/maxApplicableSoftwareVersion + min_sw_version = version_candidate["minApplicableSoftwareVersion"] + max_sw_version = version_candidate["maxApplicableSoftwareVersion"] + if ( + current_software_version < min_sw_version + or current_software_version > max_sw_version + ): + LOGGER.debug("Software version %d not applicable.", version) + continue - logging.error("Unexpected DCL response.") + return version_candidate return None except (ClientError, TimeoutError) as err: diff --git a/tests/server/ota/test_dcl.py b/tests/server/ota/test_dcl.py new file mode 100644 index 00000000..7de82a4e --- /dev/null +++ b/tests/server/ota/test_dcl.py @@ -0,0 +1,75 @@ +"""Test DCL OTA updates.""" + +from unittest.mock import AsyncMock, patch + +from matter_server.server.ota.dcl import check_for_update + +# Mock the DCL responses (sample from https://on.dcl.csa-iot.org/dcl/model/versions/4447/8194) +DCL_RESPONSE_SOFTWARE_VERSIONS = { + "modelVersions": { + "vid": 4447, + "pid": 8194, + "softwareVersions": [1000, 1011], + } +} + +# Mock the DCL responses (sample from https://on.dcl.csa-iot.org/dcl/model/versions/4447/8194/1011) +DCL_RESPONSE_SOFTWARE_VERSION_1011 = { + "modelVersion": { + "vid": 4447, + "pid": 8194, + "softwareVersion": 1011, + "softwareVersionString": "1.0.1.1", + "cdVersionNumber": 1, + "firmwareInformation": "", + "softwareVersionValid": True, + "otaUrl": "https://cdn.aqara.com/cdn/opencloud-product/mainland/product-firmware/prd/aqara.matter.4447_8194/20240306154144_rel_up_to_enc_ota_sbl_app_aqara.matter.4447_8194_1.0.1.1_115F_2002_20240115195007_7a9b91.ota", + "otaFileSize": "615708", + "otaChecksum": "rFZ6WdH0DuuCf7HVoRmNjCF73mYZ98DGYpHoDKmf0Bw=", + "otaChecksumType": 1, + "minApplicableSoftwareVersion": 1000, + "maxApplicableSoftwareVersion": 1010, + "releaseNotesUrl": "", + "creator": "cosmos1qpz3ghnqj6my7gzegkftzav9hpxymkx6zdk73v", + } +} + + +async def test_check_updates(): + """Test the case where the latest software version is applicable.""" + with ( + patch( + "matter_server.server.ota.dcl.get_software_versions", + new_callable=AsyncMock, + return_value=DCL_RESPONSE_SOFTWARE_VERSIONS, + ), + patch( + "matter_server.server.ota.dcl.get_software_version", + new_callable=AsyncMock, + return_value=DCL_RESPONSE_SOFTWARE_VERSION_1011, + ), + ): + # Call the function with a current software version of 1000 + result = await check_for_update(4447, 8194, 1000) + + assert result == DCL_RESPONSE_SOFTWARE_VERSION_1011["modelVersion"] + + +async def test_check_updates_not_applicable(): + """Test the case where the latest software version is not applicable.""" + with ( + patch( + "matter_server.server.ota.dcl.get_software_versions", + new_callable=AsyncMock, + return_value=DCL_RESPONSE_SOFTWARE_VERSIONS, + ), + patch( + "matter_server.server.ota.dcl.get_software_version", + new_callable=AsyncMock, + return_value=DCL_RESPONSE_SOFTWARE_VERSION_1011, + ), + ): + # Call the function with a current software version of 1 + result = await check_for_update(4447, 8194, 1) + + assert result is None From f698b518897796bb53072dff9dab6807cd6d0707 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Fri, 24 May 2024 13:32:04 +0200 Subject: [PATCH 07/39] Introduce hardcoded updates Add global check for update where we can insert hardcoded updates or updates from other sources in the future. --- matter_server/server/device_controller.py | 2 +- matter_server/server/ota/__init__.py | 28 +++++++++++++++++++++++ matter_server/server/ota/provider.py | 4 ++-- 3 files changed, 31 insertions(+), 3 deletions(-) create mode 100644 matter_server/server/ota/__init__.py diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index 565ff2cc..e38e9f69 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -32,7 +32,7 @@ from matter_server.common.models import CommissionableNodeData, CommissioningParameters from matter_server.server.helpers.attributes import parse_attributes_from_read_result from matter_server.server.helpers.utils import ping_ip -from matter_server.server.ota.dcl import check_for_update +from matter_server.server.ota import check_for_update from matter_server.server.ota.provider import ExternalOtaProvider from matter_server.server.sdk import ChipDeviceControllerWrapper diff --git a/matter_server/server/ota/__init__.py b/matter_server/server/ota/__init__.py new file mode 100644 index 00000000..d638afc5 --- /dev/null +++ b/matter_server/server/ota/__init__.py @@ -0,0 +1,28 @@ +"""OTA implementation for the Matter Server.""" + +from matter_server.server.ota import dcl + +HARDCODED_UPDATES: dict[tuple[int, int], dict] = { + # OTA requestor example app, useful for testing + (0xFFF1, 0x8001): { + "vid": 0xFFF1, + "pid": 0x8001, + "softwareVersion": 2, + "softwareVersionString": "2.0", + "cdVersionNumber": 1, + "softwareVersionValid": True, + "minApplicableSoftwareVersion": 1, + "maxApplicableSoftwareVersion": 1, + "otaUrl": "https://github.com/agners/matter-linux-example-apps/releases/download/v1.3.0.0/chip-ota-requestor-app-x86-64.ota", + } +} + + +async def check_for_update( + vid: int, pid: int, current_software_version: int +) -> None | dict: + """Check for software updates.""" + if (vid, pid) in HARDCODED_UPDATES: + return HARDCODED_UPDATES[(vid, pid)] + + return await dcl.check_for_update(vid, pid, current_software_version) diff --git a/matter_server/server/ota/provider.py b/matter_server/server/ota/provider.py index 0e540aec..3e605f87 100644 --- a/matter_server/server/ota/provider.py +++ b/matter_server/server/ota/provider.py @@ -225,7 +225,7 @@ async def download_update(self, update_desc: dict) -> None: loop = asyncio.get_running_loop() await loop.run_in_executor( - None, functools.partial(DEFAULT_UPDATES_PATH.mkdir, exists_ok=True) + None, functools.partial(DEFAULT_UPDATES_PATH.mkdir, exist_ok=True) ) file_path = DEFAULT_UPDATES_PATH / file_name @@ -236,7 +236,7 @@ async def download_update(self, update_desc: dict) -> None: try: async with ClientSession(raise_for_status=True) as session: # fetch the paa certificates list - logging.debug("Download update from f{url}.") + LOGGER.debug("Download update from '%s'.", url) async with session.get(url) as response: with file_path.open("wb") as f: while True: From 93f38941c9d710e2a73c5ad695726e7a7b23defb Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Fri, 24 May 2024 14:56:23 +0200 Subject: [PATCH 08/39] Split update WebSocket command into two commands Make check_node_update a separate WebSocket command which only checks for updates. The update_node command then will download and actually apply the update. --- matter_server/common/models.py | 1 + matter_server/server/device_controller.py | 94 +++++++++++++++-------- matter_server/server/ota/__init__.py | 12 ++- matter_server/server/ota/dcl.py | 66 ++++++++++------ tests/server/ota/test_dcl.py | 82 +++++++++++--------- 5 files changed, 161 insertions(+), 94 deletions(-) diff --git a/matter_server/common/models.py b/matter_server/common/models.py index 2b295ace..ea550de1 100644 --- a/matter_server/common/models.py +++ b/matter_server/common/models.py @@ -47,6 +47,7 @@ class APICommand(str, Enum): PING_NODE = "ping_node" GET_NODE_IP_ADDRESSES = "get_node_ip_addresses" IMPORT_TEST_NODE = "import_test_node" + CHECK_NODE_UPDATE = "check_node_update" UPDATE_NODE = "update_node" diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index e38e9f69..2c82a37e 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -896,8 +896,8 @@ async def import_test_node(self, dump: str) -> None: self._nodes[node.node_id] = node self.server.signal_event(EventType.NODE_ADDED, node) - @api_command(APICommand.UPDATE_NODE) - async def update_node(self, node_id: int) -> dict | None: + @api_command(APICommand.CHECK_NODE_UPDATE) + async def check_node_update(self, node_id: int) -> dict | None: """ Check if there is an update for a particular node. @@ -906,8 +906,27 @@ async def update_node(self, node_id: int) -> dict | None: information of the latest update available. """ - node_logger = LOGGER.getChild(f"node_{node_id}") - node = self._nodes[node_id] + return await self._check_node_update(node_id) + + @api_command(APICommand.UPDATE_NODE) + async def update_node(self, node_id: int, software_version: int) -> dict | None: + """ + Update a node to a new software version. + + This command checks if the requested software version is indeed still available + and if so, it will start the update process. The update process will be handled + by the built-in OTA provider. The OTA provider will download the update and + notify the node about the new update. + """ + + update = await self._check_node_update(node_id, software_version) + if update is None: + logging.error( + "Software version %d is not available for node %d", + software_version, + node_id, + ) + return None if self.chip_controller is None: raise RuntimeError("Device Controller not initialized.") @@ -916,34 +935,7 @@ async def update_node(self, node_id: int) -> dict | None: LOGGER.warning("No OTA provider found, updates not possible.") return None - node_logger.debug("Check for updates.") - vid = cast(int, node.attributes.get(BASIC_INFORMATION_VENDOR_ID_ATTRIBUTE_PATH)) - pid = cast( - int, node.attributes.get(BASIC_INFORMATION_PRODUCT_ID_ATTRIBUTE_PATH) - ) - software_version = cast( - int, node.attributes.get(BASIC_INFORMATION_SOFTWARE_VERSION_ATTRIBUTE_PATH) - ) - software_version_string = node.attributes.get( - BASIC_INFORMATION_SOFTWARE_VERSION_STRING_ATTRIBUTE_PATH - ) - - update = await check_for_update(vid, pid, software_version) - if not update: - node_logger.info("No new update found.") - return None - - if "otaUrl" not in update: - node_logger.warning("Update found, but no OTA URL provided.") - return None - - node_logger.info( - "New software update found: %s (current %s). Preparing updates...", - update["softwareVersionString"], - software_version_string, - ) - - # Add to OTA provider + # Add update to the OTA provider await self._ota_provider.download_update(update) ota_provider_node_id = self._ota_provider.get_node_id() @@ -1042,6 +1034,44 @@ async def update_node(self, node_id: int) -> dict | None: return update + async def _check_node_update( + self, + node_id: int, + requested_software_version: int | None = None, + ) -> dict | None: + node_logger = LOGGER.getChild(f"node_{node_id}") + node = self._nodes[node_id] + + node_logger.debug("Check for updates.") + vid = cast(int, node.attributes.get(BASIC_INFORMATION_VENDOR_ID_ATTRIBUTE_PATH)) + pid = cast( + int, node.attributes.get(BASIC_INFORMATION_PRODUCT_ID_ATTRIBUTE_PATH) + ) + software_version = cast( + int, node.attributes.get(BASIC_INFORMATION_SOFTWARE_VERSION_ATTRIBUTE_PATH) + ) + software_version_string = node.attributes.get( + BASIC_INFORMATION_SOFTWARE_VERSION_STRING_ATTRIBUTE_PATH + ) + + update = await check_for_update( + vid, pid, software_version, requested_software_version + ) + if not update: + node_logger.info("No new update found.") + return None + + if "otaUrl" not in update: + node_logger.warning("Update found, but no OTA URL provided.") + return None + + node_logger.info( + "New software update found: %s (current %s).", + update["softwareVersionString"], + software_version_string, + ) + return update + async def _subscribe_node(self, node_id: int) -> None: """ Subscribe to all node state changes/events for an individual node. diff --git a/matter_server/server/ota/__init__.py b/matter_server/server/ota/__init__.py index d638afc5..04ab6be5 100644 --- a/matter_server/server/ota/__init__.py +++ b/matter_server/server/ota/__init__.py @@ -19,10 +19,18 @@ async def check_for_update( - vid: int, pid: int, current_software_version: int + vid: int, + pid: int, + current_software_version: int, + requested_software_version: int | None = None, ) -> None | dict: """Check for software updates.""" if (vid, pid) in HARDCODED_UPDATES: - return HARDCODED_UPDATES[(vid, pid)] + update = HARDCODED_UPDATES[(vid, pid)] + if ( + requested_software_version is None + or update["softwareVersion"] == requested_software_version + ): + return update return await dcl.check_for_update(vid, pid, current_software_version) diff --git a/matter_server/server/ota/dcl.py b/matter_server/server/ota/dcl.py index 29411a8d..5d52fcc9 100644 --- a/matter_server/server/ota/dcl.py +++ b/matter_server/server/ota/dcl.py @@ -10,7 +10,7 @@ LOGGER = logging.getLogger(__name__) -async def get_software_versions(vid: int, pid: int) -> Any: +async def _get_software_versions(vid: int, pid: int) -> Any: """Check DCL if there are updates available for a particular node.""" async with ClientSession(raise_for_status=True) as http_session: # fetch the paa certificates list @@ -20,7 +20,7 @@ async def get_software_versions(vid: int, pid: int) -> Any: return await response.json() -async def get_software_version(vid: int, pid: int, software_version: int) -> Any: +async def _get_software_version(vid: int, pid: int, software_version: int) -> Any: """Check DCL if there are updates available for a particular node.""" async with ClientSession(raise_for_status=True) as http_session: # fetch the paa certificates list @@ -30,12 +30,45 @@ async def get_software_version(vid: int, pid: int, software_version: int) -> Any return await response.json() +async def _check_update_version( + vid: int, pid: int, version: int, current_software_version: int +) -> None | dict: + version_res: dict = await _get_software_version(vid, pid, version) + if not isinstance(version_res, dict): + raise TypeError("Unexpected DCL response.") + + if "modelVersion" not in version_res: + raise ValueError("Unexpected DCL response.") + + version_candidate: dict = cast(dict, version_res["modelVersion"]) + + # Check minApplicableSoftwareVersion/maxApplicableSoftwareVersion + min_sw_version = version_candidate["minApplicableSoftwareVersion"] + max_sw_version = version_candidate["maxApplicableSoftwareVersion"] + if ( + current_software_version < min_sw_version + or current_software_version > max_sw_version + ): + return None + + return version_candidate + + async def check_for_update( - vid: int, pid: int, current_software_version: int + vid: int, + pid: int, + current_software_version: int, + requested_software_version: int | None = None, ) -> None | dict: - """Check if there is a newer software version available on the DCL.""" + """Check if there is a software update available on the DCL.""" try: - versions = await get_software_versions(vid, pid) + if requested_software_version is not None: + return await _check_update_version( + vid, pid, requested_software_version, current_software_version + ) + + # Get all versions and check each one of them. + versions = await _get_software_versions(vid, pid) all_software_versions: list[int] = versions["modelVersions"]["softwareVersions"] newer_software_versions = [ @@ -51,26 +84,11 @@ async def check_for_update( # Check if latest firmware is applicable, and backtrack from there for version in sorted(newer_software_versions, reverse=True): - version_res: dict = await get_software_version(vid, pid, version) - if not isinstance(version_res, dict): - raise TypeError("Unexpected DCL response.") - - if "modelVersion" not in version_res: - raise ValueError("Unexpected DCL response.") - - version_candidate: dict = cast(dict, version_res["modelVersion"]) - - # Check minApplicableSoftwareVersion/maxApplicableSoftwareVersion - min_sw_version = version_candidate["minApplicableSoftwareVersion"] - max_sw_version = version_candidate["maxApplicableSoftwareVersion"] - if ( - current_software_version < min_sw_version - or current_software_version > max_sw_version + if version_candidate := await _check_update_version( + vid, pid, version, current_software_version ): - LOGGER.debug("Software version %d not applicable.", version) - continue - - return version_candidate + return version_candidate + LOGGER.debug("Software version %d not applicable.", version) return None except (ClientError, TimeoutError) as err: diff --git a/tests/server/ota/test_dcl.py b/tests/server/ota/test_dcl.py index 7de82a4e..c52f189f 100644 --- a/tests/server/ota/test_dcl.py +++ b/tests/server/ota/test_dcl.py @@ -2,6 +2,8 @@ from unittest.mock import AsyncMock, patch +import pytest + from matter_server.server.ota.dcl import check_for_update # Mock the DCL responses (sample from https://on.dcl.csa-iot.org/dcl/model/versions/4447/8194) @@ -35,41 +37,49 @@ } -async def test_check_updates(): +@pytest.fixture(name="get_software_versions") +def mock_get_software_versions(): + """Mock the _get_software_versions function.""" + with patch( + "matter_server.server.ota.dcl._get_software_versions", + new_callable=AsyncMock, + return_value=DCL_RESPONSE_SOFTWARE_VERSIONS, + ) as mock: + yield mock + + +@pytest.fixture(name="get_software_version") +def mock_get_software_version(): + """Mock the _get_software_version function.""" + with patch( + "matter_server.server.ota.dcl._get_software_version", + new_callable=AsyncMock, + return_value=DCL_RESPONSE_SOFTWARE_VERSION_1011, + ) as mock: + yield mock + + +async def test_check_updates(get_software_versions, get_software_version): """Test the case where the latest software version is applicable.""" - with ( - patch( - "matter_server.server.ota.dcl.get_software_versions", - new_callable=AsyncMock, - return_value=DCL_RESPONSE_SOFTWARE_VERSIONS, - ), - patch( - "matter_server.server.ota.dcl.get_software_version", - new_callable=AsyncMock, - return_value=DCL_RESPONSE_SOFTWARE_VERSION_1011, - ), - ): - # Call the function with a current software version of 1000 - result = await check_for_update(4447, 8194, 1000) - - assert result == DCL_RESPONSE_SOFTWARE_VERSION_1011["modelVersion"] - - -async def test_check_updates_not_applicable(): + # Call the function with a current software version of 1000 + result = await check_for_update(4447, 8194, 1000) + + assert result == DCL_RESPONSE_SOFTWARE_VERSION_1011["modelVersion"] + + +async def test_check_updates_not_applicable( + get_software_versions, get_software_version +): """Test the case where the latest software version is not applicable.""" - with ( - patch( - "matter_server.server.ota.dcl.get_software_versions", - new_callable=AsyncMock, - return_value=DCL_RESPONSE_SOFTWARE_VERSIONS, - ), - patch( - "matter_server.server.ota.dcl.get_software_version", - new_callable=AsyncMock, - return_value=DCL_RESPONSE_SOFTWARE_VERSION_1011, - ), - ): - # Call the function with a current software version of 1 - result = await check_for_update(4447, 8194, 1) - - assert result is None + # Call the function with a current software version of 1 + result = await check_for_update(4447, 8194, 1) + + assert result is None + + +async def test_check_updates_specific_version(get_software_version): + """Test the case to get a specific version.""" + # Call the function with a current software version of 1000 and request 1011 as update + result = await check_for_update(4447, 8194, 1000, 1011) + + assert result == DCL_RESPONSE_SOFTWARE_VERSION_1011["modelVersion"] From 09a446964c7261eb93d01c2e8d6a897881095ee1 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Fri, 24 May 2024 15:13:40 +0200 Subject: [PATCH 09/39] Introduce Update logic specific exceptions Add Update specific exceptions and raise them where appropriate. --- matter_server/common/errors.py | 12 ++++++ matter_server/server/device_controller.py | 48 ++++++++++++----------- matter_server/server/ota/dcl.py | 4 +- matter_server/server/ota/provider.py | 2 + 4 files changed, 42 insertions(+), 24 deletions(-) diff --git a/matter_server/common/errors.py b/matter_server/common/errors.py index d07fc222..8ce6dd0e 100644 --- a/matter_server/common/errors.py +++ b/matter_server/common/errors.py @@ -77,6 +77,18 @@ class InvalidCommand(MatterError): error_code = 9 +class UpdateCheckError(MatterError): + """Error raised when there was an error during searching for updates.""" + + error_code = 10 + + +class UpdateError(MatterError): + """Error raised when there was an error during applying updates.""" + + error_code = 11 + + def exception_from_error_code(error_code: int) -> type[MatterError]: """Return correct Exception class from error_code.""" return ERROR_MAP.get(error_code, MatterError) diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index 2c82a37e..7a1525c5 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -43,6 +43,8 @@ NodeNotExists, NodeNotReady, NodeNotResolving, + UpdateCheckError, + UpdateError, ) from ..common.helpers.api import api_command from ..common.helpers.json import JSON_DECODE_EXCEPTIONS, json_loads @@ -921,19 +923,15 @@ async def update_node(self, node_id: int, software_version: int) -> dict | None: update = await self._check_node_update(node_id, software_version) if update is None: - logging.error( - "Software version %d is not available for node %d", - software_version, - node_id, + raise UpdateCheckError( + f"Software version {software_version} is not available for node {node_id}." ) - return None if self.chip_controller is None: raise RuntimeError("Device Controller not initialized.") if not self._ota_provider: - LOGGER.warning("No OTA provider found, updates not possible.") - return None + raise UpdateError("No OTA provider found, updates not possible.") # Add update to the OTA provider await self._ota_provider.download_update(update) @@ -1011,26 +1009,33 @@ async def update_node(self, node_id: int, software_version: int) -> dict | None: ) ) if write_result[0].Status != Status.Success: - logging.error("Failed writing adjusted OTA Provider App ACL.") + logging.error( + "Failed writing adjusted OTA Provider App ACL: Status %s.", + str(write_result[0].Status), + ) await self.remove_node(ota_provider_node_id) - return None + raise UpdateError("Error while setting up OTA Provider.") except ChipStackError as ex: logging.exception("Failed adjusting OTA Provider App ACL.", exc_info=ex) await self.remove_node(ota_provider_node_id) - else: - self._ota_provider.set_node_id(ota_provider_node_id) + raise UpdateError("Error while setting up OTA Provider.") from ex + + self._ota_provider.set_node_id(ota_provider_node_id) # Notify node about the new update! - await self.chip_controller.SendCommand( - nodeid=node_id, - endpoint=0, - payload=Clusters.OtaSoftwareUpdateRequestor.Commands.AnnounceOTAProvider( - providerNodeID=ota_provider_node_id, - vendorID=0, # TODO: Use Server Vendor ID - announcementReason=Clusters.OtaSoftwareUpdateRequestor.Enums.AnnouncementReasonEnum.kUpdateAvailable, + try: + await self.chip_controller.SendCommand( + nodeid=node_id, endpoint=0, - ), - ) + payload=Clusters.OtaSoftwareUpdateRequestor.Commands.AnnounceOTAProvider( + providerNodeID=ota_provider_node_id, + vendorID=self.server.vendor_id, + announcementReason=Clusters.OtaSoftwareUpdateRequestor.Enums.AnnouncementReasonEnum.kUpdateAvailable, + endpoint=ExternalOtaProvider.ENDPOINT_ID, + ), + ) + except ChipStackError as ex: + raise UpdateError("Error while announcing OTA Provider to node.") from ex return update @@ -1062,8 +1067,7 @@ async def _check_node_update( return None if "otaUrl" not in update: - node_logger.warning("Update found, but no OTA URL provided.") - return None + raise UpdateCheckError("Update found, but no OTA URL provided.") node_logger.info( "New software update found: %s (current %s).", diff --git a/matter_server/server/ota/dcl.py b/matter_server/server/ota/dcl.py index 5d52fcc9..d8bdcaea 100644 --- a/matter_server/server/ota/dcl.py +++ b/matter_server/server/ota/dcl.py @@ -5,6 +5,7 @@ from aiohttp import ClientError, ClientSession +from matter_server.common.errors import UpdateCheckError from matter_server.server.helpers import DCL_PRODUCTION_URL LOGGER = logging.getLogger(__name__) @@ -92,5 +93,4 @@ async def check_for_update( return None except (ClientError, TimeoutError) as err: - LOGGER.error("Fetching software version failed: error %s", err, exc_info=err) - return None + raise UpdateCheckError("Fetching software version failed.") from err diff --git a/matter_server/server/ota/provider.py b/matter_server/server/ota/provider.py index 3e605f87..b37b3deb 100644 --- a/matter_server/server/ota/provider.py +++ b/matter_server/server/ota/provider.py @@ -54,6 +54,8 @@ class ExternalOtaProvider: for devices. """ + ENDPOINT_ID: Final[int] = 0 + def __init__(self, ota_provider_dir: Path) -> None: """Initialize the OTA provider.""" self._ota_provider_dir: Path = ota_provider_dir From e1a5941e1d793de23273e52bf3357d70fdbe21a3 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Fri, 24 May 2024 15:49:54 +0200 Subject: [PATCH 10/39] Implement OTA checksum verification Add capability to verify the checksum of the OTA file while downloading it. --- matter_server/server/ota/__init__.py | 2 ++ matter_server/server/ota/provider.py | 53 +++++++++++++++++++++++++--- 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/matter_server/server/ota/__init__.py b/matter_server/server/ota/__init__.py index 04ab6be5..5dafba54 100644 --- a/matter_server/server/ota/__init__.py +++ b/matter_server/server/ota/__init__.py @@ -11,6 +11,8 @@ "softwareVersionString": "2.0", "cdVersionNumber": 1, "softwareVersionValid": True, + "otaChecksum": "7qcyvg2kPmKZaDLIk8C7Vyteqf4DI73x0tFZkmPALCo=", + "otaChecksumType": 1, "minApplicableSoftwareVersion": 1, "maxApplicableSoftwareVersion": 1, "otaUrl": "https://github.com/agners/matter-linux-example-apps/releases/download/v1.3.0.0/chip-ota-requestor-app-x86-64.ota", diff --git a/matter_server/server/ota/provider.py b/matter_server/server/ota/provider.py index b37b3deb..edb5ccd9 100644 --- a/matter_server/server/ota/provider.py +++ b/matter_server/server/ota/provider.py @@ -1,8 +1,10 @@ """Handling Matter OTA provider.""" import asyncio +from base64 import b64encode from dataclasses import asdict, dataclass import functools +import hashlib import json import logging from pathlib import Path @@ -12,6 +14,7 @@ from aiohttp import ClientError, ClientSession +from matter_server.common.errors import UpdateError from matter_server.common.helpers.util import dataclass_from_dict if TYPE_CHECKING: @@ -21,6 +24,22 @@ DEFAULT_UPDATES_PATH: Final[Path] = Path("updates") +# From Matter SDK src/app/ota_image_tool.py +CHECHKSUM_TYPES: Final[dict[int, str]] = { + 1: "sha256", + 2: "sha256_128", + 3: "sha256_120", + 4: "sha256_96", + 5: "sha256_64", + 6: "sha256_32", + 7: "sha384", + 8: "sha512", + 9: "sha3_224", + 10: "sha3_256", + 11: "sha3_384", + 12: "sha3_512", +} + @dataclass class DeviceSoftwareVersionModel: # pylint: disable=C0103 @@ -231,11 +250,22 @@ async def download_update(self, update_desc: dict) -> None: ) file_path = DEFAULT_UPDATES_PATH / file_name - if await loop.run_in_executor(None, file_path.exists): - LOGGER.info("File '%s' exists already, skipping download.", file_name) - return try: + checksum_alg = None + if ( + "otaChecksum" in update_desc + and "otaChecksumType" in update_desc + and update_desc["otaChecksumType"] in CHECHKSUM_TYPES + ): + checksum_alg = hashlib.new( + CHECHKSUM_TYPES[update_desc["otaChecksumType"]] + ) + else: + LOGGER.warning( + "No OTA checksum type or not supported, OTA will not be checked." + ) + async with ClientSession(raise_for_status=True) as session: # fetch the paa certificates list LOGGER.debug("Download update from '%s'.", url) @@ -246,8 +276,21 @@ async def download_update(self, update_desc: dict) -> None: if not chunk: break await loop.run_in_executor(None, f.write, chunk) - - # TODO: Check against otaChecksum + if checksum_alg: + checksum_alg.update(chunk) + + # Download finished, check checksum if necessary + if checksum_alg: + checksum = b64encode(checksum_alg.digest()).decode("ascii") + if checksum != update_desc["otaChecksum"]: + LOGGER.error( + "Checksum mismatch for file '%s', expected: %s, got: %s", + file_name, + update_desc["otaChecksum"], + checksum, + ) + await loop.run_in_executor(None, file_path.unlink) + raise UpdateError("Checksum mismatch!") LOGGER.info( "File '%s' downloaded to '%s'", file_name, DEFAULT_UPDATES_PATH From 116077de355d8f4feba0b6d09fde7e36c795414c Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Mon, 27 May 2024 15:30:32 +0200 Subject: [PATCH 11/39] Add client commands for updates Add two new client commands to check for updates and trigger the update. --- matter_server/client/client.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/matter_server/client/client.py b/matter_server/client/client.py index 3777d15c..f9ff1477 100644 --- a/matter_server/client/client.py +++ b/matter_server/client/client.py @@ -509,6 +509,26 @@ async def interview_node(self, node_id: int) -> None: """Interview a node.""" await self.send_command(APICommand.INTERVIEW_NODE, node_id=node_id) + async def check_node_update(self, node_id: int) -> dict[str, Any]: + """Check Node for updates. + + Return a dict with the available update information. Most notable + "softwareVersion" contains the integer value of the update version which then + can be used for the update_node command to trigger the update. + + The "softwareVersionString" is a human friendly version string. + """ + node_update = await self.send_command( + APICommand.CHECK_NODE_UPDATE, node_id=node_id + ) + return cast(dict[str, Any], node_update) + + async def update_node(self, node_id: int, software_version: int) -> None: + """Start node update to a particular version.""" + await self.send_command( + APICommand.UPDATE_NODE, node_id=node_id, software_version=software_version + ) + def _prepare_message( self, command: str, From 5b41888e8e652bf16f781c6ec80b5ceffd1db93f Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Mon, 27 May 2024 23:07:57 +0200 Subject: [PATCH 12/39] Improve DCL error message when download fails --- matter_server/server/ota/dcl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/matter_server/server/ota/dcl.py b/matter_server/server/ota/dcl.py index d8bdcaea..881b0fa0 100644 --- a/matter_server/server/ota/dcl.py +++ b/matter_server/server/ota/dcl.py @@ -93,4 +93,6 @@ async def check_for_update( return None except (ClientError, TimeoutError) as err: - raise UpdateCheckError("Fetching software version failed.") from err + raise UpdateCheckError( + f"Fetching software versions from DCL for device with vendor id {vid} product id {pid} failed." + ) from err From 4b0911cb4e56de3a4683513acbd57fe7af44fe61 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Tue, 28 May 2024 14:11:38 +0200 Subject: [PATCH 13/39] Improve OTA Provider handling Create log files for each OTA Provider run. Improve setup and commissioning of the OTA Provider. --- matter_server/server/device_controller.py | 154 ++++++++++++---------- matter_server/server/ota/__init__.py | 3 +- matter_server/server/ota/provider.py | 28 ++-- 3 files changed, 105 insertions(+), 80 deletions(-) diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index 7a1525c5..927d3121 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -910,6 +910,83 @@ async def check_node_update(self, node_id: int) -> dict | None: return await self._check_node_update(node_id) + async def _initialize_ota_provider(self, ota_provider: ExternalOtaProvider) -> None: + """Commissions the OTA Provider.""" + + if self.chip_controller is None: + raise RuntimeError("Device Controller not initialized.") + + # The OTA Provider has not been commissioned yet, let's do it now. + LOGGER.info("Commissioning the built-in OTA Provider App.") + try: + ota_provider_node = await self.commission_on_network( + ota_provider.get_passcode(), + # TODO: Filtering by long discriminator seems broken + # filter_type=FilterType.LONG_DISCRIMINATOR, + # filter=ota_provider.get_descriminator(), + ) + ota_provider_node_id = ota_provider_node.node_id + except NodeCommissionFailed: + LOGGER.error("Failed to commission OTA Provider App!") + return + + LOGGER.info( + "OTA Provider App commissioned with node id %d.", + ota_provider_node_id, + ) + + # Adjust ACL of OTA Requestor such that Node peer-to-peer communication + # is allowed. + try: + read_result = await self.chip_controller.ReadAttribute( + ota_provider_node_id, [(0, Clusters.AccessControl.Attributes.Acl)] + ) + acl_list = cast( + list, + read_result[0][Clusters.AccessControl][ + Clusters.AccessControl.Attributes.Acl + ], + ) + + # Add new ACL entry... + acl_list.append( + Clusters.AccessControl.Structs.AccessControlEntryStruct( + fabricIndex=1, + privilege=Clusters.AccessControl.Enums.AccessControlEntryPrivilegeEnum.kOperate, + authMode=Clusters.AccessControl.Enums.AccessControlEntryAuthModeEnum.kCase, + subjects=Types.NullValue, + targets=[ + Clusters.AccessControl.Structs.AccessControlTargetStruct( + cluster=Clusters.OtaSoftwareUpdateProvider.id, + endpoint=0, + deviceType=Types.NullValue, + ) + ], + ) + ) + + # And write. This is persistent, so only need to be done after we commissioned + # the OTA Provider App. + write_result: Attribute.AttributeWriteResult = ( + await self.chip_controller.WriteAttribute( + ota_provider_node_id, + [(0, Clusters.AccessControl.Attributes.Acl(acl_list))], + ) + ) + if write_result[0].Status != Status.Success: + logging.error( + "Failed writing adjusted OTA Provider App ACL: Status %s.", + str(write_result[0].Status), + ) + await self.remove_node(ota_provider_node_id) + raise UpdateError("Error while setting up OTA Provider.") + except ChipStackError as ex: + logging.exception("Failed adjusting OTA Provider App ACL.", exc_info=ex) + await self.remove_node(ota_provider_node_id) + raise UpdateError("Error while setting up OTA Provider.") from ex + + ota_provider.set_node_id(ota_provider_node_id) + @api_command(APICommand.UPDATE_NODE) async def update_node(self, node_id: int, software_version: int) -> dict | None: """ @@ -937,7 +1014,9 @@ async def update_node(self, node_id: int, software_version: int) -> dict | None: await self._ota_provider.download_update(update) ota_provider_node_id = self._ota_provider.get_node_id() - if ota_provider_node_id not in self._nodes: + if ota_provider_node_id is None: + LOGGER.info("Initializing OTA Provider") + elif ota_provider_node_id not in self._nodes: LOGGER.warning( "OTA Provider node id %d no longer exists! Resetting...", ota_provider_node_id, @@ -947,82 +1026,17 @@ async def update_node(self, node_id: int, software_version: int) -> dict | None: # Make sure any previous instances get stopped await self._ota_provider.stop() - self._ota_provider.start() + await self._ota_provider.start() # Wait for OTA provider to be ready # TODO: Detect when OTA provider is ready await asyncio.sleep(2) if not ota_provider_node_id: - # The OTA Provider has not been commissioned yet, let's do it now. - LOGGER.info("Commissioning the built-in OTA Provider App.") - try: - ota_provider_node = await self.commission_on_network( - self._ota_provider.get_passcode(), - # TODO: Filtering by long discriminator seems broken - # filter_type=FilterType.LONG_DISCRIMINATOR, - # filter=self._ota_provider.get_descriminator(), - ) - ota_provider_node_id = ota_provider_node.node_id - except NodeCommissionFailed: - LOGGER.error("Failed to commission OTA Provider App!") - return None - LOGGER.info( - "OTA Provider App commissioned with node id %d.", - ota_provider_node_id, - ) - - # Adjust ACL of OTA Requestor such that Node peer-to-peer communication - # is allowed. - try: - read_result = await self.chip_controller.ReadAttribute( - ota_provider_node_id, [(0, Clusters.AccessControl.Attributes.Acl)] - ) - acl_list = cast( - list, - read_result[0][Clusters.AccessControl][ - Clusters.AccessControl.Attributes.Acl - ], - ) - - # Add new ACL entry... - acl_list.append( - Clusters.AccessControl.Structs.AccessControlEntryStruct( - fabricIndex=1, - privilege=3, - authMode=2, - subjects=Types.NullValue, - targets=[ - Clusters.AccessControl.Structs.AccessControlTargetStruct( - cluster=41, endpoint=0, deviceType=Types.NullValue - ) - ], - ) - ) - - # And write. This is persistent, so only need to be done after we commissioned - # the OTA Provider App. - write_result: Attribute.AttributeWriteResult = ( - await self.chip_controller.WriteAttribute( - ota_provider_node_id, - [(0, Clusters.AccessControl.Attributes.Acl(acl_list))], - ) - ) - if write_result[0].Status != Status.Success: - logging.error( - "Failed writing adjusted OTA Provider App ACL: Status %s.", - str(write_result[0].Status), - ) - await self.remove_node(ota_provider_node_id) - raise UpdateError("Error while setting up OTA Provider.") - except ChipStackError as ex: - logging.exception("Failed adjusting OTA Provider App ACL.", exc_info=ex) - await self.remove_node(ota_provider_node_id) - raise UpdateError("Error while setting up OTA Provider.") from ex - - self._ota_provider.set_node_id(ota_provider_node_id) + await self._initialize_ota_provider(self._ota_provider) - # Notify node about the new update! + # Notify update node about the availability of the OTA Provider. It will query + # the OTA provider and start the update. try: await self.chip_controller.SendCommand( nodeid=node_id, diff --git a/matter_server/server/ota/__init__.py b/matter_server/server/ota/__init__.py index 5dafba54..7060adc4 100644 --- a/matter_server/server/ota/__init__.py +++ b/matter_server/server/ota/__init__.py @@ -16,7 +16,8 @@ "minApplicableSoftwareVersion": 1, "maxApplicableSoftwareVersion": 1, "otaUrl": "https://github.com/agners/matter-linux-example-apps/releases/download/v1.3.0.0/chip-ota-requestor-app-x86-64.ota", - } + "releaseNotesUrl": "https://github.com/agners/matter-linux-example-apps/releases/tag/v1.3.0.0", + }, } diff --git a/matter_server/server/ota/provider.py b/matter_server/server/ota/provider.py index edb5ccd9..2df6b801 100644 --- a/matter_server/server/ota/provider.py +++ b/matter_server/server/ota/provider.py @@ -3,6 +3,7 @@ import asyncio from base64 import b64encode from dataclasses import asdict, dataclass +from datetime import UTC, datetime import functools import hashlib import json @@ -13,6 +14,7 @@ from urllib.parse import unquote, urlparse from aiohttp import ClientError, ClientSession +from aiohttp.client_exceptions import InvalidURL from matter_server.common.errors import UpdateError from matter_server.common.helpers.util import dataclass_from_dict @@ -143,7 +145,9 @@ def set_node_id(self, node_id: int) -> None: self._get_ota_provider_image_list().otaProviderNodeId = node_id - async def _start_ota_provider(self) -> None: + async def start(self) -> None: + """Start the OTA Provider.""" + def _write_ota_provider_image_list_json( ota_provider_image_list_file: Path, ota_provider_image_list: OtaProviderImageList, @@ -174,16 +178,19 @@ def _write_ota_provider_image_list_json( str(self._ota_provider_image_list_file), ] + timestamp = datetime.now(tz=UTC).strftime("%Y%m%d_%H%M%S") + log_file_path = self._ota_provider_dir / f"ota_provider_{timestamp}.log" + + log_file = await loop.run_in_executor(None, log_file_path.open, "w") + LOGGER.info("Starting OTA Provider") self._ota_provider_proc = await asyncio.create_subprocess_exec( - *ota_provider_cmd + *ota_provider_cmd, stdout=log_file, stderr=log_file ) - def start(self) -> None: - """Start the OTA Provider.""" - - loop = asyncio.get_event_loop() - self._ota_provider_task = loop.create_task(self._start_ota_provider()) + self._ota_provider_task = loop.create_task( + self._ota_provider_proc.communicate() + ) async def reset(self) -> None: """Reset the OTA Provider App state.""" @@ -293,12 +300,15 @@ async def download_update(self, update_desc: dict) -> None: raise UpdateError("Checksum mismatch!") LOGGER.info( - "File '%s' downloaded to '%s'", file_name, DEFAULT_UPDATES_PATH + "Update file '%s' downloaded to '%s'", + file_name, + DEFAULT_UPDATES_PATH, ) - except (ClientError, TimeoutError) as err: + except (InvalidURL, ClientError, TimeoutError) as err: LOGGER.error( "Fetching software version failed: error %s", err, exc_info=err ) + raise UpdateError("Fetching software version failed") from err await self.add_update(update_desc, file_path) From 70e9b60204cdcd6032311dfa06f152894e04a695 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Tue, 28 May 2024 16:29:17 +0200 Subject: [PATCH 14/39] Move almost all update logic into ExternalOtaProvider Most update logic is related to the external OTA provider (like commissioning and configuring it). This commit moves most of the update logic into the ExternalOtaProvider class. --- matter_server/server/device_controller.py | 143 ++++------------- matter_server/server/ota/__init__.py | 13 ++ matter_server/server/ota/provider.py | 185 +++++++++++++++++++++- 3 files changed, 217 insertions(+), 124 deletions(-) diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index 927d3121..914c8667 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -18,12 +18,11 @@ from typing import TYPE_CHECKING, Any, cast from chip.ChipDeviceCtrl import ChipDeviceController -from chip.clusters import Attribute, Objects as Clusters, Types +from chip.clusters import Attribute, Objects as Clusters from chip.clusters.Attribute import ValueDecodeFailure from chip.clusters.ClusterObjects import ALL_ATTRIBUTES, ALL_CLUSTERS, Cluster from chip.discovery import DiscoveryType from chip.exceptions import ChipStackError -from chip.interaction_model import Status from zeroconf import BadTypeInNameException, IPVersion, ServiceStateChange, Zeroconf from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf @@ -112,6 +111,12 @@ 0, Clusters.BasicInformation.Attributes.SoftwareVersionString ) ) +OTA_SOFTWARE_UPDATE_REQUESTOR_UPDATE_STATE_ATTRIBUTE_PATH = ( + create_attribute_path_from_attribute( + 0, Clusters.OtaSoftwareUpdateRequestor.Attributes.UpdateState + ) +) + # pylint: disable=too-many-lines,too-many-instance-attributes,too-many-public-methods @@ -910,83 +915,6 @@ async def check_node_update(self, node_id: int) -> dict | None: return await self._check_node_update(node_id) - async def _initialize_ota_provider(self, ota_provider: ExternalOtaProvider) -> None: - """Commissions the OTA Provider.""" - - if self.chip_controller is None: - raise RuntimeError("Device Controller not initialized.") - - # The OTA Provider has not been commissioned yet, let's do it now. - LOGGER.info("Commissioning the built-in OTA Provider App.") - try: - ota_provider_node = await self.commission_on_network( - ota_provider.get_passcode(), - # TODO: Filtering by long discriminator seems broken - # filter_type=FilterType.LONG_DISCRIMINATOR, - # filter=ota_provider.get_descriminator(), - ) - ota_provider_node_id = ota_provider_node.node_id - except NodeCommissionFailed: - LOGGER.error("Failed to commission OTA Provider App!") - return - - LOGGER.info( - "OTA Provider App commissioned with node id %d.", - ota_provider_node_id, - ) - - # Adjust ACL of OTA Requestor such that Node peer-to-peer communication - # is allowed. - try: - read_result = await self.chip_controller.ReadAttribute( - ota_provider_node_id, [(0, Clusters.AccessControl.Attributes.Acl)] - ) - acl_list = cast( - list, - read_result[0][Clusters.AccessControl][ - Clusters.AccessControl.Attributes.Acl - ], - ) - - # Add new ACL entry... - acl_list.append( - Clusters.AccessControl.Structs.AccessControlEntryStruct( - fabricIndex=1, - privilege=Clusters.AccessControl.Enums.AccessControlEntryPrivilegeEnum.kOperate, - authMode=Clusters.AccessControl.Enums.AccessControlEntryAuthModeEnum.kCase, - subjects=Types.NullValue, - targets=[ - Clusters.AccessControl.Structs.AccessControlTargetStruct( - cluster=Clusters.OtaSoftwareUpdateProvider.id, - endpoint=0, - deviceType=Types.NullValue, - ) - ], - ) - ) - - # And write. This is persistent, so only need to be done after we commissioned - # the OTA Provider App. - write_result: Attribute.AttributeWriteResult = ( - await self.chip_controller.WriteAttribute( - ota_provider_node_id, - [(0, Clusters.AccessControl.Attributes.Acl(acl_list))], - ) - ) - if write_result[0].Status != Status.Success: - logging.error( - "Failed writing adjusted OTA Provider App ACL: Status %s.", - str(write_result[0].Status), - ) - await self.remove_node(ota_provider_node_id) - raise UpdateError("Error while setting up OTA Provider.") - except ChipStackError as ex: - logging.exception("Failed adjusting OTA Provider App ACL.", exc_info=ex) - await self.remove_node(ota_provider_node_id) - raise UpdateError("Error while setting up OTA Provider.") from ex - - ota_provider.set_node_id(ota_provider_node_id) - @api_command(APICommand.UPDATE_NODE) async def update_node(self, node_id: int, software_version: int) -> dict | None: """ @@ -1008,48 +936,21 @@ async def update_node(self, node_id: int, software_version: int) -> dict | None: raise RuntimeError("Device Controller not initialized.") if not self._ota_provider: - raise UpdateError("No OTA provider found, updates not possible.") + raise UpdateError("No OTA provider found, updates not possible") + + if self._ota_provider.is_busy(): + raise UpdateError( + "No OTA provider currently busy, updates currently not possible" + ) # Add update to the OTA provider await self._ota_provider.download_update(update) - ota_provider_node_id = self._ota_provider.get_node_id() - if ota_provider_node_id is None: - LOGGER.info("Initializing OTA Provider") - elif ota_provider_node_id not in self._nodes: - LOGGER.warning( - "OTA Provider node id %d no longer exists! Resetting...", - ota_provider_node_id, - ) - await self._ota_provider.reset() - ota_provider_node_id = None - # Make sure any previous instances get stopped - await self._ota_provider.stop() - await self._ota_provider.start() - - # Wait for OTA provider to be ready - # TODO: Detect when OTA provider is ready - await asyncio.sleep(2) - - if not ota_provider_node_id: - await self._initialize_ota_provider(self._ota_provider) - - # Notify update node about the availability of the OTA Provider. It will query - # the OTA provider and start the update. - try: - await self.chip_controller.SendCommand( - nodeid=node_id, - endpoint=0, - payload=Clusters.OtaSoftwareUpdateRequestor.Commands.AnnounceOTAProvider( - providerNodeID=ota_provider_node_id, - vendorID=self.server.vendor_id, - announcementReason=Clusters.OtaSoftwareUpdateRequestor.Enums.AnnouncementReasonEnum.kUpdateAvailable, - endpoint=ExternalOtaProvider.ENDPOINT_ID, - ), - ) - except ChipStackError as ex: - raise UpdateError("Error while announcing OTA Provider to node.") from ex + await self._ota_provider.start_update( + self, + node_id, + ) return update @@ -1142,6 +1043,16 @@ def attribute_updated_callback( # schedule a full interview of the node if the software version changed self._loop.create_task(self.interview_node(node_id)) + # work out if update state changed + if ( + str(path) == OTA_SOFTWARE_UPDATE_REQUESTOR_UPDATE_STATE_ATTRIBUTE_PATH + and new_value != old_value + ): + if self._ota_provider: + loop.create_task( + self._ota_provider.check_update_state(node_id, new_value) + ) + # store updated value in node attributes node.attributes[str(path)] = new_value diff --git a/matter_server/server/ota/__init__.py b/matter_server/server/ota/__init__.py index 7060adc4..db15e602 100644 --- a/matter_server/server/ota/__init__.py +++ b/matter_server/server/ota/__init__.py @@ -18,6 +18,19 @@ "otaUrl": "https://github.com/agners/matter-linux-example-apps/releases/download/v1.3.0.0/chip-ota-requestor-app-x86-64.ota", "releaseNotesUrl": "https://github.com/agners/matter-linux-example-apps/releases/tag/v1.3.0.0", }, + (0x143D, 0x1001): { + "vid": 0x143D, + "pid": 0x1001, + "softwareVersion": 10010011, + "softwareVersionString": "1.1.11-c85ba1e-dirty", + "cdVersionNumber": 1, + "softwareVersionValid": True, + "otaChecksum": "x2sK9xjVuGff0eefYa4cporDO+Z+WVxxw+JP5Ol+5og=", + "otaChecksumType": 1, + "minApplicableSoftwareVersion": 10010000, + "maxApplicableSoftwareVersion": 10010011, + "otaUrl": "https://raw.githubusercontent.com/ChampOnBon/Onvis/master/S4/debug.ota", + }, } diff --git a/matter_server/server/ota/provider.py b/matter_server/server/ota/provider.py index 2df6b801..87862dc4 100644 --- a/matter_server/server/ota/provider.py +++ b/matter_server/server/ota/provider.py @@ -1,5 +1,7 @@ """Handling Matter OTA provider.""" +from __future__ import annotations + import asyncio from base64 import b64encode from dataclasses import asdict, dataclass @@ -10,15 +12,21 @@ import logging from pathlib import Path import secrets -from typing import TYPE_CHECKING, Final +from typing import TYPE_CHECKING, Final, cast from urllib.parse import unquote, urlparse from aiohttp import ClientError, ClientSession from aiohttp.client_exceptions import InvalidURL +from chip.clusters import Attribute, Objects as Clusters, Types +from chip.exceptions import ChipStackError +from chip.interaction_model import Status -from matter_server.common.errors import UpdateError +from matter_server.common.errors import NodeCommissionFailed, NodeNotExists, UpdateError from matter_server.common.helpers.util import dataclass_from_dict +if TYPE_CHECKING: + from matter_server.server.device_controller import MatterDeviceController + if TYPE_CHECKING: from asyncio.subprocess import Process @@ -84,6 +92,7 @@ def __init__(self, ota_provider_dir: Path) -> None: self._ota_provider_image_list: OtaProviderImageList | None = None self._ota_provider_proc: Process | None = None self._ota_provider_task: asyncio.Task | None = None + self._ota_target_node_id: int | None = None async def initialize(self) -> None: """Initialize OTA Provider.""" @@ -125,10 +134,9 @@ def _get_ota_provider_image_list(self) -> OtaProviderImageList: raise RuntimeError("OTA provider image list not initialized.") return self._ota_provider_image_list - def get_node_id(self) -> int | None: - """Get Node ID of the OTA Provider App.""" - - return self._get_ota_provider_image_list().otaProviderNodeId + def is_busy(self) -> bool: + """If OTA Provider is currently busy delivering updates.""" + return self._ota_target_node_id is not None def get_descriminator(self) -> int: """Return OTA Provider App discriminator.""" @@ -145,9 +153,98 @@ def set_node_id(self, node_id: int) -> None: self._get_ota_provider_image_list().otaProviderNodeId = node_id - async def start(self) -> None: + def get_node_id(self) -> int | None: + """Get Node ID of the OTA Provider App.""" + + return self._get_ota_provider_image_list().otaProviderNodeId + + async def _initialize(self, device_controller: MatterDeviceController) -> None: + """Commissions the OTA Provider.""" + + if device_controller.chip_controller is None: + raise RuntimeError("Device Controller not initialized.") + + # The OTA Provider has not been commissioned yet, let's do it now. + LOGGER.info("Commissioning the built-in OTA Provider App.") + try: + ota_provider_node = await device_controller.commission_on_network( + self.get_passcode(), + # TODO: Filtering by long discriminator seems broken + # filter_type=FilterType.LONG_DISCRIMINATOR, + # filter=ota_provider.get_descriminator(), + ) + ota_provider_node_id = ota_provider_node.node_id + except NodeCommissionFailed: + LOGGER.error("Failed to commission OTA Provider App!") + return + + LOGGER.info( + "OTA Provider App commissioned with node id %d.", + ota_provider_node_id, + ) + + # Adjust ACL of OTA Requestor such that Node peer-to-peer communication + # is allowed. + try: + read_result = await device_controller.chip_controller.ReadAttribute( + ota_provider_node_id, [(0, Clusters.AccessControl.Attributes.Acl)] + ) + acl_list = cast( + list, + read_result[0][Clusters.AccessControl][ + Clusters.AccessControl.Attributes.Acl + ], + ) + + # Add new ACL entry... + acl_list.append( + Clusters.AccessControl.Structs.AccessControlEntryStruct( + fabricIndex=1, + privilege=Clusters.AccessControl.Enums.AccessControlEntryPrivilegeEnum.kOperate, + authMode=Clusters.AccessControl.Enums.AccessControlEntryAuthModeEnum.kCase, + subjects=Types.NullValue, + targets=[ + Clusters.AccessControl.Structs.AccessControlTargetStruct( + cluster=Clusters.OtaSoftwareUpdateProvider.id, + endpoint=0, + deviceType=Types.NullValue, + ) + ], + ) + ) + + # And write. This is persistent, so only need to be done after we commissioned + # the OTA Provider App. + write_result: Attribute.AttributeWriteResult = ( + await device_controller.chip_controller.WriteAttribute( + ota_provider_node_id, + [(0, Clusters.AccessControl.Attributes.Acl(acl_list))], + ) + ) + if write_result[0].Status != Status.Success: + logging.error( + "Failed writing adjusted OTA Provider App ACL: Status %s.", + str(write_result[0].Status), + ) + await device_controller.remove_node(ota_provider_node_id) + raise UpdateError("Error while setting up OTA Provider.") + except ChipStackError as ex: + logging.exception("Failed adjusting OTA Provider App ACL.", exc_info=ex) + await device_controller.remove_node(ota_provider_node_id) + raise UpdateError("Error while setting up OTA Provider.") from ex + + self.set_node_id(ota_provider_node_id) + + async def start_update( + self, device_controller: MatterDeviceController, node_id: int + ) -> None: """Start the OTA Provider.""" + if device_controller.chip_controller is None: + raise RuntimeError("Device Controller not initialized.") + + self._ota_target_node_id = node_id + def _write_ota_provider_image_list_json( ota_provider_image_list_file: Path, ota_provider_image_list: OtaProviderImageList, @@ -192,7 +289,51 @@ def _write_ota_provider_image_list_json( self._ota_provider_proc.communicate() ) - async def reset(self) -> None: + # Wait for OTA provider to be ready + # TODO: Detect when OTA provider is ready + await asyncio.sleep(2) + + # Handle if user deleted the OTA Provider node. + ota_provider_node_id = self.get_node_id() + if ota_provider_node_id is not None: + try: + device_controller.get_node(ota_provider_node_id) + except NodeNotExists: + LOGGER.warning( + "OTA Provider node id %d not known by device controller! Resetting...", + ota_provider_node_id, + ) + await self._reset() + ota_provider_node_id = None + + # Commission and prepare OTA Provider if not initialized yet. + # Use "ota_provider_node_id" to indicate if OTA Provider is setup or not. + try: + if ota_provider_node_id is None: + LOGGER.info("Initializing OTA Provider") + await self._initialize(device_controller) + finally: + self._ota_target_node_id = None + + # Notify update node about the availability of the OTA Provider. It will query + # the OTA provider and start the update. + try: + await device_controller.chip_controller.SendCommand( + nodeid=node_id, + endpoint=0, + payload=Clusters.OtaSoftwareUpdateRequestor.Commands.AnnounceOTAProvider( + providerNodeID=ota_provider_node_id, + vendorID=device_controller.server.vendor_id, + announcementReason=Clusters.OtaSoftwareUpdateRequestor.Enums.AnnouncementReasonEnum.kUpdateAvailable, + endpoint=ExternalOtaProvider.ENDPOINT_ID, + ), + ) + except ChipStackError as ex: + raise UpdateError("Error while announcing OTA Provider to node.") from ex + finally: + self._ota_target_node_id = None + + async def _reset(self) -> None: """Reset the OTA Provider App state.""" def _remove_update_data(ota_provider_dir: Path) -> None: @@ -312,3 +453,31 @@ async def download_update(self, update_desc: dict) -> None: raise UpdateError("Fetching software version failed") from err await self.add_update(update_desc, file_path) + + async def check_update_state( + self, + node_id: int, + update_state: Clusters.OtaSoftwareUpdateRequestor.Enums.UpdateStateEnum, + ) -> None: + """ + Check the update state of a node and take appropriate action. + + Args: + node_id: The ID of the node. + update_state: The update state of the node. + """ + + if self._ota_target_node_id is None: + return + + if self._ota_target_node_id != node_id: + return + + # Update state of target node changed, check if update is done. + if ( + update_state + == Clusters.OtaSoftwareUpdateRequestor.Enums.UpdateStateEnum.kIdle + ): + LOGGER.info("Update of node %d done.", node_id) + await self.stop() + self._ota_target_node_id = None From c67c8506d67b6facc88a93b9385afd42e4fefcab Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Wed, 29 May 2024 15:59:42 +0200 Subject: [PATCH 15/39] Update implementation to work with latest refactoring Make use of the new ChipDeviceControllerWrapper to communicate with the device directly. This avoids unnecessary node interviewing and showing up on the controller side. --- matter_server/server/device_controller.py | 10 +-- matter_server/server/ota/provider.py | 99 +++++++++++------------ 2 files changed, 49 insertions(+), 60 deletions(-) diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index 914c8667..ecd8fe16 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -159,7 +159,7 @@ def __init__( self._polled_attributes: dict[int, set[str]] = {} self._custom_attribute_poller_timer: asyncio.TimerHandle | None = None self._custom_attribute_poller_task: asyncio.Task | None = None - self._ota_provider = ExternalOtaProvider(ota_provider_dir) + self._ota_provider = ExternalOtaProvider(server.vendor_id, ota_provider_dir) async def initialize(self) -> None: """Initialize the device controller.""" @@ -932,12 +932,6 @@ async def update_node(self, node_id: int, software_version: int) -> dict | None: f"Software version {software_version} is not available for node {node_id}." ) - if self.chip_controller is None: - raise RuntimeError("Device Controller not initialized.") - - if not self._ota_provider: - raise UpdateError("No OTA provider found, updates not possible") - if self._ota_provider.is_busy(): raise UpdateError( "No OTA provider currently busy, updates currently not possible" @@ -948,7 +942,7 @@ async def update_node(self, node_id: int, software_version: int) -> dict | None: # Make sure any previous instances get stopped await self._ota_provider.start_update( - self, + self._chip_device_controller, node_id, ) diff --git a/matter_server/server/ota/provider.py b/matter_server/server/ota/provider.py index 87862dc4..16b50fdf 100644 --- a/matter_server/server/ota/provider.py +++ b/matter_server/server/ota/provider.py @@ -18,22 +18,26 @@ from aiohttp import ClientError, ClientSession from aiohttp.client_exceptions import InvalidURL from chip.clusters import Attribute, Objects as Clusters, Types +from chip.discovery import FilterType from chip.exceptions import ChipStackError from chip.interaction_model import Status -from matter_server.common.errors import NodeCommissionFailed, NodeNotExists, UpdateError +from matter_server.common.errors import UpdateError from matter_server.common.helpers.util import dataclass_from_dict -if TYPE_CHECKING: - from matter_server.server.device_controller import MatterDeviceController - if TYPE_CHECKING: from asyncio.subprocess import Process + from chip.native import PyChipError + + from matter_server.server.sdk import ChipDeviceControllerWrapper + LOGGER = logging.getLogger(__name__) DEFAULT_UPDATES_PATH: Final[Path] = Path("updates") +DEFAULT_OTA_PROVIDER_NODE_ID = 999900 + # From Matter SDK src/app/ota_image_tool.py CHECHKSUM_TYPES: Final[dict[int, str]] = { 1: "sha256", @@ -85,8 +89,9 @@ class ExternalOtaProvider: ENDPOINT_ID: Final[int] = 0 - def __init__(self, ota_provider_dir: Path) -> None: + def __init__(self, vendor_id: int, ota_provider_dir: Path) -> None: """Initialize the OTA provider.""" + self._vendor_id: int = vendor_id self._ota_provider_dir: Path = ota_provider_dir self._ota_provider_image_list_file: Path = ota_provider_dir / "updates.json" self._ota_provider_image_list: OtaProviderImageList | None = None @@ -158,40 +163,45 @@ def get_node_id(self) -> int | None: return self._get_ota_provider_image_list().otaProviderNodeId - async def _initialize(self, device_controller: MatterDeviceController) -> None: + async def _initialize( + self, chip_device_controller: ChipDeviceControllerWrapper + ) -> None: """Commissions the OTA Provider.""" - if device_controller.chip_controller is None: - raise RuntimeError("Device Controller not initialized.") - # The OTA Provider has not been commissioned yet, let's do it now. LOGGER.info("Commissioning the built-in OTA Provider App.") - try: - ota_provider_node = await device_controller.commission_on_network( - self.get_passcode(), - # TODO: Filtering by long discriminator seems broken - # filter_type=FilterType.LONG_DISCRIMINATOR, - # filter=ota_provider.get_descriminator(), + + res: PyChipError = await chip_device_controller.commission_on_network( + DEFAULT_OTA_PROVIDER_NODE_ID, + self.get_passcode(), + # TODO: Filtering by long discriminator seems broken + disc_filter_type=FilterType.LONG_DISCRIMINATOR, + disc_filter=self.get_descriminator(), + ) + if not res.is_success: + await self.stop() + raise UpdateError( + f"Failed to commission OTA Provider App: SDK Error {res.code}" ) - ota_provider_node_id = ota_provider_node.node_id - except NodeCommissionFailed: - LOGGER.error("Failed to commission OTA Provider App!") - return LOGGER.info( "OTA Provider App commissioned with node id %d.", - ota_provider_node_id, + DEFAULT_OTA_PROVIDER_NODE_ID, ) # Adjust ACL of OTA Requestor such that Node peer-to-peer communication # is allowed. try: - read_result = await device_controller.chip_controller.ReadAttribute( - ota_provider_node_id, [(0, Clusters.AccessControl.Attributes.Acl)] + read_result = cast( + Attribute.AsyncReadTransaction.ReadResponse, + await chip_device_controller.read_attribute( + DEFAULT_OTA_PROVIDER_NODE_ID, + [(0, Clusters.AccessControl.Attributes.Acl)], + ), ) acl_list = cast( list, - read_result[0][Clusters.AccessControl][ + read_result.attributes[0][Clusters.AccessControl][ Clusters.AccessControl.Attributes.Acl ], ) @@ -216,8 +226,8 @@ async def _initialize(self, device_controller: MatterDeviceController) -> None: # And write. This is persistent, so only need to be done after we commissioned # the OTA Provider App. write_result: Attribute.AttributeWriteResult = ( - await device_controller.chip_controller.WriteAttribute( - ota_provider_node_id, + await chip_device_controller.write_attribute( + DEFAULT_OTA_PROVIDER_NODE_ID, [(0, Clusters.AccessControl.Attributes.Acl(acl_list))], ) ) @@ -226,22 +236,17 @@ async def _initialize(self, device_controller: MatterDeviceController) -> None: "Failed writing adjusted OTA Provider App ACL: Status %s.", str(write_result[0].Status), ) - await device_controller.remove_node(ota_provider_node_id) + await self.stop() raise UpdateError("Error while setting up OTA Provider.") except ChipStackError as ex: logging.exception("Failed adjusting OTA Provider App ACL.", exc_info=ex) - await device_controller.remove_node(ota_provider_node_id) + await self.stop() raise UpdateError("Error while setting up OTA Provider.") from ex - self.set_node_id(ota_provider_node_id) - async def start_update( - self, device_controller: MatterDeviceController, node_id: int + self, chip_device_controller: ChipDeviceControllerWrapper, node_id: int ) -> None: - """Start the OTA Provider.""" - - if device_controller.chip_controller is None: - raise RuntimeError("Device Controller not initialized.") + """Start the OTA Provider and trigger the update.""" self._ota_target_node_id = node_id @@ -291,39 +296,29 @@ def _write_ota_provider_image_list_json( # Wait for OTA provider to be ready # TODO: Detect when OTA provider is ready - await asyncio.sleep(2) + await asyncio.sleep(3) # Handle if user deleted the OTA Provider node. ota_provider_node_id = self.get_node_id() - if ota_provider_node_id is not None: - try: - device_controller.get_node(ota_provider_node_id) - except NodeNotExists: - LOGGER.warning( - "OTA Provider node id %d not known by device controller! Resetting...", - ota_provider_node_id, - ) - await self._reset() - ota_provider_node_id = None # Commission and prepare OTA Provider if not initialized yet. # Use "ota_provider_node_id" to indicate if OTA Provider is setup or not. try: if ota_provider_node_id is None: LOGGER.info("Initializing OTA Provider") - await self._initialize(device_controller) + await self._initialize(chip_device_controller) finally: self._ota_target_node_id = None # Notify update node about the availability of the OTA Provider. It will query # the OTA provider and start the update. try: - await device_controller.chip_controller.SendCommand( - nodeid=node_id, - endpoint=0, - payload=Clusters.OtaSoftwareUpdateRequestor.Commands.AnnounceOTAProvider( - providerNodeID=ota_provider_node_id, - vendorID=device_controller.server.vendor_id, + await chip_device_controller.send_command( + node_id, + endpoint_id=0, + command=Clusters.OtaSoftwareUpdateRequestor.Commands.AnnounceOTAProvider( + providerNodeID=DEFAULT_OTA_PROVIDER_NODE_ID, + vendorID=self._vendor_id, announcementReason=Clusters.OtaSoftwareUpdateRequestor.Enums.AnnouncementReasonEnum.kUpdateAvailable, endpoint=ExternalOtaProvider.ENDPOINT_ID, ), From e4bbc4780cc22dd106004ae20b2e4df45a1a813b Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Thu, 30 May 2024 10:25:21 +0200 Subject: [PATCH 16/39] Simplify ExternalOtaProvider --- matter_server/server/device_controller.py | 6 - matter_server/server/ota/provider.py | 170 ++++++++++------------ 2 files changed, 79 insertions(+), 97 deletions(-) diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index ecd8fe16..8bb72981 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -43,7 +43,6 @@ NodeNotReady, NodeNotResolving, UpdateCheckError, - UpdateError, ) from ..common.helpers.api import api_command from ..common.helpers.json import JSON_DECODE_EXCEPTIONS, json_loads @@ -932,11 +931,6 @@ async def update_node(self, node_id: int, software_version: int) -> dict | None: f"Software version {software_version} is not available for node {node_id}." ) - if self._ota_provider.is_busy(): - raise UpdateError( - "No OTA provider currently busy, updates currently not possible" - ) - # Add update to the OTA provider await self._ota_provider.download_update(update) diff --git a/matter_server/server/ota/provider.py b/matter_server/server/ota/provider.py index 16b50fdf..4dbd17fb 100644 --- a/matter_server/server/ota/provider.py +++ b/matter_server/server/ota/provider.py @@ -97,6 +97,7 @@ def __init__(self, vendor_id: int, ota_provider_dir: Path) -> None: self._ota_provider_image_list: OtaProviderImageList | None = None self._ota_provider_proc: Process | None = None self._ota_provider_task: asyncio.Task | None = None + self._ota_provider_lock: asyncio.Lock = asyncio.Lock() self._ota_target_node_id: int | None = None async def initialize(self) -> None: @@ -139,44 +140,20 @@ def _get_ota_provider_image_list(self) -> OtaProviderImageList: raise RuntimeError("OTA provider image list not initialized.") return self._ota_provider_image_list - def is_busy(self) -> bool: - """If OTA Provider is currently busy delivering updates.""" - return self._ota_target_node_id is not None - - def get_descriminator(self) -> int: - """Return OTA Provider App discriminator.""" - - return self._get_ota_provider_image_list().otaProviderDiscriminator - - def get_passcode(self) -> int: - """Return OTA Provider App passcode.""" - - return self._get_ota_provider_image_list().otaProviderPasscode - - def set_node_id(self, node_id: int) -> None: - """Set Node ID of the OTA Provider App.""" - - self._get_ota_provider_image_list().otaProviderNodeId = node_id - - def get_node_id(self) -> int | None: - """Get Node ID of the OTA Provider App.""" - - return self._get_ota_provider_image_list().otaProviderNodeId - - async def _initialize( - self, chip_device_controller: ChipDeviceControllerWrapper - ) -> None: - """Commissions the OTA Provider.""" - - # The OTA Provider has not been commissioned yet, let's do it now. - LOGGER.info("Commissioning the built-in OTA Provider App.") + async def _commission_ota_provider( + self, + passcode: int, + descriminator: int, + chip_device_controller: ChipDeviceControllerWrapper, + ) -> int: + """Commissions the OTA Provider, returns node ID of the commissioned provider.""" res: PyChipError = await chip_device_controller.commission_on_network( DEFAULT_OTA_PROVIDER_NODE_ID, - self.get_passcode(), + passcode, # TODO: Filtering by long discriminator seems broken disc_filter_type=FilterType.LONG_DISCRIMINATOR, - disc_filter=self.get_descriminator(), + disc_filter=descriminator, ) if not res.is_success: await self.stop() @@ -243,35 +220,47 @@ async def _initialize( await self.stop() raise UpdateError("Error while setting up OTA Provider.") from ex + return DEFAULT_OTA_PROVIDER_NODE_ID + + def _write_ota_provider_image_list_json( + self, + ota_provider_image_list_file: Path, + ota_provider_image_list: OtaProviderImageList, + ) -> None: + update_file_dict = asdict(ota_provider_image_list) + with open(ota_provider_image_list_file, "w") as json_file: + json.dump(update_file_dict, json_file, indent=4) + async def start_update( self, chip_device_controller: ChipDeviceControllerWrapper, node_id: int ) -> None: """Start the OTA Provider and trigger the update.""" - self._ota_target_node_id = node_id + # Don't hold the response + if self._ota_provider_lock.locked(): + raise UpdateError( + "OTA Provider already running. Only one update at a time possible." + ) - def _write_ota_provider_image_list_json( - ota_provider_image_list_file: Path, - ota_provider_image_list: OtaProviderImageList, - ) -> None: - update_file_dict = asdict(ota_provider_image_list) - with open(ota_provider_image_list_file, "w") as json_file: - json.dump(update_file_dict, json_file, indent=4) + await self._ota_provider_lock.acquire() + + self._ota_target_node_id = node_id loop = asyncio.get_running_loop() + image_list = self._get_ota_provider_image_list() await loop.run_in_executor( None, - _write_ota_provider_image_list_json, + self._write_ota_provider_image_list_json, self._ota_provider_image_list_file, - self._get_ota_provider_image_list(), + image_list, ) ota_provider_cmd = [ "chip-ota-provider-app", "--discriminator", - str(self._get_ota_provider_image_list().otaProviderDiscriminator), + str(image_list.otaProviderDiscriminator), "--passcode", - str(self._get_ota_provider_image_list().otaProviderPasscode), + str(image_list.otaProviderPasscode), "--secured-device-port", "5565", "--KVS", @@ -285,48 +274,47 @@ def _write_ota_provider_image_list_json( log_file = await loop.run_in_executor(None, log_file_path.open, "w") - LOGGER.info("Starting OTA Provider") - self._ota_provider_proc = await asyncio.create_subprocess_exec( - *ota_provider_cmd, stdout=log_file, stderr=log_file - ) - - self._ota_provider_task = loop.create_task( - self._ota_provider_proc.communicate() - ) - - # Wait for OTA provider to be ready - # TODO: Detect when OTA provider is ready - await asyncio.sleep(3) - - # Handle if user deleted the OTA Provider node. - ota_provider_node_id = self.get_node_id() - - # Commission and prepare OTA Provider if not initialized yet. - # Use "ota_provider_node_id" to indicate if OTA Provider is setup or not. try: - if ota_provider_node_id is None: - LOGGER.info("Initializing OTA Provider") - await self._initialize(chip_device_controller) - finally: - self._ota_target_node_id = None + LOGGER.info("Starting OTA Provider") + self._ota_provider_proc = await asyncio.create_subprocess_exec( + *ota_provider_cmd, stdout=log_file, stderr=log_file + ) - # Notify update node about the availability of the OTA Provider. It will query - # the OTA provider and start the update. - try: - await chip_device_controller.send_command( - node_id, - endpoint_id=0, - command=Clusters.OtaSoftwareUpdateRequestor.Commands.AnnounceOTAProvider( - providerNodeID=DEFAULT_OTA_PROVIDER_NODE_ID, - vendorID=self._vendor_id, - announcementReason=Clusters.OtaSoftwareUpdateRequestor.Enums.AnnouncementReasonEnum.kUpdateAvailable, - endpoint=ExternalOtaProvider.ENDPOINT_ID, - ), + self._ota_provider_task = loop.create_task( + self._ota_provider_proc.communicate() ) - except ChipStackError as ex: - raise UpdateError("Error while announcing OTA Provider to node.") from ex - finally: - self._ota_target_node_id = None + + # Commission and prepare OTA Provider if not initialized yet. + # Use "ota_provider_node_id" to indicate if OTA Provider is setup or not. + if image_list.otaProviderNodeId is None: + LOGGER.info("Commission and initialize OTA Provider") + image_list.otaProviderNodeId = await self._commission_ota_provider( + image_list.otaProviderPasscode, + image_list.otaProviderDiscriminator, + chip_device_controller, + ) + + # Notify update node about the availability of the OTA Provider. It will query + # the OTA provider and start the update. + try: + await chip_device_controller.send_command( + node_id, + endpoint_id=0, + command=Clusters.OtaSoftwareUpdateRequestor.Commands.AnnounceOTAProvider( + providerNodeID=image_list.otaProviderNodeId, + vendorID=self._vendor_id, + announcementReason=Clusters.OtaSoftwareUpdateRequestor.Enums.AnnouncementReasonEnum.kUpdateAvailable, + endpoint=ExternalOtaProvider.ENDPOINT_ID, + ), + ) + except ChipStackError as ex: + raise UpdateError( + "Error while announcing OTA Provider to node." + ) from ex + except UpdateError as ex: + # On error, make sure we stop the OTA Provider. + await self.stop() + raise ex async def _reset(self) -> None: """Reset the OTA Provider App state.""" @@ -352,17 +340,18 @@ async def stop(self) -> None: LOGGER.warning("Stopping OTA Provider failed with error:", exc_info=ex) if self._ota_provider_task: await self._ota_provider_task + self._ota_provider_proc = None + self._ota_provider_task = None async def add_update(self, update_desc: dict, ota_file: Path) -> None: """Add update to the OTA provider.""" local_ota_url = str(ota_file) - for i, device_software in enumerate( - self._get_ota_provider_image_list().deviceSoftwareVersionModel - ): + image_list = self._get_ota_provider_image_list() + for i, device_software in enumerate(image_list.deviceSoftwareVersionModel): if device_software.otaURL == local_ota_url: LOGGER.debug("Device software entry exists already, replacing!") - del self._get_ota_provider_image_list().deviceSoftwareVersionModel[i] + del image_list.deviceSoftwareVersionModel[i] # Convert to OTA Requestor descriptor file new_device_software = DeviceSoftwareVersionModel( @@ -376,9 +365,7 @@ async def add_update(self, update_desc: dict, ota_file: Path) -> None: maxApplicableSoftwareVersion=update_desc["maxApplicableSoftwareVersion"], otaURL=local_ota_url, ) - self._get_ota_provider_image_list().deviceSoftwareVersionModel.append( - new_device_software - ) + image_list.deviceSoftwareVersionModel.append(new_device_software) async def download_update(self, update_desc: dict) -> None: """Download update file from OTA Path and add it to the OTA provider.""" @@ -476,3 +463,4 @@ async def check_update_state( LOGGER.info("Update of node %d done.", node_id) await self.stop() self._ota_target_node_id = None + self._ota_provider_lock.release() From be9ee651efa2cfccaa4bd8d924b64c47d57764d7 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Tue, 4 Jun 2024 15:22:40 +0200 Subject: [PATCH 17/39] Support specific version by string Support updating to a specific software version by string. This requires fetching all versions one by one, but is much more user friendly. The code still checks that all restrictions are met (min/max version, etc). --- matter_server/client/client.py | 6 ++++- matter_server/server/device_controller.py | 6 +++-- matter_server/server/ota/__init__.py | 7 ++++-- matter_server/server/ota/dcl.py | 29 +++++++++++++++++------ 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/matter_server/client/client.py b/matter_server/client/client.py index f9ff1477..b926746a 100644 --- a/matter_server/client/client.py +++ b/matter_server/client/client.py @@ -523,7 +523,11 @@ async def check_node_update(self, node_id: int) -> dict[str, Any]: ) return cast(dict[str, Any], node_update) - async def update_node(self, node_id: int, software_version: int) -> None: + async def update_node( + self, + node_id: int, + software_version: int | str, + ) -> None: """Start node update to a particular version.""" await self.send_command( APICommand.UPDATE_NODE, node_id=node_id, software_version=software_version diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index 8bb72981..69842272 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -915,7 +915,9 @@ async def check_node_update(self, node_id: int) -> dict | None: return await self._check_node_update(node_id) @api_command(APICommand.UPDATE_NODE) - async def update_node(self, node_id: int, software_version: int) -> dict | None: + async def update_node( + self, node_id: int, software_version: int | str + ) -> dict | None: """ Update a node to a new software version. @@ -945,7 +947,7 @@ async def update_node(self, node_id: int, software_version: int) -> dict | None: async def _check_node_update( self, node_id: int, - requested_software_version: int | None = None, + requested_software_version: int | str | None = None, ) -> dict | None: node_logger = LOGGER.getChild(f"node_{node_id}") node = self._nodes[node_id] diff --git a/matter_server/server/ota/__init__.py b/matter_server/server/ota/__init__.py index db15e602..afd49b4f 100644 --- a/matter_server/server/ota/__init__.py +++ b/matter_server/server/ota/__init__.py @@ -38,7 +38,7 @@ async def check_for_update( vid: int, pid: int, current_software_version: int, - requested_software_version: int | None = None, + requested_software_version: int | str | None = None, ) -> None | dict: """Check for software updates.""" if (vid, pid) in HARDCODED_UPDATES: @@ -46,7 +46,10 @@ async def check_for_update( if ( requested_software_version is None or update["softwareVersion"] == requested_software_version + or update["softwareVersionString"] == requested_software_version ): return update - return await dcl.check_for_update(vid, pid, current_software_version) + return await dcl.check_for_update( + vid, pid, current_software_version, requested_software_version + ) diff --git a/matter_server/server/ota/dcl.py b/matter_server/server/ota/dcl.py index 881b0fa0..7038b696 100644 --- a/matter_server/server/ota/dcl.py +++ b/matter_server/server/ota/dcl.py @@ -32,9 +32,15 @@ async def _get_software_version(vid: int, pid: int, software_version: int) -> An async def _check_update_version( - vid: int, pid: int, version: int, current_software_version: int + vid: int, + pid: int, + current_software_version: int, + requested_software_version: int, + requested_software_version_string: str | None = None, ) -> None | dict: - version_res: dict = await _get_software_version(vid, pid, version) + version_res: dict = await _get_software_version( + vid, pid, requested_software_version + ) if not isinstance(version_res, dict): raise TypeError("Unexpected DCL response.") @@ -43,6 +49,14 @@ async def _check_update_version( version_candidate: dict = cast(dict, version_res["modelVersion"]) + # If we are looking for a specific version by string, check if it matches + if ( + requested_software_version_string is not None + and version_candidate["softwareVersionString"] + != requested_software_version_string + ): + return None + # Check minApplicableSoftwareVersion/maxApplicableSoftwareVersion min_sw_version = version_candidate["minApplicableSoftwareVersion"] max_sw_version = version_candidate["maxApplicableSoftwareVersion"] @@ -59,13 +73,14 @@ async def check_for_update( vid: int, pid: int, current_software_version: int, - requested_software_version: int | None = None, + requested_software_version: int | str | None = None, ) -> None | dict: """Check if there is a software update available on the DCL.""" try: - if requested_software_version is not None: + # If a specific version as integer is requested, just fetch it (and hope it exists) + if isinstance(requested_software_version, int): return await _check_update_version( - vid, pid, requested_software_version, current_software_version + vid, pid, current_software_version, requested_software_version ) # Get all versions and check each one of them. @@ -78,7 +93,7 @@ async def check_for_update( if version > current_software_version ] - # Check if there is a newer software version available + # Check if there is a newer software version available, no downgrade possible if not newer_software_versions: LOGGER.info("No newer software version available.") return None @@ -86,7 +101,7 @@ async def check_for_update( # Check if latest firmware is applicable, and backtrack from there for version in sorted(newer_software_versions, reverse=True): if version_candidate := await _check_update_version( - vid, pid, version, current_software_version + vid, pid, current_software_version, version, requested_software_version ): return version_candidate LOGGER.debug("Software version %d not applicable.", version) From 3c33d5ff700eab52c94a097e02313c62796994e0 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Wed, 5 Jun 2024 15:34:36 +0200 Subject: [PATCH 18/39] Use ephemeral OTA Provider instances Instead of trying to reuse the same OTA Provider instance for multiple OTA requests, create a new instance for each request. This is easier to implement and also allows parallel updates. Because the OTA Provider is now ephemeral, we need to commission it on every update. But this is quick and reliable, so not a big deal. To support multiple updates at once, we need to make sure the OTA Providers use a distinct Matter port (hence passing 0) and distinct node ids. The current implementation simply uses the target node id plus a fixed offset. Since a single node can only run one update at a time, this is sufficient. Furthermore, some updates seem to have a difference in reported versionNumberString value in the DCL vs. what is actually in the OTA metadata. Specifically Eve updates from the Testnet DCL are such updates (e.g. 3.2.0 vs 3.2.6705). When using the OTA Provider with the --otaImageList option, this discrepancy is an issue and causes OTA Provider to abort. Using the single OTA update per OTA Provider instance allows us to use --filepath, which doesn't check the versionNumberString. --- matter_server/server/device_controller.py | 53 ++--- matter_server/server/ota/provider.py | 236 ++++++---------------- 2 files changed, 94 insertions(+), 195 deletions(-) diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index 69842272..593aea6c 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -62,7 +62,7 @@ from .const import DATA_MODEL_SCHEMA_VERSION if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Callable, Iterable from pathlib import Path from .server import MatterServer @@ -110,11 +110,6 @@ 0, Clusters.BasicInformation.Attributes.SoftwareVersionString ) ) -OTA_SOFTWARE_UPDATE_REQUESTOR_UPDATE_STATE_ATTRIBUTE_PATH = ( - create_attribute_path_from_attribute( - 0, Clusters.OtaSoftwareUpdateRequestor.Attributes.UpdateState - ) -) # pylint: disable=too-many-lines,too-many-instance-attributes,too-many-public-methods @@ -131,6 +126,7 @@ def __init__( ): """Initialize the device controller.""" self.server = server + self._ota_provider_dir = ota_provider_dir self._chip_device_controller = ChipDeviceControllerWrapper( server, paa_root_cert_dir @@ -158,7 +154,7 @@ def __init__( self._polled_attributes: dict[int, set[str]] = {} self._custom_attribute_poller_timer: asyncio.TimerHandle | None = None self._custom_attribute_poller_task: asyncio.Task | None = None - self._ota_provider = ExternalOtaProvider(server.vendor_id, ota_provider_dir) + self._attribute_update_callbacks: dict[int, list[Callable]] = {} async def initialize(self) -> None: """Initialize the device controller.""" @@ -166,7 +162,6 @@ async def initialize(self) -> None: await self._chip_device_controller.get_compressed_fabric_id() ) self._fabric_id_hex = hex(self._compressed_fabric_id)[2:] - await self._ota_provider.initialize() async def start(self) -> None: """Handle logic on controller start.""" @@ -230,9 +225,6 @@ async def stop(self) -> None: # shutdown the sdk device controller await self._chip_device_controller.shutdown() - # shutdown the OTA Provider - if self._ota_provider: - await self._ota_provider.stop() LOGGER.debug("Stopped.") @property @@ -934,14 +926,29 @@ async def update_node( ) # Add update to the OTA provider - await self._ota_provider.download_update(update) + ota_provider = ExternalOtaProvider( + self.server.vendor_id, self._ota_provider_dir / f"{node_id}" + ) - # Make sure any previous instances get stopped - await self._ota_provider.start_update( - self._chip_device_controller, - node_id, + await ota_provider.initialize() + + await ota_provider.download_update(update) + + self._attribute_update_callbacks.setdefault(node_id, []).append( + ota_provider.check_update_state ) + try: + # Make sure any previous instances get stopped + await ota_provider.start_update( + self._chip_device_controller, + node_id, + ) + finally: + self._attribute_update_callbacks[node_id].remove( + ota_provider.check_update_state + ) + return update async def _check_node_update( @@ -1033,22 +1040,16 @@ def attribute_updated_callback( # schedule a full interview of the node if the software version changed self._loop.create_task(self.interview_node(node_id)) - # work out if update state changed - if ( - str(path) == OTA_SOFTWARE_UPDATE_REQUESTOR_UPDATE_STATE_ATTRIBUTE_PATH - and new_value != old_value - ): - if self._ota_provider: - loop.create_task( - self._ota_provider.check_update_state(node_id, new_value) - ) - # store updated value in node attributes node.attributes[str(path)] = new_value # schedule save to persistent storage self._write_node_state(node_id) + if node_id in self._attribute_update_callbacks: + for callback in self._attribute_update_callbacks[node_id]: + self._loop.create_task(callback(path, old_value, new_value)) + # This callback is running in the CHIP stack thread self.server.signal_event( EventType.ATTRIBUTE_UPDATED, diff --git a/matter_server/server/ota/provider.py b/matter_server/server/ota/provider.py index 4dbd17fb..028deee0 100644 --- a/matter_server/server/ota/provider.py +++ b/matter_server/server/ota/provider.py @@ -4,15 +4,13 @@ import asyncio from base64 import b64encode -from dataclasses import asdict, dataclass from datetime import UTC, datetime import functools import hashlib -import json import logging from pathlib import Path import secrets -from typing import TYPE_CHECKING, Final, cast +from typing import TYPE_CHECKING, Any, Final, cast from urllib.parse import unquote, urlparse from aiohttp import ClientError, ClientSession @@ -23,7 +21,9 @@ from chip.interaction_model import Status from matter_server.common.errors import UpdateError -from matter_server.common.helpers.util import dataclass_from_dict +from matter_server.common.helpers.util import ( + create_attribute_path_from_attribute, +) if TYPE_CHECKING: from asyncio.subprocess import Process @@ -34,9 +34,13 @@ LOGGER = logging.getLogger(__name__) -DEFAULT_UPDATES_PATH: Final[Path] = Path("updates") +DEFAULT_OTA_PROVIDER_NODE_ID: Final[int] = 990000 -DEFAULT_OTA_PROVIDER_NODE_ID = 999900 +OTA_SOFTWARE_UPDATE_REQUESTOR_UPDATE_STATE_ATTRIBUTE_PATH = ( + create_attribute_path_from_attribute( + 0, Clusters.OtaSoftwareUpdateRequestor.Attributes.UpdateState + ) +) # From Matter SDK src/app/ota_image_tool.py CHECHKSUM_TYPES: Final[dict[int, str]] = { @@ -55,31 +59,6 @@ } -@dataclass -class DeviceSoftwareVersionModel: # pylint: disable=C0103 - """Device Software Version Model for OTA Provider JSON descriptor file.""" - - vendorId: int - productId: int - softwareVersion: int - softwareVersionString: str - cDVersionNumber: int - softwareVersionValid: bool - minApplicableSoftwareVersion: int - maxApplicableSoftwareVersion: int - otaURL: str - - -@dataclass -class OtaProviderImageList: # pylint: disable=C0103 - """Update File for OTA Provider JSON descriptor file.""" - - otaProviderDiscriminator: int - otaProviderPasscode: int - otaProviderNodeId: int | None - deviceSoftwareVersionModel: list[DeviceSoftwareVersionModel] - - class ExternalOtaProvider: """Class handling Matter OTA Provider. @@ -93,11 +72,10 @@ def __init__(self, vendor_id: int, ota_provider_dir: Path) -> None: """Initialize the OTA provider.""" self._vendor_id: int = vendor_id self._ota_provider_dir: Path = ota_provider_dir - self._ota_provider_image_list_file: Path = ota_provider_dir / "updates.json" - self._ota_provider_image_list: OtaProviderImageList | None = None + self._ota_file_path: Path | None = None self._ota_provider_proc: Process | None = None self._ota_provider_task: asyncio.Task | None = None - self._ota_provider_lock: asyncio.Lock = asyncio.Lock() + self._ota_done: asyncio.Event = asyncio.Event() self._ota_target_node_id: int | None = None async def initialize(self) -> None: @@ -105,55 +83,25 @@ async def initialize(self) -> None: loop = asyncio.get_event_loop() - # Take existence of image list file as indicator if we need to initialize the - # OTA Provider. - if not await loop.run_in_executor( - None, self._ota_provider_image_list_file.exists - ): - await loop.run_in_executor( - None, functools.partial(DEFAULT_UPDATES_PATH.mkdir, exist_ok=True) - ) - - # Initialize with random data. Node ID will get written once paired by - # device controller. - self._ota_provider_image_list = OtaProviderImageList( - otaProviderDiscriminator=secrets.randbelow(2**12), - otaProviderPasscode=secrets.randbelow(2**21), - otaProviderNodeId=None, - deviceSoftwareVersionModel=[], - ) - else: - - def _read_update_json( - update_json_path: Path, - ) -> None | OtaProviderImageList: - with open(update_json_path, "r") as json_file: - data = json.load(json_file) - return dataclass_from_dict(OtaProviderImageList, data) - - self._ota_provider_image_list = await loop.run_in_executor( - None, _read_update_json, self._ota_provider_image_list_file - ) - - def _get_ota_provider_image_list(self) -> OtaProviderImageList: - if self._ota_provider_image_list is None: - raise RuntimeError("OTA provider image list not initialized.") - return self._ota_provider_image_list + await loop.run_in_executor( + None, functools.partial(self._ota_provider_dir.mkdir, exist_ok=True) + ) async def _commission_ota_provider( self, - passcode: int, - descriminator: int, chip_device_controller: ChipDeviceControllerWrapper, - ) -> int: + passcode: int, + discriminator: int, + ota_provider_node_id: int, + ) -> None: """Commissions the OTA Provider, returns node ID of the commissioned provider.""" res: PyChipError = await chip_device_controller.commission_on_network( - DEFAULT_OTA_PROVIDER_NODE_ID, + ota_provider_node_id, passcode, # TODO: Filtering by long discriminator seems broken disc_filter_type=FilterType.LONG_DISCRIMINATOR, - disc_filter=descriminator, + disc_filter=discriminator, ) if not res.is_success: await self.stop() @@ -163,7 +111,7 @@ async def _commission_ota_provider( LOGGER.info( "OTA Provider App commissioned with node id %d.", - DEFAULT_OTA_PROVIDER_NODE_ID, + ota_provider_node_id, ) # Adjust ACL of OTA Requestor such that Node peer-to-peer communication @@ -172,7 +120,7 @@ async def _commission_ota_provider( read_result = cast( Attribute.AsyncReadTransaction.ReadResponse, await chip_device_controller.read_attribute( - DEFAULT_OTA_PROVIDER_NODE_ID, + ota_provider_node_id, [(0, Clusters.AccessControl.Attributes.Acl)], ), ) @@ -204,7 +152,7 @@ async def _commission_ota_provider( # the OTA Provider App. write_result: Attribute.AttributeWriteResult = ( await chip_device_controller.write_attribute( - DEFAULT_OTA_PROVIDER_NODE_ID, + ota_provider_node_id, [(0, Clusters.AccessControl.Attributes.Acl(acl_list))], ) ) @@ -220,56 +168,33 @@ async def _commission_ota_provider( await self.stop() raise UpdateError("Error while setting up OTA Provider.") from ex - return DEFAULT_OTA_PROVIDER_NODE_ID - - def _write_ota_provider_image_list_json( - self, - ota_provider_image_list_file: Path, - ota_provider_image_list: OtaProviderImageList, - ) -> None: - update_file_dict = asdict(ota_provider_image_list) - with open(ota_provider_image_list_file, "w") as json_file: - json.dump(update_file_dict, json_file, indent=4) - async def start_update( self, chip_device_controller: ChipDeviceControllerWrapper, node_id: int ) -> None: """Start the OTA Provider and trigger the update.""" - # Don't hold the response - if self._ota_provider_lock.locked(): - raise UpdateError( - "OTA Provider already running. Only one update at a time possible." - ) - - await self._ota_provider_lock.acquire() - self._ota_target_node_id = node_id loop = asyncio.get_running_loop() - image_list = self._get_ota_provider_image_list() - await loop.run_in_executor( - None, - self._write_ota_provider_image_list_json, - self._ota_provider_image_list_file, - image_list, - ) + ota_provider_passcode = secrets.randbelow(2**21) + ota_provider_discriminator = secrets.randbelow(2**12) + + timestamp = datetime.now(tz=UTC).strftime("%Y%m%d_%H%M%S") ota_provider_cmd = [ "chip-ota-provider-app", - "--discriminator", - str(image_list.otaProviderDiscriminator), "--passcode", - str(image_list.otaProviderPasscode), + str(ota_provider_passcode), + "--discriminator", + str(ota_provider_discriminator), "--secured-device-port", - "5565", + "0", "--KVS", - str(self._ota_provider_dir / "chip_kvs_ota_provider"), - "--otaImageList", - str(self._ota_provider_image_list_file), + str(self._ota_provider_dir / f"chip_kvs_ota_provider_{timestamp}"), + "--filepath", + str(self._ota_file_path), ] - timestamp = datetime.now(tz=UTC).strftime("%Y%m%d_%H%M%S") log_file_path = self._ota_provider_dir / f"ota_provider_{timestamp}.log" log_file = await loop.run_in_executor(None, log_file_path.open, "w") @@ -284,15 +209,17 @@ async def start_update( self._ota_provider_proc.communicate() ) - # Commission and prepare OTA Provider if not initialized yet. - # Use "ota_provider_node_id" to indicate if OTA Provider is setup or not. - if image_list.otaProviderNodeId is None: - LOGGER.info("Commission and initialize OTA Provider") - image_list.otaProviderNodeId = await self._commission_ota_provider( - image_list.otaProviderPasscode, - image_list.otaProviderDiscriminator, - chip_device_controller, - ) + # Commission and prepare ephemeral OTA Provider + LOGGER.info("Commission and initialize OTA Provider") + ota_provider_node_id = ( + DEFAULT_OTA_PROVIDER_NODE_ID + self._ota_target_node_id + ) + await self._commission_ota_provider( + chip_device_controller, + ota_provider_passcode, + ota_provider_discriminator, + ota_provider_node_id, + ) # Notify update node about the availability of the OTA Provider. It will query # the OTA provider and start the update. @@ -301,7 +228,7 @@ async def start_update( node_id, endpoint_id=0, command=Clusters.OtaSoftwareUpdateRequestor.Commands.AnnounceOTAProvider( - providerNodeID=image_list.otaProviderNodeId, + providerNodeID=ota_provider_node_id, vendorID=self._vendor_id, announcementReason=Clusters.OtaSoftwareUpdateRequestor.Enums.AnnouncementReasonEnum.kUpdateAvailable, endpoint=ExternalOtaProvider.ENDPOINT_ID, @@ -311,10 +238,11 @@ async def start_update( raise UpdateError( "Error while announcing OTA Provider to node." ) from ex - except UpdateError as ex: - # On error, make sure we stop the OTA Provider. + + await self._ota_done.wait() + finally: await self.stop() - raise ex + self._ota_target_node_id = None async def _reset(self) -> None: """Reset the OTA Provider App state.""" @@ -343,30 +271,6 @@ async def stop(self) -> None: self._ota_provider_proc = None self._ota_provider_task = None - async def add_update(self, update_desc: dict, ota_file: Path) -> None: - """Add update to the OTA provider.""" - - local_ota_url = str(ota_file) - image_list = self._get_ota_provider_image_list() - for i, device_software in enumerate(image_list.deviceSoftwareVersionModel): - if device_software.otaURL == local_ota_url: - LOGGER.debug("Device software entry exists already, replacing!") - del image_list.deviceSoftwareVersionModel[i] - - # Convert to OTA Requestor descriptor file - new_device_software = DeviceSoftwareVersionModel( - vendorId=update_desc["vid"], - productId=update_desc["pid"], - softwareVersion=update_desc["softwareVersion"], - softwareVersionString=update_desc["softwareVersionString"], - cDVersionNumber=update_desc["cdVersionNumber"], - softwareVersionValid=update_desc["softwareVersionValid"], - minApplicableSoftwareVersion=update_desc["minApplicableSoftwareVersion"], - maxApplicableSoftwareVersion=update_desc["maxApplicableSoftwareVersion"], - otaURL=local_ota_url, - ) - image_list.deviceSoftwareVersionModel.append(new_device_software) - async def download_update(self, update_desc: dict) -> None: """Download update file from OTA Path and add it to the OTA provider.""" @@ -375,11 +279,8 @@ async def download_update(self, update_desc: dict) -> None: file_name = unquote(Path(parsed_url.path).name) loop = asyncio.get_running_loop() - await loop.run_in_executor( - None, functools.partial(DEFAULT_UPDATES_PATH.mkdir, exist_ok=True) - ) - file_path = DEFAULT_UPDATES_PATH / file_name + file_path = self._ota_provider_dir / file_name try: checksum_alg = None @@ -425,7 +326,7 @@ async def download_update(self, update_desc: dict) -> None: LOGGER.info( "Update file '%s' downloaded to '%s'", file_name, - DEFAULT_UPDATES_PATH, + self._ota_provider_dir, ) except (InvalidURL, ClientError, TimeoutError) as err: @@ -434,33 +335,30 @@ async def download_update(self, update_desc: dict) -> None: ) raise UpdateError("Fetching software version failed") from err - await self.add_update(update_desc, file_path) + self._ota_file_path = file_path async def check_update_state( self, - node_id: int, - update_state: Clusters.OtaSoftwareUpdateRequestor.Enums.UpdateStateEnum, + path: Attribute.AttributePath, + old_value: Any, + new_value: Any, ) -> None: - """ - Check the update state of a node and take appropriate action. + """Check the update state of a node and take appropriate action.""" - Args: - node_id: The ID of the node. - update_state: The update state of the node. - """ - - if self._ota_target_node_id is None: + LOGGER.info("Update state changed: %s, %s %s", str(path), old_value, new_value) + if str(path) != OTA_SOFTWARE_UPDATE_REQUESTOR_UPDATE_STATE_ATTRIBUTE_PATH: return - if self._ota_target_node_id != node_id: - return + update_state = cast( + Clusters.OtaSoftwareUpdateRequestor.Enums.UpdateStateEnum, new_value + ) # Update state of target node changed, check if update is done. if ( update_state == Clusters.OtaSoftwareUpdateRequestor.Enums.UpdateStateEnum.kIdle ): - LOGGER.info("Update of node %d done.", node_id) - await self.stop() - self._ota_target_node_id = None - self._ota_provider_lock.release() + LOGGER.info( + "Node %d update state idle, assuming done.", self._ota_target_node_id + ) + self._ota_done.set() From 07b8254a093a5249312964ed88050d5fbf92b976 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Wed, 5 Jun 2024 16:28:34 +0200 Subject: [PATCH 19/39] Raise update error if the node moves from querying to idle When the node's UpdateStateEnum changes from Querying to Idle it means the update file did not get processed. This could be due to temporary network issues or the update file not being honored by the target node. --- matter_server/server/ota/provider.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/matter_server/server/ota/provider.py b/matter_server/server/ota/provider.py index 028deee0..a646d0d6 100644 --- a/matter_server/server/ota/provider.py +++ b/matter_server/server/ota/provider.py @@ -353,11 +353,21 @@ async def check_update_state( Clusters.OtaSoftwareUpdateRequestor.Enums.UpdateStateEnum, new_value ) + old_update_state = cast( + Clusters.OtaSoftwareUpdateRequestor.Enums.UpdateStateEnum, old_value + ) + # Update state of target node changed, check if update is done. if ( update_state == Clusters.OtaSoftwareUpdateRequestor.Enums.UpdateStateEnum.kIdle ): + if ( + old_update_state + == Clusters.OtaSoftwareUpdateRequestor.Enums.UpdateStateEnum.kQuerying + ): + raise UpdateError("Target node did not process the update file") + LOGGER.info( "Node %d update state idle, assuming done.", self._ota_target_node_id ) From 02d43d6c3d4128b1431eda8f7ba189e76b3da434 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Wed, 5 Jun 2024 21:08:14 +0200 Subject: [PATCH 20/39] Improve logging and use Future to mark completion Instead of using an Event use a Future to mark completion of the update. This allows to set an Exception in case we see update state transitions which are unexpected (specifically kQuerying -> kIdle). --- matter_server/server/device_controller.py | 5 +++ matter_server/server/ota/provider.py | 39 +++++++++++------------ 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index 593aea6c..49e526dc 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -919,6 +919,9 @@ async def update_node( notify the node about the new update. """ + node_logger = LOGGER.getChild(f"node_{node_id}") + node_logger.info("Update to software version %r", software_version) + update = await self._check_node_update(node_id, software_version) if update is None: raise UpdateCheckError( @@ -932,6 +935,7 @@ async def update_node( await ota_provider.initialize() + node_logger.info("Downloading update from '%s'", update["otaUrl"]) await ota_provider.download_update(update) self._attribute_update_callbacks.setdefault(node_id, []).append( @@ -940,6 +944,7 @@ async def update_node( try: # Make sure any previous instances get stopped + node_logger.info("Starting update using OTA Provider.") await ota_provider.start_update( self._chip_device_controller, node_id, diff --git a/matter_server/server/ota/provider.py b/matter_server/server/ota/provider.py index a646d0d6..6dc5f755 100644 --- a/matter_server/server/ota/provider.py +++ b/matter_server/server/ota/provider.py @@ -75,7 +75,7 @@ def __init__(self, vendor_id: int, ota_provider_dir: Path) -> None: self._ota_file_path: Path | None = None self._ota_provider_proc: Process | None = None self._ota_provider_task: asyncio.Task | None = None - self._ota_done: asyncio.Event = asyncio.Event() + self._ota_done: asyncio.Future = asyncio.Future() self._ota_target_node_id: int | None = None async def initialize(self) -> None: @@ -161,11 +161,9 @@ async def _commission_ota_provider( "Failed writing adjusted OTA Provider App ACL: Status %s.", str(write_result[0].Status), ) - await self.stop() raise UpdateError("Error while setting up OTA Provider.") except ChipStackError as ex: logging.exception("Failed adjusting OTA Provider App ACL.", exc_info=ex) - await self.stop() raise UpdateError("Error while setting up OTA Provider.") from ex async def start_update( @@ -239,8 +237,11 @@ async def start_update( "Error while announcing OTA Provider to node." ) from ex - await self._ota_done.wait() + LOGGER.info("Waiting for target node update state change") + await self._ota_done + LOGGER.info("OTA update finished successfully") finally: + LOGGER.info("Cleaning up OTA provider") await self.stop() self._ota_target_node_id = None @@ -345,30 +346,28 @@ async def check_update_state( ) -> None: """Check the update state of a node and take appropriate action.""" - LOGGER.info("Update state changed: %s, %s %s", str(path), old_value, new_value) if str(path) != OTA_SOFTWARE_UPDATE_REQUESTOR_UPDATE_STATE_ATTRIBUTE_PATH: return - update_state = cast( - Clusters.OtaSoftwareUpdateRequestor.Enums.UpdateStateEnum, new_value - ) + UpdateState = Clusters.OtaSoftwareUpdateRequestor.Enums.UpdateStateEnum # noqa: N806 - old_update_state = cast( - Clusters.OtaSoftwareUpdateRequestor.Enums.UpdateStateEnum, old_value + new_update_state = UpdateState(new_value) + old_update_state = UpdateState(old_value) + + LOGGER.info( + "Update state changed from %r to %r", + old_update_state, + new_update_state, ) # Update state of target node changed, check if update is done. - if ( - update_state - == Clusters.OtaSoftwareUpdateRequestor.Enums.UpdateStateEnum.kIdle - ): - if ( - old_update_state - == Clusters.OtaSoftwareUpdateRequestor.Enums.UpdateStateEnum.kQuerying - ): - raise UpdateError("Target node did not process the update file") + if new_update_state == UpdateState.kIdle: + if old_update_state == UpdateState.kQuerying: + self._ota_done.set_exception( + UpdateError("Target node did not process the update file") + ) LOGGER.info( "Node %d update state idle, assuming done.", self._ota_target_node_id ) - self._ota_done.set() + self._ota_done.set_result(None) From 56d5b06e451c3a266ab64004d3363a209386c60c Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Wed, 5 Jun 2024 21:27:28 +0200 Subject: [PATCH 21/39] Make sure that only one updates is running at a time Running multiple updates on the same node doesn't make sense. This should be mostly handled by the client/UX. But we do check on server side for the critical path. --- matter_server/server/device_controller.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index 49e526dc..63cdc082 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -43,6 +43,7 @@ NodeNotReady, NodeNotResolving, UpdateCheckError, + UpdateError, ) from ..common.helpers.api import api_command from ..common.helpers.json import JSON_DECODE_EXCEPTIONS, json_loads @@ -139,6 +140,7 @@ def __init__( self._wifi_credentials_set: bool = False self._thread_credentials_set: bool = False self._nodes_in_setup: set[int] = set() + self._nodes_in_ota: set[int] = set() self._node_last_seen: dict[int, float] = {} self._nodes: dict[int, MatterNodeData] = {} self._last_known_ip_addresses: dict[int, list[str]] = {} @@ -943,6 +945,13 @@ async def update_node( ) try: + if node_id in self._nodes_in_ota: + raise UpdateError( + f"Node {node_id} is already in the process of updating." + ) + + self._nodes_in_ota.add(node_id) + # Make sure any previous instances get stopped node_logger.info("Starting update using OTA Provider.") await ota_provider.start_update( @@ -953,6 +962,7 @@ async def update_node( self._attribute_update_callbacks[node_id].remove( ota_provider.check_update_state ) + self._nodes_in_ota.remove(node_id) return update From 2f535aaff3b122be5191c0642c0438307490d4c3 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Thu, 20 Jun 2024 19:01:39 +0200 Subject: [PATCH 22/39] Use new commissioning API --- matter_server/server/ota/provider.py | 35 +++++++++++----------------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/matter_server/server/ota/provider.py b/matter_server/server/ota/provider.py index 6dc5f755..91418cbf 100644 --- a/matter_server/server/ota/provider.py +++ b/matter_server/server/ota/provider.py @@ -28,8 +28,6 @@ if TYPE_CHECKING: from asyncio.subprocess import Process - from chip.native import PyChipError - from matter_server.server.sdk import ChipDeviceControllerWrapper LOGGER = logging.getLogger(__name__) @@ -96,27 +94,22 @@ async def _commission_ota_provider( ) -> None: """Commissions the OTA Provider, returns node ID of the commissioned provider.""" - res: PyChipError = await chip_device_controller.commission_on_network( - ota_provider_node_id, - passcode, - # TODO: Filtering by long discriminator seems broken - disc_filter_type=FilterType.LONG_DISCRIMINATOR, - disc_filter=discriminator, - ) - if not res.is_success: - await self.stop() - raise UpdateError( - f"Failed to commission OTA Provider App: SDK Error {res.code}" - ) - - LOGGER.info( - "OTA Provider App commissioned with node id %d.", - ota_provider_node_id, - ) - # Adjust ACL of OTA Requestor such that Node peer-to-peer communication # is allowed. try: + commissioned_node_id = await chip_device_controller.commission_on_network( + ota_provider_node_id, + passcode, + disc_filter_type=FilterType.LONG_DISCRIMINATOR, + disc_filter=discriminator, + ) + assert commissioned_node_id == ota_provider_node_id + + LOGGER.info( + "OTA Provider App commissioned with node id %d.", + ota_provider_node_id, + ) + read_result = cast( Attribute.AsyncReadTransaction.ReadResponse, await chip_device_controller.read_attribute( @@ -163,7 +156,7 @@ async def _commission_ota_provider( ) raise UpdateError("Error while setting up OTA Provider.") except ChipStackError as ex: - logging.exception("Failed adjusting OTA Provider App ACL.", exc_info=ex) + logging.exception("Failed setting up OTA Provider.", exc_info=ex) raise UpdateError("Error while setting up OTA Provider.") from ex async def start_update( From 683b33f5eb87bb7ec8d4a1ec3a09b981f8470ead Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Mon, 24 Jun 2024 18:33:40 +0200 Subject: [PATCH 23/39] Ignore when there is no software version info on DCL Don't raise an exception when there is no software version info on DCL. It is perfectly fine to operate such nodes. An informational message is good enough for this case. --- matter_server/server/ota/dcl.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/matter_server/server/ota/dcl.py b/matter_server/server/ota/dcl.py index 7038b696..342f6337 100644 --- a/matter_server/server/ota/dcl.py +++ b/matter_server/server/ota/dcl.py @@ -1,5 +1,6 @@ """Handle OTA software version endpoints of the DCL.""" +from http import HTTPStatus import logging from typing import Any, cast @@ -13,11 +14,14 @@ async def _get_software_versions(vid: int, pid: int) -> Any: """Check DCL if there are updates available for a particular node.""" - async with ClientSession(raise_for_status=True) as http_session: + async with ClientSession(raise_for_status=False) as http_session: # fetch the paa certificates list async with http_session.get( f"{DCL_PRODUCTION_URL}/dcl/model/versions/{vid}/{pid}" ) as response: + if response.status == HTTPStatus.NOT_FOUND: + return None + response.raise_for_status() return await response.json() @@ -85,6 +89,9 @@ async def check_for_update( # Get all versions and check each one of them. versions = await _get_software_versions(vid, pid) + if versions is None: + LOGGER.info("There is no update information for this device on the DCL.") + return None all_software_versions: list[int] = versions["modelVersions"]["softwareVersions"] newer_software_versions = [ From 76ed950ba1124876831744e3361f2733b1bb17a6 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Mon, 24 Jun 2024 23:27:39 +0200 Subject: [PATCH 24/39] Add MatterSoftwareVersion model for check_node_update Use a new model MatterSoftwareVersion to store the software version information typically fetched from DCL for the Matter nodes. --- matter_server/client/client.py | 12 ++++--- matter_server/common/models.py | 18 ++++++++++ matter_server/server/device_controller.py | 42 ++++++++++++++++++----- 3 files changed, 59 insertions(+), 13 deletions(-) diff --git a/matter_server/client/client.py b/matter_server/client/client.py index b926746a..9253582d 100644 --- a/matter_server/client/client.py +++ b/matter_server/client/client.py @@ -29,6 +29,7 @@ EventType, MatterNodeData, MatterNodeEvent, + MatterSoftwareVersion, MessageType, NodePingResult, ResultMessageBase, @@ -509,7 +510,7 @@ async def interview_node(self, node_id: int) -> None: """Interview a node.""" await self.send_command(APICommand.INTERVIEW_NODE, node_id=node_id) - async def check_node_update(self, node_id: int) -> dict[str, Any]: + async def check_node_update(self, node_id: int) -> MatterSoftwareVersion | None: """Check Node for updates. Return a dict with the available update information. Most notable @@ -518,10 +519,11 @@ async def check_node_update(self, node_id: int) -> dict[str, Any]: The "softwareVersionString" is a human friendly version string. """ - node_update = await self.send_command( - APICommand.CHECK_NODE_UPDATE, node_id=node_id - ) - return cast(dict[str, Any], node_update) + data = await self.send_command(APICommand.CHECK_NODE_UPDATE, node_id=node_id) + if data is None: + return None + + return dataclass_from_dict(MatterSoftwareVersion, data) async def update_node( self, diff --git a/matter_server/common/models.py b/matter_server/common/models.py index ea550de1..16237d30 100644 --- a/matter_server/common/models.py +++ b/matter_server/common/models.py @@ -211,3 +211,21 @@ class CommissioningParameters: setup_pin_code: int setup_manual_code: str setup_qr_code: str + + +@dataclass +class MatterSoftwareVersion: + """Representation of a Matter software version. Return by the check_node_update command. + + This holds Matter software version information similar to what is available from the CSA DCL. + https://on.dcl.csa-iot.org/#/Query/ModelVersion. + """ + + vid: int + pid: int + software_version: int + software_version_string: str + firmware_information: str | None + min_applicable_software_version: int + max_applicable_software_version: int + release_notes_url: str | None diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index 63cdc082..3f2aee07 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -28,7 +28,11 @@ from matter_server.common.const import VERBOSE_LOG_LEVEL from matter_server.common.custom_clusters import check_polled_attributes -from matter_server.common.models import CommissionableNodeData, CommissioningParameters +from matter_server.common.models import ( + CommissionableNodeData, + CommissioningParameters, + MatterSoftwareVersion, +) from matter_server.server.helpers.attributes import parse_attributes_from_read_result from matter_server.server.helpers.utils import ping_ip from matter_server.server.ota import check_for_update @@ -897,7 +901,7 @@ async def import_test_node(self, dump: str) -> None: self.server.signal_event(EventType.NODE_ADDED, node) @api_command(APICommand.CHECK_NODE_UPDATE) - async def check_node_update(self, node_id: int) -> dict | None: + async def check_node_update(self, node_id: int) -> MatterSoftwareVersion | None: """ Check if there is an update for a particular node. @@ -906,12 +910,36 @@ async def check_node_update(self, node_id: int) -> dict | None: information of the latest update available. """ - return await self._check_node_update(node_id) + update = await self._check_node_update(node_id) + if update is None: + return None + + if not all( + key in update + for key in [ + "vid", + "pid", + "softwareVersion", + "softwareVersionString", + "minApplicableSoftwareVersion", + "maxApplicableSoftwareVersion", + ] + ): + raise UpdateCheckError("Invalid update data") + + return MatterSoftwareVersion( + vid=update["vid"], + pid=update["pid"], + software_version=update["softwareVersion"], + software_version_string=update["softwareVersionString"], + firmware_information=update.get("firmwareInformation", None), + min_applicable_software_version=update["minApplicableSoftwareVersion"], + max_applicable_software_version=update["maxApplicableSoftwareVersion"], + release_notes_url=update.get("releaseNotesUrl", None), + ) @api_command(APICommand.UPDATE_NODE) - async def update_node( - self, node_id: int, software_version: int | str - ) -> dict | None: + async def update_node(self, node_id: int, software_version: int | str) -> None: """ Update a node to a new software version. @@ -964,8 +992,6 @@ async def update_node( ) self._nodes_in_ota.remove(node_id) - return update - async def _check_node_update( self, node_id: int, From 475a1dc183003c1f47944a2fe562be310633d578 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Tue, 25 Jun 2024 00:11:11 +0200 Subject: [PATCH 25/39] Bump Server schema We've added new commands, so a schema version bump is needed. --- matter_server/common/const.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/matter_server/common/const.py b/matter_server/common/const.py index 45cc3184..26a4737e 100644 --- a/matter_server/common/const.py +++ b/matter_server/common/const.py @@ -2,7 +2,7 @@ # schema version is used to determine compatibility between server and client # bump schema if we add new features and/or make other (breaking) changes -SCHEMA_VERSION = 9 +SCHEMA_VERSION = 10 VERBOSE_LOG_LEVEL = 5 From 23a6e6b97f43e888bae21c717b741e9697e29688 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Thu, 11 Jul 2024 19:03:22 +0200 Subject: [PATCH 26/39] Use OTA Provider from dedicated repository --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index dc3cee92..e787703f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -26,7 +26,7 @@ RUN \ ARG PYTHON_MATTER_SERVER -ENV chip_example_url "https://github.com/agners/matter-linux-example-apps/releases/download/v1.3.0.0" +ENV chip_example_url "https://github.com/home-assistant-libs/matter-linux-ota-provider/releases/download/2024.7.0" ARG TARGETPLATFORM RUN \ From b0dca4bb1952d9a80c20ff29ca32053ffca6804c Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Thu, 11 Jul 2024 19:57:37 +0200 Subject: [PATCH 27/39] Bump OTA Provider to 2024.7.1 --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index e787703f..70426ab7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -26,7 +26,7 @@ RUN \ ARG PYTHON_MATTER_SERVER -ENV chip_example_url "https://github.com/home-assistant-libs/matter-linux-ota-provider/releases/download/2024.7.0" +ENV chip_example_url "https://github.com/home-assistant-libs/matter-linux-ota-provider/releases/download/2024.7.1" ARG TARGETPLATFORM RUN \ From 87cd0a4a79fa2bf76172bf3a1792f6bee8157782 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Thu, 11 Jul 2024 20:47:15 +0200 Subject: [PATCH 28/39] Use new node logger --- matter_server/server/device_controller.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index 3f2aee07..04fc493e 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -949,7 +949,7 @@ async def update_node(self, node_id: int, software_version: int | str) -> None: notify the node about the new update. """ - node_logger = LOGGER.getChild(f"node_{node_id}") + node_logger = self.get_node_logger(LOGGER, node_id) node_logger.info("Update to software version %r", software_version) update = await self._check_node_update(node_id, software_version) @@ -997,7 +997,7 @@ async def _check_node_update( node_id: int, requested_software_version: int | str | None = None, ) -> dict | None: - node_logger = LOGGER.getChild(f"node_{node_id}") + node_logger = self.get_node_logger(LOGGER, node_id) node = self._nodes[node_id] node_logger.debug("Check for updates.") From 7e7537b10ca41021a970bc3f162d736e0dbe35a9 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Thu, 11 Jul 2024 20:59:12 +0200 Subject: [PATCH 29/39] Complete future only once on error --- matter_server/server/ota/provider.py | 1 + 1 file changed, 1 insertion(+) diff --git a/matter_server/server/ota/provider.py b/matter_server/server/ota/provider.py index 91418cbf..a8703ecc 100644 --- a/matter_server/server/ota/provider.py +++ b/matter_server/server/ota/provider.py @@ -359,6 +359,7 @@ async def check_update_state( self._ota_done.set_exception( UpdateError("Target node did not process the update file") ) + return LOGGER.info( "Node %d update state idle, assuming done.", self._ota_target_node_id From 7a307007d5d616a5c1ad7b996a5c9ec252d91b24 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Thu, 11 Jul 2024 21:02:16 +0200 Subject: [PATCH 30/39] Apply suggestions from code review Co-authored-by: Marcel van der Veldt --- matter_server/client/client.py | 7 +++++-- matter_server/server/ota/provider.py | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/matter_server/client/client.py b/matter_server/client/client.py index 9253582d..6365311f 100644 --- a/matter_server/client/client.py +++ b/matter_server/client/client.py @@ -519,7 +519,7 @@ async def check_node_update(self, node_id: int) -> MatterSoftwareVersion | None: The "softwareVersionString" is a human friendly version string. """ - data = await self.send_command(APICommand.CHECK_NODE_UPDATE, node_id=node_id) + data = await self.send_command(APICommand.CHECK_NODE_UPDATE, node_id=node_id, require_schema=10) if data is None: return None @@ -532,7 +532,10 @@ async def update_node( ) -> None: """Start node update to a particular version.""" await self.send_command( - APICommand.UPDATE_NODE, node_id=node_id, software_version=software_version + APICommand.UPDATE_NODE, + node_id=node_id, + software_version=software_version, + require_schema=10 ) def _prepare_message( diff --git a/matter_server/server/ota/provider.py b/matter_server/server/ota/provider.py index a8703ecc..62932bc4 100644 --- a/matter_server/server/ota/provider.py +++ b/matter_server/server/ota/provider.py @@ -212,8 +212,8 @@ async def start_update( ota_provider_node_id, ) - # Notify update node about the availability of the OTA Provider. It will query - # the OTA provider and start the update. + # Notify update node about the availability of the OTA Provider. + # It will query the OTA provider and start the update. try: await chip_device_controller.send_command( node_id, From 07e20dd8aa2ac187910c25303884e21cd05e3d46 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Thu, 11 Jul 2024 21:27:45 +0200 Subject: [PATCH 31/39] Share client session for update check --- matter_server/server/ota/dcl.py | 109 ++++++++++++++++++-------------- 1 file changed, 62 insertions(+), 47 deletions(-) diff --git a/matter_server/server/ota/dcl.py b/matter_server/server/ota/dcl.py index 342f6337..8a08c3b3 100644 --- a/matter_server/server/ota/dcl.py +++ b/matter_server/server/ota/dcl.py @@ -12,30 +12,30 @@ LOGGER = logging.getLogger(__name__) -async def _get_software_versions(vid: int, pid: int) -> Any: +async def _get_software_versions(session: ClientSession, vid: int, pid: int) -> Any: """Check DCL if there are updates available for a particular node.""" - async with ClientSession(raise_for_status=False) as http_session: - # fetch the paa certificates list - async with http_session.get( - f"{DCL_PRODUCTION_URL}/dcl/model/versions/{vid}/{pid}" - ) as response: - if response.status == HTTPStatus.NOT_FOUND: - return None - response.raise_for_status() - return await response.json() + # fetch the paa certificates list + async with session.get(f"/dcl/model/versions/{vid}/{pid}") as response: + if response.status == HTTPStatus.NOT_FOUND: + return None + response.raise_for_status() + return await response.json() -async def _get_software_version(vid: int, pid: int, software_version: int) -> Any: +async def _get_software_version( + session: ClientSession, vid: int, pid: int, software_version: int +) -> Any: """Check DCL if there are updates available for a particular node.""" - async with ClientSession(raise_for_status=True) as http_session: - # fetch the paa certificates list - async with http_session.get( - f"{DCL_PRODUCTION_URL}/dcl/model/versions/{vid}/{pid}/{software_version}" - ) as response: - return await response.json() + # fetch the paa certificates list + async with session.get( + f"/dcl/model/versions/{vid}/{pid}/{software_version}" + ) as response: + response.raise_for_status() + return await response.json() async def _check_update_version( + session: ClientSession, vid: int, pid: int, current_software_version: int, @@ -43,7 +43,7 @@ async def _check_update_version( requested_software_version_string: str | None = None, ) -> None | dict: version_res: dict = await _get_software_version( - vid, pid, requested_software_version + session, vid, pid, requested_software_version ) if not isinstance(version_res, dict): raise TypeError("Unexpected DCL response.") @@ -81,39 +81,54 @@ async def check_for_update( ) -> None | dict: """Check if there is a software update available on the DCL.""" try: - # If a specific version as integer is requested, just fetch it (and hope it exists) - if isinstance(requested_software_version, int): - return await _check_update_version( - vid, pid, current_software_version, requested_software_version - ) - - # Get all versions and check each one of them. - versions = await _get_software_versions(vid, pid) - if versions is None: - LOGGER.info("There is no update information for this device on the DCL.") - return None + async with ClientSession( + base_url=DCL_PRODUCTION_URL, raise_for_status=False + ) as session: + # If a specific version as integer is requested, just fetch it (and hope it exists) + if isinstance(requested_software_version, int): + return await _check_update_version( + session, + vid, + pid, + current_software_version, + requested_software_version, + ) + + # Get all versions and check each one of them. + versions = await _get_software_versions(session, vid, pid) + if versions is None: + LOGGER.info( + "There is no update information for this device on the DCL." + ) + return None - all_software_versions: list[int] = versions["modelVersions"]["softwareVersions"] - newer_software_versions = [ - version - for version in all_software_versions - if version > current_software_version - ] + all_software_versions: list[int] = versions["modelVersions"][ + "softwareVersions" + ] + newer_software_versions = [ + version + for version in all_software_versions + if version > current_software_version + ] + + # Check if there is a newer software version available, no downgrade possible + if not newer_software_versions: + return None - # Check if there is a newer software version available, no downgrade possible - if not newer_software_versions: - LOGGER.info("No newer software version available.") + # Check if latest firmware is applicable, and backtrack from there + for version in sorted(newer_software_versions, reverse=True): + if version_candidate := await _check_update_version( + session, + vid, + pid, + current_software_version, + version, + requested_software_version, + ): + return version_candidate + LOGGER.debug("Software version %d not applicable.", version) return None - # Check if latest firmware is applicable, and backtrack from there - for version in sorted(newer_software_versions, reverse=True): - if version_candidate := await _check_update_version( - vid, pid, current_software_version, version, requested_software_version - ): - return version_candidate - LOGGER.debug("Software version %d not applicable.", version) - return None - except (ClientError, TimeoutError) as err: raise UpdateCheckError( f"Fetching software versions from DCL for device with vendor id {vid} product id {pid} failed." From b057eae34023d7feb1ccdf06332d621d29c2899b Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Thu, 11 Jul 2024 23:37:07 +0200 Subject: [PATCH 32/39] Provide methods to convert dataclass as dict --- matter_server/common/models.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/matter_server/common/models.py b/matter_server/common/models.py index 16237d30..ab6a27f3 100644 --- a/matter_server/common/models.py +++ b/matter_server/common/models.py @@ -229,3 +229,30 @@ class MatterSoftwareVersion: min_applicable_software_version: int max_applicable_software_version: int release_notes_url: str | None + + @classmethod + def from_dict(cls, data: dict) -> MatterSoftwareVersion: + """Initialize from dict.""" + return cls( + vid=data["vid"], + pid=data["pid"], + software_version=data["software_version"], + software_version_string=data["software_version_string"], + firmware_information=data["firmware_information"], + min_applicable_software_version=data["min_applicable_software_version"], + max_applicable_software_version=data["max_applicable_software_version"], + release_notes_url=data["release_notes_url"], + ) + + def as_dict(self) -> dict: + """Return dict representation of the object.""" + return { + "vid": self.vid, + "pid": self.pid, + "software_version": self.software_version, + "software_version_string": self.software_version_string, + "firmware_information": self.firmware_information, + "min_applicable_software_version": self.min_applicable_software_version, + "max_applicable_software_version": self.max_applicable_software_version, + "release_notes_url": self.release_notes_url, + } From 09c92f72e94d0d218c48cfab92dcd236c308454f Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Fri, 12 Jul 2024 00:08:14 +0200 Subject: [PATCH 33/39] Log with node logger when checking for updates --- matter_server/server/device_controller.py | 2 +- matter_server/server/ota/__init__.py | 5 ++++- matter_server/server/ota/dcl.py | 7 +++---- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index 04fc493e..a37a9eae 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -1013,7 +1013,7 @@ async def _check_node_update( ) update = await check_for_update( - vid, pid, software_version, requested_software_version + node_logger, vid, pid, software_version, requested_software_version ) if not update: node_logger.info("No new update found.") diff --git a/matter_server/server/ota/__init__.py b/matter_server/server/ota/__init__.py index afd49b4f..dd4a04a3 100644 --- a/matter_server/server/ota/__init__.py +++ b/matter_server/server/ota/__init__.py @@ -1,5 +1,7 @@ """OTA implementation for the Matter Server.""" +from logging import LoggerAdapter + from matter_server.server.ota import dcl HARDCODED_UPDATES: dict[tuple[int, int], dict] = { @@ -35,6 +37,7 @@ async def check_for_update( + logger: LoggerAdapter, vid: int, pid: int, current_software_version: int, @@ -51,5 +54,5 @@ async def check_for_update( return update return await dcl.check_for_update( - vid, pid, current_software_version, requested_software_version + logger, vid, pid, current_software_version, requested_software_version ) diff --git a/matter_server/server/ota/dcl.py b/matter_server/server/ota/dcl.py index 8a08c3b3..d555def0 100644 --- a/matter_server/server/ota/dcl.py +++ b/matter_server/server/ota/dcl.py @@ -9,8 +9,6 @@ from matter_server.common.errors import UpdateCheckError from matter_server.server.helpers import DCL_PRODUCTION_URL -LOGGER = logging.getLogger(__name__) - async def _get_software_versions(session: ClientSession, vid: int, pid: int) -> Any: """Check DCL if there are updates available for a particular node.""" @@ -74,6 +72,7 @@ async def _check_update_version( async def check_for_update( + logger: logging.LoggerAdapter, vid: int, pid: int, current_software_version: int, @@ -97,7 +96,7 @@ async def check_for_update( # Get all versions and check each one of them. versions = await _get_software_versions(session, vid, pid) if versions is None: - LOGGER.info( + logger.info( "There is no update information for this device on the DCL." ) return None @@ -126,7 +125,7 @@ async def check_for_update( requested_software_version, ): return version_candidate - LOGGER.debug("Software version %d not applicable.", version) + logger.debug("Software version %d not applicable.", version) return None except (ClientError, TimeoutError) as err: From 57fb7d2480ac75d7752bc1c400cbaf676474abac Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Fri, 12 Jul 2024 00:10:07 +0200 Subject: [PATCH 34/39] Fix trailing whitespace --- matter_server/server/ota/provider.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/matter_server/server/ota/provider.py b/matter_server/server/ota/provider.py index 62932bc4..cb583552 100644 --- a/matter_server/server/ota/provider.py +++ b/matter_server/server/ota/provider.py @@ -179,7 +179,7 @@ async def start_update( "--discriminator", str(ota_provider_discriminator), "--secured-device-port", - "0", + "5540", "--KVS", str(self._ota_provider_dir / f"chip_kvs_ota_provider_{timestamp}"), "--filepath", @@ -212,7 +212,7 @@ async def start_update( ota_provider_node_id, ) - # Notify update node about the availability of the OTA Provider. + # Notify update node about the availability of the OTA Provider. # It will query the OTA provider and start the update. try: await chip_device_controller.send_command( From 9ad23485d0b82d0539c67ed982a6d3db7034c06b Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Fri, 12 Jul 2024 08:04:43 +0200 Subject: [PATCH 35/39] Fix tests --- tests/server/ota/test_dcl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/server/ota/test_dcl.py b/tests/server/ota/test_dcl.py index c52f189f..e86cfb8b 100644 --- a/tests/server/ota/test_dcl.py +++ b/tests/server/ota/test_dcl.py @@ -1,6 +1,6 @@ """Test DCL OTA updates.""" -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -62,7 +62,7 @@ def mock_get_software_version(): async def test_check_updates(get_software_versions, get_software_version): """Test the case where the latest software version is applicable.""" # Call the function with a current software version of 1000 - result = await check_for_update(4447, 8194, 1000) + result = await check_for_update(MagicMock(), 4447, 8194, 1000) assert result == DCL_RESPONSE_SOFTWARE_VERSION_1011["modelVersion"] @@ -72,7 +72,7 @@ async def test_check_updates_not_applicable( ): """Test the case where the latest software version is not applicable.""" # Call the function with a current software version of 1 - result = await check_for_update(4447, 8194, 1) + result = await check_for_update(MagicMock(), 4447, 8194, 1) assert result is None @@ -80,6 +80,6 @@ async def test_check_updates_not_applicable( async def test_check_updates_specific_version(get_software_version): """Test the case to get a specific version.""" # Call the function with a current software version of 1000 and request 1011 as update - result = await check_for_update(4447, 8194, 1000, 1011) + result = await check_for_update(MagicMock(), 4447, 8194, 1000, 1011) assert result == DCL_RESPONSE_SOFTWARE_VERSION_1011["modelVersion"] From 507a42941ded1a1bb9f5198671bbbcdfd3ec1b76 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Fri, 12 Jul 2024 10:16:28 +0200 Subject: [PATCH 36/39] ruff format --- matter_server/client/client.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/matter_server/client/client.py b/matter_server/client/client.py index 6365311f..ec87dba7 100644 --- a/matter_server/client/client.py +++ b/matter_server/client/client.py @@ -519,7 +519,9 @@ async def check_node_update(self, node_id: int) -> MatterSoftwareVersion | None: The "softwareVersionString" is a human friendly version string. """ - data = await self.send_command(APICommand.CHECK_NODE_UPDATE, node_id=node_id, require_schema=10) + data = await self.send_command( + APICommand.CHECK_NODE_UPDATE, node_id=node_id, require_schema=10 + ) if data is None: return None @@ -535,7 +537,7 @@ async def update_node( APICommand.UPDATE_NODE, node_id=node_id, software_version=software_version, - require_schema=10 + require_schema=10, ) def _prepare_message( From 39f025e9552dd2aa7cce4ec571eee4c834a4a71d Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Mon, 15 Jul 2024 12:51:09 +0200 Subject: [PATCH 37/39] Support loading updates from local json file --- matter_server/server/device_controller.py | 3 +- matter_server/server/ota/__init__.py | 55 ++++++++++------------- 2 files changed, 25 insertions(+), 33 deletions(-) diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index a37a9eae..871e5cfc 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -35,7 +35,7 @@ ) from matter_server.server.helpers.attributes import parse_attributes_from_read_result from matter_server.server.helpers.utils import ping_ip -from matter_server.server.ota import check_for_update +from matter_server.server.ota import check_for_update, load_local_updates from matter_server.server.ota.provider import ExternalOtaProvider from matter_server.server.sdk import ChipDeviceControllerWrapper @@ -168,6 +168,7 @@ async def initialize(self) -> None: await self._chip_device_controller.get_compressed_fabric_id() ) self._fabric_id_hex = hex(self._compressed_fabric_id)[2:] + await load_local_updates(self._ota_provider_dir) async def start(self) -> None: """Handle logic on controller start.""" diff --git a/matter_server/server/ota/__init__.py b/matter_server/server/ota/__init__.py index dd4a04a3..1c102dca 100644 --- a/matter_server/server/ota/__init__.py +++ b/matter_server/server/ota/__init__.py @@ -1,39 +1,30 @@ """OTA implementation for the Matter Server.""" +import asyncio +import json from logging import LoggerAdapter +from pathlib import Path from matter_server.server.ota import dcl -HARDCODED_UPDATES: dict[tuple[int, int], dict] = { - # OTA requestor example app, useful for testing - (0xFFF1, 0x8001): { - "vid": 0xFFF1, - "pid": 0x8001, - "softwareVersion": 2, - "softwareVersionString": "2.0", - "cdVersionNumber": 1, - "softwareVersionValid": True, - "otaChecksum": "7qcyvg2kPmKZaDLIk8C7Vyteqf4DI73x0tFZkmPALCo=", - "otaChecksumType": 1, - "minApplicableSoftwareVersion": 1, - "maxApplicableSoftwareVersion": 1, - "otaUrl": "https://github.com/agners/matter-linux-example-apps/releases/download/v1.3.0.0/chip-ota-requestor-app-x86-64.ota", - "releaseNotesUrl": "https://github.com/agners/matter-linux-example-apps/releases/tag/v1.3.0.0", - }, - (0x143D, 0x1001): { - "vid": 0x143D, - "pid": 0x1001, - "softwareVersion": 10010011, - "softwareVersionString": "1.1.11-c85ba1e-dirty", - "cdVersionNumber": 1, - "softwareVersionValid": True, - "otaChecksum": "x2sK9xjVuGff0eefYa4cporDO+Z+WVxxw+JP5Ol+5og=", - "otaChecksumType": 1, - "minApplicableSoftwareVersion": 10010000, - "maxApplicableSoftwareVersion": 10010011, - "otaUrl": "https://raw.githubusercontent.com/ChampOnBon/Onvis/master/S4/debug.ota", - }, -} +_local_updates: dict[tuple[int, int], dict] = {} + + +async def load_local_updates(ota_provider_dir: Path) -> None: + """Load updates from locally stored json files.""" + + def _load_update(ota_provider_dir: Path) -> None: + for update_file in ota_provider_dir.glob("*.json"): + with open(update_file) as f: + update = json.load(f) + model_version = update["modelVersion"] + _local_updates[(model_version["vid"], model_version["pid"])] = ( + model_version + ) + + await asyncio.get_running_loop().run_in_executor( + None, _load_update, ota_provider_dir + ) async def check_for_update( @@ -44,8 +35,8 @@ async def check_for_update( requested_software_version: int | str | None = None, ) -> None | dict: """Check for software updates.""" - if (vid, pid) in HARDCODED_UPDATES: - update = HARDCODED_UPDATES[(vid, pid)] + if (vid, pid) in _local_updates: + update = _local_updates[(vid, pid)] if ( requested_software_version is None or update["softwareVersion"] == requested_software_version From 51bec820b202d2d336b5d0f4c0f0e76cea32cc47 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Mon, 15 Jul 2024 12:54:59 +0200 Subject: [PATCH 38/39] Check if update directory exists --- matter_server/server/ota/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/matter_server/server/ota/__init__.py b/matter_server/server/ota/__init__.py index 1c102dca..d0fef4c8 100644 --- a/matter_server/server/ota/__init__.py +++ b/matter_server/server/ota/__init__.py @@ -14,6 +14,8 @@ async def load_local_updates(ota_provider_dir: Path) -> None: """Load updates from locally stored json files.""" def _load_update(ota_provider_dir: Path) -> None: + if not ota_provider_dir.exists(): + return for update_file in ota_provider_dir.glob("*.json"): with open(update_file) as f: update = json.load(f) From d4162fe60bdc15e33a60090b0082310916d7d8b2 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Mon, 15 Jul 2024 13:24:24 +0200 Subject: [PATCH 39/39] Add software update source information --- matter_server/common/models.py | 11 +++++++++++ matter_server/server/device_controller.py | 21 ++++++++++++--------- matter_server/server/ota/__init__.py | 17 ++++++++++------- 3 files changed, 33 insertions(+), 16 deletions(-) diff --git a/matter_server/common/models.py b/matter_server/common/models.py index ab6a27f3..f049782d 100644 --- a/matter_server/common/models.py +++ b/matter_server/common/models.py @@ -213,6 +213,14 @@ class CommissioningParameters: setup_qr_code: str +class UpdateSource(Enum): + """Enum with possible sources for a software update.""" + + MAIN_NET_DCL = "main-net-dcl" + TEST_NET_DCL = "test-net-dcl" + LOCAL = "local" + + @dataclass class MatterSoftwareVersion: """Representation of a Matter software version. Return by the check_node_update command. @@ -229,6 +237,7 @@ class MatterSoftwareVersion: min_applicable_software_version: int max_applicable_software_version: int release_notes_url: str | None + update_source: UpdateSource @classmethod def from_dict(cls, data: dict) -> MatterSoftwareVersion: @@ -242,6 +251,7 @@ def from_dict(cls, data: dict) -> MatterSoftwareVersion: min_applicable_software_version=data["min_applicable_software_version"], max_applicable_software_version=data["max_applicable_software_version"], release_notes_url=data["release_notes_url"], + update_source=UpdateSource(data["update_source"]), ) def as_dict(self) -> dict: @@ -255,4 +265,5 @@ def as_dict(self) -> dict: "min_applicable_software_version": self.min_applicable_software_version, "max_applicable_software_version": self.max_applicable_software_version, "release_notes_url": self.release_notes_url, + "update_source": str(self.update_source), } diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index 871e5cfc..0ae9a299 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -63,6 +63,7 @@ MatterNodeData, MatterNodeEvent, NodePingResult, + UpdateSource, ) from .const import DATA_MODEL_SCHEMA_VERSION @@ -911,8 +912,8 @@ async def check_node_update(self, node_id: int) -> MatterSoftwareVersion | None: information of the latest update available. """ - update = await self._check_node_update(node_id) - if update is None: + update_source, update = await self._check_node_update(node_id) + if update_source is None or update is None: return None if not all( @@ -937,6 +938,7 @@ async def check_node_update(self, node_id: int) -> MatterSoftwareVersion | None: min_applicable_software_version=update["minApplicableSoftwareVersion"], max_applicable_software_version=update["maxApplicableSoftwareVersion"], release_notes_url=update.get("releaseNotesUrl", None), + update_source=update_source, ) @api_command(APICommand.UPDATE_NODE) @@ -953,7 +955,7 @@ async def update_node(self, node_id: int, software_version: int | str) -> None: node_logger = self.get_node_logger(LOGGER, node_id) node_logger.info("Update to software version %r", software_version) - update = await self._check_node_update(node_id, software_version) + _, update = await self._check_node_update(node_id, software_version) if update is None: raise UpdateCheckError( f"Software version {software_version} is not available for node {node_id}." @@ -997,7 +999,7 @@ async def _check_node_update( self, node_id: int, requested_software_version: int | str | None = None, - ) -> dict | None: + ) -> tuple[UpdateSource, dict] | tuple[None, None]: node_logger = self.get_node_logger(LOGGER, node_id) node = self._nodes[node_id] @@ -1013,22 +1015,23 @@ async def _check_node_update( BASIC_INFORMATION_SOFTWARE_VERSION_STRING_ATTRIBUTE_PATH ) - update = await check_for_update( + update_source, update = await check_for_update( node_logger, vid, pid, software_version, requested_software_version ) - if not update: + if not update_source or not update: node_logger.info("No new update found.") - return None + return None, None if "otaUrl" not in update: raise UpdateCheckError("Update found, but no OTA URL provided.") node_logger.info( - "New software update found: %s (current %s).", + "New software update found: %s on %s (current %s).", update["softwareVersionString"], + update_source, software_version_string, ) - return update + return update_source, update async def _subscribe_node(self, node_id: int) -> None: """ diff --git a/matter_server/server/ota/__init__.py b/matter_server/server/ota/__init__.py index d0fef4c8..21cf52c7 100644 --- a/matter_server/server/ota/__init__.py +++ b/matter_server/server/ota/__init__.py @@ -5,6 +5,7 @@ from logging import LoggerAdapter from pathlib import Path +from matter_server.common.models import UpdateSource from matter_server.server.ota import dcl _local_updates: dict[tuple[int, int], dict] = {} @@ -35,17 +36,19 @@ async def check_for_update( pid: int, current_software_version: int, requested_software_version: int | str | None = None, -) -> None | dict: +) -> tuple[UpdateSource, dict] | tuple[None, None]: """Check for software updates.""" if (vid, pid) in _local_updates: - update = _local_updates[(vid, pid)] + local_update = _local_updates[(vid, pid)] if ( requested_software_version is None - or update["softwareVersion"] == requested_software_version - or update["softwareVersionString"] == requested_software_version + or local_update["softwareVersion"] == requested_software_version + or local_update["softwareVersionString"] == requested_software_version ): - return update + return UpdateSource.LOCAL, local_update - return await dcl.check_for_update( + if dcl_update := await dcl.check_for_update( logger, vid, pid, current_software_version, requested_software_version - ) + ): + return UpdateSource.MAIN_NET_DCL, dcl_update + return None, None