Skip to content

Commit

Permalink
Introduces container class to run coroutines in a dedicated thread.
Browse files Browse the repository at this point in the history
Implements TFF's sync-wrapping-async pattern with an instance of this class, allowing for interop between synchronous TFF code and asyncio at higher layers.

Also adds dependency from the Python ThreadDelegatingExecutor on this class, unifying these two usages.

Represents one work-around for python/cpython#66435

PiperOrigin-RevId: 443189960
  • Loading branch information
jkr26 authored and tensorflow-copybara committed Jun 1, 2022
1 parent b267217 commit 63dfd31
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 50 deletions.
5 changes: 4 additions & 1 deletion tensorflow_federated/python/common_libs/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ py_library(
name = "async_utils",
srcs = ["async_utils.py"],
srcs_version = "PY3",
deps = ["@absl_py//absl/logging"],
deps = [
":tracing",
"@absl_py//absl/logging",
],
)

py_test(
Expand Down
54 changes: 54 additions & 0 deletions tensorflow_federated/python/common_libs/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
import asyncio
import contextlib
import sys
import threading
from typing import Callable

from absl import logging

from tensorflow_federated.python.common_libs import tracing


async def _log_error(awaitable):
try:
Expand Down Expand Up @@ -172,3 +175,54 @@ async def waiter():
return self._result

return waiter().__await__()


class AsyncThreadRunner():
"""Class which bridges async and synchronous synchronous interfaces.
This class serves as a resource and logic container, starting an event loop
in a separate thread and managing dispatching of coroutine functions to this
event loop in both synchronous and asynchronous interfaces.
There are two main uses of this class. First, this class can be used to wrap
interfaces which use `asyncio` in a synchronous 'run this coroutine'
interface in a manner which is compatible with integrating with other async
libraries. This feature is generally useful for backwards-compatibility (e.g.,
introducing asyncio in some component which sits on top of the synchronous
function calls this interface exposes), but should generally be viewed as
suboptimal--it is preferable in a situation like this to simply expose the
underlying async interfaces.
Second, this class can be used to delegate asynchronous work from one thread
to another, using its asynchronous interface.
"""

def __init__(self):
self._event_loop = asyncio.new_event_loop()
self._event_loop.set_task_factory(
tracing.propagate_trace_context_task_factory)

def target_fn():
self._event_loop.run_forever()

self._thread = threading.Thread(target=target_fn, daemon=True)
self._thread.start()

def finalizer(loop, thread):
loop.call_soon_threadsafe(loop.stop)
thread.join()

self._finalizer = finalizer

def __del__(self):
self._finalizer(self._event_loop, self._thread)

def run_coro_and_return_result(self, coro):
"""Runs coroutine in the managed event loop, returning the result."""
future = asyncio.run_coroutine_threadsafe(coro, self._event_loop)
return future.result()

async def await_coro_and_return_result(self, coro):
"""Runs coroutine in the managed event loop, returning the result."""
return await asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(coro, self._event_loop))
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ py_library(
srcs_version = "PY3",
deps = [
":cpp_async_execution_context",
"//tensorflow_federated/python/common_libs:async_utils",
"//tensorflow_federated/python/core/impl/context_stack:context_base",
"//tensorflow_federated/python/core/impl/executors:cardinalities_utils",
],
Expand All @@ -119,8 +120,8 @@ py_library(
srcs_version = "PY3",
deps = [
":async_execution_context",
"//tensorflow_federated/python/common_libs:async_utils",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:tracing",
"//tensorflow_federated/python/core/impl/computation:computation_base",
"//tensorflow_federated/python/core/impl/context_stack:context_base",
"//tensorflow_federated/python/core/impl/executors:cardinalities_utils",
Expand Down Expand Up @@ -158,6 +159,7 @@ py_library(
deps = [
":async_execution_context",
":compiler_pipeline",
"//tensorflow_federated/python/common_libs:async_utils",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/impl/compiler:building_blocks",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
# limitations under the License.
"""A context for execution based on an embedded executor instance."""

import asyncio

from tensorflow_federated.python.common_libs import async_utils
from tensorflow_federated.python.core.impl.context_stack import context_base
from tensorflow_federated.python.core.impl.execution_contexts import cpp_async_execution_context
from tensorflow_federated.python.core.impl.executors import cardinalities_utils
Expand All @@ -32,8 +31,8 @@ def __init__(
.CardinalityInferenceFnType = cardinalities_utils.infer_cardinalities):
self._async_execution_context = cpp_async_execution_context.AsyncSerializeAndExecuteCPPContext(
factory, compiler_fn, cardinality_inference_fn=cardinality_inference_fn)
self._loop = asyncio.new_event_loop()
self._async_runner = async_utils.AsyncThreadRunner()

def invoke(self, comp, arg):
return self._loop.run_until_complete(
return self._async_runner.run_coro_and_return_result(
self._async_execution_context.invoke(comp, arg))
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from typing import Any, Callable, List, Optional, Sequence, Union

import attr

from tensorflow_federated.python.common_libs import async_utils
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.core.impl.compiler import building_blocks
Expand Down Expand Up @@ -459,11 +459,12 @@ def __init__(self,
executor_factories: Sequence[executor_factory.ExecutorFactory],
compiler_fn: Optional[Callable[[computation_base.Computation],
MergeableCompForm]] = None):
self._async_runner = async_utils.AsyncThreadRunner()
self._async_execution_contexts = [
async_execution_context.AsyncExecutionContext(ex_factory)
for ex_factory in executor_factories
]
self._event_loop = asyncio.new_event_loop()

if compiler_fn is not None:
self._compiler_pipeline = compiler_pipeline.CompilerPipeline(compiler_fn)
else:
Expand Down Expand Up @@ -491,7 +492,7 @@ def invoke(self,
len(self._async_execution_contexts))

return type_conversions.type_to_py_container(
self._event_loop.run_until_complete(
self._async_runner.run_coro_and_return_result(
_invoke_mergeable_comp_form(comp, arg,
self._async_execution_contexts)),
comp.after_merge.type_signature.result)
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@
# information.
"""A context for execution based on an embedded executor instance."""

import asyncio
from typing import Any
from typing import Callable
from typing import Optional

from tensorflow_federated.python.common_libs import async_utils
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.common_libs import tracing
from tensorflow_federated.python.core.impl.computation import computation_base
from tensorflow_federated.python.core.impl.context_stack import context_base
from tensorflow_federated.python.core.impl.execution_contexts import async_execution_context
Expand Down Expand Up @@ -59,15 +57,12 @@ def __init__(
executor_fn=executor_fn,
compiler_fn=compiler_fn,
cardinality_inference_fn=cardinality_inference_fn)

self._event_loop = asyncio.new_event_loop()
self._event_loop.set_task_factory(
tracing.propagate_trace_context_task_factory)
self._async_runner = async_utils.AsyncThreadRunner()

@property
def executor_factory(self):
return self._executor_factory

def invoke(self, comp, arg):
return self._event_loop.run_until_complete(
return self._async_runner.run_coro_and_return_result(
self._async_context.invoke(comp, arg))
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import collections

from absl.testing import absltest
Expand Down Expand Up @@ -213,6 +214,23 @@ def identity(x):
with self.assertRaises(executors_errors.CardinalityError):
identity(data)

def test_sync_interface_interops_with_asyncio(self):

@tensorflow_computation.tf_computation(tf.int32)
def add_one(x):
return x + 1

async def sleep_and_add_one(x):
await asyncio.sleep(0.1)
return add_one(x)

factory = executor_stacks.local_executor_factory()
context = sync_execution_context.ExecutionContext(
factory, cardinality_inference_fn=lambda x, y: {placements.CLIENTS: 1})
with context_stack_impl.context_stack.install(context):
one = asyncio.run(sleep_and_add_one(0))
self.assertEqual(one, 1)


if __name__ == '__main__':
absltest.main()
1 change: 1 addition & 0 deletions tensorflow_federated/python/core/impl/executors/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,7 @@ py_library(
deps = [
":executor_base",
":executor_value_base",
"//tensorflow_federated/python/common_libs:async_utils",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/common_libs:tracing",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,28 @@
# information.
"""A concurrent executor that does work asynchronously in multiple threads."""

import asyncio
import functools
import threading
from typing import Optional
import weakref

from absl import logging

from tensorflow_federated.python.common_libs import async_utils
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.common_libs import tracing
from tensorflow_federated.python.core.impl.executors import executor_base as eb
from tensorflow_federated.python.core.impl.executors import executor_value_base as evb


def _delegate_with_trace_ctx(coro, event_loop):
def _delegate_with_trace_ctx(coro, async_runner):
coro_with_trace_ctx = tracing.wrap_coroutine_in_current_trace_context(coro)
return asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(coro_with_trace_ctx, event_loop))
return async_runner.await_coro_and_return_result(coro_with_trace_ctx)


class ThreadDelegatingExecutorValue(evb.ExecutorValue):
"""An ExecutorValue which delegates `compute` to an external event loop."""

def __init__(self, value: evb.ExecutorValue, event_loop):
def __init__(self, value: evb.ExecutorValue,
async_runner: async_utils.AsyncThreadRunner):
self._value = value
self._event_loop = event_loop
self._async_runner = async_runner

@property
def internal_representation(self) -> evb.ExecutorValue:
Expand All @@ -56,7 +51,7 @@ def type_signature(self):

async def compute(self):
return await _delegate_with_trace_ctx(self._value.compute(),
self._event_loop)
self._async_runner)


class ThreadDelegatingExecutor(eb.Executor):
Expand All @@ -77,25 +72,7 @@ def __init__(self, target_executor: eb.Executor):
"""
py_typecheck.check_type(target_executor, eb.Executor)
self._target_executor = target_executor
self._event_loop = asyncio.new_event_loop()
self._event_loop.set_task_factory(
tracing.propagate_trace_context_task_factory)

def run_loop(loop):
loop.run_forever()
loop.close()

self._thread = threading.Thread(
target=functools.partial(run_loop, self._event_loop), daemon=True)
self._thread.start()

def finalizer(loop, thread):
logging.debug('Finalizing, joining thread.')
loop.call_soon_threadsafe(loop.stop)
thread.join()
logging.debug('Thread joined.')

weakref.finalize(self, finalizer, self._event_loop, self._thread)
self._async_runner = async_utils.AsyncThreadRunner()

def close(self):
# Close does not clean up the event loop or thread.
Expand All @@ -107,8 +84,8 @@ def close(self):

async def _delegate(self, coro):
"""Runs a coroutine which returns an executor value on the event loop."""
result_value = await _delegate_with_trace_ctx(coro, self._event_loop)
return ThreadDelegatingExecutorValue(result_value, self._event_loop)
result_value = await _delegate_with_trace_ctx(coro, self._async_runner)
return ThreadDelegatingExecutorValue(result_value, self._async_runner)

@tracing.trace
async def create_value(self, value, type_spec=None) -> evb.ExecutorValue:
Expand Down

0 comments on commit 63dfd31

Please sign in to comment.