Skip to content

Commit

Permalink
OSS-668: Rasa Knowledge base actions are unable to query about a cert…
Browse files Browse the repository at this point in the history
…ain attribute of an object unless the user first asks to obtain a list of objects of a specific type. (#922)

* modified logic in run, slot management in _query_objects and _query_attribute functions

* added get_object_types function

* added get_object_type_dynamic function

* updated changelog

* updated changelog

* modified changelog name and content

* uncommented code snippet that was previously commented to experiment the contribution changes

* added comprehensive typing hint

* renamed  to

* replaced the 'get_object_type_dynamic' function name with the new function name

* replaced the 'get_object_type_dynamic' function name with the new function name

* modified sentences based on PR review comments

* updated the 'get_object_type_dynamic' with new name

* directly return entity without assigning it to an unused variable as suggested in PR review

* modified line 184 to safeguard against KeyErrors as suggested in PR review

* added suggestion to code change in line 185 as mentioned in the PR

* added suggestion to code change in line 185 as mentioned in the PR

* modified after running grammerly

* Added TrackerKnowledgeBase class in tracker.py to set object type slot

* modified class name

* removed passing object_type value as an argument in functions

* imported TrackerKnowledgeBase class | modified function based on changes done in utils.py | assigned extracted object type to object_type variable

* removed get_object_types from KnowledgeBase class | removed async from get_object_types

* removed await in line 136 after convering get_object_types function from async to normal

* modified code to set the object type slot temporarily using the method in Tracker class

* removing tracker.py which contained a method to temporarily set the object type slot

* removed the import from tracker.py

* removed code that gets current_object_type from the slot

* converted match_extracted_entities_to_object_types function from async to normal

* removed await for match_extracted_entities_to_object_types function

* modified action logic after assigning extracted object type to an object_type variable

* Unit test functions

* Formatted with black

* Added line

* Removed tab

* Resolved all linting issues - flake8

* Fixed Errors

* formatted with black

* removed entity:attribute appearing twice

* removed white space to meet max line charac requirement

* entered variables into new line to meet max line charac requirements

* removed print statement

* reverting the line indentation

* formatted code with black

* improved language and fixed typos

* added get_object_types method in KnowledgeBase parent class

* renamned entities_values to entity_names in line 188

* removed duplicated line in line 148

* modified return statement in function docstring

* updated the match_extracted_entities_to_object_types doc string to mention the assumption that the user message only contains reference to one object type

* renamed match_extracted_entities_to_object_types to match_extracted_entities_to_object_type

* renamed match_extracted_entities_to_object_types to match_extracted_entities_to_object_type

* reverting a change that wasn't made

* improved doc string

* changed  to

* added comment on slot object type resetting

* modified if condition in run to check whether latest message attribute

* Resolved test_actions failures

* Added unit test assertion

* Modified test_actions.py and actions.py

* modified code logic to include in one line

* formatted code with black

* Fixed formatting

* added comment about slot_object_type resetting

* formatted code with black

* fixed typos

* Reformatted actions file

* resolve lint code issue caused by inverted comma

* added a comment on replacing new_request with has_attribute_in_latest

* modified doc string in run

* Fixed trailing whitespace issue

* removed keys from entities payload that are not needed for test_utils.py

* added function docstring in get_object_types() to look up from parent class

* removed repeated info in docstring

* modified comment on resetting object type slot based on review suggestion

* replaced has_attribute_in_latest with has_attribute_in_latest_message

* Fixed linting issues

* Reformatted using black

* added a rough code on testing the tracker events

* removed import logging

* added print statements for testing

* resolved python mutable issue

* removed print statements and cleaned the code

* removed import copy

* removed trailing white space in line 157

* removed trailing white space in line 157

* added spaces between code lines

* added spaces between code lines

* added space

* Fixed formatting and lint code

* made a copy of slots using copy.deepcopy() | modified assertion to check inside actual event list

* format with black

* added comments in tracker event assertion code block

* formatted code

---------

Co-authored-by: Gajithra <[email protected]>
Co-authored-by: Gajithira Puvanendran <[email protected]>
Co-authored-by: Dilanka Sanjula <[email protected]>
Co-authored-by: Dilanka96 <[email protected]>
  • Loading branch information
5 people committed Jun 6, 2023
1 parent cc0d6ad commit 2e32306
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 37 deletions.
29 changes: 29 additions & 0 deletions changelog/922.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
## Problem
Rasa knowledge base actions cannot infer the object type of an object directly from the user message without the user first asking to list all objects related to that object type. This prevents action_query_knowledge from providing a suitable response when a user asks for a certain attribute of an object even though the knowledge base has the relevant information. That is, the knowledge base actions require the slot `object_type` to be set to one of the primary key values in the knowledge base for it to search through the objects. Here is an example:
```
Your input -> what is the price range of Berlin Burrito Company?
Sorry, I'm not sure I understand. Can you rephrase?
Your input -> list some restaurants
Found the following objects of type 'restaurant':
1: Gong Gan
2: I due forni
3: Pfefferberg
4: Lụa Restaurant
5: Donath
Your input -> what is the price range of Berlin Burrito Company?
'Berlin Burrito Company' has the value 'cheap' for attribute 'price-range'.
```

## Proposed solution
- The improvement requires changes to the classes ActionQueryKnowledgeBase and InMemoryKnowledgeBase under rasa-sdk.
- The `object_type` can be inferred by utilizing the entity extraction (DIET) where object types are used as entities to annotate object names.
This also requires changes to be made to slot management to enable dynamic inference of `object_type`.
- The scope of the suggested solution is limited to user queries where they ask for an attribute of a given object without mentioning the object type and without needing to first ask for a list of options of the corresponding object type.
- E.g.: If the user asks for ‘price range of Berlin Burrito Company’, then rasa will extract and set attribute slot value to ‘price-range’ and hotel slot value to ‘Berlin Burrito Company’. From this, it can be inferred that the user is talking about the object type ‘hotel’.

## Summary of Changes
- To enable the inference of `object_type` using the entities the following changes were made to the existing code base:
- Extract the list of object_types from our knowledge base using a new method `get_object_types()` in `storage.py` for the `InMemoryKnowledgeBase` class.
- A new method named `match_extracted_entities_to_object_type()` was added in `utils.py` to infer the object type of a given object using the entities and list of object types
- The relevant logic was added in `actions.py` to infer the object type using the above functionalities when the object type slot is not set.
- To enable dynamic inference of `object_type`, changes to slot management are also required. Currently, the change is to reset the `object_type` slot to `None` after every conversation turn.
60 changes: 41 additions & 19 deletions rasa_sdk/knowledge_base/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
SLOT_LISTED_OBJECTS,
get_object_name,
get_attribute_slots,
match_extracted_entities_to_object_type,
)
from rasa_sdk import utils
from rasa_sdk.executor import CollectingDispatcher
Expand Down Expand Up @@ -112,13 +113,12 @@ async def run(
tracker: Tracker,
domain: "DomainDict",
) -> List[Dict[Text, Any]]:
"""Executes this action.
If the user ask a question about an attribute,
"""
Executes this action. If the user ask a question about an attribute,
the knowledge base is queried for that attribute. Otherwise, if no
attribute was detected in the request or the user is talking about a new
object type, multiple objects of the requested type are returned from the
knowledge base.
attribute was detected in the latest request it assumes user is talking
about a new object type and, multiple objects of the requested type are
returned from the knowledge base.
Args:
dispatcher: the dispatcher
Expand All @@ -131,22 +131,37 @@ async def run(
object_type = tracker.get_slot(SLOT_OBJECT_TYPE)
last_object_type = tracker.get_slot(SLOT_LAST_OBJECT_TYPE)
attribute = tracker.get_slot(SLOT_ATTRIBUTE)
has_mention = tracker.get_slot(SLOT_MENTION) is not None

new_request = object_type != last_object_type
# check if attribute entity is found in latest user message. This is used
# to track whether the request is to query objects or query attributes
has_attribute_in_latest_message = any(
entity.get("entity") == "attribute"
for entity in tracker.latest_message["entities"]
)

if not object_type:
# object type always needs to be set as this is needed to query the
# knowledge base
dispatcher.utter_message(response="utter_ask_rephrase")
return []

if not attribute or new_request:
# sets the object type dynamically from entities if object_type is not
# found in user query
object_types = self.knowledge_base.get_object_types()
object_type = match_extracted_entities_to_object_type(tracker, object_types)
set_object_type_slot_event = [SlotSet(SLOT_OBJECT_TYPE, object_type)]
tracker.add_slots(
set_object_type_slot_event
) # temporarily set the `object_type_slot` to extracted value

if object_type and not has_attribute_in_latest_message:
return await self._query_objects(dispatcher, object_type, tracker)
elif attribute:
elif object_type and attribute:
return await self._query_attribute(
dispatcher, object_type, attribute, tracker
)

if last_object_type and has_mention and attribute:
return await self._query_attribute(
dispatcher, last_object_type, attribute, tracker
)

dispatcher.utter_message(response="utter_ask_rephrase")
return []

Expand All @@ -167,15 +182,13 @@ async def _query_objects(
object_attributes = await utils.call_potential_coroutine(
self.knowledge_base.get_attributes_of_object(object_type)
)

# get all set attribute slots of the object type to be able to filter the
# list of objects
attributes = get_attribute_slots(tracker, object_attributes)
# query the knowledge base
objects = await utils.call_potential_coroutine(
self.knowledge_base.get_objects(object_type, attributes)
)

await utils.call_potential_coroutine(
self.utter_objects(dispatcher, object_type, objects)
)
Expand All @@ -189,8 +202,13 @@ async def _query_objects(

last_object = None if len(objects) > 1 else objects[0][key_attribute]

# To prevent the user to first ask to list the objects for an object type,
# the object type has to be extracted while the action is executed.
# Therefore we need to reset the SLOT_OBJECT_TYPE to
# None to enable this functionality.

slots = [
SlotSet(SLOT_OBJECT_TYPE, object_type),
SlotSet(SLOT_OBJECT_TYPE, None),
SlotSet(SLOT_MENTION, None),
SlotSet(SLOT_ATTRIBUTE, None),
SlotSet(SLOT_LAST_OBJECT, last_object),
Expand Down Expand Up @@ -219,7 +237,6 @@ async def _query_attribute(
Returns: list of slots
"""

object_name = get_object_name(
tracker,
self.knowledge_base.ordinal_mention_mapping,
Expand Down Expand Up @@ -258,8 +275,13 @@ async def _query_attribute(
)
)

# To prevent the user to first ask to list the objects for an object type,
# the object type has to be extracted while the action is executed.
# Therefore we need to reset the SLOT_OBJECT_TYPE to
# None to enable this functionality.

slots = [
SlotSet(SLOT_OBJECT_TYPE, object_type),
SlotSet(SLOT_OBJECT_TYPE, None),
SlotSet(SLOT_ATTRIBUTE, None),
SlotSet(SLOT_MENTION, None),
SlotSet(SLOT_LAST_OBJECT, object_identifier),
Expand Down
12 changes: 10 additions & 2 deletions rasa_sdk/knowledge_base/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

class KnowledgeBase:
def __init__(self) -> None:

self.ordinal_mention_mapping = {
"1": lambda lst: lst[0],
"2": lambda lst: lst[1],
Expand Down Expand Up @@ -110,6 +109,12 @@ async def get_object(
"""
raise NotImplementedError("Method is not implemented.")

def get_object_types(self) -> List[Text]:
"""
Returns a list of object types from knowledge base data.
"""
raise NotImplementedError("Method is not implemented.")


class InMemoryKnowledgeBase(KnowledgeBase):
def __init__(self, data_file: Text) -> None:
Expand Down Expand Up @@ -206,7 +211,6 @@ async def get_object(
return None

objects = self.data[object_type]

key_attribute = await utils.call_potential_coroutine(
self.get_key_attribute_of_object(object_type)
)
Expand Down Expand Up @@ -242,3 +246,7 @@ async def get_object(
return None

return objects_of_interest[0]

def get_object_types(self) -> List[Text]:
"""See parent class docstring."""
return list(self.data.keys())
31 changes: 29 additions & 2 deletions rasa_sdk/knowledge_base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def resolve_mention(
listed_items = tracker.get_slot(SLOT_LISTED_OBJECTS)
last_object = tracker.get_slot(SLOT_LAST_OBJECT)
last_object_type = tracker.get_slot(SLOT_LAST_OBJECT_TYPE)
current_object_type = tracker.get_slot(SLOT_OBJECT_TYPE)

if not mention:
return None
Expand All @@ -95,7 +94,9 @@ def resolve_mention(
# for now we just assume that if the user refers to an object, for
# example via "it" or "that restaurant", they are actually referring to the last
# object that was detected.
if current_object_type == last_object_type:
# Since object type slot is reset to 'None' value, it is sufficient to only check
# whether the last_object_type is not None.
if last_object_type:
return last_object

return None
Expand Down Expand Up @@ -164,3 +165,29 @@ def reset_attribute_slots(
slots.append(SlotSet(attr, None))

return slots


def match_extracted_entities_to_object_type(
tracker: "Tracker",
object_types: List[Text],
) -> Optional[Text]:
"""
If the user ask a question about an attribute using an object name and
without specifying the object type, then this function searches the
corresponding object type. (e.g: when user asks'price range of B&B', this
function extracts the object type as 'hotel'). Here we assume that the user
message contains reference only to one object type in the knowledge base.
Args:
tracker: the tracker
object_types: list of object types in the knowledge base
Returns: the name of the object type if found, otherwise `None`.
"""
entities = tracker.latest_message.get("entities", [])
entity_names = [entity.get("entity") for entity in entities]
for entity in entity_names:
if entity in object_types:
return entity

return None
7 changes: 4 additions & 3 deletions scripts/release.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ def is_valid_version_number(v: Text) -> bool:
str(current_version.next_release_candidate("major")),
]
version = questionary.select(
f"Which {version} do you want to release?", choices=choices,
f"Which {version} do you want to release?",
choices=choices,
).ask()

if version:
Expand Down Expand Up @@ -255,8 +256,8 @@ def next_version(args: argparse.Namespace) -> Version:
def generate_changelog(version: Version) -> None:
"""Call towncrier and create a changelog from all available changelog entries."""
check_call(
["towncrier", "build", "--yes", "--version", str(version)], cwd=str(project_root())

["towncrier", "build", "--yes", "--version", str(version)],
cwd=str(project_root()),
)


Expand Down
Loading

0 comments on commit 2e32306

Please sign in to comment.