Skip to content

Commit

Permalink
Add unit tests for GRPCActionServerWebhook
Browse files Browse the repository at this point in the history
  • Loading branch information
radovanZRasa committed Jun 19, 2024
1 parent 8b8ce99 commit 3698521
Show file tree
Hide file tree
Showing 6 changed files with 325 additions and 39 deletions.
2 changes: 1 addition & 1 deletion proto/action_webhook.proto
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ service ActionService {
message ActionsRequest {}

message ActionsResponse {
map<string, string> actions = 1;
repeated google.protobuf.Struct actions = 1;
}

message Tracker {
Expand Down
46 changes: 21 additions & 25 deletions rasa_sdk/grpc_py/action_webhook_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 2 additions & 9 deletions rasa_sdk/grpc_py/action_webhook_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,9 @@ class ActionsRequest(_message.Message):

class ActionsResponse(_message.Message):
__slots__ = ["actions"]
class ActionsEntry(_message.Message):
__slots__ = ["key", "value"]
KEY_FIELD_NUMBER: _ClassVar[int]
VALUE_FIELD_NUMBER: _ClassVar[int]
key: str
value: str
def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ...
ACTIONS_FIELD_NUMBER: _ClassVar[int]
actions: _containers.ScalarMap[str, str]
def __init__(self, actions: _Optional[_Mapping[str, str]] = ...) -> None: ...
actions: _containers.RepeatedCompositeFieldContainer[_struct_pb2.Struct]
def __init__(self, actions: _Optional[_Iterable[_Union[_struct_pb2.Struct, _Mapping]]] = ...) -> None: ...

class Tracker(_message.Message):
__slots__ = ["sender_id", "slots", "latest_message", "events", "paused", "followup_action", "active_loop", "latest_action_name", "stack"]
Expand Down
15 changes: 12 additions & 3 deletions rasa_sdk/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ def __init__(
self.auto_reload = auto_reload
self.executor = executor

async def Actions(self, request: ActionsRequest, context) -> ActionsResponse:
async def Actions(
self,
request: ActionsRequest,
context: grpc.aio.ServicerContext,
) -> ActionsResponse:
"""Handle RPC request for the actions.
Args:
Expand All @@ -107,9 +111,14 @@ async def Actions(self, request: ActionsRequest, context) -> ActionsResponse:
if self.auto_reload:
self.executor.reload()

actions = self.executor.list_actions()
actions = [action.model_dump() for action in self.executor.list_actions()]
response = ActionsResponse()
return ParseDict(actions, response)
return ParseDict(
{
"actions": actions,
},
response,
)

async def Webhook(
self,
Expand Down
150 changes: 149 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from typing import List, Dict, Text, Any
from unittest.mock import MagicMock, AsyncMock

import grpc
from google.protobuf.json_format import ParseDict
from sanic import Sanic

import pytest

import rasa_sdk
from rasa_sdk import Action, FormValidationAction, Tracker, ValidationAction
from rasa_sdk.events import SlotSet
from rasa_sdk.executor import CollectingDispatcher
from rasa_sdk.executor import CollectingDispatcher, ActionExecutor, ActionName
from rasa_sdk.grpc_py import action_webhook_pb2
from rasa_sdk.grpc_server import GRPCActionServerWebhook
from rasa_sdk.types import DomainDict

Sanic.test_mode = True
Expand Down Expand Up @@ -147,3 +153,145 @@ def name(self):
class SubclassTestActionB(SubclassTestActionA):
def name(self):
return "subclass_test_action_b"


@pytest.fixture
def grpc_domain() -> action_webhook_pb2.Domain:
"""Create a gRPC domain."""
return action_webhook_pb2.Domain(
config={},
session_config={},
intents=[],
entities=[],
slots={},
responses={},
actions=[],
forms={},
e2e_actions=[],
)


@pytest.fixture
def sender_id() -> str:
return "test_sender_id"


@pytest.fixture
def grpc_tracker(sender_id: str) -> action_webhook_pb2.Tracker:
"""Create a gRPC tracker."""
return action_webhook_pb2.Tracker(
sender_id=sender_id,
slots={},
latest_message={},
events=[],
paused=False,
followup_action="",
active_loop={},
latest_action_name="",
stack={},
)


@pytest.fixture
def current_rasa_version() -> str:
"""Return current Rasa version."""
return rasa_sdk.__version__


@pytest.fixture
def previous_rasa_version() -> str:
"""Return previous Rasa version."""
return "1.0.0"


@pytest.fixture
def action_name() -> str:
"""Return action name."""
return "action_listen"


@pytest.fixture
def grpc_webhook_request(
sender_id: str,
action_name: str,
grpc_tracker: action_webhook_pb2.Tracker,
grpc_domain: action_webhook_pb2.Domain,
current_rasa_version: str,
) -> action_webhook_pb2.WebhookRequest:
"""Create a webhook request."""
return action_webhook_pb2.WebhookRequest(
next_action=action_name,
sender_id=sender_id,
tracker=grpc_tracker,
domain=grpc_domain,
version=current_rasa_version,
domain_digest="",
)


@pytest.fixture
def mock_executor() -> AsyncMock:
"""Create a mock action executor."""
return AsyncMock(spec=ActionExecutor)


@pytest.fixture
def mock_grpc_service_context() -> MagicMock:
"""Create a mock gRPC service context."""
return MagicMock(spec=grpc.aio.ServicerContext)


@pytest.fixture
def grpc_action_server_webhook(mock_executor: AsyncMock) -> GRPCActionServerWebhook:
"""Create a GRPCActionServerWebhook instance with a mock executor."""
return GRPCActionServerWebhook(executor=mock_executor)


@pytest.fixture
def response_events() -> List[Dict[Text, Any]]:
"""Create a list of response events."""
return [{"event": "slot", "name": "test", "value": "foo"}]


@pytest.fixture
def responses() -> List[Dict[Text, Any]]:
"""Create a gRPC webhook response."""
return [{"utter": "Hi"}]


@pytest.fixture
def executor_response(
response_events: List[Dict[Text, Any]], responses: List[Dict[Text, Any]]
) -> Dict[Text, Any]:
"""Create an executor response."""
return {"events": response_events, "responses": responses}


@pytest.fixture
def grpc_webhook_response(
executor_response: Dict[Text, Any],
) -> action_webhook_pb2.WebhookResponse:
"""Create a gRPC webhook response."""
result = action_webhook_pb2.WebhookResponse()
return ParseDict(executor_response, result)


def action_names() -> List[ActionName]:
"""Create a list of action names."""
return [
ActionName(name="action_listen"),
ActionName(name="action_restart"),
ActionName(name="action_session_start"),
]


def grpc_actions_response() -> action_webhook_pb2.ActionsResponse:
"""Create a gRPC actions response."""
actions = [action.model_dump() for action in action_names()]
result = action_webhook_pb2.ActionsResponse()
return ParseDict(
{
"actions": actions,
},
result,
)
Loading

0 comments on commit 3698521

Please sign in to comment.