From 0432647b8df42d79047649d72d1d7a42c52c1947 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Fri, 13 Sep 2024 12:01:26 +0200 Subject: [PATCH] Use to_thread where appropriate Use asyncio.to_thread where appropriate, and replace existing loop.run_in_executor() calls with the shorter asyncio.to_thread() call. --- matter_server/server/ota/provider.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/matter_server/server/ota/provider.py b/matter_server/server/ota/provider.py index 45214d73..f53aa316 100644 --- a/matter_server/server/ota/provider.py +++ b/matter_server/server/ota/provider.py @@ -187,7 +187,7 @@ async def start_update( log_file_path = self._ota_provider_dir / f"ota_provider_{timestamp}.log" - log_file = await loop.run_in_executor(None, log_file_path.open, "w") + log_file = await asyncio.to_thread(log_file_path.open, "w") try: LOGGER.info("Starting OTA Provider") @@ -245,8 +245,7 @@ def _remove_update_data(ota_provider_dir: Path) -> None: if not path.is_dir(): path.unlink() - loop = asyncio.get_event_loop() - await loop.run_in_executor(None, _remove_update_data, self._ota_provider_dir) + await asyncio.to_thread(_remove_update_data, self._ota_provider_dir) await self.initialize() @@ -254,9 +253,8 @@ async def stop(self) -> None: """Stop the OTA Provider.""" if self._ota_provider_proc: LOGGER.info("Terminating OTA Provider") - loop = asyncio.get_event_loop() try: - await loop.run_in_executor(None, self._ota_provider_proc.terminate) + await asyncio.to_thread(self._ota_provider_proc.terminate) except ProcessLookupError as ex: LOGGER.warning("Stopping OTA Provider failed with error:", exc_info=ex) if self._ota_provider_task: @@ -272,8 +270,6 @@ async def _download_update( parsed_url = urlparse(url) file_name = unquote(Path(parsed_url.path).name) - loop = asyncio.get_running_loop() - file_path = self._ota_provider_dir / file_name try: @@ -286,7 +282,7 @@ async def _download_update( chunk = await response.content.read(4048) if not chunk: break - await loop.run_in_executor(None, f.write, chunk) + await asyncio.to_thread(f.write, chunk) if checksum_alg: checksum_alg.update(chunk) @@ -310,8 +306,6 @@ async def fetch_update(self, update_desc: dict) -> None: parsed_url = urlparse(url) file_name = unquote(Path(parsed_url.path).name) - loop = asyncio.get_running_loop() - checksum_alg = None if ( "otaChecksum" in update_desc @@ -328,13 +322,11 @@ async def fetch_update(self, update_desc: dict) -> None: file_path = await self._download_update(url, checksum_alg) elif parsed_url.scheme in ["file"]: file_path = self._ota_provider_base_dir / Path(parsed_url.path[1:]) - if not file_path.exists(): + if not await asyncio.to_thread(file_path.exists): logging.warning("Local update file not found: %s", file_path) raise UpdateError("Local update file not found") if checksum_alg: - checksum_alg.update( - await loop.run_in_executor(None, file_path.read_bytes) - ) + checksum_alg.update(await asyncio.to_thread(file_path.read_bytes)) else: raise UpdateError("Unsupported OTA URL scheme")