Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get_context_ids method #399

Draft
wants to merge 4 commits into
base: feat/partial_context_updates
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion chatsky/context_storages/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,30 @@
from pathlib import Path
from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Set, Tuple, Union

from pydantic import BaseModel, Field, field_validator, validate_call
from pydantic import BaseModel, Field

from .protocol import PROTOCOLS

_SUBSCRIPT_TYPE = Union[Literal["__all__"], int, Set[str]]
_SUBSCRIPT_DICT = Dict[str, Union[_SUBSCRIPT_TYPE, Literal["__none__"]]]


class ContextIdFilter(BaseModel):
update_time_greater: Optional[int] = Field(default=None)
update_time_less: Optional[int] = Field(default=None)
origin_interface_whitelist: Set[str] = Field(default_factory=set)

def filter_keys(self, contexts: Dict[str, Tuple[int, int, int, bytes, bytes]]) -> Set[str]:
if self.update_time_greater is not None:
contexts = {k: (ti, ca, ua, m, fd) for k, (ti, ca, ua, m, fd) in contexts.items() if ua > self.update_time_greater}
if self.update_time_less is not None:
contexts = {k: (ti, ca, ua, m, fd) for k, (ti, ca, ua, m, fd) in contexts.items() if ua < self.update_time_less}
if len(self.origin_interface_whitelist) > 0:
# TODO: implement whitelist once context ID is ready
pass
return contexts.keys()


class DBContextStorage(ABC):
_main_table_name: Literal["main"] = "main"
_turns_table_name: Literal["turns"] = "turns"
Expand Down Expand Up @@ -79,6 +96,29 @@ def verifier(self: "DBContextStorage", *args, **kwargs):
else:
return method(self, *args, **kwargs)
return verifier

@staticmethod
def _convert_id_filter(method: Callable):
def verifier(self, *args, **kwargs):
if len(args) >= 1:
args, filter_obj = list(args[1:]), args[0]
else:
filter_obj = kwargs.pop("filter", None)
if filter_obj is None:
raise ValueError(f"For method {method.__name__} argument 'filter' is not found!")
elif isinstance(filter_obj, Dict):
filter_obj = ContextIdFilter.model_validate(filter_obj)
elif not isinstance(filter_obj, ContextIdFilter):
raise ValueError(f"Invalid type '{type(filter_obj).__name__}' for method '{method.__name__}' argument 'filter'!")
return method(self, *args, filter=filter_obj, **kwargs)
return verifier

@abstractmethod
async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]:
"""
:param filter:
"""
raise NotImplementedError

@abstractmethod
async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]:
Expand Down
4 changes: 4 additions & 0 deletions chatsky/context_storages/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ async def _save(self, data: SerializableStorage) -> None:
async def _load(self) -> SerializableStorage:
raise NotImplementedError

@DBContextStorage._convert_id_filter
async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]:
return filter.filter_keys((await self._load()).main)

async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]:
logger.debug(f"Loading main info for {ctx_id}...")
result = (await self._load()).main.get(ctx_id, None)
Expand Down
8 changes: 6 additions & 2 deletions chatsky/context_storages/memory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple, Union

from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE
from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT


class MemoryContextStorage(DBContextStorage):
Expand Down Expand Up @@ -30,6 +30,10 @@ def __init__(
self._responses_field_name: dict(),
}

@DBContextStorage._convert_id_filter
async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]:
return filter.filter_keys(self._main_storage)

async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]:
return self._main_storage.get(ctx_id, None)

Expand Down
18 changes: 15 additions & 3 deletions chatsky/context_storages/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,17 @@
"""

import asyncio
from typing import Dict, Set, Tuple, Optional, List
from typing import Any, Dict, Set, Tuple, Optional, List, Union

try:
from pymongo import UpdateOne
from pymongo.collection import Collection
from motor.motor_asyncio import AsyncIOMotorClient

mongo_available = True
except ImportError:
mongo_available = False

from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE
from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT
from .protocol import get_protocol_install_suggestion


Expand Down Expand Up @@ -73,6 +72,19 @@ def __init__(
)
)

@DBContextStorage._convert_id_filter
async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]:
ftr_dct = dict()
if filter.update_time_greater is not None:
ftr_dct.setdefault(self._updated_at_column_name, dict()).update({"$gt": filter.update_time_greater})
if filter.update_time_less is not None:
ftr_dct.setdefault(self._updated_at_column_name, dict()).update({"$lt": filter.update_time_less})
if len(filter.origin_interface_whitelist) > 0:
# TODO: implement whitelist once context ID is ready
pass
result = await self.main_table.find(ftr_dct, [self._id_column_name]).to_list(None)
return {item[self._key_column_name] for item in result}

async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]:
result = await self.main_table.find_one(
{self._id_column_name: ctx_id},
Expand Down
11 changes: 9 additions & 2 deletions chatsky/context_storages/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""

from asyncio import gather
from typing import Callable, List, Dict, Set, Tuple, Optional
from typing import Any, List, Dict, Set, Tuple, Optional, Union

try:
from redis.asyncio import Redis
Expand All @@ -23,7 +23,7 @@
except ImportError:
redis_available = False

from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE
from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT
from .protocol import get_protocol_install_suggestion


Expand Down Expand Up @@ -74,6 +74,13 @@ def _keys_to_bytes(keys: List[int]) -> List[bytes]:
def _bytes_to_keys(keys: List[bytes]) -> List[int]:
return [int(f.decode("utf-8")) for f in keys]

@DBContextStorage._convert_id_filter
async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]:
context_ids = [k.decode("utf-8") for k in await self.database.keys(f"{self._main_key}:*")]
context_upd = [int(await self.database.hget(f"{self._main_key}:{k}", self._updated_at_column_name)) for k in context_ids]
partial_contexts = {k: (None, None, ua, None, None) for k, ua in zip(context_ids, context_upd)}
return filter.filter_keys(partial_contexts)

async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]:
if await self.database.exists(f"{self._main_key}:{ctx_id}"):
cti, ca, ua, msc, fd = await gather(
Expand Down
24 changes: 21 additions & 3 deletions chatsky/context_storages/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,19 @@ def _check_availability(self):
install_suggestion = get_protocol_install_suggestion("sqlite")
raise ImportError("Package `sqlalchemy` and/or `aiosqlite` is missing.\n" + install_suggestion)

@DBContextStorage._convert_id_filter
async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]:
stmt = select(self.main_table.c[self._id_column_name])
if filter.update_time_greater is not None:
stmt.where(self.main_table.c[self._updated_at_column_name] > filter.update_time_greater)
if filter.update_time_less is not None:
stmt.where(self.main_table.c[self._updated_at_column_name] < filter.update_time_less)
if len(filter.origin_interface_whitelist) > 0:
# TODO: implement whitelist once context ID is ready
pass
async with self.engine.begin() as conn:
return set((await conn.execute(stmt)).fetchone())

@DBContextStorage._synchronously_lock(lambda s: s.is_asynchronous)
async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]:
logger.debug(f"Loading main info for {ctx_id}...")
Expand Down Expand Up @@ -270,7 +283,8 @@ async def delete_context(self, ctx_id: str) -> None:
async def load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[int, bytes]]:
logger.debug(f"Loading latest items for {ctx_id}, {field_name}...")
stmt = select(self.turns_table.c[self._key_column_name], self.turns_table.c[field_name])
stmt = stmt.where((self.turns_table.c[self._id_column_name] == ctx_id) & (self.turns_table.c[field_name] != None))
stmt = stmt.where(self.turns_table.c[self._id_column_name] == ctx_id)
stmt = stmt.where(self.turns_table.c[field_name] != None)
stmt = stmt.order_by(self.turns_table.c[self._key_column_name].desc())
if isinstance(self._subscripts[field_name], int):
stmt = stmt.limit(self._subscripts[field_name])
Expand All @@ -284,8 +298,10 @@ async def load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[in
@DBContextStorage._verify_field_name
@DBContextStorage._synchronously_lock(lambda s: s.is_asynchronous)
async def load_field_keys(self, ctx_id: str, field_name: str) -> List[int]:
stmt = select(self.turns_table.c[self._key_column_name])
stmt = stmt.where(self.turns_table.c[self._id_column_name] == ctx_id)
stmt = stmt.where(self.turns_table.c[field_name] != None)
logger.debug(f"Loading field keys for {ctx_id}, {field_name}...")
stmt = select(self.turns_table.c[self._key_column_name]).where((self.turns_table.c[self._id_column_name] == ctx_id) & (self.turns_table.c[field_name] != None))
async with self.engine.begin() as conn:
result = [k[0] for k in (await conn.execute(stmt)).fetchall()]
logger.debug(f"Field keys loaded for {ctx_id}, {field_name}: {collapse_num_list(result)}")
Expand All @@ -296,7 +312,9 @@ async def load_field_keys(self, ctx_id: str, field_name: str) -> List[int]:
async def load_field_items(self, ctx_id: str, field_name: str, keys: List[int]) -> List[bytes]:
logger.debug(f"Loading field items for {ctx_id}, {field_name} ({collapse_num_list(keys)})...")
stmt = select(self.turns_table.c[self._key_column_name], self.turns_table.c[field_name])
stmt = stmt.where((self.turns_table.c[self._id_column_name] == ctx_id) & (self.turns_table.c[self._key_column_name].in_(tuple(keys))) & (self.turns_table.c[field_name] != None))
stmt = stmt.where(self.turns_table.c[self._id_column_name] == ctx_id)
stmt = stmt.where(self.turns_table.c[self._key_column_name].in_(tuple(keys)))
stmt = stmt.where(self.turns_table.c[field_name] != None)
async with self.engine.begin() as conn:
result = list((await conn.execute(stmt)).fetchall())
logger.debug(f"Field items loaded for {ctx_id}, {field_name}: {collapse_num_list([k for k, _ in result])}")
Expand Down
36 changes: 34 additions & 2 deletions chatsky/context_storages/ydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@

from asyncio import gather, run
from os.path import join
from typing import Awaitable, Callable, Set, Tuple, List, Dict, Optional
from typing import Any, Awaitable, Callable, Dict, Set, Tuple, List, Optional, Union
from urllib.parse import urlsplit

from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE
from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT
from .protocol import get_protocol_install_suggestion

try:
Expand Down Expand Up @@ -55,6 +55,8 @@ class YDBContextStorage(DBContextStorage):
:param table_name: The name of the table to use.
"""

_UPDATE_TIME_GREATER_VAR = "update_time_greater"
_UPDATE_TIME_LESS_VAR = "update_time_less"
_LIMIT_VAR = "limit"
_KEY_VAR = "key"

Expand Down Expand Up @@ -133,6 +135,36 @@ async def callee(session: Session) -> None:

await self.pool.retry_operation(callee)

@DBContextStorage._convert_id_filter
async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]:
async def callee(session: Session) -> Set[str]:
declare, prepare, conditions = list(), dict(), list()
if filter.update_time_greater is not None:
declare += [f"DECLARE ${self._UPDATE_TIME_GREATER_VAR} AS Uint64;"]
prepare.update({f"${self._UPDATE_TIME_GREATER_VAR}": filter.update_time_greater})
conditions += [f"{self._updated_at_column_name} > ${self._UPDATE_TIME_GREATER_VAR}"]
if filter.update_time_less is not None:
declare += [f"DECLARE ${self._UPDATE_TIME_LESS_VAR} AS Uint64;"]
prepare.update({f"${self._UPDATE_TIME_LESS_VAR}": filter.update_time_less})
conditions += [f"{self._updated_at_column_name} < ${self._UPDATE_TIME_LESS_VAR}"]
if len(filter.origin_interface_whitelist) > 0:
# TODO: implement whitelist once context ID is ready
pass
where = f"WHERE {' AND '.join(conditions)}" if len(conditions) > 0 else ""
query = f"""
PRAGMA TablePathPrefix("{self.database}");
{" ".join(declare)}
SELECT {self._id_column_name}
FROM {self.main_table}
{where};
""" # noqa: E501
result_sets = await session.transaction(SerializableReadWrite()).execute(
await session.prepare(query), prepare, commit_tx=True
)
return {e[self._id_column_name] for e in result_sets[0].rows} if len(result_sets[0].rows) > 0 else set()

return await self.pool.retry_operation(callee)

async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]:
async def callee(session: Session) -> Optional[Tuple[int, int, int, bytes, bytes]]:
query = f"""
Expand Down
11 changes: 11 additions & 0 deletions tests/context_storages/test_dbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,17 @@ def configure_context_storage(
if responses_subscript is not None:
context_storage._subscripts["responses"] = responses_subscript

async def test_get_context_ids(self, db, add_context):
await add_context("1")

assert await db.get_context_ids({"update_time_greater": 0}) == {"1"}
assert await db.get_context_ids({"update_time_greater": 2}) == set()

assert await db.get_context_ids({"update_time_less": 0}) == set()
assert await db.get_context_ids({"update_time_less": 2}) == {"1"}

# TODO: implement whitelist once context ID is ready

async def test_add_context(self, db, add_context):
# test the fixture
await add_context("1")
Expand Down