diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index 5f730ff5..54bf15b4 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -972,13 +972,15 @@ async def update_node(self, node_id: int, software_version: int | str) -> None: # Add update to the OTA provider ota_provider = ExternalOtaProvider( - self.server.vendor_id, self._ota_provider_dir / f"{node_id}" + self.server.vendor_id, + self._ota_provider_dir, + self._ota_provider_dir / f"{node_id}", ) await ota_provider.initialize() node_logger.info("Downloading update from '%s'", update["otaUrl"]) - await ota_provider.download_update(update) + await ota_provider.fetch_update(update) self._attribute_update_callbacks.setdefault(node_id, []).append( ota_provider.check_update_state diff --git a/matter_server/server/ota/provider.py b/matter_server/server/ota/provider.py index 6a39fab2..034a027d 100644 --- a/matter_server/server/ota/provider.py +++ b/matter_server/server/ota/provider.py @@ -65,9 +65,12 @@ class ExternalOtaProvider: ENDPOINT_ID: Final[int] = 0 - def __init__(self, vendor_id: int, ota_provider_dir: Path) -> None: + def __init__( + self, vendor_id: int, ota_provider_base_dir: Path, ota_provider_dir: Path + ) -> None: """Initialize the OTA provider.""" self._vendor_id: int = vendor_id + self._ota_provider_base_dir: Path = ota_provider_base_dir self._ota_provider_dir: Path = ota_provider_dir self._ota_file_path: Path | None = None self._ota_provider_proc: Process | None = None @@ -261,10 +264,11 @@ async def stop(self) -> None: self._ota_provider_proc = None self._ota_provider_task = None - async def download_update(self, update_desc: dict) -> None: - """Download update file from OTA Path and add it to the OTA provider.""" + async def _download_update( + self, url: str, checksum_alg: hashlib._Hash | None + ) -> Path: + """Download update file from OTA URL.""" - url = update_desc["otaUrl"] parsed_url = urlparse(url) file_name = unquote(Path(parsed_url.path).name) @@ -273,20 +277,6 @@ async def download_update(self, update_desc: dict) -> None: file_path = self._ota_provider_dir / file_name try: - checksum_alg = None - if ( - "otaChecksum" in update_desc - and "otaChecksumType" in update_desc - and update_desc["otaChecksumType"] in CHECHKSUM_TYPES - ): - checksum_alg = hashlib.new( - CHECHKSUM_TYPES[update_desc["otaChecksumType"]] - ) - else: - LOGGER.warning( - "No OTA checksum type or not supported, OTA will not be checked." - ) - async with ClientSession(raise_for_status=True) as session: # fetch the paa certificates list LOGGER.debug("Download update from '%s'.", url) @@ -300,20 +290,6 @@ async def download_update(self, update_desc: dict) -> None: if checksum_alg: checksum_alg.update(chunk) - # Download finished, check checksum if necessary - if checksum_alg: - checksum = b64encode(checksum_alg.digest()).decode("ascii") - checksum_expected = update_desc["otaChecksum"].strip() - if checksum != checksum_expected: - LOGGER.error( - "Checksum mismatch for file '%s', expected: '%s', got: '%s'", - file_name, - checksum_expected, - checksum, - ) - await loop.run_in_executor(None, file_path.unlink) - raise UpdateError("Checksum mismatch!") - LOGGER.info( "Update file '%s' downloaded to '%s'", file_name, @@ -326,6 +302,55 @@ async def download_update(self, update_desc: dict) -> None: ) raise UpdateError("Fetching software version failed") from err + return file_path + + async def fetch_update(self, update_desc: dict) -> None: + """Fetch update file from OTA URL.""" + url = update_desc["otaUrl"] + parsed_url = urlparse(url) + file_name = unquote(Path(parsed_url.path).name) + + loop = asyncio.get_running_loop() + + checksum_alg = None + if ( + "otaChecksum" in update_desc + and "otaChecksumType" in update_desc + and update_desc["otaChecksumType"] in CHECHKSUM_TYPES + ): + checksum_alg = hashlib.new(CHECHKSUM_TYPES[update_desc["otaChecksumType"]]) + else: + LOGGER.warning( + "No OTA checksum type or not supported, OTA will not be checked." + ) + + if parsed_url.scheme in ["http", "https"]: + file_path = await self._download_update(url, checksum_alg) + elif parsed_url.scheme in ["file"]: + file_path = self._ota_provider_base_dir / Path(parsed_url.path[1:]) + if not file_path.exists(): + logging.warning("Local update file not found: %s", file_path) + raise UpdateError("Local update file not found") + if checksum_alg: + checksum_alg.update( + await loop.run_in_executor(None, file_path.read_bytes) + ) + else: + raise UpdateError("Unsupported OTA URL scheme") + + # Download finished, check checksum if necessary + if checksum_alg: + checksum_expected = update_desc["otaChecksum"].strip() + checksum = b64encode(checksum_alg.digest()).decode("ascii") + if checksum != checksum_expected: + LOGGER.error( + "Checksum mismatch for file '%s', expected: '%s', got: '%s'", + file_name, + checksum_expected, + checksum, + ) + raise UpdateError("Checksum mismatch!") + self._ota_file_path = file_path async def check_update_state(