Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update API #253

Merged
merged 36 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
a110e62
Move updates from merge/slots; ensure passing tests
ruthenian8 Oct 13, 2023
bd67607
move run_response from Node object to Actor object; execute response …
ruthenian8 Oct 13, 2023
12eed6e
update documentation && change the signature of processing-related ac…
ruthenian8 Oct 13, 2023
43533b5
remove overwrite_current_node_in_processing method
ruthenian8 Oct 17, 2023
bb42016
allow asynchronous condition functions
ruthenian8 Oct 25, 2023
039840b
merge branch 'dev'
ruthenian8 Oct 25, 2023
5dc75a1
update tutorial
ruthenian8 Oct 25, 2023
b274383
fix merge errors
RLKRo Oct 30, 2023
dda1bc0
make import absolute
RLKRo Oct 30, 2023
2d63c03
Apply suggestions by @RLKRo
ruthenian8 Nov 2, 2023
e2d8377
Update actor & group removing context assignment statements
ruthenian8 Nov 2, 2023
5df3c55
test responses for None; use validate_label once
ruthenian8 Nov 20, 2023
34c7d1a
merge dev into feat/async_handlers
ruthenian8 Nov 20, 2023
6378aa5
return context from the 'pipeline' service group
ruthenian8 Nov 20, 2023
84947ec
update tutorial tests for script/core;
ruthenian8 Nov 20, 2023
4878de9
Remove random seed; use seed value of 42
ruthenian8 Nov 27, 2023
95a5459
Deprecate overwrite_current_node
ruthenian8 Nov 27, 2023
8b0db0c
update deprecation details
ruthenian8 Nov 28, 2023
3cb2b77
fix: try to resolve docs bug
ruthenian8 Nov 28, 2023
ed126ab
update docs
ruthenian8 Nov 28, 2023
82a8c32
Update signature typings
ruthenian8 Nov 28, 2023
2a39d80
revert tutorial changes
RLKRo Nov 30, 2023
51a8624
remove overwrite_current_node_in_processing instead of deprecating
RLKRo Nov 30, 2023
63ba0f7
remove most context returns; update docs and typing
RLKRo Nov 30, 2023
7b44375
fix actor test
RLKRo Dec 1, 2023
1afa571
codestyle
RLKRo Dec 1, 2023
58fa9aa
add parallel processing test
RLKRo Dec 4, 2023
cd41155
make global/local/node conditions run sequentially
RLKRo Dec 4, 2023
3c70317
Merge branch 'dev' into feat/async_handlers
RLKRo Dec 5, 2023
74ea26d
replace args, kwargs with update_ctx_misc
RLKRo Dec 7, 2023
223b284
codestyle
RLKRo Dec 7, 2023
ce3a8a2
fix typing
RLKRo Dec 10, 2023
48b6934
line collapse
RLKRo Dec 10, 2023
d8075f0
rename function to fit inside code blocks
RLKRo Dec 10, 2023
6e0307f
fix tutorial function signatures
RLKRo Dec 10, 2023
1079a51
Merge branch 'dev' into feat/async_handlers
RLKRo Dec 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 130 additions & 43 deletions dff/pipeline/pipeline/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,9 @@

Overall, the actor acts as a bridge between the user's input and the dialog graph,
making sure that the conversation follows the expected flow and providing a personalized experience to the user.

Below you can see a diagram of user request processing with Actor.
Both `request` and `response` are saved to :py:class:`.Context`.

.. figure:: /_static/drawio/dfe/user_actor.png
"""
import logging
import asyncio
from typing import Union, Callable, Optional, Dict, List, Any, ForwardRef
import copy

Expand All @@ -34,6 +30,7 @@
from dff.script.core.script import Script, Node
from dff.script.core.normalization import normalize_label, normalize_response
from dff.script.core.keywords import GLOBAL, LOCAL
from ..service.utils import wrap_sync_function_in_async

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -109,51 +106,51 @@ def __init__(
# NB! The following API is highly experimental and may be removed at ANY time WITHOUT FURTHER NOTICE!!
self._clean_turn_cache = True

def __call__(
async def __call__(
self, pipeline: Pipeline, ctx: Optional[Union[Context, dict, str]] = None, *args, **kwargs
) -> Union[Context, dict, str]:
# context init
ctx = self._context_init(ctx, *args, **kwargs)
self._run_handlers(ctx, pipeline, ActorStage.CONTEXT_INIT, *args, **kwargs)
await self._run_handlers(ctx, pipeline, ActorStage.CONTEXT_INIT, *args, **kwargs)

# get previous node
ctx = self._get_previous_node(ctx, *args, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we want to make transition functions asynchronous as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can leave such methods as get_previous_node and get_next_node synchronous for two reasons:

  1. all they do is get values from dictionaries, so making them asynchronous won't result in a performance improvement
  2. we would have to await them inside the actor call, i.e. use them like regular blocking functions

self._run_handlers(ctx, pipeline, ActorStage.GET_PREVIOUS_NODE, *args, **kwargs)
await self._run_handlers(ctx, pipeline, ActorStage.GET_PREVIOUS_NODE, *args, **kwargs)

# rewrite previous node
ctx = self._rewrite_previous_node(ctx, *args, **kwargs)
self._run_handlers(ctx, pipeline, ActorStage.REWRITE_PREVIOUS_NODE, *args, **kwargs)
await self._run_handlers(ctx, pipeline, ActorStage.REWRITE_PREVIOUS_NODE, *args, **kwargs)

# run pre transitions processing
ctx = self._run_pre_transitions_processing(ctx, pipeline, *args, **kwargs)
self._run_handlers(ctx, pipeline, ActorStage.RUN_PRE_TRANSITIONS_PROCESSING, *args, **kwargs)
await self._run_pre_transitions_processing(ctx, pipeline, *args, **kwargs)
ruthenian8 marked this conversation as resolved.
Show resolved Hide resolved
ruthenian8 marked this conversation as resolved.
Show resolved Hide resolved
await self._run_handlers(ctx, pipeline, ActorStage.RUN_PRE_TRANSITIONS_PROCESSING, *args, **kwargs)

# get true labels for scopes (GLOBAL, LOCAL, NODE)
ctx = self._get_true_labels(ctx, pipeline, *args, **kwargs)
self._run_handlers(ctx, pipeline, ActorStage.GET_TRUE_LABELS, *args, **kwargs)
ctx = await self._get_true_labels(ctx, pipeline, *args, **kwargs)
await self._run_handlers(ctx, pipeline, ActorStage.GET_TRUE_LABELS, *args, **kwargs)

# get next node
ctx = self._get_next_node(ctx, *args, **kwargs)
self._run_handlers(ctx, pipeline, ActorStage.GET_NEXT_NODE, *args, **kwargs)
await self._run_handlers(ctx, pipeline, ActorStage.GET_NEXT_NODE, *args, **kwargs)

ctx.add_label(ctx.framework_states["actor"]["next_label"][:2])

# rewrite next node
ctx = self._rewrite_next_node(ctx, *args, **kwargs)
self._run_handlers(ctx, pipeline, ActorStage.REWRITE_NEXT_NODE, *args, **kwargs)
await self._run_handlers(ctx, pipeline, ActorStage.REWRITE_NEXT_NODE, *args, **kwargs)

# run pre response processing
ctx = self._run_pre_response_processing(ctx, pipeline, *args, **kwargs)
self._run_handlers(ctx, pipeline, ActorStage.RUN_PRE_RESPONSE_PROCESSING, *args, **kwargs)
await self._run_pre_response_processing(ctx, pipeline, *args, **kwargs)
await self._run_handlers(ctx, pipeline, ActorStage.RUN_PRE_RESPONSE_PROCESSING, *args, **kwargs)

# create response
ctx.framework_states["actor"]["response"] = ctx.framework_states["actor"][
"pre_response_processed_node"
].run_response(ctx, pipeline, *args, **kwargs)
self._run_handlers(ctx, pipeline, ActorStage.CREATE_RESPONSE, *args, **kwargs)
ctx.framework_states["actor"]["response"] = await self.run_response(
ctx.framework_states["actor"]["pre_response_processed_node"].response, ctx, pipeline, *args, **kwargs
)
await self._run_handlers(ctx, pipeline, ActorStage.CREATE_RESPONSE, *args, **kwargs)
ctx.add_response(ctx.framework_states["actor"]["response"])

self._run_handlers(ctx, pipeline, ActorStage.FINISH_TURN, *args, **kwargs)
await self._run_handlers(ctx, pipeline, ActorStage.FINISH_TURN, *args, **kwargs)
if self._clean_turn_cache:
cache_clear()

Expand All @@ -177,20 +174,20 @@ def _get_previous_node(self, ctx: Context, *args, **kwargs) -> Context:
).get(ctx.framework_states["actor"]["previous_label"][1], Node())
return ctx

def _get_true_labels(self, ctx: Context, pipeline: Pipeline, *args, **kwargs) -> Context:
async def _get_true_labels(self, ctx: Context, pipeline: Pipeline, *args, **kwargs) -> Context:
# GLOBAL
ctx.framework_states["actor"]["global_transitions"] = (
self.script.get(GLOBAL, {}).get(GLOBAL, Node()).transitions
)
ctx.framework_states["actor"]["global_true_label"] = self._get_true_label(
global_transitions_coro = self._get_true_label(
ctx.framework_states["actor"]["global_transitions"], ctx, pipeline, GLOBAL, "global"
)

# LOCAL
ctx.framework_states["actor"]["local_transitions"] = (
self.script.get(ctx.framework_states["actor"]["previous_label"][0], {}).get(LOCAL, Node()).transitions
)
ctx.framework_states["actor"]["local_true_label"] = self._get_true_label(
local_transitions_coro = self._get_true_label(
ctx.framework_states["actor"]["local_transitions"],
ctx,
pipeline,
Expand All @@ -202,13 +199,18 @@ def _get_true_labels(self, ctx: Context, pipeline: Pipeline, *args, **kwargs) ->
ctx.framework_states["actor"]["node_transitions"] = ctx.framework_states["actor"][
"pre_transitions_processed_node"
].transitions
ctx.framework_states["actor"]["node_true_label"] = self._get_true_label(
node_transitions_coro = self._get_true_label(
ctx.framework_states["actor"]["node_transitions"],
ctx,
pipeline,
ctx.framework_states["actor"]["previous_label"][0],
"node",
)
(
ctx.framework_states["actor"]["global_true_label"],
ctx.framework_states["actor"]["local_true_label"],
ctx.framework_states["actor"]["node_true_label"],
) = await asyncio.gather(*[global_transitions_coro, local_transitions_coro, node_transitions_coro])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should discuss how we want the components to execute (parallel vs sequential).
Currently:

  • pre-response/pre-transition functions depend on pipeline.parallelize_processing
  • Script conditions are executed in parallel
  • Labels are executed sequentially inside the groups GLOBAL, LOCAL, NODE (two labels are executed sequentially if they are from the same group, but in parallel if they are from different groups)
  • Actor handlers - in parallel
  • Service handlers - depends on ServiceGroup.asynchronous
  • Extra handlers - depends on asynchronous of the extra handlers

After discussing how we should handle the newly asynchronized functions we should record that information in the documentation (and also in #252; other than the asynchronous option, the guide should also mention other options such as the timeout option for extra handlers, @pseusys).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree

return ctx

def _get_next_node(self, ctx: Context, *args, **kwargs) -> Context:
Expand Down Expand Up @@ -262,25 +264,101 @@ def _overwrite_node(
overwritten_node.transitions = current_node.transitions
return overwritten_node

def _run_pre_transitions_processing(self, ctx: Context, pipeline: Pipeline, *args, **kwargs) -> Context:
ctx.framework_states["actor"]["processed_node"] = copy.deepcopy(ctx.framework_states["actor"]["previous_node"])
ctx = ctx.framework_states["actor"]["previous_node"].run_pre_transitions_processing(
ctx, pipeline, *args, **kwargs
async def run_response(
self,
response: Optional[Union[Message, Callable[..., Message]]],
ctx: Context,
pipeline: Pipeline,
*args,
**kwargs,
) -> Context:
ruthenian8 marked this conversation as resolved.
Show resolved Hide resolved
"""
Executes the normalized response as an asynchronous function.
See the details in the :py:func:`~normalize_response` function of `normalization.py`.
"""
response = normalize_response(response)
return await wrap_sync_function_in_async(response, ctx, pipeline, *args, **kwargs)

async def _run_processing_parallel(
ruthenian8 marked this conversation as resolved.
Show resolved Hide resolved
self, processing: dict, ctx: Context, pipeline: Pipeline, *args, **kwargs
) -> None:
"""
Execute the processing functions for a particular node simultaneously,
independent of the order.

Picked depending on the value of the :py:class:`.Pipeline`'s `parallelize_processing` flag.
"""
results = await asyncio.gather(
*[wrap_sync_function_in_async(func, ctx, pipeline, *args, **kwargs) for func in processing.values()],
return_exceptions=True,
)
for exc, (processing_name, processing_func) in zip(results, processing.items()):
if isinstance(exc, Exception):
logger.error(
f"Exception {exc} for processing_name={processing_name} and processing_func={processing_func}",
exc_info=exc,
)

async def _run_processing_sequential(
self, processing: dict, ctx: Context, pipeline: Pipeline, *args, **kwargs
) -> None:
"""
Execute the processing functions for a particular node in-order.

Picked depending on the value of the :py:class:`.Pipeline`'s `parallelize_processing` flag.
"""
for processing_name, processing_func in processing.items():
try:
await wrap_sync_function_in_async(processing_func, ctx, pipeline, *args, **kwargs)
except Exception as exc:
logger.error(
f"Exception {exc} for processing_name={processing_name} and processing_func={processing_func}",
exc_info=exc,
)

async def _run_pre_transitions_processing(self, ctx: Context, pipeline: Pipeline) -> None:
"""
Run `PRE_TRANSITIONS_PROCESSING` functions for a particular node.
Pre-transition processing functions can modify the context state
before the direction of the next transition is determined depending on that state.

The execution order depends on the value of the :py:class:`.Pipeline`'s
`parallelize_processing` flag.
"""
ctx.framework_states["actor"]["processed_node"] = copy.deepcopy(ctx.framework_states["actor"]["previous_node"])
pre_transitions_processing = ctx.framework_states["actor"]["previous_node"].pre_transitions_processing

if pipeline.parallelize_processing:
await self._run_processing_parallel(pre_transitions_processing, ctx, pipeline)
else:
await self._run_processing_sequential(pre_transitions_processing, ctx, pipeline)

ctx.framework_states["actor"]["pre_transitions_processed_node"] = ctx.framework_states["actor"][
"processed_node"
]
del ctx.framework_states["actor"]["processed_node"]
return ctx

def _run_pre_response_processing(self, ctx: Context, pipeline: Pipeline, *args, **kwargs) -> Context:
async def _run_pre_response_processing(self, ctx: Context, pipeline: Pipeline) -> None:
"""
Run `PRE_RESPONSE_PROCESSING` functions for a particular node.
Pre-response processing functions can modify the response before it is
returned to the user.

The execution order depends on the value of the :py:class:`.Pipeline`'s
`parallelize_processing` flag.
"""
ctx.framework_states["actor"]["processed_node"] = copy.deepcopy(ctx.framework_states["actor"]["next_node"])
ctx = ctx.framework_states["actor"]["next_node"].run_pre_response_processing(ctx, pipeline, *args, **kwargs)
pre_response_processing = ctx.framework_states["actor"]["next_node"].pre_response_processing

if pipeline.parallelize_processing:
await self._run_processing_parallel(pre_response_processing, ctx, pipeline)
else:
await self._run_processing_sequential(pre_response_processing, ctx, pipeline)

ctx.framework_states["actor"]["pre_response_processed_node"] = ctx.framework_states["actor"]["processed_node"]
del ctx.framework_states["actor"]["processed_node"]
return ctx

def _get_true_label(
async def _get_true_label(
self,
transitions: dict,
ctx: Context,
Expand All @@ -291,10 +369,17 @@ def _get_true_label(
**kwargs,
) -> Optional[NodeLabel3Type]:
true_labels = []
for label, condition in transitions.items():
if self.condition_handler(condition, ctx, pipeline, *args, **kwargs):

cond_booleans = await asyncio.gather(
*(
self.condition_handler(transition[1], ctx, pipeline, *args, **kwargs)
for transition in transitions.items()
)
)
for label, cond_is_true in zip(transitions, cond_booleans):
ruthenian8 marked this conversation as resolved.
Show resolved Hide resolved
if cond_is_true:
if callable(label):
label = label(ctx, pipeline, *args, **kwargs)
label = await wrap_sync_function_in_async(label, ctx, pipeline, *args, **kwargs)
# TODO: explicit handling of errors
if label is None:
continue
ruthenian8 marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -311,8 +396,10 @@ def _get_true_label(
logger.debug(f"{transition_info} transitions sorted by priority = {true_labels}")
return true_label

def _run_handlers(self, ctx, pipeline: Pipeline, actor_stage: ActorStage, *args, **kwargs):
[handler(ctx, pipeline, *args, **kwargs) for handler in self.handlers.get(actor_stage, [])]
async def _run_handlers(self, ctx, pipeline: Pipeline, actor_stage: ActorStage):
stage_handlers = self.handlers.get(actor_stage, [])
async_handlers = [wrap_sync_function_in_async(handler, ctx, pipeline) for handler in stage_handlers]
await asyncio.gather(*async_handlers)

def _choose_label(
self, specific_label: Optional[NodeLabel3Type], general_label: Optional[NodeLabel3Type]
Expand Down Expand Up @@ -360,7 +447,7 @@ def validate_script(self, pipeline: Pipeline, verbose: bool = True):
# validate responsing
response_func = normalize_response(node.response)
try:
response_result = response_func(ctx, pipeline)
response_result = asyncio.run(wrap_sync_function_in_async(response_func, ctx, pipeline))
if not isinstance(response_result, Message):
msg = (
"Expected type of response_result is `Message`.\n"
Expand Down Expand Up @@ -390,7 +477,7 @@ def validate_script(self, pipeline: Pipeline, verbose: bool = True):
return error_msgs


def default_condition_handler(
async def default_condition_handler(
condition: Callable, ctx: Context, pipeline: Pipeline, *args, **kwargs
) -> Callable[[Context, Pipeline, Any, Any], bool]:
"""
Expand All @@ -400,4 +487,4 @@ def default_condition_handler(
:param ctx: Context of current condition.
:param pipeline: Pipeline we use in this condition.
"""
return condition(ctx, pipeline, *args, **kwargs)
return await wrap_sync_function_in_async(condition, ctx, pipeline, *args, **kwargs)
6 changes: 6 additions & 0 deletions dff/pipeline/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class Pipeline:

- `_services_pipeline` is a pipeline root :py:class:`~.ServiceGroup` object,
- `actor` is a pipeline actor, found among services.
:param parallelize_processing: This flag determines whether or not the functions
defined in the ``PRE_RESPONSE_PROCESSING`` and ``PRE_TRANSITIONS_PROCESSING`` sections
of the script should be parallelized over respective groups.

"""

Expand All @@ -94,6 +97,7 @@ def __init__(
after_handler: Optional[ExtraHandlerBuilder] = None,
timeout: Optional[float] = None,
optimization_warnings: bool = False,
parallelize_processing: bool = False,
):
self.actor: Actor = None
self.messenger_interface = CLIMessengerInterface() if messenger_interface is None else messenger_interface
Expand Down Expand Up @@ -127,6 +131,8 @@ def __init__(
if optimization_warnings:
self._services_pipeline.log_optimization_warnings()

self.parallelize_processing = parallelize_processing

# NB! The following API is highly experimental and may be removed at ANY time WITHOUT FURTHER NOTICE!!
self._clean_turn_cache = True
if self._clean_turn_cache:
Expand Down
6 changes: 3 additions & 3 deletions dff/pipeline/service/service.py
ruthenian8 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ async def _run_handler(self, ctx: Context, pipeline: Pipeline):
else:
raise Exception(f"Too many parameters required for service '{self.name}' handler: {handler_params}!")

def _run_as_actor(self, ctx: Context, pipeline: Pipeline):
async def _run_as_actor(self, ctx: Context, pipeline: Pipeline):
"""
Method for running this service if its handler is an `Actor`.
Catches runtime exceptions.
Expand All @@ -133,7 +133,7 @@ def _run_as_actor(self, ctx: Context, pipeline: Pipeline):
:return: Context, mutated by actor.
"""
try:
ctx = pipeline.actor(pipeline, ctx)
ctx = await pipeline.actor(pipeline, ctx)
self._set_state(ctx, ComponentExecutionState.FINISHED)
except Exception as exc:
self._set_state(ctx, ComponentExecutionState.FAILED)
Expand Down Expand Up @@ -172,7 +172,7 @@ async def _run(self, ctx: Context, pipeline: Optional[Pipeline] = None) -> Optio
await self.run_extra_handler(ExtraHandlerType.BEFORE, ctx, pipeline)

if isinstance(self.handler, str) and self.handler == "ACTOR":
ctx = self._run_as_actor(ctx, pipeline)
ctx = await self._run_as_actor(ctx, pipeline)
else:
await self._run_as_service(ctx, pipeline)

Expand Down
18 changes: 0 additions & 18 deletions dff/script/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,23 +278,5 @@ def current_node(self) -> Optional[Node]:

return node

def overwrite_current_node_in_processing(self, processed_node: Node):
"""
Set the current node to be `processed_node`.
This method only works in processing functions (pre-response and pre-transition).

The actual current node is not changed.

:param processed_node: `node` to set as the current node.
"""
is_processing = self.framework_states.get("actor", {}).get("processed_node")
if is_processing:
self.framework_states["actor"]["processed_node"] = Node.model_validate(processed_node)
else:
logger.warning(
f"The `{self.overwrite_current_node_in_processing.__name__}` "
"method can only be called from processing functions (either pre-response or pre-transition)."
)


Context.model_rebuild()
Loading