Skip to content

Commit

Permalink
test: Adapte changes to test_gather_exception
Browse files Browse the repository at this point in the history
Router has no cache any more, just get_client() will get new connection.
Rewrite with mock
  • Loading branch information
frostyplanet committed Nov 25, 2024
1 parent 7f3d31d commit d16d9e9
Showing 1 changed file with 49 additions and 28 deletions.
77 changes: 49 additions & 28 deletions python/xoscar/backends/indigen/tests/test_indigen_actor_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
import time
import traceback
from collections import deque
from typing import Any, Dict
from unittest import mock

import pandas as pd
import pytest

import xoscar as mo

from ....backends.allocate_strategy import RandomSubPool
from ....backends.communication.dummy import DummyChannel
from ....core import ActorRef, LocalActorRef
from ....debug import DebugOptions, get_debug_options, set_debug_options
from ...router import Router
Expand Down Expand Up @@ -409,41 +412,59 @@ async def test_indigen_batch_method(actor_pool):
await ref1.add_ret.batch(ref1.add_ret.delay(1), ref1.add.delay(2))


@pytest.mark.asyncio
async def test_gather_exception(actor_pool):
try:
Router.get_instance_or_empty()._cache.clear()
ref1 = await mo.create_actor(DummyActor, 1, address=actor_pool.external_address)
router = Router.get_instance_or_empty()
client = next(iter(router._cache.values()))
class FakeChannel(DummyChannel):

future = asyncio.Future()
client_channel = client.channel
mock_recv: Dict[str, Any] = {}

class FakeChannel(type(client_channel)):
def __init__(self):
pass
def __init__(self, origin_channel):
self.origin_channel = origin_channel

def __getattr__(self, item):
return getattr(client_channel, item)
@classmethod
def set_exception(cls, e):
cls.mock_recv["exception"] = e

async def recv(self):
return await future
def __getattr__(self, item):
return getattr(self.origin_channel, item)

client.channel = FakeChannel()
async def recv(self):
exception = self.mock_recv.get("exception")
if exception is not None:
raise exception
else:
return await self.origin_channel.recv()

class MyException(Exception):
pass

await ref1.add(1)
tasks = [ref1.add(i) for i in range(200)]
future.set_exception(MyException("Test recv exception!!"))
with pytest.raises(MyException) as ex:
await asyncio.gather(*tasks)
s = traceback.format_tb(ex.tb)
assert 10 > "\n".join(s).count("send") > 0
finally:
Router.get_instance_or_empty()._cache.clear()
origin_get_client = Router.get_client


async def fake_get_client(external_address: str, **kw):
# XXX patched method cannot get self?
self = Router.get_instance()
assert self is not None
client = await origin_get_client(self, external_address, **kw)
client.channel = FakeChannel(client.channel)
return client


@pytest.mark.asyncio
@mock.patch.object(Router, "get_client", side_effect=fake_get_client)
async def test_gather_exception(fake_get_client, actor_pool):
dest_address = actor_pool.external_address
ref1 = await mo.create_actor(DummyActor, 1, address=dest_address)

class MyException(Exception):
pass

await ref1.add(1)
tasks = [ref1.add(i) for i in range(200)]

FakeChannel.set_exception(MyException("Test recv exception!!"))
with pytest.raises(MyException) as ex:
await asyncio.gather(*tasks)
s = traceback.format_tb(ex.tb)
assert 10 > "\n".join(s).count("send") > 0
# clear
FakeChannel.set_exception(None)


@pytest.mark.asyncio
Expand Down

0 comments on commit d16d9e9

Please sign in to comment.