Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Domain Payload Optimization to Action server #1108

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion rasa_sdk/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@
DEFAULT_SERVER_PORT,
)
from rasa_sdk.executor import ActionExecutor
from rasa_sdk.interfaces import ActionExecutionRejection, ActionNotFoundException
from rasa_sdk.interfaces import (
ActionExecutionRejection,
ActionNotFoundException,
ActionMissingDomainException,
)
from rasa_sdk.plugin import plugin_manager
from rasa_sdk.tracing.utils import (
get_tracer_and_context,
Expand Down Expand Up @@ -153,6 +157,10 @@ async def webhook(request: Request) -> HTTPResponse:
logger.error(e)
body = {"error": e.message, "action_name": e.action_name}
return response.json(body, status=404)
except ActionMissingDomainException as e:
logger.error(e)
body = {"error": e.message, "action_name": e.action_name}
return response.json(body, status=449)

set_span_attributes(span, action_call)

Expand Down
60 changes: 52 additions & 8 deletions rasa_sdk/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,29 @@
import inspect
import logging
import pkgutil
import typing
import warnings
from typing import Text, List, Dict, Any, Type, Union, Callable, Optional, Set, cast
from collections import namedtuple
import types
import sys
import os

from rasa_sdk.interfaces import Tracker, ActionNotFoundException, Action
from rasa_sdk.interfaces import (
Tracker,
ActionNotFoundException,
Action,
ActionMissingDomainException,
)

from rasa_sdk import utils

if typing.TYPE_CHECKING: # pragma: no cover
from rasa_sdk.types import ActionCall

logger = logging.getLogger(__name__)


class CollectingDispatcher:
"""Send messages back to user"""

def __init__(self) -> None:

self.messages: List[Dict[Text, Any]] = []

def utter_message(
Expand Down Expand Up @@ -162,6 +162,8 @@ def __init__(self) -> None:
self.actions: Dict[Text, Callable] = {}
self._modules: Dict[Text, TimestampModule] = {}
self._loaded: Set[Type[Action]] = set()
self.domain: Optional[Dict[Text, Any]] = None
self.domain_digest: Optional[Text] = None

def register_action(self, action: Union[Type[Action], Action]) -> None:
if inspect.isclass(action):
Expand Down Expand Up @@ -380,7 +382,49 @@ def validate_events(events: List[Dict[Text, Any]], action_name: Text):
# we won't append this to validated events -> will be ignored
return validated

async def run(self, action_call: "ActionCall") -> Optional[Dict[Text, Any]]:
def is_domain_digest_valid(self, domain_digest: Optional[Text]) -> bool:
"""Check if the domain_digest is valid
If the domain_digest is empty or different from the one provided, it is invalid.
Args:
domain_digest: latest value provided to compare the current value with.
Returns:
True if the domain_digest is valid, False otherwise.
"""
return bool(self.domain_digest) and self.domain_digest == domain_digest

def update_and_return_domain(
self, payload: Dict[Text, Any], action_name: Text
) -> Optional[Dict[Text, Any]]:
"""Validate the digest, store the domain if available, and return the domain.
This method validates the domain digest from the payload.
If the digest is invalid and no domain is provided, an exception is raised.
If domain data is available, it stores the domain and digest.
Finally, it returns the domain.
Args:
payload: Request payload containing the domain data.
action_name: Name of the action that should be executed.
Returns:
The domain dictionary.
Raises:
ActionMissingDomainException: Invalid digest and no domain data available.
"""
payload_domain = payload.get("domain")
payload_domain_digest = payload.get("domain_digest")

# If digest is invalid and no domain is available - raise the error
if (
not self.is_domain_digest_valid(payload_domain_digest)
and payload_domain is None
):
raise ActionMissingDomainException(action_name)

if payload_domain:
self.domain = payload_domain
self.domain_digest = payload_domain_digest

return self.domain

async def run(self, action_call: Dict[Text, Any]) -> Optional[Dict[Text, Any]]:
from rasa_sdk.interfaces import Tracker

action_name = action_call.get("next_action")
Expand All @@ -391,7 +435,7 @@ async def run(self, action_call: "ActionCall") -> Optional[Dict[Text, Any]]:
raise ActionNotFoundException(action_name)

tracker_json = action_call["tracker"]
domain = action_call.get("domain", {})
domain = self.update_and_return_domain(action_call, action_name)
tracker = Tracker.from_dict(tracker_json)
dispatcher = CollectingDispatcher()

Expand Down
11 changes: 11 additions & 0 deletions rasa_sdk/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,14 @@ def __init__(self, action_name: Text, message: Optional[Text] = None) -> None:

def __str__(self) -> Text:
return self.message


class ActionMissingDomainException(Exception):
"""Raising this exception when the domain is missing."""

def __init__(self, action_name: Text, message: Optional[Text] = None) -> None:
self.action_name = action_name
self.message = message or "Domain context is missing."

def __str__(self) -> Text:
return self.message
4 changes: 4 additions & 0 deletions tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def test_server_webhook_handles_action_exception(sanic_app: Sanic):
data = {
"next_action": "custom_action_exception",
"tracker": {"sender_id": "1", "conversation_id": "default"},
"domain": {},
}
request, response = sanic_app.test_client.post("/webhook", data=json.dumps(data))
assert response.status == 500
Expand All @@ -76,6 +77,7 @@ def test_server_webhook_custom_action_returns_200(sanic_app: Sanic):
data = {
"next_action": "custom_action",
"tracker": {"sender_id": "1", "conversation_id": "default"},
"domain": {},
}
request, response = sanic_app.test_client.post("/webhook", data=json.dumps(data))
events = response.json.get("events")
Expand All @@ -88,6 +90,7 @@ def test_server_webhook_custom_async_action_returns_200(sanic_app: Sanic):
data = {
"next_action": "custom_async_action",
"tracker": {"sender_id": "1", "conversation_id": "default"},
"domain": {},
}
request, response = sanic_app.test_client.post("/webhook", data=json.dumps(data))
events = response.json.get("events")
Expand Down Expand Up @@ -148,6 +151,7 @@ def test_server_webhook_custom_action_with_dialogue_stack_returns_200(
data = {
"next_action": "custom_action_with_dialogue_stack",
"tracker": {"sender_id": "1", "conversation_id": "default", **stack_state},
"domain": {},
}
_, response = sanic_app.test_client.post("/webhook", data=json.dumps(data))
events = response.json.get("events")
Expand Down
1 change: 1 addition & 0 deletions tests/tracing/instrumentation/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def test_server_webhook_custom_action_is_instrumented(
"rasa_sdk.endpoint.get_tracer_provider", lambda _: tracer_provider
)
data["next_action"] = action_name
data["domain"] = {}
app = ep.create_app(action_package)

app.register_listener(
Expand Down
Loading