diff --git a/chatsky/context_storages/database.py b/chatsky/context_storages/database.py index 2169720c0..bf9223d7a 100644 --- a/chatsky/context_storages/database.py +++ b/chatsky/context_storages/database.py @@ -14,13 +14,30 @@ from pathlib import Path from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Set, Tuple, Union -from pydantic import BaseModel, Field, field_validator, validate_call +from pydantic import BaseModel, Field + from .protocol import PROTOCOLS _SUBSCRIPT_TYPE = Union[Literal["__all__"], int, Set[str]] _SUBSCRIPT_DICT = Dict[str, Union[_SUBSCRIPT_TYPE, Literal["__none__"]]] +class ContextIdFilter(BaseModel): + update_time_greater: Optional[int] = Field(default=None) + update_time_less: Optional[int] = Field(default=None) + origin_interface_whitelist: Set[str] = Field(default_factory=set) + + def filter_keys(self, contexts: Dict[str, Tuple[int, int, int, bytes, bytes]]) -> Set[str]: + if self.update_time_greater is not None: + contexts = {k: (ti, ca, ua, m, fd) for k, (ti, ca, ua, m, fd) in contexts.items() if ua > self.update_time_greater} + if self.update_time_less is not None: + contexts = {k: (ti, ca, ua, m, fd) for k, (ti, ca, ua, m, fd) in contexts.items() if ua < self.update_time_less} + if len(self.origin_interface_whitelist) > 0: + # TODO: implement whitelist once context ID is ready + pass + return contexts.keys() + + class DBContextStorage(ABC): _main_table_name: Literal["main"] = "main" _turns_table_name: Literal["turns"] = "turns" @@ -79,6 +96,29 @@ def verifier(self: "DBContextStorage", *args, **kwargs): else: return method(self, *args, **kwargs) return verifier + + @staticmethod + def _convert_id_filter(method: Callable): + def verifier(self, *args, **kwargs): + if len(args) >= 1: + args, filter_obj = list(args[1:]), args[0] + else: + filter_obj = kwargs.pop("filter", None) + if filter_obj is None: + raise ValueError(f"For method {method.__name__} argument 'filter' is not found!") + elif isinstance(filter_obj, Dict): + filter_obj = ContextIdFilter.model_validate(filter_obj) + elif not isinstance(filter_obj, ContextIdFilter): + raise ValueError(f"Invalid type '{type(filter_obj).__name__}' for method '{method.__name__}' argument 'filter'!") + return method(self, *args, filter=filter_obj, **kwargs) + return verifier + + @abstractmethod + async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: + """ + :param filter: + """ + raise NotImplementedError @abstractmethod async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: diff --git a/chatsky/context_storages/file.py b/chatsky/context_storages/file.py index 71f9b129c..060819612 100644 --- a/chatsky/context_storages/file.py +++ b/chatsky/context_storages/file.py @@ -64,6 +64,10 @@ async def _save(self, data: SerializableStorage) -> None: async def _load(self) -> SerializableStorage: raise NotImplementedError + @DBContextStorage._convert_id_filter + async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: + return filter.filter_keys((await self._load()).main) + async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: logger.debug(f"Loading main info for {ctx_id}...") result = (await self._load()).main.get(ctx_id, None) diff --git a/chatsky/context_storages/memory.py b/chatsky/context_storages/memory.py index 58486c614..05009b008 100644 --- a/chatsky/context_storages/memory.py +++ b/chatsky/context_storages/memory.py @@ -1,6 +1,6 @@ -from typing import Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple, Union -from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE +from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT class MemoryContextStorage(DBContextStorage): @@ -30,6 +30,10 @@ def __init__( self._responses_field_name: dict(), } + @DBContextStorage._convert_id_filter + async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: + return filter.filter_keys(self._main_storage) + async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: return self._main_storage.get(ctx_id, None) diff --git a/chatsky/context_storages/mongo.py b/chatsky/context_storages/mongo.py index 7daf1da15..67208bbc9 100644 --- a/chatsky/context_storages/mongo.py +++ b/chatsky/context_storages/mongo.py @@ -13,18 +13,17 @@ """ import asyncio -from typing import Dict, Set, Tuple, Optional, List +from typing import Any, Dict, Set, Tuple, Optional, List, Union try: from pymongo import UpdateOne - from pymongo.collection import Collection from motor.motor_asyncio import AsyncIOMotorClient mongo_available = True except ImportError: mongo_available = False -from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE +from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT from .protocol import get_protocol_install_suggestion @@ -73,6 +72,19 @@ def __init__( ) ) + @DBContextStorage._convert_id_filter + async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: + ftr_dct = dict() + if filter.update_time_greater is not None: + ftr_dct.setdefault(self._updated_at_column_name, dict()).update({"$gt": filter.update_time_greater}) + if filter.update_time_less is not None: + ftr_dct.setdefault(self._updated_at_column_name, dict()).update({"$lt": filter.update_time_less}) + if len(filter.origin_interface_whitelist) > 0: + # TODO: implement whitelist once context ID is ready + pass + result = await self.main_table.find(ftr_dct, [self._id_column_name]).to_list(None) + return {item[self._key_column_name] for item in result} + async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: result = await self.main_table.find_one( {self._id_column_name: ctx_id}, diff --git a/chatsky/context_storages/redis.py b/chatsky/context_storages/redis.py index aa3eeed1a..0f213f857 100644 --- a/chatsky/context_storages/redis.py +++ b/chatsky/context_storages/redis.py @@ -14,7 +14,7 @@ """ from asyncio import gather -from typing import Callable, List, Dict, Set, Tuple, Optional +from typing import Any, List, Dict, Set, Tuple, Optional, Union try: from redis.asyncio import Redis @@ -23,7 +23,7 @@ except ImportError: redis_available = False -from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE +from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT from .protocol import get_protocol_install_suggestion @@ -74,6 +74,13 @@ def _keys_to_bytes(keys: List[int]) -> List[bytes]: def _bytes_to_keys(keys: List[bytes]) -> List[int]: return [int(f.decode("utf-8")) for f in keys] + @DBContextStorage._convert_id_filter + async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: + context_ids = [k.decode("utf-8") for k in await self.database.keys(f"{self._main_key}:*")] + context_upd = [int(await self.database.hget(f"{self._main_key}:{k}", self._updated_at_column_name)) for k in context_ids] + partial_contexts = {k: (None, None, ua, None, None) for k, ua in zip(context_ids, context_upd)} + return filter.filter_keys(partial_contexts) + async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: if await self.database.exists(f"{self._main_key}:{ctx_id}"): cti, ca, ua, msc, fd = await gather( diff --git a/chatsky/context_storages/sql.py b/chatsky/context_storages/sql.py index d021c6123..a41bd21df 100644 --- a/chatsky/context_storages/sql.py +++ b/chatsky/context_storages/sql.py @@ -222,6 +222,19 @@ def _check_availability(self): install_suggestion = get_protocol_install_suggestion("sqlite") raise ImportError("Package `sqlalchemy` and/or `aiosqlite` is missing.\n" + install_suggestion) + @DBContextStorage._convert_id_filter + async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: + stmt = select(self.main_table.c[self._id_column_name]) + if filter.update_time_greater is not None: + stmt.where(self.main_table.c[self._updated_at_column_name] > filter.update_time_greater) + if filter.update_time_less is not None: + stmt.where(self.main_table.c[self._updated_at_column_name] < filter.update_time_less) + if len(filter.origin_interface_whitelist) > 0: + # TODO: implement whitelist once context ID is ready + pass + async with self.engine.begin() as conn: + return set((await conn.execute(stmt)).fetchone()) + @DBContextStorage._synchronously_lock(lambda s: s.is_asynchronous) async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: logger.debug(f"Loading main info for {ctx_id}...") @@ -270,7 +283,8 @@ async def delete_context(self, ctx_id: str) -> None: async def load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[int, bytes]]: logger.debug(f"Loading latest items for {ctx_id}, {field_name}...") stmt = select(self.turns_table.c[self._key_column_name], self.turns_table.c[field_name]) - stmt = stmt.where((self.turns_table.c[self._id_column_name] == ctx_id) & (self.turns_table.c[field_name] != None)) + stmt = stmt.where(self.turns_table.c[self._id_column_name] == ctx_id) + stmt = stmt.where(self.turns_table.c[field_name] != None) stmt = stmt.order_by(self.turns_table.c[self._key_column_name].desc()) if isinstance(self._subscripts[field_name], int): stmt = stmt.limit(self._subscripts[field_name]) @@ -284,8 +298,10 @@ async def load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[in @DBContextStorage._verify_field_name @DBContextStorage._synchronously_lock(lambda s: s.is_asynchronous) async def load_field_keys(self, ctx_id: str, field_name: str) -> List[int]: + stmt = select(self.turns_table.c[self._key_column_name]) + stmt = stmt.where(self.turns_table.c[self._id_column_name] == ctx_id) + stmt = stmt.where(self.turns_table.c[field_name] != None) logger.debug(f"Loading field keys for {ctx_id}, {field_name}...") - stmt = select(self.turns_table.c[self._key_column_name]).where((self.turns_table.c[self._id_column_name] == ctx_id) & (self.turns_table.c[field_name] != None)) async with self.engine.begin() as conn: result = [k[0] for k in (await conn.execute(stmt)).fetchall()] logger.debug(f"Field keys loaded for {ctx_id}, {field_name}: {collapse_num_list(result)}") @@ -296,7 +312,9 @@ async def load_field_keys(self, ctx_id: str, field_name: str) -> List[int]: async def load_field_items(self, ctx_id: str, field_name: str, keys: List[int]) -> List[bytes]: logger.debug(f"Loading field items for {ctx_id}, {field_name} ({collapse_num_list(keys)})...") stmt = select(self.turns_table.c[self._key_column_name], self.turns_table.c[field_name]) - stmt = stmt.where((self.turns_table.c[self._id_column_name] == ctx_id) & (self.turns_table.c[self._key_column_name].in_(tuple(keys))) & (self.turns_table.c[field_name] != None)) + stmt = stmt.where(self.turns_table.c[self._id_column_name] == ctx_id) + stmt = stmt.where(self.turns_table.c[self._key_column_name].in_(tuple(keys))) + stmt = stmt.where(self.turns_table.c[field_name] != None) async with self.engine.begin() as conn: result = list((await conn.execute(stmt)).fetchall()) logger.debug(f"Field items loaded for {ctx_id}, {field_name}: {collapse_num_list([k for k, _ in result])}") diff --git a/chatsky/context_storages/ydb.py b/chatsky/context_storages/ydb.py index 71771fbb2..3779e88ec 100644 --- a/chatsky/context_storages/ydb.py +++ b/chatsky/context_storages/ydb.py @@ -12,10 +12,10 @@ from asyncio import gather, run from os.path import join -from typing import Awaitable, Callable, Set, Tuple, List, Dict, Optional +from typing import Any, Awaitable, Callable, Dict, Set, Tuple, List, Optional, Union from urllib.parse import urlsplit -from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE +from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT from .protocol import get_protocol_install_suggestion try: @@ -55,6 +55,8 @@ class YDBContextStorage(DBContextStorage): :param table_name: The name of the table to use. """ + _UPDATE_TIME_GREATER_VAR = "update_time_greater" + _UPDATE_TIME_LESS_VAR = "update_time_less" _LIMIT_VAR = "limit" _KEY_VAR = "key" @@ -133,6 +135,36 @@ async def callee(session: Session) -> None: await self.pool.retry_operation(callee) + @DBContextStorage._convert_id_filter + async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: + async def callee(session: Session) -> Set[str]: + declare, prepare, conditions = list(), dict(), list() + if filter.update_time_greater is not None: + declare += [f"DECLARE ${self._UPDATE_TIME_GREATER_VAR} AS Uint64;"] + prepare.update({f"${self._UPDATE_TIME_GREATER_VAR}": filter.update_time_greater}) + conditions += [f"{self._updated_at_column_name} > ${self._UPDATE_TIME_GREATER_VAR}"] + if filter.update_time_less is not None: + declare += [f"DECLARE ${self._UPDATE_TIME_LESS_VAR} AS Uint64;"] + prepare.update({f"${self._UPDATE_TIME_LESS_VAR}": filter.update_time_less}) + conditions += [f"{self._updated_at_column_name} < ${self._UPDATE_TIME_LESS_VAR}"] + if len(filter.origin_interface_whitelist) > 0: + # TODO: implement whitelist once context ID is ready + pass + where = f"WHERE {' AND '.join(conditions)}" if len(conditions) > 0 else "" + query = f""" + PRAGMA TablePathPrefix("{self.database}"); + {" ".join(declare)} + SELECT {self._id_column_name} + FROM {self.main_table} + {where}; + """ # noqa: E501 + result_sets = await session.transaction(SerializableReadWrite()).execute( + await session.prepare(query), prepare, commit_tx=True + ) + return {e[self._id_column_name] for e in result_sets[0].rows} if len(result_sets[0].rows) > 0 else set() + + return await self.pool.retry_operation(callee) + async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: async def callee(session: Session) -> Optional[Tuple[int, int, int, bytes, bytes]]: query = f""" diff --git a/tests/context_storages/test_dbs.py b/tests/context_storages/test_dbs.py index c7d0e3c56..349dad9ba 100644 --- a/tests/context_storages/test_dbs.py +++ b/tests/context_storages/test_dbs.py @@ -168,6 +168,17 @@ def configure_context_storage( if responses_subscript is not None: context_storage._subscripts["responses"] = responses_subscript + async def test_get_context_ids(self, db, add_context): + await add_context("1") + + assert await db.get_context_ids({"update_time_greater": 0}) == {"1"} + assert await db.get_context_ids({"update_time_greater": 2}) == set() + + assert await db.get_context_ids({"update_time_less": 0}) == set() + assert await db.get_context_ids({"update_time_less": 2}) == {"1"} + + # TODO: implement whitelist once context ID is ready + async def test_add_context(self, db, add_context): # test the fixture await add_context("1")