-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update API #253
Update API #253
Changes from 7 commits
a110e62
bd67607
12eed6e
43533b5
bb42016
039840b
5dc75a1
b274383
dda1bc0
2d63c03
e2d8377
5df3c55
34c7d1a
6378aa5
84947ec
4878de9
95a5459
8b0db0c
3cb2b77
ed126ab
82a8c32
2a39d80
51a8624
63ba0f7
7b44375
1afa571
58fa9aa
cd41155
3c70317
74ea26d
223b284
ce3a8a2
48b6934
d8075f0
6e0307f
1079a51
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,13 +16,9 @@ | |
|
||
Overall, the actor acts as a bridge between the user's input and the dialog graph, | ||
making sure that the conversation follows the expected flow and providing a personalized experience to the user. | ||
|
||
Below you can see a diagram of user request processing with Actor. | ||
Both `request` and `response` are saved to :py:class:`.Context`. | ||
|
||
.. figure:: /_static/drawio/dfe/user_actor.png | ||
""" | ||
import logging | ||
import asyncio | ||
from typing import Union, Callable, Optional, Dict, List, Any, ForwardRef | ||
import copy | ||
|
||
|
@@ -34,6 +30,7 @@ | |
from dff.script.core.script import Script, Node | ||
from dff.script.core.normalization import normalize_label, normalize_response | ||
from dff.script.core.keywords import GLOBAL, LOCAL | ||
from ..service.utils import wrap_sync_function_in_async | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -109,51 +106,51 @@ def __init__( | |
# NB! The following API is highly experimental and may be removed at ANY time WITHOUT FURTHER NOTICE!! | ||
self._clean_turn_cache = True | ||
|
||
def __call__( | ||
async def __call__( | ||
self, pipeline: Pipeline, ctx: Optional[Union[Context, dict, str]] = None, *args, **kwargs | ||
) -> Union[Context, dict, str]: | ||
# context init | ||
ctx = self._context_init(ctx, *args, **kwargs) | ||
self._run_handlers(ctx, pipeline, ActorStage.CONTEXT_INIT, *args, **kwargs) | ||
await self._run_handlers(ctx, pipeline, ActorStage.CONTEXT_INIT, *args, **kwargs) | ||
|
||
# get previous node | ||
ctx = self._get_previous_node(ctx, *args, **kwargs) | ||
self._run_handlers(ctx, pipeline, ActorStage.GET_PREVIOUS_NODE, *args, **kwargs) | ||
await self._run_handlers(ctx, pipeline, ActorStage.GET_PREVIOUS_NODE, *args, **kwargs) | ||
|
||
# rewrite previous node | ||
ctx = self._rewrite_previous_node(ctx, *args, **kwargs) | ||
self._run_handlers(ctx, pipeline, ActorStage.REWRITE_PREVIOUS_NODE, *args, **kwargs) | ||
await self._run_handlers(ctx, pipeline, ActorStage.REWRITE_PREVIOUS_NODE, *args, **kwargs) | ||
|
||
# run pre transitions processing | ||
ctx = self._run_pre_transitions_processing(ctx, pipeline, *args, **kwargs) | ||
self._run_handlers(ctx, pipeline, ActorStage.RUN_PRE_TRANSITIONS_PROCESSING, *args, **kwargs) | ||
await self._run_pre_transitions_processing(ctx, pipeline, *args, **kwargs) | ||
ruthenian8 marked this conversation as resolved.
Show resolved
Hide resolved
ruthenian8 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
await self._run_handlers(ctx, pipeline, ActorStage.RUN_PRE_TRANSITIONS_PROCESSING, *args, **kwargs) | ||
|
||
# get true labels for scopes (GLOBAL, LOCAL, NODE) | ||
ctx = self._get_true_labels(ctx, pipeline, *args, **kwargs) | ||
self._run_handlers(ctx, pipeline, ActorStage.GET_TRUE_LABELS, *args, **kwargs) | ||
ctx = await self._get_true_labels(ctx, pipeline, *args, **kwargs) | ||
await self._run_handlers(ctx, pipeline, ActorStage.GET_TRUE_LABELS, *args, **kwargs) | ||
|
||
# get next node | ||
ctx = self._get_next_node(ctx, *args, **kwargs) | ||
self._run_handlers(ctx, pipeline, ActorStage.GET_NEXT_NODE, *args, **kwargs) | ||
await self._run_handlers(ctx, pipeline, ActorStage.GET_NEXT_NODE, *args, **kwargs) | ||
|
||
ctx.add_label(ctx.framework_states["actor"]["next_label"][:2]) | ||
|
||
# rewrite next node | ||
ctx = self._rewrite_next_node(ctx, *args, **kwargs) | ||
self._run_handlers(ctx, pipeline, ActorStage.REWRITE_NEXT_NODE, *args, **kwargs) | ||
await self._run_handlers(ctx, pipeline, ActorStage.REWRITE_NEXT_NODE, *args, **kwargs) | ||
|
||
# run pre response processing | ||
ctx = self._run_pre_response_processing(ctx, pipeline, *args, **kwargs) | ||
self._run_handlers(ctx, pipeline, ActorStage.RUN_PRE_RESPONSE_PROCESSING, *args, **kwargs) | ||
await self._run_pre_response_processing(ctx, pipeline, *args, **kwargs) | ||
await self._run_handlers(ctx, pipeline, ActorStage.RUN_PRE_RESPONSE_PROCESSING, *args, **kwargs) | ||
|
||
# create response | ||
ctx.framework_states["actor"]["response"] = ctx.framework_states["actor"][ | ||
"pre_response_processed_node" | ||
].run_response(ctx, pipeline, *args, **kwargs) | ||
self._run_handlers(ctx, pipeline, ActorStage.CREATE_RESPONSE, *args, **kwargs) | ||
ctx.framework_states["actor"]["response"] = await self.run_response( | ||
ctx.framework_states["actor"]["pre_response_processed_node"].response, ctx, pipeline, *args, **kwargs | ||
) | ||
await self._run_handlers(ctx, pipeline, ActorStage.CREATE_RESPONSE, *args, **kwargs) | ||
ctx.add_response(ctx.framework_states["actor"]["response"]) | ||
|
||
self._run_handlers(ctx, pipeline, ActorStage.FINISH_TURN, *args, **kwargs) | ||
await self._run_handlers(ctx, pipeline, ActorStage.FINISH_TURN, *args, **kwargs) | ||
if self._clean_turn_cache: | ||
cache_clear() | ||
|
||
|
@@ -177,20 +174,20 @@ def _get_previous_node(self, ctx: Context, *args, **kwargs) -> Context: | |
).get(ctx.framework_states["actor"]["previous_label"][1], Node()) | ||
return ctx | ||
|
||
def _get_true_labels(self, ctx: Context, pipeline: Pipeline, *args, **kwargs) -> Context: | ||
async def _get_true_labels(self, ctx: Context, pipeline: Pipeline, *args, **kwargs) -> Context: | ||
# GLOBAL | ||
ctx.framework_states["actor"]["global_transitions"] = ( | ||
self.script.get(GLOBAL, {}).get(GLOBAL, Node()).transitions | ||
) | ||
ctx.framework_states["actor"]["global_true_label"] = self._get_true_label( | ||
global_transitions_coro = self._get_true_label( | ||
ctx.framework_states["actor"]["global_transitions"], ctx, pipeline, GLOBAL, "global" | ||
) | ||
|
||
# LOCAL | ||
ctx.framework_states["actor"]["local_transitions"] = ( | ||
self.script.get(ctx.framework_states["actor"]["previous_label"][0], {}).get(LOCAL, Node()).transitions | ||
) | ||
ctx.framework_states["actor"]["local_true_label"] = self._get_true_label( | ||
local_transitions_coro = self._get_true_label( | ||
ctx.framework_states["actor"]["local_transitions"], | ||
ctx, | ||
pipeline, | ||
|
@@ -202,13 +199,18 @@ def _get_true_labels(self, ctx: Context, pipeline: Pipeline, *args, **kwargs) -> | |
ctx.framework_states["actor"]["node_transitions"] = ctx.framework_states["actor"][ | ||
"pre_transitions_processed_node" | ||
].transitions | ||
ctx.framework_states["actor"]["node_true_label"] = self._get_true_label( | ||
node_transitions_coro = self._get_true_label( | ||
ctx.framework_states["actor"]["node_transitions"], | ||
ctx, | ||
pipeline, | ||
ctx.framework_states["actor"]["previous_label"][0], | ||
"node", | ||
) | ||
( | ||
ctx.framework_states["actor"]["global_true_label"], | ||
ctx.framework_states["actor"]["local_true_label"], | ||
ctx.framework_states["actor"]["node_true_label"], | ||
) = await asyncio.gather(*[global_transitions_coro, local_transitions_coro, node_transitions_coro]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should discuss how we want the components to execute (parallel vs sequential).
After discussing how we should handle the newly asynchronized functions we should record that information in the documentation (and also in #252; other than the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree |
||
return ctx | ||
|
||
def _get_next_node(self, ctx: Context, *args, **kwargs) -> Context: | ||
|
@@ -262,25 +264,101 @@ def _overwrite_node( | |
overwritten_node.transitions = current_node.transitions | ||
return overwritten_node | ||
|
||
def _run_pre_transitions_processing(self, ctx: Context, pipeline: Pipeline, *args, **kwargs) -> Context: | ||
ctx.framework_states["actor"]["processed_node"] = copy.deepcopy(ctx.framework_states["actor"]["previous_node"]) | ||
ctx = ctx.framework_states["actor"]["previous_node"].run_pre_transitions_processing( | ||
ctx, pipeline, *args, **kwargs | ||
async def run_response( | ||
self, | ||
response: Optional[Union[Message, Callable[..., Message]]], | ||
ctx: Context, | ||
pipeline: Pipeline, | ||
*args, | ||
**kwargs, | ||
) -> Context: | ||
ruthenian8 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Executes the normalized response as an asynchronous function. | ||
See the details in the :py:func:`~normalize_response` function of `normalization.py`. | ||
""" | ||
response = normalize_response(response) | ||
return await wrap_sync_function_in_async(response, ctx, pipeline, *args, **kwargs) | ||
|
||
async def _run_processing_parallel( | ||
ruthenian8 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self, processing: dict, ctx: Context, pipeline: Pipeline, *args, **kwargs | ||
) -> None: | ||
""" | ||
Execute the processing functions for a particular node simultaneously, | ||
independent of the order. | ||
|
||
Picked depending on the value of the :py:class:`.Pipeline`'s `parallelize_processing` flag. | ||
""" | ||
results = await asyncio.gather( | ||
*[wrap_sync_function_in_async(func, ctx, pipeline, *args, **kwargs) for func in processing.values()], | ||
return_exceptions=True, | ||
) | ||
for exc, (processing_name, processing_func) in zip(results, processing.items()): | ||
if isinstance(exc, Exception): | ||
logger.error( | ||
f"Exception {exc} for processing_name={processing_name} and processing_func={processing_func}", | ||
exc_info=exc, | ||
) | ||
|
||
async def _run_processing_sequential( | ||
self, processing: dict, ctx: Context, pipeline: Pipeline, *args, **kwargs | ||
) -> None: | ||
""" | ||
Execute the processing functions for a particular node in-order. | ||
|
||
Picked depending on the value of the :py:class:`.Pipeline`'s `parallelize_processing` flag. | ||
""" | ||
for processing_name, processing_func in processing.items(): | ||
try: | ||
await wrap_sync_function_in_async(processing_func, ctx, pipeline, *args, **kwargs) | ||
except Exception as exc: | ||
logger.error( | ||
f"Exception {exc} for processing_name={processing_name} and processing_func={processing_func}", | ||
exc_info=exc, | ||
) | ||
|
||
async def _run_pre_transitions_processing(self, ctx: Context, pipeline: Pipeline) -> None: | ||
""" | ||
Run `PRE_TRANSITIONS_PROCESSING` functions for a particular node. | ||
Pre-transition processing functions can modify the context state | ||
before the direction of the next transition is determined depending on that state. | ||
|
||
The execution order depends on the value of the :py:class:`.Pipeline`'s | ||
`parallelize_processing` flag. | ||
""" | ||
ctx.framework_states["actor"]["processed_node"] = copy.deepcopy(ctx.framework_states["actor"]["previous_node"]) | ||
pre_transitions_processing = ctx.framework_states["actor"]["previous_node"].pre_transitions_processing | ||
|
||
if pipeline.parallelize_processing: | ||
await self._run_processing_parallel(pre_transitions_processing, ctx, pipeline) | ||
else: | ||
await self._run_processing_sequential(pre_transitions_processing, ctx, pipeline) | ||
|
||
ctx.framework_states["actor"]["pre_transitions_processed_node"] = ctx.framework_states["actor"][ | ||
"processed_node" | ||
] | ||
del ctx.framework_states["actor"]["processed_node"] | ||
return ctx | ||
|
||
def _run_pre_response_processing(self, ctx: Context, pipeline: Pipeline, *args, **kwargs) -> Context: | ||
async def _run_pre_response_processing(self, ctx: Context, pipeline: Pipeline) -> None: | ||
""" | ||
Run `PRE_RESPONSE_PROCESSING` functions for a particular node. | ||
Pre-response processing functions can modify the response before it is | ||
returned to the user. | ||
|
||
The execution order depends on the value of the :py:class:`.Pipeline`'s | ||
`parallelize_processing` flag. | ||
""" | ||
ctx.framework_states["actor"]["processed_node"] = copy.deepcopy(ctx.framework_states["actor"]["next_node"]) | ||
ctx = ctx.framework_states["actor"]["next_node"].run_pre_response_processing(ctx, pipeline, *args, **kwargs) | ||
pre_response_processing = ctx.framework_states["actor"]["next_node"].pre_response_processing | ||
|
||
if pipeline.parallelize_processing: | ||
await self._run_processing_parallel(pre_response_processing, ctx, pipeline) | ||
else: | ||
await self._run_processing_sequential(pre_response_processing, ctx, pipeline) | ||
|
||
ctx.framework_states["actor"]["pre_response_processed_node"] = ctx.framework_states["actor"]["processed_node"] | ||
del ctx.framework_states["actor"]["processed_node"] | ||
return ctx | ||
|
||
def _get_true_label( | ||
async def _get_true_label( | ||
self, | ||
transitions: dict, | ||
ctx: Context, | ||
|
@@ -291,10 +369,17 @@ def _get_true_label( | |
**kwargs, | ||
) -> Optional[NodeLabel3Type]: | ||
true_labels = [] | ||
for label, condition in transitions.items(): | ||
if self.condition_handler(condition, ctx, pipeline, *args, **kwargs): | ||
|
||
cond_booleans = await asyncio.gather( | ||
*( | ||
self.condition_handler(transition[1], ctx, pipeline, *args, **kwargs) | ||
for transition in transitions.items() | ||
) | ||
) | ||
for label, cond_is_true in zip(transitions, cond_booleans): | ||
ruthenian8 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if cond_is_true: | ||
if callable(label): | ||
label = label(ctx, pipeline, *args, **kwargs) | ||
label = await wrap_sync_function_in_async(label, ctx, pipeline, *args, **kwargs) | ||
# TODO: explicit handling of errors | ||
if label is None: | ||
continue | ||
ruthenian8 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
@@ -311,8 +396,10 @@ def _get_true_label( | |
logger.debug(f"{transition_info} transitions sorted by priority = {true_labels}") | ||
return true_label | ||
|
||
def _run_handlers(self, ctx, pipeline: Pipeline, actor_stage: ActorStage, *args, **kwargs): | ||
[handler(ctx, pipeline, *args, **kwargs) for handler in self.handlers.get(actor_stage, [])] | ||
async def _run_handlers(self, ctx, pipeline: Pipeline, actor_stage: ActorStage): | ||
stage_handlers = self.handlers.get(actor_stage, []) | ||
async_handlers = [wrap_sync_function_in_async(handler, ctx, pipeline) for handler in stage_handlers] | ||
await asyncio.gather(*async_handlers) | ||
|
||
def _choose_label( | ||
self, specific_label: Optional[NodeLabel3Type], general_label: Optional[NodeLabel3Type] | ||
|
@@ -360,7 +447,7 @@ def validate_script(self, pipeline: Pipeline, verbose: bool = True): | |
# validate responsing | ||
response_func = normalize_response(node.response) | ||
try: | ||
response_result = response_func(ctx, pipeline) | ||
response_result = asyncio.run(wrap_sync_function_in_async(response_func, ctx, pipeline)) | ||
if not isinstance(response_result, Message): | ||
msg = ( | ||
"Expected type of response_result is `Message`.\n" | ||
|
@@ -390,7 +477,7 @@ def validate_script(self, pipeline: Pipeline, verbose: bool = True): | |
return error_msgs | ||
|
||
|
||
def default_condition_handler( | ||
async def default_condition_handler( | ||
condition: Callable, ctx: Context, pipeline: Pipeline, *args, **kwargs | ||
) -> Callable[[Context, Pipeline, Any, Any], bool]: | ||
""" | ||
|
@@ -400,4 +487,4 @@ def default_condition_handler( | |
:param ctx: Context of current condition. | ||
:param pipeline: Pipeline we use in this condition. | ||
""" | ||
return condition(ctx, pipeline, *args, **kwargs) | ||
return await wrap_sync_function_in_async(condition, ctx, pipeline, *args, **kwargs) |
ruthenian8 marked this conversation as resolved.
Show resolved
Hide resolved
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we want to make transition functions asynchronous as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can leave such methods as
get_previous_node
andget_next_node
synchronous for two reasons: