From 2e32306fe26a89d6438d25049742e025026acca0 Mon Sep 17 00:00:00 2001 From: Thirunayan <56014038+Thirunayan22@users.noreply.github.com> Date: Tue, 6 Jun 2023 20:33:05 +0530 Subject: [PATCH] OSS-668: Rasa Knowledge base actions are unable to query about a certain attribute of an object unless the user first asks to obtain a list of objects of a specific type. (#922) * modified logic in run, slot management in _query_objects and _query_attribute functions * added get_object_types function * added get_object_type_dynamic function * updated changelog * updated changelog * modified changelog name and content * uncommented code snippet that was previously commented to experiment the contribution changes * added comprehensive typing hint * renamed to * replaced the 'get_object_type_dynamic' function name with the new function name * replaced the 'get_object_type_dynamic' function name with the new function name * modified sentences based on PR review comments * updated the 'get_object_type_dynamic' with new name * directly return entity without assigning it to an unused variable as suggested in PR review * modified line 184 to safeguard against KeyErrors as suggested in PR review * added suggestion to code change in line 185 as mentioned in the PR * added suggestion to code change in line 185 as mentioned in the PR * modified after running grammerly * Added TrackerKnowledgeBase class in tracker.py to set object type slot * modified class name * removed passing object_type value as an argument in functions * imported TrackerKnowledgeBase class | modified function based on changes done in utils.py | assigned extracted object type to object_type variable * removed get_object_types from KnowledgeBase class | removed async from get_object_types * removed await in line 136 after convering get_object_types function from async to normal * modified code to set the object type slot temporarily using the method in Tracker class * removing tracker.py which contained a method to temporarily set the object type slot * removed the import from tracker.py * removed code that gets current_object_type from the slot * converted match_extracted_entities_to_object_types function from async to normal * removed await for match_extracted_entities_to_object_types function * modified action logic after assigning extracted object type to an object_type variable * Unit test functions * Formatted with black * Added line * Removed tab * Resolved all linting issues - flake8 * Fixed Errors * formatted with black * removed entity:attribute appearing twice * removed white space to meet max line charac requirement * entered variables into new line to meet max line charac requirements * removed print statement * reverting the line indentation * formatted code with black * improved language and fixed typos * added get_object_types method in KnowledgeBase parent class * renamned entities_values to entity_names in line 188 * removed duplicated line in line 148 * modified return statement in function docstring * updated the match_extracted_entities_to_object_types doc string to mention the assumption that the user message only contains reference to one object type * renamed match_extracted_entities_to_object_types to match_extracted_entities_to_object_type * renamed match_extracted_entities_to_object_types to match_extracted_entities_to_object_type * reverting a change that wasn't made * improved doc string * changed to * added comment on slot object type resetting * modified if condition in run to check whether latest message attribute * Resolved test_actions failures * Added unit test assertion * Modified test_actions.py and actions.py * modified code logic to include in one line * formatted code with black * Fixed formatting * added comment about slot_object_type resetting * formatted code with black * fixed typos * Reformatted actions file * resolve lint code issue caused by inverted comma * added a comment on replacing new_request with has_attribute_in_latest * modified doc string in run * Fixed trailing whitespace issue * removed keys from entities payload that are not needed for test_utils.py * added function docstring in get_object_types() to look up from parent class * removed repeated info in docstring * modified comment on resetting object type slot based on review suggestion * replaced has_attribute_in_latest with has_attribute_in_latest_message * Fixed linting issues * Reformatted using black * added a rough code on testing the tracker events * removed import logging * added print statements for testing * resolved python mutable issue * removed print statements and cleaned the code * removed import copy * removed trailing white space in line 157 * removed trailing white space in line 157 * added spaces between code lines * added spaces between code lines * added space * Fixed formatting and lint code * made a copy of slots using copy.deepcopy() | modified assertion to check inside actual event list * format with black * added comments in tracker event assertion code block * formatted code --------- Co-authored-by: Gajithra Co-authored-by: Gajithira Puvanendran <66731983+Gajithra@users.noreply.github.com> Co-authored-by: Dilanka Sanjula <103400313+Dilanka96@users.noreply.github.com> Co-authored-by: Dilanka96 --- changelog/922.improvement.md | 29 +++++++ rasa_sdk/knowledge_base/actions.py | 60 ++++++++++----- rasa_sdk/knowledge_base/storage.py | 12 ++- rasa_sdk/knowledge_base/utils.py | 31 +++++++- scripts/release.py | 7 +- tests/knowledge_base/test_actions.py | 111 ++++++++++++++++++++++++--- tests/knowledge_base/test_utils.py | 26 +++++++ 7 files changed, 239 insertions(+), 37 deletions(-) create mode 100644 changelog/922.improvement.md diff --git a/changelog/922.improvement.md b/changelog/922.improvement.md new file mode 100644 index 000000000..b0f0194ea --- /dev/null +++ b/changelog/922.improvement.md @@ -0,0 +1,29 @@ +## Problem +Rasa knowledge base actions cannot infer the object type of an object directly from the user message without the user first asking to list all objects related to that object type. This prevents action_query_knowledge from providing a suitable response when a user asks for a certain attribute of an object even though the knowledge base has the relevant information. That is, the knowledge base actions require the slot `object_type` to be set to one of the primary key values in the knowledge base for it to search through the objects. Here is an example: +``` +Your input -> what is the price range of Berlin Burrito Company? +Sorry, I'm not sure I understand. Can you rephrase? +Your input -> list some restaurants +Found the following objects of type 'restaurant': +1: Gong Gan +2: I due forni +3: Pfefferberg +4: Lụa Restaurant +5: Donath +Your input -> what is the price range of Berlin Burrito Company? +'Berlin Burrito Company' has the value 'cheap' for attribute 'price-range'. +``` + +## Proposed solution +- The improvement requires changes to the classes ActionQueryKnowledgeBase and InMemoryKnowledgeBase under rasa-sdk. +- The `object_type` can be inferred by utilizing the entity extraction (DIET) where object types are used as entities to annotate object names. +This also requires changes to be made to slot management to enable dynamic inference of `object_type`. +- The scope of the suggested solution is limited to user queries where they ask for an attribute of a given object without mentioning the object type and without needing to first ask for a list of options of the corresponding object type. +- E.g.: If the user asks for ‘price range of Berlin Burrito Company’, then rasa will extract and set attribute slot value to ‘price-range’ and hotel slot value to ‘Berlin Burrito Company’. From this, it can be inferred that the user is talking about the object type ‘hotel’. + +## Summary of Changes +- To enable the inference of `object_type` using the entities the following changes were made to the existing code base: + - Extract the list of object_types from our knowledge base using a new method `get_object_types()` in `storage.py` for the `InMemoryKnowledgeBase` class. + - A new method named `match_extracted_entities_to_object_type()` was added in `utils.py` to infer the object type of a given object using the entities and list of object types + - The relevant logic was added in `actions.py` to infer the object type using the above functionalities when the object type slot is not set. + - To enable dynamic inference of `object_type`, changes to slot management are also required. Currently, the change is to reset the `object_type` slot to `None` after every conversation turn. \ No newline at end of file diff --git a/rasa_sdk/knowledge_base/actions.py b/rasa_sdk/knowledge_base/actions.py index af61e94ae..097bd5031 100644 --- a/rasa_sdk/knowledge_base/actions.py +++ b/rasa_sdk/knowledge_base/actions.py @@ -13,6 +13,7 @@ SLOT_LISTED_OBJECTS, get_object_name, get_attribute_slots, + match_extracted_entities_to_object_type, ) from rasa_sdk import utils from rasa_sdk.executor import CollectingDispatcher @@ -112,13 +113,12 @@ async def run( tracker: Tracker, domain: "DomainDict", ) -> List[Dict[Text, Any]]: - """Executes this action. - - If the user ask a question about an attribute, + """ + Executes this action. If the user ask a question about an attribute, the knowledge base is queried for that attribute. Otherwise, if no - attribute was detected in the request or the user is talking about a new - object type, multiple objects of the requested type are returned from the - knowledge base. + attribute was detected in the latest request it assumes user is talking + about a new object type and, multiple objects of the requested type are + returned from the knowledge base. Args: dispatcher: the dispatcher @@ -131,22 +131,37 @@ async def run( object_type = tracker.get_slot(SLOT_OBJECT_TYPE) last_object_type = tracker.get_slot(SLOT_LAST_OBJECT_TYPE) attribute = tracker.get_slot(SLOT_ATTRIBUTE) + has_mention = tracker.get_slot(SLOT_MENTION) is not None - new_request = object_type != last_object_type + # check if attribute entity is found in latest user message. This is used + # to track whether the request is to query objects or query attributes + has_attribute_in_latest_message = any( + entity.get("entity") == "attribute" + for entity in tracker.latest_message["entities"] + ) if not object_type: - # object type always needs to be set as this is needed to query the - # knowledge base - dispatcher.utter_message(response="utter_ask_rephrase") - return [] - - if not attribute or new_request: + # sets the object type dynamically from entities if object_type is not + # found in user query + object_types = self.knowledge_base.get_object_types() + object_type = match_extracted_entities_to_object_type(tracker, object_types) + set_object_type_slot_event = [SlotSet(SLOT_OBJECT_TYPE, object_type)] + tracker.add_slots( + set_object_type_slot_event + ) # temporarily set the `object_type_slot` to extracted value + + if object_type and not has_attribute_in_latest_message: return await self._query_objects(dispatcher, object_type, tracker) - elif attribute: + elif object_type and attribute: return await self._query_attribute( dispatcher, object_type, attribute, tracker ) + if last_object_type and has_mention and attribute: + return await self._query_attribute( + dispatcher, last_object_type, attribute, tracker + ) + dispatcher.utter_message(response="utter_ask_rephrase") return [] @@ -167,7 +182,6 @@ async def _query_objects( object_attributes = await utils.call_potential_coroutine( self.knowledge_base.get_attributes_of_object(object_type) ) - # get all set attribute slots of the object type to be able to filter the # list of objects attributes = get_attribute_slots(tracker, object_attributes) @@ -175,7 +189,6 @@ async def _query_objects( objects = await utils.call_potential_coroutine( self.knowledge_base.get_objects(object_type, attributes) ) - await utils.call_potential_coroutine( self.utter_objects(dispatcher, object_type, objects) ) @@ -189,8 +202,13 @@ async def _query_objects( last_object = None if len(objects) > 1 else objects[0][key_attribute] + # To prevent the user to first ask to list the objects for an object type, + # the object type has to be extracted while the action is executed. + # Therefore we need to reset the SLOT_OBJECT_TYPE to + # None to enable this functionality. + slots = [ - SlotSet(SLOT_OBJECT_TYPE, object_type), + SlotSet(SLOT_OBJECT_TYPE, None), SlotSet(SLOT_MENTION, None), SlotSet(SLOT_ATTRIBUTE, None), SlotSet(SLOT_LAST_OBJECT, last_object), @@ -219,7 +237,6 @@ async def _query_attribute( Returns: list of slots """ - object_name = get_object_name( tracker, self.knowledge_base.ordinal_mention_mapping, @@ -258,8 +275,13 @@ async def _query_attribute( ) ) + # To prevent the user to first ask to list the objects for an object type, + # the object type has to be extracted while the action is executed. + # Therefore we need to reset the SLOT_OBJECT_TYPE to + # None to enable this functionality. + slots = [ - SlotSet(SLOT_OBJECT_TYPE, object_type), + SlotSet(SLOT_OBJECT_TYPE, None), SlotSet(SLOT_ATTRIBUTE, None), SlotSet(SLOT_MENTION, None), SlotSet(SLOT_LAST_OBJECT, object_identifier), diff --git a/rasa_sdk/knowledge_base/storage.py b/rasa_sdk/knowledge_base/storage.py index 3b4a0c89b..22f31e853 100644 --- a/rasa_sdk/knowledge_base/storage.py +++ b/rasa_sdk/knowledge_base/storage.py @@ -12,7 +12,6 @@ class KnowledgeBase: def __init__(self) -> None: - self.ordinal_mention_mapping = { "1": lambda lst: lst[0], "2": lambda lst: lst[1], @@ -110,6 +109,12 @@ async def get_object( """ raise NotImplementedError("Method is not implemented.") + def get_object_types(self) -> List[Text]: + """ + Returns a list of object types from knowledge base data. + """ + raise NotImplementedError("Method is not implemented.") + class InMemoryKnowledgeBase(KnowledgeBase): def __init__(self, data_file: Text) -> None: @@ -206,7 +211,6 @@ async def get_object( return None objects = self.data[object_type] - key_attribute = await utils.call_potential_coroutine( self.get_key_attribute_of_object(object_type) ) @@ -242,3 +246,7 @@ async def get_object( return None return objects_of_interest[0] + + def get_object_types(self) -> List[Text]: + """See parent class docstring.""" + return list(self.data.keys()) diff --git a/rasa_sdk/knowledge_base/utils.py b/rasa_sdk/knowledge_base/utils.py index 7594b9c96..24011ff89 100644 --- a/rasa_sdk/knowledge_base/utils.py +++ b/rasa_sdk/knowledge_base/utils.py @@ -82,7 +82,6 @@ def resolve_mention( listed_items = tracker.get_slot(SLOT_LISTED_OBJECTS) last_object = tracker.get_slot(SLOT_LAST_OBJECT) last_object_type = tracker.get_slot(SLOT_LAST_OBJECT_TYPE) - current_object_type = tracker.get_slot(SLOT_OBJECT_TYPE) if not mention: return None @@ -95,7 +94,9 @@ def resolve_mention( # for now we just assume that if the user refers to an object, for # example via "it" or "that restaurant", they are actually referring to the last # object that was detected. - if current_object_type == last_object_type: + # Since object type slot is reset to 'None' value, it is sufficient to only check + # whether the last_object_type is not None. + if last_object_type: return last_object return None @@ -164,3 +165,29 @@ def reset_attribute_slots( slots.append(SlotSet(attr, None)) return slots + + +def match_extracted_entities_to_object_type( + tracker: "Tracker", + object_types: List[Text], +) -> Optional[Text]: + """ + If the user ask a question about an attribute using an object name and + without specifying the object type, then this function searches the + corresponding object type. (e.g: when user asks'price range of B&B', this + function extracts the object type as 'hotel'). Here we assume that the user + message contains reference only to one object type in the knowledge base. + + Args: + tracker: the tracker + object_types: list of object types in the knowledge base + + Returns: the name of the object type if found, otherwise `None`. + """ + entities = tracker.latest_message.get("entities", []) + entity_names = [entity.get("entity") for entity in entities] + for entity in entity_names: + if entity in object_types: + return entity + + return None diff --git a/scripts/release.py b/scripts/release.py index 01af33b53..6e7010cfe 100644 --- a/scripts/release.py +++ b/scripts/release.py @@ -161,7 +161,8 @@ def is_valid_version_number(v: Text) -> bool: str(current_version.next_release_candidate("major")), ] version = questionary.select( - f"Which {version} do you want to release?", choices=choices, + f"Which {version} do you want to release?", + choices=choices, ).ask() if version: @@ -255,8 +256,8 @@ def next_version(args: argparse.Namespace) -> Version: def generate_changelog(version: Version) -> None: """Call towncrier and create a changelog from all available changelog entries.""" check_call( - ["towncrier", "build", "--yes", "--version", str(version)], cwd=str(project_root()) - + ["towncrier", "build", "--yes", "--version", str(version)], + cwd=str(project_root()), ) diff --git a/tests/knowledge_base/test_actions.py b/tests/knowledge_base/test_actions.py index 6ae808348..b962ac70e 100644 --- a/tests/knowledge_base/test_actions.py +++ b/tests/knowledge_base/test_actions.py @@ -1,4 +1,5 @@ import pytest +import copy from rasa_sdk import Tracker from rasa_sdk.events import SlotSet @@ -28,9 +29,16 @@ def compare_slots(slot_list_1, slot_list_2): @pytest.mark.parametrize( - "slots,expected_slots", + "latest_message,slots,expected_slots", [ ( + { + "entities": [ + { + "entity": "object_type", + }, + ], + }, { SLOT_MENTION: None, SLOT_ATTRIBUTE: None, @@ -42,26 +50,36 @@ def compare_slots(slot_list_1, slot_list_2): [ SlotSet(SLOT_MENTION, None), SlotSet(SLOT_ATTRIBUTE, None), - SlotSet(SLOT_OBJECT_TYPE, "restaurant"), + SlotSet(SLOT_OBJECT_TYPE, None), SlotSet(SLOT_LAST_OBJECT, None), SlotSet(SLOT_LAST_OBJECT_TYPE, "restaurant"), SlotSet(SLOT_LISTED_OBJECTS, [3, 2, 1]), ], ), ( + { + "entities": [ + { + "entity": "object_type", + }, + { + "entity": "cuisine", + }, + ], + }, { SLOT_MENTION: None, SLOT_ATTRIBUTE: None, SLOT_OBJECT_TYPE: "restaurant", SLOT_LISTED_OBJECTS: None, SLOT_LAST_OBJECT: None, - SLOT_LAST_OBJECT_TYPE: "restaurant", + SLOT_LAST_OBJECT_TYPE: None, "cuisine": "Italian", }, [ SlotSet(SLOT_MENTION, None), SlotSet(SLOT_ATTRIBUTE, None), - SlotSet(SLOT_OBJECT_TYPE, "restaurant"), + SlotSet(SLOT_OBJECT_TYPE, None), SlotSet(SLOT_LAST_OBJECT, None), SlotSet(SLOT_LAST_OBJECT_TYPE, "restaurant"), SlotSet(SLOT_LISTED_OBJECTS, [3, 1]), @@ -69,10 +87,20 @@ def compare_slots(slot_list_1, slot_list_2): ], ), ( + { + "entities": [ + { + "entity": "attribute", + }, + { + "entity": " mention", + }, + ], + }, { SLOT_MENTION: "2", SLOT_ATTRIBUTE: "cuisine", - SLOT_OBJECT_TYPE: "restaurant", + SLOT_OBJECT_TYPE: None, SLOT_LISTED_OBJECTS: [1, 2, 3], SLOT_LAST_OBJECT: None, SLOT_LAST_OBJECT_TYPE: "restaurant", @@ -80,16 +108,26 @@ def compare_slots(slot_list_1, slot_list_2): [ SlotSet(SLOT_MENTION, None), SlotSet(SLOT_ATTRIBUTE, None), - SlotSet(SLOT_OBJECT_TYPE, "restaurant"), + SlotSet(SLOT_OBJECT_TYPE, None), SlotSet(SLOT_LAST_OBJECT, 2), SlotSet(SLOT_LAST_OBJECT_TYPE, "restaurant"), ], ), ( + { + "entities": [ + { + "entity": "attribute", + }, + { + "entity": "restaurant", + }, + ], + }, { SLOT_MENTION: None, SLOT_ATTRIBUTE: "cuisine", - SLOT_OBJECT_TYPE: "restaurant", + SLOT_OBJECT_TYPE: None, SLOT_LISTED_OBJECTS: [1, 2, 3], SLOT_LAST_OBJECT: None, SLOT_LAST_OBJECT_TYPE: "restaurant", @@ -98,12 +136,15 @@ def compare_slots(slot_list_1, slot_list_2): [ SlotSet(SLOT_MENTION, None), SlotSet(SLOT_ATTRIBUTE, None), - SlotSet(SLOT_OBJECT_TYPE, "restaurant"), + SlotSet(SLOT_OBJECT_TYPE, None), SlotSet(SLOT_LAST_OBJECT, 1), SlotSet(SLOT_LAST_OBJECT_TYPE, "restaurant"), ], ), ( + { + "entities": [], + }, { SLOT_MENTION: None, SLOT_ATTRIBUTE: None, @@ -116,12 +157,20 @@ def compare_slots(slot_list_1, slot_list_2): ), ], ) -async def test_action_run(data_file, slots, expected_slots): +async def test_action_run(data_file, latest_message, slots, expected_slots): knowledge_base = InMemoryKnowledgeBase(data_file) action = ActionQueryKnowledgeBase(knowledge_base) dispatcher = CollectingDispatcher() - tracker = Tracker("default", slots, {}, [], False, None, {}, "action_listen") + + tracker = Tracker( + "default", slots, latest_message, [], False, None, {}, "action_listen" + ) + + # To prevent unintended modifications, create a copy of the slots dictionary + # before passing it to the Tracker object, as dictionaries in + # Python are mutable and passed by reference. + initial_slots = copy.deepcopy(slots) actual_slots = await action.run(dispatcher, tracker, {}) @@ -131,7 +180,6 @@ async def test_action_run(data_file, slots, expected_slots): # Check that utterances produced by action are correct. if slots[SLOT_ATTRIBUTE]: if slots.get("restaurant") is not None: - name = slots["restaurant"] attr = slots[SLOT_ATTRIBUTE] obj = await knowledge_base.get_object("restaurant", name) @@ -152,3 +200,44 @@ async def test_action_run(data_file, slots, expected_slots): actual_msg = dispatcher.messages[0]["text"] assert actual_msg == expected_msg + + # Check that temporary slot setting by action is correct. + if any(initial_slots.values()): + # The condition block below checks for user message + # such as "what is the price range of Pasta Bar?". + # This user message example is denoted by test case `4`. + if ( + initial_slots.get(SLOT_OBJECT_TYPE) is None + and initial_slots.get(SLOT_MENTION) is None + ): + # Since Pasta Bar belongs to `restaurant` object type + # the tracker event passed to set the slot temporarily + # should look like this. + expected_tracker_event = { + "event": "slot", + "timestamp": None, + "name": "object_type", + "value": "restaurant", + } + + assert expected_tracker_event in tracker.events + + # The condition block below checks for user message + # such as "what is the cuisine of second one?". + # This user message example is denoted by test case `3`. + elif ( + initial_slots.get(SLOT_OBJECT_TYPE) is None + and initial_slots.get(SLOT_MENTION) is not None + ): + # Since there is no `restaurant` entity in the user message, + # the `object_type` will be `None`. + # Therefore, the tracker event passed to set the slot temporarily + # should look like this. + expected_tracker_event = { + "event": "slot", + "timestamp": None, + "name": "object_type", + "value": None, + } + + assert expected_tracker_event in tracker.events diff --git a/tests/knowledge_base/test_utils.py b/tests/knowledge_base/test_utils.py index 48d9ab67a..e5438939d 100644 --- a/tests/knowledge_base/test_utils.py +++ b/tests/knowledge_base/test_utils.py @@ -3,6 +3,7 @@ from rasa_sdk import Tracker from rasa_sdk.events import SlotSet from rasa_sdk.knowledge_base.utils import ( + match_extracted_entities_to_object_type, get_attribute_slots, reset_attribute_slots, get_object_name, @@ -129,3 +130,28 @@ def test_get_object_name(slots, use_last_object_mention, expected_object_name): ) assert actual_object_name == expected_object_name + + +@pytest.mark.parametrize( + "latest_message,object_types,expected_object_name", + [ + ( + { + "entities": [ + {"entity": "attribute"}, + {"entity": "restaurant"}, + ], + }, + ["hotel", "restaurant"], + "restaurant", + ), + ], +) +def test_match_extracted_entities_to_object_type( + latest_message, object_types, expected_object_name +): + tracker = Tracker( + "default", {}, latest_message, [], False, None, {}, "action_listen" + ) + actual_object_name = match_extracted_entities_to_object_type(tracker, object_types) + assert actual_object_name == expected_object_name