Skip to content

Commit

Permalink
Use to_thread where appropriate
Browse files Browse the repository at this point in the history
Use asyncio.to_thread where appropriate, and replace existing
loop.run_in_executor() calls with the shorter asyncio.to_thread() call.
  • Loading branch information
agners committed Sep 13, 2024
1 parent 52a9fce commit 0432647
Showing 1 changed file with 6 additions and 14 deletions.
20 changes: 6 additions & 14 deletions matter_server/server/ota/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ async def start_update(

log_file_path = self._ota_provider_dir / f"ota_provider_{timestamp}.log"

log_file = await loop.run_in_executor(None, log_file_path.open, "w")
log_file = await asyncio.to_thread(log_file_path.open, "w")

try:
LOGGER.info("Starting OTA Provider")
Expand Down Expand Up @@ -245,18 +245,16 @@ def _remove_update_data(ota_provider_dir: Path) -> None:
if not path.is_dir():
path.unlink()

loop = asyncio.get_event_loop()
await loop.run_in_executor(None, _remove_update_data, self._ota_provider_dir)
await asyncio.to_thread(_remove_update_data, self._ota_provider_dir)

await self.initialize()

async def stop(self) -> None:
"""Stop the OTA Provider."""
if self._ota_provider_proc:
LOGGER.info("Terminating OTA Provider")
loop = asyncio.get_event_loop()
try:
await loop.run_in_executor(None, self._ota_provider_proc.terminate)
await asyncio.to_thread(self._ota_provider_proc.terminate)
except ProcessLookupError as ex:
LOGGER.warning("Stopping OTA Provider failed with error:", exc_info=ex)
if self._ota_provider_task:
Expand All @@ -272,8 +270,6 @@ async def _download_update(
parsed_url = urlparse(url)
file_name = unquote(Path(parsed_url.path).name)

loop = asyncio.get_running_loop()

file_path = self._ota_provider_dir / file_name

try:
Expand All @@ -286,7 +282,7 @@ async def _download_update(
chunk = await response.content.read(4048)
if not chunk:
break
await loop.run_in_executor(None, f.write, chunk)
await asyncio.to_thread(f.write, chunk)
if checksum_alg:
checksum_alg.update(chunk)

Expand All @@ -310,8 +306,6 @@ async def fetch_update(self, update_desc: dict) -> None:
parsed_url = urlparse(url)
file_name = unquote(Path(parsed_url.path).name)

loop = asyncio.get_running_loop()

checksum_alg = None
if (
"otaChecksum" in update_desc
Expand All @@ -328,13 +322,11 @@ async def fetch_update(self, update_desc: dict) -> None:
file_path = await self._download_update(url, checksum_alg)
elif parsed_url.scheme in ["file"]:
file_path = self._ota_provider_base_dir / Path(parsed_url.path[1:])
if not file_path.exists():
if not await asyncio.to_thread(file_path.exists):
logging.warning("Local update file not found: %s", file_path)
raise UpdateError("Local update file not found")
if checksum_alg:
checksum_alg.update(
await loop.run_in_executor(None, file_path.read_bytes)
)
checksum_alg.update(await asyncio.to_thread(file_path.read_bytes))
else:
raise UpdateError("Unsupported OTA URL scheme")

Expand Down

0 comments on commit 0432647

Please sign in to comment.