diff --git a/Dockerfile b/Dockerfile index 1f911762..70426ab7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -9,6 +9,7 @@ RUN \ set -x \ && apt-get update \ && apt-get install -y --no-install-recommends \ + curl \ libuv1 \ zlib1g \ libjson-c5 \ @@ -25,6 +26,21 @@ RUN \ ARG PYTHON_MATTER_SERVER +ENV chip_example_url "https://github.com/home-assistant-libs/matter-linux-ota-provider/releases/download/2024.7.1" +ARG TARGETPLATFORM + +RUN \ + set -x \ + && echo "${TARGETPLATFORM}" \ + && if [ "${TARGETPLATFORM}" = "linux/amd64" ]; then \ + curl -Lo /usr/local/bin/chip-ota-provider-app "${chip_example_url}/chip-ota-provider-app-x86-64"; \ + elif [ "${TARGETPLATFORM}" = "linux/arm64" ]; then \ + curl -Lo /usr/local/bin/chip-ota-provider-app "${chip_example_url}/chip-ota-provider-app-aarch64"; \ + else \ + exit 1; \ + fi \ + && chmod +x /usr/local/bin/chip-ota-provider-app + # hadolint ignore=DL3013 RUN \ pip3 install --no-cache-dir "python-matter-server[server]==${PYTHON_MATTER_SERVER}" diff --git a/matter_server/client/client.py b/matter_server/client/client.py index 3777d15c..ec87dba7 100644 --- a/matter_server/client/client.py +++ b/matter_server/client/client.py @@ -29,6 +29,7 @@ EventType, MatterNodeData, MatterNodeEvent, + MatterSoftwareVersion, MessageType, NodePingResult, ResultMessageBase, @@ -509,6 +510,36 @@ async def interview_node(self, node_id: int) -> None: """Interview a node.""" await self.send_command(APICommand.INTERVIEW_NODE, node_id=node_id) + async def check_node_update(self, node_id: int) -> MatterSoftwareVersion | None: + """Check Node for updates. + + Return a dict with the available update information. Most notable + "softwareVersion" contains the integer value of the update version which then + can be used for the update_node command to trigger the update. + + The "softwareVersionString" is a human friendly version string. + """ + data = await self.send_command( + APICommand.CHECK_NODE_UPDATE, node_id=node_id, require_schema=10 + ) + if data is None: + return None + + return dataclass_from_dict(MatterSoftwareVersion, data) + + async def update_node( + self, + node_id: int, + software_version: int | str, + ) -> None: + """Start node update to a particular version.""" + await self.send_command( + APICommand.UPDATE_NODE, + node_id=node_id, + software_version=software_version, + require_schema=10, + ) + def _prepare_message( self, command: str, diff --git a/matter_server/common/const.py b/matter_server/common/const.py index 45cc3184..26a4737e 100644 --- a/matter_server/common/const.py +++ b/matter_server/common/const.py @@ -2,7 +2,7 @@ # schema version is used to determine compatibility between server and client # bump schema if we add new features and/or make other (breaking) changes -SCHEMA_VERSION = 9 +SCHEMA_VERSION = 10 VERBOSE_LOG_LEVEL = 5 diff --git a/matter_server/common/errors.py b/matter_server/common/errors.py index d07fc222..8ce6dd0e 100644 --- a/matter_server/common/errors.py +++ b/matter_server/common/errors.py @@ -77,6 +77,18 @@ class InvalidCommand(MatterError): error_code = 9 +class UpdateCheckError(MatterError): + """Error raised when there was an error during searching for updates.""" + + error_code = 10 + + +class UpdateError(MatterError): + """Error raised when there was an error during applying updates.""" + + error_code = 11 + + def exception_from_error_code(error_code: int) -> type[MatterError]: """Return correct Exception class from error_code.""" return ERROR_MAP.get(error_code, MatterError) diff --git a/matter_server/common/models.py b/matter_server/common/models.py index efd4049d..f049782d 100644 --- a/matter_server/common/models.py +++ b/matter_server/common/models.py @@ -47,6 +47,8 @@ class APICommand(str, Enum): PING_NODE = "ping_node" GET_NODE_IP_ADDRESSES = "get_node_ip_addresses" IMPORT_TEST_NODE = "import_test_node" + CHECK_NODE_UPDATE = "check_node_update" + UPDATE_NODE = "update_node" EventCallBackType = Callable[[EventType, Any], None] @@ -209,3 +211,59 @@ class CommissioningParameters: setup_pin_code: int setup_manual_code: str setup_qr_code: str + + +class UpdateSource(Enum): + """Enum with possible sources for a software update.""" + + MAIN_NET_DCL = "main-net-dcl" + TEST_NET_DCL = "test-net-dcl" + LOCAL = "local" + + +@dataclass +class MatterSoftwareVersion: + """Representation of a Matter software version. Return by the check_node_update command. + + This holds Matter software version information similar to what is available from the CSA DCL. + https://on.dcl.csa-iot.org/#/Query/ModelVersion. + """ + + vid: int + pid: int + software_version: int + software_version_string: str + firmware_information: str | None + min_applicable_software_version: int + max_applicable_software_version: int + release_notes_url: str | None + update_source: UpdateSource + + @classmethod + def from_dict(cls, data: dict) -> MatterSoftwareVersion: + """Initialize from dict.""" + return cls( + vid=data["vid"], + pid=data["pid"], + software_version=data["software_version"], + software_version_string=data["software_version_string"], + firmware_information=data["firmware_information"], + min_applicable_software_version=data["min_applicable_software_version"], + max_applicable_software_version=data["max_applicable_software_version"], + release_notes_url=data["release_notes_url"], + update_source=UpdateSource(data["update_source"]), + ) + + def as_dict(self) -> dict: + """Return dict representation of the object.""" + return { + "vid": self.vid, + "pid": self.pid, + "software_version": self.software_version, + "software_version_string": self.software_version_string, + "firmware_information": self.firmware_information, + "min_applicable_software_version": self.min_applicable_software_version, + "max_applicable_software_version": self.max_applicable_software_version, + "release_notes_url": self.release_notes_url, + "update_source": str(self.update_source), + } diff --git a/matter_server/server/__main__.py b/matter_server/server/__main__.py index df66597a..f12bdd44 100644 --- a/matter_server/server/__main__.py +++ b/matter_server/server/__main__.py @@ -116,6 +116,12 @@ nargs="+", help="List of node IDs to show logs from (applies only to server logs).", ) +parser.add_argument( + "--ota-provider-dir", + type=str, + default=None, + help="Directory where OTA Provider stores software updates and configuration.", +) args = parser.parse_args() @@ -216,6 +222,7 @@ def main() -> None: args.paa_root_cert_dir, args.enable_test_net_dcl, args.bluetooth_adapter, + args.ota_provider_dir, ) async def handle_stop(loop: asyncio.AbstractEventLoop) -> None: diff --git a/matter_server/server/const.py b/matter_server/server/const.py index 5afa4c1f..f411f68f 100644 --- a/matter_server/server/const.py +++ b/matter_server/server/const.py @@ -20,3 +20,5 @@ .parent.resolve() .joinpath("credentials/development/paa-root-certs") ) + +DEFAULT_OTA_PROVIDER_DIR: Final[pathlib.Path] = pathlib.Path().cwd().joinpath("updates") diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index 730aec57..0ae9a299 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -28,9 +28,15 @@ from matter_server.common.const import VERBOSE_LOG_LEVEL from matter_server.common.custom_clusters import check_polled_attributes -from matter_server.common.models import CommissionableNodeData, CommissioningParameters +from matter_server.common.models import ( + CommissionableNodeData, + CommissioningParameters, + MatterSoftwareVersion, +) from matter_server.server.helpers.attributes import parse_attributes_from_read_result from matter_server.server.helpers.utils import ping_ip +from matter_server.server.ota import check_for_update, load_local_updates +from matter_server.server.ota.provider import ExternalOtaProvider from matter_server.server.sdk import ChipDeviceControllerWrapper from ..common.errors import ( @@ -40,6 +46,8 @@ NodeNotExists, NodeNotReady, NodeNotResolving, + UpdateCheckError, + UpdateError, ) from ..common.helpers.api import api_command from ..common.helpers.json import JSON_DECODE_EXCEPTIONS, json_loads @@ -55,11 +63,12 @@ MatterNodeData, MatterNodeEvent, NodePingResult, + UpdateSource, ) from .const import DATA_MODEL_SCHEMA_VERSION if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Callable, Iterable from pathlib import Path from .server import MatterServer @@ -91,11 +100,23 @@ DESCRIPTOR_PARTS_LIST_ATTRIBUTE_PATH = create_attribute_path_from_attribute( 0, Clusters.Descriptor.Attributes.PartsList ) +BASIC_INFORMATION_VENDOR_ID_ATTRIBUTE_PATH = create_attribute_path_from_attribute( + 0, Clusters.BasicInformation.Attributes.VendorID +) +BASIC_INFORMATION_PRODUCT_ID_ATTRIBUTE_PATH = create_attribute_path_from_attribute( + 0, Clusters.BasicInformation.Attributes.ProductID +) BASIC_INFORMATION_SOFTWARE_VERSION_ATTRIBUTE_PATH = ( create_attribute_path_from_attribute( 0, Clusters.BasicInformation.Attributes.SoftwareVersion ) ) +BASIC_INFORMATION_SOFTWARE_VERSION_STRING_ATTRIBUTE_PATH = ( + create_attribute_path_from_attribute( + 0, Clusters.BasicInformation.Attributes.SoftwareVersionString + ) +) + # pylint: disable=too-many-lines,too-many-instance-attributes,too-many-public-methods @@ -107,9 +128,11 @@ def __init__( self, server: MatterServer, paa_root_cert_dir: Path, + ota_provider_dir: Path, ): """Initialize the device controller.""" self.server = server + self._ota_provider_dir = ota_provider_dir self._chip_device_controller = ChipDeviceControllerWrapper( server, paa_root_cert_dir @@ -122,6 +145,7 @@ def __init__( self._wifi_credentials_set: bool = False self._thread_credentials_set: bool = False self._nodes_in_setup: set[int] = set() + self._nodes_in_ota: set[int] = set() self._node_last_seen: dict[int, float] = {} self._nodes: dict[int, MatterNodeData] = {} self._last_known_ip_addresses: dict[int, list[str]] = {} @@ -137,6 +161,7 @@ def __init__( self._polled_attributes: dict[int, set[str]] = {} self._custom_attribute_poller_timer: asyncio.TimerHandle | None = None self._custom_attribute_poller_task: asyncio.Task | None = None + self._attribute_update_callbacks: dict[int, list[Callable]] = {} async def initialize(self) -> None: """Initialize the device controller.""" @@ -144,6 +169,7 @@ async def initialize(self) -> None: await self._chip_device_controller.get_compressed_fabric_id() ) self._fabric_id_hex = hex(self._compressed_fabric_id)[2:] + await load_local_updates(self._ota_provider_dir) async def start(self) -> None: """Handle logic on controller start.""" @@ -876,6 +902,137 @@ async def import_test_node(self, dump: str) -> None: self._nodes[node.node_id] = node self.server.signal_event(EventType.NODE_ADDED, node) + @api_command(APICommand.CHECK_NODE_UPDATE) + async def check_node_update(self, node_id: int) -> MatterSoftwareVersion | None: + """ + Check if there is an update for a particular node. + + Reads the current software version and checks the DCL if there is an update + available. If there is an update available, the command returns the version + information of the latest update available. + """ + + update_source, update = await self._check_node_update(node_id) + if update_source is None or update is None: + return None + + if not all( + key in update + for key in [ + "vid", + "pid", + "softwareVersion", + "softwareVersionString", + "minApplicableSoftwareVersion", + "maxApplicableSoftwareVersion", + ] + ): + raise UpdateCheckError("Invalid update data") + + return MatterSoftwareVersion( + vid=update["vid"], + pid=update["pid"], + software_version=update["softwareVersion"], + software_version_string=update["softwareVersionString"], + firmware_information=update.get("firmwareInformation", None), + min_applicable_software_version=update["minApplicableSoftwareVersion"], + max_applicable_software_version=update["maxApplicableSoftwareVersion"], + release_notes_url=update.get("releaseNotesUrl", None), + update_source=update_source, + ) + + @api_command(APICommand.UPDATE_NODE) + async def update_node(self, node_id: int, software_version: int | str) -> None: + """ + Update a node to a new software version. + + This command checks if the requested software version is indeed still available + and if so, it will start the update process. The update process will be handled + by the built-in OTA provider. The OTA provider will download the update and + notify the node about the new update. + """ + + node_logger = self.get_node_logger(LOGGER, node_id) + node_logger.info("Update to software version %r", software_version) + + _, update = await self._check_node_update(node_id, software_version) + if update is None: + raise UpdateCheckError( + f"Software version {software_version} is not available for node {node_id}." + ) + + # Add update to the OTA provider + ota_provider = ExternalOtaProvider( + self.server.vendor_id, self._ota_provider_dir / f"{node_id}" + ) + + await ota_provider.initialize() + + node_logger.info("Downloading update from '%s'", update["otaUrl"]) + await ota_provider.download_update(update) + + self._attribute_update_callbacks.setdefault(node_id, []).append( + ota_provider.check_update_state + ) + + try: + if node_id in self._nodes_in_ota: + raise UpdateError( + f"Node {node_id} is already in the process of updating." + ) + + self._nodes_in_ota.add(node_id) + + # Make sure any previous instances get stopped + node_logger.info("Starting update using OTA Provider.") + await ota_provider.start_update( + self._chip_device_controller, + node_id, + ) + finally: + self._attribute_update_callbacks[node_id].remove( + ota_provider.check_update_state + ) + self._nodes_in_ota.remove(node_id) + + async def _check_node_update( + self, + node_id: int, + requested_software_version: int | str | None = None, + ) -> tuple[UpdateSource, dict] | tuple[None, None]: + node_logger = self.get_node_logger(LOGGER, node_id) + node = self._nodes[node_id] + + node_logger.debug("Check for updates.") + vid = cast(int, node.attributes.get(BASIC_INFORMATION_VENDOR_ID_ATTRIBUTE_PATH)) + pid = cast( + int, node.attributes.get(BASIC_INFORMATION_PRODUCT_ID_ATTRIBUTE_PATH) + ) + software_version = cast( + int, node.attributes.get(BASIC_INFORMATION_SOFTWARE_VERSION_ATTRIBUTE_PATH) + ) + software_version_string = node.attributes.get( + BASIC_INFORMATION_SOFTWARE_VERSION_STRING_ATTRIBUTE_PATH + ) + + update_source, update = await check_for_update( + node_logger, vid, pid, software_version, requested_software_version + ) + if not update_source or not update: + node_logger.info("No new update found.") + return None, None + + if "otaUrl" not in update: + raise UpdateCheckError("Update found, but no OTA URL provided.") + + node_logger.info( + "New software update found: %s on %s (current %s).", + update["softwareVersionString"], + update_source, + software_version_string, + ) + return update_source, update + async def _subscribe_node(self, node_id: int) -> None: """ Subscribe to all node state changes/events for an individual node. @@ -934,6 +1091,10 @@ def attribute_updated_callback( # schedule save to persistent storage self._write_node_state(node_id) + if node_id in self._attribute_update_callbacks: + for callback in self._attribute_update_callbacks[node_id]: + self._loop.create_task(callback(path, old_value, new_value)) + # This callback is running in the CHIP stack thread self.server.signal_event( EventType.ATTRIBUTE_UPDATED, diff --git a/matter_server/server/helpers/__init__.py b/matter_server/server/helpers/__init__.py index 8cb651ea..b80f5f6a 100644 --- a/matter_server/server/helpers/__init__.py +++ b/matter_server/server/helpers/__init__.py @@ -1 +1,4 @@ """Helpers/utils for the Matter Server.""" + +DCL_PRODUCTION_URL = "https://on.dcl.csa-iot.org" +DCL_TEST_URL = "https://on.test-net.dcl.csa-iot.org" diff --git a/matter_server/server/helpers/paa_certificates.py b/matter_server/server/helpers/paa_certificates.py index 4a459e65..27fc5875 100644 --- a/matter_server/server/helpers/paa_certificates.py +++ b/matter_server/server/helpers/paa_certificates.py @@ -19,14 +19,14 @@ from cryptography.hazmat.primitives import serialization from cryptography.utils import CryptographyDeprecationWarning +from matter_server.server.helpers import DCL_PRODUCTION_URL, DCL_TEST_URL + # Git repo details OWNER = "project-chip" REPO = "connectedhomeip" PATH = "credentials/development/paa-root-certs" LOGGER = logging.getLogger(__name__) -PRODUCTION_URL = "https://on.dcl.csa-iot.org" -TEST_URL = "https://on.test-net.dcl.csa-iot.org" GIT_URL = f"https://raw.githubusercontent.com/{OWNER}/{REPO}/master/{PATH}" @@ -226,7 +226,7 @@ def _check_paa_root_dir( fetch_count = await fetch_dcl_certificates( paa_root_cert_dir=paa_root_cert_dir, base_name="dcld_production_", - base_url=PRODUCTION_URL, + base_url=DCL_PRODUCTION_URL, ) LOGGER.info("Fetched %s PAA root certificates from DCL.", fetch_count) total_fetch_count += fetch_count @@ -235,7 +235,7 @@ def _check_paa_root_dir( fetch_count = await fetch_dcl_certificates( paa_root_cert_dir=paa_root_cert_dir, base_name="dcld_test_", - base_url=TEST_URL, + base_url=DCL_TEST_URL, ) LOGGER.info("Fetched %s PAA root certificates from Test DCL.", fetch_count) total_fetch_count += fetch_count diff --git a/matter_server/server/ota/__init__.py b/matter_server/server/ota/__init__.py new file mode 100644 index 00000000..21cf52c7 --- /dev/null +++ b/matter_server/server/ota/__init__.py @@ -0,0 +1,54 @@ +"""OTA implementation for the Matter Server.""" + +import asyncio +import json +from logging import LoggerAdapter +from pathlib import Path + +from matter_server.common.models import UpdateSource +from matter_server.server.ota import dcl + +_local_updates: dict[tuple[int, int], dict] = {} + + +async def load_local_updates(ota_provider_dir: Path) -> None: + """Load updates from locally stored json files.""" + + def _load_update(ota_provider_dir: Path) -> None: + if not ota_provider_dir.exists(): + return + for update_file in ota_provider_dir.glob("*.json"): + with open(update_file) as f: + update = json.load(f) + model_version = update["modelVersion"] + _local_updates[(model_version["vid"], model_version["pid"])] = ( + model_version + ) + + await asyncio.get_running_loop().run_in_executor( + None, _load_update, ota_provider_dir + ) + + +async def check_for_update( + logger: LoggerAdapter, + vid: int, + pid: int, + current_software_version: int, + requested_software_version: int | str | None = None, +) -> tuple[UpdateSource, dict] | tuple[None, None]: + """Check for software updates.""" + if (vid, pid) in _local_updates: + local_update = _local_updates[(vid, pid)] + if ( + requested_software_version is None + or local_update["softwareVersion"] == requested_software_version + or local_update["softwareVersionString"] == requested_software_version + ): + return UpdateSource.LOCAL, local_update + + if dcl_update := await dcl.check_for_update( + logger, vid, pid, current_software_version, requested_software_version + ): + return UpdateSource.MAIN_NET_DCL, dcl_update + return None, None diff --git a/matter_server/server/ota/dcl.py b/matter_server/server/ota/dcl.py new file mode 100644 index 00000000..d555def0 --- /dev/null +++ b/matter_server/server/ota/dcl.py @@ -0,0 +1,134 @@ +"""Handle OTA software version endpoints of the DCL.""" + +from http import HTTPStatus +import logging +from typing import Any, cast + +from aiohttp import ClientError, ClientSession + +from matter_server.common.errors import UpdateCheckError +from matter_server.server.helpers import DCL_PRODUCTION_URL + + +async def _get_software_versions(session: ClientSession, vid: int, pid: int) -> Any: + """Check DCL if there are updates available for a particular node.""" + # fetch the paa certificates list + async with session.get(f"/dcl/model/versions/{vid}/{pid}") as response: + if response.status == HTTPStatus.NOT_FOUND: + return None + response.raise_for_status() + return await response.json() + + +async def _get_software_version( + session: ClientSession, vid: int, pid: int, software_version: int +) -> Any: + """Check DCL if there are updates available for a particular node.""" + # fetch the paa certificates list + async with session.get( + f"/dcl/model/versions/{vid}/{pid}/{software_version}" + ) as response: + response.raise_for_status() + return await response.json() + + +async def _check_update_version( + session: ClientSession, + vid: int, + pid: int, + current_software_version: int, + requested_software_version: int, + requested_software_version_string: str | None = None, +) -> None | dict: + version_res: dict = await _get_software_version( + session, vid, pid, requested_software_version + ) + if not isinstance(version_res, dict): + raise TypeError("Unexpected DCL response.") + + if "modelVersion" not in version_res: + raise ValueError("Unexpected DCL response.") + + version_candidate: dict = cast(dict, version_res["modelVersion"]) + + # If we are looking for a specific version by string, check if it matches + if ( + requested_software_version_string is not None + and version_candidate["softwareVersionString"] + != requested_software_version_string + ): + return None + + # Check minApplicableSoftwareVersion/maxApplicableSoftwareVersion + min_sw_version = version_candidate["minApplicableSoftwareVersion"] + max_sw_version = version_candidate["maxApplicableSoftwareVersion"] + if ( + current_software_version < min_sw_version + or current_software_version > max_sw_version + ): + return None + + return version_candidate + + +async def check_for_update( + logger: logging.LoggerAdapter, + vid: int, + pid: int, + current_software_version: int, + requested_software_version: int | str | None = None, +) -> None | dict: + """Check if there is a software update available on the DCL.""" + try: + async with ClientSession( + base_url=DCL_PRODUCTION_URL, raise_for_status=False + ) as session: + # If a specific version as integer is requested, just fetch it (and hope it exists) + if isinstance(requested_software_version, int): + return await _check_update_version( + session, + vid, + pid, + current_software_version, + requested_software_version, + ) + + # Get all versions and check each one of them. + versions = await _get_software_versions(session, vid, pid) + if versions is None: + logger.info( + "There is no update information for this device on the DCL." + ) + return None + + all_software_versions: list[int] = versions["modelVersions"][ + "softwareVersions" + ] + newer_software_versions = [ + version + for version in all_software_versions + if version > current_software_version + ] + + # Check if there is a newer software version available, no downgrade possible + if not newer_software_versions: + return None + + # Check if latest firmware is applicable, and backtrack from there + for version in sorted(newer_software_versions, reverse=True): + if version_candidate := await _check_update_version( + session, + vid, + pid, + current_software_version, + version, + requested_software_version, + ): + return version_candidate + logger.debug("Software version %d not applicable.", version) + return None + + except (ClientError, TimeoutError) as err: + raise UpdateCheckError( + f"Fetching software versions from DCL for device with vendor id {vid} product id {pid} failed." + ) from err diff --git a/matter_server/server/ota/provider.py b/matter_server/server/ota/provider.py new file mode 100644 index 00000000..cb583552 --- /dev/null +++ b/matter_server/server/ota/provider.py @@ -0,0 +1,367 @@ +"""Handling Matter OTA provider.""" + +from __future__ import annotations + +import asyncio +from base64 import b64encode +from datetime import UTC, datetime +import functools +import hashlib +import logging +from pathlib import Path +import secrets +from typing import TYPE_CHECKING, Any, Final, cast +from urllib.parse import unquote, urlparse + +from aiohttp import ClientError, ClientSession +from aiohttp.client_exceptions import InvalidURL +from chip.clusters import Attribute, Objects as Clusters, Types +from chip.discovery import FilterType +from chip.exceptions import ChipStackError +from chip.interaction_model import Status + +from matter_server.common.errors import UpdateError +from matter_server.common.helpers.util import ( + create_attribute_path_from_attribute, +) + +if TYPE_CHECKING: + from asyncio.subprocess import Process + + from matter_server.server.sdk import ChipDeviceControllerWrapper + +LOGGER = logging.getLogger(__name__) + +DEFAULT_OTA_PROVIDER_NODE_ID: Final[int] = 990000 + +OTA_SOFTWARE_UPDATE_REQUESTOR_UPDATE_STATE_ATTRIBUTE_PATH = ( + create_attribute_path_from_attribute( + 0, Clusters.OtaSoftwareUpdateRequestor.Attributes.UpdateState + ) +) + +# From Matter SDK src/app/ota_image_tool.py +CHECHKSUM_TYPES: Final[dict[int, str]] = { + 1: "sha256", + 2: "sha256_128", + 3: "sha256_120", + 4: "sha256_96", + 5: "sha256_64", + 6: "sha256_32", + 7: "sha384", + 8: "sha512", + 9: "sha3_224", + 10: "sha3_256", + 11: "sha3_384", + 12: "sha3_512", +} + + +class ExternalOtaProvider: + """Class handling Matter OTA Provider. + + The OTA Provider class implements a Matter OTA (over-the-air) update provider + for devices. + """ + + ENDPOINT_ID: Final[int] = 0 + + def __init__(self, vendor_id: int, ota_provider_dir: Path) -> None: + """Initialize the OTA provider.""" + self._vendor_id: int = vendor_id + self._ota_provider_dir: Path = ota_provider_dir + self._ota_file_path: Path | None = None + self._ota_provider_proc: Process | None = None + self._ota_provider_task: asyncio.Task | None = None + self._ota_done: asyncio.Future = asyncio.Future() + self._ota_target_node_id: int | None = None + + async def initialize(self) -> None: + """Initialize OTA Provider.""" + + loop = asyncio.get_event_loop() + + await loop.run_in_executor( + None, functools.partial(self._ota_provider_dir.mkdir, exist_ok=True) + ) + + async def _commission_ota_provider( + self, + chip_device_controller: ChipDeviceControllerWrapper, + passcode: int, + discriminator: int, + ota_provider_node_id: int, + ) -> None: + """Commissions the OTA Provider, returns node ID of the commissioned provider.""" + + # Adjust ACL of OTA Requestor such that Node peer-to-peer communication + # is allowed. + try: + commissioned_node_id = await chip_device_controller.commission_on_network( + ota_provider_node_id, + passcode, + disc_filter_type=FilterType.LONG_DISCRIMINATOR, + disc_filter=discriminator, + ) + assert commissioned_node_id == ota_provider_node_id + + LOGGER.info( + "OTA Provider App commissioned with node id %d.", + ota_provider_node_id, + ) + + read_result = cast( + Attribute.AsyncReadTransaction.ReadResponse, + await chip_device_controller.read_attribute( + ota_provider_node_id, + [(0, Clusters.AccessControl.Attributes.Acl)], + ), + ) + acl_list = cast( + list, + read_result.attributes[0][Clusters.AccessControl][ + Clusters.AccessControl.Attributes.Acl + ], + ) + + # Add new ACL entry... + acl_list.append( + Clusters.AccessControl.Structs.AccessControlEntryStruct( + fabricIndex=1, + privilege=Clusters.AccessControl.Enums.AccessControlEntryPrivilegeEnum.kOperate, + authMode=Clusters.AccessControl.Enums.AccessControlEntryAuthModeEnum.kCase, + subjects=Types.NullValue, + targets=[ + Clusters.AccessControl.Structs.AccessControlTargetStruct( + cluster=Clusters.OtaSoftwareUpdateProvider.id, + endpoint=0, + deviceType=Types.NullValue, + ) + ], + ) + ) + + # And write. This is persistent, so only need to be done after we commissioned + # the OTA Provider App. + write_result: Attribute.AttributeWriteResult = ( + await chip_device_controller.write_attribute( + ota_provider_node_id, + [(0, Clusters.AccessControl.Attributes.Acl(acl_list))], + ) + ) + if write_result[0].Status != Status.Success: + logging.error( + "Failed writing adjusted OTA Provider App ACL: Status %s.", + str(write_result[0].Status), + ) + raise UpdateError("Error while setting up OTA Provider.") + except ChipStackError as ex: + logging.exception("Failed setting up OTA Provider.", exc_info=ex) + raise UpdateError("Error while setting up OTA Provider.") from ex + + async def start_update( + self, chip_device_controller: ChipDeviceControllerWrapper, node_id: int + ) -> None: + """Start the OTA Provider and trigger the update.""" + + self._ota_target_node_id = node_id + + loop = asyncio.get_running_loop() + + ota_provider_passcode = secrets.randbelow(2**21) + ota_provider_discriminator = secrets.randbelow(2**12) + + timestamp = datetime.now(tz=UTC).strftime("%Y%m%d_%H%M%S") + ota_provider_cmd = [ + "chip-ota-provider-app", + "--passcode", + str(ota_provider_passcode), + "--discriminator", + str(ota_provider_discriminator), + "--secured-device-port", + "5540", + "--KVS", + str(self._ota_provider_dir / f"chip_kvs_ota_provider_{timestamp}"), + "--filepath", + str(self._ota_file_path), + ] + + log_file_path = self._ota_provider_dir / f"ota_provider_{timestamp}.log" + + log_file = await loop.run_in_executor(None, log_file_path.open, "w") + + try: + LOGGER.info("Starting OTA Provider") + self._ota_provider_proc = await asyncio.create_subprocess_exec( + *ota_provider_cmd, stdout=log_file, stderr=log_file + ) + + self._ota_provider_task = loop.create_task( + self._ota_provider_proc.communicate() + ) + + # Commission and prepare ephemeral OTA Provider + LOGGER.info("Commission and initialize OTA Provider") + ota_provider_node_id = ( + DEFAULT_OTA_PROVIDER_NODE_ID + self._ota_target_node_id + ) + await self._commission_ota_provider( + chip_device_controller, + ota_provider_passcode, + ota_provider_discriminator, + ota_provider_node_id, + ) + + # Notify update node about the availability of the OTA Provider. + # It will query the OTA provider and start the update. + try: + await chip_device_controller.send_command( + node_id, + endpoint_id=0, + command=Clusters.OtaSoftwareUpdateRequestor.Commands.AnnounceOTAProvider( + providerNodeID=ota_provider_node_id, + vendorID=self._vendor_id, + announcementReason=Clusters.OtaSoftwareUpdateRequestor.Enums.AnnouncementReasonEnum.kUpdateAvailable, + endpoint=ExternalOtaProvider.ENDPOINT_ID, + ), + ) + except ChipStackError as ex: + raise UpdateError( + "Error while announcing OTA Provider to node." + ) from ex + + LOGGER.info("Waiting for target node update state change") + await self._ota_done + LOGGER.info("OTA update finished successfully") + finally: + LOGGER.info("Cleaning up OTA provider") + await self.stop() + self._ota_target_node_id = None + + async def _reset(self) -> None: + """Reset the OTA Provider App state.""" + + def _remove_update_data(ota_provider_dir: Path) -> None: + for path in ota_provider_dir.iterdir(): + if not path.is_dir(): + path.unlink() + + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, _remove_update_data, self._ota_provider_dir) + + await self.initialize() + + async def stop(self) -> None: + """Stop the OTA Provider.""" + if self._ota_provider_proc: + LOGGER.info("Terminating OTA Provider") + loop = asyncio.get_event_loop() + try: + await loop.run_in_executor(None, self._ota_provider_proc.terminate) + except ProcessLookupError as ex: + LOGGER.warning("Stopping OTA Provider failed with error:", exc_info=ex) + if self._ota_provider_task: + await self._ota_provider_task + self._ota_provider_proc = None + self._ota_provider_task = None + + async def download_update(self, update_desc: dict) -> None: + """Download update file from OTA Path and add it to the OTA provider.""" + + url = update_desc["otaUrl"] + parsed_url = urlparse(url) + file_name = unquote(Path(parsed_url.path).name) + + loop = asyncio.get_running_loop() + + file_path = self._ota_provider_dir / file_name + + try: + checksum_alg = None + if ( + "otaChecksum" in update_desc + and "otaChecksumType" in update_desc + and update_desc["otaChecksumType"] in CHECHKSUM_TYPES + ): + checksum_alg = hashlib.new( + CHECHKSUM_TYPES[update_desc["otaChecksumType"]] + ) + else: + LOGGER.warning( + "No OTA checksum type or not supported, OTA will not be checked." + ) + + async with ClientSession(raise_for_status=True) as session: + # fetch the paa certificates list + LOGGER.debug("Download update from '%s'.", url) + async with session.get(url) as response: + with file_path.open("wb") as f: + while True: + chunk = await response.content.read(4048) + if not chunk: + break + await loop.run_in_executor(None, f.write, chunk) + if checksum_alg: + checksum_alg.update(chunk) + + # Download finished, check checksum if necessary + if checksum_alg: + checksum = b64encode(checksum_alg.digest()).decode("ascii") + if checksum != update_desc["otaChecksum"]: + LOGGER.error( + "Checksum mismatch for file '%s', expected: %s, got: %s", + file_name, + update_desc["otaChecksum"], + checksum, + ) + await loop.run_in_executor(None, file_path.unlink) + raise UpdateError("Checksum mismatch!") + + LOGGER.info( + "Update file '%s' downloaded to '%s'", + file_name, + self._ota_provider_dir, + ) + + except (InvalidURL, ClientError, TimeoutError) as err: + LOGGER.error( + "Fetching software version failed: error %s", err, exc_info=err + ) + raise UpdateError("Fetching software version failed") from err + + self._ota_file_path = file_path + + async def check_update_state( + self, + path: Attribute.AttributePath, + old_value: Any, + new_value: Any, + ) -> None: + """Check the update state of a node and take appropriate action.""" + + if str(path) != OTA_SOFTWARE_UPDATE_REQUESTOR_UPDATE_STATE_ATTRIBUTE_PATH: + return + + UpdateState = Clusters.OtaSoftwareUpdateRequestor.Enums.UpdateStateEnum # noqa: N806 + + new_update_state = UpdateState(new_value) + old_update_state = UpdateState(old_value) + + LOGGER.info( + "Update state changed from %r to %r", + old_update_state, + new_update_state, + ) + + # Update state of target node changed, check if update is done. + if new_update_state == UpdateState.kIdle: + if old_update_state == UpdateState.kQuerying: + self._ota_done.set_exception( + UpdateError("Target node did not process the update file") + ) + return + + LOGGER.info( + "Node %d update state idle, assuming done.", self._ota_target_node_id + ) + self._ota_done.set_result(None) diff --git a/matter_server/server/server.py b/matter_server/server/server.py index 63a71261..25d08da3 100644 --- a/matter_server/server/server.py +++ b/matter_server/server/server.py @@ -30,7 +30,11 @@ ServerInfoMessage, ) from ..server.client_handler import WebsocketClientHandler -from .const import DEFAULT_PAA_ROOT_CERTS_DIR, MIN_SCHEMA_VERSION +from .const import ( + DEFAULT_OTA_PROVIDER_DIR, + DEFAULT_PAA_ROOT_CERTS_DIR, + MIN_SCHEMA_VERSION, +) from .device_controller import MatterDeviceController from .stack import MatterStack from .storage import StorageController @@ -91,12 +95,12 @@ async def _handle_shutdown(app: web.Application) -> None: class MatterServer: """Serve Matter stack over WebSockets.""" - # pylint: disable=too-many-instance-attributes + # pylint: disable=too-many-instance-attributes,too-many-arguments _runner: web.AppRunner | None = None _http: MultiHostTCPSite | None = None - def __init__( + def __init__( # noqa: PLR0913 self, storage_path: str, vendor_id: int, @@ -107,6 +111,7 @@ def __init__( paa_root_cert_dir: Path | None = None, enable_test_net_dcl: bool = False, bluetooth_adapter_id: int | None = None, + ota_provider_dir: Path | None = None, ) -> None: """Initialize the Matter Server.""" self.storage_path = storage_path @@ -121,6 +126,10 @@ def __init__( self.paa_root_cert_dir = Path(paa_root_cert_dir).absolute() self.enable_test_net_dcl = enable_test_net_dcl self.bluetooth_enabled = bluetooth_adapter_id is not None + if ota_provider_dir is None: + self.ota_provider_dir = DEFAULT_OTA_PROVIDER_DIR + else: + self.ota_provider_dir = Path(ota_provider_dir).absolute() self.logger = logging.getLogger(__name__) self.app = web.Application() self.loop: asyncio.AbstractEventLoop | None = None @@ -165,7 +174,9 @@ async def start(self) -> None: # Initialize our (intermediate) device controller which keeps track # of Matter devices and their subscriptions. - self._device_controller = MatterDeviceController(self, self.paa_root_cert_dir) + self._device_controller = MatterDeviceController( + self, self.paa_root_cert_dir, self.ota_provider_dir + ) self._register_api_commands() await self._device_controller.initialize() diff --git a/tests/server/ota/test_dcl.py b/tests/server/ota/test_dcl.py new file mode 100644 index 00000000..e86cfb8b --- /dev/null +++ b/tests/server/ota/test_dcl.py @@ -0,0 +1,85 @@ +"""Test DCL OTA updates.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from matter_server.server.ota.dcl import check_for_update + +# Mock the DCL responses (sample from https://on.dcl.csa-iot.org/dcl/model/versions/4447/8194) +DCL_RESPONSE_SOFTWARE_VERSIONS = { + "modelVersions": { + "vid": 4447, + "pid": 8194, + "softwareVersions": [1000, 1011], + } +} + +# Mock the DCL responses (sample from https://on.dcl.csa-iot.org/dcl/model/versions/4447/8194/1011) +DCL_RESPONSE_SOFTWARE_VERSION_1011 = { + "modelVersion": { + "vid": 4447, + "pid": 8194, + "softwareVersion": 1011, + "softwareVersionString": "1.0.1.1", + "cdVersionNumber": 1, + "firmwareInformation": "", + "softwareVersionValid": True, + "otaUrl": "https://cdn.aqara.com/cdn/opencloud-product/mainland/product-firmware/prd/aqara.matter.4447_8194/20240306154144_rel_up_to_enc_ota_sbl_app_aqara.matter.4447_8194_1.0.1.1_115F_2002_20240115195007_7a9b91.ota", + "otaFileSize": "615708", + "otaChecksum": "rFZ6WdH0DuuCf7HVoRmNjCF73mYZ98DGYpHoDKmf0Bw=", + "otaChecksumType": 1, + "minApplicableSoftwareVersion": 1000, + "maxApplicableSoftwareVersion": 1010, + "releaseNotesUrl": "", + "creator": "cosmos1qpz3ghnqj6my7gzegkftzav9hpxymkx6zdk73v", + } +} + + +@pytest.fixture(name="get_software_versions") +def mock_get_software_versions(): + """Mock the _get_software_versions function.""" + with patch( + "matter_server.server.ota.dcl._get_software_versions", + new_callable=AsyncMock, + return_value=DCL_RESPONSE_SOFTWARE_VERSIONS, + ) as mock: + yield mock + + +@pytest.fixture(name="get_software_version") +def mock_get_software_version(): + """Mock the _get_software_version function.""" + with patch( + "matter_server.server.ota.dcl._get_software_version", + new_callable=AsyncMock, + return_value=DCL_RESPONSE_SOFTWARE_VERSION_1011, + ) as mock: + yield mock + + +async def test_check_updates(get_software_versions, get_software_version): + """Test the case where the latest software version is applicable.""" + # Call the function with a current software version of 1000 + result = await check_for_update(MagicMock(), 4447, 8194, 1000) + + assert result == DCL_RESPONSE_SOFTWARE_VERSION_1011["modelVersion"] + + +async def test_check_updates_not_applicable( + get_software_versions, get_software_version +): + """Test the case where the latest software version is not applicable.""" + # Call the function with a current software version of 1 + result = await check_for_update(MagicMock(), 4447, 8194, 1) + + assert result is None + + +async def test_check_updates_specific_version(get_software_version): + """Test the case to get a specific version.""" + # Call the function with a current software version of 1000 and request 1011 as update + result = await check_for_update(MagicMock(), 4447, 8194, 1000, 1011) + + assert result == DCL_RESPONSE_SOFTWARE_VERSION_1011["modelVersion"]