Skip to content

Commit

Permalink
Add server argument to specify custom PAA root certificate location (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
agners authored Mar 7, 2024
1 parent 39c5906 commit 9950f8b
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 29 deletions.
7 changes: 7 additions & 0 deletions matter_server/server/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@
default=None,
help="Primary network interface for link-local addresses (optional).",
)
parser.add_argument(
"--paa-root-cert-dir",
type=str,
default=None,
help="Directory where PAA root certificates are stored.",
)

args = parser.parse_args()

Expand Down Expand Up @@ -175,6 +181,7 @@ def main() -> None:
int(args.port),
args.listen_address,
args.primary_interface,
args.paa_root_cert_dir,
)

async def handle_stop(loop: asyncio.AbstractEventLoop) -> None:
Expand Down
8 changes: 3 additions & 5 deletions matter_server/server/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
# and a full re-interview is mandatory
DATA_MODEL_SCHEMA_VERSION = 6

# the paa-root-certs path is hardcoded in the sdk at this time
# and always uses the development subfolder
# regardless of anything you pass into instantiating the controller
# revisit this once matter 1.1 is released
PAA_ROOT_CERTS_DIR: Final[pathlib.Path] = (
# Keep default location inherited from early version of the Python
# bindings.
DEFAULT_PAA_ROOT_CERTS_DIR: Final[pathlib.Path] = (
pathlib.Path(__file__)
.parent.resolve()
.parent.resolve()
Expand Down
9 changes: 5 additions & 4 deletions matter_server/server/device_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from datetime import datetime
from functools import partial
import logging
from pathlib import Path
from random import randint
import time
from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, cast
Expand Down Expand Up @@ -48,7 +49,7 @@
MatterNodeEvent,
NodePingResult,
)
from .const import DATA_MODEL_SCHEMA_VERSION, PAA_ROOT_CERTS_DIR
from .const import DATA_MODEL_SCHEMA_VERSION
from .helpers.paa_certificates import fetch_certificates

if TYPE_CHECKING:
Expand Down Expand Up @@ -117,15 +118,15 @@ def __init__(
self._mdns_event_timer: dict[str, asyncio.TimerHandle] = {}
self._node_lock: dict[int, asyncio.Lock] = {}

async def initialize(self) -> None:
async def initialize(self, paa_root_cert_dir: Path) -> None:
"""Async initialize of controller."""
# (re)fetch all PAA certificates once at startup
# NOTE: this must be done before initializing the controller
await fetch_certificates()
await fetch_certificates(paa_root_cert_dir)

# Instantiate the underlying ChipDeviceController instance on the Fabric
self.chip_controller = self.server.stack.fabric_admin.NewController(
paaTrustStorePath=str(PAA_ROOT_CERTS_DIR)
paaTrustStorePath=str(paa_root_cert_dir)
)
self.compressed_fabric_id = cast(
int, await self._call_sdk(self.chip_controller.GetCompressedFabricId)
Expand Down
34 changes: 16 additions & 18 deletions matter_server/server/helpers/paa_certificates.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@
from datetime import UTC, datetime, timedelta
import logging
from os import makedirs
from pathlib import Path
import re

from aiohttp import ClientError, ClientSession
from cryptography import x509
from cryptography.hazmat.primitives import serialization

from matter_server.server.const import PAA_ROOT_CERTS_DIR

# Git repo details
OWNER = "project-chip"
REPO = "connectedhomeip"
Expand All @@ -33,14 +32,16 @@
LAST_CERT_IDS: set[str] = set()


async def write_paa_root_cert(certificate: str, subject: str) -> None:
async def write_paa_root_cert(
paa_root_cert_dir: Path, certificate: str, subject: str
) -> None:
"""Write certificate from string to file."""

def _write() -> None:
filename_base = "dcld_mirror_" + re.sub(
"[^a-zA-Z0-9_-]", "", re.sub("[=, ]", "_", subject)
)
filepath_base = PAA_ROOT_CERTS_DIR.joinpath(filename_base)
filepath_base = paa_root_cert_dir.joinpath(filename_base)
# handle PEM certificate file
file_path_pem = f"{filepath_base}.pem"
LOGGER.debug("Writing certificate %s", file_path_pem)
Expand All @@ -58,6 +59,7 @@ def _write() -> None:


async def fetch_dcl_certificates(
paa_root_cert_dir: Path,
fetch_test_certificates: bool = True,
fetch_production_certificates: bool = True,
) -> int:
Expand Down Expand Up @@ -99,6 +101,7 @@ async def fetch_dcl_certificates(
certificate = certificate.rstrip("\n")

await write_paa_root_cert(
paa_root_cert_dir,
certificate,
subject,
)
Expand All @@ -119,7 +122,7 @@ async def fetch_dcl_certificates(
# are correctly captured


async def fetch_git_certificates() -> int:
async def fetch_git_certificates(paa_root_cert_dir: Path) -> int:
"""Fetch Git PAA Certificates."""
fetch_count = 0
LOGGER.info("Fetching the latest PAA root certificates from Git.")
Expand All @@ -137,7 +140,7 @@ async def fetch_git_certificates() -> int:
continue
async with http_session.get(f"{GIT_URL}/{cert}.pem") as response:
certificate = await response.text()
await write_paa_root_cert(certificate, cert)
await write_paa_root_cert(paa_root_cert_dir, certificate, cert)
LAST_CERT_IDS.add(cert)
fetch_count += 1
except (ClientError, TimeoutError) as err:
Expand All @@ -150,24 +153,18 @@ async def fetch_git_certificates() -> int:
return fetch_count


async def _get_certificate_age() -> datetime:
"""Get last time PAA Certificates have been fetched."""
loop = asyncio.get_running_loop()
stat = await loop.run_in_executor(None, PAA_ROOT_CERTS_DIR.stat)
return datetime.fromtimestamp(stat.st_mtime, tz=UTC)


async def fetch_certificates(
paa_root_cert_dir: Path,
fetch_test_certificates: bool = True,
fetch_production_certificates: bool = True,
) -> int:
"""Fetch PAA Certificates."""
loop = asyncio.get_running_loop()

if not PAA_ROOT_CERTS_DIR.is_dir():
await loop.run_in_executor(None, makedirs, PAA_ROOT_CERTS_DIR)
if not paa_root_cert_dir.is_dir():
await loop.run_in_executor(None, makedirs, paa_root_cert_dir)
else:
stat = await loop.run_in_executor(None, PAA_ROOT_CERTS_DIR.stat)
stat = await loop.run_in_executor(None, paa_root_cert_dir.stat)
last_fetch = datetime.fromtimestamp(stat.st_mtime, tz=UTC)
if last_fetch > datetime.now(tz=UTC) - timedelta(days=1):
LOGGER.info(
Expand All @@ -176,13 +173,14 @@ async def fetch_certificates(
return 0

fetch_count = await fetch_dcl_certificates(
paa_root_cert_dir=paa_root_cert_dir,
fetch_test_certificates=fetch_test_certificates,
fetch_production_certificates=fetch_production_certificates,
)

if fetch_test_certificates:
fetch_count += await fetch_git_certificates()
fetch_count += await fetch_git_certificates(paa_root_cert_dir)

await loop.run_in_executor(None, PAA_ROOT_CERTS_DIR.touch)
await loop.run_in_executor(None, paa_root_cert_dir.touch)

return fetch_count
9 changes: 7 additions & 2 deletions matter_server/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
ServerInfoMessage,
)
from ..server.client_handler import WebsocketClientHandler
from .const import MIN_SCHEMA_VERSION
from .const import DEFAULT_PAA_ROOT_CERTS_DIR, MIN_SCHEMA_VERSION
from .device_controller import MatterDeviceController
from .stack import MatterStack
from .storage import StorageController
Expand Down Expand Up @@ -100,6 +100,7 @@ def __init__(
port: int,
listen_addresses: list[str] | None = None,
primary_interface: str | None = None,
paa_root_cert_dir: Path | None = None,
) -> None:
"""Initialize the Matter Server."""
self.storage_path = storage_path
Expand All @@ -108,6 +109,10 @@ def __init__(
self.port = port
self.listen_addresses = listen_addresses
self.primary_interface = primary_interface
if paa_root_cert_dir is None:
self.paa_root_cert_dir = DEFAULT_PAA_ROOT_CERTS_DIR
else:
self.paa_root_cert_dir = Path(paa_root_cert_dir).absolute()
self.logger = logging.getLogger(__name__)
self.app = web.Application()
self.loop: asyncio.AbstractEventLoop | None = None
Expand All @@ -134,7 +139,7 @@ async def start(self) -> None:
self.loop = asyncio.get_running_loop()
self.loop.set_exception_handler(_global_loop_exception_handler)
self.loop.set_debug(os.environ.get("PYTHONDEBUG", "") != "")
await self.device_controller.initialize()
await self.device_controller.initialize(self.paa_root_cert_dir)
await self.storage.start()
await self.device_controller.start()
await self.vendor_info.start()
Expand Down

0 comments on commit 9950f8b

Please sign in to comment.