From cb9cd9add7ba81321dd9a7d007c059a1f278d6ae Mon Sep 17 00:00:00 2001 From: Douglas Raillard Date: Thu, 9 May 2024 16:08:42 +0100 Subject: [PATCH] utils/asyn: Use nest_asyncio when possible nest_asyncio only supports the event loop implementation from the standard library, but allows seemingly blocking calls to actually be non-blocking in any situation, as opposed to the greenlet implementation that sometimes is forced to run the coroutine in a separate thread. --- devlib/utils/asyn.py | 136 ++++++++++++++++++++++++++++++++++--------- setup.py | 1 + 2 files changed, 109 insertions(+), 28 deletions(-) diff --git a/devlib/utils/asyn.py b/devlib/utils/asyn.py index 88ceb99fb..a16981d5f 100644 --- a/devlib/utils/asyn.py +++ b/devlib/utils/asyn.py @@ -20,6 +20,7 @@ import abc import asyncio +import asyncio.events import functools import itertools import contextlib @@ -28,9 +29,73 @@ import inspect import threading from concurrent.futures import ThreadPoolExecutor -from weakref import WeakSet +from weakref import WeakSet, WeakKeyDictionary from greenlet import greenlet +import nest_asyncio + + + +def _apply_nest_asyncio(): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + + policy = asyncio.get_event_loop_policy() + + # Only apply nest_asyncio if the current event loop policy provides a + # BaseEventLoop from the standard library, as once nest_asyncio.apply() has + # been called, the policy itself is patched and there is no coming back + # from that. If the loop turns out to be non-patchable, every + # loop.run_until_complete()/asyncio.run() will just raise. + if isinstance(loop, asyncio.BaseEventLoop) and isinstance(policy, asyncio.events.BaseDefaultEventLoopPolicy): + try: + nest_asyncio.apply() + except Exception: + pass + +_apply_nest_asyncio() + + +_USE_NEST_ASYNCIO_LOCK = threading.RLock() +_USE_NEST_ASYNCIO_LOOP = WeakKeyDictionary() +def _use_nest_asyncio(loop=None): + if loop is None: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + + def _run(coro): + try: + return loop.run_until_complete(coro) + finally: + # Suppress the "coroutine was never awaited" warning + coro.close() + + async def test_nested(): + # Smoke-test that should trigger an exception if the event loop + # in use is not supported by asyncio or it did not get patched + # somehow. + _run(asyncio.sleep(0)) + + with _USE_NEST_ASYNCIO_LOCK: + try: + return _USE_NEST_ASYNCIO_LOOP[loop] + except KeyError: + # Break infinite recursion as _use_nest_asyncio() can be used by + # the task factory set using loop.set_task_factory() + _USE_NEST_ASYNCIO_LOOP[loop] = False + try: + _run(test_nested()) + except Exception: + ok = False + else: + ok = True + + _USE_NEST_ASYNCIO_LOOP[loop] = ok + return ok def create_task(awaitable, name=None): @@ -419,6 +484,17 @@ def allow_nested_run(coro): .. warning:: The coroutine needs to be consumed in the same OS thread it was created in. """ + return _allow_nested_run(coro, loop=None) + + +def _allow_nested_run(coro, loop=None): + if _use_nest_asyncio(loop): + return coro + else: + return _do_allow_nested_run(coro) + + +def _do_allow_nested_run(coro): return _AwaitableGen.wrap_coro(coro) @@ -430,8 +506,9 @@ def _coro_thread_f(coro): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - # The coroutine needs to be wrapped in the same thread that will consume it, - coro = allow_nested_run(coro) + _install_task_factory(loop) + # The coroutine needs to be wrapped in the same thread that will consume it, + coro = _allow_nested_run(coro, loop) return loop.run_until_complete(coro) @@ -450,7 +527,7 @@ def default_factory(loop, coro, context=None): make_task = loop.get_task_factory() or default_factory def factory(loop, coro, context=None): - coro = allow_nested_run(coro) + coro = _allow_nested_run(coro, loop) return make_task(loop, coro, context=context) loop.set_task_factory(factory) @@ -483,32 +560,35 @@ def run(coro): """ assert inspect.getcoroutinestate(coro) == inspect.CORO_CREATED - try: - loop = asyncio.get_running_loop() - except RuntimeError: - # We are not currently running an event loop, so it's ok to just use - # asyncio.run() and let it create one. - # Once the coroutine is wrapped, we will be able to yield across - # blocking function boundaries thanks to _Genlet - return asyncio.run(allow_nested_run(coro)) + if _use_nest_asyncio(): + return asyncio.run(coro) else: - # Increase the odds that in the future, we have a wrapped coroutine in - # our callstack to avoid the _run_in_thread() path. - _install_task_factory(loop) - - if loop.is_running(): - g = _Genlet.get_enclosing() - if g is None: - # If we are not running under a wrapped coroutine, we don't - # have a choice and we need to run in a separate event loop. - return _run_in_thread(coro) - else: - # This requires that we have an coroutine wrapped with - # allow_nested_run() higher in the callstack, that we will be - # able to use as a conduit to yield the futures. - return g.consume_coro(coro, None) + try: + loop = asyncio.get_running_loop() + except RuntimeError: + # We are not currently running an event loop, so it's ok to just use + # asyncio.run() and let it create one. + # Once the coroutine is wrapped, we will be able to yield across + # blocking function boundaries thanks to _Genlet + return asyncio.run(_do_allow_nested_run(coro)) else: - return loop.run_until_complete(coro) + # Increase the odds that in the future, we have a wrapped coroutine in + # our callstack to avoid the _run_in_thread() path. + _install_task_factory(loop) + + if loop.is_running(): + g = _Genlet.get_enclosing() + if g is None: + # If we are not running under a wrapped coroutine, we don't + # have a choice and we need to run in a separate event loop. + return _run_in_thread(coro) + else: + # This requires that we have an coroutine wrapped with + # allow_nested_run() higher in the callstack, that we will be + # able to use as a conduit to yield the futures. + return g.consume_coro(coro, None) + else: + return loop.run_until_complete(coro) def asyncf(f): diff --git a/setup.py b/setup.py index 0bd3724f9..527383ba7 100644 --- a/setup.py +++ b/setup.py @@ -104,6 +104,7 @@ def _load_path(filepath): 'pandas', 'pytest', 'lxml', # More robust xml parsing + 'nest_asyncio', # Allows running nested asyncio loops 'greenlet', # Allows running nested asyncio loops 'future', # for the "past" Python package 'ruamel.yaml >= 0.15.72', # YAML formatted config parsing