Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disallow unserializable #408

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions chatsky/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import logging
import asyncio
from uuid import UUID, uuid4
from typing import Any, Optional, Union, Dict, TYPE_CHECKING
from typing import Optional, Union, Dict, TYPE_CHECKING

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, JsonValue

from chatsky.core.message import Message, MessageInitTypes
from chatsky.slots.slots import SlotManager
Expand Down Expand Up @@ -87,7 +87,7 @@ class FrameworkData(BaseModel, arbitrary_types_allowed=True):
Instance of the pipeline that manages this context.
Can be used to obtain run configuration such as script or fallback label.
"""
stats: Dict[str, Any] = Field(default_factory=dict)
stats: Dict[str, Union[BaseModel, JsonValue]] = Field(default_factory=dict)
"Enables complex stats collection across multiple turns."
slot_manager: SlotManager = Field(default_factory=SlotManager)
"Stores extracted slots."
Expand Down Expand Up @@ -133,7 +133,7 @@ class Context(BaseModel):
First response is stored at key ``1``.
IDs go up by ``1`` after that.
"""
misc: Dict[str, Any] = Field(default_factory=dict)
misc: Dict[str, Union[BaseModel, JsonValue]] = Field(default_factory=dict)
"""
``misc`` stores any custom data. The framework doesn't use this dictionary,
so storage of any data won't reflect on the work of the internal Chatsky functions.
Expand Down
68 changes: 9 additions & 59 deletions chatsky/core/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,21 @@
"""

from __future__ import annotations
from typing import Literal, Optional, List, Union, Dict, Any, TYPE_CHECKING
from typing import Literal, Optional, List, Union, Dict, TYPE_CHECKING
from typing_extensions import TypeAlias, Annotated
from pathlib import Path
from urllib.request import urlopen
import uuid
import abc

from pydantic import Field, FilePath, HttpUrl, model_validator, field_validator, field_serializer
from pydantic import BaseModel, Field, FilePath, HttpUrl, JsonValue, model_validator
from pydantic_core import Url

from chatsky.utils.devel import (
json_pickle_validator,
json_pickle_serializer,
pickle_serializer,
pickle_validator,
JSONSerializableExtras,
)

if TYPE_CHECKING:
from chatsky.messengers.common.interface import MessengerInterfaceWithAttachments


class DataModel(JSONSerializableExtras):
class DataModel(BaseModel, extra="allow"):
"""
This class is a Pydantic BaseModel that can have any type and number of extras.
"""
Expand Down Expand Up @@ -290,9 +282,9 @@ class level variables to store message information.
]
]
] = None
annotations: Optional[Dict[str, Any]] = None
misc: Optional[Dict[str, Any]] = None
original_message: Optional[Any] = None
annotations: Optional[Dict[str, Union[BaseModel, JsonValue]]] = None
misc: Optional[Dict[str, Union[BaseModel, JsonValue]]] = None
Copy link
Member

@RLKRo RLKRo Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change type annotation (for Union[BaseModel, JsonValue]) to allow deeper BaseModel usage (e.g. a dictionary or a list with BaseModel values).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PydanticValue: TypeAlias = Union[
    List["PydanticValue"],
    Dict[str, "PydanticValue"],
    BaseModel,
    str,
    bool,
    int,
    float,
    None,
]

original_message: Optional[Union[BaseModel, JsonValue]] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge dev to have #398 here.


def __init__( # this allows initializing Message with string as positional argument
self,
Expand All @@ -318,9 +310,9 @@ def __init__( # this allows initializing Message with string as positional argu
]
]
] = None,
annotations: Optional[Dict[str, Any]] = None,
misc: Optional[Dict[str, Any]] = None,
original_message: Optional[Any] = None,
annotations: Optional[Dict[str, Union[BaseModel, JsonValue]]] = None,
misc: Optional[Dict[str, Union[BaseModel, JsonValue]]] = None,
original_message: Optional[Union[BaseModel, JsonValue]] = None,
**kwargs,
):
super().__init__(
Expand All @@ -332,48 +324,6 @@ def __init__( # this allows initializing Message with string as positional argu
**kwargs,
)

@field_serializer("annotations", "misc", when_used="json")
def pickle_serialize_dicts(self, value):
"""
Serialize values that are not json-serializable via pickle.
Allows storing arbitrary data in misc/annotations when using context storages.
"""
if isinstance(value, dict):
return json_pickle_serializer(value)
return value

@field_validator("annotations", "misc", mode="before")
@classmethod
def pickle_validate_dicts(cls, value):
"""Restore values serialized with :py:meth:`pickle_serialize_dicts`."""
if isinstance(value, dict):
return json_pickle_validator(value)
return value

@field_serializer("original_message", when_used="json")
def pickle_serialize_original_message(self, value):
"""
Cast :py:attr:`original_message` to string via pickle.
Allows storing arbitrary data in this field when using context storages.
"""
if value is not None:
return pickle_serializer(value)
return value

@field_validator("original_message", mode="before")
@classmethod
def pickle_validate_original_message(cls, value):
"""
Restore :py:attr:`original_message` after being processed with
:py:meth:`pickle_serialize_original_message`.
"""
if value is not None:
return pickle_validator(value)
return value

def __str__(self) -> str:
return " ".join([f"{key}='{value}'" for key, value in self.model_dump(exclude_none=True).items()])

@model_validator(mode="before")
@classmethod
def validate_from_str(cls, data):
Expand Down
2 changes: 1 addition & 1 deletion chatsky/messengers/telegram/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ async def _on_event(self, update: Update, _: Any, create_message: Callable[[Upda
data_available = update.message is not None or update.callback_query is not None
if update.effective_chat is not None and data_available:
message = create_message(update)
message.original_message = update
message.original_message = update.to_dict(recursive=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also needs to apply to extra fields in Attachments.
Add a validator for Attachment and Message extras that modifies the extra field via to_dict if the field is of the TelegramObject value.

AFAIK if the extra field value is a dictionary from to_dict it should still work for the tg bot methods.

resp = await self._pipeline_runner(message, update.effective_chat.id)
if resp.last_response is not None:
await self.cast_message_to_telegram_and_send(
Expand Down
38 changes: 8 additions & 30 deletions chatsky/slots/slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,15 @@
import asyncio
import re
from abc import ABC, abstractmethod
from typing import Callable, Any, Awaitable, TYPE_CHECKING, Union, Optional, Dict
from typing import Callable, Awaitable, TYPE_CHECKING, Union, Optional, Dict
from typing_extensions import TypeAlias, Annotated
import logging
from functools import reduce
from string import Formatter

from pydantic import BaseModel, model_validator, Field, field_serializer, field_validator
from pydantic import BaseModel, JsonValue, model_validator, Field

from chatsky.utils.devel.async_helpers import wrap_sync_function_in_async
from chatsky.utils.devel.json_serialization import pickle_serializer, pickle_validator

if TYPE_CHECKING:
from chatsky.core import Context, Message
Expand Down Expand Up @@ -117,29 +116,8 @@ class ExtractedValueSlot(ExtractedSlot):
"""Value extracted from :py:class:`~.ValueSlot`."""

is_slot_extracted: bool
extracted_value: Any
default_value: Any = None

@field_serializer("extracted_value", "default_value", when_used="json")
def pickle_serialize_values(self, value):
"""
Cast values to string via pickle.
Allows storing arbitrary data in these fields when using context storages.
"""
if value is not None:
return pickle_serializer(value)
return value

@field_validator("extracted_value", "default_value", mode="before")
@classmethod
def pickle_validate_values(cls, value):
"""
Restore values after being processed with
:py:meth:`pickle_serialize_values`.
"""
if value is not None:
return pickle_validator(value)
return value
extracted_value: Union[BaseModel, JsonValue]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slots store exceptions on failure (which are not serializable).
I think we should store exception representation instead.

default_value: Optional[Union[BaseModel, JsonValue]] = None

@property
def __slot_extracted__(self) -> bool:
Expand Down Expand Up @@ -219,10 +197,10 @@ class ValueSlot(BaseSlot, frozen=True):
Subclass it, if you want to declare your own slot type.
"""

default_value: Any = None
default_value: Union[BaseModel, JsonValue] = None

@abstractmethod
async def extract_value(self, ctx: Context) -> Union[Any, SlotNotExtracted]:
async def extract_value(self, ctx: Context) -> Union[Union[BaseModel, JsonValue], SlotNotExtracted]:
"""
Return value extracted from context.

Expand Down Expand Up @@ -328,9 +306,9 @@ class FunctionSlot(ValueSlot, frozen=True):
Uses a user-defined `func` to extract slot value from the :py:attr:`~.Context.last_request` Message.
"""

func: Callable[[Message], Union[Awaitable[Union[Any, SlotNotExtracted]], Any, SlotNotExtracted]]
func: Callable[[Message], Union[Awaitable[Union[Union[BaseModel, JsonValue], SlotNotExtracted]], Union[BaseModel, JsonValue], SlotNotExtracted]]

async def extract_value(self, ctx: Context) -> Union[Any, SlotNotExtracted]:
async def extract_value(self, ctx: Context) -> Union[Union[BaseModel, JsonValue], SlotNotExtracted]:
return await wrap_sync_function_in_async(self.func, ctx.last_request)


Expand Down
154 changes: 0 additions & 154 deletions chatsky/utils/devel/json_serialization.py

This file was deleted.

Loading
Loading