From be554151d6270570d48a810f3dc0e9df1b030d03 Mon Sep 17 00:00:00 2001 From: Douglas Raillard Date: Wed, 22 May 2024 18:31:36 +0100 Subject: [PATCH] utils/asyn: Ensure the async generators inside context managers are fully consumed on a single event loop Context managers implemented using contextlib.asynccontextmanager() rely on an async generators. This async generator must be fully consumed on a single event loop, otherwise: 1. The async generator will be closed after the __enter__ run() call, as event loops close all async generators upon closing. 2. The async generator would end up being migrated from one event loop to another, which is not going to work. Sidestep both issues by ensuring that the same event loop is in use for both __enter__() and __exit__(). If necessary, _AsyncPolymorphicCM() creates and manages its own event loop to ensure this is the case. --- devlib/utils/asyn.py | 106 +++++++++++++++++++++++++++++++++---------- 1 file changed, 82 insertions(+), 24 deletions(-) diff --git a/devlib/utils/asyn.py b/devlib/utils/asyn.py index cbad976c6..7b209f7de 100644 --- a/devlib/utils/asyn.py +++ b/devlib/utils/asyn.py @@ -546,30 +546,34 @@ def run(coro): # blocking function boundaries thanks to _Genlet return asyncio.run(_do_allow_nested_run(coro)) else: - # Increase the odds that in the future, we have a wrapped coroutine in - # our callstack to avoid the _run_in_thread() path. - _install_task_factory(loop) + return _run_in_loop(loop, coro) - if loop.is_running(): - g = _Genlet.get_enclosing() - if g is None: - # If we are not running under a wrapped coroutine, we don't - # have a choice and we need to run in a separate event loop. We - # cannot just create another event loop and install it, as - # asyncio forbids that, so the only choice is doing this in a - # separate thread that we fully control. - return _run_in_thread(coro) - else: - # This requires that we have an coroutine wrapped with - # allow_nested_run() higher in the callstack, that we will be - # able to use as a conduit to yield the futures. - return g.consume_coro(coro, None) + +def _run_in_loop(loop, coro): + # Increase the odds that in the future, we have a wrapped coroutine in + # our callstack to avoid the _run_in_thread() path. + _install_task_factory(loop) + + if loop.is_running(): + g = _Genlet.get_enclosing() + if g is None: + # If we are not running under a wrapped coroutine, we don't + # have a choice and we need to run in a separate event loop. We + # cannot just create another event loop and install it, as + # asyncio forbids that, so the only choice is doing this in a + # separate thread that we fully control. + return _run_in_thread(coro) else: - # In the odd case a loop was installed but is not running, we just - # use it. With _install_task_factory(), we should have the - # top-level Task run an instrumented coroutine (wrapped with - # allow_nested_run()) - return loop.run_until_complete(coro) + # This requires that we have an coroutine wrapped with + # allow_nested_run() higher in the callstack, that we will be + # able to use as a conduit to yield the futures. + return g.consume_coro(coro, None) + else: + # In the odd case a loop was installed but is not running, we just + # use it. With _install_task_factory(), we should have the + # top-level Task run an instrumented coroutine (wrapped with + # allow_nested_run()) + return loop.run_until_complete(coro) def asyncf(f): @@ -628,8 +632,35 @@ class _AsyncPolymorphicCM: Wrap an async context manager such that it exposes a synchronous API as well for backward compatibility. """ + _nested = threading.local() + + def _get_nesting(self): + try: + return self._nested.x + except AttributeError: + self._nested.x = 0 + return 0 + + def _update_nesting(self, n): + x = self._get_nesting() + n + self._nested.x = x + return bool(x) + def __init__(self, async_cm): self.cm = async_cm + self._loop = None + + def _close_loop(self): + reentered = self._update_nesting(0) + if not reentered: + loop = self._loop + self._loop = None + if loop is not None: + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.run_until_complete( + loop.shutdown_default_executor() + ) + loop.close() def __aenter__(self, *args, **kwargs): return self.cm.__aenter__(*args, **kwargs) @@ -638,10 +669,37 @@ def __aexit__(self, *args, **kwargs): return self.cm.__aexit__(*args, **kwargs) def __enter__(self, *args, **kwargs): - return run(self.cm.__aenter__(*args, **kwargs)) + self._update_nesting(1) + coro = self.cm.__aenter__(*args, **kwargs) + # If there is already a running loop, no need to create a new one + try: + asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + self._loop = loop + try: + asyncio.set_event_loop(loop) + return _run_in_loop(loop, coro) + except BaseException: + self._close_loop() + raise + else: + return run(coro) def __exit__(self, *args, **kwargs): - return run(self.cm.__aexit__(*args, **kwargs)) + try: + self._update_nesting(-1) + coro = self.cm.__aexit__(*args, **kwargs) + loop = self._loop + if loop is None: + return run(coro) + else: + return _run_in_loop(loop, coro) + finally: + self._close_loop() + + def __del__(self): + self._close_loop() def asynccontextmanager(f):