Skip to content

Commit

Permalink
Allow nested use of asyncio.run
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Apr 27, 2023
1 parent dff8e5d commit fd22cad
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 24 deletions.
16 changes: 10 additions & 6 deletions Doc/library/asyncio-runner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,20 @@ to simplify async code usage for common wide-spread scenarios.
Running an asyncio Program
==========================

.. function:: run(coro, *, debug=None, loop_factory=None)
.. function:: run(coro, *, debug=None, loop_factory=None, running_ok=False)

Execute the :term:`coroutine` *coro* and return the result.

This function runs the passed coroutine, taking care of
If *running_ok* is ``False``, this function runs the passed coroutine, taking care of
managing the asyncio event loop, *finalizing asynchronous
generators*, and closing the executor.

This function cannot be called when another asyncio event loop is
running in the same thread.
generators*, and closing the executor. This function cannot be called when another
asyncio event loop is running in the same thread.

If *running_ok* is ``True``, this function allows running the passed coroutine even if
this code is already running in an event loop. In other words, it allows re-entering
the event loop, while an exception would be raised if *running_ok* were ``False``. If
this function is called inside an already running event loop, the same loop is used,
and it is not closed at the end.

If *debug* is ``True``, the event loop will be run in debug mode. ``False`` disables
debug mode explicitly. ``None`` is used to respect the global
Expand Down
20 changes: 11 additions & 9 deletions Lib/asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,21 +594,23 @@ def _do_shutdown(self, future):
if not self.is_closed():
self.call_soon_threadsafe(future.set_exception, ex)

def _check_running(self):
def _check_running(self, running_ok=False):
if self.is_running():
raise RuntimeError('This event loop is already running')
if events._get_running_loop() is not None:
if not running_ok and events._get_running_loop() is not None:
raise RuntimeError(
'Cannot run the event loop while another loop is running')

def run_forever(self):
def run_forever(self, running_ok=False):
"""Run until stop() is called."""
self._check_closed()
self._check_running()
self._check_running(running_ok=running_ok)
self._set_coroutine_origin_tracking(self._debug)

old_agen_hooks = sys.get_asyncgen_hooks()
try:
old_thread_id = self._thread_id
old_running_loop = events._get_running_loop()
self._thread_id = threading.get_ident()
sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook,
finalizer=self._asyncgen_finalizer_hook)
Expand All @@ -620,12 +622,12 @@ def run_forever(self):
break
finally:
self._stopping = False
self._thread_id = None
events._set_running_loop(None)
self._thread_id = old_thread_id
events._set_running_loop(old_running_loop)
self._set_coroutine_origin_tracking(False)
sys.set_asyncgen_hooks(*old_agen_hooks)

def run_until_complete(self, future):
def run_until_complete(self, future, running_ok=False):
"""Run until the Future is done.
If the argument is a coroutine, it is wrapped in a Task.
Expand All @@ -637,7 +639,7 @@ def run_until_complete(self, future):
Return the Future's result, or raise its exception.
"""
self._check_closed()
self._check_running()
self._check_running(running_ok=running_ok)

new_task = not futures.isfuture(future)
future = tasks.ensure_future(future, loop=self)
Expand All @@ -648,7 +650,7 @@ def run_until_complete(self, future):

future.add_done_callback(_run_until_complete_cb)
try:
self.run_forever()
self.run_forever(running_ok=running_ok)
except:
if new_task and future.done() and not future.cancelled():
# The coroutine raised a BaseException. Consume the exception
Expand Down
29 changes: 20 additions & 9 deletions Lib/asyncio/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ class Runner:

# Note: the class is final, it is not intended for inheritance.

def __init__(self, *, debug=None, loop_factory=None):
def __init__(self, *, debug=None, loop_factory=None, running_ok=False):
self._state = _State.CREATED
self._debug = debug
self._loop_factory = loop_factory
self._running_ok = running_ok
self._loop = None
self._context = None
self._interrupt_count = 0
Expand All @@ -59,7 +60,15 @@ def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
close = True
try:
events.get_running_loop()
if self._running_ok:
close = False
except:
pass
if close:
self.close()

def close(self):
"""Shutdown and close event loop."""
Expand All @@ -68,9 +77,11 @@ def close(self):
try:
loop = self._loop
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
loop.run_until_complete(loop.shutdown_asyncgens(), running_ok=self._running_ok)
loop.run_until_complete(
loop.shutdown_default_executor(constants.THREAD_JOIN_TIMEOUT))
loop.shutdown_default_executor(constants.THREAD_JOIN_TIMEOUT),
running_ok=self._running_ok,
)
finally:
if self._set_event_loop:
events.set_event_loop(None)
Expand All @@ -88,7 +99,7 @@ def run(self, coro, *, context=None):
if not coroutines.iscoroutine(coro):
raise ValueError("a coroutine was expected, got {!r}".format(coro))

if events._get_running_loop() is not None:
if not self._running_ok and events._get_running_loop() is not None:
# fail fast with short traceback
raise RuntimeError(
"Runner.run() cannot be called from a running event loop")
Expand All @@ -115,7 +126,7 @@ def run(self, coro, *, context=None):

self._interrupt_count = 0
try:
return self._loop.run_until_complete(task)
return self._loop.run_until_complete(task, running_ok=self._running_ok)
except exceptions.CancelledError:
if self._interrupt_count > 0:
uncancel = getattr(task, "uncancel", None)
Expand Down Expand Up @@ -157,7 +168,7 @@ def _on_sigint(self, signum, frame, main_task):
raise KeyboardInterrupt()


def run(main, *, debug=None, loop_factory=None):
def run(main, *, debug=None, loop_factory=None, running_ok=False):
"""Execute the coroutine and return the result.
This function runs the passed coroutine, taking care of
Expand Down Expand Up @@ -185,12 +196,12 @@ async def main():
asyncio.run(main())
"""
if events._get_running_loop() is not None:
if not running_ok and events._get_running_loop() is not None:
# fail fast with short traceback
raise RuntimeError(
"asyncio.run() cannot be called from a running event loop")

with Runner(debug=debug, loop_factory=loop_factory) as runner:
with Runner(debug=debug, loop_factory=loop_factory, running_ok=running_ok) as runner:
return runner.run(main)


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Allow the event loop to be re-entrant, by making it possible to call
``asyncio.run(coro, running_ok=True)`` inside an already running event loop.

0 comments on commit fd22cad

Please sign in to comment.