diff --git a/changelog/922.improvement.md b/changelog/922.improvement.md new file mode 100644 index 000000000..b0f0194ea --- /dev/null +++ b/changelog/922.improvement.md @@ -0,0 +1,29 @@ +## Problem +Rasa knowledge base actions cannot infer the object type of an object directly from the user message without the user first asking to list all objects related to that object type. This prevents action_query_knowledge from providing a suitable response when a user asks for a certain attribute of an object even though the knowledge base has the relevant information. That is, the knowledge base actions require the slot `object_type` to be set to one of the primary key values in the knowledge base for it to search through the objects. Here is an example: +``` +Your input -> what is the price range of Berlin Burrito Company? +Sorry, I'm not sure I understand. Can you rephrase? +Your input -> list some restaurants +Found the following objects of type 'restaurant': +1: Gong Gan +2: I due forni +3: Pfefferberg +4: Lụa Restaurant +5: Donath +Your input -> what is the price range of Berlin Burrito Company? +'Berlin Burrito Company' has the value 'cheap' for attribute 'price-range'. +``` + +## Proposed solution +- The improvement requires changes to the classes ActionQueryKnowledgeBase and InMemoryKnowledgeBase under rasa-sdk. +- The `object_type` can be inferred by utilizing the entity extraction (DIET) where object types are used as entities to annotate object names. +This also requires changes to be made to slot management to enable dynamic inference of `object_type`. +- The scope of the suggested solution is limited to user queries where they ask for an attribute of a given object without mentioning the object type and without needing to first ask for a list of options of the corresponding object type. +- E.g.: If the user asks for ‘price range of Berlin Burrito Company’, then rasa will extract and set attribute slot value to ‘price-range’ and hotel slot value to ‘Berlin Burrito Company’. From this, it can be inferred that the user is talking about the object type ‘hotel’. + +## Summary of Changes +- To enable the inference of `object_type` using the entities the following changes were made to the existing code base: + - Extract the list of object_types from our knowledge base using a new method `get_object_types()` in `storage.py` for the `InMemoryKnowledgeBase` class. + - A new method named `match_extracted_entities_to_object_type()` was added in `utils.py` to infer the object type of a given object using the entities and list of object types + - The relevant logic was added in `actions.py` to infer the object type using the above functionalities when the object type slot is not set. + - To enable dynamic inference of `object_type`, changes to slot management are also required. Currently, the change is to reset the `object_type` slot to `None` after every conversation turn. \ No newline at end of file diff --git a/rasa_sdk/knowledge_base/actions.py b/rasa_sdk/knowledge_base/actions.py index af61e94ae..097bd5031 100644 --- a/rasa_sdk/knowledge_base/actions.py +++ b/rasa_sdk/knowledge_base/actions.py @@ -13,6 +13,7 @@ SLOT_LISTED_OBJECTS, get_object_name, get_attribute_slots, + match_extracted_entities_to_object_type, ) from rasa_sdk import utils from rasa_sdk.executor import CollectingDispatcher @@ -112,13 +113,12 @@ async def run( tracker: Tracker, domain: "DomainDict", ) -> List[Dict[Text, Any]]: - """Executes this action. - - If the user ask a question about an attribute, + """ + Executes this action. If the user ask a question about an attribute, the knowledge base is queried for that attribute. Otherwise, if no - attribute was detected in the request or the user is talking about a new - object type, multiple objects of the requested type are returned from the - knowledge base. + attribute was detected in the latest request it assumes user is talking + about a new object type and, multiple objects of the requested type are + returned from the knowledge base. Args: dispatcher: the dispatcher @@ -131,22 +131,37 @@ async def run( object_type = tracker.get_slot(SLOT_OBJECT_TYPE) last_object_type = tracker.get_slot(SLOT_LAST_OBJECT_TYPE) attribute = tracker.get_slot(SLOT_ATTRIBUTE) + has_mention = tracker.get_slot(SLOT_MENTION) is not None - new_request = object_type != last_object_type + # check if attribute entity is found in latest user message. This is used + # to track whether the request is to query objects or query attributes + has_attribute_in_latest_message = any( + entity.get("entity") == "attribute" + for entity in tracker.latest_message["entities"] + ) if not object_type: - # object type always needs to be set as this is needed to query the - # knowledge base - dispatcher.utter_message(response="utter_ask_rephrase") - return [] - - if not attribute or new_request: + # sets the object type dynamically from entities if object_type is not + # found in user query + object_types = self.knowledge_base.get_object_types() + object_type = match_extracted_entities_to_object_type(tracker, object_types) + set_object_type_slot_event = [SlotSet(SLOT_OBJECT_TYPE, object_type)] + tracker.add_slots( + set_object_type_slot_event + ) # temporarily set the `object_type_slot` to extracted value + + if object_type and not has_attribute_in_latest_message: return await self._query_objects(dispatcher, object_type, tracker) - elif attribute: + elif object_type and attribute: return await self._query_attribute( dispatcher, object_type, attribute, tracker ) + if last_object_type and has_mention and attribute: + return await self._query_attribute( + dispatcher, last_object_type, attribute, tracker + ) + dispatcher.utter_message(response="utter_ask_rephrase") return [] @@ -167,7 +182,6 @@ async def _query_objects( object_attributes = await utils.call_potential_coroutine( self.knowledge_base.get_attributes_of_object(object_type) ) - # get all set attribute slots of the object type to be able to filter the # list of objects attributes = get_attribute_slots(tracker, object_attributes) @@ -175,7 +189,6 @@ async def _query_objects( objects = await utils.call_potential_coroutine( self.knowledge_base.get_objects(object_type, attributes) ) - await utils.call_potential_coroutine( self.utter_objects(dispatcher, object_type, objects) ) @@ -189,8 +202,13 @@ async def _query_objects( last_object = None if len(objects) > 1 else objects[0][key_attribute] + # To prevent the user to first ask to list the objects for an object type, + # the object type has to be extracted while the action is executed. + # Therefore we need to reset the SLOT_OBJECT_TYPE to + # None to enable this functionality. + slots = [ - SlotSet(SLOT_OBJECT_TYPE, object_type), + SlotSet(SLOT_OBJECT_TYPE, None), SlotSet(SLOT_MENTION, None), SlotSet(SLOT_ATTRIBUTE, None), SlotSet(SLOT_LAST_OBJECT, last_object), @@ -219,7 +237,6 @@ async def _query_attribute( Returns: list of slots """ - object_name = get_object_name( tracker, self.knowledge_base.ordinal_mention_mapping, @@ -258,8 +275,13 @@ async def _query_attribute( ) ) + # To prevent the user to first ask to list the objects for an object type, + # the object type has to be extracted while the action is executed. + # Therefore we need to reset the SLOT_OBJECT_TYPE to + # None to enable this functionality. + slots = [ - SlotSet(SLOT_OBJECT_TYPE, object_type), + SlotSet(SLOT_OBJECT_TYPE, None), SlotSet(SLOT_ATTRIBUTE, None), SlotSet(SLOT_MENTION, None), SlotSet(SLOT_LAST_OBJECT, object_identifier), diff --git a/rasa_sdk/knowledge_base/storage.py b/rasa_sdk/knowledge_base/storage.py index 3b4a0c89b..22f31e853 100644 --- a/rasa_sdk/knowledge_base/storage.py +++ b/rasa_sdk/knowledge_base/storage.py @@ -12,7 +12,6 @@ class KnowledgeBase: def __init__(self) -> None: - self.ordinal_mention_mapping = { "1": lambda lst: lst[0], "2": lambda lst: lst[1], @@ -110,6 +109,12 @@ async def get_object( """ raise NotImplementedError("Method is not implemented.") + def get_object_types(self) -> List[Text]: + """ + Returns a list of object types from knowledge base data. + """ + raise NotImplementedError("Method is not implemented.") + class InMemoryKnowledgeBase(KnowledgeBase): def __init__(self, data_file: Text) -> None: @@ -206,7 +211,6 @@ async def get_object( return None objects = self.data[object_type] - key_attribute = await utils.call_potential_coroutine( self.get_key_attribute_of_object(object_type) ) @@ -242,3 +246,7 @@ async def get_object( return None return objects_of_interest[0] + + def get_object_types(self) -> List[Text]: + """See parent class docstring.""" + return list(self.data.keys()) diff --git a/rasa_sdk/knowledge_base/utils.py b/rasa_sdk/knowledge_base/utils.py index 7594b9c96..24011ff89 100644 --- a/rasa_sdk/knowledge_base/utils.py +++ b/rasa_sdk/knowledge_base/utils.py @@ -82,7 +82,6 @@ def resolve_mention( listed_items = tracker.get_slot(SLOT_LISTED_OBJECTS) last_object = tracker.get_slot(SLOT_LAST_OBJECT) last_object_type = tracker.get_slot(SLOT_LAST_OBJECT_TYPE) - current_object_type = tracker.get_slot(SLOT_OBJECT_TYPE) if not mention: return None @@ -95,7 +94,9 @@ def resolve_mention( # for now we just assume that if the user refers to an object, for # example via "it" or "that restaurant", they are actually referring to the last # object that was detected. - if current_object_type == last_object_type: + # Since object type slot is reset to 'None' value, it is sufficient to only check + # whether the last_object_type is not None. + if last_object_type: return last_object return None @@ -164,3 +165,29 @@ def reset_attribute_slots( slots.append(SlotSet(attr, None)) return slots + + +def match_extracted_entities_to_object_type( + tracker: "Tracker", + object_types: List[Text], +) -> Optional[Text]: + """ + If the user ask a question about an attribute using an object name and + without specifying the object type, then this function searches the + corresponding object type. (e.g: when user asks'price range of B&B', this + function extracts the object type as 'hotel'). Here we assume that the user + message contains reference only to one object type in the knowledge base. + + Args: + tracker: the tracker + object_types: list of object types in the knowledge base + + Returns: the name of the object type if found, otherwise `None`. + """ + entities = tracker.latest_message.get("entities", []) + entity_names = [entity.get("entity") for entity in entities] + for entity in entity_names: + if entity in object_types: + return entity + + return None diff --git a/scripts/release.py b/scripts/release.py index 01af33b53..6e7010cfe 100644 --- a/scripts/release.py +++ b/scripts/release.py @@ -161,7 +161,8 @@ def is_valid_version_number(v: Text) -> bool: str(current_version.next_release_candidate("major")), ] version = questionary.select( - f"Which {version} do you want to release?", choices=choices, + f"Which {version} do you want to release?", + choices=choices, ).ask() if version: @@ -255,8 +256,8 @@ def next_version(args: argparse.Namespace) -> Version: def generate_changelog(version: Version) -> None: """Call towncrier and create a changelog from all available changelog entries.""" check_call( - ["towncrier", "build", "--yes", "--version", str(version)], cwd=str(project_root()) - + ["towncrier", "build", "--yes", "--version", str(version)], + cwd=str(project_root()), ) diff --git a/tests/knowledge_base/test_actions.py b/tests/knowledge_base/test_actions.py index 6ae808348..b962ac70e 100644 --- a/tests/knowledge_base/test_actions.py +++ b/tests/knowledge_base/test_actions.py @@ -1,4 +1,5 @@ import pytest +import copy from rasa_sdk import Tracker from rasa_sdk.events import SlotSet @@ -28,9 +29,16 @@ def compare_slots(slot_list_1, slot_list_2): @pytest.mark.parametrize( - "slots,expected_slots", + "latest_message,slots,expected_slots", [ ( + { + "entities": [ + { + "entity": "object_type", + }, + ], + }, { SLOT_MENTION: None, SLOT_ATTRIBUTE: None, @@ -42,26 +50,36 @@ def compare_slots(slot_list_1, slot_list_2): [ SlotSet(SLOT_MENTION, None), SlotSet(SLOT_ATTRIBUTE, None), - SlotSet(SLOT_OBJECT_TYPE, "restaurant"), + SlotSet(SLOT_OBJECT_TYPE, None), SlotSet(SLOT_LAST_OBJECT, None), SlotSet(SLOT_LAST_OBJECT_TYPE, "restaurant"), SlotSet(SLOT_LISTED_OBJECTS, [3, 2, 1]), ], ), ( + { + "entities": [ + { + "entity": "object_type", + }, + { + "entity": "cuisine", + }, + ], + }, { SLOT_MENTION: None, SLOT_ATTRIBUTE: None, SLOT_OBJECT_TYPE: "restaurant", SLOT_LISTED_OBJECTS: None, SLOT_LAST_OBJECT: None, - SLOT_LAST_OBJECT_TYPE: "restaurant", + SLOT_LAST_OBJECT_TYPE: None, "cuisine": "Italian", }, [ SlotSet(SLOT_MENTION, None), SlotSet(SLOT_ATTRIBUTE, None), - SlotSet(SLOT_OBJECT_TYPE, "restaurant"), + SlotSet(SLOT_OBJECT_TYPE, None), SlotSet(SLOT_LAST_OBJECT, None), SlotSet(SLOT_LAST_OBJECT_TYPE, "restaurant"), SlotSet(SLOT_LISTED_OBJECTS, [3, 1]), @@ -69,10 +87,20 @@ def compare_slots(slot_list_1, slot_list_2): ], ), ( + { + "entities": [ + { + "entity": "attribute", + }, + { + "entity": " mention", + }, + ], + }, { SLOT_MENTION: "2", SLOT_ATTRIBUTE: "cuisine", - SLOT_OBJECT_TYPE: "restaurant", + SLOT_OBJECT_TYPE: None, SLOT_LISTED_OBJECTS: [1, 2, 3], SLOT_LAST_OBJECT: None, SLOT_LAST_OBJECT_TYPE: "restaurant", @@ -80,16 +108,26 @@ def compare_slots(slot_list_1, slot_list_2): [ SlotSet(SLOT_MENTION, None), SlotSet(SLOT_ATTRIBUTE, None), - SlotSet(SLOT_OBJECT_TYPE, "restaurant"), + SlotSet(SLOT_OBJECT_TYPE, None), SlotSet(SLOT_LAST_OBJECT, 2), SlotSet(SLOT_LAST_OBJECT_TYPE, "restaurant"), ], ), ( + { + "entities": [ + { + "entity": "attribute", + }, + { + "entity": "restaurant", + }, + ], + }, { SLOT_MENTION: None, SLOT_ATTRIBUTE: "cuisine", - SLOT_OBJECT_TYPE: "restaurant", + SLOT_OBJECT_TYPE: None, SLOT_LISTED_OBJECTS: [1, 2, 3], SLOT_LAST_OBJECT: None, SLOT_LAST_OBJECT_TYPE: "restaurant", @@ -98,12 +136,15 @@ def compare_slots(slot_list_1, slot_list_2): [ SlotSet(SLOT_MENTION, None), SlotSet(SLOT_ATTRIBUTE, None), - SlotSet(SLOT_OBJECT_TYPE, "restaurant"), + SlotSet(SLOT_OBJECT_TYPE, None), SlotSet(SLOT_LAST_OBJECT, 1), SlotSet(SLOT_LAST_OBJECT_TYPE, "restaurant"), ], ), ( + { + "entities": [], + }, { SLOT_MENTION: None, SLOT_ATTRIBUTE: None, @@ -116,12 +157,20 @@ def compare_slots(slot_list_1, slot_list_2): ), ], ) -async def test_action_run(data_file, slots, expected_slots): +async def test_action_run(data_file, latest_message, slots, expected_slots): knowledge_base = InMemoryKnowledgeBase(data_file) action = ActionQueryKnowledgeBase(knowledge_base) dispatcher = CollectingDispatcher() - tracker = Tracker("default", slots, {}, [], False, None, {}, "action_listen") + + tracker = Tracker( + "default", slots, latest_message, [], False, None, {}, "action_listen" + ) + + # To prevent unintended modifications, create a copy of the slots dictionary + # before passing it to the Tracker object, as dictionaries in + # Python are mutable and passed by reference. + initial_slots = copy.deepcopy(slots) actual_slots = await action.run(dispatcher, tracker, {}) @@ -131,7 +180,6 @@ async def test_action_run(data_file, slots, expected_slots): # Check that utterances produced by action are correct. if slots[SLOT_ATTRIBUTE]: if slots.get("restaurant") is not None: - name = slots["restaurant"] attr = slots[SLOT_ATTRIBUTE] obj = await knowledge_base.get_object("restaurant", name) @@ -152,3 +200,44 @@ async def test_action_run(data_file, slots, expected_slots): actual_msg = dispatcher.messages[0]["text"] assert actual_msg == expected_msg + + # Check that temporary slot setting by action is correct. + if any(initial_slots.values()): + # The condition block below checks for user message + # such as "what is the price range of Pasta Bar?". + # This user message example is denoted by test case `4`. + if ( + initial_slots.get(SLOT_OBJECT_TYPE) is None + and initial_slots.get(SLOT_MENTION) is None + ): + # Since Pasta Bar belongs to `restaurant` object type + # the tracker event passed to set the slot temporarily + # should look like this. + expected_tracker_event = { + "event": "slot", + "timestamp": None, + "name": "object_type", + "value": "restaurant", + } + + assert expected_tracker_event in tracker.events + + # The condition block below checks for user message + # such as "what is the cuisine of second one?". + # This user message example is denoted by test case `3`. + elif ( + initial_slots.get(SLOT_OBJECT_TYPE) is None + and initial_slots.get(SLOT_MENTION) is not None + ): + # Since there is no `restaurant` entity in the user message, + # the `object_type` will be `None`. + # Therefore, the tracker event passed to set the slot temporarily + # should look like this. + expected_tracker_event = { + "event": "slot", + "timestamp": None, + "name": "object_type", + "value": None, + } + + assert expected_tracker_event in tracker.events diff --git a/tests/knowledge_base/test_utils.py b/tests/knowledge_base/test_utils.py index 48d9ab67a..e5438939d 100644 --- a/tests/knowledge_base/test_utils.py +++ b/tests/knowledge_base/test_utils.py @@ -3,6 +3,7 @@ from rasa_sdk import Tracker from rasa_sdk.events import SlotSet from rasa_sdk.knowledge_base.utils import ( + match_extracted_entities_to_object_type, get_attribute_slots, reset_attribute_slots, get_object_name, @@ -129,3 +130,28 @@ def test_get_object_name(slots, use_last_object_mention, expected_object_name): ) assert actual_object_name == expected_object_name + + +@pytest.mark.parametrize( + "latest_message,object_types,expected_object_name", + [ + ( + { + "entities": [ + {"entity": "attribute"}, + {"entity": "restaurant"}, + ], + }, + ["hotel", "restaurant"], + "restaurant", + ), + ], +) +def test_match_extracted_entities_to_object_type( + latest_message, object_types, expected_object_name +): + tracker = Tracker( + "default", {}, latest_message, [], False, None, {}, "action_listen" + ) + actual_object_name = match_extracted_entities_to_object_type(tracker, object_types) + assert actual_object_name == expected_object_name