diff --git a/dff/pipeline/__init__.py b/dff/pipeline/__init__.py index 1b345f647..bf828179d 100644 --- a/dff/pipeline/__init__.py +++ b/dff/pipeline/__init__.py @@ -13,6 +13,7 @@ ComponentExecutionState, GlobalExtraHandlerType, ExtraHandlerType, + PIPELINE_EXCEPTION_KEY, PIPELINE_STATE_KEY, StartConditionCheckerFunction, StartConditionCheckerAggregationFunction, @@ -27,6 +28,7 @@ PipelineBuilder, ) +from .pipeline.actor import LATEST_EXCEPTION_KEY, LATEST_FAILED_NODE_KEY from .pipeline.pipeline import Pipeline, ACTOR from .service.extra import BeforeHandler, AfterHandler diff --git a/dff/pipeline/pipeline/actor.py b/dff/pipeline/pipeline/actor.py index fdddf542a..589e69131 100644 --- a/dff/pipeline/pipeline/actor.py +++ b/dff/pipeline/pipeline/actor.py @@ -25,7 +25,7 @@ from __future__ import annotations import logging import asyncio -from typing import Union, Callable, Optional, Dict, List, TYPE_CHECKING +from typing import Type, Union, Callable, Optional, Dict, List, TYPE_CHECKING import copy from dff.utils.turn_caching import cache_clear @@ -37,6 +37,10 @@ from dff.script.core.normalization import normalize_label, normalize_response from dff.script.core.keywords import GLOBAL, LOCAL from dff.pipeline.service.utils import wrap_sync_function_in_async +from dff.pipeline.types import PIPELINE_EXCEPTION_KEY + +LATEST_EXCEPTION_KEY = "LATEST_EXCEPTION" +LATEST_FAILED_NODE_KEY = "LATEST_FAILED_NODE" logger = logging.getLogger(__name__) @@ -131,7 +135,7 @@ async def __call__(self, pipeline: Pipeline, ctx: Context): await self._run_handlers(ctx, pipeline, ActorStage.RUN_PRE_TRANSITIONS_PROCESSING) # get true labels for scopes (GLOBAL, LOCAL, NODE) - await self._get_true_labels(ctx, pipeline) + await self._get_true_labels(ctx, pipeline, False) await self._run_handlers(ctx, pipeline, ActorStage.GET_TRUE_LABELS) # get next node @@ -161,6 +165,43 @@ async def __call__(self, pipeline: Pipeline, ctx: Context): del ctx.framework_states["actor"] + async def process_exception(self, pipeline: Pipeline, ctx: Context): + # context init + self._context_init(ctx) + + # get previous node + self._get_previous_node(ctx) + + # rewrite previous node + self._rewrite_previous_node(ctx) + + # run pre transitions processing + await self._run_pre_transitions_processing(ctx, pipeline) + + # get true labels for scopes (GLOBAL, LOCAL, NODE) + await self._get_true_labels(ctx, pipeline, True) + + # get next node + self._get_next_node(ctx) + + ctx.add_label(ctx.framework_states["actor"]["next_label"][:2]) + + # rewrite next node + self._rewrite_next_node(ctx) + + # run pre response processing + await self._run_pre_response_processing(ctx, pipeline) + + # create response + ctx.framework_states["actor"]["response"] = await self.run_response( + ctx.framework_states["actor"]["pre_response_processed_node"].response, ctx, pipeline + ) + ctx.add_response(ctx.framework_states["actor"]["response"]) + + if self._clean_turn_cache: + cache_clear() + del ctx.framework_states["actor"] + @staticmethod def _context_init(ctx: Optional[Union[Context, dict, str]] = None): ctx.framework_states["actor"] = {} @@ -173,13 +214,13 @@ def _get_previous_node(self, ctx: Context): ctx.framework_states["actor"]["previous_label"][0], {} ).get(ctx.framework_states["actor"]["previous_label"][1], Node()) - async def _get_true_labels(self, ctx: Context, pipeline: Pipeline): + async def _get_true_labels(self, ctx: Context, pipeline: Pipeline, is_exceptional: bool): # GLOBAL ctx.framework_states["actor"]["global_transitions"] = ( self.script.get(GLOBAL, {}).get(GLOBAL, Node()).transitions ) ctx.framework_states["actor"]["global_true_label"] = await self._get_true_label( - ctx.framework_states["actor"]["global_transitions"], ctx, pipeline, GLOBAL, "global" + ctx.framework_states["actor"]["global_transitions"], ctx, pipeline, GLOBAL, is_exceptional, "global" ) # LOCAL @@ -191,6 +232,7 @@ async def _get_true_labels(self, ctx: Context, pipeline: Pipeline): ctx, pipeline, ctx.framework_states["actor"]["previous_label"][0], + is_exceptional, "local", ) @@ -203,6 +245,7 @@ async def _get_true_labels(self, ctx: Context, pipeline: Pipeline): ctx, pipeline, ctx.framework_states["actor"]["previous_label"][0], + is_exceptional, "node", ) @@ -346,14 +389,21 @@ async def _get_true_label( ctx: Context, pipeline: Pipeline, flow_label: LabelType, + is_exceptional: bool, transition_info: str = "", ) -> Optional[NodeLabel3Type]: true_labels = [] - cond_booleans = await asyncio.gather( - *(self.condition_handler(condition, ctx, pipeline) for condition in transitions.values()) - ) - for label, cond_is_true in zip(transitions.keys(), cond_booleans): + cond_values = await asyncio.gather(*(self.condition_handler(condition, ctx, pipeline) for condition in transitions.values())) + cond_items = list(zip(transitions.keys(), cond_values)) + + if is_exceptional: + exception = ctx.framework_states[PIPELINE_EXCEPTION_KEY][LATEST_EXCEPTION_KEY] + cond_items = [(label, isinstance(exception, type(value))) for label, value in cond_items if issubclass(type(value), BaseException)] + else: + cond_items = [(label, value) for label, value in cond_items if isinstance(value, bool)] + + for label, cond_is_true in cond_items: if cond_is_true: if callable(label): label = await wrap_sync_function_in_async(label, ctx, pipeline) diff --git a/dff/pipeline/pipeline/pipeline.py b/dff/pipeline/pipeline/pipeline.py index fb548e9f5..1406a0aee 100644 --- a/dff/pipeline/pipeline/pipeline.py +++ b/dff/pipeline/pipeline/pipeline.py @@ -18,13 +18,13 @@ from typing import Union, List, Dict, Optional, Hashable, Callable from dff.context_storages import DBContextStorage -from dff.script import Script, Context, ActorStage -from dff.script import NodeLabel2Type, Message +from dff.script import Script, Context, ActorStage, NodeLabel2Type, Message from dff.utils.turn_caching import cache_clear from dff.messengers.common import MessengerInterface, CLIMessengerInterface from ..service.group import ServiceGroup from ..types import ( + ComponentExecutionState, ServiceBuilder, ServiceGroupBuilder, PipelineBuilder, @@ -32,7 +32,7 @@ ExtraHandlerFunction, ExtraHandlerBuilder, ) -from ..types import PIPELINE_STATE_KEY +from ..types import PIPELINE_EXCEPTION_KEY, PIPELINE_STATE_KEY from .utils import finalize_service_group, pretty_format_component_info_dict from dff.pipeline.pipeline.actor import Actor @@ -343,13 +343,18 @@ async def _run_pipeline( ctx.misc.update(update_ctx_misc) ctx.framework_states[PIPELINE_STATE_KEY] = {} + ctx.framework_states[PIPELINE_EXCEPTION_KEY] = {} ctx.add_request(request) result = await self._services_pipeline(ctx, self) if asyncio.iscoroutine(result): await result + if self._services_pipeline.get_state(ctx) == ComponentExecutionState.FAILED: + await self.actor.process_exception(self, ctx) + del ctx.framework_states[PIPELINE_STATE_KEY] + del ctx.framework_states[PIPELINE_EXCEPTION_KEY] if isinstance(self.context_storage, DBContextStorage): await self.context_storage.set_item_async(ctx_id, ctx) diff --git a/dff/pipeline/service/service.py b/dff/pipeline/service/service.py index 0895380f8..9e5b07f93 100644 --- a/dff/pipeline/service/service.py +++ b/dff/pipeline/service/service.py @@ -18,12 +18,14 @@ from .utils import wrap_sync_function_in_async, collect_defined_constructor_parameters_to_dict, _get_attrs_with_updates from ..types import ( + PIPELINE_EXCEPTION_KEY, ServiceBuilder, StartConditionCheckerFunction, ComponentExecutionState, ExtraHandlerBuilder, ExtraHandlerType, ) +from ..pipeline.actor import LATEST_EXCEPTION_KEY, LATEST_FAILED_NODE_KEY from ..pipeline.component import PipelineComponent logger = logging.getLogger(__name__) @@ -133,6 +135,13 @@ async def _run_as_actor(self, ctx: Context, pipeline: Pipeline) -> None: await pipeline.actor(pipeline, ctx) self._set_state(ctx, ComponentExecutionState.FINISHED) except Exception as exc: + if "actor" in ctx.framework_states: + last_label = ctx.framework_states["actor"]["next_label"] + latest_node = f"{self.name}:{last_label[0]}:{last_label[1]}" + else: + latest_node = self.name + ctx.framework_states[PIPELINE_EXCEPTION_KEY][LATEST_EXCEPTION_KEY] = exc + ctx.framework_states[PIPELINE_EXCEPTION_KEY][LATEST_FAILED_NODE_KEY] = latest_node self._set_state(ctx, ComponentExecutionState.FAILED) logger.error(f"Actor '{self.name}' execution failed!\n{exc}") @@ -152,6 +161,8 @@ async def _run_as_service(self, ctx: Context, pipeline: Pipeline) -> None: else: self._set_state(ctx, ComponentExecutionState.NOT_RUN) except Exception as e: + ctx.framework_states[PIPELINE_EXCEPTION_KEY][LATEST_EXCEPTION_KEY] = e + ctx.framework_states[PIPELINE_EXCEPTION_KEY][LATEST_FAILED_NODE_KEY] = self.name self._set_state(ctx, ComponentExecutionState.FAILED) logger.error(f"Service '{self.name}' execution failed!\n{e}") diff --git a/dff/pipeline/types.py b/dff/pipeline/types.py index 4409cc87b..b9ce428a5 100644 --- a/dff/pipeline/types.py +++ b/dff/pipeline/types.py @@ -102,6 +102,9 @@ class ExtraHandlerType(str, Enum): AFTER = "AFTER" +PIPELINE_EXCEPTION_KEY = "EXCEPTION" + + PIPELINE_STATE_KEY = "PIPELINE" """ PIPELINE: storage for services and groups execution status. diff --git a/tutorials/script/core/10_error_conditions.py b/tutorials/script/core/10_error_conditions.py new file mode 100644 index 000000000..7ea08a4f8 --- /dev/null +++ b/tutorials/script/core/10_error_conditions.py @@ -0,0 +1,136 @@ +# %% [markdown] +""" +# Core: 10. Error conditions +""" + +# %pip install dff + +# %% +from typing import Type +from dff.script import GLOBAL, TRANSITIONS, RESPONSE, Context, Message +from dff.pipeline import PIPELINE_EXCEPTION_KEY, LATEST_EXCEPTION_KEY, LATEST_FAILED_NODE_KEY, Pipeline +import dff.script.conditions as cnd +import dff.script.labels as lbl + +from dff.utils.testing.common import ( + check_happy_path, + is_interactive_mode, + run_interactive_mode, +) + + +def raise_exception(exception_class: Type[BaseException]) -> Message: + raise exception_class("Some evil cause!") + + +def print_exception(name: str, _: Pipeline, ctx: Context) -> Message: + exception = ctx.framework_states[PIPELINE_EXCEPTION_KEY].get(LATEST_EXCEPTION_KEY, None) + message = "UNKNOWN" if exception is None else str(exception) + source = ctx.framework_states[PIPELINE_EXCEPTION_KEY].get(LATEST_FAILED_NODE_KEY, None) + return Message(f"Exception type {name} with message '{message}' received from node {source}!") + + +# %% +toy_script = { + GLOBAL: { + TRANSITIONS: { + ("error_flow", "node_name_handler", 1.1): NameError, + ("error_flow", "node_buffer_handler", 1.1): BufferError, + }, + }, + "error_flow": { + "start_node": { + RESPONSE: Message(), + TRANSITIONS: { + "node_start_exceptor": cnd.exact_match(Message("start")), + }, + }, + "node_start_exceptor": { + RESPONSE: Message("Select an exception to throw!"), + TRANSITIONS: { + "node_name_thrower": cnd.exact_match(Message("name")), + "node_buffer_thrower": cnd.exact_match(Message("buffer")), + "node_file_thrower": cnd.exact_match(Message("fallback")), + }, + }, + "node_name_thrower": { + RESPONSE: lambda _, __: raise_exception(NameError), + }, + "node_buffer_thrower": { + RESPONSE: lambda _, __: raise_exception(BufferError), + }, + "node_file_thrower": { + RESPONSE: lambda _, __: raise_exception(FileNotFoundError), + }, + "node_name_handler": { + RESPONSE: lambda ctx, pipeline: print_exception("Name Error", pipeline, ctx), + TRANSITIONS: { + "node_start_exceptor": cnd.exact_match(Message("okay...")), + }, + }, + "node_buffer_handler": { + RESPONSE: lambda ctx, pipeline: print_exception("Buffer Error", pipeline, ctx), + TRANSITIONS: { + "node_start_exceptor": cnd.exact_match(Message("okay...")), + }, + }, + "fallback_node": { + RESPONSE: Message(f"Unexpected message received or an unknown exception caught!"), + TRANSITIONS: { + "node_start_exceptor": cnd.exact_match(Message("okay...")), + }, + }, + } +} + + +happy_path = ( + ( + Message("start"), + Message("Select an exception to throw!"), + ), + ( + Message("name"), + Message("Exception type Name Error with message 'Some evil cause!' received from node actor_0:error_flow:node_name_thrower!"), + ), + ( + Message("okay..."), + Message("Select an exception to throw!"), + ), + ( + Message("buffer"), + Message("Exception type Buffer Error with message 'Some evil cause!' received from node actor_0:error_flow:node_buffer_thrower!"), + ), + ( + Message("okay..."), + Message("Select an exception to throw!"), + ), + ( + Message("fallback"), + Message("Unexpected message received or an unknown exception caught!"), + ), + ( + Message("okay..."), + Message("Select an exception to throw!"), + ), + ( + Message("something"), + Message("Unexpected message received or an unknown exception caught!"), + ), +) + + +# %% +pipeline = Pipeline.from_script( + toy_script, + start_label=("error_flow", "start_node"), + fallback_label=("error_flow", "fallback_node"), + validation_stage=False, +) + +if __name__ == "__main__": + check_happy_path(pipeline, happy_path) + if is_interactive_mode(): + run_interactive_mode(pipeline) + +# %%