Skip to content

Commit

Permalink
Support local update file for OTA
Browse files Browse the repository at this point in the history
Support local update file for Matter OTA updates. The otaUrl must use
the file:/// URL scheme. The path is relative to the OTA provider
directory.
  • Loading branch information
agners committed Sep 12, 2024
1 parent f29cc35 commit 92404a9
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 34 deletions.
6 changes: 4 additions & 2 deletions matter_server/server/device_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,13 +972,15 @@ async def update_node(self, node_id: int, software_version: int | str) -> None:

# Add update to the OTA provider
ota_provider = ExternalOtaProvider(
self.server.vendor_id, self._ota_provider_dir / f"{node_id}"
self.server.vendor_id,
self._ota_provider_dir,
self._ota_provider_dir / f"{node_id}",
)

await ota_provider.initialize()

node_logger.info("Downloading update from '%s'", update["otaUrl"])
await ota_provider.download_update(update)
await ota_provider.fetch_update(update)

self._attribute_update_callbacks.setdefault(node_id, []).append(
ota_provider.check_update_state
Expand Down
89 changes: 57 additions & 32 deletions matter_server/server/ota/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,12 @@ class ExternalOtaProvider:

ENDPOINT_ID: Final[int] = 0

def __init__(self, vendor_id: int, ota_provider_dir: Path) -> None:
def __init__(
self, vendor_id: int, ota_provider_base_dir: Path, ota_provider_dir: Path
) -> None:
"""Initialize the OTA provider."""
self._vendor_id: int = vendor_id
self._ota_provider_base_dir: Path = ota_provider_base_dir
self._ota_provider_dir: Path = ota_provider_dir
self._ota_file_path: Path | None = None
self._ota_provider_proc: Process | None = None
Expand Down Expand Up @@ -261,10 +264,11 @@ async def stop(self) -> None:
self._ota_provider_proc = None
self._ota_provider_task = None

async def download_update(self, update_desc: dict) -> None:
"""Download update file from OTA Path and add it to the OTA provider."""
async def _download_update(
self, url: str, checksum_alg: hashlib._Hash | None
) -> Path:
"""Download update file from OTA URL."""

url = update_desc["otaUrl"]
parsed_url = urlparse(url)
file_name = unquote(Path(parsed_url.path).name)

Expand All @@ -273,20 +277,6 @@ async def download_update(self, update_desc: dict) -> None:
file_path = self._ota_provider_dir / file_name

try:
checksum_alg = None
if (
"otaChecksum" in update_desc
and "otaChecksumType" in update_desc
and update_desc["otaChecksumType"] in CHECHKSUM_TYPES
):
checksum_alg = hashlib.new(
CHECHKSUM_TYPES[update_desc["otaChecksumType"]]
)
else:
LOGGER.warning(
"No OTA checksum type or not supported, OTA will not be checked."
)

async with ClientSession(raise_for_status=True) as session:
# fetch the paa certificates list
LOGGER.debug("Download update from '%s'.", url)
Expand All @@ -300,20 +290,6 @@ async def download_update(self, update_desc: dict) -> None:
if checksum_alg:
checksum_alg.update(chunk)

# Download finished, check checksum if necessary
if checksum_alg:
checksum = b64encode(checksum_alg.digest()).decode("ascii")
checksum_expected = update_desc["otaChecksum"].strip()
if checksum != checksum_expected:
LOGGER.error(
"Checksum mismatch for file '%s', expected: '%s', got: '%s'",
file_name,
checksum_expected,
checksum,
)
await loop.run_in_executor(None, file_path.unlink)
raise UpdateError("Checksum mismatch!")

LOGGER.info(
"Update file '%s' downloaded to '%s'",
file_name,
Expand All @@ -326,6 +302,55 @@ async def download_update(self, update_desc: dict) -> None:
)
raise UpdateError("Fetching software version failed") from err

return file_path

async def fetch_update(self, update_desc: dict) -> None:
"""Fetch update file from OTA URL."""
url = update_desc["otaUrl"]
parsed_url = urlparse(url)
file_name = unquote(Path(parsed_url.path).name)

loop = asyncio.get_running_loop()

checksum_alg = None
if (
"otaChecksum" in update_desc
and "otaChecksumType" in update_desc
and update_desc["otaChecksumType"] in CHECHKSUM_TYPES
):
checksum_alg = hashlib.new(CHECHKSUM_TYPES[update_desc["otaChecksumType"]])
else:
LOGGER.warning(
"No OTA checksum type or not supported, OTA will not be checked."
)

if parsed_url.scheme in ["http", "https"]:
file_path = await self._download_update(url, checksum_alg)
elif parsed_url.scheme in ["file"]:
file_path = self._ota_provider_base_dir / Path(parsed_url.path[1:])
if not file_path.exists():
logging.warning("Local update file not found: %s", file_path)
raise UpdateError("Local update file not found")
if checksum_alg:
checksum_alg.update(
await loop.run_in_executor(None, file_path.read_bytes)
)
else:
raise UpdateError("Unsupported OTA URL scheme")

# Download finished, check checksum if necessary
if checksum_alg:
checksum_expected = update_desc["otaChecksum"].strip()
checksum = b64encode(checksum_alg.digest()).decode("ascii")
if checksum != checksum_expected:
LOGGER.error(
"Checksum mismatch for file '%s', expected: '%s', got: '%s'",
file_name,
checksum_expected,
checksum,
)
raise UpdateError("Checksum mismatch!")

self._ota_file_path = file_path

async def check_update_state(
Expand Down

0 comments on commit 92404a9

Please sign in to comment.