Skip to content

Commit

Permalink
utils/asyn: Ensure we always have a task with a greenback portals
Browse files Browse the repository at this point in the history
Ensure that we have a thread with an event loop that has a greenback
portal installed. Having that separate thread avoids having to
manipulate the current thread event loop, which would be brittle as a
user could change the event loop after we have setup ours. It also would
not work in environments where the import is executing inside a aysncio
Task, which is the case in Jupyter lab.
  • Loading branch information
douglas-raillard-arm committed May 8, 2024
1 parent 6d5be35 commit 4b053b4
Showing 1 changed file with 65 additions and 1 deletion.
66 changes: 65 additions & 1 deletion devlib/utils/asyn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,45 @@
import pathlib
import os.path
import inspect
import threading

import greenback


def _wrap_coro(coro):
async def coro_f():
await greenback.ensure_portal()
return await coro
return coro_f()


class _GreenbackEventLoop(asyncio.AbstractEventLoop):
def __init__(self, loop):
self.__loop = loop

# Because AbstractEventLoop has default implementation for every method, we
# cannot just use __getattr__, otherwise we will hit these default
# implementation first.
def __getattribute__(self, attr):
if attr == '_GreenbackEventLoop__loop':
return object.__getattribute__(self, '_GreenbackEventLoop__loop')
else:
try:
override = type(self)._OVERRIDES[attr]
except KeyError:
return getattr(self.__loop, attr)
else:
return override.__get__(self, type(self))

def create_task(self, coro, *args, **kwargs):
loop = self.__loop
return loop.create_task(_wrap_coro(coro), *args, **kwargs)

_OVERRIDES = dict(
create_task=create_task,
)


def create_task(awaitable, name=None):
if isinstance(awaitable, asyncio.Task):
task = awaitable
Expand Down Expand Up @@ -282,6 +318,27 @@ def __set_name__(self, owner, name):
self.name = name


# This thread runs an event loop that guarantees having greenback portals
# installed for each task that is created. This can otherwise not be
# guaranteed, even with import-time manipulations as the import itself might
# happen in an already-created tasks, e.g. inside Jupyter lab notebooks.
_LOOP = None
_LOOP_READY = threading.Event()

def _coro_runner():
global _LOOP
loop = asyncio.new_event_loop()
loop = _GreenbackEventLoop(loop)
asyncio.set_event_loop(loop)

_LOOP = loop
_LOOP_READY.set()
_LOOP.run_forever()

_RUN_THREAD = threading.Thread(target=_coro_runner, daemon=True)
_RUN_THREAD.start()


def run(coro):
"""
Similar to :func:`asyncio.run` but can be called while an event loop is
Expand All @@ -305,7 +362,14 @@ async def coro_f():
# re-enter the event loop, as this is not supported in the standard
# library:
# https://github.com/python/cpython/issues/66435
return greenback.await_(coro)
if greenback.has_portal():
return greenback.await_(coro)
else:
# If we don't have a portal setup for the current task, run the
# coroutine on the dedicated thread that has a policy which
# installs portals for every new task.
_LOOP_READY.wait()
return asyncio.run_coroutine_threadsafe(coro, _LOOP).result()


def asyncf(f):
Expand Down

0 comments on commit 4b053b4

Please sign in to comment.