From 9c697185536ce83e1550b6c158f0c6948d91ce72 Mon Sep 17 00:00:00 2001 From: danc Date: Wed, 20 Sep 2023 14:06:47 +0100 Subject: [PATCH 01/19] chore: renamed cdu folder in tests as dialogue_understanding. --- .../generator/llm_command_generator.py | 19 +- .../__init__.py | 0 .../commands/__init__.py | 0 .../commands/conftest.py | 0 .../commands/test_can_not_handle_command.py | 0 .../commands/test_cancel_flow_command.py | 0 .../commands/test_chit_chat_answer_command.py | 0 .../commands/test_clarify_command.py | 0 .../commands/test_command.py | 0 .../commands/test_command_processor.py | 0 .../commands/test_correct_slots_command.py | 0 .../commands/test_error_command.py | 0 .../test_handle_code_change_command.py | 0 .../commands/test_human_handoff_command.py | 0 .../commands/test_konwledge_answer_command.py | 0 .../commands/test_set_slot_command.py | 0 .../commands/test_start_flow_command.py | 0 .../generator/__init__.py | 0 .../generator/test_command_generator.py | 0 .../generator/test_llm_command_generator.py | 242 ++++++++++++++++++ .../stack/__init__.py | 0 .../stack/frames/__init__.py | 0 .../stack/frames/test_chit_chat_frame.py | 0 .../stack/frames/test_dialogue_stack_frame.py | 0 .../stack/frames/test_flow_frame.py | 0 .../stack/frames/test_search_frame.py | 0 .../stack/test_dialogue_stack.py | 0 .../stack/test_utils.py | 0 28 files changed, 257 insertions(+), 4 deletions(-) rename tests/{cdu => dialogue_understanding}/__init__.py (100%) rename tests/{cdu => dialogue_understanding}/commands/__init__.py (100%) rename tests/{cdu => dialogue_understanding}/commands/conftest.py (100%) rename tests/{cdu => dialogue_understanding}/commands/test_can_not_handle_command.py (100%) rename tests/{cdu => dialogue_understanding}/commands/test_cancel_flow_command.py (100%) rename tests/{cdu => dialogue_understanding}/commands/test_chit_chat_answer_command.py (100%) rename tests/{cdu => dialogue_understanding}/commands/test_clarify_command.py (100%) rename tests/{cdu => dialogue_understanding}/commands/test_command.py (100%) rename tests/{cdu => dialogue_understanding}/commands/test_command_processor.py (100%) rename tests/{cdu => dialogue_understanding}/commands/test_correct_slots_command.py (100%) rename tests/{cdu => dialogue_understanding}/commands/test_error_command.py (100%) rename tests/{cdu => dialogue_understanding}/commands/test_handle_code_change_command.py (100%) rename tests/{cdu => dialogue_understanding}/commands/test_human_handoff_command.py (100%) rename tests/{cdu => dialogue_understanding}/commands/test_konwledge_answer_command.py (100%) rename tests/{cdu => dialogue_understanding}/commands/test_set_slot_command.py (100%) rename tests/{cdu => dialogue_understanding}/commands/test_start_flow_command.py (100%) rename tests/{cdu => dialogue_understanding}/generator/__init__.py (100%) rename tests/{cdu => dialogue_understanding}/generator/test_command_generator.py (100%) create mode 100644 tests/dialogue_understanding/generator/test_llm_command_generator.py rename tests/{cdu => dialogue_understanding}/stack/__init__.py (100%) rename tests/{cdu => dialogue_understanding}/stack/frames/__init__.py (100%) rename tests/{cdu => dialogue_understanding}/stack/frames/test_chit_chat_frame.py (100%) rename tests/{cdu => dialogue_understanding}/stack/frames/test_dialogue_stack_frame.py (100%) rename tests/{cdu => dialogue_understanding}/stack/frames/test_flow_frame.py (100%) rename tests/{cdu => dialogue_understanding}/stack/frames/test_search_frame.py (100%) rename tests/{cdu => dialogue_understanding}/stack/test_dialogue_stack.py (100%) rename tests/{cdu => dialogue_understanding}/stack/test_utils.py (100%) diff --git a/rasa/dialogue_understanding/generator/llm_command_generator.py b/rasa/dialogue_understanding/generator/llm_command_generator.py index d5ee1ac8bcc3..45e8ccce9a5b 100644 --- a/rasa/dialogue_understanding/generator/llm_command_generator.py +++ b/rasa/dialogue_understanding/generator/llm_command_generator.py @@ -4,6 +4,7 @@ from jinja2 import Template import structlog + from rasa.dialogue_understanding.stack.utils import top_flow_frame from rasa.dialogue_understanding.generator import CommandGenerator from rasa.dialogue_understanding.commands import ( @@ -47,9 +48,6 @@ "rasa.dialogue_understanding.generator", "command_prompt_template.jinja2" ) -structlogger = structlog.get_logger() - - DEFAULT_LLM_CONFIG = { "_type": "openai", "request_timeout": 7, @@ -59,6 +57,8 @@ LLM_CONFIG_KEY = "llm" +structlogger = structlog.get_logger() + @DefaultV1Recipe.register( [ @@ -67,6 +67,11 @@ is_trainable=True, ) class LLMCommandGenerator(GraphComponent, CommandGenerator): + """An LLM based command generator. + + # TODO: add description to the docstring. + + """ @staticmethod def get_default_config() -> Dict[str, Any]: """The component's default config (see parent class for full docstring).""" @@ -142,6 +147,7 @@ def predict_commands( flows: FlowsList, tracker: Optional[DialogueStateTracker] = None, ) -> List[Command]: + """TODO: add docstring""" if tracker is None or flows.is_empty(): # cannot do anything if there are no flows or no tracker return [] @@ -164,6 +170,7 @@ def predict_commands( @staticmethod def is_none_value(value: str) -> bool: + """TODO: add docstring""" return value in { "[missing information]", "[missing]", @@ -217,7 +224,9 @@ def coerce_slot_value( def parse_commands( cls, actions: Optional[str], tracker: DialogueStateTracker ) -> List[Command]: - """Parse the actions returned by the llm into intent and entities.""" + """Parse the actions returned by the llm into intent and entities. + #TODO: add arguments and returns. + """ if not actions: return [ErrorCommand()] @@ -267,6 +276,7 @@ def parse_commands( def create_template_inputs( cls, flows: FlowsList, tracker: DialogueStateTracker ) -> List[Dict[str, Any]]: + """TODO: add docstring.""" result = [] for flow in flows.underlying_flows: # TODO: check if we should filter more flows; e.g. flows that are @@ -337,6 +347,7 @@ def slot_value(tracker: DialogueStateTracker, slot_name: str) -> str: def render_template( self, message: Message, tracker: DialogueStateTracker, flows: FlowsList ) -> str: + """TODO: add docstring""" flows_without_patterns = FlowsList( [f for f in flows.underlying_flows if not f.is_handling_pattern()] ) diff --git a/tests/cdu/__init__.py b/tests/dialogue_understanding/__init__.py similarity index 100% rename from tests/cdu/__init__.py rename to tests/dialogue_understanding/__init__.py diff --git a/tests/cdu/commands/__init__.py b/tests/dialogue_understanding/commands/__init__.py similarity index 100% rename from tests/cdu/commands/__init__.py rename to tests/dialogue_understanding/commands/__init__.py diff --git a/tests/cdu/commands/conftest.py b/tests/dialogue_understanding/commands/conftest.py similarity index 100% rename from tests/cdu/commands/conftest.py rename to tests/dialogue_understanding/commands/conftest.py diff --git a/tests/cdu/commands/test_can_not_handle_command.py b/tests/dialogue_understanding/commands/test_can_not_handle_command.py similarity index 100% rename from tests/cdu/commands/test_can_not_handle_command.py rename to tests/dialogue_understanding/commands/test_can_not_handle_command.py diff --git a/tests/cdu/commands/test_cancel_flow_command.py b/tests/dialogue_understanding/commands/test_cancel_flow_command.py similarity index 100% rename from tests/cdu/commands/test_cancel_flow_command.py rename to tests/dialogue_understanding/commands/test_cancel_flow_command.py diff --git a/tests/cdu/commands/test_chit_chat_answer_command.py b/tests/dialogue_understanding/commands/test_chit_chat_answer_command.py similarity index 100% rename from tests/cdu/commands/test_chit_chat_answer_command.py rename to tests/dialogue_understanding/commands/test_chit_chat_answer_command.py diff --git a/tests/cdu/commands/test_clarify_command.py b/tests/dialogue_understanding/commands/test_clarify_command.py similarity index 100% rename from tests/cdu/commands/test_clarify_command.py rename to tests/dialogue_understanding/commands/test_clarify_command.py diff --git a/tests/cdu/commands/test_command.py b/tests/dialogue_understanding/commands/test_command.py similarity index 100% rename from tests/cdu/commands/test_command.py rename to tests/dialogue_understanding/commands/test_command.py diff --git a/tests/cdu/commands/test_command_processor.py b/tests/dialogue_understanding/commands/test_command_processor.py similarity index 100% rename from tests/cdu/commands/test_command_processor.py rename to tests/dialogue_understanding/commands/test_command_processor.py diff --git a/tests/cdu/commands/test_correct_slots_command.py b/tests/dialogue_understanding/commands/test_correct_slots_command.py similarity index 100% rename from tests/cdu/commands/test_correct_slots_command.py rename to tests/dialogue_understanding/commands/test_correct_slots_command.py diff --git a/tests/cdu/commands/test_error_command.py b/tests/dialogue_understanding/commands/test_error_command.py similarity index 100% rename from tests/cdu/commands/test_error_command.py rename to tests/dialogue_understanding/commands/test_error_command.py diff --git a/tests/cdu/commands/test_handle_code_change_command.py b/tests/dialogue_understanding/commands/test_handle_code_change_command.py similarity index 100% rename from tests/cdu/commands/test_handle_code_change_command.py rename to tests/dialogue_understanding/commands/test_handle_code_change_command.py diff --git a/tests/cdu/commands/test_human_handoff_command.py b/tests/dialogue_understanding/commands/test_human_handoff_command.py similarity index 100% rename from tests/cdu/commands/test_human_handoff_command.py rename to tests/dialogue_understanding/commands/test_human_handoff_command.py diff --git a/tests/cdu/commands/test_konwledge_answer_command.py b/tests/dialogue_understanding/commands/test_konwledge_answer_command.py similarity index 100% rename from tests/cdu/commands/test_konwledge_answer_command.py rename to tests/dialogue_understanding/commands/test_konwledge_answer_command.py diff --git a/tests/cdu/commands/test_set_slot_command.py b/tests/dialogue_understanding/commands/test_set_slot_command.py similarity index 100% rename from tests/cdu/commands/test_set_slot_command.py rename to tests/dialogue_understanding/commands/test_set_slot_command.py diff --git a/tests/cdu/commands/test_start_flow_command.py b/tests/dialogue_understanding/commands/test_start_flow_command.py similarity index 100% rename from tests/cdu/commands/test_start_flow_command.py rename to tests/dialogue_understanding/commands/test_start_flow_command.py diff --git a/tests/cdu/generator/__init__.py b/tests/dialogue_understanding/generator/__init__.py similarity index 100% rename from tests/cdu/generator/__init__.py rename to tests/dialogue_understanding/generator/__init__.py diff --git a/tests/cdu/generator/test_command_generator.py b/tests/dialogue_understanding/generator/test_command_generator.py similarity index 100% rename from tests/cdu/generator/test_command_generator.py rename to tests/dialogue_understanding/generator/test_command_generator.py diff --git a/tests/dialogue_understanding/generator/test_llm_command_generator.py b/tests/dialogue_understanding/generator/test_llm_command_generator.py new file mode 100644 index 000000000000..ee807f82c6e9 --- /dev/null +++ b/tests/dialogue_understanding/generator/test_llm_command_generator.py @@ -0,0 +1,242 @@ +from unittest.mock import Mock, patch + +import pytest +from langchain.llms.fake import FakeListLLM +from structlog.testing import capture_logs + +from rasa.dialogue_understanding.generator.llm_command_generator import LLMCommandGenerator +from rasa.dialogue_understanding.commands import ( + # Command, + ErrorCommand, + SetSlotCommand, + CancelFlowCommand, + StartFlowCommand, + HumanHandoffCommand, + ChitChatAnswerCommand, + KnowledgeAnswerCommand, + ClarifyCommand, +) +from rasa.engine.graph import ExecutionContext +from rasa.engine.storage.resource import Resource +from rasa.engine.storage.storage import ModelStorage +from rasa.shared.core.slots import BooleanSlot, FloatSlot, TextSlot +from rasa.shared.core.trackers import DialogueStateTracker + + +class TestLLMCommandGenerator: + """Tests for the LLMCommandGenerator.""" + + @pytest.fixture + def command_generator(self): + """Create an LLMCommandGenerator.""" + return LLMCommandGenerator.create( + config={}, resource=Mock(), model_storage=Mock(), execution_context=Mock()) + + @pytest.fixture + def mock_command_generator( + self, + default_model_storage: ModelStorage, + default_execution_context: ExecutionContext, + ) -> LLMCommandGenerator: + """Create a patched LLMCommandGenerator.""" + with patch( + "rasa.dialogue_understanding.generator.llm_command_generator.llm_factory", + Mock(return_value=FakeListLLM(responses=["StartFlow(check_balance)"])), + ) as mock_llm: + return LLMCommandGenerator.create( + config=LLMCommandGenerator.get_default_config(), + model_storage=default_model_storage, + resource=Resource("llmcommandgenerator"), + execution_context=default_execution_context) + + def test_predict_commands_with_no_flows(self, mock_command_generator: LLMCommandGenerator): + """Test that predict_commands returns an empty list when flows is None.""" + # When + predicted_commands = mock_command_generator.predict_commands(Mock(), flows=None, tracker=Mock()) + # Then + assert not predicted_commands + + def test_predict_commands_with_no_tracker(self, mock_command_generator: LLMCommandGenerator): + """Test that predict_commands returns an empty list when tracker is None.""" + # When + predicted_commands = mock_command_generator.predict_commands(Mock(), flows=Mock(), tracker=None) + # Then + assert not predicted_commands + + @patch.object(LLMCommandGenerator, "render_template", Mock(return_value="some prompt")) + @patch.object(LLMCommandGenerator, "parse_commands", Mock()) + def test_predict_commands_calls_llm_correctly(self, command_generator: LLMCommandGenerator): + """Test that predict_commands calls llm correctly.""" + # When + mock_llm = Mock() + with patch( + "rasa.dialogue_understanding.generator.llm_command_generator.llm_factory", + Mock(return_value=mock_llm), + ): + command_generator.predict_commands(Mock(), flows=Mock(), tracker=Mock()) + # Then + mock_llm.assert_called_once_with("some prompt") + + @patch.object(LLMCommandGenerator, "render_template", Mock(return_value="some prompt")) + @patch.object(LLMCommandGenerator, "parse_commands", Mock()) + def test_generate_action_list_catches_llm_exception(self, command_generator: LLMCommandGenerator): + """Test that predict_commands calls llm correctly.""" + # Given + mock_llm = Mock(side_effect=Exception("some exception")) + with patch( + "rasa.dialogue_understanding.generator.llm_command_generator.llm_factory", + Mock(return_value=mock_llm), + ): + # When + with capture_logs() as logs: + command_generator.predict_commands(Mock(), flows=Mock(), tracker=Mock()) + # Then + print(logs) + assert len(logs) == 4 + assert isinstance(logs[1]["error"]) == isinstance(Exception("some exception")) + + + + def test_render_template(self, mock_command_generator: LLMCommandGenerator): + """Test that render_template renders a template.""" + pass + # # Given + # message = Mock() + + # tracker = Mock() + + # flows = Mock() + # # When + # rendered_template = command_generator.render_template() + + # # Then + # assert rendered_template == "template" + + # def test_generate_action_list_calls_llm_with_correct_promt(self): + # # Given + # prompt = "some prompt" + # with patch( + # "rasa.rasa.shared.utils.llm.llm_factory", + # Mock(return_value=FakeListLLM(responses=["hello"])) + # ) as mock_llm: + # LLMCommandGenerator._generate_action_list(prompt) + # mock_llm.assert_called_once_with(prompt) + + @pytest.mark.parametrize( + "input_action, expected_command", + [ + ( + None, + [ErrorCommand()] + ), + ( + "SetSlot(transfer_money_amount_of_money, )", + [SetSlotCommand(name="transfer_money_amount_of_money", value=None)] + ), + ( + "SetSlot(flow_name, some_flow)", + [StartFlowCommand(flow="some_flow")] + ), + ( + "StartFlow(check_balance)", + [StartFlowCommand(flow="check_balance")] + ), + ( + "CancelFlow()", + [CancelFlowCommand()] + ), + ( + "ChitChat()", + [ChitChatAnswerCommand()] + ), + ( + "SearchAndReply()", + [KnowledgeAnswerCommand()] + ), + ( + "HumanHandoff()", + [HumanHandoffCommand()] + ), + ( + "Clarify(transfer_money)", + [ClarifyCommand(options=["transfer_money"])] + ), + ( + "Clarify(list_contacts, add_contact, remove_contact)", + [ClarifyCommand(options=["list_contacts", "add_contact", "remove_contact"])] + ), + ]) + def test_parse_commands_identifies_correct_command(self, input_action, expected_command): + """Test that parse_commands identifies the correct commands.""" + # When + with patch.object(LLMCommandGenerator, "coerce_slot_value", Mock(return_value=None)): + parsed_commands = LLMCommandGenerator.parse_commands(input_action, Mock()) + # Then + assert parsed_commands == expected_command + + @pytest.mark.parametrize( + "slot_name, slot, slot_value, expected_coerced_value", + [ + ("some_other_slot", FloatSlot("some_float", []), None, None), + ("some_float", FloatSlot("some_float", []), 40, 40.0), + ("some_float", FloatSlot("some_float", []), 40.0, 40.0), + ("some_text", TextSlot("some_text", []),"fourty", "fourty"), + ("some_bool", BooleanSlot("some_bool", []), "True", True), + ("some_bool", BooleanSlot("some_bool", []), "false", False) + ]) + def test_coerce_slot_value(self, slot_name, slot, slot_value, expected_coerced_value): + """Test that coerce_slot_value coerces the slot value correctly.""" + # Given + tracker = DialogueStateTracker.from_events( + "test", + evts=[], + slots=[slot] + ) + # When + coerced_value = LLMCommandGenerator.coerce_slot_value(slot_value, slot_name, tracker) + # Then + assert coerced_value == expected_coerced_value + + @pytest.mark.parametrize( + "input_string, expected_string", + [ + ("text", "text"), + (" text ", "text"), + ("\"text\"", "text"), + ("'text'", "text"), + ("' \"text' \" ", "text"), + ("", "") + ]) + def test_clean_extracted_value(self, input_string, expected_string): + """Test that clean_extracted_value removes the leading and trailing whitespaces.""" + # When + cleaned_extracted_value = LLMCommandGenerator.clean_extracted_value(input_string) + # Then + assert cleaned_extracted_value == expected_string + + + + + + + + + + + + # def test_allowd_values_for_slot(self, command_generator): + # """Test that allowed_values_for_slot returns the allowed values for a slot.""" + # # When + # allowed_values = command_generator.allowed_values_for_slot("slot_name") + + # # Then + # assert allowed_values == [] + + # @pytest.mark.parametrize("input_value, expected_truthiness", + # [(None, True), + # ("", False), + + # )] + # def test_is_none_value(self): + # """Test that is_none_value returns True when the value is None.""" + # assert LLMCommandGenerator.is_none_value(None) diff --git a/tests/cdu/stack/__init__.py b/tests/dialogue_understanding/stack/__init__.py similarity index 100% rename from tests/cdu/stack/__init__.py rename to tests/dialogue_understanding/stack/__init__.py diff --git a/tests/cdu/stack/frames/__init__.py b/tests/dialogue_understanding/stack/frames/__init__.py similarity index 100% rename from tests/cdu/stack/frames/__init__.py rename to tests/dialogue_understanding/stack/frames/__init__.py diff --git a/tests/cdu/stack/frames/test_chit_chat_frame.py b/tests/dialogue_understanding/stack/frames/test_chit_chat_frame.py similarity index 100% rename from tests/cdu/stack/frames/test_chit_chat_frame.py rename to tests/dialogue_understanding/stack/frames/test_chit_chat_frame.py diff --git a/tests/cdu/stack/frames/test_dialogue_stack_frame.py b/tests/dialogue_understanding/stack/frames/test_dialogue_stack_frame.py similarity index 100% rename from tests/cdu/stack/frames/test_dialogue_stack_frame.py rename to tests/dialogue_understanding/stack/frames/test_dialogue_stack_frame.py diff --git a/tests/cdu/stack/frames/test_flow_frame.py b/tests/dialogue_understanding/stack/frames/test_flow_frame.py similarity index 100% rename from tests/cdu/stack/frames/test_flow_frame.py rename to tests/dialogue_understanding/stack/frames/test_flow_frame.py diff --git a/tests/cdu/stack/frames/test_search_frame.py b/tests/dialogue_understanding/stack/frames/test_search_frame.py similarity index 100% rename from tests/cdu/stack/frames/test_search_frame.py rename to tests/dialogue_understanding/stack/frames/test_search_frame.py diff --git a/tests/cdu/stack/test_dialogue_stack.py b/tests/dialogue_understanding/stack/test_dialogue_stack.py similarity index 100% rename from tests/cdu/stack/test_dialogue_stack.py rename to tests/dialogue_understanding/stack/test_dialogue_stack.py diff --git a/tests/cdu/stack/test_utils.py b/tests/dialogue_understanding/stack/test_utils.py similarity index 100% rename from tests/cdu/stack/test_utils.py rename to tests/dialogue_understanding/stack/test_utils.py From 25ce5fe45470436d4364500ad02d321be0c4137b Mon Sep 17 00:00:00 2001 From: danc Date: Thu, 21 Sep 2023 00:16:15 +0100 Subject: [PATCH 02/19] added render_template test. --- .../generator/llm_command_generator.py | 2 +- .../generator/rendered_prompt.txt | 1 + .../generator/test_llm_command_generator.py | 165 ++++++++++++------ 3 files changed, 115 insertions(+), 53 deletions(-) create mode 100644 tests/dialogue_understanding/generator/rendered_prompt.txt diff --git a/rasa/dialogue_understanding/generator/llm_command_generator.py b/rasa/dialogue_understanding/generator/llm_command_generator.py index 45e8ccce9a5b..07ed70c33dfe 100644 --- a/rasa/dialogue_understanding/generator/llm_command_generator.py +++ b/rasa/dialogue_understanding/generator/llm_command_generator.py @@ -239,7 +239,7 @@ def parse_commands( cancel_flow_re = re.compile(r"CancelFlow\(\)") chitchat_re = re.compile(r"ChitChat\(\)") knowledge_re = re.compile(r"SearchAndReply\(\)") - humand_handoff_re = re.compile(r"HumandHandoff\(\)") + humand_handoff_re = re.compile(r"HumanHandoff\(\)") clarify_re = re.compile(r"Clarify\(([a-zA-Z0-9_, ]+)\)") for action in actions.strip().splitlines(): diff --git a/tests/dialogue_understanding/generator/rendered_prompt.txt b/tests/dialogue_understanding/generator/rendered_prompt.txt new file mode 100644 index 000000000000..ea92e3f93387 --- /dev/null +++ b/tests/dialogue_understanding/generator/rendered_prompt.txt @@ -0,0 +1 @@ +Your task is to analyze the current conversation context and generate a list of actions to start new business processes that we call flows, to extract slots, or respond to small talk and knowledge requests.\n\nThese are the flows that can be started, with their description and slots:\n\ntest_flow: some description\n slot: test_slot\n \n\n===\nHere is what happened previously in the conversation:\nUSER: Hello\nAI: Hi\nUSER: some message\n\n===\n\nYou are currently not in any flow and so there are no active slots.\nThis means you can only set a slot if you first start a flow that requires that slot.\n\nIf you start a flow, first start the flow and then optionally fill that flow\'s slots with information the user provided in their message.\n\nThe user just said """some message""".\n\n===\nBased on this information generate a list of actions you want to take. Your job is to start flows and to fill slots where appropriate. Any logic of what happens afterwards is handled by the flow engine. These are your available actions:\n* Slot setting, described by "SetSlot(slot_name, slot_value)". An example would be "SetSlot(recipient, Freddy)"\n* Starting another flow, described by "StartFlow(flow_name)". An example would be "StartFlow(transfer_money)"\n* Cancelling the current flow, described by "CancelFlow()"\n* Clarifying which flow should be started. An example would be Clarify(list_contacts, add_contact, remove_contact) if the user just wrote "contacts" and there are multiple potential candidates. It also works with a single flow name to confirm you understood correctly, as in Clarify(transfer_money).\n* Responding to knowledge-oriented user messages, described by "SearchAndReply()"\n* Responding to a casual, non-task-oriented user message, described by "ChitChat()".\n* Handing off to a human, in case the user seems frustrated or explicitly asks to speak to one, described by "HumanHandoff()".\n\n===\nWrite out the actions you want to take, one per line, in the order they should take place.\nDo not fill slots with abstract values or placeholders.\nOnly use information provided by the user.\nOnly start a flow if it\'s completely clear what the user wants. Imagine you were a person reading this message. If it\'s not 100% clear, clarify the next step.\nDon\'t be overly confident. Take a conservative approach and clarify before proceeding.\nIf the user asks for two things which seem contradictory, clarify before starting a flow.\nStrictly adhere to the provided action types listed above.\nFocus on the last message and take it one step at a time.\nUse the previous conversation steps only to aid understanding.\n\nYour action list: \ No newline at end of file diff --git a/tests/dialogue_understanding/generator/test_llm_command_generator.py b/tests/dialogue_understanding/generator/test_llm_command_generator.py index ee807f82c6e9..02d13ecd049e 100644 --- a/tests/dialogue_understanding/generator/test_llm_command_generator.py +++ b/tests/dialogue_understanding/generator/test_llm_command_generator.py @@ -16,13 +16,19 @@ KnowledgeAnswerCommand, ClarifyCommand, ) -from rasa.engine.graph import ExecutionContext -from rasa.engine.storage.resource import Resource -from rasa.engine.storage.storage import ModelStorage +# from rasa.engine.graph import ExecutionContext +# from rasa.engine.storage.resource import Resource +# from rasa.engine.storage.storage import ModelStorage +from rasa.shared.core.events import BotUttered, UserUttered +from rasa.shared.core.flows.flow import FlowsList from rasa.shared.core.slots import BooleanSlot, FloatSlot, TextSlot from rasa.shared.core.trackers import DialogueStateTracker +from rasa.shared.nlu.training_data.message import Message +from tests.utilities import flows_from_str +TEST_PROMPT_PATH = "./tests/dialogue_understanding/generator/rendered_prompt.txt" + class TestLLMCommandGenerator: """Tests for the LLMCommandGenerator.""" @@ -32,64 +38,109 @@ def command_generator(self): return LLMCommandGenerator.create( config={}, resource=Mock(), model_storage=Mock(), execution_context=Mock()) + # @pytest.fixture + # def mock_command_generator( + # self, + # default_model_storage: ModelStorage, + # default_execution_context: ExecutionContext, + # ) -> LLMCommandGenerator: + # """Create a patched LLMCommandGenerator.""" + # with patch( + # "rasa.shared.utils.llm.llm_factory", + # Mock(return_value=FakeListLLM(responses=["StartFlow(check_balance)"])), + # ) as mock_llm: + # return LLMCommandGenerator.create( + # config=LLMCommandGenerator.get_default_config(), + # model_storage=default_model_storage, + # resource=Resource("llmcommandgenerator"), + # execution_context=default_execution_context) + @pytest.fixture - def mock_command_generator( - self, - default_model_storage: ModelStorage, - default_execution_context: ExecutionContext, - ) -> LLMCommandGenerator: - """Create a patched LLMCommandGenerator.""" - with patch( - "rasa.dialogue_understanding.generator.llm_command_generator.llm_factory", - Mock(return_value=FakeListLLM(responses=["StartFlow(check_balance)"])), - ) as mock_llm: - return LLMCommandGenerator.create( - config=LLMCommandGenerator.get_default_config(), - model_storage=default_model_storage, - resource=Resource("llmcommandgenerator"), - execution_context=default_execution_context) - - def test_predict_commands_with_no_flows(self, mock_command_generator: LLMCommandGenerator): + def test_flows(self) -> FlowsList: + """Create a FlowsList.""" + return flows_from_str( + """ + flows: + test_flow: + steps: + - id: first_step + action: action_listen + """ + ) + + + def test_predict_commands_with_no_flows( + self, + mock_command_generator: LLMCommandGenerator + ): """Test that predict_commands returns an empty list when flows is None.""" + # Given + empty_flows = FlowsList([]) # When - predicted_commands = mock_command_generator.predict_commands(Mock(), flows=None, tracker=Mock()) + predicted_commands = mock_command_generator.predict_commands( + Mock(), + flows=empty_flows, + tracker=Mock() + ) # Then assert not predicted_commands - def test_predict_commands_with_no_tracker(self, mock_command_generator: LLMCommandGenerator): + def test_predict_commands_with_no_tracker( + self, + mock_command_generator: LLMCommandGenerator + ): """Test that predict_commands returns an empty list when tracker is None.""" # When - predicted_commands = mock_command_generator.predict_commands(Mock(), flows=Mock(), tracker=None) + predicted_commands = mock_command_generator.predict_commands( + Mock(), + flows=Mock(), + tracker=None + ) # Then assert not predicted_commands - @patch.object(LLMCommandGenerator, "render_template", Mock(return_value="some prompt")) + @patch.object( + LLMCommandGenerator, + "render_template", + Mock(return_value="some prompt") + ) @patch.object(LLMCommandGenerator, "parse_commands", Mock()) - def test_predict_commands_calls_llm_correctly(self, command_generator: LLMCommandGenerator): + def test_predict_commands_calls_llm_correctly( + self, + command_generator: LLMCommandGenerator, + test_flows: FlowsList + ): """Test that predict_commands calls llm correctly.""" # When mock_llm = Mock() with patch( - "rasa.dialogue_understanding.generator.llm_command_generator.llm_factory", - Mock(return_value=mock_llm), + "rasa.shared.utils.llm.llm_factory", + Mock(return_value=mock_llm) ): - command_generator.predict_commands(Mock(), flows=Mock(), tracker=Mock()) + command_generator.predict_commands(Mock(), flows=test_flows, tracker=Mock()) # Then mock_llm.assert_called_once_with("some prompt") - @patch.object(LLMCommandGenerator, "render_template", Mock(return_value="some prompt")) + + @patch.object( + LLMCommandGenerator, + "render_template", + Mock(return_value="some prompt") + ) @patch.object(LLMCommandGenerator, "parse_commands", Mock()) - def test_generate_action_list_catches_llm_exception(self, command_generator: LLMCommandGenerator): - """Test that predict_commands calls llm correctly.""" + def test_generate_action_list_catches_llm_exception(self, + command_generator: LLMCommandGenerator, + test_flows: FlowsList): + """Test that predict_commands catches llm exceptions.""" # Given mock_llm = Mock(side_effect=Exception("some exception")) with patch( - "rasa.dialogue_understanding.generator.llm_command_generator.llm_factory", + "rasa.shared.utils.llm.llm_factory", Mock(return_value=mock_llm), ): # When with capture_logs() as logs: - command_generator.predict_commands(Mock(), flows=Mock(), tracker=Mock()) + command_generator.predict_commands(Mock(), flows=test_flows, tracker=Mock()) # Then print(logs) assert len(logs) == 4 @@ -97,30 +148,40 @@ def test_generate_action_list_catches_llm_exception(self, command_generator: LLM - def test_render_template(self, mock_command_generator: LLMCommandGenerator): + def test_render_template(self, command_generator: LLMCommandGenerator): """Test that render_template renders a template.""" - pass - # # Given - # message = Mock() - - # tracker = Mock() - - # flows = Mock() + # Given + test_message = Message.build(text="some message") + test_slot = TextSlot( + name="test_slot", mappings=[{}], initial_value=None, influence_conversation=False + ) + test_tracker = DialogueStateTracker.from_events( + sender_id="test", + evts=[UserUttered("Hello"), BotUttered("Hi")], + slots=[test_slot] + ) + test_flows = flows_from_str( + """ + flows: + test_flow: + description: some description + steps: + - id: first_step + collect_information: test_slot + """ + ) + with open(TEST_PROMPT_PATH, "r", encoding='unicode_escape') as f: + expected_template = f.read() # # When - # rendered_template = command_generator.render_template() + rendered_template = command_generator.render_template( + message=test_message, + tracker=test_tracker, + flows=test_flows + ) # # Then - # assert rendered_template == "template" + assert rendered_template == expected_template - # def test_generate_action_list_calls_llm_with_correct_promt(self): - # # Given - # prompt = "some prompt" - # with patch( - # "rasa.rasa.shared.utils.llm.llm_factory", - # Mock(return_value=FakeListLLM(responses=["hello"])) - # ) as mock_llm: - # LLMCommandGenerator._generate_action_list(prompt) - # mock_llm.assert_called_once_with(prompt) @pytest.mark.parametrize( "input_action, expected_command", From 4c82cbe0f8ca31a780bdd3892d39abdefe28e24a Mon Sep 17 00:00:00 2001 From: danc Date: Sun, 24 Sep 2023 16:00:06 +0100 Subject: [PATCH 03/19] added is extractable tests and tidy up. --- .../generator/llm_command_generator.py | 8 +- .../generator/test_llm_command_generator.py | 369 ++++++++++++------ 2 files changed, 251 insertions(+), 126 deletions(-) diff --git a/rasa/dialogue_understanding/generator/llm_command_generator.py b/rasa/dialogue_understanding/generator/llm_command_generator.py index 07ed70c33dfe..8938e0859aa3 100644 --- a/rasa/dialogue_understanding/generator/llm_command_generator.py +++ b/rasa/dialogue_understanding/generator/llm_command_generator.py @@ -300,7 +300,7 @@ def create_template_inputs( @staticmethod def is_extractable( - q: CollectInformationFlowStep, + info_step: CollectInformationFlowStep, tracker: DialogueStateTracker, current_step: Optional[FlowStep] = None, ) -> bool: @@ -309,20 +309,20 @@ def is_extractable( A collect slot can only be filled if the slot exist and either the collect has been asked already or the slot has been filled already.""" - slot = tracker.slots.get(q.collect) + slot = tracker.slots.get(info_step.collect_information) if slot is None: return False return ( # we can fill because this is a slot that can be filled ahead of time - not q.ask_before_filling + not info_step.ask_before_filling # we can fill because the slot has been filled already or slot.has_been_set # we can fill because the is currently getting asked or ( current_step is not None and isinstance(current_step, CollectInformationFlowStep) - and current_step.collect == q.collect + and current_step.collect_information == info_step.collect_information ) ) diff --git a/tests/dialogue_understanding/generator/test_llm_command_generator.py b/tests/dialogue_understanding/generator/test_llm_command_generator.py index 02d13ecd049e..c69f6c891977 100644 --- a/tests/dialogue_understanding/generator/test_llm_command_generator.py +++ b/tests/dialogue_understanding/generator/test_llm_command_generator.py @@ -1,12 +1,14 @@ +from typing import Optional from unittest.mock import Mock, patch import pytest -from langchain.llms.fake import FakeListLLM from structlog.testing import capture_logs -from rasa.dialogue_understanding.generator.llm_command_generator import LLMCommandGenerator +from rasa.dialogue_understanding.generator.llm_command_generator import ( + LLMCommandGenerator +) from rasa.dialogue_understanding.commands import ( - # Command, + Command, ErrorCommand, SetSlotCommand, CancelFlowCommand, @@ -16,18 +18,21 @@ KnowledgeAnswerCommand, ClarifyCommand, ) -# from rasa.engine.graph import ExecutionContext -# from rasa.engine.storage.resource import Resource -# from rasa.engine.storage.storage import ModelStorage -from rasa.shared.core.events import BotUttered, UserUttered -from rasa.shared.core.flows.flow import FlowsList -from rasa.shared.core.slots import BooleanSlot, FloatSlot, TextSlot +from rasa.shared.core.events import BotUttered, SlotSet, UserUttered +from rasa.shared.core.flows.flow import CollectInformationFlowStep, FlowsList +from rasa.shared.core.slots import ( + Slot, + BooleanSlot, + CategoricalSlot, + FloatSlot, + TextSlot +) from rasa.shared.core.trackers import DialogueStateTracker from rasa.shared.nlu.training_data.message import Message from tests.utilities import flows_from_str -TEST_PROMPT_PATH = "./tests/dialogue_understanding/generator/rendered_prompt.txt" +EXPECTED_PROMPT_PATH = "./tests/dialogue_understanding/generator/rendered_prompt.txt" class TestLLMCommandGenerator: """Tests for the LLMCommandGenerator.""" @@ -37,26 +42,10 @@ def command_generator(self): """Create an LLMCommandGenerator.""" return LLMCommandGenerator.create( config={}, resource=Mock(), model_storage=Mock(), execution_context=Mock()) - - # @pytest.fixture - # def mock_command_generator( - # self, - # default_model_storage: ModelStorage, - # default_execution_context: ExecutionContext, - # ) -> LLMCommandGenerator: - # """Create a patched LLMCommandGenerator.""" - # with patch( - # "rasa.shared.utils.llm.llm_factory", - # Mock(return_value=FakeListLLM(responses=["StartFlow(check_balance)"])), - # ) as mock_llm: - # return LLMCommandGenerator.create( - # config=LLMCommandGenerator.get_default_config(), - # model_storage=default_model_storage, - # resource=Resource("llmcommandgenerator"), - # execution_context=default_execution_context) - + + @pytest.fixture - def test_flows(self) -> FlowsList: + def flows(self) -> FlowsList: """Create a FlowsList.""" return flows_from_str( """ @@ -68,16 +57,15 @@ def test_flows(self) -> FlowsList: """ ) - def test_predict_commands_with_no_flows( - self, - mock_command_generator: LLMCommandGenerator - ): + self, + command_generator: LLMCommandGenerator + ): """Test that predict_commands returns an empty list when flows is None.""" # Given empty_flows = FlowsList([]) # When - predicted_commands = mock_command_generator.predict_commands( + predicted_commands = command_generator.predict_commands( Mock(), flows=empty_flows, tracker=Mock() @@ -86,12 +74,12 @@ def test_predict_commands_with_no_flows( assert not predicted_commands def test_predict_commands_with_no_tracker( - self, - mock_command_generator: LLMCommandGenerator - ): + self, + command_generator: LLMCommandGenerator + ): """Test that predict_commands returns an empty list when tracker is None.""" # When - predicted_commands = mock_command_generator.predict_commands( + predicted_commands = command_generator.predict_commands( Mock(), flows=Mock(), tracker=None @@ -99,57 +87,43 @@ def test_predict_commands_with_no_tracker( # Then assert not predicted_commands - @patch.object( - LLMCommandGenerator, - "render_template", - Mock(return_value="some prompt") - ) - @patch.object(LLMCommandGenerator, "parse_commands", Mock()) - def test_predict_commands_calls_llm_correctly( + # def test_generate_action_list_calls_llm_correctly( + # self, + # command_generator: LLMCommandGenerator, + # ): + # """Test that _generate_action_list calls llm correctly.""" + # # When + # with patch( + # "rasa.shared.utils.llm.llm_factory", + # Mock() + # ) as mock_llm_factory: + # command_generator._generate_action_list_using_llm("some prompt") + # # Then + # mock_llm_factory.assert_called() + + # def test_generate_action_list_catches_llm_exception( + # self, + # command_generator: LLMCommandGenerator, + # ): + # """Test that _generate_action_list calls llm correctly.""" + # # When + # mock_llm = Mock(side_effect=Exception("some exception")) + # with patch( + # "rasa.shared.utils.llm.llm_factory", + # Mock(return_value=mock_llm) + # ): + # with capture_logs() as logs: + # command_generator._generate_action_list_using_llm("some prompt") + # # Then + # print(logs) + # assert len(logs) == 4 + # # assert logs[1]["error"] == "some expection" + + def test_render_template( self, command_generator: LLMCommandGenerator, - test_flows: FlowsList ): - """Test that predict_commands calls llm correctly.""" - # When - mock_llm = Mock() - with patch( - "rasa.shared.utils.llm.llm_factory", - Mock(return_value=mock_llm) - ): - command_generator.predict_commands(Mock(), flows=test_flows, tracker=Mock()) - # Then - mock_llm.assert_called_once_with("some prompt") - - - @patch.object( - LLMCommandGenerator, - "render_template", - Mock(return_value="some prompt") - ) - @patch.object(LLMCommandGenerator, "parse_commands", Mock()) - def test_generate_action_list_catches_llm_exception(self, - command_generator: LLMCommandGenerator, - test_flows: FlowsList): - """Test that predict_commands catches llm exceptions.""" - # Given - mock_llm = Mock(side_effect=Exception("some exception")) - with patch( - "rasa.shared.utils.llm.llm_factory", - Mock(return_value=mock_llm), - ): - # When - with capture_logs() as logs: - command_generator.predict_commands(Mock(), flows=test_flows, tracker=Mock()) - # Then - print(logs) - assert len(logs) == 4 - assert isinstance(logs[1]["error"]) == isinstance(Exception("some exception")) - - - - def test_render_template(self, command_generator: LLMCommandGenerator): - """Test that render_template renders a template.""" + """Test that render_template renders the correct template string.""" # Given test_message = Message.build(text="some message") test_slot = TextSlot( @@ -170,7 +144,7 @@ def test_render_template(self, command_generator: LLMCommandGenerator): collect_information: test_slot """ ) - with open(TEST_PROMPT_PATH, "r", encoding='unicode_escape') as f: + with open(EXPECTED_PROMPT_PATH, "r", encoding='unicode_escape') as f: expected_template = f.read() # # When rendered_template = command_generator.render_template( @@ -182,7 +156,6 @@ def test_render_template(self, command_generator: LLMCommandGenerator): # # Then assert rendered_template == expected_template - @pytest.mark.parametrize( "input_action, expected_command", [ @@ -224,19 +197,31 @@ def test_render_template(self, command_generator: LLMCommandGenerator): ), ( "Clarify(list_contacts, add_contact, remove_contact)", - [ClarifyCommand(options=["list_contacts", "add_contact", "remove_contact"])] + [ClarifyCommand(options=[ + "list_contacts", + "add_contact", + "remove_contact" + ])] ), ]) - def test_parse_commands_identifies_correct_command(self, input_action, expected_command): + def test_parse_commands_identifies_correct_command( + self, + input_action: Optional[str], + expected_command: Command, + ): """Test that parse_commands identifies the correct commands.""" # When - with patch.object(LLMCommandGenerator, "coerce_slot_value", Mock(return_value=None)): + with patch.object( + LLMCommandGenerator, + "coerce_slot_value", + Mock(return_value=None) + ): parsed_commands = LLMCommandGenerator.parse_commands(input_action, Mock()) # Then assert parsed_commands == expected_command @pytest.mark.parametrize( - "slot_name, slot, slot_value, expected_coerced_value", + "slot_name, slot, slot_value, expected_output", [ ("some_other_slot", FloatSlot("some_float", []), None, None), ("some_float", FloatSlot("some_float", []), 40, 40.0), @@ -245,7 +230,13 @@ def test_parse_commands_identifies_correct_command(self, input_action, expected_ ("some_bool", BooleanSlot("some_bool", []), "True", True), ("some_bool", BooleanSlot("some_bool", []), "false", False) ]) - def test_coerce_slot_value(self, slot_name, slot, slot_value, expected_coerced_value): + def test_coerce_slot_value( + self, + slot_name: str, + slot: Slot, + slot_value: Optional[str|int|float|bool], + expected_output: Optional[str|int|float|bool], + ): """Test that coerce_slot_value coerces the slot value correctly.""" # Given tracker = DialogueStateTracker.from_events( @@ -256,10 +247,10 @@ def test_coerce_slot_value(self, slot_name, slot, slot_value, expected_coerced_v # When coerced_value = LLMCommandGenerator.coerce_slot_value(slot_value, slot_name, tracker) # Then - assert coerced_value == expected_coerced_value + assert coerced_value == expected_output @pytest.mark.parametrize( - "input_string, expected_string", + "input_value, expected_output", [ ("text", "text"), (" text ", "text"), @@ -268,36 +259,170 @@ def test_coerce_slot_value(self, slot_name, slot, slot_value, expected_coerced_v ("' \"text' \" ", "text"), ("", "") ]) - def test_clean_extracted_value(self, input_string, expected_string): - """Test that clean_extracted_value removes the leading and trailing whitespaces.""" + def test_clean_extracted_value(self, input_value: str, expected_output: str): + """Test that clean_extracted_value removes + the leading and trailing whitespaces. + """ # When - cleaned_extracted_value = LLMCommandGenerator.clean_extracted_value(input_string) + cleaned_value = LLMCommandGenerator.clean_extracted_value(input_value) # Then - assert cleaned_extracted_value == expected_string + assert cleaned_value == expected_output - - - - - - - - - - - # def test_allowd_values_for_slot(self, command_generator): - # """Test that allowed_values_for_slot returns the allowed values for a slot.""" - # # When - # allowed_values = command_generator.allowed_values_for_slot("slot_name") + @pytest.mark.parametrize( + "input_value, expected_truthiness", + [ + ("", False), + (" ", False), + ("none", False), + ("some text", False), + ("[missing information]", True), + ("[missing]", True), + ("None", True), + ("undefined",True), + ("null", True) + ]) + def test_is_none_value(self, input_value: str, expected_truthiness: bool): + """Test that is_none_value returns True when the value is None.""" + assert LLMCommandGenerator.is_none_value(input_value) == expected_truthiness - # # Then - # assert allowed_values == [] + @pytest.mark.parametrize( + "slot, slot_name, expected_output", + [ + (TextSlot("test_slot", [], initial_value="hello"), "test_slot", "hello"), + (TextSlot("test_slot", []), "some_other_slot", "undefined"), + ] + ) + def test_slot_value(self, slot: Slot, slot_name: str, expected_output: str): + """Test that slot_value returns the correct string.""" + # Given + tracker = DialogueStateTracker.from_events( + "test", + evts=[], + slots=[slot] + ) + # When + slot_value = LLMCommandGenerator.slot_value(tracker, slot_name) + + assert slot_value == expected_output + + @pytest.mark.parametrize( + "input_slot, expected_slot_values", + [ + (FloatSlot("test_slot", []), None), + (TextSlot("test_slot", []), None), + (BooleanSlot("test_slot", []), "[True, False]"), + (CategoricalSlot( + "test_slot", + [], + values=["Value1", "Value2"] ), "['value1', 'value2']"), + ]) + def test_allowed_values_for_slot( + self, + command_generator: LLMCommandGenerator, + input_slot: Slot, + expected_slot_values: Optional[str] + ): + """Test that allowed_values_for_slot returns the correct values.""" + # When + allowed_values = command_generator.allowed_values_for_slot(input_slot) + # Then + assert allowed_values == expected_slot_values + + @pytest.fixture + def collect_info_step(self) -> CollectInformationFlowStep: + """Create a CollectInformationFlowStep.""" + return CollectInformationFlowStep( + collect_information="test_slot", + ask_before_filling=True, + id="collect_information", + description="test_slot", + metadata={}, + next="next_step" + ) + + def test_is_extractable_with_no_slot( + self, + command_generator: LLMCommandGenerator, + collect_info_step: CollectInformationFlowStep + ): + """Test that is_extractable returns False + when there are no slots to be filled. + """ + # Given + tracker = DialogueStateTracker.from_events(sender_id="test", evts=[], slots=[]) + # When + is_extractable = command_generator.is_extractable( + collect_info_step, + tracker) + # Then + assert not is_extractable + + def test_is_extractable_when_slot_can_be_filled_without_asking( + self, + command_generator: LLMCommandGenerator, + ): + """Test that is_extractable returns True when collect_information can be filled.""" + # Given + tracker = DialogueStateTracker.from_events( + sender_id="test", + evts=[], + slots=[TextSlot(name="test_slot", mappings=[])] + ) + collect_info_step = CollectInformationFlowStep( + collect_information="test_slot", + ask_before_filling=False, + id="collect_information", + description="test_slot", + metadata={}, + next="next_step" + ) + # When + is_extractable = command_generator.is_extractable( + collect_info_step, + tracker) + # Then + assert is_extractable - # @pytest.mark.parametrize("input_value, expected_truthiness", - # [(None, True), - # ("", False), + def test_is_extractable_when_slot_has_already_been_set( + self, + command_generator: LLMCommandGenerator, + collect_info_step: CollectInformationFlowStep + ): + """Test that is_extractable returns True + when collect_information can be filled. + """ + # Given + slot = TextSlot(name="test_slot", mappings=[]) + tracker = DialogueStateTracker.from_events( + sender_id="test", + evts=[SlotSet("test_slot", "hello")], + slots=[slot] + ) + # When + is_extractable = command_generator.is_extractable( + collect_info_step, + tracker) + # Then + assert is_extractable - # )] - # def test_is_none_value(self): - # """Test that is_none_value returns True when the value is None.""" - # assert LLMCommandGenerator.is_none_value(None) + def test_is_extractable_with_current_step( + self, + command_generator: LLMCommandGenerator, + collect_info_step: CollectInformationFlowStep + ): + """Test that is_extractable returns True when the current step is a collect + information step and matches the information step. + """ + # Given + tracker = DialogueStateTracker.from_events( + sender_id="test", + evts=[UserUttered("Hello"), BotUttered("Hi")], + slots=[TextSlot(name="test_slot", mappings=[])] + ) + # When + is_extractable = command_generator.is_extractable( + collect_info_step, + tracker, + current_step=collect_info_step) + # Then + assert is_extractable \ No newline at end of file From 23329f7cfbe2c0abc0e360a0a78e5ba15a36f346 Mon Sep 17 00:00:00 2001 From: danc Date: Tue, 26 Sep 2023 10:52:25 +0100 Subject: [PATCH 04/19] fixed generate commnands tests and tidy up doc strings. --- .../generator/llm_command_generator.py | 310 ++++++++++-------- .../generator/test_llm_command_generator.py | 84 +++-- 2 files changed, 231 insertions(+), 163 deletions(-) diff --git a/rasa/dialogue_understanding/generator/llm_command_generator.py b/rasa/dialogue_understanding/generator/llm_command_generator.py index 8938e0859aa3..d13b9e16fe19 100644 --- a/rasa/dialogue_understanding/generator/llm_command_generator.py +++ b/rasa/dialogue_understanding/generator/llm_command_generator.py @@ -67,11 +67,7 @@ is_trainable=True, ) class LLMCommandGenerator(GraphComponent, CommandGenerator): - """An LLM based command generator. - - # TODO: add description to the docstring. - - """ + """An LLM based command generator.""" @staticmethod def get_default_config() -> Dict[str, Any]: """The component's default config (see parent class for full docstring).""" @@ -102,9 +98,6 @@ def create( """Creates a new untrained component (see parent class for full docstring).""" return cls(config, model_storage, resource) - def persist(self) -> None: - pass - @classmethod def load( cls, @@ -117,37 +110,30 @@ def load( """Loads trained component (see parent class for full docstring).""" return cls(config, model_storage, resource) + def persist(self) -> None: + pass + def train(self, training_data: TrainingData) -> Resource: """Train the intent classifier on a data set.""" self.persist() return self._resource - def _generate_action_list_using_llm(self, prompt: str) -> Optional[str]: - """Use LLM to generate a response. - - Args: - prompt: the prompt to send to the LLM - - Returns: - generated text - """ - llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG) - - try: - return llm(prompt) - except Exception as e: - # unfortunately, langchain does not wrap LLM exceptions which means - # we have to catch all exceptions here - structlogger.error("llm_command_generator.llm.error", error=e) - return None - def predict_commands( self, message: Message, flows: FlowsList, tracker: Optional[DialogueStateTracker] = None, ) -> List[Command]: - """TODO: add docstring""" + """Predict commands using the LLM. + + Args: + message: The message from the user. + flows: The flows available to the user. + tracker: The tracker containing the current state of the conversation. + + Returns: + The commands generated by the llm. + """ if tracker is None or flows.is_empty(): # cannot do anything if there are no flows or no tracker return [] @@ -168,64 +154,96 @@ def predict_commands( return commands - @staticmethod - def is_none_value(value: str) -> bool: - """TODO: add docstring""" - return value in { - "[missing information]", - "[missing]", - "None", - "undefined", - "null", - } + def render_template( + self, message: Message, tracker: DialogueStateTracker, flows: FlowsList + ) -> str: + """Render the jinja template to create the prompt for the LLM. - @staticmethod - def clean_extracted_value(value: str) -> str: - """Clean up the extracted value from the llm.""" - # replace any combination of single quotes, double quotes, and spaces - # from the beginning and end of the string - return re.sub(r"^['\"\s]+|['\"\s]+$", "", value) + Args: + message: The current message from the user. + tracker: The tracker containing the current state of the conversation. + flows: The flows available to the user. - @classmethod - def coerce_slot_value( - cls, slot_value: str, slot_name: str, tracker: DialogueStateTracker - ) -> Union[str, bool, float, None]: - """Coerce the slot value to the correct type. + Returns: + The rendered prompt template. + """ + flows_without_patterns = FlowsList( + [f for f in flows.underlying_flows if not f.is_handling_pattern()] + ) + top_relevant_frame = top_flow_frame(DialogueStack.from_tracker(tracker)) + top_flow = top_relevant_frame.flow(flows) if top_relevant_frame else None + current_step = top_relevant_frame.step(flows) if top_relevant_frame else None + if top_flow is not None: + flow_slots = [ + { + "name": info_step.collect_information, + "value": self.slot_value(tracker, info_step.collect_information), + "type": tracker.slots[info_step.collect_information].type_name, + "allowed_values": self.allowed_values_for_slot( + tracker.slots[info_step.collect_information] + ), + "description": info_step.description, + } + for info_step in top_flow.get_collect_information_steps() + if self.is_extractable(info_step, tracker, current_step) + ] + else: + flow_slots = [] - Tries to coerce the slot value to the correct type. If the - conversion fails, `None` is returned. + collect_information, collect_information_description = ( + (current_step.collect_information, current_step.description) + if isinstance(current_step, CollectInformationFlowStep) + else (None, None) + ) + current_conversation = tracker_as_readable_transcript(tracker) + latest_user_message = sanitize_message_for_prompt(message.get(TEXT)) + current_conversation += f"\nUSER: {latest_user_message}" + + inputs = { + "available_flows": self.create_template_inputs( + flows_without_patterns, tracker + ), + "current_conversation": current_conversation, + "flow_slots": flow_slots, + "current_flow": top_flow.id if top_flow is not None else None, + "collect_information": collect_information, + "collect_information_description": collect_information_description, + "user_message": latest_user_message, + } + + return Template(self.prompt_template).render(**inputs) + + def _generate_action_list_using_llm(self, prompt: str) -> Optional[str]: + """Use LLM to generate a response. Args: - value: the value to coerce - slot_name: the name of the slot - tracker: the tracker + prompt: The prompt to send to the LLM. Returns: - the coerced value or `None` if the conversion failed.""" - nullable_value = slot_value if not cls.is_none_value(slot_value) else None - if slot_name not in tracker.slots: - return nullable_value + The generated text. + """ + llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG) - slot = tracker.slots[slot_name] - if isinstance(slot, BooleanSlot): - try: - return bool_from_any(nullable_value) - except (ValueError, TypeError): - return None - elif isinstance(slot, FloatSlot): - try: - return float(nullable_value) - except (ValueError, TypeError): - return None - else: - return nullable_value + try: + return llm(prompt) + except Exception as e: + # unfortunately, langchain does not wrap LLM exceptions which means + # we have to catch all exceptions here + structlogger.error("llm_command_generator.llm.error", error=e) + return None @classmethod def parse_commands( cls, actions: Optional[str], tracker: DialogueStateTracker ) -> List[Command]: """Parse the actions returned by the llm into intent and entities. - #TODO: add arguments and returns. + + Args: + actions: The actions returned by the llm. + tracker: The tracker containing the current state of the conversation. + + Returns: + The parsed commands. """ if not actions: return [ErrorCommand()] @@ -243,9 +261,9 @@ def parse_commands( clarify_re = re.compile(r"Clarify\(([a-zA-Z0-9_, ]+)\)") for action in actions.strip().splitlines(): - if m := slot_set_re.search(action): - slot_name = m.group(1).strip() - slot_value = cls.clean_extracted_value(m.group(2)) + if match := slot_set_re.search(action): + slot_name = match.group(1).strip() + slot_value = cls.clean_extracted_value(match.group(2)) # error case where the llm tries to start a flow using a slot set if slot_name == "flow_name": commands.append(StartFlowCommand(flow=slot_value)) @@ -256,8 +274,8 @@ def parse_commands( commands.append( SetSlotCommand(name=slot_name, value=typed_slot_value) ) - elif m := start_flow_re.search(action): - commands.append(StartFlowCommand(flow=m.group(1).strip())) + elif match := start_flow_re.search(action): + commands.append(StartFlowCommand(flow=match.group(1).strip())) elif cancel_flow_re.search(action): commands.append(CancelFlowCommand()) elif chitchat_re.search(action): @@ -266,17 +284,78 @@ def parse_commands( commands.append(KnowledgeAnswerCommand()) elif humand_handoff_re.search(action): commands.append(HumanHandoffCommand()) - elif m := clarify_re.search(action): - options = [opt.strip() for opt in m.group(1).split(",")] + elif match := clarify_re.search(action): + options = [opt.strip() for opt in match.group(1).split(",")] commands.append(ClarifyCommand(options)) return commands + @staticmethod + def is_none_value(value: str) -> bool: + """Check if the value is a none value.""" + return value in { + "[missing information]", + "[missing]", + "None", + "undefined", + "null", + } + + @staticmethod + def clean_extracted_value(value: str) -> str: + """Clean up the extracted value from the llm.""" + # replace any combination of single quotes, double quotes, and spaces + # from the beginning and end of the string + return re.sub(r"^['\"\s]+|['\"\s]+$", "", value) + + @classmethod + def coerce_slot_value( + cls, slot_value: str, slot_name: str, tracker: DialogueStateTracker + ) -> Union[str, bool, float, None]: + """Coerce the slot value to the correct type. + + Tries to coerce the slot value to the correct type. If the + conversion fails, `None` is returned. + + Args: + value: The value to coerce. + slot_name: The name of the slot. + tracker: The tracker containing the current state of the conversation. + + Returns: + The coerced value or `None` if the conversion failed. + """ + nullable_value = slot_value if not cls.is_none_value(slot_value) else None + if slot_name not in tracker.slots: + return nullable_value + + slot = tracker.slots[slot_name] + if isinstance(slot, BooleanSlot): + try: + return bool_from_any(nullable_value) + except (ValueError, TypeError): + return None + elif isinstance(slot, FloatSlot): + try: + return float(nullable_value) + except (ValueError, TypeError): + return None + else: + return nullable_value + @classmethod def create_template_inputs( cls, flows: FlowsList, tracker: DialogueStateTracker ) -> List[Dict[str, Any]]: - """TODO: add docstring.""" + """Create the template inputs for the flows. + + Args: + flows: The flows available to the user. + tracker: The tracker containing the current state of the conversation. + + Returns: + The inputs for the prompt template. + """ result = [] for flow in flows.underlying_flows: # TODO: check if we should filter more flows; e.g. flows that are @@ -308,7 +387,16 @@ def is_extractable( A collect slot can only be filled if the slot exist and either the collect has been asked already or the - slot has been filled already.""" + slot has been filled already. + + Args: + info_step: The collect_information step. + tracker: The tracker containing the current state of the conversation. + current_step: The current step in the flow. + + Returns: + `True` if the slot can be filled, `False` otherwise. + """ slot = tracker.slots.get(info_step.collect_information) if slot is None: return False @@ -337,59 +425,17 @@ def allowed_values_for_slot(self, slot: Slot) -> Optional[str]: @staticmethod def slot_value(tracker: DialogueStateTracker, slot_name: str) -> str: - """Get the slot value from the tracker.""" + """Get the slot value from the tracker. + + Args: + tracker: The tracker containing the current state of the conversation. + slot_name: The name of the slot. + + Returns: + The slot value as a string. + """ slot_value = tracker.get_slot(slot_name) if slot_value is None: return "undefined" else: return str(slot_value) - - def render_template( - self, message: Message, tracker: DialogueStateTracker, flows: FlowsList - ) -> str: - """TODO: add docstring""" - flows_without_patterns = FlowsList( - [f for f in flows.underlying_flows if not f.is_handling_pattern()] - ) - top_relevant_frame = top_flow_frame(DialogueStack.from_tracker(tracker)) - top_flow = top_relevant_frame.flow(flows) if top_relevant_frame else None - current_step = top_relevant_frame.step(flows) if top_relevant_frame else None - if top_flow is not None: - flow_slots = [ - { - "name": q.collect, - "value": self.slot_value(tracker, q.collect), - "type": tracker.slots[q.collect].type_name, - "allowed_values": self.allowed_values_for_slot( - tracker.slots[q.collect] - ), - "description": q.description, - } - for q in top_flow.get_collect_steps() - if self.is_extractable(q, tracker, current_step) - ] - else: - flow_slots = [] - - collect, collect_description = ( - (current_step.collect, current_step.description) - if isinstance(current_step, CollectInformationFlowStep) - else (None, None) - ) - current_conversation = tracker_as_readable_transcript(tracker) - latest_user_message = sanitize_message_for_prompt(message.get(TEXT)) - current_conversation += f"\nUSER: {latest_user_message}" - - inputs = { - "available_flows": self.create_template_inputs( - flows_without_patterns, tracker - ), - "current_conversation": current_conversation, - "flow_slots": flow_slots, - "current_flow": top_flow.id if top_flow is not None else None, - "collect": collect, - "collect_description": collect_description, - "user_message": latest_user_message, - } - - return Template(self.prompt_template).render(**inputs) diff --git a/tests/dialogue_understanding/generator/test_llm_command_generator.py b/tests/dialogue_understanding/generator/test_llm_command_generator.py index c69f6c891977..10c2239bf25a 100644 --- a/tests/dialogue_understanding/generator/test_llm_command_generator.py +++ b/tests/dialogue_understanding/generator/test_llm_command_generator.py @@ -43,7 +43,6 @@ def command_generator(self): return LLMCommandGenerator.create( config={}, resource=Mock(), model_storage=Mock(), execution_context=Mock()) - @pytest.fixture def flows(self) -> FlowsList: """Create a FlowsList.""" @@ -87,37 +86,60 @@ def test_predict_commands_with_no_tracker( # Then assert not predicted_commands - # def test_generate_action_list_calls_llm_correctly( - # self, - # command_generator: LLMCommandGenerator, - # ): - # """Test that _generate_action_list calls llm correctly.""" - # # When - # with patch( - # "rasa.shared.utils.llm.llm_factory", - # Mock() - # ) as mock_llm_factory: - # command_generator._generate_action_list_using_llm("some prompt") - # # Then - # mock_llm_factory.assert_called() + def test_generate_action_list_calls_llm_factory_correctly( + self, + command_generator: LLMCommandGenerator, + ): + """Test that _generate_action_list calls llm correctly.""" + # Given + llm_config = { + "_type": "openai", + "request_timeout": 7, + "temperature": 0.0, + "model_name": "gpt-4", + } + # When + with patch( + "rasa.dialogue_understanding.generator.llm_command_generator.llm_factory", + Mock() + ) as mock_llm_factory: + command_generator._generate_action_list_using_llm("some prompt") + # Then + mock_llm_factory.assert_called_once_with(None, llm_config) + + def test_generate_action_list_calls_llm_correctly( + self, + command_generator: LLMCommandGenerator, + ): + """Test that _generate_action_list calls llm correctly.""" + # Given + with patch( + "rasa.dialogue_understanding.generator.llm_command_generator.llm_factory", + Mock() + ) as mock_llm_factory: + mock_llm_factory.return_value = Mock() + # When + command_generator._generate_action_list_using_llm("some prompt") + # Then + mock_llm_factory.return_value.assert_called_once_with("some prompt") - # def test_generate_action_list_catches_llm_exception( - # self, - # command_generator: LLMCommandGenerator, - # ): - # """Test that _generate_action_list calls llm correctly.""" - # # When - # mock_llm = Mock(side_effect=Exception("some exception")) - # with patch( - # "rasa.shared.utils.llm.llm_factory", - # Mock(return_value=mock_llm) - # ): - # with capture_logs() as logs: - # command_generator._generate_action_list_using_llm("some prompt") - # # Then - # print(logs) - # assert len(logs) == 4 - # # assert logs[1]["error"] == "some expection" + def test_generate_action_list_catches_llm_exception( + self, + command_generator: LLMCommandGenerator, + ): + """Test that _generate_action_list calls llm correctly.""" + # When + mock_llm = Mock(side_effect=Exception("some exception")) + with patch( + "rasa.dialogue_understanding.generator.llm_command_generator.llm_factory", + Mock(return_value=mock_llm) + ): + with capture_logs() as logs: + command_generator._generate_action_list_using_llm("some prompt") + # Then + print(logs) + assert len(logs) == 1 + assert logs[0]["event"] == "llm_command_generator.llm.error" def test_render_template( self, From 97565f8d012f59f1353ed4e9f27a82deb5205c7d Mon Sep 17 00:00:00 2001 From: danc Date: Tue, 26 Sep 2023 11:41:16 +0100 Subject: [PATCH 05/19] fixed collect information flow steps and applied black formatter. --- .../generator/llm_command_generator.py | 17 +- .../generator/test_llm_command_generator.py | 287 ++++++++---------- 2 files changed, 136 insertions(+), 168 deletions(-) diff --git a/rasa/dialogue_understanding/generator/llm_command_generator.py b/rasa/dialogue_understanding/generator/llm_command_generator.py index d13b9e16fe19..4b48931efc95 100644 --- a/rasa/dialogue_understanding/generator/llm_command_generator.py +++ b/rasa/dialogue_understanding/generator/llm_command_generator.py @@ -68,6 +68,7 @@ ) class LLMCommandGenerator(GraphComponent, CommandGenerator): """An LLM based command generator.""" + @staticmethod def get_default_config() -> Dict[str, Any]: """The component's default config (see parent class for full docstring).""" @@ -130,7 +131,7 @@ def predict_commands( message: The message from the user. flows: The flows available to the user. tracker: The tracker containing the current state of the conversation. - + Returns: The commands generated by the llm. """ @@ -237,11 +238,11 @@ def parse_commands( cls, actions: Optional[str], tracker: DialogueStateTracker ) -> List[Command]: """Parse the actions returned by the llm into intent and entities. - + Args: actions: The actions returned by the llm. tracker: The tracker containing the current state of the conversation. - + Returns: The parsed commands. """ @@ -318,7 +319,7 @@ def coerce_slot_value( conversion fails, `None` is returned. Args: - value: The value to coerce. + value: The value to coerce. slot_name: The name of the slot. tracker: The tracker containing the current state of the conversation. @@ -352,7 +353,7 @@ def create_template_inputs( Args: flows: The flows available to the user. tracker: The tracker containing the current state of the conversation. - + Returns: The inputs for the prompt template. """ @@ -388,7 +389,7 @@ def is_extractable( A collect slot can only be filled if the slot exist and either the collect has been asked already or the slot has been filled already. - + Args: info_step: The collect_information step. tracker: The tracker containing the current state of the conversation. @@ -426,11 +427,11 @@ def allowed_values_for_slot(self, slot: Slot) -> Optional[str]: @staticmethod def slot_value(tracker: DialogueStateTracker, slot_name: str) -> str: """Get the slot value from the tracker. - + Args: tracker: The tracker containing the current state of the conversation. slot_name: The name of the slot. - + Returns: The slot value as a string. """ diff --git a/tests/dialogue_understanding/generator/test_llm_command_generator.py b/tests/dialogue_understanding/generator/test_llm_command_generator.py index 10c2239bf25a..de66e616fa13 100644 --- a/tests/dialogue_understanding/generator/test_llm_command_generator.py +++ b/tests/dialogue_understanding/generator/test_llm_command_generator.py @@ -5,7 +5,7 @@ from structlog.testing import capture_logs from rasa.dialogue_understanding.generator.llm_command_generator import ( - LLMCommandGenerator + LLMCommandGenerator, ) from rasa.dialogue_understanding.commands import ( Command, @@ -19,13 +19,17 @@ ClarifyCommand, ) from rasa.shared.core.events import BotUttered, SlotSet, UserUttered -from rasa.shared.core.flows.flow import CollectInformationFlowStep, FlowsList +from rasa.shared.core.flows.flow import ( + CollectInformationFlowStep, + FlowsList, + SlotRejection, +) from rasa.shared.core.slots import ( Slot, BooleanSlot, CategoricalSlot, FloatSlot, - TextSlot + TextSlot, ) from rasa.shared.core.trackers import DialogueStateTracker from rasa.shared.nlu.training_data.message import Message @@ -34,6 +38,7 @@ EXPECTED_PROMPT_PATH = "./tests/dialogue_understanding/generator/rendered_prompt.txt" + class TestLLMCommandGenerator: """Tests for the LLMCommandGenerator.""" @@ -41,7 +46,8 @@ class TestLLMCommandGenerator: def command_generator(self): """Create an LLMCommandGenerator.""" return LLMCommandGenerator.create( - config={}, resource=Mock(), model_storage=Mock(), execution_context=Mock()) + config={}, resource=Mock(), model_storage=Mock(), execution_context=Mock() + ) @pytest.fixture def flows(self) -> FlowsList: @@ -57,31 +63,25 @@ def flows(self) -> FlowsList: ) def test_predict_commands_with_no_flows( - self, - command_generator: LLMCommandGenerator + self, command_generator: LLMCommandGenerator ): """Test that predict_commands returns an empty list when flows is None.""" # Given empty_flows = FlowsList([]) # When predicted_commands = command_generator.predict_commands( - Mock(), - flows=empty_flows, - tracker=Mock() + Mock(), flows=empty_flows, tracker=Mock() ) # Then assert not predicted_commands def test_predict_commands_with_no_tracker( - self, - command_generator: LLMCommandGenerator + self, command_generator: LLMCommandGenerator ): """Test that predict_commands returns an empty list when tracker is None.""" # When predicted_commands = command_generator.predict_commands( - Mock(), - flows=Mock(), - tracker=None + Mock(), flows=Mock(), tracker=None ) # Then assert not predicted_commands @@ -101,7 +101,7 @@ def test_generate_action_list_calls_llm_factory_correctly( # When with patch( "rasa.dialogue_understanding.generator.llm_command_generator.llm_factory", - Mock() + Mock(), ) as mock_llm_factory: command_generator._generate_action_list_using_llm("some prompt") # Then @@ -115,7 +115,7 @@ def test_generate_action_list_calls_llm_correctly( # Given with patch( "rasa.dialogue_understanding.generator.llm_command_generator.llm_factory", - Mock() + Mock(), ) as mock_llm_factory: mock_llm_factory.return_value = Mock() # When @@ -132,7 +132,7 @@ def test_generate_action_list_catches_llm_exception( mock_llm = Mock(side_effect=Exception("some exception")) with patch( "rasa.dialogue_understanding.generator.llm_command_generator.llm_factory", - Mock(return_value=mock_llm) + Mock(return_value=mock_llm), ): with capture_logs() as logs: command_generator._generate_action_list_using_llm("some prompt") @@ -149,12 +149,15 @@ def test_render_template( # Given test_message = Message.build(text="some message") test_slot = TextSlot( - name="test_slot", mappings=[{}], initial_value=None, influence_conversation=False + name="test_slot", + mappings=[{}], + initial_value=None, + influence_conversation=False, ) test_tracker = DialogueStateTracker.from_events( sender_id="test", evts=[UserUttered("Hello"), BotUttered("Hi")], - slots=[test_slot] + slots=[test_slot], ) test_flows = flows_from_str( """ @@ -166,66 +169,41 @@ def test_render_template( collect_information: test_slot """ ) - with open(EXPECTED_PROMPT_PATH, "r", encoding='unicode_escape') as f: + with open(EXPECTED_PROMPT_PATH, "r", encoding="unicode_escape") as f: expected_template = f.read() # # When rendered_template = command_generator.render_template( - message=test_message, - tracker=test_tracker, - flows=test_flows + message=test_message, tracker=test_tracker, flows=test_flows ) # # Then assert rendered_template == expected_template @pytest.mark.parametrize( - "input_action, expected_command", - [ - ( - None, - [ErrorCommand()] - ), - ( - "SetSlot(transfer_money_amount_of_money, )", - [SetSlotCommand(name="transfer_money_amount_of_money", value=None)] - ), - ( - "SetSlot(flow_name, some_flow)", - [StartFlowCommand(flow="some_flow")] - ), - ( - "StartFlow(check_balance)", - [StartFlowCommand(flow="check_balance")] - ), - ( - "CancelFlow()", - [CancelFlowCommand()] - ), - ( - "ChitChat()", - [ChitChatAnswerCommand()] - ), - ( - "SearchAndReply()", - [KnowledgeAnswerCommand()] - ), - ( - "HumanHandoff()", - [HumanHandoffCommand()] - ), - ( - "Clarify(transfer_money)", - [ClarifyCommand(options=["transfer_money"])] - ), - ( - "Clarify(list_contacts, add_contact, remove_contact)", - [ClarifyCommand(options=[ - "list_contacts", - "add_contact", - "remove_contact" - ])] - ), - ]) + "input_action, expected_command", + [ + (None, [ErrorCommand()]), + ( + "SetSlot(transfer_money_amount_of_money, )", + [SetSlotCommand(name="transfer_money_amount_of_money", value=None)], + ), + ("SetSlot(flow_name, some_flow)", [StartFlowCommand(flow="some_flow")]), + ("StartFlow(check_balance)", [StartFlowCommand(flow="check_balance")]), + ("CancelFlow()", [CancelFlowCommand()]), + ("ChitChat()", [ChitChatAnswerCommand()]), + ("SearchAndReply()", [KnowledgeAnswerCommand()]), + ("HumanHandoff()", [HumanHandoffCommand()]), + ("Clarify(transfer_money)", [ClarifyCommand(options=["transfer_money"])]), + ( + "Clarify(list_contacts, add_contact, remove_contact)", + [ + ClarifyCommand( + options=["list_contacts", "add_contact", "remove_contact"] + ) + ], + ), + ], + ) def test_parse_commands_identifies_correct_command( self, input_action: Optional[str], @@ -234,94 +212,89 @@ def test_parse_commands_identifies_correct_command( """Test that parse_commands identifies the correct commands.""" # When with patch.object( - LLMCommandGenerator, - "coerce_slot_value", - Mock(return_value=None) + LLMCommandGenerator, "coerce_slot_value", Mock(return_value=None) ): parsed_commands = LLMCommandGenerator.parse_commands(input_action, Mock()) # Then assert parsed_commands == expected_command @pytest.mark.parametrize( - "slot_name, slot, slot_value, expected_output", - [ - ("some_other_slot", FloatSlot("some_float", []), None, None), - ("some_float", FloatSlot("some_float", []), 40, 40.0), - ("some_float", FloatSlot("some_float", []), 40.0, 40.0), - ("some_text", TextSlot("some_text", []),"fourty", "fourty"), - ("some_bool", BooleanSlot("some_bool", []), "True", True), - ("some_bool", BooleanSlot("some_bool", []), "false", False) - ]) + "slot_name, slot, slot_value, expected_output", + [ + ("some_other_slot", FloatSlot("some_float", []), None, None), + ("some_float", FloatSlot("some_float", []), 40, 40.0), + ("some_float", FloatSlot("some_float", []), 40.0, 40.0), + ("some_text", TextSlot("some_text", []), "fourty", "fourty"), + ("some_bool", BooleanSlot("some_bool", []), "True", True), + ("some_bool", BooleanSlot("some_bool", []), "false", False), + ], + ) def test_coerce_slot_value( self, slot_name: str, slot: Slot, - slot_value: Optional[str|int|float|bool], - expected_output: Optional[str|int|float|bool], + slot_value: Optional[str | int | float | bool], + expected_output: Optional[str | int | float | bool], ): """Test that coerce_slot_value coerces the slot value correctly.""" # Given - tracker = DialogueStateTracker.from_events( - "test", - evts=[], - slots=[slot] - ) + tracker = DialogueStateTracker.from_events("test", evts=[], slots=[slot]) # When - coerced_value = LLMCommandGenerator.coerce_slot_value(slot_value, slot_name, tracker) + coerced_value = LLMCommandGenerator.coerce_slot_value( + slot_value, slot_name, tracker + ) # Then assert coerced_value == expected_output @pytest.mark.parametrize( - "input_value, expected_output", - [ - ("text", "text"), - (" text ", "text"), - ("\"text\"", "text"), - ("'text'", "text"), - ("' \"text' \" ", "text"), - ("", "") - ]) + "input_value, expected_output", + [ + ("text", "text"), + (" text ", "text"), + ('"text"', "text"), + ("'text'", "text"), + ("' \"text' \" ", "text"), + ("", ""), + ], + ) def test_clean_extracted_value(self, input_value: str, expected_output: str): - """Test that clean_extracted_value removes - the leading and trailing whitespaces. + """Test that clean_extracted_value removes + the leading and trailing whitespaces. """ # When - cleaned_value = LLMCommandGenerator.clean_extracted_value(input_value) + cleaned_value = LLMCommandGenerator.clean_extracted_value(input_value) # Then assert cleaned_value == expected_output @pytest.mark.parametrize( - "input_value, expected_truthiness", - [ - ("", False), - (" ", False), - ("none", False), - ("some text", False), - ("[missing information]", True), - ("[missing]", True), - ("None", True), - ("undefined",True), - ("null", True) - ]) + "input_value, expected_truthiness", + [ + ("", False), + (" ", False), + ("none", False), + ("some text", False), + ("[missing information]", True), + ("[missing]", True), + ("None", True), + ("undefined", True), + ("null", True), + ], + ) def test_is_none_value(self, input_value: str, expected_truthiness: bool): """Test that is_none_value returns True when the value is None.""" assert LLMCommandGenerator.is_none_value(input_value) == expected_truthiness @pytest.mark.parametrize( - "slot, slot_name, expected_output", - [ - (TextSlot("test_slot", [], initial_value="hello"), "test_slot", "hello"), - (TextSlot("test_slot", []), "some_other_slot", "undefined"), - ] + "slot, slot_name, expected_output", + [ + (TextSlot("test_slot", [], initial_value="hello"), "test_slot", "hello"), + (TextSlot("test_slot", []), "some_other_slot", "undefined"), + ], ) def test_slot_value(self, slot: Slot, slot_name: str, expected_output: str): """Test that slot_value returns the correct string.""" # Given - tracker = DialogueStateTracker.from_events( - "test", - evts=[], - slots=[slot] - ) + tracker = DialogueStateTracker.from_events("test", evts=[], slots=[slot]) # When slot_value = LLMCommandGenerator.slot_value(tracker, slot_name) @@ -333,16 +306,17 @@ def test_slot_value(self, slot: Slot, slot_name: str, expected_output: str): (FloatSlot("test_slot", []), None), (TextSlot("test_slot", []), None), (BooleanSlot("test_slot", []), "[True, False]"), - (CategoricalSlot( - "test_slot", - [], - values=["Value1", "Value2"] ), "['value1', 'value2']"), - ]) + ( + CategoricalSlot("test_slot", [], values=["Value1", "Value2"]), + "['value1', 'value2']", + ), + ], + ) def test_allowed_values_for_slot( self, command_generator: LLMCommandGenerator, input_slot: Slot, - expected_slot_values: Optional[str] + expected_slot_values: Optional[str], ): """Test that allowed_values_for_slot returns the correct values.""" # When @@ -356,26 +330,26 @@ def collect_info_step(self) -> CollectInformationFlowStep: return CollectInformationFlowStep( collect_information="test_slot", ask_before_filling=True, + utter="hello", + rejections=[SlotRejection("test_slot", "some rejection")], id="collect_information", description="test_slot", metadata={}, - next="next_step" + next="next_step", ) def test_is_extractable_with_no_slot( self, command_generator: LLMCommandGenerator, - collect_info_step: CollectInformationFlowStep + collect_info_step: CollectInformationFlowStep, ): """Test that is_extractable returns False - when there are no slots to be filled. + when there are no slots to be filled. """ # Given tracker = DialogueStateTracker.from_events(sender_id="test", evts=[], slots=[]) # When - is_extractable = command_generator.is_extractable( - collect_info_step, - tracker) + is_extractable = command_generator.is_extractable(collect_info_step, tracker) # Then assert not is_extractable @@ -386,65 +360,58 @@ def test_is_extractable_when_slot_can_be_filled_without_asking( """Test that is_extractable returns True when collect_information can be filled.""" # Given tracker = DialogueStateTracker.from_events( - sender_id="test", - evts=[], - slots=[TextSlot(name="test_slot", mappings=[])] + sender_id="test", evts=[], slots=[TextSlot(name="test_slot", mappings=[])] ) collect_info_step = CollectInformationFlowStep( collect_information="test_slot", ask_before_filling=False, + utter="hello", + rejections=[SlotRejection("test_slot", "some rejection")], id="collect_information", description="test_slot", metadata={}, - next="next_step" + next="next_step", ) # When - is_extractable = command_generator.is_extractable( - collect_info_step, - tracker) + is_extractable = command_generator.is_extractable(collect_info_step, tracker) # Then assert is_extractable def test_is_extractable_when_slot_has_already_been_set( self, command_generator: LLMCommandGenerator, - collect_info_step: CollectInformationFlowStep + collect_info_step: CollectInformationFlowStep, ): - """Test that is_extractable returns True - when collect_information can be filled. + """Test that is_extractable returns True + when collect_information can be filled. """ # Given slot = TextSlot(name="test_slot", mappings=[]) tracker = DialogueStateTracker.from_events( - sender_id="test", - evts=[SlotSet("test_slot", "hello")], - slots=[slot] + sender_id="test", evts=[SlotSet("test_slot", "hello")], slots=[slot] ) # When - is_extractable = command_generator.is_extractable( - collect_info_step, - tracker) + is_extractable = command_generator.is_extractable(collect_info_step, tracker) # Then assert is_extractable def test_is_extractable_with_current_step( self, command_generator: LLMCommandGenerator, - collect_info_step: CollectInformationFlowStep + collect_info_step: CollectInformationFlowStep, ): - """Test that is_extractable returns True when the current step is a collect - information step and matches the information step. + """Test that is_extractable returns True when the current step is a collect + information step and matches the information step. """ # Given tracker = DialogueStateTracker.from_events( sender_id="test", evts=[UserUttered("Hello"), BotUttered("Hi")], - slots=[TextSlot(name="test_slot", mappings=[])] + slots=[TextSlot(name="test_slot", mappings=[])], ) # When is_extractable = command_generator.is_extractable( - collect_info_step, - tracker, - current_step=collect_info_step) + collect_info_step, tracker, current_step=collect_info_step + ) # Then - assert is_extractable \ No newline at end of file + assert is_extractable From f4b67fe9ec0b5dde26ff95e3029b1d1baac5cd10 Mon Sep 17 00:00:00 2001 From: danc Date: Tue, 26 Sep 2023 11:52:09 +0100 Subject: [PATCH 06/19] fixed missed long line in in tests. --- .../generator/test_llm_command_generator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/dialogue_understanding/generator/test_llm_command_generator.py b/tests/dialogue_understanding/generator/test_llm_command_generator.py index de66e616fa13..1f46ed181b79 100644 --- a/tests/dialogue_understanding/generator/test_llm_command_generator.py +++ b/tests/dialogue_understanding/generator/test_llm_command_generator.py @@ -357,7 +357,9 @@ def test_is_extractable_when_slot_can_be_filled_without_asking( self, command_generator: LLMCommandGenerator, ): - """Test that is_extractable returns True when collect_information can be filled.""" + """Test that is_extractable returns True when + collect_information slot can be filled. + """ # Given tracker = DialogueStateTracker.from_events( sender_id="test", evts=[], slots=[TextSlot(name="test_slot", mappings=[])] From 6fb3dd825e44d1f1a034ecac7e98b5a917e6768a Mon Sep 17 00:00:00 2001 From: danc Date: Tue, 26 Sep 2023 12:03:37 +0100 Subject: [PATCH 07/19] fixed single trailing whitespace. --- .../generator/test_llm_command_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dialogue_understanding/generator/test_llm_command_generator.py b/tests/dialogue_understanding/generator/test_llm_command_generator.py index 1f46ed181b79..dd0769303a04 100644 --- a/tests/dialogue_understanding/generator/test_llm_command_generator.py +++ b/tests/dialogue_understanding/generator/test_llm_command_generator.py @@ -357,7 +357,7 @@ def test_is_extractable_when_slot_can_be_filled_without_asking( self, command_generator: LLMCommandGenerator, ): - """Test that is_extractable returns True when + """Test that is_extractable returns True when collect_information slot can be filled. """ # Given From a0885311fbf8380f77878d6dea177abab17b1845 Mon Sep 17 00:00:00 2001 From: danc Date: Tue, 26 Sep 2023 12:27:14 +0100 Subject: [PATCH 08/19] fixed use of pipe operator. --- .../generator/test_llm_command_generator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/dialogue_understanding/generator/test_llm_command_generator.py b/tests/dialogue_understanding/generator/test_llm_command_generator.py index dd0769303a04..78f73843708a 100644 --- a/tests/dialogue_understanding/generator/test_llm_command_generator.py +++ b/tests/dialogue_understanding/generator/test_llm_command_generator.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Any from unittest.mock import Mock, patch import pytest @@ -233,8 +233,8 @@ def test_coerce_slot_value( self, slot_name: str, slot: Slot, - slot_value: Optional[str | int | float | bool], - expected_output: Optional[str | int | float | bool], + slot_value: Any, + expected_output: Any, ): """Test that coerce_slot_value coerces the slot value correctly.""" # Given From 5cb9634bcd1b147f84dec65c40e4837ca90d2408 Mon Sep 17 00:00:00 2001 From: danc Date: Wed, 4 Oct 2023 14:50:50 +0100 Subject: [PATCH 09/19] chore: fix unit tests after rebase. --- .../generator/llm_command_generator.py | 14 +++++++------- .../generator/rendered_prompt.txt | 2 +- .../generator/test_llm_command_generator.py | 10 ++++++---- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/rasa/dialogue_understanding/generator/llm_command_generator.py b/rasa/dialogue_understanding/generator/llm_command_generator.py index 4b48931efc95..95ef6962b467 100644 --- a/rasa/dialogue_understanding/generator/llm_command_generator.py +++ b/rasa/dialogue_understanding/generator/llm_command_generator.py @@ -177,11 +177,11 @@ def render_template( if top_flow is not None: flow_slots = [ { - "name": info_step.collect_information, - "value": self.slot_value(tracker, info_step.collect_information), - "type": tracker.slots[info_step.collect_information].type_name, + "name": info_step.collect, + "value": self.slot_value(tracker, info_step.collect), + "type": tracker.slots[info_step.collect].type_name, "allowed_values": self.allowed_values_for_slot( - tracker.slots[info_step.collect_information] + tracker.slots[info_step.collect] ), "description": info_step.description, } @@ -192,7 +192,7 @@ def render_template( flow_slots = [] collect_information, collect_information_description = ( - (current_step.collect_information, current_step.description) + (current_step.collect, current_step.description) if isinstance(current_step, CollectInformationFlowStep) else (None, None) ) @@ -398,7 +398,7 @@ def is_extractable( Returns: `True` if the slot can be filled, `False` otherwise. """ - slot = tracker.slots.get(info_step.collect_information) + slot = tracker.slots.get(info_step.collect) if slot is None: return False @@ -411,7 +411,7 @@ def is_extractable( or ( current_step is not None and isinstance(current_step, CollectInformationFlowStep) - and current_step.collect_information == info_step.collect_information + and current_step.collect == info_step.collect ) ) diff --git a/tests/dialogue_understanding/generator/rendered_prompt.txt b/tests/dialogue_understanding/generator/rendered_prompt.txt index ea92e3f93387..c4571d42b794 100644 --- a/tests/dialogue_understanding/generator/rendered_prompt.txt +++ b/tests/dialogue_understanding/generator/rendered_prompt.txt @@ -1 +1 @@ -Your task is to analyze the current conversation context and generate a list of actions to start new business processes that we call flows, to extract slots, or respond to small talk and knowledge requests.\n\nThese are the flows that can be started, with their description and slots:\n\ntest_flow: some description\n slot: test_slot\n \n\n===\nHere is what happened previously in the conversation:\nUSER: Hello\nAI: Hi\nUSER: some message\n\n===\n\nYou are currently not in any flow and so there are no active slots.\nThis means you can only set a slot if you first start a flow that requires that slot.\n\nIf you start a flow, first start the flow and then optionally fill that flow\'s slots with information the user provided in their message.\n\nThe user just said """some message""".\n\n===\nBased on this information generate a list of actions you want to take. Your job is to start flows and to fill slots where appropriate. Any logic of what happens afterwards is handled by the flow engine. These are your available actions:\n* Slot setting, described by "SetSlot(slot_name, slot_value)". An example would be "SetSlot(recipient, Freddy)"\n* Starting another flow, described by "StartFlow(flow_name)". An example would be "StartFlow(transfer_money)"\n* Cancelling the current flow, described by "CancelFlow()"\n* Clarifying which flow should be started. An example would be Clarify(list_contacts, add_contact, remove_contact) if the user just wrote "contacts" and there are multiple potential candidates. It also works with a single flow name to confirm you understood correctly, as in Clarify(transfer_money).\n* Responding to knowledge-oriented user messages, described by "SearchAndReply()"\n* Responding to a casual, non-task-oriented user message, described by "ChitChat()".\n* Handing off to a human, in case the user seems frustrated or explicitly asks to speak to one, described by "HumanHandoff()".\n\n===\nWrite out the actions you want to take, one per line, in the order they should take place.\nDo not fill slots with abstract values or placeholders.\nOnly use information provided by the user.\nOnly start a flow if it\'s completely clear what the user wants. Imagine you were a person reading this message. If it\'s not 100% clear, clarify the next step.\nDon\'t be overly confident. Take a conservative approach and clarify before proceeding.\nIf the user asks for two things which seem contradictory, clarify before starting a flow.\nStrictly adhere to the provided action types listed above.\nFocus on the last message and take it one step at a time.\nUse the previous conversation steps only to aid understanding.\n\nYour action list: \ No newline at end of file +Your task is to analyze the current conversation context and generate a list of actions to start new business processes that we call flows, to extract slots, or respond to small talk and knowledge requests.\n\nThese are the flows that can be started, with their description and slots:\n\ntest_flow: some description\n \n\n===\nHere is what happened previously in the conversation:\nUSER: Hello\nAI: Hi\nUSER: some message\n\n===\n\nYou are currently not in any flow and so there are no active slots.\nThis means you can only set a slot if you first start a flow that requires that slot.\n\nIf you start a flow, first start the flow and then optionally fill that flow\'s slots with information the user provided in their message.\n\nThe user just said """some message""".\n\n===\nBased on this information generate a list of actions you want to take. Your job is to start flows and to fill slots where appropriate. Any logic of what happens afterwards is handled by the flow engine. These are your available actions:\n* Slot setting, described by "SetSlot(slot_name, slot_value)". An example would be "SetSlot(recipient, Freddy)"\n* Starting another flow, described by "StartFlow(flow_name)". An example would be "StartFlow(transfer_money)"\n* Cancelling the current flow, described by "CancelFlow()"\n* Clarifying which flow should be started. An example would be Clarify(list_contacts, add_contact, remove_contact) if the user just wrote "contacts" and there are multiple potential candidates. It also works with a single flow name to confirm you understood correctly, as in Clarify(transfer_money).\n* Responding to knowledge-oriented user messages, described by "SearchAndReply()"\n* Responding to a casual, non-task-oriented user message, described by "ChitChat()".\n* Handing off to a human, in case the user seems frustrated or explicitly asks to speak to one, described by "HumanHandoff()".\n\n===\nWrite out the actions you want to take, one per line, in the order they should take place.\nDo not fill slots with abstract values or placeholders.\nOnly use information provided by the user.\nOnly start a flow if it\'s completely clear what the user wants. Imagine you were a person reading this message. If it\'s not 100% clear, clarify the next step.\nDon\'t be overly confident. Take a conservative approach and clarify before proceeding.\nIf the user asks for two things which seem contradictory, clarify before starting a flow.\nStrictly adhere to the provided action types listed above.\nFocus on the last message and take it one step at a time.\nUse the previous conversation steps only to aid understanding.\n\nYour action list: \ No newline at end of file diff --git a/tests/dialogue_understanding/generator/test_llm_command_generator.py b/tests/dialogue_understanding/generator/test_llm_command_generator.py index 78f73843708a..216621439a58 100644 --- a/tests/dialogue_understanding/generator/test_llm_command_generator.py +++ b/tests/dialogue_understanding/generator/test_llm_command_generator.py @@ -328,11 +328,12 @@ def test_allowed_values_for_slot( def collect_info_step(self) -> CollectInformationFlowStep: """Create a CollectInformationFlowStep.""" return CollectInformationFlowStep( - collect_information="test_slot", + collect="test_slot", + idx=0, ask_before_filling=True, utter="hello", rejections=[SlotRejection("test_slot", "some rejection")], - id="collect_information", + custom_id="collect", description="test_slot", metadata={}, next="next_step", @@ -365,11 +366,12 @@ def test_is_extractable_when_slot_can_be_filled_without_asking( sender_id="test", evts=[], slots=[TextSlot(name="test_slot", mappings=[])] ) collect_info_step = CollectInformationFlowStep( - collect_information="test_slot", + collect="test_slot", ask_before_filling=False, utter="hello", rejections=[SlotRejection("test_slot", "some rejection")], - id="collect_information", + custom_id="collect_information", + idx=0, description="test_slot", metadata={}, next="next_step", From 8194f0b3776845b54c278bf75114ef500bc01687 Mon Sep 17 00:00:00 2001 From: danc Date: Thu, 5 Oct 2023 09:11:34 +0100 Subject: [PATCH 10/19] chore: fix collect_information. --- rasa/dialogue_understanding/generator/llm_command_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/dialogue_understanding/generator/llm_command_generator.py b/rasa/dialogue_understanding/generator/llm_command_generator.py index 95ef6962b467..2d5592e30dd8 100644 --- a/rasa/dialogue_understanding/generator/llm_command_generator.py +++ b/rasa/dialogue_understanding/generator/llm_command_generator.py @@ -185,7 +185,7 @@ def render_template( ), "description": info_step.description, } - for info_step in top_flow.get_collect_information_steps() + for info_step in top_flow.get_collect_steps() if self.is_extractable(info_step, tracker, current_step) ] else: From ff6de9f9dbbfc791f421be3b4a44d69ebbe34225 Mon Sep 17 00:00:00 2001 From: danc Date: Sun, 8 Oct 2023 12:08:45 +0100 Subject: [PATCH 11/19] chore: responded to review comments. adding FlowList.user_flows, updating FLowList.user_flow_ids, making FlowLists itterable. --- .../commands/start_flow_command.py | 2 +- .../generator/llm_command_generator.py | 121 ++++++++++-------- rasa/shared/core/flows/flow.py | 29 +++-- 3 files changed, 86 insertions(+), 66 deletions(-) diff --git a/rasa/dialogue_understanding/commands/start_flow_command.py b/rasa/dialogue_understanding/commands/start_flow_command.py index cb3cd5d5166a..952e8174a152 100644 --- a/rasa/dialogue_understanding/commands/start_flow_command.py +++ b/rasa/dialogue_understanding/commands/start_flow_command.py @@ -71,7 +71,7 @@ def run_command_on_tracker( "command_executor.skip_command.already_started_flow", command=self ) return [] - elif self.flow not in all_flows.non_pattern_flows(): + elif self.flow not in all_flows.user_flow_ids: structlogger.debug( "command_executor.skip_command.start_invalid_flow_id", command=self ) diff --git a/rasa/dialogue_understanding/generator/llm_command_generator.py b/rasa/dialogue_understanding/generator/llm_command_generator.py index 2d5592e30dd8..238574cbed29 100644 --- a/rasa/dialogue_understanding/generator/llm_command_generator.py +++ b/rasa/dialogue_understanding/generator/llm_command_generator.py @@ -23,7 +23,12 @@ from rasa.engine.recipes.default_recipe import DefaultV1Recipe from rasa.engine.storage.resource import Resource from rasa.engine.storage.storage import ModelStorage -from rasa.shared.core.flows.flow import FlowStep, FlowsList, CollectInformationFlowStep +from rasa.shared.core.flows.flow import ( + Flow, + FlowStep, + FlowsList, + CollectInformationFlowStep, +) from rasa.shared.core.trackers import DialogueStateTracker from rasa.shared.core.slots import ( BooleanSlot, @@ -67,7 +72,7 @@ is_trainable=True, ) class LLMCommandGenerator(GraphComponent, CommandGenerator): - """An LLM based command generator.""" + """An LLM-based command generator.""" @staticmethod def get_default_config() -> Dict[str, Any]: @@ -168,47 +173,29 @@ def render_template( Returns: The rendered prompt template. """ - flows_without_patterns = FlowsList( - [f for f in flows.underlying_flows if not f.is_handling_pattern()] - ) top_relevant_frame = top_flow_frame(DialogueStack.from_tracker(tracker)) top_flow = top_relevant_frame.flow(flows) if top_relevant_frame else None current_step = top_relevant_frame.step(flows) if top_relevant_frame else None - if top_flow is not None: - flow_slots = [ - { - "name": info_step.collect, - "value": self.slot_value(tracker, info_step.collect), - "type": tracker.slots[info_step.collect].type_name, - "allowed_values": self.allowed_values_for_slot( - tracker.slots[info_step.collect] - ), - "description": info_step.description, - } - for info_step in top_flow.get_collect_steps() - if self.is_extractable(info_step, tracker, current_step) - ] - else: - flow_slots = [] - collect_information, collect_information_description = ( - (current_step.collect, current_step.description) - if isinstance(current_step, CollectInformationFlowStep) - else (None, None) + flow_slots = self.prepare_current_flow_slots_for_template( + top_flow, current_step, tracker + ) + current_slot, current_slot_description = self.prepare_current_slot_for_template( + current_step ) current_conversation = tracker_as_readable_transcript(tracker) latest_user_message = sanitize_message_for_prompt(message.get(TEXT)) current_conversation += f"\nUSER: {latest_user_message}" inputs = { - "available_flows": self.create_template_inputs( - flows_without_patterns, tracker + "available_flows": self.prepare_flows_for_template( + flows.user_flows, tracker ), "current_conversation": current_conversation, "flow_slots": flow_slots, "current_flow": top_flow.id if top_flow is not None else None, - "collect_information": collect_information, - "collect_information_description": collect_information_description, + "collect_information": current_slot, + "collect_information_description": current_slot_description, "user_message": latest_user_message, } @@ -307,7 +294,7 @@ def clean_extracted_value(value: str) -> str: """Clean up the extracted value from the llm.""" # replace any combination of single quotes, double quotes, and spaces # from the beginning and end of the string - return re.sub(r"^['\"\s]+|['\"\s]+$", "", value) + return value.strip("'\" ") @classmethod def coerce_slot_value( @@ -345,10 +332,10 @@ def coerce_slot_value( return nullable_value @classmethod - def create_template_inputs( + def prepare_flows_for_template( cls, flows: FlowsList, tracker: DialogueStateTracker ) -> List[Dict[str, Any]]: - """Create the template inputs for the flows. + """Format data on available flows for insertion into the prompt template. Args: flows: The flows available to the user. @@ -358,29 +345,24 @@ def create_template_inputs( The inputs for the prompt template. """ result = [] - for flow in flows.underlying_flows: - # TODO: check if we should filter more flows; e.g. flows that are - # linked to by other flows and that shouldn't be started directly. - # we might need a separate flag for that. - if not flow.is_rasa_default_flow(): - - slots_with_info = [ - {"name": q.collect, "description": q.description} - for q in flow.get_collect_steps() - if cls.is_extractable(q, tracker) - ] - result.append( - { - "name": flow.id, - "description": flow.description, - "slots": slots_with_info, - } - ) + for flow in flows.user_flows: + slots_with_info = [ + {"name": q.collect, "description": q.description} + for q in flow.get_collect_steps() + if cls.is_extractable(q, tracker) + ] + result.append( + { + "name": flow.id, + "description": flow.description, + "slots": slots_with_info, + } + ) return result @staticmethod def is_extractable( - info_step: CollectInformationFlowStep, + collect_step: CollectInformationFlowStep, tracker: DialogueStateTracker, current_step: Optional[FlowStep] = None, ) -> bool: @@ -391,27 +373,27 @@ def is_extractable( slot has been filled already. Args: - info_step: The collect_information step. + collect_step: The collect_information step. tracker: The tracker containing the current state of the conversation. current_step: The current step in the flow. Returns: `True` if the slot can be filled, `False` otherwise. """ - slot = tracker.slots.get(info_step.collect) + slot = tracker.slots.get(collect_step.collect) if slot is None: return False return ( # we can fill because this is a slot that can be filled ahead of time - not info_step.ask_before_filling + not collect_step.ask_before_filling # we can fill because the slot has been filled already or slot.has_been_set # we can fill because the is currently getting asked or ( current_step is not None and isinstance(current_step, CollectInformationFlowStep) - and current_step.collect == info_step.collect + and current_step.collect == collect_step.collect ) ) @@ -440,3 +422,32 @@ def slot_value(tracker: DialogueStateTracker, slot_name: str) -> str: return "undefined" else: return str(slot_value) + + def prepare_current_flow_slots_for_template( + self, top_flow: Flow, current_step: FlowStep, tracker: DialogueStateTracker + ) -> List[Dict[str, Any]]: + if top_flow is not None: + flow_slots = [ + { + "name": collect_step.collect, + "value": self.slot_value(tracker, collect_step.collect), + "type": tracker.slots[collect_step.collect].type_name, + "allowed_values": self.allowed_values_for_slot( + tracker.slots[collect_step.collect] + ), + "description": collect_step.description, + } + for collect_step in top_flow.get_collect_steps() + if self.is_extractable(collect_step, tracker, current_step) + ] + else: + flow_slots = [] + return flow_slots + + def prepare_current_slot_for_template(self, current_step: FlowStep): + """Prepare the current slot for the template.""" + return ( + (current_step.collect, current_step.description) + if isinstance(current_step, CollectInformationFlowStep) + else (None, None) + ) diff --git a/rasa/shared/core/flows/flow.py b/rasa/shared/core/flows/flow.py index 4b12f93f36f9..f749b4beadf8 100644 --- a/rasa/shared/core/flows/flow.py +++ b/rasa/shared/core/flows/flow.py @@ -179,6 +179,10 @@ def __init__(self, flows: List[Flow]) -> None: """ self.underlying_flows = flows + def __iter__(self) -> Generator[Flow, None, None]: + """Iterates over the flows.""" + yield from self.underlying_flows + def is_empty(self) -> bool: """Returns whether the flows list is empty.""" return len(self.underlying_flows) == 0 @@ -254,15 +258,23 @@ def validate(self) -> None: for flow in self.underlying_flows: flow.validate() - def non_pattern_flows(self) -> List[str]: - """Get all flows that can be started. + @property + def user_flow_ids(self) -> List[str]: + """Get all ids of flows that can be started by a user. - Args: - all_flows: All flows. + Returns: + The ids of all flows that can be started by a user.""" + return [f.id for f in self.user_flows] + + @property + def user_flows(self) -> FlowsList: + """Get all flows that can be started by a user. Returns: - All flows that can be started.""" - return [f.id for f in self.underlying_flows if not f.is_handling_pattern()] + All flows that can be started by a user.""" + return FlowsList( + [f for f in self.underlying_flows if not f.is_rasa_default_flow] + ) @dataclass @@ -495,10 +507,6 @@ def _previously_asked_collect( return _previously_asked_collect(step_id or START_STEP, set()) - def is_handling_pattern(self) -> bool: - """Returns whether the flow is handling a pattern.""" - return self.id.startswith(RASA_DEFAULT_FLOW_PATTERN_PREFIX) - def get_trigger_intents(self) -> Set[str]: """Returns the trigger intents of the flow""" results: Set[str] = set() @@ -519,6 +527,7 @@ def is_user_triggerable(self) -> bool: """Test whether a user can trigger the flow with an intent.""" return len(self.get_trigger_intents()) > 0 + @property def is_rasa_default_flow(self) -> bool: """Test whether something is a rasa default flow.""" return self.id.startswith(RASA_DEFAULT_FLOW_PATTERN_PREFIX) From 5221173288b5ae2b64d8ae2198b9d8ab691dab0e Mon Sep 17 00:00:00 2001 From: danc Date: Mon, 9 Oct 2023 16:21:43 +0100 Subject: [PATCH 12/19] chore: added testing for FlowsList.user_flows including __eq__ method on FLowsList class. --- rasa/shared/core/flows/flow.py | 7 +++++ tests/core/flows/test_flow.py | 55 +++++++++++++++++++++++++++------- 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/rasa/shared/core/flows/flow.py b/rasa/shared/core/flows/flow.py index f749b4beadf8..1d3a89c238f1 100644 --- a/rasa/shared/core/flows/flow.py +++ b/rasa/shared/core/flows/flow.py @@ -183,6 +183,13 @@ def __iter__(self) -> Generator[Flow, None, None]: """Iterates over the flows.""" yield from self.underlying_flows + def __eq__(self, other: Any) -> bool: + """Compares the flows.""" + return ( + isinstance(other, FlowsList) + and self.underlying_flows == other.underlying_flows + ) + def is_empty(self) -> bool: """Returns whether the flows list is empty.""" return len(self.underlying_flows) == 0 diff --git a/tests/core/flows/test_flow.py b/tests/core/flows/test_flow.py index 4192ce132962..fa489776f8ba 100644 --- a/tests/core/flows/test_flow.py +++ b/tests/core/flows/test_flow.py @@ -1,9 +1,12 @@ -from rasa.shared.core.flows.flow import FlowsList +import pytest + +from rasa.shared.core.flows.flow import Flow, FlowsList from tests.utilities import flows_from_str -def test_non_pattern_flows(): - all_flows = flows_from_str( +@pytest.fixture +def user_flows_and_patterns() -> FlowsList: + return flows_from_str( """ flows: foo: @@ -16,15 +19,11 @@ def test_non_pattern_flows(): action: action_listen """ ) - assert all_flows.non_pattern_flows() == ["foo"] - - -def test_non_pattern_handles_empty(): - assert FlowsList(flows=[]).non_pattern_flows() == [] -def test_non_pattern_flows_handles_patterns_only(): - all_flows = flows_from_str( +@pytest.fixture +def only_patterns() -> FlowsList: + return flows_from_str( """ flows: pattern_bar: @@ -33,4 +32,38 @@ def test_non_pattern_flows_handles_patterns_only(): action: action_listen """ ) - assert all_flows.non_pattern_flows() == [] + + +@pytest.fixture +def empty_flowlist() -> FlowsList: + return FlowsList(flows=[]) + + +def test_user_flow_ids(user_flows_and_patterns: FlowsList): + assert user_flows_and_patterns.user_flow_ids == ["foo"] + + +def test_user_flow_ids_handles_empty(empty_flowlist: FlowsList): + assert empty_flowlist.user_flow_ids == [] + + +def test_user_flow_ids_handles_patterns_only(only_patterns: FlowsList): + assert only_patterns.user_flow_ids == [] + + +def test_user_flows(user_flows_and_patterns: FlowsList): + user_flows = user_flows_and_patterns.user_flows + expected_user_flows = FlowsList( + [Flow.from_json("foo", {"steps": [{"id": "first", "action": "action_listen"}]})] + ) + assert user_flows == expected_user_flows + + +def test_user_flows_handles_empty(empty_flowlist: FlowsList): + assert empty_flowlist.user_flows == empty_flowlist + + +def test_user_flows_handles_patterns_only( + only_patterns: FlowsList, empty_flowlist: FlowsList +): + assert only_patterns.user_flows == empty_flowlist From 6b13c0ad1f93fa3088a89ba6b6f8de65d10fe66b Mon Sep 17 00:00:00 2001 From: danc Date: Tue, 10 Oct 2023 18:22:58 +0100 Subject: [PATCH 13/19] chore: addressed comments on tests. --- .../generator/command_prompt_template.jinja2 | 4 ++-- .../generator/llm_command_generator.py | 16 ++++++++-------- .../generator/test_llm_command_generator.py | 14 +++++++++++++- 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/rasa/dialogue_understanding/generator/command_prompt_template.jinja2 b/rasa/dialogue_understanding/generator/command_prompt_template.jinja2 index bb46fdcdf663..039ef52dc360 100644 --- a/rasa/dialogue_understanding/generator/command_prompt_template.jinja2 +++ b/rasa/dialogue_understanding/generator/command_prompt_template.jinja2 @@ -14,8 +14,8 @@ Here is what happened previously in the conversation: === {% if current_flow != None %} -You are currently in the flow "{{ current_flow }}", which {{ current_flow.description }} -You have just asked the user for the slot "{{ collect }}"{% if collect_description %} ({{ collect_description }}){% endif %}. +You are currently in the flow "{{ current_flow }}". +You have just asked the user for the slot "{{ current_slot }}"{% if current_slot_description %} ({{ current_slot_description }}){% endif %}. {% if flow_slots|length > 0 %} Here are the slots of the currently active flow: diff --git a/rasa/dialogue_understanding/generator/llm_command_generator.py b/rasa/dialogue_understanding/generator/llm_command_generator.py index 238574cbed29..fcdd80ee8804 100644 --- a/rasa/dialogue_understanding/generator/llm_command_generator.py +++ b/rasa/dialogue_understanding/generator/llm_command_generator.py @@ -188,14 +188,12 @@ def render_template( current_conversation += f"\nUSER: {latest_user_message}" inputs = { - "available_flows": self.prepare_flows_for_template( - flows.user_flows, tracker - ), + "available_flows": self.prepare_flows_for_template(flows, tracker), "current_conversation": current_conversation, "flow_slots": flow_slots, "current_flow": top_flow.id if top_flow is not None else None, - "collect_information": current_slot, - "collect_information_description": current_slot_description, + "current_slot": current_slot, + "current_slot_description": current_slot_description, "user_message": latest_user_message, } @@ -407,7 +405,7 @@ def allowed_values_for_slot(self, slot: Slot) -> Optional[str]: return None @staticmethod - def slot_value(tracker: DialogueStateTracker, slot_name: str) -> str: + def get_slot_value(tracker: DialogueStateTracker, slot_name: str) -> str: """Get the slot value from the tracker. Args: @@ -430,7 +428,7 @@ def prepare_current_flow_slots_for_template( flow_slots = [ { "name": collect_step.collect, - "value": self.slot_value(tracker, collect_step.collect), + "value": self.get_slot_value(tracker, collect_step.collect), "type": tracker.slots[collect_step.collect].type_name, "allowed_values": self.allowed_values_for_slot( tracker.slots[collect_step.collect] @@ -444,7 +442,9 @@ def prepare_current_flow_slots_for_template( flow_slots = [] return flow_slots - def prepare_current_slot_for_template(self, current_step: FlowStep): + def prepare_current_slot_for_template( + self, current_step: FlowStep + ) -> tuple[Optional[str], Optional[str]]: """Prepare the current slot for the template.""" return ( (current_step.collect, current_step.description) diff --git a/tests/dialogue_understanding/generator/test_llm_command_generator.py b/tests/dialogue_understanding/generator/test_llm_command_generator.py index 216621439a58..387df58f82bd 100644 --- a/tests/dialogue_understanding/generator/test_llm_command_generator.py +++ b/tests/dialogue_understanding/generator/test_llm_command_generator.py @@ -202,6 +202,18 @@ def test_render_template( ) ], ), + ( + "Here is a list of commands:\nSetSlot(flow_name, some_flow)\n", + [StartFlowCommand(flow="some_flow")], + ), + ( + """SetSlot(flow_name, some_flow) + SetSlot(transfer_money_amount_of_money,)""", + [ + StartFlowCommand(flow="some_flow"), + SetSlotCommand(name="transfer_money_amount_of_money", value=None), + ], + ), ], ) def test_parse_commands_identifies_correct_command( @@ -296,7 +308,7 @@ def test_slot_value(self, slot: Slot, slot_name: str, expected_output: str): # Given tracker = DialogueStateTracker.from_events("test", evts=[], slots=[slot]) # When - slot_value = LLMCommandGenerator.slot_value(tracker, slot_name) + slot_value = LLMCommandGenerator.get_slot_value(tracker, slot_name) assert slot_value == expected_output From 6eca72a92d45091c9b0307818291c7d099dea9fb Mon Sep 17 00:00:00 2001 From: danc Date: Mon, 16 Oct 2023 16:06:35 +0100 Subject: [PATCH 14/19] chore: fixed type hints for python 3.8. --- .../generator/llm_command_generator.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/rasa/dialogue_understanding/generator/llm_command_generator.py b/rasa/dialogue_understanding/generator/llm_command_generator.py index da3a863fac75..aad8e01e90d0 100644 --- a/rasa/dialogue_understanding/generator/llm_command_generator.py +++ b/rasa/dialogue_understanding/generator/llm_command_generator.py @@ -1,6 +1,6 @@ import importlib.resources import re -from typing import Dict, Any, Optional, List, Union +from typing import Dict, Any, List, Optional, Tuple, Union from jinja2 import Template import structlog @@ -396,7 +396,7 @@ def is_extractable( ) ) - def allowed_values_for_slot(self, slot: Slot) -> Optional[str]: + def allowed_values_for_slot(self, slot: Slot) -> Union[str, None]: """Get the allowed values for a slot.""" if isinstance(slot, BooleanSlot): return str([True, False]) @@ -425,6 +425,16 @@ def get_slot_value(tracker: DialogueStateTracker, slot_name: str) -> str: def prepare_current_flow_slots_for_template( self, top_flow: Flow, current_step: FlowStep, tracker: DialogueStateTracker ) -> List[Dict[str, Any]]: + """Prepare the current flow slots for the template. + + Args: + top_flow: The top flow. + current_step: The current step in the flow. + tracker: The tracker containing the current state of the conversation. + + Returns: + The slots with values, types, allowed values and a description. + """ if top_flow is not None: flow_slots = [ { @@ -445,7 +455,7 @@ def prepare_current_flow_slots_for_template( def prepare_current_slot_for_template( self, current_step: FlowStep - ) -> tuple[Optional[str], Optional[str]]: + ) -> Tuple[Union[str, None], Union[str, None]]: """Prepare the current slot for the template.""" return ( (current_step.collect, current_step.description) From 7ec98a215cedeeeefdc09f4479fd4462f3c4d16c Mon Sep 17 00:00:00 2001 From: danc Date: Mon, 16 Oct 2023 16:38:28 +0100 Subject: [PATCH 15/19] chore: updated schema for flow in test. --- .../generator/test_llm_command_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dialogue_understanding/generator/test_llm_command_generator.py b/tests/dialogue_understanding/generator/test_llm_command_generator.py index 387df58f82bd..5dbd9ed6b2d7 100644 --- a/tests/dialogue_understanding/generator/test_llm_command_generator.py +++ b/tests/dialogue_understanding/generator/test_llm_command_generator.py @@ -166,7 +166,7 @@ def test_render_template( description: some description steps: - id: first_step - collect_information: test_slot + collect: test_slot """ ) with open(EXPECTED_PROMPT_PATH, "r", encoding="unicode_escape") as f: From 09a56a10dd983d17bcd443839b4ca84c6b33d425 Mon Sep 17 00:00:00 2001 From: danc Date: Wed, 18 Oct 2023 09:00:31 +0100 Subject: [PATCH 16/19] chore: updated render prompt test to test line by line. --- .../generator/rendered_prompt.txt | 46 ++++++++++++++++++- .../generator/test_llm_command_generator.py | 12 +++-- 2 files changed, 52 insertions(+), 6 deletions(-) diff --git a/tests/dialogue_understanding/generator/rendered_prompt.txt b/tests/dialogue_understanding/generator/rendered_prompt.txt index c4571d42b794..d9e9fd4c14e7 100644 --- a/tests/dialogue_understanding/generator/rendered_prompt.txt +++ b/tests/dialogue_understanding/generator/rendered_prompt.txt @@ -1 +1,45 @@ -Your task is to analyze the current conversation context and generate a list of actions to start new business processes that we call flows, to extract slots, or respond to small talk and knowledge requests.\n\nThese are the flows that can be started, with their description and slots:\n\ntest_flow: some description\n \n\n===\nHere is what happened previously in the conversation:\nUSER: Hello\nAI: Hi\nUSER: some message\n\n===\n\nYou are currently not in any flow and so there are no active slots.\nThis means you can only set a slot if you first start a flow that requires that slot.\n\nIf you start a flow, first start the flow and then optionally fill that flow\'s slots with information the user provided in their message.\n\nThe user just said """some message""".\n\n===\nBased on this information generate a list of actions you want to take. Your job is to start flows and to fill slots where appropriate. Any logic of what happens afterwards is handled by the flow engine. These are your available actions:\n* Slot setting, described by "SetSlot(slot_name, slot_value)". An example would be "SetSlot(recipient, Freddy)"\n* Starting another flow, described by "StartFlow(flow_name)". An example would be "StartFlow(transfer_money)"\n* Cancelling the current flow, described by "CancelFlow()"\n* Clarifying which flow should be started. An example would be Clarify(list_contacts, add_contact, remove_contact) if the user just wrote "contacts" and there are multiple potential candidates. It also works with a single flow name to confirm you understood correctly, as in Clarify(transfer_money).\n* Responding to knowledge-oriented user messages, described by "SearchAndReply()"\n* Responding to a casual, non-task-oriented user message, described by "ChitChat()".\n* Handing off to a human, in case the user seems frustrated or explicitly asks to speak to one, described by "HumanHandoff()".\n\n===\nWrite out the actions you want to take, one per line, in the order they should take place.\nDo not fill slots with abstract values or placeholders.\nOnly use information provided by the user.\nOnly start a flow if it\'s completely clear what the user wants. Imagine you were a person reading this message. If it\'s not 100% clear, clarify the next step.\nDon\'t be overly confident. Take a conservative approach and clarify before proceeding.\nIf the user asks for two things which seem contradictory, clarify before starting a flow.\nStrictly adhere to the provided action types listed above.\nFocus on the last message and take it one step at a time.\nUse the previous conversation steps only to aid understanding.\n\nYour action list: \ No newline at end of file +Your task is to analyze the current conversation context and generate a list of actions to start new business processes that we call flows, to extract slots, or respond to small talk and knowledge requests. + +These are the flows that can be started, with their description and slots: + +test_flow: some description + slot: test_slot + + +=== +Here is what happened previously in the conversation: +USER: Hello +AI: Hi +USER: some message + +=== + +You are currently not in any flow and so there are no active slots. +This means you can only set a slot if you first start a flow that requires that slot. + +If you start a flow, first start the flow and then optionally fill that flow's slots with information the user provided in their message. + +The user just said """some message""". + +=== +Based on this information generate a list of actions you want to take. Your job is to start flows and to fill slots where appropriate. Any logic of what happens afterwards is handled by the flow engine. These are your available actions: +* Slot setting, described by "SetSlot(slot_name, slot_value)". An example would be "SetSlot(recipient, Freddy)" +* Starting another flow, described by "StartFlow(flow_name)". An example would be "StartFlow(transfer_money)" +* Cancelling the current flow, described by "CancelFlow()" +* Clarifying which flow should be started. An example would be Clarify(list_contacts, add_contact, remove_contact) if the user just wrote "contacts" and there are multiple potential candidates. It also works with a single flow name to confirm you understood correctly, as in Clarify(transfer_money). +* Responding to knowledge-oriented user messages, described by "SearchAndReply()" +* Responding to a casual, non-task-oriented user message, described by "ChitChat()". +* Handing off to a human, in case the user seems frustrated or explicitly asks to speak to one, described by "HumanHandoff()". + +=== +Write out the actions you want to take, one per line, in the order they should take place. +Do not fill slots with abstract values or placeholders. +Only use information provided by the user. +Only start a flow if it's completely clear what the user wants. Imagine you were a person reading this message. If it's not 100% clear, clarify the next step. +Don't be overly confident. Take a conservative approach and clarify before proceeding. +If the user asks for two things which seem contradictory, clarify before starting a flow. +Strictly adhere to the provided action types listed above. +Focus on the last message and take it one step at a time. +Use the previous conversation steps only to aid understanding. + +Your action list: \ No newline at end of file diff --git a/tests/dialogue_understanding/generator/test_llm_command_generator.py b/tests/dialogue_understanding/generator/test_llm_command_generator.py index 5dbd9ed6b2d7..cbd73e6afde6 100644 --- a/tests/dialogue_understanding/generator/test_llm_command_generator.py +++ b/tests/dialogue_understanding/generator/test_llm_command_generator.py @@ -170,14 +170,16 @@ def test_render_template( """ ) with open(EXPECTED_PROMPT_PATH, "r", encoding="unicode_escape") as f: - expected_template = f.read() - # # When + expected_template = f.readlines() + # When rendered_template = command_generator.render_template( message=test_message, tracker=test_tracker, flows=test_flows ) - - # # Then - assert rendered_template == expected_template + # Then + for rendered_line, expected_line in zip( + rendered_template.splitlines(True), expected_template + ): + assert rendered_line == expected_line @pytest.mark.parametrize( "input_action, expected_command", From d754306dddf9625b6dde9b7b92fe925e192b031c Mon Sep 17 00:00:00 2001 From: Alan Nichol Date: Thu, 19 Oct 2023 16:48:57 -0400 Subject: [PATCH 17/19] remove unneeded slot mapping defs --- examples/money_transfer/domain.yml | 8 -------- rasa/cli/project_templates/tutorial/domain.yml | 4 ---- 2 files changed, 12 deletions(-) diff --git a/examples/money_transfer/domain.yml b/examples/money_transfer/domain.yml index 6d2cd6f1304c..a7b79ad598f0 100644 --- a/examples/money_transfer/domain.yml +++ b/examples/money_transfer/domain.yml @@ -3,20 +3,12 @@ version: "3.1" slots: recipient: type: text - mappings: - - type: custom amount: type: float - mappings: - - type: custom final_confirmation: type: bool - mappings: - - type: custom has_sufficient_funds: type: bool - mappings: - - type: custom responses: utter_ask_recipient: diff --git a/rasa/cli/project_templates/tutorial/domain.yml b/rasa/cli/project_templates/tutorial/domain.yml index dc2854563712..4c04450e3094 100644 --- a/rasa/cli/project_templates/tutorial/domain.yml +++ b/rasa/cli/project_templates/tutorial/domain.yml @@ -3,12 +3,8 @@ version: "3.1" slots: recipient: type: text - mappings: - - type: custom amount: type: float - mappings: - - type: custom responses: utter_ask_recipient: From 0f05c4455bee9d503deed8fea7152c1fcdc1db48 Mon Sep 17 00:00:00 2001 From: Alan Nichol Date: Thu, 19 Oct 2023 16:54:22 -0400 Subject: [PATCH 18/19] add description to amount step --- rasa/cli/project_templates/tutorial/data/flows.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/rasa/cli/project_templates/tutorial/data/flows.yml b/rasa/cli/project_templates/tutorial/data/flows.yml index d849432a00e6..1ad56dc32fe0 100644 --- a/rasa/cli/project_templates/tutorial/data/flows.yml +++ b/rasa/cli/project_templates/tutorial/data/flows.yml @@ -4,4 +4,5 @@ flows: steps: - collect: recipient - collect: amount + description: the number of US dollars to send - action: utter_transfer_complete From 9ae34405dbd47fa700f749ae459b49c815eba31a Mon Sep 17 00:00:00 2001 From: Anca Lita <27920906+ancalita@users.noreply.github.com> Date: Fri, 20 Oct 2023 19:16:48 +0100 Subject: [PATCH 19/19] Reset slot value for `set_slots` once flow ends (#12918) * implement changes + unit test * handle edge case --- rasa/core/policies/flow_policy.py | 40 +++++++++--- tests/core/policies/test_flow_policy.py | 86 +++++++++++++++++++++++++ 2 files changed, 118 insertions(+), 8 deletions(-) diff --git a/rasa/core/policies/flow_policy.py b/rasa/core/policies/flow_policy.py index 74e74d4b074a..31c6ca93521c 100644 --- a/rasa/core/policies/flow_policy.py +++ b/rasa/core/policies/flow_policy.py @@ -559,16 +559,40 @@ def _reset_scoped_slots( self, current_flow: Flow, tracker: DialogueStateTracker ) -> List[Event]: """Reset all scoped slots.""" + + def _reset_slot( + slot_name: Text, dialogue_tracker: DialogueStateTracker + ) -> None: + slot = dialogue_tracker.slots.get(slot_name, None) + initial_value = slot.initial_value if slot else None + events.append(SlotSet(slot_name, initial_value)) + events: List[Event] = [] + + not_resettable_slot_names = set() + for step in current_flow.steps: - # reset all slots scoped to the flow - if ( - isinstance(step, CollectInformationFlowStep) - and step.reset_after_flow_ends - ): - slot = tracker.slots.get(step.collect, None) - initial_value = slot.initial_value if slot else None - events.append(SlotSet(step.collect, initial_value)) + if isinstance(step, CollectInformationFlowStep): + # reset all slots scoped to the flow + if step.reset_after_flow_ends: + _reset_slot(step.collect, tracker) + else: + not_resettable_slot_names.add(step.collect) + + # slots set by the set slots step should be reset after the flow ends + # unless they are also used in a collect step where `reset_after_flow_ends` + # is set to `False` + resettable_set_slots = [ + slot["key"] + for step in current_flow.steps + if isinstance(step, SetSlotsFlowStep) + for slot in step.slots + if slot["key"] not in not_resettable_slot_names + ] + + for name in resettable_set_slots: + _reset_slot(name, tracker) + return events def run_step( diff --git a/tests/core/policies/test_flow_policy.py b/tests/core/policies/test_flow_policy.py index 6575a4f679e6..5a84a9dcca1c 100644 --- a/tests/core/policies/test_flow_policy.py +++ b/tests/core/policies/test_flow_policy.py @@ -16,6 +16,7 @@ from rasa.shared.core.events import ActionExecuted, Event, SlotSet from rasa.shared.core.flows.flow import FlowsList from rasa.shared.core.flows.yaml_flows_io import YAMLFlowsReader +from rasa.shared.core.slots import TextSlot from rasa.shared.core.trackers import DialogueStateTracker from rasa.dialogue_understanding.stack.frames import ( UserFlowStackFrame, @@ -328,3 +329,88 @@ def test_executor_does_not_get_tripped_if_an_action_is_predicted_in_loop(): selection = executor.select_next_action(tracker) assert selection.action_name == "action_listen" + + +def test_flow_policy_resets_all_slots_after_flow_ends() -> None: + flows = flows_from_str( + """ + flows: + foo_flow: + steps: + - id: "1" + collect: my_slot + - id: "2" + set_slots: + - foo: bar + - other_slot: other_value + - id: "3" + action: action_listen + """ + ) + tracker = DialogueStateTracker.from_events( + "test", + [ + SlotSet("my_slot", "my_value"), + SlotSet("foo", "bar"), + SlotSet("other_slot", "other_value"), + ActionExecuted("action_listen"), + ], + slots=[ + TextSlot("my_slot", mappings=[], initial_value="initial_value"), + TextSlot("foo", mappings=[]), + TextSlot("other_slot", mappings=[]), + ], + ) + + domain = Domain.empty() + executor = FlowExecutor.from_tracker(tracker, flows, domain) + + current_flow = flows.flow_by_id("foo_flow") + events = executor._reset_scoped_slots(current_flow, tracker) + assert events == [ + SlotSet("my_slot", "initial_value"), + SlotSet("foo", None), + SlotSet("other_slot", None), + ] + + +def test_flow_policy_set_slots_inherit_reset_from_collect_step() -> None: + """Test that `reset_after_flow_ends` is inherited from the collect step.""" + slot_name = "my_slot" + flows = flows_from_str( + f""" + flows: + foo_flow: + steps: + - id: "1" + collect: {slot_name} + reset_after_flow_ends: false + - id: "2" + set_slots: + - foo: bar + - {slot_name}: my_value + - id: "3" + action: action_listen + """ + ) + tracker = DialogueStateTracker.from_events( + "test123", + [ + SlotSet("my_slot", "my_value"), + SlotSet("foo", "bar"), + ActionExecuted("action_listen"), + ], + slots=[ + TextSlot("my_slot", mappings=[], initial_value="initial_value"), + TextSlot("foo", mappings=[]), + ], + ) + + domain = Domain.empty() + executor = FlowExecutor.from_tracker(tracker, flows, domain) + + current_flow = flows.flow_by_id("foo_flow") + events = executor._reset_scoped_slots(current_flow, tracker) + assert events == [ + SlotSet("foo", None), + ]