diff --git a/python/xoscar/backends/context.py b/python/xoscar/backends/context.py index 23fe3f7..e6926e0 100644 --- a/python/xoscar/backends/context.py +++ b/python/xoscar/backends/context.py @@ -28,8 +28,8 @@ from ..errors import CannotCancelTask from ..utils import dataslots, fix_all_zero_ip from .allocate_strategy import AddressSpecified, AllocateStrategy -from .communication import Client, DummyClient, UCXClient -from .core import ActorCaller +from .communication import UCXClient +from .core import ActorCallerThreaded, CallerClient from .message import ( DEFAULT_PROTOCOL, ActorRefMessage, @@ -59,17 +59,16 @@ class ProfilingContext: class IndigenActorContext(BaseActorContext): - __slots__ = ("_caller", "_lock") + __slots__ = "_caller" support_allocate_strategy = True def __init__(self, address: str | None = None): BaseActorContext.__init__(self, address) - self._caller = ActorCaller() - self._lock = asyncio.Lock() + self._caller = ActorCallerThreaded() def __del__(self): - self._caller.cancel_tasks() + self._caller.stop_nonblock() async def _call( self, address: str, message: _MessageBase, wait: bool = True @@ -78,22 +77,6 @@ async def _call( Router.get_instance_or_empty(), address, message, wait=wait ) - async def _call_with_client( - self, client: Client, message: _MessageBase, wait: bool = True - ) -> Union[ResultMessage, ErrorMessage, asyncio.Future]: - return await self._caller.call_with_client(client, message, wait) - - async def _call_send_buffers( - self, - client: UCXClient, - local_buffers: list, - meta_message: _MessageBase, - wait: bool = True, - ) -> Union[ResultMessage, ErrorMessage, asyncio.Future]: - return await self._caller.call_send_buffers( - client, local_buffers, meta_message, wait - ) - @staticmethod def _process_result_message(message: Union[ResultMessage, ErrorMessage]): if isinstance(message, ResultMessage): @@ -294,29 +277,6 @@ def _gen_copy_to_buffers_message(content: Any): def _gen_copy_to_fileobjs_message(content: Any): return CopyToFileObjectsMessage(message_id=new_message_id(), content=content) # type: ignore - async def _get_copy_to_client(self, router, address) -> Client: - client = await self._caller.get_client(router, address) - if isinstance(client, DummyClient) or hasattr(client, "send_buffers"): - return client - client_types = router.get_all_client_types(address) - # For inter-process communication, the ``self._caller.get_client`` interface would not look for UCX Client, - # we still try to find UCXClient for this case. - try: - client_type = next( - client_type - for client_type in client_types - if hasattr(client_type, "send_buffers") - ) - except StopIteration: - return client - else: - return await self._caller.get_client_via_type(router, address, client_type) - - async def _get_client(self, address: str) -> Client: - router = Router.get_instance() - assert router is not None, "`copy_to` can only be used inside pools" - return await self._get_copy_to_client(router, address) - async def copy_to_buffers( self, local_buffers: list, @@ -324,52 +284,65 @@ async def copy_to_buffers( block_size: Optional[int] = None, ): address = remote_buffer_refs[0].address - client = await self._get_client(address) - if isinstance(client, UCXClient): - message = [(buf.address, buf.uid) for buf in remote_buffer_refs] - await self._call_send_buffers( - client, - local_buffers, - self._gen_switch_to_copy_to_control_message(message), - ) - else: - # ``local_buffers`` will be divided into buffers of the specified block size for transmission. - # Smaller buffers will be accumulated and sent together, - # while larger buffers will be divided and sent. - current_buf_size = 0 - one_block_data = [] - block_size = block_size or DEFAULT_TRANSFER_BLOCK_SIZE - for i, (l_buf, r_buf) in enumerate(zip(local_buffers, remote_buffer_refs)): - if current_buf_size + len(l_buf) < block_size: - one_block_data.append( - (r_buf.address, r_buf.uid, 0, len(l_buf), l_buf) - ) - current_buf_size += len(l_buf) - continue - last_start = 0 - while current_buf_size + len(l_buf) > block_size: - remain = block_size - current_buf_size - one_block_data.append( - (r_buf.address, r_buf.uid, last_start, remain, l_buf[:remain]) - ) - await self._call_with_client( + client = await self._caller.get_copy_to_client(address) + client.add_ref() # Prevent close due to idle + try: + assert isinstance(client, CallerClient) + if isinstance(client._inner, UCXClient): + message = [(buf.address, buf.uid) for buf in remote_buffer_refs] + await self._caller.call_send_buffers( + client, + local_buffers, + self._gen_switch_to_copy_to_control_message(message), + ) + else: + # ``local_buffers`` will be divided into buffers of the specified block size for transmission. + # Smaller buffers will be accumulated and sent together, + # while larger buffers will be divided and sent. + current_buf_size = 0 + one_block_data = [] + block_size = block_size or DEFAULT_TRANSFER_BLOCK_SIZE + for i, (l_buf, r_buf) in enumerate( + zip(local_buffers, remote_buffer_refs) + ): + if current_buf_size + len(l_buf) < block_size: + one_block_data.append( + (r_buf.address, r_buf.uid, 0, len(l_buf), l_buf) + ) + current_buf_size += len(l_buf) + continue + last_start = 0 + while current_buf_size + len(l_buf) > block_size: + remain = block_size - current_buf_size + one_block_data.append( + ( + r_buf.address, + r_buf.uid, + last_start, + remain, + l_buf[:remain], + ) + ) + await self._caller.call_with_client( + client, self._gen_copy_to_buffers_message(one_block_data) + ) + one_block_data = [] + current_buf_size = 0 + last_start += remain + l_buf = l_buf[remain:] + + if len(l_buf) > 0: + one_block_data.append( + (r_buf.address, r_buf.uid, last_start, len(l_buf), l_buf) + ) + current_buf_size = len(l_buf) + + if one_block_data: + await self._caller.call_with_client( client, self._gen_copy_to_buffers_message(one_block_data) ) - one_block_data = [] - current_buf_size = 0 - last_start += remain - l_buf = l_buf[remain:] - - if len(l_buf) > 0: - one_block_data.append( - (r_buf.address, r_buf.uid, last_start, len(l_buf), l_buf) - ) - current_buf_size = len(l_buf) - - if one_block_data: - await self._call_with_client( - client, self._gen_copy_to_buffers_message(one_block_data) - ) + finally: + client.de_ref() async def copy_to_fileobjs( self, @@ -378,27 +351,32 @@ async def copy_to_fileobjs( block_size: Optional[int] = None, ): address = remote_fileobj_refs[0].address - client = await self._get_client(address) - block_size = block_size or DEFAULT_TRANSFER_BLOCK_SIZE - one_block_data = [] - current_file_size = 0 - for file_obj, remote_ref in zip(local_fileobjs, remote_fileobj_refs): - while True: - file_data = await file_obj.read(block_size) # type: ignore - if file_data: - one_block_data.append( - (remote_ref.address, remote_ref.uid, file_data) - ) - current_file_size += len(file_data) - if current_file_size >= block_size: - message = self._gen_copy_to_fileobjs_message(one_block_data) - await self._call_with_client(client, message) - one_block_data.clear() - current_file_size = 0 - else: - break - - if current_file_size > 0: - message = self._gen_copy_to_fileobjs_message(one_block_data) - await self._call_with_client(client, message) - one_block_data.clear() + client = await self._caller.get_copy_to_client(address) + client.add_ref() # Prevent close due to idle + try: + assert isinstance(client, CallerClient) + block_size = block_size or DEFAULT_TRANSFER_BLOCK_SIZE + one_block_data = [] + current_file_size = 0 + for file_obj, remote_ref in zip(local_fileobjs, remote_fileobj_refs): + while True: + file_data = await file_obj.read(block_size) # type: ignore + if file_data: + one_block_data.append( + (remote_ref.address, remote_ref.uid, file_data) + ) + current_file_size += len(file_data) + if current_file_size >= block_size: + message = self._gen_copy_to_fileobjs_message(one_block_data) + await self._caller.call_with_client(client, message) + one_block_data.clear() + current_file_size = 0 + else: + break + + if current_file_size > 0: + message = self._gen_copy_to_fileobjs_message(one_block_data) + await self._caller.call_with_client(client, message) + one_block_data.clear() + finally: + client.de_ref() diff --git a/python/xoscar/backends/core.py b/python/xoscar/backends/core.py index c8fb54f..797060b 100644 --- a/python/xoscar/backends/core.py +++ b/python/xoscar/backends/core.py @@ -18,12 +18,16 @@ import asyncio import copy import logging -from typing import Type, Union +import os +import threading +import time +from typing import Dict, List, Optional, Tuple, Type, Union from .._utils import Timer +from ..constants import XOSCAR_IDLE_TIMEOUT from ..errors import ServerClosed from ..profiling import get_profiling_data -from .communication import Client, UCXClient +from .communication import Client, DummyClient, UCXClient from .message import DeserializeMessageFailed, ErrorMessage, ResultMessage, _MessageBase from .router import Router @@ -31,119 +35,351 @@ logger = logging.getLogger(__name__) +class ActorCallerThreaded: + """ + Just a proxy class to ActorCaller. + Each thread as its own ActorCaller in case in multi threaded env. + NOTE that it does not cleanup when thread exit + """ + + def __init__(self): + self.local = threading.local() + + def _get_local(self) -> ActorCaller: + try: + return self.local.caller + except AttributeError: + caller = ActorCaller() + self.local.caller = caller + return caller + + async def get_copy_to_client(self, address: str) -> CallerClient: + caller = self._get_local() + return await caller.get_copy_to_client(address) + + async def call_with_client( + self, client: CallerClient, message: _MessageBase, wait: bool = True + ) -> ResultMessage | ErrorMessage | asyncio.Future: + """ + Althoght we've already wrapped CallerClient in get_client(), + might directly call from api (not recommended), Compatible with old usage. + """ + caller = self._get_local() + return await caller.call_with_client(client, message, wait) + + async def call_send_buffers( + self, + client: CallerClient, + local_buffers: List, + meta_message: _MessageBase, + wait: bool = True, + ) -> ResultMessage | ErrorMessage | asyncio.Future: + """ + Althoght we've already wrapped CallerClient in get_client(), + might directly call from api (not recommended), Compatible with old usage. + """ + caller = self._get_local() + return await caller.call_send_buffers(client, local_buffers, meta_message, wait) + + async def call( + self, + router: Router, + dest_address: str, + message: _MessageBase, + wait: bool = True, + ) -> ResultMessage | ErrorMessage | asyncio.Future: + caller = self._get_local() + return await caller.call(router, dest_address, message, wait) + + async def stop(self): + local_caller = self._get_local() + await local_caller.stop() + + def stop_nonblock(self): + local_caller = self._get_local() + local_caller.stop_nonblock() + + +class CallerClient: + """ + A proxy class for under layer client, keep track for its ref_count. + """ + + _inner: Client + _client_to_message_futures: dict[bytes, asyncio.Future] + _ref_count: int + _last_used: float + _listen_task: Optional[asyncio.Task] + _dest_address: str + + def __init__(self, client: Client, dest_address: str): + self._inner = client + self._ref_count = 0 + self._last_used = 0 + self._dest_address = dest_address + self._listen_task = None + self._client_to_message_futures = dict() + + def start_receiver(self): + self._listen_task = asyncio.create_task(self._listen()) + + def __repr__(self) -> str: + return self._inner.__repr__() + + def abort(self): + if self._listen_task is None: + return + try: + self._listen_task.cancel() + except: + pass + self._listen_task = None + # Since listen task is aborted, need someone to cleanup + self._cleanup(ServerClosed("Connection abort")) + + async def send( + self, + message: _MessageBase, + wait_response: asyncio.Future, + local_buffers: Optional[list] = None, + ): + self.add_ref() + self._client_to_message_futures[message.message_id] = wait_response + try: + if local_buffers is None: + await self._inner.send(message) + else: + assert isinstance(self._inner, UCXClient) + await self._inner.send_buffers(local_buffers, message) + self._last_used = time.time() + except ConnectionError: + try: + # listen task will be notify by connection to exit and cleanup + await self._inner.close() + except ConnectionError: + # close failed, ignore it + pass + raise ServerClosed(f"{self} closed") + except: + try: + # listen task will be notify by connection to exit and cleanup + await self._inner.close() + except ConnectionError: + # close failed, ignore it + pass + raise + + async def close(self): + """ + Close connection. + """ + self.abort() + if not self.closed: + await self._inner.close() + + @property + def closed(self) -> bool: + return self._inner.closed + + def _cleanup(self, e): + message_futures = self._client_to_message_futures + self._client_to_message_futures = dict() + if e is None: + e = ServerClosed(f"Remote server {self._inner.dest_address} closed") + for future in message_futures.values(): + future.set_exception(copy.copy(e)) + + async def _listen(self): + client = self._inner + while not client.closed: + try: + try: + message: _MessageBase = await client.recv() + self._last_used = time.time() + except (EOFError, ConnectionError, BrokenPipeError) as e: + logger.debug(f"{client.dest_address} close due to {e}") + # remote server closed, close client and raise ServerClosed + try: + await client.close() + except (ConnectionError, BrokenPipeError): + # close failed, ignore it + pass + raise ServerClosed( + f"Remote server {client.dest_address} closed: {e}" + ) from None + future = self._client_to_message_futures.pop(message.message_id) + if not future.done(): + future.set_result(message) + except DeserializeMessageFailed as e: + message_id = e.message_id + future = self._client_to_message_futures.pop(message_id) + future.set_exception(e.__cause__) # type: ignore + except Exception as e: # noqa: E722 # pylint: disable=bare-except + self._cleanup(e) + logger.debug(f"{e}", exc_info=True) + finally: + # Counter part of self.add_ref() in send() + self.de_ref() + # message may have Ray ObjectRef, delete it early in case next loop doesn't run + # as soon as expected. + try: + del message + except NameError: + pass + try: + del future + except NameError: + pass + await asyncio.sleep(0) + self._cleanup(None) + + def add_ref(self): + self._ref_count += 1 + + def de_ref(self): + self._ref_count -= 1 + self._last_used = time.time() + + def get_ref(self) -> int: + return self._ref_count + + def is_idle(self, idle_timeout: int) -> bool: + return self.get_ref() == 0 and time.time() > idle_timeout + self._last_used + + class ActorCaller: - __slots__ = "_client_to_message_futures", "_clients", "_profiling_data" - _client_to_message_futures: dict[Client, dict[bytes, asyncio.Future]] - _clients: dict[Client, asyncio.Task] + _clients: Dict[Client, CallerClient] + # _addr_to_clients: A cache from dest_address to Caller, (regardless what mapping router did), + # if multiple ClientType only keep one + _addr_to_clients: Dict[Tuple[str, Optional[Type[Client]]], CallerClient] + _check_task: asyncio.Task def __init__(self): - self._client_to_message_futures = dict() self._clients = dict() + self._addr_to_clients = dict() self._profiling_data = get_profiling_data() + self._check_task = None # Due to cython env If start task here will not get the shared copy of self. + self._default_idle_timeout = int( + os.environ.get("XOSCAR_IDLE_TIMEOUT", XOSCAR_IDLE_TIMEOUT) + ) + self._loop = asyncio.get_running_loop() + + async def periodic_check(self): + try: + while True: + router = Router.get_instance_or_empty() + config = router.get_config() + idle_timeout = config.get( + "idle_timeout", + self._default_idle_timeout, + ) + await asyncio.sleep(idle_timeout) + try_remove = [] + to_remove = [] + for client in self._clients.values(): + if client.closed: + to_remove.append(client) + elif client.is_idle(idle_timeout): + try_remove.append(client) + for client in to_remove: + self._force_remove_client(client) + for client in try_remove: + await self._try_remove_client(client, idle_timeout) + logger.debug("periodic_check: %d clients left", len(self._clients)) - def _listen_client(self, client: Client): - if client not in self._clients: - self._clients[client] = asyncio.create_task(self._listen(client)) - self._client_to_message_futures[client] = dict() - client_count = len(self._clients) - if client_count >= 100: # pragma: no cover - if (client_count - 100) % 10 == 0: # pragma: no cover - logger.warning( - "Actor caller has created too many clients (%s >= 100), " - "the global router may not be set.", - client_count, - ) + addr_to_remove = [] + for key, client in self._addr_to_clients.items(): + if client.closed: + addr_to_remove.append(key) + for key in addr_to_remove: + self._addr_to_clients.pop(key, None) + except Exception as e: + logger.error(e, exc_info=True) + + async def _try_remove_client(self, client: CallerClient, idle_timeout): + if client.closed: + self._force_remove_client(client) + logger.debug(f"Removed closed client {client}") + elif client.is_idle(idle_timeout): + self._force_remove_client(client) + logger.debug(f"Removed idle client {client}") + await client.close() + + def _force_remove_client(self, client: CallerClient): + """ + Force removal client because is close + """ + self._clients.pop(client._inner, None) + client.abort() + # althoght not necessarily dest_address is in _addr_to_clients, it's double ensurrence + self._addr_to_clients.pop((client._dest_address, None), None) + self._addr_to_clients.pop((client._dest_address, client._inner.__class__), None) + + def _add_client( + self, dest_address: str, client_type: Optional[Type[Client]], client: Client + ) -> CallerClient: + caller_client = CallerClient(client, dest_address) + caller_client.start_receiver() + self._addr_to_clients[(dest_address, client_type)] = caller_client + self._clients[caller_client._inner] = caller_client + if self._check_task is None: + # Delay the start of background task so that it get a ref of self + self._check_task = asyncio.create_task(self.periodic_check()) + return caller_client + + async def get_copy_to_client(self, address: str) -> CallerClient: + router = Router.get_instance() + assert router is not None, "`copy_to` can only be used inside pools" + client = await self.get_client(router, address) + if isinstance(client._inner, DummyClient) or hasattr( + client._inner, "send_buffers" + ): + return client + client_types = router.get_all_client_types(address) + # For inter-process communication, the ``self._caller.get_client`` interface would not look for UCX Client, + # we still try to find UCXClient for this case. + try: + client_type = next( + client_type + for client_type in client_types + if hasattr(client_type, "send_buffers") + ) + except StopIteration: + return client + else: + return await self.get_client_via_type(router, address, client_type) async def get_client_via_type( self, router: Router, dest_address: str, client_type: Type[Client] - ) -> Client: - client = await router.get_client_via_type( - dest_address, client_type, from_who=self - ) - self._listen_client(client) + ) -> CallerClient: + client = self._addr_to_clients.get((dest_address, client_type), None) + if client is None or client.closed: + _client = await router.get_client_via_type(dest_address, client_type) + client = self._add_client(dest_address, client_type, _client) return client - async def get_client(self, router: Router, dest_address: str) -> Client: - client = await router.get_client(dest_address, from_who=self) - self._listen_client(client) + async def get_client(self, router: Router, dest_address: str) -> CallerClient: + client = self._addr_to_clients.get((dest_address, None), None) + if client is None or client.closed: + _client = await router.get_client(dest_address) + client = self._add_client(dest_address, None, _client) return client - async def _listen(self, client: Client): - try: - while not client.closed: - try: - try: - message: _MessageBase = await client.recv() - except (EOFError, ConnectionError, BrokenPipeError) as e: - # AssertionError is from get_header - # remote server closed, close client and raise ServerClosed - logger.debug(f"{client.dest_address} close due to {e}") - try: - await client.close() - except (ConnectionError, BrokenPipeError): - # close failed, ignore it - pass - raise ServerClosed( - f"Remote server {client.dest_address} closed: {e}" - ) from None - future = self._client_to_message_futures[client].pop( - message.message_id - ) - if not future.done(): - future.set_result(message) - except DeserializeMessageFailed as e: - message_id = e.message_id - future = self._client_to_message_futures[client].pop(message_id) - future.set_exception(e.__cause__) # type: ignore - except Exception as e: # noqa: E722 # pylint: disable=bare-except - message_futures = self._client_to_message_futures[client] - self._client_to_message_futures[client] = dict() - for future in message_futures.values(): - future.set_exception(copy.copy(e)) - finally: - # message may have Ray ObjectRef, delete it early in case next loop doesn't run - # as soon as expected. - try: - del message - except NameError: - pass - try: - del future - except NameError: - pass - await asyncio.sleep(0) - - message_futures = self._client_to_message_futures[client] - self._client_to_message_futures[client] = dict() - error = ServerClosed(f"Remote server {client.dest_address} closed") - for future in message_futures.values(): - future.set_exception(copy.copy(error)) - finally: - try: - await client.close() - except: # noqa: E722 # nosec # pylint: disable=bare-except - # ignore all error if fail to close at last - pass - async def call_with_client( - self, client: Client, message: _MessageBase, wait: bool = True + self, client: CallerClient, message: _MessageBase, wait: bool = True ) -> ResultMessage | ErrorMessage | asyncio.Future: + """ + Althoght we've already wrapped CallerClient in get_client(), + might directly call from api (not recommended), Compatible with old usage. + """ loop = asyncio.get_running_loop() wait_response = loop.create_future() - self._client_to_message_futures[client][message.message_id] = wait_response - with Timer() as timer: - try: - await client.send(message) - except ConnectionError: - try: - await client.close() - except ConnectionError: - # close failed, ignore it - pass - raise ServerClosed(f"Remote server {client.dest_address} closed") - + await client.send(message, wait_response) + # NOTE: When send raise exception, we should not _force_remove_client, + # let _listen_task exit normally on client close, + # and set all futures in client to exception if not wait: r = wait_response else: @@ -154,26 +390,23 @@ async def call_with_client( async def call_send_buffers( self, - client: UCXClient, + client: CallerClient, local_buffers: list, meta_message: _MessageBase, wait: bool = True, ) -> ResultMessage | ErrorMessage | asyncio.Future: + """ + Althoght we've already wrapped CallerClient in get_client(), + might directly call from api (not recommended), Compatible with old usage. + """ + loop = asyncio.get_running_loop() wait_response = loop.create_future() - self._client_to_message_futures[client][meta_message.message_id] = wait_response - with Timer() as timer: - try: - await client.send_buffers(local_buffers, meta_message) - except ConnectionError: # pragma: no cover - try: - await client.close() - except ConnectionError: - # close failed, ignore it - pass - raise ServerClosed(f"Remote server {client.dest_address} closed") - + await client.send(meta_message, wait_response, local_buffers=local_buffers) + # NOTE: When send raise exception, we should not _force_remove_client, + # let _listen_task exit normally on client close, + # and set all futures in client to exception if not wait: # pragma: no cover r = wait_response else: @@ -193,12 +426,23 @@ async def call( return await self.call_with_client(client, message, wait) async def stop(self): + """Gracefully stop all client connections and background task""" try: await asyncio.gather(*[client.close() for client in self._clients]) except (ConnectionError, ServerClosed): pass - self.cancel_tasks() + self.stop_nonblock() - def cancel_tasks(self): - # cancel listening for all clients - _ = [task.cancel() for task in self._clients.values()] + def stop_nonblock(self): + """Clear all client without async closing, abort background task + Use in non-async context""" + if self._check_task is not None: + try: + self._check_task.cancel() + except: + pass + self._check_task = None + for client in self._clients.values(): + client.abort() + self._clients = {} + self._addr_to_clients = {} diff --git a/python/xoscar/backends/indigen/tests/test_indigen_actor_context.py b/python/xoscar/backends/indigen/tests/test_indigen_actor_context.py index 173dc3c..be64a16 100644 --- a/python/xoscar/backends/indigen/tests/test_indigen_actor_context.py +++ b/python/xoscar/backends/indigen/tests/test_indigen_actor_context.py @@ -20,6 +20,8 @@ import time import traceback from collections import deque +from typing import Any, Dict +from unittest import mock import pandas as pd import pytest @@ -27,6 +29,7 @@ import xoscar as mo from ....backends.allocate_strategy import RandomSubPool +from ....backends.communication.dummy import DummyChannel from ....core import ActorRef, LocalActorRef from ....debug import DebugOptions, get_debug_options, set_debug_options from ...router import Router @@ -409,41 +412,59 @@ async def test_indigen_batch_method(actor_pool): await ref1.add_ret.batch(ref1.add_ret.delay(1), ref1.add.delay(2)) -@pytest.mark.asyncio -async def test_gather_exception(actor_pool): - try: - Router.get_instance_or_empty()._cache.clear() - ref1 = await mo.create_actor(DummyActor, 1, address=actor_pool.external_address) - router = Router.get_instance_or_empty() - client = next(iter(router._cache.values())) +class FakeChannel(DummyChannel): - future = asyncio.Future() - client_channel = client.channel + mock_recv: Dict[str, Any] = {} - class FakeChannel(type(client_channel)): - def __init__(self): - pass + def __init__(self, origin_channel): + self.origin_channel = origin_channel - def __getattr__(self, item): - return getattr(client_channel, item) + @classmethod + def set_exception(cls, e): + cls.mock_recv["exception"] = e - async def recv(self): - return await future + def __getattr__(self, item): + return getattr(self.origin_channel, item) - client.channel = FakeChannel() + async def recv(self): + exception = self.mock_recv.get("exception") + if exception is not None: + raise exception + else: + return await self.origin_channel.recv() - class MyException(Exception): - pass - await ref1.add(1) - tasks = [ref1.add(i) for i in range(200)] - future.set_exception(MyException("Test recv exception!!")) - with pytest.raises(MyException) as ex: - await asyncio.gather(*tasks) - s = traceback.format_tb(ex.tb) - assert 10 > "\n".join(s).count("send") > 0 - finally: - Router.get_instance_or_empty()._cache.clear() +origin_get_client = Router.get_client + + +async def fake_get_client(external_address: str, **kw): + # XXX patched method cannot get self? + self = Router.get_instance() + assert self is not None + client = await origin_get_client(self, external_address, **kw) + client.channel = FakeChannel(client.channel) + return client + + +@pytest.mark.asyncio +@mock.patch.object(Router, "get_client", side_effect=fake_get_client) +async def test_gather_exception(fake_get_client, actor_pool): + dest_address = actor_pool.external_address + ref1 = await mo.create_actor(DummyActor, 1, address=dest_address) + + class MyException(Exception): + pass + + await ref1.add(1) + tasks = [ref1.add(i) for i in range(200)] + + FakeChannel.set_exception(MyException("Test recv exception!!")) + with pytest.raises(MyException) as ex: + await asyncio.gather(*tasks) + s = traceback.format_tb(ex.tb) + assert 10 > "\n".join(s).count("send") > 0 + # clear + FakeChannel.set_exception(None) @pytest.mark.asyncio diff --git a/python/xoscar/backends/indigen/tests/test_pool.py b/python/xoscar/backends/indigen/tests/test_pool.py index fbf0324..5d5a68c 100644 --- a/python/xoscar/backends/indigen/tests/test_pool.py +++ b/python/xoscar/backends/indigen/tests/test_pool.py @@ -893,9 +893,9 @@ async def test_server_closed(): process.kill() process.join() - with pytest.raises(ServerClosed): + with pytest.raises(Exception): # process already been killed, - # ServerClosed will be raised + # ServerClosed or ConnectionError will be raised await task assert not process.is_alive() diff --git a/python/xoscar/backends/pool.py b/python/xoscar/backends/pool.py index 93fd164..ed9ebbb 100644 --- a/python/xoscar/backends/pool.py +++ b/python/xoscar/backends/pool.py @@ -52,7 +52,7 @@ ) from .communication.errors import ChannelClosed from .config import ActorPoolConfig -from .core import ActorCaller, ResultMessageType +from .core import ActorCallerThreaded, ResultMessageType from .message import ( DEFAULT_PROTOCOL, ActorRefMessage, @@ -186,7 +186,7 @@ def __init__( self._process_messages = dict() # manage async actor callers - self._caller = ActorCaller() + self._caller = ActorCallerThreaded() self._asyncio_task_timeout_detector_task = ( register_asyncio_task_timeout_detector() ) diff --git a/python/xoscar/backends/router.py b/python/xoscar/backends/router.py index 8a58a4f..d1d8626 100644 --- a/python/xoscar/backends/router.py +++ b/python/xoscar/backends/router.py @@ -15,9 +15,7 @@ from __future__ import annotations -import asyncio -import threading -from typing import Any, Dict, List, Optional, Type +from typing import Dict, List, Optional, Type from .communication import Client, get_client_type @@ -32,7 +30,6 @@ class Router: "_local_mapping", "_mapping", "_comm_config", - "_cache_local", ) _instance: "Router" | None = None @@ -65,34 +62,18 @@ def __init__( mapping = dict() self._mapping = mapping self._comm_config = comm_config or dict() - self._cache_local = threading.local() - - @property - def _cache(self) -> dict[tuple[str, Any, Optional[Type[Client]]], Client]: - try: - return self._cache_local.cache - except AttributeError: - cache = self._cache_local.cache = dict() - return cache - - @property - def _lock(self) -> asyncio.Lock: - try: - return self._cache_local.lock - except AttributeError: - lock = self._cache_local.lock = asyncio.Lock() - return lock def set_mapping(self, mapping: dict[str, str]): self._mapping = mapping - self._cache_local = threading.local() + + def get_config(self): + return self._comm_config def add_router(self, router: "Router"): self._curr_external_addresses.extend(router._curr_external_addresses) self._local_mapping.update(router._local_mapping) self._mapping.update(router._mapping) self._comm_config.update(router._comm_config) - self._cache_local = threading.local() def remove_router(self, router: "Router"): for external_address in router._curr_external_addresses: @@ -104,7 +85,6 @@ def remove_router(self, router: "Router"): self._local_mapping.pop(addr, None) for addr in router._mapping: self._mapping.pop(addr, None) - self._cache_local = threading.local() @property def external_address(self): @@ -122,28 +102,15 @@ def get_internal_address(self, external_address: str) -> str | None: async def get_client( self, external_address: str, - from_who: Any = None, - cached: bool = True, **kw, ) -> Client: - async with self._lock: - if cached and (external_address, from_who, None) in self._cache: - cached_client = self._cache[external_address, from_who, None] - if cached_client.closed: - # closed before, ignore it - del self._cache[external_address, from_who, None] - else: - return cached_client - - address = self.get_internal_address(external_address) - if address is None: - # no inner address, just use external address - address = external_address - client_type: Type[Client] = get_client_type(address) - client = await self._create_client(client_type, address, **kw) - if cached: - self._cache[external_address, from_who, None] = client - return client + address = self.get_internal_address(external_address) + if address is None: + # no inner address, just use external address + address = external_address + client_type: Type[Client] = get_client_type(address) + client = await self._create_client(client_type, address, **kw) + return client async def _create_client( self, client_type: Type[Client], address: str, **kw @@ -180,28 +147,13 @@ async def get_client_via_type( self, external_address: str, client_type: Type[Client], - from_who: Any = None, - cached: bool = True, **kw, ) -> Client: - async with self._lock: - if cached and (external_address, from_who, client_type) in self._cache: - cached_client = self._cache[external_address, from_who, client_type] - if cached_client.closed: # pragma: no cover - # closed before, ignore it - del self._cache[external_address, from_who, client_type] - else: - return cached_client - - client_type_to_addresses = self._get_client_type_to_addresses( - external_address + client_type_to_addresses = self._get_client_type_to_addresses(external_address) + if client_type not in client_type_to_addresses: # pragma: no cover + raise ValueError( + f"Client type({client_type}) is not supported for {external_address}" ) - if client_type not in client_type_to_addresses: # pragma: no cover - raise ValueError( - f"Client type({client_type}) is not supported for {external_address}" - ) - address = client_type_to_addresses[client_type] - client = await self._create_client(client_type, address, **kw) - if cached: - self._cache[external_address, from_who, client_type] = client - return client + address = client_type_to_addresses[client_type] + client = await self._create_client(client_type, address, **kw) + return client diff --git a/python/xoscar/constants.py b/python/xoscar/constants.py index 5a94057..913ffab 100644 --- a/python/xoscar/constants.py +++ b/python/xoscar/constants.py @@ -21,3 +21,5 @@ XOSCAR_UNIX_SOCKET_DIR = XOSCAR_TEMP_DIR / "socket" XOSCAR_CONNECT_TIMEOUT = 8 + +XOSCAR_IDLE_TIMEOUT = 30