diff --git a/rasa/dialogue_understanding/generator/llm_command_generator.py b/rasa/dialogue_understanding/generator/llm_command_generator.py index d4c4088c9d7a..c6533a4a341e 100644 --- a/rasa/dialogue_understanding/generator/llm_command_generator.py +++ b/rasa/dialogue_understanding/generator/llm_command_generator.py @@ -4,6 +4,7 @@ from jinja2 import Template import structlog + from rasa.dialogue_understanding.stack.utils import top_flow_frame from rasa.dialogue_understanding.generator import CommandGenerator from rasa.dialogue_understanding.commands import ( @@ -47,9 +48,6 @@ "rasa.dialogue_understanding.generator", "command_prompt_template.jinja2" ) -structlogger = structlog.get_logger() - - DEFAULT_LLM_CONFIG = { "_type": "openai", "request_timeout": 7, @@ -59,6 +57,8 @@ LLM_CONFIG_KEY = "llm" +structlogger = structlog.get_logger() + @DefaultV1Recipe.register( [ @@ -67,6 +67,11 @@ is_trainable=True, ) class LLMCommandGenerator(GraphComponent, CommandGenerator): + """An LLM based command generator. + + # TODO: add description to the docstring. + + """ @staticmethod def get_default_config() -> Dict[str, Any]: """The component's default config (see parent class for full docstring).""" @@ -142,6 +147,7 @@ def predict_commands( flows: FlowsList, tracker: Optional[DialogueStateTracker] = None, ) -> List[Command]: + """TODO: add docstring""" if tracker is None or flows.is_empty(): # cannot do anything if there are no flows or no tracker return [] @@ -164,6 +170,7 @@ def predict_commands( @staticmethod def is_none_value(value: str) -> bool: + """TODO: add docstring""" return value in { "[missing information]", "[missing]", @@ -217,7 +224,9 @@ def coerce_slot_value( def parse_commands( cls, actions: Optional[str], tracker: DialogueStateTracker ) -> List[Command]: - """Parse the actions returned by the llm into intent and entities.""" + """Parse the actions returned by the llm into intent and entities. + #TODO: add arguments and returns. + """ if not actions: return [ErrorCommand()] @@ -267,6 +276,7 @@ def parse_commands( def create_template_inputs( cls, flows: FlowsList, tracker: DialogueStateTracker ) -> List[Dict[str, Any]]: + """TODO: add docstring.""" result = [] for flow in flows.underlying_flows: # TODO: check if we should filter more flows; e.g. flows that are @@ -337,6 +347,7 @@ def slot_value(tracker: DialogueStateTracker, slot_name: str) -> str: def render_template( self, message: Message, tracker: DialogueStateTracker, flows: FlowsList ) -> str: + """TODO: add docstring""" flows_without_patterns = FlowsList( [f for f in flows.underlying_flows if not f.is_handling_pattern()] ) diff --git a/tests/cdu/__init__.py b/tests/dialogue_understanding/__init__.py similarity index 100% rename from tests/cdu/__init__.py rename to tests/dialogue_understanding/__init__.py diff --git a/tests/cdu/commands/__init__.py b/tests/dialogue_understanding/commands/__init__.py similarity index 100% rename from tests/cdu/commands/__init__.py rename to tests/dialogue_understanding/commands/__init__.py diff --git a/tests/cdu/commands/test_can_not_handle_command.py b/tests/dialogue_understanding/commands/test_can_not_handle_command.py similarity index 100% rename from tests/cdu/commands/test_can_not_handle_command.py rename to tests/dialogue_understanding/commands/test_can_not_handle_command.py diff --git a/tests/cdu/commands/test_cancel_flow_command.py b/tests/dialogue_understanding/commands/test_cancel_flow_command.py similarity index 100% rename from tests/cdu/commands/test_cancel_flow_command.py rename to tests/dialogue_understanding/commands/test_cancel_flow_command.py diff --git a/tests/cdu/commands/test_chit_chat_answer_command.py b/tests/dialogue_understanding/commands/test_chit_chat_answer_command.py similarity index 100% rename from tests/cdu/commands/test_chit_chat_answer_command.py rename to tests/dialogue_understanding/commands/test_chit_chat_answer_command.py diff --git a/tests/cdu/commands/test_clarify_command.py b/tests/dialogue_understanding/commands/test_clarify_command.py similarity index 100% rename from tests/cdu/commands/test_clarify_command.py rename to tests/dialogue_understanding/commands/test_clarify_command.py diff --git a/tests/cdu/commands/test_command.py b/tests/dialogue_understanding/commands/test_command.py similarity index 100% rename from tests/cdu/commands/test_command.py rename to tests/dialogue_understanding/commands/test_command.py diff --git a/tests/cdu/commands/test_correct_slots_command.py b/tests/dialogue_understanding/commands/test_correct_slots_command.py similarity index 100% rename from tests/cdu/commands/test_correct_slots_command.py rename to tests/dialogue_understanding/commands/test_correct_slots_command.py diff --git a/tests/cdu/commands/test_error_command.py b/tests/dialogue_understanding/commands/test_error_command.py similarity index 100% rename from tests/cdu/commands/test_error_command.py rename to tests/dialogue_understanding/commands/test_error_command.py diff --git a/tests/cdu/commands/test_human_handoff_command.py b/tests/dialogue_understanding/commands/test_human_handoff_command.py similarity index 100% rename from tests/cdu/commands/test_human_handoff_command.py rename to tests/dialogue_understanding/commands/test_human_handoff_command.py diff --git a/tests/cdu/commands/test_konwledge_answer_command.py b/tests/dialogue_understanding/commands/test_konwledge_answer_command.py similarity index 100% rename from tests/cdu/commands/test_konwledge_answer_command.py rename to tests/dialogue_understanding/commands/test_konwledge_answer_command.py diff --git a/tests/cdu/commands/test_set_slot_command.py b/tests/dialogue_understanding/commands/test_set_slot_command.py similarity index 100% rename from tests/cdu/commands/test_set_slot_command.py rename to tests/dialogue_understanding/commands/test_set_slot_command.py diff --git a/tests/cdu/commands/test_start_flow_command.py b/tests/dialogue_understanding/commands/test_start_flow_command.py similarity index 100% rename from tests/cdu/commands/test_start_flow_command.py rename to tests/dialogue_understanding/commands/test_start_flow_command.py diff --git a/tests/cdu/generator/__init__.py b/tests/dialogue_understanding/generator/__init__.py similarity index 100% rename from tests/cdu/generator/__init__.py rename to tests/dialogue_understanding/generator/__init__.py diff --git a/tests/cdu/generator/test_command_generator.py b/tests/dialogue_understanding/generator/test_command_generator.py similarity index 100% rename from tests/cdu/generator/test_command_generator.py rename to tests/dialogue_understanding/generator/test_command_generator.py diff --git a/tests/dialogue_understanding/generator/test_llm_command_generator.py b/tests/dialogue_understanding/generator/test_llm_command_generator.py new file mode 100644 index 000000000000..ee807f82c6e9 --- /dev/null +++ b/tests/dialogue_understanding/generator/test_llm_command_generator.py @@ -0,0 +1,242 @@ +from unittest.mock import Mock, patch + +import pytest +from langchain.llms.fake import FakeListLLM +from structlog.testing import capture_logs + +from rasa.dialogue_understanding.generator.llm_command_generator import LLMCommandGenerator +from rasa.dialogue_understanding.commands import ( + # Command, + ErrorCommand, + SetSlotCommand, + CancelFlowCommand, + StartFlowCommand, + HumanHandoffCommand, + ChitChatAnswerCommand, + KnowledgeAnswerCommand, + ClarifyCommand, +) +from rasa.engine.graph import ExecutionContext +from rasa.engine.storage.resource import Resource +from rasa.engine.storage.storage import ModelStorage +from rasa.shared.core.slots import BooleanSlot, FloatSlot, TextSlot +from rasa.shared.core.trackers import DialogueStateTracker + + +class TestLLMCommandGenerator: + """Tests for the LLMCommandGenerator.""" + + @pytest.fixture + def command_generator(self): + """Create an LLMCommandGenerator.""" + return LLMCommandGenerator.create( + config={}, resource=Mock(), model_storage=Mock(), execution_context=Mock()) + + @pytest.fixture + def mock_command_generator( + self, + default_model_storage: ModelStorage, + default_execution_context: ExecutionContext, + ) -> LLMCommandGenerator: + """Create a patched LLMCommandGenerator.""" + with patch( + "rasa.dialogue_understanding.generator.llm_command_generator.llm_factory", + Mock(return_value=FakeListLLM(responses=["StartFlow(check_balance)"])), + ) as mock_llm: + return LLMCommandGenerator.create( + config=LLMCommandGenerator.get_default_config(), + model_storage=default_model_storage, + resource=Resource("llmcommandgenerator"), + execution_context=default_execution_context) + + def test_predict_commands_with_no_flows(self, mock_command_generator: LLMCommandGenerator): + """Test that predict_commands returns an empty list when flows is None.""" + # When + predicted_commands = mock_command_generator.predict_commands(Mock(), flows=None, tracker=Mock()) + # Then + assert not predicted_commands + + def test_predict_commands_with_no_tracker(self, mock_command_generator: LLMCommandGenerator): + """Test that predict_commands returns an empty list when tracker is None.""" + # When + predicted_commands = mock_command_generator.predict_commands(Mock(), flows=Mock(), tracker=None) + # Then + assert not predicted_commands + + @patch.object(LLMCommandGenerator, "render_template", Mock(return_value="some prompt")) + @patch.object(LLMCommandGenerator, "parse_commands", Mock()) + def test_predict_commands_calls_llm_correctly(self, command_generator: LLMCommandGenerator): + """Test that predict_commands calls llm correctly.""" + # When + mock_llm = Mock() + with patch( + "rasa.dialogue_understanding.generator.llm_command_generator.llm_factory", + Mock(return_value=mock_llm), + ): + command_generator.predict_commands(Mock(), flows=Mock(), tracker=Mock()) + # Then + mock_llm.assert_called_once_with("some prompt") + + @patch.object(LLMCommandGenerator, "render_template", Mock(return_value="some prompt")) + @patch.object(LLMCommandGenerator, "parse_commands", Mock()) + def test_generate_action_list_catches_llm_exception(self, command_generator: LLMCommandGenerator): + """Test that predict_commands calls llm correctly.""" + # Given + mock_llm = Mock(side_effect=Exception("some exception")) + with patch( + "rasa.dialogue_understanding.generator.llm_command_generator.llm_factory", + Mock(return_value=mock_llm), + ): + # When + with capture_logs() as logs: + command_generator.predict_commands(Mock(), flows=Mock(), tracker=Mock()) + # Then + print(logs) + assert len(logs) == 4 + assert isinstance(logs[1]["error"]) == isinstance(Exception("some exception")) + + + + def test_render_template(self, mock_command_generator: LLMCommandGenerator): + """Test that render_template renders a template.""" + pass + # # Given + # message = Mock() + + # tracker = Mock() + + # flows = Mock() + # # When + # rendered_template = command_generator.render_template() + + # # Then + # assert rendered_template == "template" + + # def test_generate_action_list_calls_llm_with_correct_promt(self): + # # Given + # prompt = "some prompt" + # with patch( + # "rasa.rasa.shared.utils.llm.llm_factory", + # Mock(return_value=FakeListLLM(responses=["hello"])) + # ) as mock_llm: + # LLMCommandGenerator._generate_action_list(prompt) + # mock_llm.assert_called_once_with(prompt) + + @pytest.mark.parametrize( + "input_action, expected_command", + [ + ( + None, + [ErrorCommand()] + ), + ( + "SetSlot(transfer_money_amount_of_money, )", + [SetSlotCommand(name="transfer_money_amount_of_money", value=None)] + ), + ( + "SetSlot(flow_name, some_flow)", + [StartFlowCommand(flow="some_flow")] + ), + ( + "StartFlow(check_balance)", + [StartFlowCommand(flow="check_balance")] + ), + ( + "CancelFlow()", + [CancelFlowCommand()] + ), + ( + "ChitChat()", + [ChitChatAnswerCommand()] + ), + ( + "SearchAndReply()", + [KnowledgeAnswerCommand()] + ), + ( + "HumanHandoff()", + [HumanHandoffCommand()] + ), + ( + "Clarify(transfer_money)", + [ClarifyCommand(options=["transfer_money"])] + ), + ( + "Clarify(list_contacts, add_contact, remove_contact)", + [ClarifyCommand(options=["list_contacts", "add_contact", "remove_contact"])] + ), + ]) + def test_parse_commands_identifies_correct_command(self, input_action, expected_command): + """Test that parse_commands identifies the correct commands.""" + # When + with patch.object(LLMCommandGenerator, "coerce_slot_value", Mock(return_value=None)): + parsed_commands = LLMCommandGenerator.parse_commands(input_action, Mock()) + # Then + assert parsed_commands == expected_command + + @pytest.mark.parametrize( + "slot_name, slot, slot_value, expected_coerced_value", + [ + ("some_other_slot", FloatSlot("some_float", []), None, None), + ("some_float", FloatSlot("some_float", []), 40, 40.0), + ("some_float", FloatSlot("some_float", []), 40.0, 40.0), + ("some_text", TextSlot("some_text", []),"fourty", "fourty"), + ("some_bool", BooleanSlot("some_bool", []), "True", True), + ("some_bool", BooleanSlot("some_bool", []), "false", False) + ]) + def test_coerce_slot_value(self, slot_name, slot, slot_value, expected_coerced_value): + """Test that coerce_slot_value coerces the slot value correctly.""" + # Given + tracker = DialogueStateTracker.from_events( + "test", + evts=[], + slots=[slot] + ) + # When + coerced_value = LLMCommandGenerator.coerce_slot_value(slot_value, slot_name, tracker) + # Then + assert coerced_value == expected_coerced_value + + @pytest.mark.parametrize( + "input_string, expected_string", + [ + ("text", "text"), + (" text ", "text"), + ("\"text\"", "text"), + ("'text'", "text"), + ("' \"text' \" ", "text"), + ("", "") + ]) + def test_clean_extracted_value(self, input_string, expected_string): + """Test that clean_extracted_value removes the leading and trailing whitespaces.""" + # When + cleaned_extracted_value = LLMCommandGenerator.clean_extracted_value(input_string) + # Then + assert cleaned_extracted_value == expected_string + + + + + + + + + + + + # def test_allowd_values_for_slot(self, command_generator): + # """Test that allowed_values_for_slot returns the allowed values for a slot.""" + # # When + # allowed_values = command_generator.allowed_values_for_slot("slot_name") + + # # Then + # assert allowed_values == [] + + # @pytest.mark.parametrize("input_value, expected_truthiness", + # [(None, True), + # ("", False), + + # )] + # def test_is_none_value(self): + # """Test that is_none_value returns True when the value is None.""" + # assert LLMCommandGenerator.is_none_value(None) diff --git a/tests/cdu/stack/__init__.py b/tests/dialogue_understanding/stack/__init__.py similarity index 100% rename from tests/cdu/stack/__init__.py rename to tests/dialogue_understanding/stack/__init__.py diff --git a/tests/cdu/stack/frames/__init__.py b/tests/dialogue_understanding/stack/frames/__init__.py similarity index 100% rename from tests/cdu/stack/frames/__init__.py rename to tests/dialogue_understanding/stack/frames/__init__.py diff --git a/tests/cdu/stack/frames/test_chit_chat_frame.py b/tests/dialogue_understanding/stack/frames/test_chit_chat_frame.py similarity index 100% rename from tests/cdu/stack/frames/test_chit_chat_frame.py rename to tests/dialogue_understanding/stack/frames/test_chit_chat_frame.py diff --git a/tests/cdu/stack/frames/test_dialogue_stack_frame.py b/tests/dialogue_understanding/stack/frames/test_dialogue_stack_frame.py similarity index 100% rename from tests/cdu/stack/frames/test_dialogue_stack_frame.py rename to tests/dialogue_understanding/stack/frames/test_dialogue_stack_frame.py diff --git a/tests/cdu/stack/frames/test_flow_frame.py b/tests/dialogue_understanding/stack/frames/test_flow_frame.py similarity index 100% rename from tests/cdu/stack/frames/test_flow_frame.py rename to tests/dialogue_understanding/stack/frames/test_flow_frame.py diff --git a/tests/cdu/stack/frames/test_search_frame.py b/tests/dialogue_understanding/stack/frames/test_search_frame.py similarity index 100% rename from tests/cdu/stack/frames/test_search_frame.py rename to tests/dialogue_understanding/stack/frames/test_search_frame.py diff --git a/tests/cdu/stack/test_dialogue_stack.py b/tests/dialogue_understanding/stack/test_dialogue_stack.py similarity index 100% rename from tests/cdu/stack/test_dialogue_stack.py rename to tests/dialogue_understanding/stack/test_dialogue_stack.py diff --git a/tests/cdu/stack/test_utils.py b/tests/dialogue_understanding/stack/test_utils.py similarity index 100% rename from tests/cdu/stack/test_utils.py rename to tests/dialogue_understanding/stack/test_utils.py