Skip to content

Commit

Permalink
update some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ancalita committed May 14, 2024
1 parent 0f24207 commit 85e3233
Show file tree
Hide file tree
Showing 12 changed files with 189 additions and 166 deletions.
58 changes: 0 additions & 58 deletions tests/test_actions.py

This file was deleted.

File renamed without changes.
132 changes: 132 additions & 0 deletions tests/test_actions/test_actions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from typing import List, Dict, Text, Any

import pytest

from rasa_sdk import Action, FormValidationAction, Tracker, ValidationAction
from rasa_sdk.events import SlotSet
from rasa_sdk.executor import CollectingDispatcher
from rasa_sdk.types import DomainDict


class CustomAsyncAction(Action):
def name(cls) -> Text:
return "custom_async_action"

async def run(
self,
dispatcher: CollectingDispatcher,
tracker: Tracker,
domain: DomainDict,
) -> List[Dict[Text, Any]]:
return [SlotSet("test", "foo"), SlotSet("test2", "boo")]


class CustomAction(Action):
def name(cls) -> Text:
return "custom_action"

def run(
self,
dispatcher: CollectingDispatcher,
tracker: Tracker,
domain: DomainDict,
) -> List[Dict[Text, Any]]:
return [SlotSet("test", "bar")]


class CustomActionRaisingException(Action):
def name(cls) -> Text:
return "custom_action_exception"

def run(
self,
dispatcher: CollectingDispatcher,
tracker: Tracker,
domain: DomainDict,
) -> List[Dict[Text, Any]]:
raise Exception("test exception")


class CustomActionWithDialogueStack(Action):
def name(cls) -> Text:
return "custom_action_with_dialogue_stack"

def run(
self,
dispatcher: CollectingDispatcher,
tracker: Tracker,
domain: DomainDict,
) -> List[Dict[Text, Any]]:
return [SlotSet("stack", tracker.stack)]


class MockFormValidationAction(FormValidationAction):
def __init__(self) -> None:
self.fail_if_undefined("run")

def fail_if_undefined(self, method_name: str) -> None:
if not (
hasattr(self.__class__.__base__, method_name)
and callable(getattr(self.__class__.__base__, method_name))
):
pytest.fail(
f"method '{method_name}' not found in {self.__class__.__base__}. "
f"This likely means the method was renamed, which means the "
f"instrumentation needs to be adapted!"
)

async def _extract_validation_events(
self,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
) -> None:
return tracker.events

def name(self) -> str:
return "mock_form_validation_action"


class MockValidationAction(ValidationAction):
def __init__(self) -> None:
self.fail_if_undefined("run")

def fail_if_undefined(self, method_name: Text) -> None:
if not (
hasattr(self.__class__.__base__, method_name)
and callable(getattr(self.__class__.__base__, method_name))
):
pytest.fail(
f"method '{method_name}' not found in {self.__class__.__base__}. "
f"This likely means the method was renamed, which means the "
f"instrumentation needs to be adapted!"
)

async def run(
self,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
) -> None:
pass

def name(self) -> Text:
return "mock_validation_action"

async def _extract_validation_events(
self,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
) -> None:
return tracker.events


class SubclassTestActionA(Action):
def name(self):
return "subclass_test_action_a"


class SubclassTestActionB(SubclassTestActionA):
def name(self):
return "subclass_test_action_b"
51 changes: 30 additions & 21 deletions tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,74 +15,81 @@


@pytest.fixture
def app():
return ep.create_app("tests.test_actions")
def sanic_app():
return ep.create_app("tests")


def test_endpoint_exit_for_unknown_actions_package():
with pytest.raises(SystemExit):
ep.create_app("non-existing-actions-package")


def test_server_health_returns_200(app: Sanic):
request, response = app.test_client.get("/health")
def test_server_health_returns_200(sanic_app: Sanic):
request, response = sanic_app.test_client.get("/health")
assert response.status == 200
assert response.json == {"status": "ok"}


def test_server_list_actions_returns_200(app: Sanic):
request, response = app.test_client.get("/actions")
def test_server_list_actions_returns_200(sanic_app: Sanic):
request, response = sanic_app.test_client.get("/actions")
assert response.status == 200
assert len(response.json) == 4

assert len(response.json) == 9
print(response.json)
expected = [
# defined in tests/test_actions.py
# defined in tests/test_actions
{"name": "custom_async_action"},
{"name": "custom_action"},
{"name": "custom_action_exception"},
{"name": "custom_action_with_dialogue_stack"},
{"name": "subclass_test_action_a"},
{"name": "mock_validation_action"},
{"name": "mock_form_validation_action"},
# defined in tests/test_forms.py
{"name": "some_form"},
# defined in tests/test_actions
{"name": "subclass_test_action_b"},
]
assert response.json == expected


def test_server_webhook_unknown_action_returns_404(app: Sanic):
def test_server_webhook_unknown_action_returns_404(sanic_app: Sanic):
data = {
"next_action": "test_action_1",
"tracker": {"sender_id": "1", "conversation_id": "default"},
}
request, response = app.test_client.post("/webhook", data=json.dumps(data))
request, response = sanic_app.test_client.post("/webhook", data=json.dumps(data))
assert response.status == 404


def test_server_webhook_handles_action_exception(app: Sanic):
def test_server_webhook_handles_action_exception(sanic_app: Sanic):
data = {
"next_action": "custom_action_exception",
"tracker": {"sender_id": "1", "conversation_id": "default"},
}
request, response = app.test_client.post("/webhook", data=json.dumps(data))
request, response = sanic_app.test_client.post("/webhook", data=json.dumps(data))
assert response.status == 500
assert response.json.get("error") == "test exception"
assert response.json.get("request_body") == data


def test_server_webhook_custom_action_returns_200(app: Sanic):
def test_server_webhook_custom_action_returns_200(sanic_app: Sanic):
data = {
"next_action": "custom_action",
"tracker": {"sender_id": "1", "conversation_id": "default"},
}
request, response = app.test_client.post("/webhook", data=json.dumps(data))
request, response = sanic_app.test_client.post("/webhook", data=json.dumps(data))
events = response.json.get("events")

assert events == [SlotSet("test", "bar")]
assert response.status == 200


def test_server_webhook_custom_async_action_returns_200(app: Sanic):
def test_server_webhook_custom_async_action_returns_200(sanic_app: Sanic):
data = {
"next_action": "custom_async_action",
"tracker": {"sender_id": "1", "conversation_id": "default"},
}
request, response = app.test_client.post("/webhook", data=json.dumps(data))
request, response = sanic_app.test_client.post("/webhook", data=json.dumps(data))
events = response.json.get("events")

assert events == [SlotSet("test", "foo"), SlotSet("test2", "boo")]
Expand All @@ -108,14 +115,14 @@ def test_arg_parser_actions_params_module_style():
assert cmdline_args.actions == "actions.act"


def test_server_webhook_custom_action_encoded_data_returns_200(app: Sanic):
def test_server_webhook_custom_action_encoded_data_returns_200(sanic_app: Sanic):
data = {
"next_action": "custom_action",
"tracker": {"sender_id": "1", "conversation_id": "default"},
"domain": {"intents": ["greet", "goodbye"]},
}

request, response = app.test_client.post(
request, response = sanic_app.test_client.post(
"/webhook",
data=zlib.compress(json.dumps(data).encode()),
headers={"Content-encoding": "deflate"},
Expand All @@ -134,13 +141,15 @@ def test_server_webhook_custom_action_encoded_data_returns_200(app: Sanic):
],
)
def test_server_webhook_custom_action_with_dialogue_stack_returns_200(
stack_state: Dict[Text, Any], dialogue_stack: List[Dict[Text, Any]], app: Sanic
stack_state: Dict[Text, Any],
dialogue_stack: List[Dict[Text, Any]],
sanic_app: Sanic,
):
data = {
"next_action": "custom_action_with_dialogue_stack",
"tracker": {"sender_id": "1", "conversation_id": "default", **stack_state},
}
_, response = app.test_client.post("/webhook", data=json.dumps(data))
_, response = sanic_app.test_client.post("/webhook", data=json.dumps(data))
events = response.json.get("events")

assert events == [SlotSet("stack", dialogue_stack)]
Expand Down
12 changes: 1 addition & 11 deletions tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import string
import time

from rasa_sdk import Action
from typing import Text, Optional, Generator

import pytest
from rasa_sdk.executor import ActionExecutor, CollectingDispatcher
from tests.test_actions.test_actions import SubclassTestActionA, SubclassTestActionB

TEST_PACKAGE_BASE = "tests/executor_test_packages"

Expand Down Expand Up @@ -237,16 +237,6 @@ async def test_reload_module(
}


class SubclassTestActionA(Action):
def name(self):
return "subclass_test_action_a"


class SubclassTestActionB(SubclassTestActionA):
def name(self):
return "subclass_test_action_b"


def test_load_subclasses(executor: ActionExecutor):
executor.register_action(SubclassTestActionB)
assert list(executor.actions) == ["subclass_test_action_b"]
Expand Down
5 changes: 3 additions & 2 deletions tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_plugin_attach_sanic_app_extension(
monkeypatch.setattr(
manager.hook, "attach_sanic_app_extensions", MagicMock(return_value=None)
)
monkeypatch.setattr("rasa_sdk.endpoint.Sanic.serve", MagicMock(return_value=None))
app_mock = MagicMock()

# Create a MagicMock object to replace the create_app() method
Expand All @@ -42,8 +43,8 @@ def test_plugin_attach_sanic_app_extension(
# Set the create_app() method to return create_app_mock
monkeypatch.setattr("rasa_sdk.endpoint.create_app", create_app_mock)

# Set the return value of app_mock.run() to None
app_mock.run.return_value = None
# Set the return value of app_mock.prepare() to None
app_mock.prepare.return_value = None

with warnings.catch_warnings():
warnings.simplefilter("error")
Expand Down
Empty file.
Loading

0 comments on commit 85e3233

Please sign in to comment.