Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
jopemachine committed Jan 29, 2024
1 parent 816d030 commit 47c1e6a
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 82 deletions.
1 change: 1 addition & 0 deletions etcd_client.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class Compare:
def with_prefix() -> "Compare": ...

class Txn:
def __init__(self) -> None: ...
def when(self, compares: list["Compare"]) -> "Txn": ...
def and_then(self, operations: list["TxnOp"]) -> "Txn": ...
def or_else(self, operations: list["TxnOp"]) -> "Txn": ...
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
maturin==1.3.2
pytest==7.3.1
pytest==7.3.1
trafaret~=2.1
5 changes: 5 additions & 0 deletions src/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ pub struct PyTxn(pub Txn);

#[pymethods]
impl PyTxn {
#[new]
fn new() -> Self {
PyTxn(Txn::new())
}

fn when(&self, compares: Vec<PyCompare>) -> PyResult<Self> {
let compares = compares.into_iter().map(|c| c.0).collect::<Vec<_>>();
Ok(PyTxn(self.0.clone().when(compares)))
Expand Down
195 changes: 114 additions & 81 deletions tests/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,11 @@
from urllib.parse import quote as _quote
from urllib.parse import unquote

import grpc # pants: no-infer-dep (etcetra)
import trafaret as t
from etcd_client import EtcdCommunicator, WatchEvent
from etcd_client.client import EtcdClient, EtcdTransactionAction
from etcd_client.types import CompareKey, EtcdCredential
from etcd_client.types import HostPortPair as EtcetraHostPortPair

from etcd_client import Communicator as EtcdCommunicator, CompareOp, CondVar, WatchEvent
from etcd_client import Client as EtcdClient, Txn as EtcdTransactionAction, TxnOp
from etcd_client import Compare as CompareKey


class QueueSentinel(enum.Enum):
Expand Down Expand Up @@ -129,11 +128,11 @@ def _slash(v: str):
class AsyncEtcd:
etcd: EtcdClient

_creds: Optional[EtcdCredential]
# _creds: Optional[EtcdCredential]

def __init__(
self,
addr: HostPortPair | EtcetraHostPortPair,
addr: HostPortPair,
namespace: str,
scope_prefix_map: Mapping[ConfigScopes, str],
*,
Expand All @@ -146,19 +145,21 @@ def __init__(
t.Key(ConfigScopes.SGROUP, optional=True): t.String,
t.Key(ConfigScopes.NODE, optional=True): t.String,
}).check(scope_prefix_map)
if credentials is not None:
self._creds = EtcdCredential(credentials["user"], credentials["password"])
else:
self._creds = None

# if credentials is not None:
# self._creds = EtcdCredential(credentials["user"], credentials["password"])
# else:
self._creds = None

self.ns = namespace
log.info('using etcd cluster from {} with namespace "{}"', addr, namespace)
self.encoding = encoding
self.watch_reconnect_intvl = watch_reconnect_intvl

self.etcd = EtcdClient(
EtcetraHostPortPair(str(addr.host), addr.port),
credentials=self._creds,
encoding=self.encoding,
["http://localhost:2379"],
# credentials=self._creds,
# encoding=self.encoding,
)

async def close(self):
Expand Down Expand Up @@ -244,12 +245,18 @@ def _flatten(prefix: str, inner_dict: NestedStrKeyedDict) -> None:

_flatten(key, cast(NestedStrKeyedDict, dict_obj))

def _txn(action: EtcdTransactionAction):
for k, v in flattened_dict.items():
action.put(self._mangle_key(f"{_slash(scope_prefix)}{k}"), str(v))
# def _txn(action: EtcdTransactionAction):
# for k, v in flattened_dict.items():
# action.put(self._mangle_key(f"{_slash(scope_prefix)}{k}"), str(v))

# TODO: Test below transaction codes
async with self.etcd.connect() as communicator:
await communicator.txn(_txn)
actions = []
for k, v in flattened_dict.items():
actions.append(TxnOp.put(self._mangle_key(f"{_slash(scope_prefix)}{k}"), str(v)))

txn = EtcdTransactionAction()
await communicator.txn(txn.and_then(actions).or_else([]))

async def put_dict(
self,
Expand All @@ -270,12 +277,21 @@ async def put_dict(
"""
scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope]

def _pipe(txn: EtcdTransactionAction):
for k, v in flattened_dict_obj.items():
txn.put(self._mangle_key(f"{_slash(scope_prefix)}{k}"), str(v))
# TODO: Test below transaction codes
# def _pipe(txn: EtcdTransactionAction):
# for k, v in flattened_dict_obj.items():
# txn.put(self._mangle_key(f"{_slash(scope_prefix)}{k}"), str(v))

# async with self.etcd.connect() as communicator:
# await communicator.txn(_pipe)

actions = []
for k, v in flattened_dict_obj.items():
actions.append(TxnOp.put(self._mangle_key(f"{_slash(scope_prefix)}{k}"), str(v)))
txn = EtcdTransactionAction()

async with self.etcd.connect() as communicator:
await communicator.txn(_pipe)
await communicator.txn(txn.and_then(actions).or_else([]))

async def get(
self,
Expand Down Expand Up @@ -409,17 +425,27 @@ async def replace(
scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope]
mangled_key = self._mangle_key(f"{_slash(scope_prefix)}{key}")

def _txn(success: EtcdTransactionAction, _):
success.put(mangled_key, new_val)
# def _txn(success: EtcdTransactionAction, _):
# success.put(mangled_key, new_val)

# async with self.etcd.connect() as communicator:
# _, success = await communicator.txn_compare(
# [
# CompareKey(mangled_key).value == initial_val,
# ],
# _txn,
# )
# return success

# TODO: Test below transaction codes
async with self.etcd.connect() as communicator:
_, success = await communicator.txn_compare(
[
CompareKey(mangled_key).value == initial_val,
],
_txn,
)
return success
put_action = TxnOp.put(mangled_key, new_val)

txn = EtcdTransactionAction()

communicator.txn(txn.when([
CompareKey.value(mangled_key, CompareOp.EQUAL, initial_val),
]).and_then(list[put_action]).or_else([]))

async def delete(
self,
Expand All @@ -443,11 +469,17 @@ async def delete_multi(
scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope]
async with self.etcd.connect() as communicator:

def _txn(action: EtcdTransactionAction):
for k in keys:
action.delete(self._mangle_key(f"{_slash(scope_prefix)}{k}"))
# def _txn(action: EtcdTransactionAction):
# for k in keys:
# action.delete(self._mangle_key(f"{_slash(scope_prefix)}{k}"))

await communicator.txn(_txn)
# await communicator.txn(_txn)
# TODO: Test below transaction codes
actions = []
for k in keys:
actions.append(TxnOp.delete(k))
txn = EtcdTransactionAction()
communicator.txn(txn.and_then(actions).or_else([]))

async def delete_prefix(
self,
Expand All @@ -466,7 +498,7 @@ async def _watch_impl(
iterator_factory: Callable[[EtcdCommunicator], AsyncIterator[WatchEvent]],
scope_prefix_len: int,
once: bool,
cleanup_event: Optional[asyncio.Event] = None,
cleanup_event: Optional[CondVar] = None,
wait_timeout: Optional[float] = None,
) -> AsyncGenerator[Union[QueueSentinel, Event], None]:
try:
Expand All @@ -483,7 +515,8 @@ async def _watch_impl(
return
finally:
if cleanup_event:
cleanup_event.set()
cleanup_event.notify_all()
# cleanup_event.set()

async def watch(
self,
Expand All @@ -492,8 +525,8 @@ async def watch(
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None,
once: bool = False,
ready_event: asyncio.Event = None,
cleanup_event: asyncio.Event = None,
ready_event: Optional[CondVar] = None,
cleanup_event: Optional[CondVar] = None,
wait_timeout: float = None,
) -> AsyncGenerator[Union[QueueSentinel, Event], None]:
scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope]
Expand All @@ -502,26 +535,26 @@ async def watch(
ended_without_error = False

while not ended_without_error:
try:
async for ev in self._watch_impl(
lambda communicator: communicator.watch(
mangled_key,
ready_event=ready_event,
),
scope_prefix_len,
once,
cleanup_event=cleanup_event,
wait_timeout=wait_timeout,
):
yield ev
ended_without_error = True
except grpc.aio.AioRpcError as e:
if e.code() == grpc.StatusCode.UNAVAILABLE:
log.warn("watch(): error while connecting to Etcd server, retrying...")
await asyncio.sleep(self.watch_reconnect_intvl)
ended_without_error = False
else:
raise
# try:
async for ev in self._watch_impl(
lambda communicator: communicator.watch(
mangled_key,
ready_event=ready_event,
),
scope_prefix_len,
once,
cleanup_event=cleanup_event,
wait_timeout=wait_timeout,
):
yield ev
ended_without_error = True
# except grpc.aio.AioRpcError as e:
# if e.code() == grpc.StatusCode.UNAVAILABLE:
# log.warn("watch(): error while connecting to Etcd server, retrying...")
# await asyncio.sleep(self.watch_reconnect_intvl)
# ended_without_error = False
# else:
# raise

async def watch_prefix(
self,
Expand All @@ -530,8 +563,8 @@ async def watch_prefix(
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None,
once: bool = False,
ready_event: asyncio.Event = None,
cleanup_event: asyncio.Event = None,
ready_event: Optional[CondVar] = None,
cleanup_event: Optional[CondVar] = None,
wait_timeout: float = None,
) -> AsyncGenerator[Union[QueueSentinel, Event], None]:
scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope]
Expand All @@ -540,23 +573,23 @@ async def watch_prefix(
ended_without_error = False

while not ended_without_error:
try:
async for ev in self._watch_impl(
lambda communicator: communicator.watch_prefix(
mangled_key_prefix,
ready_event=ready_event,
),
scope_prefix_len,
once,
cleanup_event=cleanup_event,
wait_timeout=wait_timeout,
):
yield ev
ended_without_error = True
except grpc.aio.AioRpcError as e:
if e.code() == grpc.StatusCode.UNAVAILABLE:
log.warn("watch_prefix(): error while connecting to Etcd server, retrying...")
await asyncio.sleep(self.watch_reconnect_intvl)
ended_without_error = False
else:
raise e
# try:
async for ev in self._watch_impl(
lambda communicator: communicator.watch_prefix(
mangled_key_prefix,
ready_event=ready_event,
),
scope_prefix_len,
once,
cleanup_event=cleanup_event,
wait_timeout=wait_timeout,
):
yield ev
ended_without_error = True
# except grpc.aio.AioRpcError as e:
# if e.code() == grpc.StatusCode.UNAVAILABLE:
# log.warn("watch_prefix(): error while connecting to Etcd server, retrying...")
# await asyncio.sleep(self.watch_reconnect_intvl)
# ended_without_error = False
# else:
# raise e

0 comments on commit 47c1e6a

Please sign in to comment.