Skip to content

Commit

Permalink
Support local update file for OTA (#884)
Browse files Browse the repository at this point in the history
  • Loading branch information
agners authored Sep 12, 2024
1 parent 0af331d commit fcf6229
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 34 deletions.
6 changes: 4 additions & 2 deletions matter_server/server/device_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,13 +972,15 @@ async def update_node(self, node_id: int, software_version: int | str) -> None:

# Add update to the OTA provider
ota_provider = ExternalOtaProvider(
self.server.vendor_id, self._ota_provider_dir / f"{node_id}"
self.server.vendor_id,
self._ota_provider_dir,
self._ota_provider_dir / f"{node_id}",
)

await ota_provider.initialize()

node_logger.info("Downloading update from '%s'", update["otaUrl"])
await ota_provider.download_update(update)
await ota_provider.fetch_update(update)

self._attribute_update_callbacks.setdefault(node_id, []).append(
ota_provider.check_update_state
Expand Down
89 changes: 57 additions & 32 deletions matter_server/server/ota/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,12 @@ class ExternalOtaProvider:

ENDPOINT_ID: Final[int] = 0

def __init__(self, vendor_id: int, ota_provider_dir: Path) -> None:
def __init__(
self, vendor_id: int, ota_provider_base_dir: Path, ota_provider_dir: Path
) -> None:
"""Initialize the OTA provider."""
self._vendor_id: int = vendor_id
self._ota_provider_base_dir: Path = ota_provider_base_dir
self._ota_provider_dir: Path = ota_provider_dir
self._ota_file_path: Path | None = None
self._ota_provider_proc: Process | None = None
Expand Down Expand Up @@ -261,10 +264,11 @@ async def stop(self) -> None:
self._ota_provider_proc = None
self._ota_provider_task = None

async def download_update(self, update_desc: dict) -> None:
"""Download update file from OTA Path and add it to the OTA provider."""
async def _download_update(
self, url: str, checksum_alg: hashlib._Hash | None
) -> Path:
"""Download update file from OTA URL."""

url = update_desc["otaUrl"]
parsed_url = urlparse(url)
file_name = unquote(Path(parsed_url.path).name)

Expand All @@ -273,20 +277,6 @@ async def download_update(self, update_desc: dict) -> None:
file_path = self._ota_provider_dir / file_name

try:
checksum_alg = None
if (
"otaChecksum" in update_desc
and "otaChecksumType" in update_desc
and update_desc["otaChecksumType"] in CHECHKSUM_TYPES
):
checksum_alg = hashlib.new(
CHECHKSUM_TYPES[update_desc["otaChecksumType"]]
)
else:
LOGGER.warning(
"No OTA checksum type or not supported, OTA will not be checked."
)

async with ClientSession(raise_for_status=True) as session:
# fetch the paa certificates list
LOGGER.debug("Download update from '%s'.", url)
Expand All @@ -300,20 +290,6 @@ async def download_update(self, update_desc: dict) -> None:
if checksum_alg:
checksum_alg.update(chunk)

# Download finished, check checksum if necessary
if checksum_alg:
checksum = b64encode(checksum_alg.digest()).decode("ascii")
checksum_expected = update_desc["otaChecksum"].strip()
if checksum != checksum_expected:
LOGGER.error(
"Checksum mismatch for file '%s', expected: '%s', got: '%s'",
file_name,
checksum_expected,
checksum,
)
await loop.run_in_executor(None, file_path.unlink)
raise UpdateError("Checksum mismatch!")

LOGGER.info(
"Update file '%s' downloaded to '%s'",
file_name,
Expand All @@ -326,6 +302,55 @@ async def download_update(self, update_desc: dict) -> None:
)
raise UpdateError("Fetching software version failed") from err

return file_path

async def fetch_update(self, update_desc: dict) -> None:
"""Fetch update file from OTA URL."""
url = update_desc["otaUrl"]
parsed_url = urlparse(url)
file_name = unquote(Path(parsed_url.path).name)

loop = asyncio.get_running_loop()

checksum_alg = None
if (
"otaChecksum" in update_desc
and "otaChecksumType" in update_desc
and update_desc["otaChecksumType"] in CHECHKSUM_TYPES
):
checksum_alg = hashlib.new(CHECHKSUM_TYPES[update_desc["otaChecksumType"]])
else:
LOGGER.warning(
"No OTA checksum type or not supported, OTA will not be checked."
)

if parsed_url.scheme in ["http", "https"]:
file_path = await self._download_update(url, checksum_alg)
elif parsed_url.scheme in ["file"]:
file_path = self._ota_provider_base_dir / Path(parsed_url.path[1:])
if not file_path.exists():
logging.warning("Local update file not found: %s", file_path)
raise UpdateError("Local update file not found")
if checksum_alg:
checksum_alg.update(
await loop.run_in_executor(None, file_path.read_bytes)
)
else:
raise UpdateError("Unsupported OTA URL scheme")

# Download finished, check checksum if necessary
if checksum_alg:
checksum_expected = update_desc["otaChecksum"].strip()
checksum = b64encode(checksum_alg.digest()).decode("ascii")
if checksum != checksum_expected:
LOGGER.error(
"Checksum mismatch for file '%s', expected: '%s', got: '%s'",
file_name,
checksum_expected,
checksum,
)
raise UpdateError("Checksum mismatch!")

self._ota_file_path = file_path

async def check_update_state(
Expand Down

0 comments on commit fcf6229

Please sign in to comment.