Skip to content

Commit

Permalink
Implement update using OTA Provider app
Browse files Browse the repository at this point in the history
Use the OTA Provider example app to implement a OTA provider. The
example app supports a JSON update descriptor file to manage update
metadata. Tested with the OTA Requestor app.
  • Loading branch information
agners committed May 17, 2024
1 parent a90afb7 commit 4e75b77
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 30 deletions.
52 changes: 42 additions & 10 deletions matter_server/server/device_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@ async def stop(self) -> None:
for sub in self._subscriptions.values():
await self._call_sdk(sub.Shutdown)
self._subscriptions = {}
# shutdown the OTA Provider
if self._ota_provider:
await self._ota_provider.stop()
# shutdown the sdk device controller
await self._call_sdk(self.chip_controller.Shutdown)
LOGGER.debug("Stopped.")
Expand Down Expand Up @@ -921,6 +924,13 @@ async def update_node(self, node_id: int) -> dict | None:
node_logger = LOGGER.getChild(f"node_{node_id}")
node = self._nodes[node_id]

if self.chip_controller is None:
raise RuntimeError("Device Controller not initialized.")

if not self._ota_provider:
LOGGER.warning("No OTA provider found, updates not possible.")
return None

node_logger.debug("Check for updates.")
vid = cast(int, node.attributes.get(BASIC_INFORMATION_VENDOR_ID))
pid = cast(int, node.attributes.get(BASIC_INFORMATION_PRODUCT_ID))
Expand All @@ -932,17 +942,39 @@ async def update_node(self, node_id: int) -> dict | None:
)

update = await check_updates(node_id, vid, pid, software_version)
if update and "otaUrl" in update and len(update["otaUrl"]) > 0:
node_logger.info(
"New software update found: %s (current %s). Preparing updates...",
update["softwareVersionString"],
software_version_string,
)
if not update:
node_logger.info("No new update found.")
return None

if "otaUrl" not in update:
node_logger.warning("Update found, but no OTA URL provided.")
return None

# Add to OTA provider
if not self._ota_provider:
return None
await self._ota_provider.download_update(update)
node_logger.info(
"New software update found: %s (current %s). Preparing updates...",
update["softwareVersionString"],
software_version_string,
)

# Add to OTA provider
await self._ota_provider.download_update(update)

self._ota_provider.start()

# Wait for OTA provider to be ready
# TODO: Detect when OTA provider is ready
await asyncio.sleep(2)

await self.chip_controller.SendCommand(
nodeid=node_id,
endpoint=0,
payload=Clusters.OtaSoftwareUpdateRequestor.Commands.AnnounceOTAProvider(
providerNodeID=32,
vendorID=0, # TODO: Use Server Vendor ID
announcementReason=Clusters.OtaSoftwareUpdateRequestor.Enums.AnnouncementReasonEnum.kUpdateAvailable,
endpoint=0,
),
)

return update

Expand Down
4 changes: 4 additions & 0 deletions matter_server/server/ota/dcl.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,13 @@ async def check_updates(

software_versions: list[int] = versions["modelVersions"]["softwareVersions"]
latest_software_version = max(software_versions)

# Check if the software is indeed newer
if latest_software_version <= current_software_version:
return None

# TODO: Check minApplicableSoftwareVersion/maxApplicableSoftwareVersion

version: dict = await get_software_version(
node_id, vid, pid, latest_software_version
)
Expand Down
85 changes: 65 additions & 20 deletions matter_server/server/ota/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@

import asyncio
from dataclasses import asdict, dataclass
import functools
import json
import logging
from pathlib import Path
from typing import Final
from typing import TYPE_CHECKING, Final
from urllib.parse import unquote, urlparse

from aiohttp import ClientError, ClientSession

from matter_server.common.helpers.util import dataclass_from_dict

if TYPE_CHECKING:
from asyncio.subprocess import Process

LOGGER = logging.getLogger(__name__)

DEFAULT_UPDATES_PATH: Final[Path] = Path("updates")
Expand Down Expand Up @@ -48,10 +52,42 @@ class ExternalOtaProvider:

def __init__(self) -> None:
"""Initialize the OTA provider."""
self._ota_provider_proc: Process | None = None
self._ota_provider_task: asyncio.Task | None = None

async def _start_ota_provider(self) -> None:
# TODO: Randomize discriminator
ota_provider_cmd = [
"chip-ota-provider-app",
"--discriminator",
"22",
"--secured-device-port",
"5565",
"--KVS",
"/data/chip_kvs_provider",
"--otaImageList",
str(DEFAULT_UPDATES_PATH / "updates.json"),
]

LOGGER.info("Starting OTA Provider")
self._ota_provider_proc = await asyncio.create_subprocess_exec(
*ota_provider_cmd
)

def start(self) -> None:
"""Start the OTA Provider."""

loop = asyncio.get_event_loop()
self._ota_provider_task = loop.create_task(self._start_ota_provider())

async def stop(self) -> None:
"""Stop the OTA Provider."""
if self._ota_provider_proc:
LOGGER.info("Terminating OTA Provider")
self._ota_provider_proc.terminate()
if self._ota_provider_task:
await self._ota_provider_task

async def add_update(self, update_desc: dict, ota_file: Path) -> None:
"""Add update to the OTA provider."""

Expand All @@ -73,24 +109,25 @@ def _read_update_json(update_json_path: Path) -> None | UpdateFile:
if not update_file:
update_file = UpdateFile(deviceSoftwareVersionModel=[])

local_ota_url = str(ota_file)
for i, device_software in enumerate(update_file.deviceSoftwareVersionModel):
if device_software.otaURL == local_ota_url:
LOGGER.debug("Device software entry exists already, replacing!")
del update_file.deviceSoftwareVersionModel[i]

# Convert to OTA Requestor descriptor file
update_file.deviceSoftwareVersionModel.append(
DeviceSoftwareVersionModel(
vendorId=update_desc["vid"],
productId=update_desc["pid"],
softwareVersion=update_desc["softwareVersion"],
softwareVersionString=update_desc["softwareVersionString"],
cDVersionNumber=update_desc["cdVersionNumber"],
softwareVersionValid=update_desc["softwareVersionValid"],
minApplicableSoftwareVersion=update_desc[
"minApplicableSoftwareVersion"
],
maxApplicableSoftwareVersion=update_desc[
"maxApplicableSoftwareVersion"
],
otaURL=str(ota_file),
)
new_device_software = DeviceSoftwareVersionModel(
vendorId=update_desc["vid"],
productId=update_desc["pid"],
softwareVersion=update_desc["softwareVersion"],
softwareVersionString=update_desc["softwareVersionString"],
cDVersionNumber=update_desc["cdVersionNumber"],
softwareVersionValid=update_desc["softwareVersionValid"],
minApplicableSoftwareVersion=update_desc["minApplicableSoftwareVersion"],
maxApplicableSoftwareVersion=update_desc["maxApplicableSoftwareVersion"],
otaURL=local_ota_url,
)
update_file.deviceSoftwareVersionModel.append(new_device_software)

def _write_update_json(update_json_path: Path, update_file: UpdateFile) -> None:
update_file_dict = asdict(update_file)
Expand All @@ -112,9 +149,14 @@ async def download_update(self, update_desc: dict) -> None:
file_name = unquote(Path(parsed_url.path).name)

loop = asyncio.get_running_loop()
await loop.run_in_executor(None, DEFAULT_UPDATES_PATH.mkdir)
await loop.run_in_executor(
None, functools.partial(DEFAULT_UPDATES_PATH.mkdir, exists_ok=True)
)

file_path = DEFAULT_UPDATES_PATH / file_name
if await loop.run_in_executor(None, file_path.exists):
LOGGER.info("File '%s' exists already, skipping download.", file_name)
return

try:
async with ClientSession(raise_for_status=True) as session:
Expand All @@ -123,10 +165,13 @@ async def download_update(self, update_desc: dict) -> None:
async with session.get(url) as response:
with file_path.open("wb") as f:
while True:
chunk = await response.content.read(1024)
chunk = await response.content.read(4048)
if not chunk:
break
f.write(chunk)
await loop.run_in_executor(None, f.write, chunk)

# TODO: Check against otaChecksum

LOGGER.info(
"File '%s' downloaded to '%s'", file_name, DEFAULT_UPDATES_PATH
)
Expand Down

0 comments on commit 4e75b77

Please sign in to comment.