Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exception handling rework #329

Draft
wants to merge 1 commit into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dff/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ComponentExecutionState,
GlobalExtraHandlerType,
ExtraHandlerType,
PIPELINE_EXCEPTION_KEY,
PIPELINE_STATE_KEY,
StartConditionCheckerFunction,
StartConditionCheckerAggregationFunction,
Expand All @@ -27,6 +28,7 @@
PipelineBuilder,
)

from .pipeline.actor import LATEST_EXCEPTION_KEY, LATEST_FAILED_NODE_KEY
from .pipeline.pipeline import Pipeline, ACTOR

from .service.extra import BeforeHandler, AfterHandler
Expand Down
66 changes: 58 additions & 8 deletions dff/pipeline/pipeline/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from __future__ import annotations
import logging
import asyncio
from typing import Union, Callable, Optional, Dict, List, TYPE_CHECKING
from typing import Type, Union, Callable, Optional, Dict, List, TYPE_CHECKING
import copy

from dff.utils.turn_caching import cache_clear
Expand All @@ -37,6 +37,10 @@
from dff.script.core.normalization import normalize_label, normalize_response
from dff.script.core.keywords import GLOBAL, LOCAL
from dff.pipeline.service.utils import wrap_sync_function_in_async
from dff.pipeline.types import PIPELINE_EXCEPTION_KEY

LATEST_EXCEPTION_KEY = "LATEST_EXCEPTION"
LATEST_FAILED_NODE_KEY = "LATEST_FAILED_NODE"

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -131,7 +135,7 @@ async def __call__(self, pipeline: Pipeline, ctx: Context):
await self._run_handlers(ctx, pipeline, ActorStage.RUN_PRE_TRANSITIONS_PROCESSING)

# get true labels for scopes (GLOBAL, LOCAL, NODE)
await self._get_true_labels(ctx, pipeline)
await self._get_true_labels(ctx, pipeline, False)
await self._run_handlers(ctx, pipeline, ActorStage.GET_TRUE_LABELS)

# get next node
Expand Down Expand Up @@ -161,6 +165,43 @@ async def __call__(self, pipeline: Pipeline, ctx: Context):

del ctx.framework_states["actor"]

async def process_exception(self, pipeline: Pipeline, ctx: Context):
# context init
self._context_init(ctx)

# get previous node
self._get_previous_node(ctx)

# rewrite previous node
self._rewrite_previous_node(ctx)

# run pre transitions processing
await self._run_pre_transitions_processing(ctx, pipeline)

# get true labels for scopes (GLOBAL, LOCAL, NODE)
await self._get_true_labels(ctx, pipeline, True)

# get next node
self._get_next_node(ctx)

ctx.add_label(ctx.framework_states["actor"]["next_label"][:2])

# rewrite next node
self._rewrite_next_node(ctx)

# run pre response processing
await self._run_pre_response_processing(ctx, pipeline)

# create response
ctx.framework_states["actor"]["response"] = await self.run_response(
ctx.framework_states["actor"]["pre_response_processed_node"].response, ctx, pipeline
)
ctx.add_response(ctx.framework_states["actor"]["response"])

if self._clean_turn_cache:
cache_clear()
del ctx.framework_states["actor"]

@staticmethod
def _context_init(ctx: Optional[Union[Context, dict, str]] = None):
ctx.framework_states["actor"] = {}
Expand All @@ -173,13 +214,13 @@ def _get_previous_node(self, ctx: Context):
ctx.framework_states["actor"]["previous_label"][0], {}
).get(ctx.framework_states["actor"]["previous_label"][1], Node())

async def _get_true_labels(self, ctx: Context, pipeline: Pipeline):
async def _get_true_labels(self, ctx: Context, pipeline: Pipeline, is_exceptional: bool):
# GLOBAL
ctx.framework_states["actor"]["global_transitions"] = (
self.script.get(GLOBAL, {}).get(GLOBAL, Node()).transitions
)
ctx.framework_states["actor"]["global_true_label"] = await self._get_true_label(
ctx.framework_states["actor"]["global_transitions"], ctx, pipeline, GLOBAL, "global"
ctx.framework_states["actor"]["global_transitions"], ctx, pipeline, GLOBAL, is_exceptional, "global"
)

# LOCAL
Expand All @@ -191,6 +232,7 @@ async def _get_true_labels(self, ctx: Context, pipeline: Pipeline):
ctx,
pipeline,
ctx.framework_states["actor"]["previous_label"][0],
is_exceptional,
"local",
)

Expand All @@ -203,6 +245,7 @@ async def _get_true_labels(self, ctx: Context, pipeline: Pipeline):
ctx,
pipeline,
ctx.framework_states["actor"]["previous_label"][0],
is_exceptional,
"node",
)

Expand Down Expand Up @@ -346,14 +389,21 @@ async def _get_true_label(
ctx: Context,
pipeline: Pipeline,
flow_label: LabelType,
is_exceptional: bool,
transition_info: str = "",
) -> Optional[NodeLabel3Type]:
true_labels = []

cond_booleans = await asyncio.gather(
*(self.condition_handler(condition, ctx, pipeline) for condition in transitions.values())
)
for label, cond_is_true in zip(transitions.keys(), cond_booleans):
cond_values = await asyncio.gather(*(self.condition_handler(condition, ctx, pipeline) for condition in transitions.values()))
cond_items = list(zip(transitions.keys(), cond_values))

if is_exceptional:
exception = ctx.framework_states[PIPELINE_EXCEPTION_KEY][LATEST_EXCEPTION_KEY]
cond_items = [(label, isinstance(exception, type(value))) for label, value in cond_items if issubclass(type(value), BaseException)]
else:
cond_items = [(label, value) for label, value in cond_items if isinstance(value, bool)]

for label, cond_is_true in cond_items:
if cond_is_true:
if callable(label):
label = await wrap_sync_function_in_async(label, ctx, pipeline)
Expand Down
11 changes: 8 additions & 3 deletions dff/pipeline/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,21 @@
from typing import Union, List, Dict, Optional, Hashable, Callable

from dff.context_storages import DBContextStorage
from dff.script import Script, Context, ActorStage
from dff.script import NodeLabel2Type, Message
from dff.script import Script, Context, ActorStage, NodeLabel2Type, Message
from dff.utils.turn_caching import cache_clear

from dff.messengers.common import MessengerInterface, CLIMessengerInterface
from ..service.group import ServiceGroup
from ..types import (
ComponentExecutionState,
ServiceBuilder,
ServiceGroupBuilder,
PipelineBuilder,
GlobalExtraHandlerType,
ExtraHandlerFunction,
ExtraHandlerBuilder,
)
from ..types import PIPELINE_STATE_KEY
from ..types import PIPELINE_EXCEPTION_KEY, PIPELINE_STATE_KEY
from .utils import finalize_service_group, pretty_format_component_info_dict
from dff.pipeline.pipeline.actor import Actor

Expand Down Expand Up @@ -343,13 +343,18 @@ async def _run_pipeline(
ctx.misc.update(update_ctx_misc)

ctx.framework_states[PIPELINE_STATE_KEY] = {}
ctx.framework_states[PIPELINE_EXCEPTION_KEY] = {}
ctx.add_request(request)
result = await self._services_pipeline(ctx, self)

if asyncio.iscoroutine(result):
await result

if self._services_pipeline.get_state(ctx) == ComponentExecutionState.FAILED:
await self.actor.process_exception(self, ctx)

del ctx.framework_states[PIPELINE_STATE_KEY]
del ctx.framework_states[PIPELINE_EXCEPTION_KEY]

if isinstance(self.context_storage, DBContextStorage):
await self.context_storage.set_item_async(ctx_id, ctx)
Expand Down
11 changes: 11 additions & 0 deletions dff/pipeline/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@

from .utils import wrap_sync_function_in_async, collect_defined_constructor_parameters_to_dict, _get_attrs_with_updates
from ..types import (
PIPELINE_EXCEPTION_KEY,
ServiceBuilder,
StartConditionCheckerFunction,
ComponentExecutionState,
ExtraHandlerBuilder,
ExtraHandlerType,
)
from ..pipeline.actor import LATEST_EXCEPTION_KEY, LATEST_FAILED_NODE_KEY
from ..pipeline.component import PipelineComponent

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -133,6 +135,13 @@ async def _run_as_actor(self, ctx: Context, pipeline: Pipeline) -> None:
await pipeline.actor(pipeline, ctx)
self._set_state(ctx, ComponentExecutionState.FINISHED)
except Exception as exc:
if "actor" in ctx.framework_states:
last_label = ctx.framework_states["actor"]["next_label"]
latest_node = f"{self.name}:{last_label[0]}:{last_label[1]}"
else:
latest_node = self.name
ctx.framework_states[PIPELINE_EXCEPTION_KEY][LATEST_EXCEPTION_KEY] = exc
ctx.framework_states[PIPELINE_EXCEPTION_KEY][LATEST_FAILED_NODE_KEY] = latest_node
self._set_state(ctx, ComponentExecutionState.FAILED)
logger.error(f"Actor '{self.name}' execution failed!\n{exc}")

Expand All @@ -152,6 +161,8 @@ async def _run_as_service(self, ctx: Context, pipeline: Pipeline) -> None:
else:
self._set_state(ctx, ComponentExecutionState.NOT_RUN)
except Exception as e:
ctx.framework_states[PIPELINE_EXCEPTION_KEY][LATEST_EXCEPTION_KEY] = e
ctx.framework_states[PIPELINE_EXCEPTION_KEY][LATEST_FAILED_NODE_KEY] = self.name
self._set_state(ctx, ComponentExecutionState.FAILED)
logger.error(f"Service '{self.name}' execution failed!\n{e}")

Expand Down
3 changes: 3 additions & 0 deletions dff/pipeline/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ class ExtraHandlerType(str, Enum):
AFTER = "AFTER"


PIPELINE_EXCEPTION_KEY = "EXCEPTION"


PIPELINE_STATE_KEY = "PIPELINE"
"""
PIPELINE: storage for services and groups execution status.
Expand Down
136 changes: 136 additions & 0 deletions tutorials/script/core/10_error_conditions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# %% [markdown]
"""
# Core: 10. Error conditions
"""

# %pip install dff

# %%
from typing import Type
from dff.script import GLOBAL, TRANSITIONS, RESPONSE, Context, Message
from dff.pipeline import PIPELINE_EXCEPTION_KEY, LATEST_EXCEPTION_KEY, LATEST_FAILED_NODE_KEY, Pipeline
import dff.script.conditions as cnd
import dff.script.labels as lbl

from dff.utils.testing.common import (
check_happy_path,
is_interactive_mode,
run_interactive_mode,
)


def raise_exception(exception_class: Type[BaseException]) -> Message:
raise exception_class("Some evil cause!")


def print_exception(name: str, _: Pipeline, ctx: Context) -> Message:
exception = ctx.framework_states[PIPELINE_EXCEPTION_KEY].get(LATEST_EXCEPTION_KEY, None)
message = "UNKNOWN" if exception is None else str(exception)
source = ctx.framework_states[PIPELINE_EXCEPTION_KEY].get(LATEST_FAILED_NODE_KEY, None)
return Message(f"Exception type {name} with message '{message}' received from node {source}!")


# %%
toy_script = {
GLOBAL: {
TRANSITIONS: {
("error_flow", "node_name_handler", 1.1): NameError,
("error_flow", "node_buffer_handler", 1.1): BufferError,
},
},
"error_flow": {
"start_node": {
RESPONSE: Message(),
TRANSITIONS: {
"node_start_exceptor": cnd.exact_match(Message("start")),
},
},
"node_start_exceptor": {
RESPONSE: Message("Select an exception to throw!"),
TRANSITIONS: {
"node_name_thrower": cnd.exact_match(Message("name")),
"node_buffer_thrower": cnd.exact_match(Message("buffer")),
"node_file_thrower": cnd.exact_match(Message("fallback")),
},
},
"node_name_thrower": {
RESPONSE: lambda _, __: raise_exception(NameError),
},
"node_buffer_thrower": {
RESPONSE: lambda _, __: raise_exception(BufferError),
},
"node_file_thrower": {
RESPONSE: lambda _, __: raise_exception(FileNotFoundError),
},
"node_name_handler": {
RESPONSE: lambda ctx, pipeline: print_exception("Name Error", pipeline, ctx),
TRANSITIONS: {
"node_start_exceptor": cnd.exact_match(Message("okay...")),
},
},
"node_buffer_handler": {
RESPONSE: lambda ctx, pipeline: print_exception("Buffer Error", pipeline, ctx),
TRANSITIONS: {
"node_start_exceptor": cnd.exact_match(Message("okay...")),
},
},
"fallback_node": {
RESPONSE: Message(f"Unexpected message received or an unknown exception caught!"),
TRANSITIONS: {
"node_start_exceptor": cnd.exact_match(Message("okay...")),
},
},
}
}


happy_path = (
(
Message("start"),
Message("Select an exception to throw!"),
),
(
Message("name"),
Message("Exception type Name Error with message 'Some evil cause!' received from node actor_0:error_flow:node_name_thrower!"),
),
(
Message("okay..."),
Message("Select an exception to throw!"),
),
(
Message("buffer"),
Message("Exception type Buffer Error with message 'Some evil cause!' received from node actor_0:error_flow:node_buffer_thrower!"),
),
(
Message("okay..."),
Message("Select an exception to throw!"),
),
(
Message("fallback"),
Message("Unexpected message received or an unknown exception caught!"),
),
(
Message("okay..."),
Message("Select an exception to throw!"),
),
(
Message("something"),
Message("Unexpected message received or an unknown exception caught!"),
),
)


# %%
pipeline = Pipeline.from_script(
toy_script,
start_label=("error_flow", "start_node"),
fallback_label=("error_flow", "fallback_node"),
validation_stage=False,
)

if __name__ == "__main__":
check_happy_path(pipeline, happy_path)
if is_interactive_mode():
run_interactive_mode(pipeline)

# %%
Loading