From 73f9ab5a2364a02ce5d7823936bbbd37ddfee50c Mon Sep 17 00:00:00 2001 From: Roman Zlobin Date: Thu, 24 Oct 2024 23:11:45 +0300 Subject: [PATCH 1/3] Reapply "key filter implementation" This reverts commit e61b1b7feac960b32994942c7aafd50eb3e077ae. --- chatsky/context_storages/database.py | 41 ++++++++++++++++++++++++++-- chatsky/context_storages/file.py | 8 ++++-- chatsky/context_storages/memory.py | 8 ++++-- chatsky/context_storages/redis.py | 9 ++++-- 4 files changed, 58 insertions(+), 8 deletions(-) diff --git a/chatsky/context_storages/database.py b/chatsky/context_storages/database.py index 563d7a175..bd319a40b 100644 --- a/chatsky/context_storages/database.py +++ b/chatsky/context_storages/database.py @@ -10,11 +10,10 @@ from abc import ABC, abstractmethod from importlib import import_module -from inspect import signature from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Set, Tuple, Union -from pydantic import BaseModel, Field, field_validator, validate_call +from pydantic import BaseModel, Field from .protocol import PROTOCOLS @@ -22,6 +21,21 @@ _SUBSCRIPT_DICT = Dict[str, Union[_SUBSCRIPT_TYPE, Literal["__none__"]]] +class ContextIdFilter(BaseModel): + update_time_greater: Optional[int] = Field(default=None) + update_time_less: Optional[int] = Field(default=None) + origin_interface_whitelist: Set[str] = Field(default_factory=set) + + def filter_keys(self, keys: Set[str]) -> Set[str]: + if self.update_time_greater is not None: + keys = {k for k in keys if k > self.update_time_greater} + if self.update_time_less is not None: + keys = {k for k in keys if k < self.update_time_greater} + if len(self.origin_interface_whitelist) > 0: + keys = {k for k in keys if k in self.origin_interface_whitelist} + return keys + + class DBContextStorage(ABC): _main_table_name: Literal["main"] = "main" _turns_table_name: Literal["turns"] = "turns" @@ -72,6 +86,29 @@ def verifier(self, *args, **kwargs): else: return method(self, *args, **kwargs) return verifier + + @staticmethod + def _convert_id_filter(method: Callable): + def verifier(self, *args, **kwargs): + if len(args) >= 1: + args, filter_obj = [args[0]] + args[1:], args[1] + else: + filter_obj = kwargs.pop("filter", None) + if filter_obj is None: + raise ValueError(f"For method {method.__name__} argument 'filter' is not found!") + elif isinstance(filter_obj, Dict): + filter_obj = ContextIdFilter.validate_model(filter_obj) + elif not isinstance(filter_obj, ContextIdFilter): + raise ValueError(f"Invalid type '{type(filter_obj).__name__}' for method '{method.__name__}' argument 'filter'!") + return method(self, *args, filter=filter_obj, **kwargs) + return verifier + + @abstractmethod + async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> List[str]: + """ + :param filter: + """ + raise NotImplementedError @abstractmethod async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: diff --git a/chatsky/context_storages/file.py b/chatsky/context_storages/file.py index 8cebb9118..d1ca0b853 100644 --- a/chatsky/context_storages/file.py +++ b/chatsky/context_storages/file.py @@ -10,11 +10,11 @@ import asyncio from pickle import loads, dumps from shelve import DbfilenameShelf -from typing import List, Set, Tuple, Dict, Optional +from typing import Any, List, Set, Tuple, Dict, Optional, Union from pydantic import BaseModel, Field -from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE +from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT try: from aiofiles import open @@ -61,6 +61,10 @@ async def _save(self, data: SerializableStorage) -> None: async def _load(self) -> SerializableStorage: raise NotImplementedError + @DBContextStorage._verify_field_name + async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: + return filter.filter_keys(set((await self._load()).main.keys())) + async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: return (await self._load()).main.get(ctx_id, None) diff --git a/chatsky/context_storages/memory.py b/chatsky/context_storages/memory.py index b8bbb2e71..805310d53 100644 --- a/chatsky/context_storages/memory.py +++ b/chatsky/context_storages/memory.py @@ -1,6 +1,6 @@ -from typing import Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple, Union -from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE +from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT class MemoryContextStorage(DBContextStorage): @@ -32,6 +32,10 @@ def __init__( self._responses_field_name: dict(), } + @DBContextStorage._verify_field_name + async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: + return filter.filter_keys(set(self._main_storage.keys())) + async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: return self._main_storage.get(ctx_id, None) diff --git a/chatsky/context_storages/redis.py b/chatsky/context_storages/redis.py index 99e57ad7f..bf4fcea37 100644 --- a/chatsky/context_storages/redis.py +++ b/chatsky/context_storages/redis.py @@ -14,7 +14,7 @@ """ from asyncio import gather -from typing import Callable, List, Dict, Set, Tuple, Optional +from typing import Any, List, Dict, Set, Tuple, Optional, Union try: from redis.asyncio import Redis @@ -23,7 +23,7 @@ except ImportError: redis_available = False -from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE +from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT from .protocol import get_protocol_install_suggestion @@ -76,6 +76,11 @@ def _keys_to_bytes(keys: List[int]) -> List[bytes]: def _bytes_to_keys(keys: List[bytes]) -> List[int]: return [int(f.decode("utf-8")) for f in keys] + @DBContextStorage._verify_field_name + async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: + context_ids = {k.decode("utf-8") for k in await self.database.keys(f"{self._main_key}:*")} + return filter.filter_keys(context_ids) + async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: if await self.database.exists(f"{self._main_key}:{ctx_id}"): cti, ca, ua, msc, fd = await gather( From 3531008aa751a628b9372eb873f9e4ea1aaf4987 Mon Sep 17 00:00:00 2001 From: pseusys Date: Sat, 2 Nov 2024 18:51:01 +0800 Subject: [PATCH 2/3] filtering function finished --- chatsky/context_storages/database.py | 11 +++++----- chatsky/context_storages/file.py | 3 +-- chatsky/context_storages/memory.py | 3 +-- chatsky/context_storages/mongo.py | 17 +++++++++++++--- chatsky/context_storages/redis.py | 7 ++++--- chatsky/context_storages/sql.py | 27 ++++++++++++++++++++----- chatsky/context_storages/ydb.py | 30 ++++++++++++++++++++++++++-- 7 files changed, 76 insertions(+), 22 deletions(-) diff --git a/chatsky/context_storages/database.py b/chatsky/context_storages/database.py index bd319a40b..bf3f75b57 100644 --- a/chatsky/context_storages/database.py +++ b/chatsky/context_storages/database.py @@ -26,14 +26,15 @@ class ContextIdFilter(BaseModel): update_time_less: Optional[int] = Field(default=None) origin_interface_whitelist: Set[str] = Field(default_factory=set) - def filter_keys(self, keys: Set[str]) -> Set[str]: + def filter_keys(self, contexts: Dict[str, Tuple[int, int, int, bytes, bytes]]) -> Set[str]: if self.update_time_greater is not None: - keys = {k for k in keys if k > self.update_time_greater} + contexts = {k: (ti, ca, ua, m, fd) for k, (ti, ca, ua, m, fd) in contexts.items() if ua > self.update_time_greater} if self.update_time_less is not None: - keys = {k for k in keys if k < self.update_time_greater} + contexts = {k: (ti, ca, ua, m, fd) for k, (ti, ca, ua, m, fd) in contexts.items() if ua < self.update_time_less} if len(self.origin_interface_whitelist) > 0: - keys = {k for k in keys if k in self.origin_interface_whitelist} - return keys + # TODO: implement whitelist once context ID is + pass + return contexts.keys() class DBContextStorage(ABC): diff --git a/chatsky/context_storages/file.py b/chatsky/context_storages/file.py index d1ca0b853..ae82544fb 100644 --- a/chatsky/context_storages/file.py +++ b/chatsky/context_storages/file.py @@ -61,9 +61,8 @@ async def _save(self, data: SerializableStorage) -> None: async def _load(self) -> SerializableStorage: raise NotImplementedError - @DBContextStorage._verify_field_name async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: - return filter.filter_keys(set((await self._load()).main.keys())) + return filter.filter_keys((await self._load()).main) async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: return (await self._load()).main.get(ctx_id, None) diff --git a/chatsky/context_storages/memory.py b/chatsky/context_storages/memory.py index 805310d53..870d3885e 100644 --- a/chatsky/context_storages/memory.py +++ b/chatsky/context_storages/memory.py @@ -32,9 +32,8 @@ def __init__( self._responses_field_name: dict(), } - @DBContextStorage._verify_field_name async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: - return filter.filter_keys(set(self._main_storage.keys())) + return filter.filter_keys(self._main_storage) async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: return self._main_storage.get(ctx_id, None) diff --git a/chatsky/context_storages/mongo.py b/chatsky/context_storages/mongo.py index c1e01ddbd..00ca560fa 100644 --- a/chatsky/context_storages/mongo.py +++ b/chatsky/context_storages/mongo.py @@ -13,18 +13,17 @@ """ import asyncio -from typing import Dict, Set, Tuple, Optional, List +from typing import Any, Dict, Set, Tuple, Optional, List, Union try: from pymongo import UpdateOne - from pymongo.collection import Collection from motor.motor_asyncio import AsyncIOMotorClient mongo_available = True except ImportError: mongo_available = False -from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE +from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT from .protocol import get_protocol_install_suggestion @@ -75,6 +74,18 @@ def __init__( ) ) + async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: + ftr_dct = dict() + if filter.update_time_greater is not None: + ftr_dct.setdefault(self._updated_at_column_name, dict()).update({"$gt": filter.update_time_greater}) + if filter.update_time_less is not None: + ftr_dct.setdefault(self._updated_at_column_name, dict()).update({"$lt": filter.update_time_less}) + if len(filter.origin_interface_whitelist) > 0: + # TODO: implement whitelist once context ID is + pass + result = await self.main_table.find(ftr_dct, [self._id_column_name]).to_list(None) + return {item[self._key_column_name] for item in result} + async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: result = await self.main_table.find_one( {self._id_column_name: ctx_id}, diff --git a/chatsky/context_storages/redis.py b/chatsky/context_storages/redis.py index bf4fcea37..d8920bc0b 100644 --- a/chatsky/context_storages/redis.py +++ b/chatsky/context_storages/redis.py @@ -76,10 +76,11 @@ def _keys_to_bytes(keys: List[int]) -> List[bytes]: def _bytes_to_keys(keys: List[bytes]) -> List[int]: return [int(f.decode("utf-8")) for f in keys] - @DBContextStorage._verify_field_name async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: - context_ids = {k.decode("utf-8") for k in await self.database.keys(f"{self._main_key}:*")} - return filter.filter_keys(context_ids) + context_ids = [k.decode("utf-8") for k in await self.database.keys(f"{self._main_key}:*")] + context_upd = [int(await self.database.hget(f"{self._main_key}:{k}", self._updated_at_column_name)) for k in context_ids] + partial_contexts = {k: (None, None, ua, None, None) for k, ua in zip(context_ids, context_upd)} + return filter.filter_keys(partial_contexts) async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: if await self.database.exists(f"{self._main_key}:{ctx_id}"): diff --git a/chatsky/context_storages/sql.py b/chatsky/context_storages/sql.py index dc0b0fb8d..3bedfc5a4 100644 --- a/chatsky/context_storages/sql.py +++ b/chatsky/context_storages/sql.py @@ -17,9 +17,9 @@ import asyncio from importlib import import_module from os import getenv -from typing import Callable, Collection, Dict, List, Optional, Set, Tuple +from typing import Any, Callable, Collection, Dict, List, Optional, Set, Tuple, Union -from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE +from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT from .protocol import get_protocol_install_suggestion try: @@ -214,6 +214,18 @@ def _check_availability(self): install_suggestion = get_protocol_install_suggestion("sqlite") raise ImportError("Package `sqlalchemy` and/or `aiosqlite` is missing.\n" + install_suggestion) + async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: + stmt = select(self.main_table.c[self._id_column_name]) + if filter.update_time_greater is not None: + stmt.where(self.main_table.c[self._updated_at_column_name] > filter.update_time_greater) + if filter.update_time_less is not None: + stmt.where(self.main_table.c[self._updated_at_column_name] < filter.update_time_less) + if len(filter.origin_interface_whitelist) > 0: + # TODO: implement whitelist once context ID is + pass + async with self.engine.begin() as conn: + return set((await conn.execute(stmt)).fetchone()) + async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: stmt = select(self.main_table).where(self.main_table.c[self._id_column_name] == ctx_id) async with self.engine.begin() as conn: @@ -251,7 +263,8 @@ async def delete_context(self, ctx_id: str) -> None: @DBContextStorage._verify_field_name async def load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[int, bytes]]: stmt = select(self.turns_table.c[self._key_column_name], self.turns_table.c[field_name]) - stmt = stmt.where((self.turns_table.c[self._id_column_name] == ctx_id) & (self.turns_table.c[field_name] != None)) + stmt = stmt.where(self.turns_table.c[self._id_column_name] == ctx_id) + stmt = stmt.where(self.turns_table.c[field_name] != None) stmt = stmt.order_by(self.turns_table.c[self._key_column_name].desc()) if isinstance(self._subscripts[field_name], int): stmt = stmt.limit(self._subscripts[field_name]) @@ -262,14 +275,18 @@ async def load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[in @DBContextStorage._verify_field_name async def load_field_keys(self, ctx_id: str, field_name: str) -> List[int]: - stmt = select(self.turns_table.c[self._key_column_name]).where((self.turns_table.c[self._id_column_name] == ctx_id) & (self.turns_table.c[field_name] != None)) + stmt = select(self.turns_table.c[self._key_column_name]) + stmt = stmt.where(self.turns_table.c[self._id_column_name] == ctx_id) + stmt = stmt.where(self.turns_table.c[field_name] != None) async with self.engine.begin() as conn: return [k[0] for k in (await conn.execute(stmt)).fetchall()] @DBContextStorage._verify_field_name async def load_field_items(self, ctx_id: str, field_name: str, keys: List[int]) -> List[bytes]: stmt = select(self.turns_table.c[self._key_column_name], self.turns_table.c[field_name]) - stmt = stmt.where((self.turns_table.c[self._id_column_name] == ctx_id) & (self.turns_table.c[self._key_column_name].in_(tuple(keys))) & (self.turns_table.c[field_name] != None)) + stmt = stmt.where(self.turns_table.c[self._id_column_name] == ctx_id) + stmt = stmt.where(self.turns_table.c[self._key_column_name].in_(tuple(keys))) + stmt = stmt.where(self.turns_table.c[field_name] != None) async with self.engine.begin() as conn: return list((await conn.execute(stmt)).fetchall()) diff --git a/chatsky/context_storages/ydb.py b/chatsky/context_storages/ydb.py index 34fd063fe..4a1680d36 100644 --- a/chatsky/context_storages/ydb.py +++ b/chatsky/context_storages/ydb.py @@ -12,10 +12,10 @@ from asyncio import gather, run from os.path import join -from typing import Awaitable, Callable, Set, Tuple, List, Dict, Optional +from typing import Any, Awaitable, Callable, Dict, Set, Tuple, List, Optional, Union from urllib.parse import urlsplit -from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE +from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT from .protocol import get_protocol_install_suggestion try: @@ -132,6 +132,32 @@ async def callee(session: Session) -> None: await self.pool.retry_operation(callee) + async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: + async def callee(session: Session) -> Set[str]: + where_stmt = "" + conditions = list() + if filter.update_time_greater is not None: + conditions += [f"{self._updated_at_column_name} > {filter.update_time_greater}"] + if filter.update_time_less is not None: + conditions += [f"{self._updated_at_column_name} < {filter.update_time_less}"] + if len(filter.origin_interface_whitelist) > 0: + # TODO: implement whitelist once context ID is + pass + if len(conditions) > 0: + where_stmt = f"WHERE {' AND '.join(conditions)}" + query = f""" + PRAGMA TablePathPrefix("{self.database}"); + SELECT {self._id_column_name} + FROM {self.main_table} + {where_stmt}; + """ # noqa: E501 + result_sets = await session.transaction(SerializableReadWrite()).execute( + await session.prepare(query), dict(), commit_tx=True + ) + return {e[self._id_column_name] for e in result_sets[0].rows} if len(result_sets[0].rows) > 0 else set() + + return await self.pool.retry_operation(callee) + async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: async def callee(session: Session) -> Optional[Tuple[int, int, int, bytes, bytes]]: query = f""" From de2632f29f48350360112cdf94bddb42ab676657 Mon Sep 17 00:00:00 2001 From: pseusys Date: Tue, 12 Nov 2024 20:39:57 +0800 Subject: [PATCH 3/3] get_context_ids tests added --- chatsky/context_storages/database.py | 10 +++++----- chatsky/context_storages/file.py | 1 + chatsky/context_storages/memory.py | 1 + chatsky/context_storages/mongo.py | 3 ++- chatsky/context_storages/redis.py | 1 + chatsky/context_storages/sql.py | 3 ++- chatsky/context_storages/ydb.py | 25 ++++++++++++++++--------- tests/context_storages/test_dbs.py | 11 +++++++++++ 8 files changed, 39 insertions(+), 16 deletions(-) diff --git a/chatsky/context_storages/database.py b/chatsky/context_storages/database.py index bf3f75b57..7c14d9c28 100644 --- a/chatsky/context_storages/database.py +++ b/chatsky/context_storages/database.py @@ -32,7 +32,7 @@ def filter_keys(self, contexts: Dict[str, Tuple[int, int, int, bytes, bytes]]) - if self.update_time_less is not None: contexts = {k: (ti, ca, ua, m, fd) for k, (ti, ca, ua, m, fd) in contexts.items() if ua < self.update_time_less} if len(self.origin_interface_whitelist) > 0: - # TODO: implement whitelist once context ID is + # TODO: implement whitelist once context ID is ready pass return contexts.keys() @@ -79,7 +79,7 @@ def __init__( @staticmethod def _verify_field_name(method: Callable): def verifier(self, *args, **kwargs): - field_name = args[1] if len(args) >= 1 else kwargs.get("field_name", None) + field_name = args[1] if len(args) >= 2 else kwargs.get("field_name", None) if field_name is None: raise ValueError(f"For method {method.__name__} argument 'field_name' is not found!") elif field_name not in (self._labels_field_name, self._requests_field_name, self._responses_field_name): @@ -92,20 +92,20 @@ def verifier(self, *args, **kwargs): def _convert_id_filter(method: Callable): def verifier(self, *args, **kwargs): if len(args) >= 1: - args, filter_obj = [args[0]] + args[1:], args[1] + args, filter_obj = list(args[1:]), args[0] else: filter_obj = kwargs.pop("filter", None) if filter_obj is None: raise ValueError(f"For method {method.__name__} argument 'filter' is not found!") elif isinstance(filter_obj, Dict): - filter_obj = ContextIdFilter.validate_model(filter_obj) + filter_obj = ContextIdFilter.model_validate(filter_obj) elif not isinstance(filter_obj, ContextIdFilter): raise ValueError(f"Invalid type '{type(filter_obj).__name__}' for method '{method.__name__}' argument 'filter'!") return method(self, *args, filter=filter_obj, **kwargs) return verifier @abstractmethod - async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> List[str]: + async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: """ :param filter: """ diff --git a/chatsky/context_storages/file.py b/chatsky/context_storages/file.py index ae82544fb..5d7215abf 100644 --- a/chatsky/context_storages/file.py +++ b/chatsky/context_storages/file.py @@ -61,6 +61,7 @@ async def _save(self, data: SerializableStorage) -> None: async def _load(self) -> SerializableStorage: raise NotImplementedError + @DBContextStorage._convert_id_filter async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: return filter.filter_keys((await self._load()).main) diff --git a/chatsky/context_storages/memory.py b/chatsky/context_storages/memory.py index 870d3885e..55efde037 100644 --- a/chatsky/context_storages/memory.py +++ b/chatsky/context_storages/memory.py @@ -32,6 +32,7 @@ def __init__( self._responses_field_name: dict(), } + @DBContextStorage._convert_id_filter async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: return filter.filter_keys(self._main_storage) diff --git a/chatsky/context_storages/mongo.py b/chatsky/context_storages/mongo.py index 00ca560fa..2788ff420 100644 --- a/chatsky/context_storages/mongo.py +++ b/chatsky/context_storages/mongo.py @@ -74,6 +74,7 @@ def __init__( ) ) + @DBContextStorage._convert_id_filter async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: ftr_dct = dict() if filter.update_time_greater is not None: @@ -81,7 +82,7 @@ async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) if filter.update_time_less is not None: ftr_dct.setdefault(self._updated_at_column_name, dict()).update({"$lt": filter.update_time_less}) if len(filter.origin_interface_whitelist) > 0: - # TODO: implement whitelist once context ID is + # TODO: implement whitelist once context ID is ready pass result = await self.main_table.find(ftr_dct, [self._id_column_name]).to_list(None) return {item[self._key_column_name] for item in result} diff --git a/chatsky/context_storages/redis.py b/chatsky/context_storages/redis.py index d8920bc0b..0a3bda341 100644 --- a/chatsky/context_storages/redis.py +++ b/chatsky/context_storages/redis.py @@ -76,6 +76,7 @@ def _keys_to_bytes(keys: List[int]) -> List[bytes]: def _bytes_to_keys(keys: List[bytes]) -> List[int]: return [int(f.decode("utf-8")) for f in keys] + @DBContextStorage._convert_id_filter async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: context_ids = [k.decode("utf-8") for k in await self.database.keys(f"{self._main_key}:*")] context_upd = [int(await self.database.hget(f"{self._main_key}:{k}", self._updated_at_column_name)) for k in context_ids] diff --git a/chatsky/context_storages/sql.py b/chatsky/context_storages/sql.py index 3bedfc5a4..f1f666f19 100644 --- a/chatsky/context_storages/sql.py +++ b/chatsky/context_storages/sql.py @@ -214,6 +214,7 @@ def _check_availability(self): install_suggestion = get_protocol_install_suggestion("sqlite") raise ImportError("Package `sqlalchemy` and/or `aiosqlite` is missing.\n" + install_suggestion) + @DBContextStorage._convert_id_filter async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: stmt = select(self.main_table.c[self._id_column_name]) if filter.update_time_greater is not None: @@ -221,7 +222,7 @@ async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) if filter.update_time_less is not None: stmt.where(self.main_table.c[self._updated_at_column_name] < filter.update_time_less) if len(filter.origin_interface_whitelist) > 0: - # TODO: implement whitelist once context ID is + # TODO: implement whitelist once context ID is ready pass async with self.engine.begin() as conn: return set((await conn.execute(stmt)).fetchone()) diff --git a/chatsky/context_storages/ydb.py b/chatsky/context_storages/ydb.py index 4a1680d36..1b0b20d19 100644 --- a/chatsky/context_storages/ydb.py +++ b/chatsky/context_storages/ydb.py @@ -55,6 +55,9 @@ class YDBContextStorage(DBContextStorage): :param table_name: The name of the table to use. """ + _UPDATE_TIME_GREATER_VAR = "update_time_greater" + _UPDATE_TIME_LESS_VAR = "update_time_less" + is_asynchronous = True def __init__( @@ -132,27 +135,31 @@ async def callee(session: Session) -> None: await self.pool.retry_operation(callee) + @DBContextStorage._convert_id_filter async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: async def callee(session: Session) -> Set[str]: - where_stmt = "" - conditions = list() + declare, prepare, conditions = list(), dict(), list() if filter.update_time_greater is not None: - conditions += [f"{self._updated_at_column_name} > {filter.update_time_greater}"] + declare += [f"DECLARE ${self._UPDATE_TIME_GREATER_VAR} AS Uint64;"] + prepare.update({f"${self._UPDATE_TIME_GREATER_VAR}": filter.update_time_greater}) + conditions += [f"{self._updated_at_column_name} > ${self._UPDATE_TIME_GREATER_VAR}"] if filter.update_time_less is not None: - conditions += [f"{self._updated_at_column_name} < {filter.update_time_less}"] + declare += [f"DECLARE ${self._UPDATE_TIME_LESS_VAR} AS Uint64;"] + prepare.update({f"${self._UPDATE_TIME_LESS_VAR}": filter.update_time_less}) + conditions += [f"{self._updated_at_column_name} < ${self._UPDATE_TIME_LESS_VAR}"] if len(filter.origin_interface_whitelist) > 0: - # TODO: implement whitelist once context ID is + # TODO: implement whitelist once context ID is ready pass - if len(conditions) > 0: - where_stmt = f"WHERE {' AND '.join(conditions)}" + where = f"WHERE {' AND '.join(conditions)}" if len(conditions) > 0 else "" query = f""" PRAGMA TablePathPrefix("{self.database}"); + {" ".join(declare)} SELECT {self._id_column_name} FROM {self.main_table} - {where_stmt}; + {where}; """ # noqa: E501 result_sets = await session.transaction(SerializableReadWrite()).execute( - await session.prepare(query), dict(), commit_tx=True + await session.prepare(query), prepare, commit_tx=True ) return {e[self._id_column_name] for e in result_sets[0].rows} if len(result_sets[0].rows) > 0 else set() diff --git a/tests/context_storages/test_dbs.py b/tests/context_storages/test_dbs.py index 67629abc6..f72aa47d4 100644 --- a/tests/context_storages/test_dbs.py +++ b/tests/context_storages/test_dbs.py @@ -167,6 +167,17 @@ def configure_context_storage( if responses_subscript is not None: context_storage._subscripts["responses"] = responses_subscript + async def test_get_context_ids(self, db, add_context): + await add_context("1") + + assert await db.get_context_ids({"update_time_greater": 0}) == {"1"} + assert await db.get_context_ids({"update_time_greater": 2}) == set() + + assert await db.get_context_ids({"update_time_less": 0}) == set() + assert await db.get_context_ids({"update_time_less": 2}) == {"1"} + + # TODO: implement whitelist once context ID is ready + async def test_add_context(self, db, add_context): # test the fixture await add_context("1")