diff --git a/devlib/utils/asyn.py b/devlib/utils/asyn.py index 23c1f928a..ddbf5fac3 100644 --- a/devlib/utils/asyn.py +++ b/devlib/utils/asyn.py @@ -26,9 +26,45 @@ import pathlib import os.path import inspect +import threading import greenback + +def _wrap_coro(coro): + async def coro_f(): + await greenback.ensure_portal() + return await coro + return coro_f() + + +class _GreenbackEventLoop(asyncio.AbstractEventLoop): + def __init__(self, loop): + self.__loop = loop + + # Because AbstractEventLoop has default implementation for every method, we + # cannot just use __getattr__, otherwise we will hit these default + # implementation first. + def __getattribute__(self, attr): + if attr == '_GreenbackEventLoop__loop': + return object.__getattribute__(self, '_GreenbackEventLoop__loop') + else: + try: + override = type(self)._OVERRIDES[attr] + except KeyError: + return getattr(self.__loop, attr) + else: + return override.__get__(self, type(self)) + + def create_task(self, coro, *args, **kwargs): + loop = self.__loop + return loop.create_task(_wrap_coro(coro), *args, **kwargs) + + _OVERRIDES = dict( + create_task=create_task, + ) + + def create_task(awaitable, name=None): if isinstance(awaitable, asyncio.Task): task = awaitable @@ -282,6 +318,27 @@ def __set_name__(self, owner, name): self.name = name +# This thread runs an event loop that guarantees having greenback portals +# installed for each task that is created. This can otherwise not be +# guaranteed, even with import-time manipulations as the import itself might +# happen in an already-created tasks, e.g. inside Jupyter lab notebooks. +_LOOP = None +_LOOP_READY = threading.Event() + +def _coro_runner(): + global _LOOP + loop = asyncio.new_event_loop() + loop = _GreenbackEventLoop(loop) + asyncio.set_event_loop(loop) + + _LOOP = loop + _LOOP_READY.set() + _LOOP.run_forever() + +_RUN_THREAD = threading.Thread(target=_coro_runner, daemon=True) +_RUN_THREAD.start() + + def run(coro): """ Similar to :func:`asyncio.run` but can be called while an event loop is @@ -305,7 +362,14 @@ async def coro_f(): # re-enter the event loop, as this is not supported in the standard # library: # https://github.com/python/cpython/issues/66435 - return greenback.await_(coro) + if greenback.has_portal(): + return greenback.await_(coro) + else: + # If we don't have a portal setup for the current task, run the + # coroutine on the dedicated thread that has a policy which + # installs portals for every new task. + _LOOP_READY.wait() + return asyncio.run_coroutine_threadsafe(coro, _LOOP).result() def asyncf(f):