Skip to content

Commit

Permalink
Feat/script parser (#385)
Browse files Browse the repository at this point in the history
# Description

Added option to initialize Pipeline from yaml/json files.

## Changelog

### Breaking changes

- Removed `proc.ExtractAll` -- the function is unsafe as it overwrites
the entire slot storage. it is still available as method of the slot
manager

### Features

- Add Pipeline from file import
- Add function that creates an index of commonly-used Chatsky objects
- Added imports to some `__init__` files
- Allow initializing NodeLabel from a list of two strings

### Bug fixes

- Slot extraction will now not write the value to the slot storage if
value was not successfully extracted. Can be changed via the
`success_only` flag

### Devel
- Add aliases to script keywords

# Checklist

- [x] I have performed a self-review of the changes

# To Consider

- Add tests (if functionality is changed)
- Update API reference / tutorials / guides
- Update CONTRIBUTING.md (if devel workflow is changed)
- Update `.ignore` files, scripts (such as `lint`), distribution
manifest (if files are added/deleted)
- Search for references to changed entities in the codebase

---------

Co-authored-by: Ramimashkouk <[email protected]>
Co-authored-by: Ramimashkouk <[email protected]>
  • Loading branch information
3 people authored Sep 7, 2024
1 parent 029838c commit 40981ef
Show file tree
Hide file tree
Showing 36 changed files with 1,100 additions and 66 deletions.
1 change: 1 addition & 0 deletions chatsky/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@
import chatsky.responses as rsp
import chatsky.processing as proc


import chatsky.__rebuild_pydantic_models__
22 changes: 21 additions & 1 deletion chatsky/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,27 @@
"""

from chatsky.core.context import Context
from chatsky.core.message import Message, MessageInitTypes
from chatsky.core.message import (
Message,
MessageInitTypes,
Attachment,
CallbackQuery,
Location,
Contact,
Invoice,
PollOption,
Poll,
DataAttachment,
Audio,
Video,
Animation,
Image,
Sticker,
Document,
VoiceMessage,
VideoMessage,
MediaGroup,
)
from chatsky.core.pipeline import Pipeline
from chatsky.core.script import Node, Flow, Script
from chatsky.core.script_function import BaseCondition, BaseResponse, BaseDestination, BaseProcessing, BasePriority
Expand Down
8 changes: 6 additions & 2 deletions chatsky/core/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
It only contains types and properties that are compatible with most messaging services.
"""

from typing import Literal, Optional, List, Union, Dict, Any
from __future__ import annotations
from typing import Literal, Optional, List, Union, Dict, Any, TYPE_CHECKING
from typing_extensions import TypeAlias, Annotated
from pathlib import Path
from urllib.request import urlopen
Expand All @@ -16,7 +17,6 @@
from pydantic import Field, FilePath, HttpUrl, model_validator, field_validator, field_serializer
from pydantic_core import Url

from chatsky.messengers.common.interface import MessengerInterfaceWithAttachments
from chatsky.utils.devel import (
json_pickle_validator,
json_pickle_serializer,
Expand All @@ -25,6 +25,9 @@
JSONSerializableExtras,
)

if TYPE_CHECKING:
from chatsky.messengers.common.interface import MessengerInterfaceWithAttachments


class DataModel(JSONSerializableExtras):
"""
Expand Down Expand Up @@ -283,6 +286,7 @@ class level variables to store message information.
VoiceMessage,
VideoMessage,
MediaGroup,
DataModel,
]
]
] = None
Expand Down
12 changes: 8 additions & 4 deletions chatsky/core/node_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from __future__ import annotations

from typing import Optional, Union, Tuple, TYPE_CHECKING
from typing import Optional, Union, Tuple, List, TYPE_CHECKING
from typing_extensions import TypeAlias, Annotated

from pydantic import BaseModel, model_validator, ValidationInfo
Expand Down Expand Up @@ -47,26 +47,29 @@ def validate_from_str_or_tuple(cls, data, info: ValidationInfo):
Allow instantiating of this class from:
- A single string (node name). Also attempt to get the current flow name from context.
- A tuple of two strings (flow and node name).
- A tuple or list of two strings (flow and node name).
"""
if isinstance(data, str):
flow_name = None
context = info.context
if isinstance(context, dict):
flow_name = _get_current_flow_name(context.get("ctx"))
return {"flow_name": flow_name, "node_name": data}
elif isinstance(data, tuple):
elif isinstance(data, (tuple, list)):
if len(data) == 2 and isinstance(data[0], str) and isinstance(data[1], str):
return {"flow_name": data[0], "node_name": data[1]}
else:
raise ValueError(f"Cannot validate NodeLabel from {data!r}: tuple should contain 2 strings.")
raise ValueError(
f"Cannot validate NodeLabel from {data!r}: {type(data).__name__} should contain 2 strings."
)
return data


NodeLabelInitTypes: TypeAlias = Union[
NodeLabel,
Annotated[str, "node_name, flow name equal to current flow's name"],
Tuple[Annotated[str, "flow_name"], Annotated[str, "node_name"]],
Annotated[List[str], "list of two strings (flow_name and node_name)"],
Annotated[dict, "dict following the NodeLabel data model"],
]
"""Types that :py:class:`~.NodeLabel` can be validated from."""
Expand Down Expand Up @@ -124,6 +127,7 @@ def check_node_exists(self, info: ValidationInfo):
AbsoluteNodeLabel,
NodeLabel,
Tuple[Annotated[str, "flow_name"], Annotated[str, "node_name"]],
Annotated[List[str], "list of two strings (flow_name and node_name)"],
Annotated[dict, "dict following the AbsoluteNodeLabel data model"],
]
"""Types that :py:class:`~.AbsoluteNodeLabel` can be validated from."""
26 changes: 26 additions & 0 deletions chatsky/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .utils import finalize_service_group
from chatsky.core.service.actor import Actor
from chatsky.core.node_label import AbsoluteNodeLabel, AbsoluteNodeLabelInitTypes
from chatsky.core.script_parsing import JSONImporter, Path

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -167,6 +168,31 @@ def __init__(
super().__init__(**init_dict)
self.services_pipeline # cache services

@classmethod
def from_file(
cls,
file: Union[str, Path],
custom_dir: Union[str, Path] = "custom",
**overrides,
) -> "Pipeline":
"""
Create Pipeline by importing it from a file.
A file (json or yaml) should contain a dictionary with keys being a subset of pipeline init parameters.
See :py:meth:`.JSONImporter.import_pipeline_file` for more information.
:param file: Path to a file containing pipeline init parameters.
:param custom_dir: Path to a directory containing custom code.
Defaults to "./custom".
If ``file`` does not use custom code, this parameter will not have any effect.
:param overrides: You can pass init parameters to override those imported from the ``file``.
"""
pipeline = JSONImporter(custom_dir=custom_dir).import_pipeline_file(file)

pipeline.update(overrides)

return cls(**pipeline)

@computed_field
@cached_property
def actor(self) -> Actor:
Expand Down
42 changes: 26 additions & 16 deletions chatsky/core/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import logging
from typing import List, Optional, Dict

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, AliasChoices

from chatsky.core.script_function import AnyResponse, BaseProcessing
from chatsky.core.node_label import AbsoluteNodeLabel
Expand All @@ -21,28 +21,34 @@
logger = logging.getLogger(__name__)


class Node(BaseModel):
class Node(BaseModel, extra="forbid"):
"""
Node is a basic element of the dialog graph.
Usually used to represent a specific state of a conversation.
"""

transitions: List[Transition] = Field(default_factory=list)
transitions: List[Transition] = Field(
validation_alias=AliasChoices("transitions", "TRANSITIONS"), default_factory=list
)
"""List of transitions possible from this node."""
response: Optional[AnyResponse] = Field(default=None)
response: Optional[AnyResponse] = Field(validation_alias=AliasChoices("response", "RESPONSE"), default=None)
"""Response produced when this node is entered."""
pre_transition: Dict[str, BaseProcessing] = Field(default_factory=dict)
pre_transition: Dict[str, BaseProcessing] = Field(
validation_alias=AliasChoices("pre_transition", "PRE_TRANSITION"), default_factory=dict
)
"""
A dictionary of :py:class:`.BaseProcessing` functions that are executed before transitions are processed.
Keys of the dictionary act as names for the processing functions.
"""
pre_response: Dict[str, BaseProcessing] = Field(default_factory=dict)
pre_response: Dict[str, BaseProcessing] = Field(
validation_alias=AliasChoices("pre_response", "PRE_RESPONSE"), default_factory=dict
)
"""
A dictionary of :py:class:`.BaseProcessing` functions that are executed before response is processed.
Keys of the dictionary act as names for the processing functions.
"""
misc: dict = Field(default_factory=dict)
misc: dict = Field(validation_alias=AliasChoices("misc", "MISC"), default_factory=dict)
"""
A dictionary that is used to store metadata about the node.
Expand Down Expand Up @@ -72,7 +78,9 @@ class Flow(BaseModel, extra="allow"):
This is used to group them by a specific purpose.
"""

local_node: Node = Field(alias="local", default_factory=Node)
local_node: Node = Field(
validation_alias=AliasChoices("local", "LOCAL", "local_node", "LOCAL_NODE"), default_factory=Node
)
"""Node from which all other nodes in this Flow inherit properties according to :py:meth:`Node.merge`."""
__pydantic_extra__: Dict[str, Node]

Expand Down Expand Up @@ -100,7 +108,9 @@ class Script(BaseModel, extra="allow"):
It represents an entire dialog graph.
"""

global_node: Node = Field(alias="global", default_factory=Node)
global_node: Node = Field(
validation_alias=AliasChoices("global", "GLOBAL", "global_node", "GLOBAL_NODE"), default_factory=Node
)
"""Node from which all other nodes in this Script inherit properties according to :py:meth:`Node.merge`."""
__pydantic_extra__: Dict[str, Flow]

Expand Down Expand Up @@ -157,17 +167,17 @@ def get_inherited_node(self, label: AbsoluteNodeLabel) -> Optional[Node]:
return inheritant_node.merge(self.global_node).merge(flow.local_node).merge(node)


GLOBAL = "global"
GLOBAL = "GLOBAL"
"""Key for :py:attr:`~chatsky.core.script.Script.global_node`."""
LOCAL = "local"
LOCAL = "LOCAL"
"""Key for :py:attr:`~chatsky.core.script.Flow.local_node`."""
TRANSITIONS = "transitions"
TRANSITIONS = "TRANSITIONS"
"""Key for :py:attr:`~chatsky.core.script.Node.transitions`."""
RESPONSE = "response"
RESPONSE = "RESPONSE"
"""Key for :py:attr:`~chatsky.core.script.Node.response`."""
MISC = "misc"
MISC = "MISC"
"""Key for :py:attr:`~chatsky.core.script.Node.misc`."""
PRE_RESPONSE = "pre_response"
PRE_RESPONSE = "PRE_RESPONSE"
"""Key for :py:attr:`~chatsky.core.script.Node.pre_response`."""
PRE_TRANSITION = "pre_transition"
PRE_TRANSITION = "PRE_TRANSITION"
"""Key for :py:attr:`~chatsky.core.script.Node.pre_transition`."""
Loading

0 comments on commit 40981ef

Please sign in to comment.