diff --git a/tensorflow_federated/python/common_libs/BUILD b/tensorflow_federated/python/common_libs/BUILD index 4dff8a5646..adfbe0ce81 100644 --- a/tensorflow_federated/python/common_libs/BUILD +++ b/tensorflow_federated/python/common_libs/BUILD @@ -26,7 +26,10 @@ py_library( name = "async_utils", srcs = ["async_utils.py"], srcs_version = "PY3", - deps = ["@absl_py//absl/logging"], + deps = [ + ":tracing", + "@absl_py//absl/logging", + ], ) py_test( diff --git a/tensorflow_federated/python/common_libs/async_utils.py b/tensorflow_federated/python/common_libs/async_utils.py index 158df647c6..2f9b2fb614 100644 --- a/tensorflow_federated/python/common_libs/async_utils.py +++ b/tensorflow_federated/python/common_libs/async_utils.py @@ -16,10 +16,13 @@ import asyncio import contextlib import sys +import threading from typing import Callable from absl import logging +from tensorflow_federated.python.common_libs import tracing + async def _log_error(awaitable): try: @@ -172,3 +175,54 @@ async def waiter(): return self._result return waiter().__await__() + + +class AsyncThreadRunner(): + """Class which bridges async and synchronous synchronous interfaces. + + This class serves as a resource and logic container, starting an event loop + in a separate thread and managing dispatching of coroutine functions to this + event loop in both synchronous and asynchronous interfaces. + + There are two main uses of this class. First, this class can be used to wrap + interfaces which use `asyncio` in a synchronous 'run this coroutine' + interface in a manner which is compatible with integrating with other async + libraries. This feature is generally useful for backwards-compatibility (e.g., + introducing asyncio in some component which sits on top of the synchronous + function calls this interface exposes), but should generally be viewed as + suboptimal--it is preferable in a situation like this to simply expose the + underlying async interfaces. + + Second, this class can be used to delegate asynchronous work from one thread + to another, using its asynchronous interface. + """ + + def __init__(self): + self._event_loop = asyncio.new_event_loop() + self._event_loop.set_task_factory( + tracing.propagate_trace_context_task_factory) + + def target_fn(): + self._event_loop.run_forever() + + self._thread = threading.Thread(target=target_fn, daemon=True) + self._thread.start() + + def finalizer(loop, thread): + loop.call_soon_threadsafe(loop.stop) + thread.join() + + self._finalizer = finalizer + + def __del__(self): + self._finalizer(self._event_loop, self._thread) + + def run_coro_and_return_result(self, coro): + """Runs coroutine in the managed event loop, returning the result.""" + future = asyncio.run_coroutine_threadsafe(coro, self._event_loop) + return future.result() + + async def await_coro_and_return_result(self, coro): + """Runs coroutine in the managed event loop, returning the result.""" + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, self._event_loop)) diff --git a/tensorflow_federated/python/core/impl/execution_contexts/BUILD b/tensorflow_federated/python/core/impl/execution_contexts/BUILD index 3df16ae050..55b46279eb 100644 --- a/tensorflow_federated/python/core/impl/execution_contexts/BUILD +++ b/tensorflow_federated/python/core/impl/execution_contexts/BUILD @@ -108,6 +108,7 @@ py_library( srcs_version = "PY3", deps = [ ":cpp_async_execution_context", + "//tensorflow_federated/python/common_libs:async_utils", "//tensorflow_federated/python/core/impl/context_stack:context_base", "//tensorflow_federated/python/core/impl/executors:cardinalities_utils", ], @@ -119,8 +120,8 @@ py_library( srcs_version = "PY3", deps = [ ":async_execution_context", + "//tensorflow_federated/python/common_libs:async_utils", "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/common_libs:tracing", "//tensorflow_federated/python/core/impl/computation:computation_base", "//tensorflow_federated/python/core/impl/context_stack:context_base", "//tensorflow_federated/python/core/impl/executors:cardinalities_utils", @@ -158,6 +159,7 @@ py_library( deps = [ ":async_execution_context", ":compiler_pipeline", + "//tensorflow_federated/python/common_libs:async_utils", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:structure", "//tensorflow_federated/python/core/impl/compiler:building_blocks", diff --git a/tensorflow_federated/python/core/impl/execution_contexts/cpp_sync_execution_context.py b/tensorflow_federated/python/core/impl/execution_contexts/cpp_sync_execution_context.py index 1f2bd7ab37..90d3e11fbe 100644 --- a/tensorflow_federated/python/core/impl/execution_contexts/cpp_sync_execution_context.py +++ b/tensorflow_federated/python/core/impl/execution_contexts/cpp_sync_execution_context.py @@ -13,8 +13,7 @@ # limitations under the License. """A context for execution based on an embedded executor instance.""" -import asyncio - +from tensorflow_federated.python.common_libs import async_utils from tensorflow_federated.python.core.impl.context_stack import context_base from tensorflow_federated.python.core.impl.execution_contexts import cpp_async_execution_context from tensorflow_federated.python.core.impl.executors import cardinalities_utils @@ -32,8 +31,8 @@ def __init__( .CardinalityInferenceFnType = cardinalities_utils.infer_cardinalities): self._async_execution_context = cpp_async_execution_context.AsyncSerializeAndExecuteCPPContext( factory, compiler_fn, cardinality_inference_fn=cardinality_inference_fn) - self._loop = asyncio.new_event_loop() + self._async_runner = async_utils.AsyncThreadRunner() def invoke(self, comp, arg): - return self._loop.run_until_complete( + return self._async_runner.run_coro_and_return_result( self._async_execution_context.invoke(comp, arg)) diff --git a/tensorflow_federated/python/core/impl/execution_contexts/mergeable_comp_execution_context.py b/tensorflow_federated/python/core/impl/execution_contexts/mergeable_comp_execution_context.py index e379a78aac..b60b20e581 100644 --- a/tensorflow_federated/python/core/impl/execution_contexts/mergeable_comp_execution_context.py +++ b/tensorflow_federated/python/core/impl/execution_contexts/mergeable_comp_execution_context.py @@ -24,7 +24,7 @@ from typing import Any, Callable, List, Optional, Sequence, Union import attr - +from tensorflow_federated.python.common_libs import async_utils from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.impl.compiler import building_blocks @@ -459,11 +459,12 @@ def __init__(self, executor_factories: Sequence[executor_factory.ExecutorFactory], compiler_fn: Optional[Callable[[computation_base.Computation], MergeableCompForm]] = None): + self._async_runner = async_utils.AsyncThreadRunner() self._async_execution_contexts = [ async_execution_context.AsyncExecutionContext(ex_factory) for ex_factory in executor_factories ] - self._event_loop = asyncio.new_event_loop() + if compiler_fn is not None: self._compiler_pipeline = compiler_pipeline.CompilerPipeline(compiler_fn) else: @@ -491,7 +492,7 @@ def invoke(self, len(self._async_execution_contexts)) return type_conversions.type_to_py_container( - self._event_loop.run_until_complete( + self._async_runner.run_coro_and_return_result( _invoke_mergeable_comp_form(comp, arg, self._async_execution_contexts)), comp.after_merge.type_signature.result) diff --git a/tensorflow_federated/python/core/impl/execution_contexts/sync_execution_context.py b/tensorflow_federated/python/core/impl/execution_contexts/sync_execution_context.py index 2f1e6ac149..c79ab3ba08 100644 --- a/tensorflow_federated/python/core/impl/execution_contexts/sync_execution_context.py +++ b/tensorflow_federated/python/core/impl/execution_contexts/sync_execution_context.py @@ -18,13 +18,11 @@ # information. """A context for execution based on an embedded executor instance.""" -import asyncio from typing import Any from typing import Callable from typing import Optional - +from tensorflow_federated.python.common_libs import async_utils from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.common_libs import tracing from tensorflow_federated.python.core.impl.computation import computation_base from tensorflow_federated.python.core.impl.context_stack import context_base from tensorflow_federated.python.core.impl.execution_contexts import async_execution_context @@ -59,15 +57,12 @@ def __init__( executor_fn=executor_fn, compiler_fn=compiler_fn, cardinality_inference_fn=cardinality_inference_fn) - - self._event_loop = asyncio.new_event_loop() - self._event_loop.set_task_factory( - tracing.propagate_trace_context_task_factory) + self._async_runner = async_utils.AsyncThreadRunner() @property def executor_factory(self): return self._executor_factory def invoke(self, comp, arg): - return self._event_loop.run_until_complete( + return self._async_runner.run_coro_and_return_result( self._async_context.invoke(comp, arg)) diff --git a/tensorflow_federated/python/core/impl/execution_contexts/sync_execution_context_test.py b/tensorflow_federated/python/core/impl/execution_contexts/sync_execution_context_test.py index b92c5ba33c..92c31581d3 100644 --- a/tensorflow_federated/python/core/impl/execution_contexts/sync_execution_context_test.py +++ b/tensorflow_federated/python/core/impl/execution_contexts/sync_execution_context_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import collections from absl.testing import absltest @@ -213,6 +214,23 @@ def identity(x): with self.assertRaises(executors_errors.CardinalityError): identity(data) + def test_sync_interface_interops_with_asyncio(self): + + @tensorflow_computation.tf_computation(tf.int32) + def add_one(x): + return x + 1 + + async def sleep_and_add_one(x): + await asyncio.sleep(0.1) + return add_one(x) + + factory = executor_stacks.local_executor_factory() + context = sync_execution_context.ExecutionContext( + factory, cardinality_inference_fn=lambda x, y: {placements.CLIENTS: 1}) + with context_stack_impl.context_stack.install(context): + one = asyncio.run(sleep_and_add_one(0)) + self.assertEqual(one, 1) + if __name__ == '__main__': absltest.main() diff --git a/tensorflow_federated/python/core/impl/executors/BUILD b/tensorflow_federated/python/core/impl/executors/BUILD index 9bebc55a5b..f7c55e6fbc 100644 --- a/tensorflow_federated/python/core/impl/executors/BUILD +++ b/tensorflow_federated/python/core/impl/executors/BUILD @@ -878,6 +878,7 @@ py_library( deps = [ ":executor_base", ":executor_value_base", + "//tensorflow_federated/python/common_libs:async_utils", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:structure", "//tensorflow_federated/python/common_libs:tracing", diff --git a/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py b/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py index 57095ad761..eb98fbeedd 100644 --- a/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py +++ b/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py @@ -18,14 +18,9 @@ # information. """A concurrent executor that does work asynchronously in multiple threads.""" -import asyncio -import functools -import threading from typing import Optional -import weakref - -from absl import logging +from tensorflow_federated.python.common_libs import async_utils from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.common_libs import tracing @@ -33,18 +28,18 @@ from tensorflow_federated.python.core.impl.executors import executor_value_base as evb -def _delegate_with_trace_ctx(coro, event_loop): +def _delegate_with_trace_ctx(coro, async_runner): coro_with_trace_ctx = tracing.wrap_coroutine_in_current_trace_context(coro) - return asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(coro_with_trace_ctx, event_loop)) + return async_runner.await_coro_and_return_result(coro_with_trace_ctx) class ThreadDelegatingExecutorValue(evb.ExecutorValue): """An ExecutorValue which delegates `compute` to an external event loop.""" - def __init__(self, value: evb.ExecutorValue, event_loop): + def __init__(self, value: evb.ExecutorValue, + async_runner: async_utils.AsyncThreadRunner): self._value = value - self._event_loop = event_loop + self._async_runner = async_runner @property def internal_representation(self) -> evb.ExecutorValue: @@ -56,7 +51,7 @@ def type_signature(self): async def compute(self): return await _delegate_with_trace_ctx(self._value.compute(), - self._event_loop) + self._async_runner) class ThreadDelegatingExecutor(eb.Executor): @@ -77,25 +72,7 @@ def __init__(self, target_executor: eb.Executor): """ py_typecheck.check_type(target_executor, eb.Executor) self._target_executor = target_executor - self._event_loop = asyncio.new_event_loop() - self._event_loop.set_task_factory( - tracing.propagate_trace_context_task_factory) - - def run_loop(loop): - loop.run_forever() - loop.close() - - self._thread = threading.Thread( - target=functools.partial(run_loop, self._event_loop), daemon=True) - self._thread.start() - - def finalizer(loop, thread): - logging.debug('Finalizing, joining thread.') - loop.call_soon_threadsafe(loop.stop) - thread.join() - logging.debug('Thread joined.') - - weakref.finalize(self, finalizer, self._event_loop, self._thread) + self._async_runner = async_utils.AsyncThreadRunner() def close(self): # Close does not clean up the event loop or thread. @@ -107,8 +84,8 @@ def close(self): async def _delegate(self, coro): """Runs a coroutine which returns an executor value on the event loop.""" - result_value = await _delegate_with_trace_ctx(coro, self._event_loop) - return ThreadDelegatingExecutorValue(result_value, self._event_loop) + result_value = await _delegate_with_trace_ctx(coro, self._async_runner) + return ThreadDelegatingExecutorValue(result_value, self._async_runner) @tracing.trace async def create_value(self, value, type_spec=None) -> evb.ExecutorValue: