Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use to_thread where appropriate #886

Merged
merged 1 commit into from
Sep 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions matter_server/server/ota/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ async def start_update(

log_file_path = self._ota_provider_dir / f"ota_provider_{timestamp}.log"

log_file = await loop.run_in_executor(None, log_file_path.open, "w")
log_file = await asyncio.to_thread(log_file_path.open, "w")

try:
LOGGER.info("Starting OTA Provider")
Expand Down Expand Up @@ -245,18 +245,16 @@ def _remove_update_data(ota_provider_dir: Path) -> None:
if not path.is_dir():
path.unlink()

loop = asyncio.get_event_loop()
await loop.run_in_executor(None, _remove_update_data, self._ota_provider_dir)
await asyncio.to_thread(_remove_update_data, self._ota_provider_dir)

await self.initialize()

async def stop(self) -> None:
"""Stop the OTA Provider."""
if self._ota_provider_proc:
LOGGER.info("Terminating OTA Provider")
loop = asyncio.get_event_loop()
try:
await loop.run_in_executor(None, self._ota_provider_proc.terminate)
await asyncio.to_thread(self._ota_provider_proc.terminate)
except ProcessLookupError as ex:
LOGGER.warning("Stopping OTA Provider failed with error:", exc_info=ex)
if self._ota_provider_task:
Expand All @@ -272,8 +270,6 @@ async def _download_update(
parsed_url = urlparse(url)
file_name = unquote(Path(parsed_url.path).name)

loop = asyncio.get_running_loop()

file_path = self._ota_provider_dir / file_name

try:
Expand All @@ -286,7 +282,7 @@ async def _download_update(
chunk = await response.content.read(4048)
if not chunk:
break
await loop.run_in_executor(None, f.write, chunk)
await asyncio.to_thread(f.write, chunk)
if checksum_alg:
checksum_alg.update(chunk)

Expand All @@ -310,8 +306,6 @@ async def fetch_update(self, update_desc: dict) -> None:
parsed_url = urlparse(url)
file_name = unquote(Path(parsed_url.path).name)

loop = asyncio.get_running_loop()

checksum_alg = None
if (
"otaChecksum" in update_desc
Expand All @@ -328,13 +322,11 @@ async def fetch_update(self, update_desc: dict) -> None:
file_path = await self._download_update(url, checksum_alg)
elif parsed_url.scheme in ["file"]:
file_path = self._ota_provider_base_dir / Path(parsed_url.path[1:])
if not file_path.exists():
if not await asyncio.to_thread(file_path.exists):
logging.warning("Local update file not found: %s", file_path)
raise UpdateError("Local update file not found")
if checksum_alg:
checksum_alg.update(
await loop.run_in_executor(None, file_path.read_bytes)
)
checksum_alg.update(await asyncio.to_thread(file_path.read_bytes))
else:
raise UpdateError("Unsupported OTA URL scheme")

Expand Down