From ce1f5d2a31e55723fde1f522a5338eb51b840385 Mon Sep 17 00:00:00 2001 From: Shailendra Paliwal Date: Fri, 19 Jan 2024 11:56:21 +0100 Subject: [PATCH] Get Full Retrieval Intent Name (#12998) * update full retrieval intent name * format * add test for test_update_full_retrieval_intent * fix linting * fix the structure of dict --- rasa/core/processor.py | 22 +++++++++++++ tests/core/test_processor.py | 64 +++++++++++++++++++++++++++++++++++- 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/rasa/core/processor.py b/rasa/core/processor.py index ec6ef1260a3e..fc628c7a7247 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -66,7 +66,11 @@ ENTITIES, INTENT, INTENT_NAME_KEY, + INTENT_RESPONSE_KEY, PREDICTED_CONFIDENCE_KEY, + FULL_RETRIEVAL_INTENT_NAME_KEY, + RESPONSE_SELECTOR, + RESPONSE, TEXT, ) from rasa.utils.endpoints import EndpointConfig @@ -721,6 +725,7 @@ async def parse_message( message, tracker, only_output_properties ) + self._update_full_retrieval_intent(parse_data) structlogger.debug( "processor.message.parse", parse_data_text=copy.deepcopy(parse_data["text"]), @@ -732,6 +737,23 @@ async def parse_message( return parse_data + def _update_full_retrieval_intent(self, parse_data: Dict[Text, Any]) -> None: + """Update the parse data with the full retrieval intent. + + Args: + parse_data: Message parse data to update. + """ + intent_name = parse_data.get(INTENT, {}).get(INTENT_NAME_KEY) + response_selector = parse_data.get(RESPONSE_SELECTOR, {}) + all_retrieval_intents = response_selector.get("all_retrieval_intents", []) + if intent_name and intent_name in all_retrieval_intents: + retrieval_intent = ( + response_selector.get(intent_name, {}) + .get(RESPONSE, {}) + .get(INTENT_RESPONSE_KEY) + ) + parse_data[INTENT][FULL_RETRIEVAL_INTENT_NAME_KEY] = retrieval_intent + def _parse_message_with_graph( self, message: UserMessage, diff --git a/tests/core/test_processor.py b/tests/core/test_processor.py index 392d85c29745..d0581b1800ef 100644 --- a/tests/core/test_processor.py +++ b/tests/core/test_processor.py @@ -70,7 +70,12 @@ from rasa.core.http_interpreter import RasaNLUHttpInterpreter from rasa.core.processor import MessageProcessor from rasa.shared.core.trackers import DialogueStateTracker -from rasa.shared.nlu.constants import INTENT_NAME_KEY, METADATA_MODEL_ID +from rasa.shared.nlu.constants import ( + INTENT, + INTENT_NAME_KEY, + FULL_RETRIEVAL_INTENT_NAME_KEY, + METADATA_MODEL_ID, +) from rasa.shared.nlu.training_data.message import Message from rasa.utils.endpoints import EndpointConfig from rasa.shared.core.constants import ( @@ -1928,3 +1933,60 @@ async def test_run_anonymization_pipeline_mocked_pipeline( await processor.run_anonymization_pipeline(tracker) event_diff.assert_called_once() + + +async def test_update_full_retrieval_intent( + default_processor: MessageProcessor, +) -> None: + parse_data = { + "text": "I like sunny days in berlin", + "intent": {"name": "chitchat", "confidence": 0.9}, + "entities": [], + "response_selector": { + "all_retrieval_intents": ["faq", "chitchat"], + "faq": { + "response": { + "responses": [{"text": "Our return policy lasts 30 days."}], + "confidence": 1.0, + "intent_response_key": "faq/what_is_return_policy", + "utter_action": "utter_faq/what_is_return_policy", + }, + "ranking": [ + { + "confidence": 1.0, + "intent_response_key": "faq/what_is_return_policy", + }, + { + "confidence": 2.3378809862799945e-19, + "intent_response_key": "faq/how_can_i_track_my_order", + }, + ], + }, + "chitchat": { + "response": { + "responses": [ + { + "text": "The sun is out today! Isn't that great?", + }, + ], + "confidence": 1.0, + "intent_response_key": "chitchat/ask_weather", + "utter_action": "utter_chitchat/ask_weather", + }, + "ranking": [ + { + "confidence": 1.0, + "intent_response_key": "chitchat/ask_weather", + }, + {"confidence": 0.0, "intent_response_key": "chitchat/ask_name"}, + ], + }, + }, + } + + default_processor._update_full_retrieval_intent(parse_data) + + assert parse_data[INTENT][INTENT_NAME_KEY] == "chitchat" + # assert that parse_data["intent"] has a key called response + assert FULL_RETRIEVAL_INTENT_NAME_KEY in parse_data[INTENT] + assert parse_data[INTENT][FULL_RETRIEVAL_INTENT_NAME_KEY] == "chitchat/ask_weather"