diff --git a/rasa/cdu/generator/command_generator.py b/rasa/cdu/generator/command_generator.py index 2a0e4b0e3e2a..5664f003931d 100644 --- a/rasa/cdu/generator/command_generator.py +++ b/rasa/cdu/generator/command_generator.py @@ -1,10 +1,13 @@ from typing import List, Optional +import structlog from rasa.cdu.commands import Command from rasa.shared.core.flows.flow import FlowsList from rasa.shared.core.trackers import DialogueStateTracker from rasa.shared.nlu.training_data.message import Message from rasa.shared.nlu.constants import COMMANDS +structlogger = structlog.get_logger() + class CommandGenerator: """A command generator. @@ -33,7 +36,13 @@ def process( The processed messages (usually this is just one during prediction). """ for message in messages: - commands = self.predict_commands(message, flows, tracker) + try: + commands = self.predict_commands(message, flows, tracker) + except Exception as e: + if isinstance(e, NotImplementedError): + raise e + structlogger.error("command_generator.predict.error", error=e) + commands = [] commands_dicts = [command.as_dict() for command in commands] message.set(COMMANDS, commands_dicts, add_to_output=True) return messages diff --git a/tests/cdu/generator/__init__.py b/tests/cdu/generator/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/cdu/generator/test_command_generator.py b/tests/cdu/generator/test_command_generator.py new file mode 100644 index 000000000000..d72b628901db --- /dev/null +++ b/tests/cdu/generator/test_command_generator.py @@ -0,0 +1,33 @@ +from typing import Optional, List + +from rasa.cdu.commands import Command +from rasa.cdu.generator.command_generator import CommandGenerator +from rasa.cdu.commands.chit_chat_answer_command import ChitChatAnswerCommand +from rasa.shared.core.flows.flow import FlowsList +from rasa.shared.core.trackers import DialogueStateTracker +from rasa.shared.nlu.constants import TEXT, COMMANDS +from rasa.shared.nlu.training_data.message import Message + + +class WackyCommandGenerator(CommandGenerator): + def predict_commands( + self, + message: Message, + flows: FlowsList, + tracker: Optional[DialogueStateTracker] = None, + ) -> List[Command]: + if message.get(TEXT) == "Hi": + raise ValueError("Message too banal - I am quitting.") + else: + return [ChitChatAnswerCommand()] + + +def test_command_generator_catches_processing_errors(): + generator = WackyCommandGenerator() + messages = [Message.build("Hi"), Message.build("What is your purpose?")] + generator.process(messages, FlowsList([])) + commands = [m.get(COMMANDS) for m in messages] + + assert len(commands[0]) == 0 + assert len(commands[1]) == 1 + assert commands[1][0]["command"] == ChitChatAnswerCommand.command()