diff --git a/proto/action_webhook.proto b/proto/action_webhook.proto index 5d4e9da9..55d1bf7e 100644 --- a/proto/action_webhook.proto +++ b/proto/action_webhook.proto @@ -11,7 +11,7 @@ service ActionService { message ActionsRequest {} message ActionsResponse { - map actions = 1; + repeated google.protobuf.Struct actions = 1; } message Tracker { diff --git a/rasa_sdk/grpc_py/action_webhook_pb2.py b/rasa_sdk/grpc_py/action_webhook_pb2.py index 26f550a6..385d0ed6 100644 --- a/rasa_sdk/grpc_py/action_webhook_pb2.py +++ b/rasa_sdk/grpc_py/action_webhook_pb2.py @@ -14,7 +14,7 @@ from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n%rasa_sdk/grpc_py/action_webhook.proto\x12\x15\x61\x63tion_server_webhook\x1a\x1cgoogle/protobuf/struct.proto\"\x10\n\x0e\x41\x63tionsRequest\"\x87\x01\n\x0f\x41\x63tionsResponse\x12\x44\n\x07\x61\x63tions\x18\x01 \x03(\x0b\x32\x33.action_server_webhook.ActionsResponse.ActionsEntry\x1a.\n\x0c\x41\x63tionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xb8\x03\n\x07Tracker\x12\x11\n\tsender_id\x18\x01 \x01(\t\x12&\n\x05slots\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12/\n\x0elatest_message\x18\x03 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\'\n\x06\x65vents\x18\x04 \x03(\x0b\x32\x17.google.protobuf.Struct\x12\x0e\n\x06paused\x18\x05 \x01(\x08\x12\x1c\n\x0f\x66ollowup_action\x18\x06 \x01(\tH\x00\x88\x01\x01\x12\x43\n\x0b\x61\x63tive_loop\x18\x07 \x03(\x0b\x32..action_server_webhook.Tracker.ActiveLoopEntry\x12\x1f\n\x12latest_action_name\x18\x08 \x01(\tH\x01\x88\x01\x01\x12&\n\x05stack\x18\t \x03(\x0b\x32\x17.google.protobuf.Struct\x1a\x31\n\x0f\x41\x63tiveLoopEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\x12\n\x10_followup_actionB\x15\n\x13_latest_action_name\"K\n\x06Intent\x12\x14\n\x0cstring_value\x18\x01 \x01(\t\x12+\n\ndict_value\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"K\n\x06\x45ntity\x12\x14\n\x0cstring_value\x18\x01 \x01(\t\x12+\n\ndict_value\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"K\n\x06\x41\x63tion\x12\x14\n\x0cstring_value\x18\x01 \x01(\t\x12+\n\ndict_value\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"\x9d\x03\n\x06\x44omain\x12\'\n\x06\x63onfig\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\x12/\n\x0esession_config\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12.\n\x07intents\x18\x03 \x03(\x0b\x32\x1d.action_server_webhook.Intent\x12/\n\x08\x65ntities\x18\x04 \x03(\x0b\x32\x1d.action_server_webhook.Entity\x12&\n\x05slots\x18\x05 \x01(\x0b\x32\x17.google.protobuf.Struct\x12*\n\tresponses\x18\x06 \x01(\x0b\x32\x17.google.protobuf.Struct\x12.\n\x07\x61\x63tions\x18\x07 \x03(\x0b\x32\x1d.action_server_webhook.Action\x12&\n\x05\x66orms\x18\x08 \x01(\x0b\x32\x17.google.protobuf.Struct\x12,\n\x0b\x65\x32\x65_actions\x18\t \x03(\x0b\x32\x17.google.protobuf.Struct\"\xd7\x01\n\x0eWebhookRequest\x12\x13\n\x0bnext_action\x18\x01 \x01(\t\x12\x11\n\tsender_id\x18\x02 \x01(\t\x12/\n\x07tracker\x18\x03 \x01(\x0b\x32\x1e.action_server_webhook.Tracker\x12-\n\x06\x64omain\x18\x04 \x01(\x0b\x32\x1d.action_server_webhook.Domain\x12\x0f\n\x07version\x18\x05 \x01(\t\x12\x1a\n\rdomain_digest\x18\x06 \x01(\tH\x00\x88\x01\x01\x42\x10\n\x0e_domain_digest\"f\n\x0fWebhookResponse\x12\'\n\x06\x65vents\x18\x01 \x03(\x0b\x32\x17.google.protobuf.Struct\x12*\n\tresponses\x18\x02 \x03(\x0b\x32\x17.google.protobuf.Struct2\xc3\x01\n\rActionService\x12X\n\x07Webhook\x12%.action_server_webhook.WebhookRequest\x1a&.action_server_webhook.WebhookResponse\x12X\n\x07\x41\x63tions\x12%.action_server_webhook.ActionsRequest\x1a&.action_server_webhook.ActionsResponseb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n%rasa_sdk/grpc_py/action_webhook.proto\x12\x15\x61\x63tion_server_webhook\x1a\x1cgoogle/protobuf/struct.proto\"\x10\n\x0e\x41\x63tionsRequest\";\n\x0f\x41\x63tionsResponse\x12(\n\x07\x61\x63tions\x18\x01 \x03(\x0b\x32\x17.google.protobuf.Struct\"\xb8\x03\n\x07Tracker\x12\x11\n\tsender_id\x18\x01 \x01(\t\x12&\n\x05slots\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12/\n\x0elatest_message\x18\x03 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\'\n\x06\x65vents\x18\x04 \x03(\x0b\x32\x17.google.protobuf.Struct\x12\x0e\n\x06paused\x18\x05 \x01(\x08\x12\x1c\n\x0f\x66ollowup_action\x18\x06 \x01(\tH\x00\x88\x01\x01\x12\x43\n\x0b\x61\x63tive_loop\x18\x07 \x03(\x0b\x32..action_server_webhook.Tracker.ActiveLoopEntry\x12\x1f\n\x12latest_action_name\x18\x08 \x01(\tH\x01\x88\x01\x01\x12&\n\x05stack\x18\t \x03(\x0b\x32\x17.google.protobuf.Struct\x1a\x31\n\x0f\x41\x63tiveLoopEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\x12\n\x10_followup_actionB\x15\n\x13_latest_action_name\"K\n\x06Intent\x12\x14\n\x0cstring_value\x18\x01 \x01(\t\x12+\n\ndict_value\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"K\n\x06\x45ntity\x12\x14\n\x0cstring_value\x18\x01 \x01(\t\x12+\n\ndict_value\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"K\n\x06\x41\x63tion\x12\x14\n\x0cstring_value\x18\x01 \x01(\t\x12+\n\ndict_value\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"\x9d\x03\n\x06\x44omain\x12\'\n\x06\x63onfig\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\x12/\n\x0esession_config\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12.\n\x07intents\x18\x03 \x03(\x0b\x32\x1d.action_server_webhook.Intent\x12/\n\x08\x65ntities\x18\x04 \x03(\x0b\x32\x1d.action_server_webhook.Entity\x12&\n\x05slots\x18\x05 \x01(\x0b\x32\x17.google.protobuf.Struct\x12*\n\tresponses\x18\x06 \x01(\x0b\x32\x17.google.protobuf.Struct\x12.\n\x07\x61\x63tions\x18\x07 \x03(\x0b\x32\x1d.action_server_webhook.Action\x12&\n\x05\x66orms\x18\x08 \x01(\x0b\x32\x17.google.protobuf.Struct\x12,\n\x0b\x65\x32\x65_actions\x18\t \x03(\x0b\x32\x17.google.protobuf.Struct\"\xd7\x01\n\x0eWebhookRequest\x12\x13\n\x0bnext_action\x18\x01 \x01(\t\x12\x11\n\tsender_id\x18\x02 \x01(\t\x12/\n\x07tracker\x18\x03 \x01(\x0b\x32\x1e.action_server_webhook.Tracker\x12-\n\x06\x64omain\x18\x04 \x01(\x0b\x32\x1d.action_server_webhook.Domain\x12\x0f\n\x07version\x18\x05 \x01(\t\x12\x1a\n\rdomain_digest\x18\x06 \x01(\tH\x00\x88\x01\x01\x42\x10\n\x0e_domain_digest\"f\n\x0fWebhookResponse\x12\'\n\x06\x65vents\x18\x01 \x03(\x0b\x32\x17.google.protobuf.Struct\x12*\n\tresponses\x18\x02 \x03(\x0b\x32\x17.google.protobuf.Struct2\xc3\x01\n\rActionService\x12X\n\x07Webhook\x12%.action_server_webhook.WebhookRequest\x1a&.action_server_webhook.WebhookResponse\x12X\n\x07\x41\x63tions\x12%.action_server_webhook.ActionsRequest\x1a&.action_server_webhook.ActionsResponseb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -22,32 +22,28 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _ACTIONSRESPONSE_ACTIONSENTRY._options = None - _ACTIONSRESPONSE_ACTIONSENTRY._serialized_options = b'8\001' _TRACKER_ACTIVELOOPENTRY._options = None _TRACKER_ACTIVELOOPENTRY._serialized_options = b'8\001' _globals['_ACTIONSREQUEST']._serialized_start=94 _globals['_ACTIONSREQUEST']._serialized_end=110 - _globals['_ACTIONSRESPONSE']._serialized_start=113 - _globals['_ACTIONSRESPONSE']._serialized_end=248 - _globals['_ACTIONSRESPONSE_ACTIONSENTRY']._serialized_start=202 - _globals['_ACTIONSRESPONSE_ACTIONSENTRY']._serialized_end=248 - _globals['_TRACKER']._serialized_start=251 - _globals['_TRACKER']._serialized_end=691 - _globals['_TRACKER_ACTIVELOOPENTRY']._serialized_start=599 - _globals['_TRACKER_ACTIVELOOPENTRY']._serialized_end=648 - _globals['_INTENT']._serialized_start=693 - _globals['_INTENT']._serialized_end=768 - _globals['_ENTITY']._serialized_start=770 - _globals['_ENTITY']._serialized_end=845 - _globals['_ACTION']._serialized_start=847 - _globals['_ACTION']._serialized_end=922 - _globals['_DOMAIN']._serialized_start=925 - _globals['_DOMAIN']._serialized_end=1338 - _globals['_WEBHOOKREQUEST']._serialized_start=1341 - _globals['_WEBHOOKREQUEST']._serialized_end=1556 - _globals['_WEBHOOKRESPONSE']._serialized_start=1558 - _globals['_WEBHOOKRESPONSE']._serialized_end=1660 - _globals['_ACTIONSERVICE']._serialized_start=1663 - _globals['_ACTIONSERVICE']._serialized_end=1858 + _globals['_ACTIONSRESPONSE']._serialized_start=112 + _globals['_ACTIONSRESPONSE']._serialized_end=171 + _globals['_TRACKER']._serialized_start=174 + _globals['_TRACKER']._serialized_end=614 + _globals['_TRACKER_ACTIVELOOPENTRY']._serialized_start=522 + _globals['_TRACKER_ACTIVELOOPENTRY']._serialized_end=571 + _globals['_INTENT']._serialized_start=616 + _globals['_INTENT']._serialized_end=691 + _globals['_ENTITY']._serialized_start=693 + _globals['_ENTITY']._serialized_end=768 + _globals['_ACTION']._serialized_start=770 + _globals['_ACTION']._serialized_end=845 + _globals['_DOMAIN']._serialized_start=848 + _globals['_DOMAIN']._serialized_end=1261 + _globals['_WEBHOOKREQUEST']._serialized_start=1264 + _globals['_WEBHOOKREQUEST']._serialized_end=1479 + _globals['_WEBHOOKRESPONSE']._serialized_start=1481 + _globals['_WEBHOOKRESPONSE']._serialized_end=1583 + _globals['_ACTIONSERVICE']._serialized_start=1586 + _globals['_ACTIONSERVICE']._serialized_end=1781 # @@protoc_insertion_point(module_scope) diff --git a/rasa_sdk/grpc_py/action_webhook_pb2.pyi b/rasa_sdk/grpc_py/action_webhook_pb2.pyi index c8c43431..e871cfc4 100644 --- a/rasa_sdk/grpc_py/action_webhook_pb2.pyi +++ b/rasa_sdk/grpc_py/action_webhook_pb2.pyi @@ -12,16 +12,9 @@ class ActionsRequest(_message.Message): class ActionsResponse(_message.Message): __slots__ = ["actions"] - class ActionsEntry(_message.Message): - __slots__ = ["key", "value"] - KEY_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - key: str - value: str - def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... ACTIONS_FIELD_NUMBER: _ClassVar[int] - actions: _containers.ScalarMap[str, str] - def __init__(self, actions: _Optional[_Mapping[str, str]] = ...) -> None: ... + actions: _containers.RepeatedCompositeFieldContainer[_struct_pb2.Struct] + def __init__(self, actions: _Optional[_Iterable[_Union[_struct_pb2.Struct, _Mapping]]] = ...) -> None: ... class Tracker(_message.Message): __slots__ = ["sender_id", "slots", "latest_message", "events", "paused", "followup_action", "active_loop", "latest_action_name", "stack"] diff --git a/rasa_sdk/grpc_server.py b/rasa_sdk/grpc_server.py index f3f875d8..1b71060e 100644 --- a/rasa_sdk/grpc_server.py +++ b/rasa_sdk/grpc_server.py @@ -94,7 +94,11 @@ def __init__( self.auto_reload = auto_reload self.executor = executor - async def Actions(self, request: ActionsRequest, context) -> ActionsResponse: + async def Actions( + self, + request: ActionsRequest, + context: grpc.aio.ServicerContext, + ) -> ActionsResponse: """Handle RPC request for the actions. Args: @@ -107,9 +111,14 @@ async def Actions(self, request: ActionsRequest, context) -> ActionsResponse: if self.auto_reload: self.executor.reload() - actions = self.executor.list_actions() + actions = [action.model_dump() for action in self.executor.list_actions()] response = ActionsResponse() - return ParseDict(actions, response) + return ParseDict( + { + "actions": actions, + }, + response, + ) async def Webhook( self, diff --git a/tests/conftest.py b/tests/conftest.py index 22eac5ff..1d82a438 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ import pytest +import rasa_sdk from rasa_sdk import Action, FormValidationAction, Tracker, ValidationAction from rasa_sdk.events import SlotSet from rasa_sdk.executor import CollectingDispatcher @@ -13,6 +14,7 @@ def get_stack(): + """Return a dialogue stack.""" dialogue_stack = [ { "frame_id": "CP6JP9GQ", @@ -147,3 +149,15 @@ def name(self): class SubclassTestActionB(SubclassTestActionA): def name(self): return "subclass_test_action_b" + + +@pytest.fixture +def current_rasa_version() -> str: + """Return current Rasa version.""" + return rasa_sdk.__version__ + + +@pytest.fixture +def previous_rasa_version() -> str: + """Return previous Rasa version.""" + return "1.0.0" diff --git a/tests/test_grpc_server.py b/tests/test_grpc_server.py new file mode 100644 index 00000000..9c9f02b9 --- /dev/null +++ b/tests/test_grpc_server.py @@ -0,0 +1,245 @@ +from typing import Union, Any, Dict, Text, List +from unittest.mock import MagicMock, AsyncMock + +import grpc +import pytest +from google.protobuf.json_format import MessageToDict, ParseDict + +from rasa_sdk import ActionExecutionRejection +from rasa_sdk.executor import ActionName, ActionExecutor +from rasa_sdk.grpc_errors import ( + ActionExecutionFailed, + ResourceNotFound, + ResourceNotFoundType, +) +from rasa_sdk.grpc_py import action_webhook_pb2 +from rasa_sdk.grpc_server import GRPCActionServerWebhook +from rasa_sdk.interfaces import ActionMissingDomainException, ActionNotFoundException + + +@pytest.fixture +def sender_id() -> str: + """Return sender id.""" + return "test_sender_id" + + +@pytest.fixture +def action_name() -> str: + """Return action name.""" + return "action_listen" + + +@pytest.fixture +def grpc_webhook_request( + sender_id: str, + action_name: str, + current_rasa_version: str, +) -> action_webhook_pb2.WebhookRequest: + """Create a webhook request.""" + return action_webhook_pb2.WebhookRequest( + next_action=action_name, + sender_id=sender_id, + tracker=action_webhook_pb2.Tracker( + sender_id=sender_id, + slots={}, + latest_message={}, + events=[], + paused=False, + followup_action="", + active_loop={}, + latest_action_name="", + stack={}, + ), + domain=action_webhook_pb2.Domain( + config={}, + session_config={}, + intents=[], + entities=[], + slots={}, + responses={}, + actions=[], + forms={}, + e2e_actions=[], + ), + version=current_rasa_version, + domain_digest="", + ) + + +@pytest.fixture +def mock_executor() -> AsyncMock: + """Create a mock action executor.""" + return AsyncMock(spec=ActionExecutor) + + +@pytest.fixture +def mock_grpc_service_context() -> MagicMock: + """Create a mock gRPC service context.""" + return MagicMock(spec=grpc.aio.ServicerContext) + + +@pytest.fixture +def grpc_action_server_webhook(mock_executor: AsyncMock) -> GRPCActionServerWebhook: + """Create a GRPCActionServerWebhook instance with a mock executor.""" + return GRPCActionServerWebhook(executor=mock_executor) + + +@pytest.fixture +def executor_response() -> Dict[Text, Any]: + """Create an executor response.""" + return { + "events": [{"event": "slot", "name": "test", "value": "foo"}], + "responses": [{"utter": "Hi"}], + } + + +@pytest.fixture +def expected_grpc_webhook_response( + executor_response: Dict[Text, Any], +) -> action_webhook_pb2.WebhookResponse: + """Create a gRPC webhook response.""" + result = action_webhook_pb2.WebhookResponse() + return ParseDict(executor_response, result) + + +def action_names() -> List[ActionName]: + """Create a list of action names.""" + return [ + ActionName(name="action_listen"), + ActionName(name="action_restart"), + ActionName(name="action_session_start"), + ] + + +def expected_grpc_actions_response() -> action_webhook_pb2.ActionsResponse: + """Create a gRPC actions response.""" + actions = [action.model_dump() for action in action_names()] + result = action_webhook_pb2.ActionsResponse() + return ParseDict( + { + "actions": actions, + }, + result, + ) + + +@pytest.mark.parametrize( + "auto_reload, expected_reload_call_count", [(True, 1), (False, 0)] +) +async def test_grpc_action_server_webhook_no_errors( + auto_reload: bool, + expected_reload_call_count: int, + grpc_action_server_webhook: GRPCActionServerWebhook, + grpc_webhook_request: action_webhook_pb2.WebhookRequest, + mock_executor: AsyncMock, + mock_grpc_service_context: MagicMock, + executor_response: Dict[Text, Any], + expected_grpc_webhook_response: action_webhook_pb2.WebhookResponse, +): + """Test that the gRPC action server webhook can handle a request without errors.""" + grpc_action_server_webhook.auto_reload = auto_reload + mock_executor.run.return_value = executor_response + response = await grpc_action_server_webhook.Webhook( + grpc_webhook_request, + mock_grpc_service_context, + ) + + assert response == expected_grpc_webhook_response + + mock_grpc_service_context.set_code.assert_not_called() + mock_grpc_service_context.set_details.assert_not_called() + + assert mock_executor.reload.call_count == expected_reload_call_count + + expected_action_call = MessageToDict( + grpc_webhook_request, + preserving_proto_field_name=True, + ) + mock_executor.run.assert_called_once_with(expected_action_call) + + +@pytest.mark.parametrize( + "exception, expected_status_code, expected_body", + [ + ( + ActionExecutionRejection("action_name", "message"), + grpc.StatusCode.INTERNAL, + ActionExecutionFailed( + action_name="action_name", message="message" + ).model_dump_json(), + ), + ( + ActionNotFoundException("action_name", "message"), + grpc.StatusCode.NOT_FOUND, + ResourceNotFound( + action_name="action_name", + message="message", + resource_type=ResourceNotFoundType.ACTION, + ).model_dump_json(), + ), + ( + ActionMissingDomainException("action_name", "message"), + grpc.StatusCode.NOT_FOUND, + ResourceNotFound( + action_name="action_name", + message="message", + resource_type=ResourceNotFoundType.DOMAIN, + ).model_dump_json(), + ), + ], +) +async def test_grpc_action_server_webhook_action_execution_rejected( + exception: Union[ + ActionExecutionRejection, ActionNotFoundException, ActionMissingDomainException + ], + expected_status_code: grpc.StatusCode, + expected_body: str, + grpc_action_server_webhook: GRPCActionServerWebhook, + grpc_webhook_request: action_webhook_pb2.WebhookRequest, + mock_executor: AsyncMock, + mock_grpc_service_context: MagicMock, +): + """Test that the gRPC action server webhook can handle a request with an action execution rejection.""" # noqa: E501 + mock_executor.run.side_effect = exception + response = await grpc_action_server_webhook.Webhook( + grpc_webhook_request, + mock_grpc_service_context, + ) + + assert response == action_webhook_pb2.WebhookResponse() + + mock_grpc_service_context.set_code.assert_called_once_with(expected_status_code) + mock_grpc_service_context.set_details.assert_called_once_with(expected_body) + + +@pytest.mark.parametrize( + "given_action_names, expected_grpc_actions_response", + [ + ( + [], + action_webhook_pb2.ActionsResponse(), + ), + ( + action_names(), + expected_grpc_actions_response(), + ), + ], +) +async def test_grpc_action_server_actions( + given_action_names: List[ActionName], + expected_grpc_actions_response: action_webhook_pb2.ActionsResponse, + grpc_action_server_webhook: GRPCActionServerWebhook, + mock_grpc_service_context: MagicMock, + mock_executor: AsyncMock, +): + """Test that the gRPC action server webhook can handle a request for actions.""" + mock_executor.list_actions.return_value = given_action_names + + response = await grpc_action_server_webhook.Actions( + action_webhook_pb2.ActionsRequest(), mock_grpc_service_context + ) + + assert response == expected_grpc_actions_response + + mock_grpc_service_context.set_code.assert_not_called() + mock_grpc_service_context.set_details.assert_not_called()