Skip to content

Commit

Permalink
wrap up and tests for flow trigger action
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbo committed Oct 24, 2023
1 parent f71a675 commit 22a66ef
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 24 deletions.
4 changes: 2 additions & 2 deletions rasa/core/actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,9 @@ def action_for_name_or_text(
return FormAction(action_name_or_text, action_endpoint)

if action_name_or_text.startswith(FLOW_PREFIX):
from rasa.core.actions.flows import FlowTriggerAction
from rasa.core.actions.action_trigger_flow import ActionTriggerFlow

return FlowTriggerAction(action_name_or_text)
return ActionTriggerFlow(action_name_or_text)
return RemoteAction(action_name_or_text, action_endpoint)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,54 +22,81 @@
structlogger = structlog.get_logger(__name__)


class FlowTriggerAction(action.Action):
"""Action which implements and executes the form logic."""
class ActionTriggerFlow(action.Action):
"""Action which triggers a flow by putting it on the dialogue stack."""

def __init__(self, flow_action_name: Text) -> None:
"""Creates a `FlowTriggerAction`.
"""Creates a `ActionTriggerFlow`.
Args:
flow_action_name: Name of the flow.
"""
super().__init__()

if not flow_action_name.startswith(FLOW_PREFIX):
raise ValueError(
f"Flow action name '{flow_action_name}' needs to start with "
f"'{FLOW_PREFIX}'."
)

self._flow_name = flow_action_name[len(FLOW_PREFIX) :]
self._flow_action_name = flow_action_name

def name(self) -> Text:
"""Return the flow name."""
return self._flow_action_name

async def run(
self,
output_channel: "OutputChannel",
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
metadata: Optional[Dict[Text, Any]] = None,
) -> List[Event]:
"""Trigger the flow."""
def create_event_to_start_flow(self, tracker: DialogueStateTracker) -> Event:
"""Create an event to start the flow.
Args:
tracker: The tracker to start the flow on.
Returns:
The event to start the flow."""
stack = DialogueStack.from_tracker(tracker)
if not stack.is_empty():
frame_type = FlowStackFrameType.INTERRUPT
else:
frame_type = FlowStackFrameType.REGULAR
frame_type = (
FlowStackFrameType.REGULAR
if stack.is_empty()
else FlowStackFrameType.INTERRUPT
)

stack.push(
UserFlowStackFrame(
flow_id=self._flow_name,
frame_type=frame_type,
)
)
return stack.persist_as_event()

def create_events_to_set_flow_slots(self, metadata: Dict[str, Any]) -> List[Event]:
"""Create events to set the flow slots.
Set additional slots to prefill information for the flow.
Args:
metadata: The metadata to set the slots from.
Returns:
The events to set the flow slots.
"""
slots_to_be_set = metadata.get("slots", {}) if metadata else {}
slot_set_events: List[Event] = [
SlotSet(key, value) for key, value in slots_to_be_set.items()
]
return [SlotSet(key, value) for key, value in slots_to_be_set.items()]

async def run(
self,
output_channel: "OutputChannel",
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
metadata: Optional[Dict[Text, Any]] = None,
) -> List[Event]:
"""Trigger the flow."""
events: List[Event] = [self.create_event_to_start_flow(tracker)]
events.extend(self.create_events_to_set_flow_slots(metadata))

events: List[Event] = [
stack.persist_as_event(),
] + slot_set_events
if tracker.active_loop_name:
# end any active loop to ensure we are progressing the started flow
events.append(ActiveLoop(None))

return events
91 changes: 91 additions & 0 deletions tests/core/actions/test_action_trigger_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import pytest
from rasa.core.actions.action_trigger_flow import ActionTriggerFlow
from rasa.core.channels import CollectingOutputChannel
from rasa.core.nlg import TemplatedNaturalLanguageGenerator
from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack
from rasa.dialogue_understanding.stack.frames.flow_stack_frame import (
FlowStackFrameType,
UserFlowStackFrame,
)
from rasa.shared.core.constants import DIALOGUE_STACK_SLOT
from rasa.shared.core.domain import Domain
from rasa.shared.core.events import ActiveLoop, SlotSet
from rasa.shared.core.trackers import DialogueStateTracker


async def test_action_trigger_flow():
tracker = DialogueStateTracker.from_events("test", [])
action = ActionTriggerFlow("flow_foo")
channel = CollectingOutputChannel()
nlg = TemplatedNaturalLanguageGenerator({})
events = await action.run(channel, nlg, tracker, Domain.empty())
assert len(events) == 1
event = events[0]
assert isinstance(event, SlotSet)
assert event.key == DIALOGUE_STACK_SLOT
assert len(event.value) == 1
assert event.value[0]["type"] == UserFlowStackFrame.type()
assert event.value[0]["flow_id"] == "foo"


async def test_action_trigger_flow_with_slots():
tracker = DialogueStateTracker.from_events("test", [])
action = ActionTriggerFlow("flow_foo")
channel = CollectingOutputChannel()
nlg = TemplatedNaturalLanguageGenerator({})
events = await action.run(
channel, nlg, tracker, Domain.empty(), metadata={"slots": {"foo": "bar"}}
)

event = events[0]
assert isinstance(event, SlotSet)
assert event.key == DIALOGUE_STACK_SLOT
assert len(event.value) == 1
assert event.value[0]["type"] == UserFlowStackFrame.type()
assert event.value[0]["flow_id"] == "foo"

assert len(events) == 2
event = events[1]
assert isinstance(event, SlotSet)
assert event.key == "foo"
assert event.value == "bar"


async def test_action_trigger_fails_if_name_is_invalid():
with pytest.raises(ValueError):
ActionTriggerFlow("foo")


async def test_action_trigger_ends_an_active_loop_on_the_tracker():
tracker = DialogueStateTracker.from_events("test", [ActiveLoop("loop_foo")])
action = ActionTriggerFlow("flow_foo")
channel = CollectingOutputChannel()
nlg = TemplatedNaturalLanguageGenerator({})
events = await action.run(channel, nlg, tracker, Domain.empty())

assert len(events) == 2
assert isinstance(events[1], ActiveLoop)
assert events[1].name is None


async def test_action_trigger_uses_interrupt_flow_type_if_stack_already_contains_flow():
user_frame = UserFlowStackFrame(
flow_id="my_flow", step_id="collect_bar", frame_id="some-frame-id"
)
stack = DialogueStack(frames=[user_frame])
tracker = DialogueStateTracker.from_events("test", [stack.persist_as_event()])

action = ActionTriggerFlow("flow_foo")
channel = CollectingOutputChannel()
nlg = TemplatedNaturalLanguageGenerator({})

events = await action.run(channel, nlg, tracker, Domain.empty())

assert len(events) == 1
event = events[0]
assert isinstance(event, SlotSet)
assert event.key == DIALOGUE_STACK_SLOT
assert len(event.value) == 2
assert event.value[1]["type"] == UserFlowStackFrame.type()
assert event.value[1]["flow_id"] == "foo"
assert event.value[1]["frame_type"] == FlowStackFrameType.INTERRUPT.value

0 comments on commit 22a66ef

Please sign in to comment.