Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flow circuit breaker #12882

Merged
merged 6 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions rasa/core/policies/flow_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from structlog.contextvars import (
bound_contextvars,
)
from rasa.dialogue_understanding.patterns.internal_error import (
InternalErrorPatternFlowStackFrame,
)
from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack
from rasa.dialogue_understanding.stack.frames import (
BaseFlowStackFrame,
Expand All @@ -23,7 +26,10 @@
ContinueInterruptedPatternFlowStackFrame,
)
from rasa.dialogue_understanding.stack.frames.flow_stack_frame import FlowStackFrameType
from rasa.dialogue_understanding.stack.utils import top_user_flow_frame
from rasa.dialogue_understanding.stack.utils import (
end_top_user_flow,
top_user_flow_frame,
)

from rasa.core.constants import (
DEFAULT_POLICY_PRIORITY,
Expand Down Expand Up @@ -220,7 +226,18 @@ def predict_action_probabilities(
"There appears to be an infinite loop in the flows."
),
)
return self._prediction(self._default_predictions(domain))
# end the current flow and start the internal error flow
end_top_user_flow(executor.dialogue_stack)
executor.dialogue_stack.push(InternalErrorPatternFlowStackFrame())
# we retry, with the internal error frame on the stack
prediction = executor.advance_flows(tracker)
return self._create_prediction_result(
prediction.action_name,
domain,
prediction.score,
prediction.events,
prediction.metadata,
)

def _create_prediction_result(
self,
Expand Down
23 changes: 22 additions & 1 deletion rasa/dialogue_understanding/stack/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from rasa.dialogue_understanding.stack.frames import BaseFlowStackFrame
from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack
from rasa.dialogue_understanding.stack.frames import UserFlowStackFrame
from rasa.shared.core.flows.flow import FlowsList
from rasa.dialogue_understanding.stack.frames.flow_stack_frame import FlowStackFrameType
from rasa.shared.core.flows.flow import END_STEP, ContinueFlowStep, FlowsList


def top_flow_frame(
Expand Down Expand Up @@ -106,3 +107,23 @@ def user_flows_on_the_stack(dialogue_stack: DialogueStack) -> Set[str]:
return {
f.flow_id for f in dialogue_stack.frames if isinstance(f, UserFlowStackFrame)
}


def end_top_user_flow(stack: DialogueStack) -> None:
"""Ends all frames on top of the stack including the topmost user frame.

Ends all flows until the next user flow is reached. This is useful
if you want to end all flows that are currently on the stack and
the user flow that triggered them.

Args:
stack: The dialogue stack.
"""

for frame in reversed(stack.frames):
if isinstance(frame, BaseFlowStackFrame):
frame.step_id = ContinueFlowStep.continue_step_for_id(END_STEP)
if isinstance(frame, UserFlowStackFrame):
# Making sure there are no "continue interrupts" triggered
frame.frame_type = FlowStackFrameType.REGULAR
tmbo marked this conversation as resolved.
Show resolved Hide resolved
break
28 changes: 21 additions & 7 deletions rasa/shared/importers/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,15 +405,19 @@ def load_default_pattern_flows_domain() -> Domain:

return Domain.from_path(default_flows_file)

@rasa.shared.utils.common.cached_method
def get_flows(self) -> FlowsList:
flows = self._importer.get_flows()
@classmethod
def merge_with_default_flows(cls, flows: FlowsList) -> FlowsList:
"""Merges the passed flows with the default flows.

if flows.is_empty():
# if there are no flows, we don't need to add the default flows either
return flows
If a user defined flow contains a flow with an id of a default flow,
it will overwrite the default flow.

Args:
flows: user defined flows.

default_flows = self.load_default_pattern_flows()
Returns:
Merged flows."""
default_flows = cls.load_default_pattern_flows()

user_flow_ids = [flow.id for flow in flows.underlying_flows]
missing_default_flows = [
Expand All @@ -424,6 +428,16 @@ def get_flows(self) -> FlowsList:

return flows.merge(FlowsList(missing_default_flows))

@rasa.shared.utils.common.cached_method
def get_flows(self) -> FlowsList:
flows = self._importer.get_flows()

if flows.is_empty():
# if there are no flows, we don't need to add the default flows either
return flows

return self.merge_with_default_flows(flows)
tmbo marked this conversation as resolved.
Show resolved Hide resolved

@rasa.shared.utils.common.cached_method
def get_domain(self) -> Domain:
"""Merge existing domain with properties of flows."""
Expand Down
86 changes: 85 additions & 1 deletion tests/cdu/stack/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
)
from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack
from rasa.dialogue_understanding.stack.frames.chit_chat_frame import ChitChatStackFrame
from rasa.dialogue_understanding.stack.frames.flow_stack_frame import UserFlowStackFrame
from rasa.dialogue_understanding.stack.frames.flow_stack_frame import (
FlowStackFrameType,
UserFlowStackFrame,
)
from rasa.dialogue_understanding.stack.utils import (
end_top_user_flow,
filled_slots_for_active_flow,
top_flow_frame,
top_user_flow_frame,
Expand Down Expand Up @@ -199,3 +203,83 @@ def test_filled_slots_for_active_flow_only_collects_till_top_most_user_flow_fram
stack = DialogueStack(frames=[another_user_frame, user_frame])

assert filled_slots_for_active_flow(stack, all_flows) == {"foo", "bar"}


def test_end_top_user_flow():
user_frame = UserFlowStackFrame(
flow_id="my_flow", step_id="collect_bar", frame_id="some-frame-id"
)
pattern_frame = CollectInformationPatternFlowStackFrame(
collect="foo", frame_id="some-other-id"
)
stack = DialogueStack(frames=[user_frame, pattern_frame])

end_top_user_flow(stack)

assert len(stack.frames) == 2

assert stack.frames[0] == UserFlowStackFrame(
flow_id="my_flow", step_id="NEXT:END", frame_id="some-frame-id"
)
assert stack.frames[1] == CollectInformationPatternFlowStackFrame(
collect="foo", frame_id="some-other-id", step_id="NEXT:END"
)


def test_end_top_user_flow_only_ends_topmost_user_frame():
user_frame = UserFlowStackFrame(
flow_id="my_flow", step_id="collect_bar", frame_id="some-frame-id"
)
other_user_frame = UserFlowStackFrame(
flow_id="my_other_flow", step_id="collect_bar2", frame_id="some-other-id"
)
stack = DialogueStack(frames=[other_user_frame, user_frame])

end_top_user_flow(stack)

assert len(stack.frames) == 2

assert stack.frames[0] == UserFlowStackFrame(
flow_id="my_other_flow", step_id="collect_bar2", frame_id="some-other-id"
)
assert stack.frames[1] == UserFlowStackFrame(
flow_id="my_flow", step_id="NEXT:END", frame_id="some-frame-id"
)


def test_end_top_user_flow_handles_interrupt_frames():
user_frame = UserFlowStackFrame(
flow_id="my_flow",
step_id="collect_bar",
frame_id="some-frame-id",
frame_type=FlowStackFrameType.INTERRUPT,
)
other_user_frame = UserFlowStackFrame(
flow_id="my_other_flow", step_id="collect_bar2", frame_id="some-other-id"
)
stack = DialogueStack(frames=[other_user_frame, user_frame])

end_top_user_flow(stack)

assert len(stack.frames) == 2

assert stack.frames[0] == UserFlowStackFrame(
flow_id="my_other_flow",
step_id="collect_bar2",
frame_id="some-other-id",
frame_type=FlowStackFrameType.REGULAR,
)

assert stack.frames[1] == UserFlowStackFrame(
flow_id="my_flow",
step_id="NEXT:END",
frame_id="some-frame-id",
frame_type=FlowStackFrameType.REGULAR,
)
tmbo marked this conversation as resolved.
Show resolved Hide resolved


def test_end_top_user_flow_handles_empty():
stack = DialogueStack(frames=[])
end_top_user_flow(stack)

assert len(stack.frames) == 0
59 changes: 58 additions & 1 deletion tests/core/policies/test_flow_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
UserFlowStackFrame,
SearchStackFrame,
)
from tests.utilities import flows_from_str
from tests.utilities import (
flows_default_domain,
flows_from_str,
flows_from_str_with_defaults,
)


@pytest.fixture()
Expand Down Expand Up @@ -237,6 +241,59 @@ def test_executor_trips_internal_circuit_breaker():
executor.select_next_action(tracker)


def test_policy_triggers_error_pattern_if_internal_circuit_breaker_is_tripped(
default_flow_policy: FlowPolicy,
):
flow_with_loop = flows_from_str_with_defaults(
"""
flows:
foo_flow:
steps:
- id: "1"
set_slot:
foo: bar
next: "2"
- id: "2"
set_slot:
foo: barbar
next: "1"
"""
)

domain = flows_default_domain()

stack = DialogueStack(
frames=[UserFlowStackFrame(flow_id="foo_flow", step_id="1", frame_id="some-id")]
)

tracker = DialogueStateTracker.from_events(
"test",
evts=[ActionExecuted(action_name="action_listen"), stack.persist_as_event()],
domain=domain,
slots=domain.slots,
)

prediction = default_flow_policy.predict_action_probabilities(
tracker=tracker, domain=domain, flows=flow_with_loop
)

assert prediction.max_confidence == 1.0

predicted_idx = prediction.max_confidence_index
assert domain.action_names_or_texts[predicted_idx] == "utter_internal_error_rasa"
# check that the stack was updated.
assert len(prediction.optional_events) == 1
assert isinstance(prediction.optional_events[0], SlotSet)

assert prediction.optional_events[0].key == "dialogue_stack"
# the user flow should be on the stack as well as the error pattern
assert len(prediction.optional_events[0].value) == 2
# the user flow should be about to end
assert prediction.optional_events[0].value[0]["step_id"] == "NEXT:END"
# the pattern should be the other frame
assert prediction.optional_events[0].value[1]["flow_id"] == "pattern_internal_error"


def test_executor_does_not_get_tripped_if_an_action_is_predicted_in_loop():
flow_with_loop = flows_from_str(
"""
Expand Down
12 changes: 12 additions & 0 deletions tests/utilities.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from yarl import URL
import textwrap
from rasa.shared.core.domain import Domain
from rasa.shared.core.flows.flow import FlowsList
from rasa.shared.core.flows.yaml_flows_io import YAMLFlowsReader
from rasa.shared.importers.importer import FlowSyncImporter


def latest_request(mocked, request_type, path):
Expand All @@ -15,3 +17,13 @@ def json_of_latest_request(r):
def flows_from_str(yaml_str: str) -> FlowsList:
"""Reads flows from a YAML string."""
return YAMLFlowsReader.read_from_string(textwrap.dedent(yaml_str))


def flows_from_str_with_defaults(yaml_str: str) -> FlowsList:
"""Reads flows from a YAML string and includes buildin flows."""
return FlowSyncImporter.merge_with_default_flows(flows_from_str(yaml_str))


def flows_default_domain() -> Domain:
"""Returns the default domain for the default flows."""
return FlowSyncImporter.load_default_pattern_flows_domain()
Loading