Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ATO-1652] Add unit tests for GRPCActionServerWebhook #1112

Merged
merged 1 commit into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion proto/action_webhook.proto
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ service ActionService {
message ActionsRequest {}

message ActionsResponse {
map<string, string> actions = 1;
repeated google.protobuf.Struct actions = 1;
}

message Tracker {
Expand Down
46 changes: 21 additions & 25 deletions rasa_sdk/grpc_py/action_webhook_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 2 additions & 9 deletions rasa_sdk/grpc_py/action_webhook_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,9 @@ class ActionsRequest(_message.Message):

class ActionsResponse(_message.Message):
__slots__ = ["actions"]
class ActionsEntry(_message.Message):
__slots__ = ["key", "value"]
KEY_FIELD_NUMBER: _ClassVar[int]
VALUE_FIELD_NUMBER: _ClassVar[int]
key: str
value: str
def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ...
ACTIONS_FIELD_NUMBER: _ClassVar[int]
actions: _containers.ScalarMap[str, str]
def __init__(self, actions: _Optional[_Mapping[str, str]] = ...) -> None: ...
actions: _containers.RepeatedCompositeFieldContainer[_struct_pb2.Struct]
def __init__(self, actions: _Optional[_Iterable[_Union[_struct_pb2.Struct, _Mapping]]] = ...) -> None: ...

class Tracker(_message.Message):
__slots__ = ["sender_id", "slots", "latest_message", "events", "paused", "followup_action", "active_loop", "latest_action_name", "stack"]
Expand Down
15 changes: 12 additions & 3 deletions rasa_sdk/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ def __init__(
self.auto_reload = auto_reload
self.executor = executor

async def Actions(self, request: ActionsRequest, context) -> ActionsResponse:
async def Actions(
self,
request: ActionsRequest,
context: grpc.aio.ServicerContext,
) -> ActionsResponse:
"""Handle RPC request for the actions.

Args:
Expand All @@ -107,9 +111,14 @@ async def Actions(self, request: ActionsRequest, context) -> ActionsResponse:
if self.auto_reload:
self.executor.reload()

actions = self.executor.list_actions()
actions = [action.model_dump() for action in self.executor.list_actions()]
response = ActionsResponse()
return ParseDict(actions, response)
return ParseDict(
{
"actions": actions,
},
response,
)

async def Webhook(
self,
Expand Down
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest

import rasa_sdk
from rasa_sdk import Action, FormValidationAction, Tracker, ValidationAction
from rasa_sdk.events import SlotSet
from rasa_sdk.executor import CollectingDispatcher
Expand All @@ -13,6 +14,7 @@


def get_stack():
"""Return a dialogue stack."""
dialogue_stack = [
{
"frame_id": "CP6JP9GQ",
Expand Down Expand Up @@ -147,3 +149,15 @@ def name(self):
class SubclassTestActionB(SubclassTestActionA):
def name(self):
return "subclass_test_action_b"


@pytest.fixture
def current_rasa_version() -> str:
"""Return current Rasa version."""
return rasa_sdk.__version__


@pytest.fixture
def previous_rasa_version() -> str:
"""Return previous Rasa version."""
return "1.0.0"
245 changes: 245 additions & 0 deletions tests/test_grpc_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
from typing import Union, Any, Dict, Text, List
from unittest.mock import MagicMock, AsyncMock

import grpc
import pytest
from google.protobuf.json_format import MessageToDict, ParseDict

from rasa_sdk import ActionExecutionRejection
from rasa_sdk.executor import ActionName, ActionExecutor
from rasa_sdk.grpc_errors import (
ActionExecutionFailed,
ResourceNotFound,
ResourceNotFoundType,
)
from rasa_sdk.grpc_py import action_webhook_pb2
from rasa_sdk.grpc_server import GRPCActionServerWebhook
from rasa_sdk.interfaces import ActionMissingDomainException, ActionNotFoundException


@pytest.fixture
def sender_id() -> str:
"""Return sender id."""
return "test_sender_id"


@pytest.fixture
def action_name() -> str:
"""Return action name."""
return "action_listen"


@pytest.fixture
def grpc_webhook_request(
sender_id: str,
action_name: str,
current_rasa_version: str,
) -> action_webhook_pb2.WebhookRequest:
"""Create a webhook request."""
return action_webhook_pb2.WebhookRequest(
next_action=action_name,
sender_id=sender_id,
tracker=action_webhook_pb2.Tracker(
sender_id=sender_id,
slots={},
latest_message={},
events=[],
paused=False,
followup_action="",
active_loop={},
latest_action_name="",
stack={},
),
domain=action_webhook_pb2.Domain(
config={},
session_config={},
intents=[],
entities=[],
slots={},
responses={},
actions=[],
forms={},
e2e_actions=[],
),
version=current_rasa_version,
domain_digest="",
)


@pytest.fixture
def mock_executor() -> AsyncMock:
"""Create a mock action executor."""
return AsyncMock(spec=ActionExecutor)


@pytest.fixture
def mock_grpc_service_context() -> MagicMock:
"""Create a mock gRPC service context."""
return MagicMock(spec=grpc.aio.ServicerContext)


@pytest.fixture
def grpc_action_server_webhook(mock_executor: AsyncMock) -> GRPCActionServerWebhook:
"""Create a GRPCActionServerWebhook instance with a mock executor."""
return GRPCActionServerWebhook(executor=mock_executor)


@pytest.fixture
def executor_response() -> Dict[Text, Any]:
"""Create an executor response."""
return {
"events": [{"event": "slot", "name": "test", "value": "foo"}],
"responses": [{"utter": "Hi"}],
}


@pytest.fixture
def expected_grpc_webhook_response(
executor_response: Dict[Text, Any],
) -> action_webhook_pb2.WebhookResponse:
"""Create a gRPC webhook response."""
result = action_webhook_pb2.WebhookResponse()
return ParseDict(executor_response, result)


def action_names() -> List[ActionName]:
"""Create a list of action names."""
return [
ActionName(name="action_listen"),
ActionName(name="action_restart"),
ActionName(name="action_session_start"),
]


def expected_grpc_actions_response() -> action_webhook_pb2.ActionsResponse:
"""Create a gRPC actions response."""
actions = [action.model_dump() for action in action_names()]
result = action_webhook_pb2.ActionsResponse()
return ParseDict(
{
"actions": actions,
},
result,
)


@pytest.mark.parametrize(
"auto_reload, expected_reload_call_count", [(True, 1), (False, 0)]
)
async def test_grpc_action_server_webhook_no_errors(
auto_reload: bool,
expected_reload_call_count: int,
grpc_action_server_webhook: GRPCActionServerWebhook,
grpc_webhook_request: action_webhook_pb2.WebhookRequest,
mock_executor: AsyncMock,
mock_grpc_service_context: MagicMock,
executor_response: Dict[Text, Any],
expected_grpc_webhook_response: action_webhook_pb2.WebhookResponse,
):
"""Test that the gRPC action server webhook can handle a request without errors."""
grpc_action_server_webhook.auto_reload = auto_reload
mock_executor.run.return_value = executor_response
response = await grpc_action_server_webhook.Webhook(
grpc_webhook_request,
mock_grpc_service_context,
)

assert response == expected_grpc_webhook_response

mock_grpc_service_context.set_code.assert_not_called()
mock_grpc_service_context.set_details.assert_not_called()

assert mock_executor.reload.call_count == expected_reload_call_count

expected_action_call = MessageToDict(
grpc_webhook_request,
preserving_proto_field_name=True,
)
mock_executor.run.assert_called_once_with(expected_action_call)


@pytest.mark.parametrize(
"exception, expected_status_code, expected_body",
[
(
ActionExecutionRejection("action_name", "message"),
grpc.StatusCode.INTERNAL,
ActionExecutionFailed(
action_name="action_name", message="message"
).model_dump_json(),
),
(
ActionNotFoundException("action_name", "message"),
grpc.StatusCode.NOT_FOUND,
ResourceNotFound(
action_name="action_name",
message="message",
resource_type=ResourceNotFoundType.ACTION,
).model_dump_json(),
),
(
ActionMissingDomainException("action_name", "message"),
grpc.StatusCode.NOT_FOUND,
ResourceNotFound(
action_name="action_name",
message="message",
resource_type=ResourceNotFoundType.DOMAIN,
).model_dump_json(),
),
],
)
async def test_grpc_action_server_webhook_action_execution_rejected(
exception: Union[
ActionExecutionRejection, ActionNotFoundException, ActionMissingDomainException
],
expected_status_code: grpc.StatusCode,
expected_body: str,
grpc_action_server_webhook: GRPCActionServerWebhook,
grpc_webhook_request: action_webhook_pb2.WebhookRequest,
mock_executor: AsyncMock,
mock_grpc_service_context: MagicMock,
):
"""Test that the gRPC action server webhook can handle a request with an action execution rejection.""" # noqa: E501
mock_executor.run.side_effect = exception
response = await grpc_action_server_webhook.Webhook(
grpc_webhook_request,
mock_grpc_service_context,
)

assert response == action_webhook_pb2.WebhookResponse()

mock_grpc_service_context.set_code.assert_called_once_with(expected_status_code)
mock_grpc_service_context.set_details.assert_called_once_with(expected_body)


@pytest.mark.parametrize(
"given_action_names, expected_grpc_actions_response",
[
(
[],
action_webhook_pb2.ActionsResponse(),
),
(
action_names(),
expected_grpc_actions_response(),
),
],
)
async def test_grpc_action_server_actions(
given_action_names: List[ActionName],
expected_grpc_actions_response: action_webhook_pb2.ActionsResponse,
grpc_action_server_webhook: GRPCActionServerWebhook,
mock_grpc_service_context: MagicMock,
mock_executor: AsyncMock,
):
"""Test that the gRPC action server webhook can handle a request for actions."""
mock_executor.list_actions.return_value = given_action_names

response = await grpc_action_server_webhook.Actions(
action_webhook_pb2.ActionsRequest(), mock_grpc_service_context
)

assert response == expected_grpc_actions_response

mock_grpc_service_context.set_code.assert_not_called()
mock_grpc_service_context.set_details.assert_not_called()
Loading