diff --git a/devlib/utils/asyn.py b/devlib/utils/asyn.py index c0e415612..77451b4b3 100644 --- a/devlib/utils/asyn.py +++ b/devlib/utils/asyn.py @@ -26,17 +26,11 @@ import pathlib import os.path import inspect +import threading +from concurrent.futures import ThreadPoolExecutor +from weakref import WeakSet -# Allow nesting asyncio loops, which is necessary for: -# * Being able to call the blocking variant of a function from an async -# function for backward compat -# * Critically, run the blocking variant of a function in a Jupyter notebook -# environment, since it also uses asyncio. -# -# Maybe there is still hope for future versions of Python though: -# https://bugs.python.org/issue22239 -import nest_asyncio -nest_asyncio.apply() +from greenlet import greenlet def create_task(awaitable, name=None): @@ -292,12 +286,171 @@ def __set_name__(self, owner, name): self.name = name +class _Genlet(greenlet): + @classmethod + def from_coro(cls, coro): + f = lambda x: self.consume_coro(coro, x) + self = cls(f) + return self + + def consume_coro(self, coro, x): + while True: + try: + future = coro.send(x) + except StopIteration as e: + return e.value + else: + # Switch back to the consumer that returns the values via + # __next__ + x = self.consumer_genlet.switch(future) + + @classmethod + def get_enclosing(cls): + g = greenlet.getcurrent() + while not (isinstance(g, cls) or g is None): + g = g.parent + return g + + def send(self, x): + self.consumer_genlet = greenlet.getcurrent() + # Switch back to the function yielding values + result = self.switch(x) + if self: + return result + else: + raise StopIteration(result) + + +class _AwaitableGen: + @classmethod + def wrap_coro(cls, coro): + if _Genlet.get_enclosing() is None: + # Create a top-level _Genlet that all nested runs will use to yield + # their futures + aw = cls(coro) + async def coro_f(): + return await aw + return coro_f() + else: + return coro + + def __init__(self, coro): + self._coro = coro + + def __await__(self): + coro = self._coro + is_started = not (inspect.getcoroutinestate(coro) != inspect.CORO_CREATED) + + def genf(): + gen = _Genlet.from_coro(coro) + # The coroutine is already started, so we need to dispatch the + # value from the upcoming send() to the gen without running + # gen first. + if is_started: + x = yield + else: + x = None + + while True: + try: + x = yield gen.send(x) + except StopIteration as e: + return e.value + + gen = genf() + if is_started: + # Start the generator so it waits at the first yield point + gen.send(None) + + return gen + + +def allow_nested_run(coro): + """ + Wrap the coroutine ``coro`` such that nested calls to :func:`run` will be + allowed. + + .. warning:: The coroutine needs to be consumed in the same OS thread it + was created in. + """ + return _AwaitableGen.wrap_coro(coro) + + +_CORO_THREAD_EXECUTOR = ThreadPoolExecutor(max_workers=1) +def _coro_thread_f(coro): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # The coroutine needs to be wrapped in the same thread that will consume it, + coro = allow_nested_run(coro) + return loop.run_until_complete(coro) + + +def _run_in_thread(coro): + future = _CORO_THREAD_EXECUTOR.submit(_coro_thread_f, coro) + return future.result() + + +_PATCHED_LOOP_LOCK = threading.Lock() +_PATCHED_LOOP = WeakSet() + +def _install_task_factory(loop): + def install(loop): + def default_factory(loop, coro, context=None): + return asyncio.Task(coro, loop=loop, context=context) + + make_task = loop.get_task_factory() or default_factory + def factory(loop, coro, context=None): + coro = allow_nested_run(coro) + return make_task(loop, coro, context=context) + + loop.set_task_factory(factory) + + with _PATCHED_LOOP_LOCK: + if loop in _PATCHED_LOOP: + return + else: + install(loop) + _PATCHED_LOOP.add(loop) + + def run(coro): """ Similar to :func:`asyncio.run` but can be called while an event loop is - running. + running if a coroutine higher in the callstack has been wrapped using + :func:`allow_nested_run`. """ - return asyncio.run(coro) + assert inspect.getcoroutinestate(coro) == inspect.CORO_CREATED + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + # We are not currently running an event loop, so it's ok to just use + # asyncio.run() and let it create one. + # Once the coroutine is wrapped, we will be able to yield across + # blocking function boundaries thanks to _Genlet + return asyncio.run(allow_nested_run(coro)) + else: + # Increase the odds that in the future, we have a wrapped coroutine in + # our callstack to avoid the _run_in_thread() path. + _install_task_factory(loop) + + if loop.is_running(): + g = _Genlet.get_enclosing() + if g is None: + # If we are not running under a wrapped coroutine, we don't + # have a choice and we need to run in a separate event loop. + return _run_in_thread(coro) + else: + # This requires that we have an coroutine wrapped with + # allow_nested_run() higher in the callstack, that we will be + # able to use as a conduit to yield the futures. + return g.consume_coro(coro, None) + else: + return loop.run_until_complete(coro) def asyncf(f): diff --git a/setup.py b/setup.py index e8b7d0fbe..0bd3724f9 100644 --- a/setup.py +++ b/setup.py @@ -104,7 +104,7 @@ def _load_path(filepath): 'pandas', 'pytest', 'lxml', # More robust xml parsing - 'nest_asyncio', # Allows running nested asyncio loops + 'greenlet', # Allows running nested asyncio loops 'future', # for the "past" Python package 'ruamel.yaml >= 0.15.72', # YAML formatted config parsing ],