Skip to content

Commit

Permalink
Use wait_for_pending_wakeups in observe_value
Browse files Browse the repository at this point in the history
  • Loading branch information
coretl committed Nov 15, 2024
1 parent 559dcc5 commit 6f0d1e6
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 21 deletions.
2 changes: 2 additions & 0 deletions src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
get_unique,
in_micros,
wait_for_connection,
wait_for_pending_wakeups,
)

__all__ = [
Expand Down Expand Up @@ -190,4 +191,5 @@
"in_micros",
"wait_for_connection",
"completed_status",
"wait_for_pending_wakeups",
]
3 changes: 2 additions & 1 deletion src/ophyd_async/core/_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
Callback,
LazyMock,
T,
wait_for_pending_wakeups,
)


Expand Down Expand Up @@ -459,7 +460,7 @@ async def observe_value(
item = await asyncio.wait_for(q.get(), timeout)
# yield here in case something else is filling the queue
# like in test_observe_value_times_out_with_no_external_task()
await asyncio.sleep(0)
await wait_for_pending_wakeups()
if done_status and item is done_status:
if exc := done_status.exception():
raise exc
Expand Down
24 changes: 24 additions & 0 deletions src/ophyd_async/core/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import logging
import warnings
from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence
from dataclasses import dataclass
from enum import Enum, EnumMeta
Expand Down Expand Up @@ -295,3 +296,26 @@ def __call__(self) -> Mock:
if self.parent is not None:
self.parent().attach_mock(self._mock, self.name)
return self._mock


async def wait_for_pending_wakeups(max_yields=10):
"""Allow any ready asyncio tasks to be woken up.
Used in:
- Tests to allow tasks like ``set()`` to start so that signal
puts can be tested
- `observe_value` to allow it to be wrapped in `asyncio.wait_for`
with a timeout
"""
loop = asyncio.get_event_loop()
# If anything has called loop.call_soon or is scheduled a wakeup
# then let it run
for _ in range(max_yields):
await asyncio.sleep(0)
if not loop._ready: # type: ignore # noqa: SLF001
return
warnings.warn(
f"Tasks still scheduling wakeups after {max_yields} yields",
RuntimeWarning,
stacklevel=2,
)
6 changes: 3 additions & 3 deletions tests/epics/signal/test_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,13 +945,13 @@ async def test_observe_ticking_signal_with_busy_loop(ioc: IOC):

async def watch():
async for val in observe_value(sig):
time.sleep(0.15)
time.sleep(0.3)
recv.append(val)

start = time.time()
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(watch(), timeout=0.2)
assert time.time() - start == pytest.approx(0.3, abs=0.05)
await asyncio.wait_for(watch(), timeout=0.4)
assert time.time() - start == pytest.approx(0.6, abs=0.1)
assert len(recv) == 2
# Don't check values as CA and PVA have different algorithms for
# dropping updates for slow callbacks
24 changes: 7 additions & 17 deletions tests/epics/test_motor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,11 @@
set_mock_put_proceeds,
set_mock_value,
soft_signal_rw,
wait_for_pending_wakeups,
)
from ophyd_async.epics import motor


async def wait_for_wakeups(max_yields=10):
loop = asyncio.get_event_loop()
# If anything has called loop.call_soon or is scheduled a wakeup
# then let it run
for _ in range(max_yields):
await asyncio.sleep(0)
if not loop._ready:
return
raise RuntimeError(f"Tasks still scheduling wakeups after {max_yields} yields")


@pytest.fixture
async def sim_motor():
async with DeviceCollector(mock=True):
Expand All @@ -44,7 +34,7 @@ async def sim_motor():
async def wait_for_eq(item, attribute, comparison, timeout):
timeout_time = time.monotonic() + timeout
while getattr(item, attribute) != comparison:
await wait_for_wakeups()
await wait_for_pending_wakeups()
if time.monotonic() > timeout_time:
raise TimeoutError

Expand All @@ -56,7 +46,7 @@ async def test_motor_moving_well(sim_motor: motor.Motor) -> None:
s.watch(watcher)
done = Mock()
s.add_callback(done)
await wait_for_wakeups()
await wait_for_pending_wakeups()
await wait_for_eq(watcher, "call_count", 1, 1)
assert watcher.call_args == call(
name="sim_motor",
Expand Down Expand Up @@ -86,7 +76,7 @@ async def test_motor_moving_well(sim_motor: motor.Motor) -> None:
set_mock_value(sim_motor.motor_done_move, True)
set_mock_value(sim_motor.user_readback, 0.55)
set_mock_put_proceeds(sim_motor.user_setpoint, True)
await wait_for_wakeups()
await wait_for_pending_wakeups()
await wait_for_eq(s, "done", True, 1)
done.assert_called_once_with(s)

Expand All @@ -98,7 +88,7 @@ async def test_motor_moving_well_2(sim_motor: motor.Motor) -> None:
s.watch(watcher)
done = Mock()
s.add_callback(done)
await wait_for_wakeups()
await wait_for_pending_wakeups()
assert watcher.call_count == 1
assert watcher.call_args == call(
name="sim_motor",
Expand Down Expand Up @@ -126,7 +116,7 @@ async def test_motor_moving_well_2(sim_motor: motor.Motor) -> None:
time_elapsed=pytest.approx(0.1, abs=0.2),
)
set_mock_put_proceeds(sim_motor.user_setpoint, True)
await wait_for_wakeups()
await wait_for_pending_wakeups()
assert s.done
done.assert_called_once_with(s)

Expand Down Expand Up @@ -165,7 +155,7 @@ async def test_motor_moving_stopped(sim_motor: motor.Motor):
assert not s.done
await sim_motor.stop()
set_mock_put_proceeds(sim_motor.user_setpoint, True)
await wait_for_wakeups()
await wait_for_pending_wakeups()
assert s.done
assert s.success is False

Expand Down

0 comments on commit 6f0d1e6

Please sign in to comment.