Skip to content

Commit

Permalink
WIP: added tests for contains_command, get_commands, validate_state_o…
Browse files Browse the repository at this point in the history
…f_commands.
  • Loading branch information
djcowley committed Oct 19, 2023
1 parent 5569080 commit ee0ece6
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from rasa.shared.core.flows.flow import FlowsList
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.core.flows.yaml_flows_io import flows_from_str
from tests.dialogue_understanding.commands.conftest import start_bar_user_uttered
from tests.dialogue_understanding.conftest import start_bar_user_uttered


def test_properly_prepared_tracker(tracker: DialogueStateTracker):
Expand Down
File renamed without changes.
75 changes: 62 additions & 13 deletions tests/dialogue_understanding/processor/test_command_processor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import pytest

from rasa.dialogue_understanding.commands import SetSlotCommand, StartFlowCommand
from rasa.dialogue_understanding.commands import (
CancelFlowCommand,
CorrectSlotsCommand,
FreeFormAnswerCommand,
SetSlotCommand,
StartFlowCommand,
)
from rasa.dialogue_understanding.processor.command_processor import (
contains_command,
_get_commands_from_tracker,
contains_command,
find_updated_flows,
validate_state_of_commands,
)
from rasa.shared.core.events import SlotSet, UserUttered
from rasa.shared.core.flows.flow import FlowsList
from rasa.shared.core.trackers import DialogueStateTracker


Expand All @@ -30,18 +39,58 @@ def test_contains_command(commands, command_type, expected_result):
assert result == expected_result


def test_get_commands_from_tracker():
def test_get_commands_from_tracker(tracker: DialogueStateTracker):
"""Test if commands are correctly extracted from tracker."""
# Given
tracker = DialogueStateTracker.from_events(
"test",
evts=[
UserUttered("hi", {"name": "greet"}),
],
)
# use the conftest.py written by thomas in stack clean up pr.
# When
commands = _get_commands_from_tracker(tracker)
# Then
assert len(commands) == 2
assert isinstance(commands[0], SetSlotCommand)
assert isinstance(commands[0], StartFlowCommand)
assert commands[0].command() == "start flow"
assert commands[0].flow == "foo"


@pytest.mark.parametrize(
"commands",
[
[CancelFlowCommand()],
[StartFlowCommand("flow_name")],
[SetSlotCommand("slot_name", "slot_value")],
[StartFlowCommand("flow_name"), SetSlotCommand("slot_name", "slot_value")],
[FreeFormAnswerCommand(), SetSlotCommand("slot_name", "slot_value")],
[
FreeFormAnswerCommand(),
FreeFormAnswerCommand(),
StartFlowCommand("flow_name"),
],
[CorrectSlotsCommand([])],
[CorrectSlotsCommand([]), StartFlowCommand("flow_name")],
],
)
def test_validate_state_of_commands(commands):
"""Test if commands are correctly validated."""
# Then
validate_state_of_commands(commands)
# No exception should be raised


@pytest.mark.parametrize(
"commands",
[
[CancelFlowCommand(), CancelFlowCommand()],
[StartFlowCommand("flow_name"), FreeFormAnswerCommand()],
[CorrectSlotsCommand([]), CorrectSlotsCommand([])],
],
)
def test_validate_state_of_commands_raises_exception(commands):
"""Test if commands are correctly validated."""
# Then
with pytest.raises(AssertionError):
validate_state_of_commands(commands)


def test_find_updated_flows(tracker: DialogueStateTracker, all_flows: FlowsList):
"""Test if updated flows are correctly found."""
# When
updated_flows = find_updated_flows(tracker, all_flows)
# Then
assert updated_flows == {"foo"}

0 comments on commit ee0ece6

Please sign in to comment.