diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index f2b34880..d2885b71 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -27,7 +27,7 @@ from matter_server.common.models import CommissionableNodeData, CommissioningParameters from matter_server.server.helpers.attributes import parse_attributes_from_read_result from matter_server.server.helpers.utils import ping_ip -from matter_server.server.ota.dcl import check_updates +from matter_server.server.ota.dcl import check_for_update from matter_server.server.ota.provider import ExternalOtaProvider from ..common.errors import ( @@ -971,7 +971,7 @@ async def update_node(self, node_id: int) -> dict | None: BASIC_INFORMATION_SOFTWARE_VERSION_STRING_ATTRIBUTE_PATH ) - update = await check_updates(node_id, vid, pid, software_version) + update = await check_for_update(vid, pid, software_version) if not update: node_logger.info("No new update found.") return None diff --git a/matter_server/server/ota/dcl.py b/matter_server/server/ota/dcl.py index 3ce0eace..29411a8d 100644 --- a/matter_server/server/ota/dcl.py +++ b/matter_server/server/ota/dcl.py @@ -1,7 +1,7 @@ """Handle OTA software version endpoints of the DCL.""" import logging -from typing import Any +from typing import Any, cast from aiohttp import ClientError, ClientSession @@ -10,7 +10,7 @@ LOGGER = logging.getLogger(__name__) -async def get_software_versions(node_id: int, vid: int, pid: int) -> Any: +async def get_software_versions(vid: int, pid: int) -> Any: """Check DCL if there are updates available for a particular node.""" async with ClientSession(raise_for_status=True) as http_session: # fetch the paa certificates list @@ -20,9 +20,7 @@ async def get_software_versions(node_id: int, vid: int, pid: int) -> Any: return await response.json() -async def get_software_version( - node_id: int, vid: int, pid: int, software_version: int -) -> Any: +async def get_software_version(vid: int, pid: int, software_version: int) -> Any: """Check DCL if there are updates available for a particular node.""" async with ClientSession(raise_for_status=True) as http_session: # fetch the paa certificates list @@ -32,27 +30,47 @@ async def get_software_version( return await response.json() -async def check_updates( - node_id: int, vid: int, pid: int, current_software_version: int +async def check_for_update( + vid: int, pid: int, current_software_version: int ) -> None | dict: """Check if there is a newer software version available on the DCL.""" try: - versions = await get_software_versions(node_id, vid, pid) + versions = await get_software_versions(vid, pid) - software_versions: list[int] = versions["modelVersions"]["softwareVersions"] - latest_software_version = max(software_versions) - if latest_software_version <= current_software_version: + all_software_versions: list[int] = versions["modelVersions"]["softwareVersions"] + newer_software_versions = [ + version + for version in all_software_versions + if version > current_software_version + ] + + # Check if there is a newer software version available + if not newer_software_versions: + LOGGER.info("No newer software version available.") return None - version: dict = await get_software_version( - node_id, vid, pid, latest_software_version - ) - if isinstance(version, dict) and "modelVersion" in version: - result: Any = version["modelVersion"] - if isinstance(result, dict): - return result + # Check if latest firmware is applicable, and backtrack from there + for version in sorted(newer_software_versions, reverse=True): + version_res: dict = await get_software_version(vid, pid, version) + if not isinstance(version_res, dict): + raise TypeError("Unexpected DCL response.") + + if "modelVersion" not in version_res: + raise ValueError("Unexpected DCL response.") + + version_candidate: dict = cast(dict, version_res["modelVersion"]) + + # Check minApplicableSoftwareVersion/maxApplicableSoftwareVersion + min_sw_version = version_candidate["minApplicableSoftwareVersion"] + max_sw_version = version_candidate["maxApplicableSoftwareVersion"] + if ( + current_software_version < min_sw_version + or current_software_version > max_sw_version + ): + LOGGER.debug("Software version %d not applicable.", version) + continue - logging.error("Unexpected DCL response.") + return version_candidate return None except (ClientError, TimeoutError) as err: diff --git a/tests/server/ota/test_dcl.py b/tests/server/ota/test_dcl.py new file mode 100644 index 00000000..7de82a4e --- /dev/null +++ b/tests/server/ota/test_dcl.py @@ -0,0 +1,75 @@ +"""Test DCL OTA updates.""" + +from unittest.mock import AsyncMock, patch + +from matter_server.server.ota.dcl import check_for_update + +# Mock the DCL responses (sample from https://on.dcl.csa-iot.org/dcl/model/versions/4447/8194) +DCL_RESPONSE_SOFTWARE_VERSIONS = { + "modelVersions": { + "vid": 4447, + "pid": 8194, + "softwareVersions": [1000, 1011], + } +} + +# Mock the DCL responses (sample from https://on.dcl.csa-iot.org/dcl/model/versions/4447/8194/1011) +DCL_RESPONSE_SOFTWARE_VERSION_1011 = { + "modelVersion": { + "vid": 4447, + "pid": 8194, + "softwareVersion": 1011, + "softwareVersionString": "1.0.1.1", + "cdVersionNumber": 1, + "firmwareInformation": "", + "softwareVersionValid": True, + "otaUrl": "https://cdn.aqara.com/cdn/opencloud-product/mainland/product-firmware/prd/aqara.matter.4447_8194/20240306154144_rel_up_to_enc_ota_sbl_app_aqara.matter.4447_8194_1.0.1.1_115F_2002_20240115195007_7a9b91.ota", + "otaFileSize": "615708", + "otaChecksum": "rFZ6WdH0DuuCf7HVoRmNjCF73mYZ98DGYpHoDKmf0Bw=", + "otaChecksumType": 1, + "minApplicableSoftwareVersion": 1000, + "maxApplicableSoftwareVersion": 1010, + "releaseNotesUrl": "", + "creator": "cosmos1qpz3ghnqj6my7gzegkftzav9hpxymkx6zdk73v", + } +} + + +async def test_check_updates(): + """Test the case where the latest software version is applicable.""" + with ( + patch( + "matter_server.server.ota.dcl.get_software_versions", + new_callable=AsyncMock, + return_value=DCL_RESPONSE_SOFTWARE_VERSIONS, + ), + patch( + "matter_server.server.ota.dcl.get_software_version", + new_callable=AsyncMock, + return_value=DCL_RESPONSE_SOFTWARE_VERSION_1011, + ), + ): + # Call the function with a current software version of 1000 + result = await check_for_update(4447, 8194, 1000) + + assert result == DCL_RESPONSE_SOFTWARE_VERSION_1011["modelVersion"] + + +async def test_check_updates_not_applicable(): + """Test the case where the latest software version is not applicable.""" + with ( + patch( + "matter_server.server.ota.dcl.get_software_versions", + new_callable=AsyncMock, + return_value=DCL_RESPONSE_SOFTWARE_VERSIONS, + ), + patch( + "matter_server.server.ota.dcl.get_software_version", + new_callable=AsyncMock, + return_value=DCL_RESPONSE_SOFTWARE_VERSION_1011, + ), + ): + # Call the function with a current software version of 1 + result = await check_for_update(4447, 8194, 1) + + assert result is None