Skip to content

Commit

Permalink
catching any command generator exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
twerkmeister committed Sep 14, 2023
1 parent 1425286 commit 03aaf6c
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 1 deletion.
11 changes: 10 additions & 1 deletion rasa/cdu/generator/command_generator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import List, Optional
import structlog
from rasa.cdu.commands import Command
from rasa.shared.core.flows.flow import FlowsList
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.nlu.training_data.message import Message
from rasa.shared.nlu.constants import COMMANDS

structlogger = structlog.get_logger()


class CommandGenerator:
"""A command generator.
Expand Down Expand Up @@ -33,7 +36,13 @@ def process(
The processed messages (usually this is just one during prediction).
"""
for message in messages:
commands = self.predict_commands(message, flows, tracker)
try:
commands = self.predict_commands(message, flows, tracker)
except Exception as e:
if isinstance(e, NotImplementedError):
raise e
structlogger.error("command_generator.predict.error", error=e)
commands = []
commands_dicts = [command.as_dict() for command in commands]
message.set(COMMANDS, commands_dicts, add_to_output=True)
return messages
Expand Down
Empty file added tests/cdu/generator/__init__.py
Empty file.
33 changes: 33 additions & 0 deletions tests/cdu/generator/test_command_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Optional, List

from rasa.cdu.commands import Command
from rasa.cdu.generator.command_generator import CommandGenerator
from rasa.cdu.commands.chit_chat_answer_command import ChitChatAnswerCommand
from rasa.shared.core.flows.flow import FlowsList
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.nlu.constants import TEXT, COMMANDS
from rasa.shared.nlu.training_data.message import Message


class WackyCommandGenerator(CommandGenerator):
def predict_commands(
self,
message: Message,
flows: FlowsList,
tracker: Optional[DialogueStateTracker] = None,
) -> List[Command]:
if message.get(TEXT) == "Hi":
raise ValueError("Message too banal - I am quitting.")
else:
return [ChitChatAnswerCommand()]


def test_command_generator_catches_processing_errors():
generator = WackyCommandGenerator()
messages = [Message.build("Hi"), Message.build("What is your purpose?")]
generator.process(messages, FlowsList([]))
commands = [m.get(COMMANDS) for m in messages]

assert len(commands[0]) == 0
assert len(commands[1]) == 1
assert commands[1][0]["command"] == ChitChatAnswerCommand.command()

0 comments on commit 03aaf6c

Please sign in to comment.