Skip to content

Commit

Permalink
Implement once option to watch API
Browse files Browse the repository at this point in the history
  • Loading branch information
jopemachine committed Jan 29, 2024
1 parent 64ae685 commit bae864b
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 73 deletions.
15 changes: 13 additions & 2 deletions etcd_client.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ class Client:

class Watch:
""" """

async def __aiter__(self) -> AsyncIterator["Watch"]:
""" """
async def __anext__(self) -> "Event":
""" """

class CondVar:
""" """

def __init__(self) -> None:
""" """
async def wait(self) -> None:
Expand All @@ -48,16 +50,25 @@ class Communicator:
async def replace(self, key: str, initial_value: str, new_value: str) -> bool:
""" """
def watch(
self, key: str, *, ready_event: Optional["CondVar"] = None
self,
key: str,
*,
once: Optional[bool] = False,
ready_event: Optional["CondVar"] = None,
) -> "Watch":
""" """
def watch_prefix(
self, key: str, *, ready_event: Optional["CondVar"] = None
self,
key: str,
*,
once: Optional[bool] = False,
ready_event: Optional["CondVar"] = None,
) -> "Watch":
""" """

class Event:
""" """

key: str
value: str
event_type: "EventType"
Expand Down
8 changes: 6 additions & 2 deletions src/communicator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,21 +159,25 @@ impl PyCommunicator {
fn watch(
&self,
key: String,
once: Option<bool>,
ready_event: Option<PyCondVar>,
cleanup_event: Option<PyCondVar>,
) -> PyWatch {
let client = self.0.clone();
PyWatch::new(client, key, None, ready_event, cleanup_event)
let once = once.unwrap_or(false);
PyWatch::new(client, key, once, None, ready_event, cleanup_event)
}

fn watch_prefix(
&self,
key: String,
once: Option<bool>,
ready_event: Option<PyCondVar>,
cleanup_event: Option<PyCondVar>,
) -> PyWatch {
let client = self.0.clone();
let once = once.unwrap_or(false);
let options = WatchOptions::new().with_prefix();
PyWatch::new(client, key, Some(options), ready_event, cleanup_event)
PyWatch::new(client, key, once, Some(options), ready_event, cleanup_event)
}
}
14 changes: 9 additions & 5 deletions src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,24 @@ pub struct PyEventStream {
stream: WatchStream,
events: Vec<PyEvent>,
index: usize,
once: bool
}

impl PyEventStream {
pub fn new(stream: WatchStream) -> Self {
pub fn new(stream: WatchStream, once: bool) -> Self {
Self {
stream,
events: Vec::new(),
index: 0,
once,
}
}

pub async fn next(&mut self) -> Option<Result<PyEvent, Error>> {
if self.once && self.index > 0 {
return None;
}

if self.index < self.events.len() {
let event = self.events[self.index].clone();
self.index += 1;
Expand All @@ -42,10 +48,8 @@ impl PyEventStream {
None
}
}
Some(Err(error)) => {
Some(Err(Error(error)))
}
None => None
Some(Err(error)) => Some(Err(Error(error))),
None => None,
}
}
}
56 changes: 38 additions & 18 deletions src/watch.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use etcd_client::Client as EtcdClient;
use etcd_client::WatchOptions;
use etcd_client::Watcher;
use pyo3::exceptions::PyStopAsyncIteration;
use pyo3::prelude::*;
use pyo3_asyncio::tokio::future_into_py;
use tokio::sync::Notify;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::Notify;

use crate::condvar::PyCondVar;
use crate::error::Error;
Expand All @@ -16,7 +17,9 @@ use crate::stream::PyEventStream;
pub struct PyWatch {
client: Arc<Mutex<EtcdClient>>,
key: String,
once: bool,
options: Option<WatchOptions>,
watcher: Arc<Mutex<Option<Watcher>>>,
event_stream_init_notifier: Arc<Notify>,
event_stream: Arc<Mutex<Option<PyEventStream>>>,
ready_event: Option<PyCondVar>,
Expand All @@ -27,16 +30,19 @@ impl PyWatch {
pub fn new(
client: Arc<Mutex<EtcdClient>>,
key: String,
once: bool,
options: Option<WatchOptions>,
ready_event: Option<PyCondVar>,
cleanup_event: Option<PyCondVar>,
) -> Self {
Self {
client,
key,
once,
options,
event_stream_init_notifier: Arc::new(Notify::new()),
event_stream: Arc::new(Mutex::new(None)),
watcher: Arc::new(Mutex::new(None)),
ready_event,
cleanup_event,
}
Expand All @@ -54,16 +60,18 @@ impl PyWatch {
let mut client = self.client.lock().await;

match client.watch(self.key.clone(), self.options.clone()).await {
Ok((_, stream)) => {
*event_stream = Some(PyEventStream::new(stream));
Ok((watcher, stream)) => {
*event_stream = Some(PyEventStream::new(stream, self.once));
*self.watcher.lock().await = Some(watcher);

event_stream_init_notifier.notify_waiters();

if let Some(ready_event) = &self.ready_event {
ready_event._notify_waiters().await;
}
Ok(())
}
Err(error) => return Err(Error(error)),
Err(error) => Err(Error(error)),
}
}
}
Expand All @@ -77,25 +85,37 @@ impl PyWatch {
fn __anext__<'a>(&'a mut self, py: Python<'a>) -> PyResult<Option<PyObject>> {
let watch = Arc::new(Mutex::new(self.clone()));
let event_stream_init_notifier = self.event_stream_init_notifier.clone();
let watcher = self.watcher.clone();
let once = self.once;

let result = future_into_py(py, async move {
let mut watch = watch.lock().await;
watch.init().await?;
Ok(Some(
future_into_py(py, async move {
let mut watch = watch.lock().await;
watch.init().await?;

let mut event_stream = watch.event_stream.lock().await;
let mut event_stream = watch.event_stream.lock().await;

if event_stream.is_none() {
event_stream_init_notifier.notified().await;
}
if event_stream.is_none() {
event_stream_init_notifier.notified().await;
}

let event_stream = event_stream.as_mut().unwrap();

let event_stream = event_stream.as_mut().unwrap();
let event = match event_stream.next().await {
Some(result) => {
if once {
let mut watcher = watcher.lock().await;
watcher.as_mut().unwrap().cancel().await.unwrap();
}

Ok(match event_stream.next().await {
Some(result) => result,
None => return Err(PyStopAsyncIteration::new_err(()))
}?)
});
result
},
None => return Err(PyStopAsyncIteration::new_err(())),
}?;

Ok(Some(result.unwrap().into()))
Ok(event)
})?
.into(),
))
}
}
93 changes: 47 additions & 46 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ async def _record_prefix():
return

async with (
asyncio.timeout(5),
asyncio.timeout(10),
asyncio.TaskGroup() as tg,
):
tg.create_task(_record())
Expand Down Expand Up @@ -257,48 +257,49 @@ async def _record_prefix():
assert records_prefix[3].value == ""


# @pytest.mark.asyncio
# async def test_watch_once(etcd: AsyncEtcd) -> None:
# records = []
# records_prefix = []
# r_ready = asyncio.Event()
# rp_ready = asyncio.Event()

# async def _record():
# recv_count = 0
# async for ev in etcd.watch("wow", once=True, ready_event=r_ready):
# records.append(ev)
# recv_count += 1
# if recv_count == 1:
# return

# async def _record_prefix():
# recv_count = 0
# async for ev in etcd.watch_prefix("wow/city", once=True, ready_event=rp_ready):
# records_prefix.append(ev)
# recv_count += 1
# if recv_count == 1:
# return

# async with (
# asyncio.timeout(10),
# asyncio.TaskGroup() as tg,
# ):
# tg.create_task(_record())
# tg.create_task(_record_prefix())

# await r_ready.wait()
# await rp_ready.wait()

# await etcd.put("wow/city1", "seoul")
# await etcd.put("wow/city2", "daejeon")
# await etcd.put("wow", "korea")
# await etcd.delete_prefix("wow")

# assert records[0].key == "wow"
# assert records[0].event == EventType.PUT
# assert records[0].value == "korea"

# assert records_prefix[0].key == "wow/city1"
# assert records_prefix[0].event == EventType.PUT
# assert records_prefix[0].value == "seoul"
@pytest.mark.asyncio
async def test_watch_once() -> None:
records = []
records_prefix = []
r_ready = CondVar()
rp_ready = CondVar()

async with etcd_client.connect() as etcd:
async def _record():
recv_count = 0
# Below watcher returns after first event
async for ev in etcd.watch("wow", once=True, ready_event=r_ready):
records.append(ev)
recv_count += 1

async def _record_prefix():
recv_count = 0
# Below watcher returns after first event
async for ev in etcd.watch_prefix("wow/city", once=True, ready_event=rp_ready):
records_prefix.append(ev)
recv_count += 1

async with (
asyncio.timeout(10),
asyncio.TaskGroup() as tg,
):
tg.create_task(_record())
tg.create_task(_record_prefix())

await r_ready.wait()
await rp_ready.wait()

await etcd.put("wow/city1", "seoul")
await etcd.put("wow/city2", "daejeon")
await etcd.put("wow", "korea")
await etcd.delete_prefix("wow")

assert len(records) == 1

assert records[0].key == "wow"
assert records[0].event == EventType.PUT
assert records[0].value == "korea"

assert records_prefix[0].key == "wow/city1"
assert records_prefix[0].event == EventType.PUT
assert records_prefix[0].value == "seoul"

0 comments on commit bae864b

Please sign in to comment.