diff --git a/matter_server/server/__main__.py b/matter_server/server/__main__.py index 9620612a..bbf445e9 100644 --- a/matter_server/server/__main__.py +++ b/matter_server/server/__main__.py @@ -91,6 +91,12 @@ default=None, help="Primary network interface for link-local addresses (optional).", ) +parser.add_argument( + "--paa-root-cert-dir", + type=str, + default=None, + help="Directory where PAA root certificates are stored.", +) args = parser.parse_args() @@ -175,6 +181,7 @@ def main() -> None: int(args.port), args.listen_address, args.primary_interface, + args.paa_root_cert_dir, ) async def handle_stop(loop: asyncio.AbstractEventLoop) -> None: diff --git a/matter_server/server/const.py b/matter_server/server/const.py index 2a6140b0..8cca3cf4 100644 --- a/matter_server/server/const.py +++ b/matter_server/server/const.py @@ -11,11 +11,9 @@ # and a full re-interview is mandatory DATA_MODEL_SCHEMA_VERSION = 6 -# the paa-root-certs path is hardcoded in the sdk at this time -# and always uses the development subfolder -# regardless of anything you pass into instantiating the controller -# revisit this once matter 1.1 is released -PAA_ROOT_CERTS_DIR: Final[pathlib.Path] = ( +# Keep default location inherited from early version of the Python +# bindings. +DEFAULT_PAA_ROOT_CERTS_DIR: Final[pathlib.Path] = ( pathlib.Path(__file__) .parent.resolve() .parent.resolve() diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index e49caf48..d18006b0 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -9,6 +9,7 @@ from datetime import datetime from functools import partial import logging +from pathlib import Path from random import randint import time from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, cast @@ -48,7 +49,7 @@ MatterNodeEvent, NodePingResult, ) -from .const import DATA_MODEL_SCHEMA_VERSION, PAA_ROOT_CERTS_DIR +from .const import DATA_MODEL_SCHEMA_VERSION from .helpers.paa_certificates import fetch_certificates if TYPE_CHECKING: @@ -117,15 +118,15 @@ def __init__( self._mdns_event_timer: dict[str, asyncio.TimerHandle] = {} self._node_lock: dict[int, asyncio.Lock] = {} - async def initialize(self) -> None: + async def initialize(self, paa_root_cert_dir: Path) -> None: """Async initialize of controller.""" # (re)fetch all PAA certificates once at startup # NOTE: this must be done before initializing the controller - await fetch_certificates() + await fetch_certificates(paa_root_cert_dir) # Instantiate the underlying ChipDeviceController instance on the Fabric self.chip_controller = self.server.stack.fabric_admin.NewController( - paaTrustStorePath=str(PAA_ROOT_CERTS_DIR) + paaTrustStorePath=str(paa_root_cert_dir) ) self.compressed_fabric_id = cast( int, await self._call_sdk(self.chip_controller.GetCompressedFabricId) diff --git a/matter_server/server/helpers/paa_certificates.py b/matter_server/server/helpers/paa_certificates.py index d186be18..e5308380 100644 --- a/matter_server/server/helpers/paa_certificates.py +++ b/matter_server/server/helpers/paa_certificates.py @@ -11,14 +11,13 @@ from datetime import UTC, datetime, timedelta import logging from os import makedirs +from pathlib import Path import re from aiohttp import ClientError, ClientSession from cryptography import x509 from cryptography.hazmat.primitives import serialization -from matter_server.server.const import PAA_ROOT_CERTS_DIR - # Git repo details OWNER = "project-chip" REPO = "connectedhomeip" @@ -33,14 +32,16 @@ LAST_CERT_IDS: set[str] = set() -async def write_paa_root_cert(certificate: str, subject: str) -> None: +async def write_paa_root_cert( + paa_root_cert_dir: Path, certificate: str, subject: str +) -> None: """Write certificate from string to file.""" def _write() -> None: filename_base = "dcld_mirror_" + re.sub( "[^a-zA-Z0-9_-]", "", re.sub("[=, ]", "_", subject) ) - filepath_base = PAA_ROOT_CERTS_DIR.joinpath(filename_base) + filepath_base = paa_root_cert_dir.joinpath(filename_base) # handle PEM certificate file file_path_pem = f"{filepath_base}.pem" LOGGER.debug("Writing certificate %s", file_path_pem) @@ -58,6 +59,7 @@ def _write() -> None: async def fetch_dcl_certificates( + paa_root_cert_dir: Path, fetch_test_certificates: bool = True, fetch_production_certificates: bool = True, ) -> int: @@ -99,6 +101,7 @@ async def fetch_dcl_certificates( certificate = certificate.rstrip("\n") await write_paa_root_cert( + paa_root_cert_dir, certificate, subject, ) @@ -119,7 +122,7 @@ async def fetch_dcl_certificates( # are correctly captured -async def fetch_git_certificates() -> int: +async def fetch_git_certificates(paa_root_cert_dir: Path) -> int: """Fetch Git PAA Certificates.""" fetch_count = 0 LOGGER.info("Fetching the latest PAA root certificates from Git.") @@ -137,7 +140,7 @@ async def fetch_git_certificates() -> int: continue async with http_session.get(f"{GIT_URL}/{cert}.pem") as response: certificate = await response.text() - await write_paa_root_cert(certificate, cert) + await write_paa_root_cert(paa_root_cert_dir, certificate, cert) LAST_CERT_IDS.add(cert) fetch_count += 1 except (ClientError, TimeoutError) as err: @@ -150,24 +153,18 @@ async def fetch_git_certificates() -> int: return fetch_count -async def _get_certificate_age() -> datetime: - """Get last time PAA Certificates have been fetched.""" - loop = asyncio.get_running_loop() - stat = await loop.run_in_executor(None, PAA_ROOT_CERTS_DIR.stat) - return datetime.fromtimestamp(stat.st_mtime, tz=UTC) - - async def fetch_certificates( + paa_root_cert_dir: Path, fetch_test_certificates: bool = True, fetch_production_certificates: bool = True, ) -> int: """Fetch PAA Certificates.""" loop = asyncio.get_running_loop() - if not PAA_ROOT_CERTS_DIR.is_dir(): - await loop.run_in_executor(None, makedirs, PAA_ROOT_CERTS_DIR) + if not paa_root_cert_dir.is_dir(): + await loop.run_in_executor(None, makedirs, paa_root_cert_dir) else: - stat = await loop.run_in_executor(None, PAA_ROOT_CERTS_DIR.stat) + stat = await loop.run_in_executor(None, paa_root_cert_dir.stat) last_fetch = datetime.fromtimestamp(stat.st_mtime, tz=UTC) if last_fetch > datetime.now(tz=UTC) - timedelta(days=1): LOGGER.info( @@ -176,13 +173,14 @@ async def fetch_certificates( return 0 fetch_count = await fetch_dcl_certificates( + paa_root_cert_dir=paa_root_cert_dir, fetch_test_certificates=fetch_test_certificates, fetch_production_certificates=fetch_production_certificates, ) if fetch_test_certificates: - fetch_count += await fetch_git_certificates() + fetch_count += await fetch_git_certificates(paa_root_cert_dir) - await loop.run_in_executor(None, PAA_ROOT_CERTS_DIR.touch) + await loop.run_in_executor(None, paa_root_cert_dir.touch) return fetch_count diff --git a/matter_server/server/server.py b/matter_server/server/server.py index 5b29aeb8..87456b09 100644 --- a/matter_server/server/server.py +++ b/matter_server/server/server.py @@ -29,7 +29,7 @@ ServerInfoMessage, ) from ..server.client_handler import WebsocketClientHandler -from .const import MIN_SCHEMA_VERSION +from .const import DEFAULT_PAA_ROOT_CERTS_DIR, MIN_SCHEMA_VERSION from .device_controller import MatterDeviceController from .stack import MatterStack from .storage import StorageController @@ -100,6 +100,7 @@ def __init__( port: int, listen_addresses: list[str] | None = None, primary_interface: str | None = None, + paa_root_cert_dir: Path | None = None, ) -> None: """Initialize the Matter Server.""" self.storage_path = storage_path @@ -108,6 +109,10 @@ def __init__( self.port = port self.listen_addresses = listen_addresses self.primary_interface = primary_interface + if paa_root_cert_dir is None: + self.paa_root_cert_dir = DEFAULT_PAA_ROOT_CERTS_DIR + else: + self.paa_root_cert_dir = Path(paa_root_cert_dir).absolute() self.logger = logging.getLogger(__name__) self.app = web.Application() self.loop: asyncio.AbstractEventLoop | None = None @@ -134,7 +139,7 @@ async def start(self) -> None: self.loop = asyncio.get_running_loop() self.loop.set_exception_handler(_global_loop_exception_handler) self.loop.set_debug(os.environ.get("PYTHONDEBUG", "") != "") - await self.device_controller.initialize() + await self.device_controller.initialize(self.paa_root_cert_dir) await self.storage.start() await self.device_controller.start() await self.vendor_info.start()