Skip to content

Commit

Permalink
utils/asyn: Ensure the async generators inside context managers are f…
Browse files Browse the repository at this point in the history
…ully consumed on a single event loop

Context managers implemented using contextlib.asynccontextmanager() rely
on an async generators. This async generator must be fully consumed on a
single event loop, otherwise:

1. The async generator will be closed after the __enter__ run() call, as
   event loops close all async generators upon closing.

2. The async generator would end up being migrated from one event loop
   to another, which is not going to work.

Sidestep both issues by ensuring that the same event loop is in use for
both __enter__() and __exit__(). If necessary, _AsyncPolymorphicCM()
creates and manages its own event loop to ensure this is the case.
  • Loading branch information
douglas-raillard-arm committed Jul 12, 2024
1 parent a1677c1 commit d071e9d
Showing 1 changed file with 82 additions and 24 deletions.
106 changes: 82 additions & 24 deletions devlib/utils/asyn.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,30 +546,34 @@ def run(coro):
# blocking function boundaries thanks to _Genlet
return asyncio.run(_do_allow_nested_run(coro))
else:
# Increase the odds that in the future, we have a wrapped coroutine in
# our callstack to avoid the _run_in_thread() path.
_install_task_factory(loop)
return _run_in_loop(loop, coro)

if loop.is_running():
g = _Genlet.get_enclosing()
if g is None:
# If we are not running under a wrapped coroutine, we don't
# have a choice and we need to run in a separate event loop. We
# cannot just create another event loop and install it, as
# asyncio forbids that, so the only choice is doing this in a
# separate thread that we fully control.
return _run_in_thread(coro)
else:
# This requires that we have an coroutine wrapped with
# allow_nested_run() higher in the callstack, that we will be
# able to use as a conduit to yield the futures.
return g.consume_coro(coro, None)

def _run_in_loop(loop, coro):
# Increase the odds that in the future, we have a wrapped coroutine in
# our callstack to avoid the _run_in_thread() path.
_install_task_factory(loop)

if loop.is_running():
g = _Genlet.get_enclosing()
if g is None:
# If we are not running under a wrapped coroutine, we don't
# have a choice and we need to run in a separate event loop. We
# cannot just create another event loop and install it, as
# asyncio forbids that, so the only choice is doing this in a
# separate thread that we fully control.
return _run_in_thread(coro)
else:
# In the odd case a loop was installed but is not running, we just
# use it. With _install_task_factory(), we should have the
# top-level Task run an instrumented coroutine (wrapped with
# allow_nested_run())
return loop.run_until_complete(coro)
# This requires that we have an coroutine wrapped with
# allow_nested_run() higher in the callstack, that we will be
# able to use as a conduit to yield the futures.
return g.consume_coro(coro, None)
else:
# In the odd case a loop was installed but is not running, we just
# use it. With _install_task_factory(), we should have the
# top-level Task run an instrumented coroutine (wrapped with
# allow_nested_run())
return loop.run_until_complete(coro)


def asyncf(f):
Expand Down Expand Up @@ -628,8 +632,35 @@ class _AsyncPolymorphicCM:
Wrap an async context manager such that it exposes a synchronous API as
well for backward compatibility.
"""
_nested = threading.local()

def _get_nesting(self):
try:
return self._nested.x
except AttributeError:
self._nested.x = 0
return 0

def _update_nesting(self, n):
x = self._get_nesting() + n
self._nested.x = x
return bool(x)

def __init__(self, async_cm):
self.cm = async_cm
self._loop = None

def _close_loop(self):
reentered = self._update_nesting(0)
if not reentered:
loop = self._loop
self._loop = None
if loop is not None:
loop.run_until_complete(loop.shutdown_asyncgens())
loop.run_until_complete(
loop.shutdown_default_executor()
)
loop.close()

def __aenter__(self, *args, **kwargs):
return self.cm.__aenter__(*args, **kwargs)
Expand All @@ -638,10 +669,37 @@ def __aexit__(self, *args, **kwargs):
return self.cm.__aexit__(*args, **kwargs)

def __enter__(self, *args, **kwargs):
return run(self.cm.__aenter__(*args, **kwargs))
self._update_nesting(1)
coro = self.cm.__aenter__(*args, **kwargs)
# If there is already a running loop, no need to create a new one
try:
asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
self._loop = loop
try:
asyncio.set_event_loop(loop)
return _run_in_loop(loop, coro)
except BaseException:
self._close_loop()
raise
else:
return run(coro)

def __exit__(self, *args, **kwargs):
return run(self.cm.__aexit__(*args, **kwargs))
try:
self._update_nesting(-1)
coro = self.cm.__aexit__(*args, **kwargs)
loop = self._loop
if loop is None:
return run(coro)
else:
return _run_in_loop(loop, coro)
finally:
self._close_loop()

def __del__(self):
self._close_loop()


def asynccontextmanager(f):
Expand Down

0 comments on commit d071e9d

Please sign in to comment.