diff --git a/rasa/core/actions/flow_trigger_action.py b/rasa/core/actions/flow_trigger_action.py index 82ee017679bf..7b4271e27c50 100644 --- a/rasa/core/actions/flow_trigger_action.py +++ b/rasa/core/actions/flow_trigger_action.py @@ -10,9 +10,6 @@ from rasa.core.channels import OutputChannel from rasa.shared.constants import FLOW_PREFIX -from rasa.shared.core.constants import ( - DIALOGUE_STACK_SLOT, -) from rasa.shared.core.domain import Domain from rasa.shared.core.events import ( ActiveLoop, @@ -70,7 +67,7 @@ async def run( ] events: List[Event] = [ - SlotSet(DIALOGUE_STACK_SLOT, stack.as_dict()) + stack.persist_as_event(), ] + slot_set_events if tracker.active_loop_name: events.append(ActiveLoop(None)) diff --git a/rasa/core/policies/flow_policy.py b/rasa/core/policies/flow_policy.py index 99d2937d3159..9fc23ee73986 100644 --- a/rasa/core/policies/flow_policy.py +++ b/rasa/core/policies/flow_policy.py @@ -7,6 +7,9 @@ from structlog.contextvars import ( bound_contextvars, ) +from rasa.dialogue_understanding.patterns.internal_error import ( + InternalErrorPatternFlowStackFrame, +) from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack from rasa.dialogue_understanding.stack.frames import ( BaseFlowStackFrame, @@ -23,7 +26,10 @@ ContinueInterruptedPatternFlowStackFrame, ) from rasa.dialogue_understanding.stack.frames.flow_stack_frame import FlowStackFrameType -from rasa.dialogue_understanding.stack.utils import top_user_flow_frame +from rasa.dialogue_understanding.stack.utils import ( + end_top_user_flow, + top_user_flow_frame, +) from rasa.core.constants import ( DEFAULT_POLICY_PRIORITY, @@ -36,7 +42,6 @@ from rasa.shared.core.constants import ( ACTION_LISTEN_NAME, ACTION_SEND_TEXT_NAME, - DIALOGUE_STACK_SLOT, ) from rasa.shared.core.events import Event, SlotSet from rasa.shared.core.flows.flow import ( @@ -61,7 +66,7 @@ StaticFlowLink, ) from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer -from rasa.core.policies.policy import Policy, PolicyPrediction, SupportedData +from rasa.core.policies.policy import Policy, PolicyPrediction from rasa.engine.graph import ExecutionContext from rasa.engine.recipes.default_recipe import DefaultV1Recipe from rasa.engine.storage.resource import Resource @@ -73,15 +78,39 @@ ) import structlog +from rasa.shared.exceptions import RasaException + structlogger = structlog.get_logger() +MAX_NUMBER_OF_STEPS = 250 -class FlowException(Exception): + +class FlowException(RasaException): """Exception that is raised when there is a problem with a flow.""" pass +class FlowCircuitBreakerTrippedException(FlowException): + """Exception that is raised when there is a problem with a flow.""" + + def __init__( + self, dialogue_stack: DialogueStack, number_of_steps_taken: int + ) -> None: + """Creates a `FlowCircuitBreakerTrippedException`. + + Args: + dialogue_stack: The dialogue stack. + number_of_steps_taken: The number of steps that were taken. + """ + super().__init__( + f"Flow circuit breaker tripped after {number_of_steps_taken} steps. " + "There appears to be an infinite loop in the flows." + ) + self.dialogue_stack = dialogue_stack + self.number_of_steps_taken = number_of_steps_taken + + @DefaultV1Recipe.register( DefaultV1Recipe.ComponentType.POLICY_WITHOUT_END_TO_END_SUPPORT, is_trainable=False ) @@ -94,7 +123,15 @@ class FlowPolicy(Policy): @staticmethod def does_support_stack_frame(frame: DialogueStackFrame) -> bool: - """Checks if the policy supports the given stack frame.""" + """Checks if the policy supports the topmost frame on the dialogue stack. + + If `False` is returned, the policy will abstain from making a prediction. + + Args: + frame: The frame to check. + + Returns: + `True` if the policy supports the frame, `False` otherwise.""" return isinstance(frame, BaseFlowStackFrame) @staticmethod @@ -106,18 +143,6 @@ def get_default_config() -> Dict[Text, Any]: POLICY_MAX_HISTORY: None, } - @staticmethod - def supported_data() -> SupportedData: - """The type of data supported by this policy. - - By default, this is only ML-based training data. If policies support rule data, - or both ML-based data and rule data, they need to override this method. - - Returns: - The data type supported by this policy (ML-based training data). - """ - return SupportedData.ML_DATA - def __init__( self, config: Dict[Text, Any], @@ -150,9 +175,6 @@ def train( A policy must return its resource locator so that potential children nodes can load the policy from the resource. """ - # currently, nothing to do here. we have access to the flows during - # prediction. we might want to store the flows in the future - # or do some preprocessing here. return self.resource def predict_action_probabilities( @@ -178,20 +200,44 @@ def predict_action_probabilities( The prediction. """ if not self.supports_current_stack_frame(tracker): + # if the policy doesn't support the current stack frame, we'll abstain return self._prediction(self._default_predictions(domain)) flows = flows or FlowsList([]) executor = FlowExecutor.from_tracker(tracker, flows, domain) # create executor and predict next action - prediction = executor.advance_flows(tracker) - return self._create_prediction_result( - prediction.action_name, - domain, - prediction.score, - prediction.events, - prediction.metadata, - ) + try: + prediction = executor.advance_flows(tracker) + return self._create_prediction_result( + prediction.action_name, + domain, + prediction.score, + prediction.events, + prediction.metadata, + ) + except FlowCircuitBreakerTrippedException as e: + structlogger.error( + "flow.circuit_breaker", + dialogue_stack=e.dialogue_stack, + number_of_steps_taken=e.number_of_steps_taken, + event_info=( + "The flow circuit breaker tripped. " + "There appears to be an infinite loop in the flows." + ), + ) + # end the current flow and start the internal error flow + end_top_user_flow(executor.dialogue_stack) + executor.dialogue_stack.push(InternalErrorPatternFlowStackFrame()) + # we retry, with the internal error frame on the stack + prediction = executor.advance_flows(tracker) + return self._create_prediction_result( + prediction.action_name, + domain, + prediction.score, + prediction.events, + prediction.metadata, + ) def _create_prediction_result( self, @@ -425,20 +471,15 @@ def advance_flows(self, tracker: DialogueStateTracker) -> ActionPrediction: return ActionPrediction(None, 0.0) else: previous_stack = DialogueStack.get_persisted_stack(tracker) - prediction = self._select_next_action(tracker) + prediction = self.select_next_action(tracker) if previous_stack != self.dialogue_stack.as_dict(): # we need to update dialogue stack to persist the state of the executor if not prediction.events: prediction.events = [] - prediction.events.append( - SlotSet( - DIALOGUE_STACK_SLOT, - self.dialogue_stack.as_dict(), - ) - ) + prediction.events.append(self.dialogue_stack.persist_as_event()) return prediction - def _select_next_action( + def select_next_action( self, tracker: DialogueStateTracker, ) -> ActionPrediction: @@ -462,7 +503,16 @@ def _select_next_action( number_of_initial_events = len(tracker.events) + number_of_steps_taken = 0 + while isinstance(step_result, ContinueFlowWithNextStep): + + number_of_steps_taken += 1 + if number_of_steps_taken > MAX_NUMBER_OF_STEPS: + raise FlowCircuitBreakerTrippedException( + self.dialogue_stack, number_of_steps_taken + ) + active_frame = self.dialogue_stack.top() if not isinstance(active_frame, BaseFlowStackFrame): # If there is no current flow, we assume that all flows are done @@ -485,7 +535,7 @@ def _select_next_action( self._advance_top_flow_on_stack(current_step.id) with bound_contextvars(step_id=current_step.id): - step_result = self._run_step( + step_result = self.run_step( current_flow, current_step, tracker ) tracker.update_with_events(step_result.events, self.domain) @@ -521,7 +571,7 @@ def _reset_scoped_slots( events.append(SlotSet(step.collect, initial_value)) return events - def _run_step( + def run_step( self, flow: Flow, step: FlowStep, diff --git a/rasa/dialogue_understanding/commands/cancel_flow_command.py b/rasa/dialogue_understanding/commands/cancel_flow_command.py index 125365142713..904fdc775b82 100644 --- a/rasa/dialogue_understanding/commands/cancel_flow_command.py +++ b/rasa/dialogue_understanding/commands/cancel_flow_command.py @@ -9,8 +9,7 @@ from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack from rasa.dialogue_understanding.stack.frames import UserFlowStackFrame -from rasa.shared.core.constants import DIALOGUE_STACK_SLOT -from rasa.shared.core.events import Event, SlotSet +from rasa.shared.core.events import Event from rasa.shared.core.flows.flow import FlowsList from rasa.shared.core.trackers import DialogueStateTracker from rasa.dialogue_understanding.stack.utils import top_user_flow_frame @@ -103,4 +102,4 @@ def run_command_on_tracker( canceled_frames=canceled_frames, ) ) - return [SlotSet(DIALOGUE_STACK_SLOT, stack.as_dict())] + return [stack.persist_as_event()] diff --git a/rasa/dialogue_understanding/commands/chit_chat_answer_command.py b/rasa/dialogue_understanding/commands/chit_chat_answer_command.py index 38be12d210b8..e0701bc530f4 100644 --- a/rasa/dialogue_understanding/commands/chit_chat_answer_command.py +++ b/rasa/dialogue_understanding/commands/chit_chat_answer_command.py @@ -5,8 +5,7 @@ from rasa.dialogue_understanding.commands import FreeFormAnswerCommand from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack from rasa.dialogue_understanding.stack.frames.chit_chat_frame import ChitChatStackFrame -from rasa.shared.core.constants import DIALOGUE_STACK_SLOT -from rasa.shared.core.events import Event, SlotSet +from rasa.shared.core.events import Event from rasa.shared.core.flows.flow import FlowsList from rasa.shared.core.trackers import DialogueStateTracker @@ -47,4 +46,4 @@ def run_command_on_tracker( """ stack = DialogueStack.from_tracker(tracker) stack.push(ChitChatStackFrame()) - return [SlotSet(DIALOGUE_STACK_SLOT, stack.as_dict())] + return [stack.persist_as_event()] diff --git a/rasa/dialogue_understanding/commands/clarify_command.py b/rasa/dialogue_understanding/commands/clarify_command.py index 69a413730e0f..21bbd9ec6f51 100644 --- a/rasa/dialogue_understanding/commands/clarify_command.py +++ b/rasa/dialogue_understanding/commands/clarify_command.py @@ -7,8 +7,7 @@ from rasa.dialogue_understanding.commands import Command from rasa.dialogue_understanding.patterns.clarify import ClarifyPatternFlowStackFrame from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack -from rasa.shared.core.constants import DIALOGUE_STACK_SLOT -from rasa.shared.core.events import Event, SlotSet +from rasa.shared.core.events import Event from rasa.shared.core.flows.flow import FlowsList from rasa.shared.core.trackers import DialogueStateTracker @@ -76,4 +75,4 @@ def run_command_on_tracker( relevant_flows = [all_flows.flow_by_id(opt) for opt in clean_options] names = [flow.readable_name() for flow in relevant_flows if flow is not None] stack.push(ClarifyPatternFlowStackFrame(names=names)) - return [SlotSet(DIALOGUE_STACK_SLOT, stack.as_dict())] + return [stack.persist_as_event()] diff --git a/rasa/dialogue_understanding/commands/correct_slots_command.py b/rasa/dialogue_understanding/commands/correct_slots_command.py index 2c80ed35e531..bc29b90b6a9f 100644 --- a/rasa/dialogue_understanding/commands/correct_slots_command.py +++ b/rasa/dialogue_understanding/commands/correct_slots_command.py @@ -15,8 +15,7 @@ BaseFlowStackFrame, UserFlowStackFrame, ) -from rasa.shared.core.constants import DIALOGUE_STACK_SLOT -from rasa.shared.core.events import Event, SlotSet +from rasa.shared.core.events import Event from rasa.shared.core.flows.flow import END_STEP, ContinueFlowStep, FlowStep, FlowsList from rasa.shared.core.trackers import DialogueStateTracker import rasa.dialogue_understanding.stack.utils as utils @@ -284,4 +283,4 @@ def run_command_on_tracker( self.end_previous_correction(top_flow_frame, stack) stack.push(correction_frame, index=insertion_index) - return [SlotSet(DIALOGUE_STACK_SLOT, stack.as_dict())] + return [stack.persist_as_event()] diff --git a/rasa/dialogue_understanding/commands/error_command.py b/rasa/dialogue_understanding/commands/error_command.py index b5e62cb90d1c..da5b3fbaf393 100644 --- a/rasa/dialogue_understanding/commands/error_command.py +++ b/rasa/dialogue_understanding/commands/error_command.py @@ -9,8 +9,7 @@ InternalErrorPatternFlowStackFrame, ) from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack -from rasa.shared.core.constants import DIALOGUE_STACK_SLOT -from rasa.shared.core.events import Event, SlotSet +from rasa.shared.core.events import Event from rasa.shared.core.flows.flow import FlowsList from rasa.shared.core.trackers import DialogueStateTracker @@ -54,4 +53,4 @@ def run_command_on_tracker( dialogue_stack = DialogueStack.from_tracker(tracker) structlogger.debug("command_executor.error", command=self) dialogue_stack.push(InternalErrorPatternFlowStackFrame()) - return [SlotSet(DIALOGUE_STACK_SLOT, dialogue_stack.as_dict())] + return [dialogue_stack.persist_as_event()] diff --git a/rasa/dialogue_understanding/commands/knowledge_answer_command.py b/rasa/dialogue_understanding/commands/knowledge_answer_command.py index 3077dd44a739..d27587082d8f 100644 --- a/rasa/dialogue_understanding/commands/knowledge_answer_command.py +++ b/rasa/dialogue_understanding/commands/knowledge_answer_command.py @@ -5,8 +5,7 @@ from rasa.dialogue_understanding.commands import FreeFormAnswerCommand from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack from rasa.dialogue_understanding.stack.frames.search_frame import SearchStackFrame -from rasa.shared.core.constants import DIALOGUE_STACK_SLOT -from rasa.shared.core.events import Event, SlotSet +from rasa.shared.core.events import Event from rasa.shared.core.flows.flow import FlowsList from rasa.shared.core.trackers import DialogueStateTracker @@ -47,4 +46,4 @@ def run_command_on_tracker( """ dialogue_stack = DialogueStack.from_tracker(tracker) dialogue_stack.push(SearchStackFrame()) - return [SlotSet(DIALOGUE_STACK_SLOT, dialogue_stack.as_dict())] + return [dialogue_stack.persist_as_event()] diff --git a/rasa/dialogue_understanding/commands/start_flow_command.py b/rasa/dialogue_understanding/commands/start_flow_command.py index cb3cd5d5166a..11ab06bb162b 100644 --- a/rasa/dialogue_understanding/commands/start_flow_command.py +++ b/rasa/dialogue_understanding/commands/start_flow_command.py @@ -14,8 +14,7 @@ top_user_flow_frame, user_flows_on_the_stack, ) -from rasa.shared.core.constants import DIALOGUE_STACK_SLOT -from rasa.shared.core.events import Event, SlotSet +from rasa.shared.core.events import Event from rasa.shared.core.flows.flow import FlowsList from rasa.shared.core.trackers import DialogueStateTracker @@ -88,4 +87,4 @@ def run_command_on_tracker( ) structlogger.debug("command_executor.start_flow", command=self) stack.push(UserFlowStackFrame(flow_id=self.flow, frame_type=frame_type)) - return [SlotSet(DIALOGUE_STACK_SLOT, stack.as_dict())] + return [stack.persist_as_event()] diff --git a/rasa/dialogue_understanding/patterns/cancel.py b/rasa/dialogue_understanding/patterns/cancel.py index 76678285cfd5..b60df2e004cc 100644 --- a/rasa/dialogue_understanding/patterns/cancel.py +++ b/rasa/dialogue_understanding/patterns/cancel.py @@ -15,9 +15,9 @@ from rasa.core.channels.channel import OutputChannel from rasa.core.nlg.generator import NaturalLanguageGenerator from rasa.shared.constants import RASA_DEFAULT_FLOW_PATTERN_PREFIX -from rasa.shared.core.constants import ACTION_CANCEL_FLOW, DIALOGUE_STACK_SLOT +from rasa.shared.core.constants import ACTION_CANCEL_FLOW from rasa.shared.core.domain import Domain -from rasa.shared.core.events import Event, SlotSet +from rasa.shared.core.events import Event from rasa.shared.core.flows.flow import END_STEP, ContinueFlowStep from rasa.shared.core.trackers import DialogueStateTracker @@ -111,4 +111,4 @@ async def run( frame_id=canceled_frame_id, ) - return [SlotSet(DIALOGUE_STACK_SLOT, stack.as_dict())] + return [stack.persist_as_event()] diff --git a/rasa/dialogue_understanding/patterns/clarify.py b/rasa/dialogue_understanding/patterns/clarify.py index fdf41d79c247..db877bf5ef66 100644 --- a/rasa/dialogue_understanding/patterns/clarify.py +++ b/rasa/dialogue_understanding/patterns/clarify.py @@ -12,9 +12,9 @@ from rasa.core.channels.channel import OutputChannel from rasa.core.nlg.generator import NaturalLanguageGenerator from rasa.shared.constants import RASA_DEFAULT_FLOW_PATTERN_PREFIX -from rasa.shared.core.constants import ACTION_CLARIFY_FLOWS, DIALOGUE_STACK_SLOT +from rasa.shared.core.constants import ACTION_CLARIFY_FLOWS from rasa.shared.core.domain import Domain -from rasa.shared.core.events import Event, SlotSet +from rasa.shared.core.events import Event from rasa.shared.core.trackers import DialogueStateTracker @@ -98,4 +98,4 @@ async def run( options_string = self.assemble_options_string(top.names) top.clarification_options = options_string # since we modified the stack frame, we need to update the stack - return [SlotSet(DIALOGUE_STACK_SLOT, stack.as_dict())] + return [stack.persist_as_event()] diff --git a/rasa/dialogue_understanding/patterns/correction.py b/rasa/dialogue_understanding/patterns/correction.py index 7dbd5013e667..1409bba8fba1 100644 --- a/rasa/dialogue_understanding/patterns/correction.py +++ b/rasa/dialogue_understanding/patterns/correction.py @@ -7,9 +7,6 @@ ) from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack from rasa.shared.constants import RASA_DEFAULT_FLOW_PATTERN_PREFIX -from rasa.shared.core.constants import ( - DIALOGUE_STACK_SLOT, -) from rasa.shared.core.flows.flow import ( START_STEP, ) @@ -139,7 +136,7 @@ async def run( ContinueFlowStep.continue_step_for_id(END_STEP) ) - events: List[Event] = [SlotSet(DIALOGUE_STACK_SLOT, stack.as_dict())] + events: List[Event] = [stack.persist_as_event()] events.extend([SlotSet(k, v) for k, v in top.corrected_slots.items()]) diff --git a/rasa/dialogue_understanding/stack/dialogue_stack.py b/rasa/dialogue_understanding/stack/dialogue_stack.py index 5059b919523d..45911b0207c6 100644 --- a/rasa/dialogue_understanding/stack/dialogue_stack.py +++ b/rasa/dialogue_understanding/stack/dialogue_stack.py @@ -6,6 +6,7 @@ from rasa.shared.core.constants import ( DIALOGUE_STACK_SLOT, ) +from rasa.shared.core.events import Event, SlotSet from rasa.shared.core.trackers import ( DialogueStateTracker, ) @@ -129,6 +130,10 @@ def get_persisted_stack(tracker: DialogueStateTracker) -> List[Dict[str, Any]]: The persisted stack as a dictionary.""" return tracker.get_slot(DIALOGUE_STACK_SLOT) or [] + def persist_as_event(self) -> Event: + """Returns the stack as a slot set event.""" + return SlotSet(DIALOGUE_STACK_SLOT, self.as_dict()) + @staticmethod def from_tracker(tracker: DialogueStateTracker) -> DialogueStack: """Creates a `DialogueStack` from a tracker. diff --git a/rasa/dialogue_understanding/stack/utils.py b/rasa/dialogue_understanding/stack/utils.py index 44c1675b82ef..71e59b90d4ba 100644 --- a/rasa/dialogue_understanding/stack/utils.py +++ b/rasa/dialogue_understanding/stack/utils.py @@ -5,7 +5,7 @@ from rasa.dialogue_understanding.stack.frames import BaseFlowStackFrame from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack from rasa.dialogue_understanding.stack.frames import UserFlowStackFrame -from rasa.shared.core.flows.flow import FlowsList +from rasa.shared.core.flows.flow import END_STEP, ContinueFlowStep, FlowsList def top_flow_frame( @@ -106,3 +106,21 @@ def user_flows_on_the_stack(dialogue_stack: DialogueStack) -> Set[str]: return { f.flow_id for f in dialogue_stack.frames if isinstance(f, UserFlowStackFrame) } + + +def end_top_user_flow(stack: DialogueStack) -> None: + """Ends all frames on top of the stack including the topmost user frame. + + Ends all flows until the next user flow is reached. This is useful + if you want to end all flows that are currently on the stack and + the user flow that triggered them. + + Args: + stack: The dialogue stack. + """ + + for frame in reversed(stack.frames): + if isinstance(frame, BaseFlowStackFrame): + frame.step_id = ContinueFlowStep.continue_step_for_id(END_STEP) + if isinstance(frame, UserFlowStackFrame): + break diff --git a/rasa/shared/importers/importer.py b/rasa/shared/importers/importer.py index a03c92298940..9cd5904d43ee 100644 --- a/rasa/shared/importers/importer.py +++ b/rasa/shared/importers/importer.py @@ -405,15 +405,19 @@ def load_default_pattern_flows_domain() -> Domain: return Domain.from_path(default_flows_file) - @rasa.shared.utils.common.cached_method - def get_flows(self) -> FlowsList: - flows = self._importer.get_flows() + @classmethod + def merge_with_default_flows(cls, flows: FlowsList) -> FlowsList: + """Merges the passed flows with the default flows. - if flows.is_empty(): - # if there are no flows, we don't need to add the default flows either - return flows + If a user defined flow contains a flow with an id of a default flow, + it will overwrite the default flow. + + Args: + flows: user defined flows. - default_flows = self.load_default_pattern_flows() + Returns: + Merged flows.""" + default_flows = cls.load_default_pattern_flows() user_flow_ids = [flow.id for flow in flows.underlying_flows] missing_default_flows = [ @@ -424,6 +428,16 @@ def get_flows(self) -> FlowsList: return flows.merge(FlowsList(missing_default_flows)) + @rasa.shared.utils.common.cached_method + def get_flows(self) -> FlowsList: + flows = self._importer.get_flows() + + if flows.is_empty(): + # if there are no flows, we don't need to add the default flows either + return flows + + return self.merge_with_default_flows(flows) + @rasa.shared.utils.common.cached_method def get_domain(self) -> Domain: """Merge existing domain with properties of flows.""" diff --git a/tests/core/policies/test_flow_policy.py b/tests/core/policies/test_flow_policy.py index 4bfcf918237b..a4b66a202ab6 100644 --- a/tests/core/policies/test_flow_policy.py +++ b/tests/core/policies/test_flow_policy.py @@ -1,13 +1,70 @@ import textwrap from typing import List, Optional, Text, Tuple +import pytest + from rasa.core.policies.flow_policy import ( + FlowCircuitBreakerTrippedException, FlowExecutor, + FlowPolicy, ) +from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack +from rasa.engine.graph import ExecutionContext +from rasa.engine.storage.resource import Resource +from rasa.engine.storage.storage import ModelStorage from rasa.shared.core.domain import Domain -from rasa.shared.core.events import ActionExecuted, Event +from rasa.shared.core.events import ActionExecuted, Event, SlotSet +from rasa.shared.core.flows.flow import FlowsList from rasa.shared.core.flows.yaml_flows_io import YAMLFlowsReader from rasa.shared.core.trackers import DialogueStateTracker +from rasa.dialogue_understanding.stack.frames import ( + UserFlowStackFrame, + SearchStackFrame, +) +from tests.utilities import ( + flows_default_domain, + flows_from_str, + flows_from_str_with_defaults, +) + + +@pytest.fixture() +def resource() -> Resource: + return Resource("flow_policy") + + +@pytest.fixture() +def default_flow_policy( + resource: Resource, + default_model_storage: ModelStorage, + default_execution_context: ExecutionContext, +) -> FlowPolicy: + return FlowPolicy( + config={}, + model_storage=default_model_storage, + resource=resource, + execution_context=default_execution_context, + ) + + +@pytest.fixture() +def default_flows() -> FlowsList: + return flows_from_str( + """ + flows: + foo_flow: + steps: + - id: "1" + action: action_listen + next: "2" + - id: "2" + action: action_unlikely_intent # some action that exists by default + bar_flow: + steps: + - id: first_step + action: action_listen + """ + ) def _run_flow_until_listen( @@ -66,3 +123,207 @@ def test_select_next_action() -> None: assert actions == ["flow_test_flow", None] assert events == [] + + +def test_flow_policy_does_support_user_flowstack_frame(): + frame = UserFlowStackFrame(flow_id="foo", step_id="first_step", frame_id="some-id") + assert FlowPolicy.does_support_stack_frame(frame) + + +def test_flow_policy_does_not_support_search_frame(): + frame = SearchStackFrame( + frame_id="some-id", + ) + assert not FlowPolicy.does_support_stack_frame(frame) + + +def test_get_default_config(): + assert FlowPolicy.get_default_config() == {"priority": 1, "max_history": None} + + +def test_predict_action_probabilities_abstains_from_unsupported_frame( + default_flow_policy: FlowPolicy, +): + domain = Domain.empty() + + stack = DialogueStack(frames=[SearchStackFrame(frame_id="some-id")]) + # create a tracker with the stack set + tracker = DialogueStateTracker.from_events( + "test abstain", + domain=domain, + slots=domain.slots, + evts=[ActionExecuted(action_name="action_listen"), stack.persist_as_event()], + ) + + prediction = default_flow_policy.predict_action_probabilities( + tracker=tracker, + domain=Domain.empty(), + ) + + # check that the policy didn't predict anything + assert prediction.max_confidence == 0.0 + + +def test_predict_action_probabilities_advances_topmost_flow( + default_flow_policy: FlowPolicy, default_flows: FlowsList +): + domain = Domain.empty() + + stack = DialogueStack( + frames=[UserFlowStackFrame(flow_id="foo_flow", step_id="1", frame_id="some-id")] + ) + + tracker = DialogueStateTracker.from_events( + "test abstain", + domain=domain, + slots=domain.slots, + evts=[ActionExecuted(action_name="action_listen"), stack.persist_as_event()], + ) + + prediction = default_flow_policy.predict_action_probabilities( + tracker=tracker, domain=Domain.empty(), flows=default_flows + ) + + assert prediction.max_confidence == 1.0 + + predicted_idx = prediction.max_confidence_index + assert domain.action_names_or_texts[predicted_idx] == "action_unlikely_intent" + # check that the stack was updated + assert prediction.optional_events == [ + SlotSet( + "dialogue_stack", + [ + { + "frame_id": "some-id", + "flow_id": "foo_flow", + "step_id": "2", + "frame_type": "regular", + "type": "flow", + } + ], + ) + ] + + +def test_executor_trips_internal_circuit_breaker(): + flow_with_loop = flows_from_str( + """ + flows: + foo_flow: + steps: + - id: "1" + set_slot: + foo: bar + next: "2" + - id: "2" + set_slot: + foo: barbar + next: "1" + """ + ) + + domain = Domain.empty() + + stack = DialogueStack( + frames=[UserFlowStackFrame(flow_id="foo_flow", step_id="1", frame_id="some-id")] + ) + + tracker = DialogueStateTracker.from_events( + "test", + evts=[ActionExecuted(action_name="action_listen"), stack.persist_as_event()], + domain=domain, + slots=domain.slots, + ) + + executor = FlowExecutor.from_tracker(tracker, flow_with_loop, domain) + + with pytest.raises(FlowCircuitBreakerTrippedException): + executor.select_next_action(tracker) + + +def test_policy_triggers_error_pattern_if_internal_circuit_breaker_is_tripped( + default_flow_policy: FlowPolicy, +): + flow_with_loop = flows_from_str_with_defaults( + """ + flows: + foo_flow: + steps: + - id: "1" + set_slot: + foo: bar + next: "2" + - id: "2" + set_slot: + foo: barbar + next: "1" + """ + ) + + domain = flows_default_domain() + + stack = DialogueStack( + frames=[UserFlowStackFrame(flow_id="foo_flow", step_id="1", frame_id="some-id")] + ) + + tracker = DialogueStateTracker.from_events( + "test", + evts=[ActionExecuted(action_name="action_listen"), stack.persist_as_event()], + domain=domain, + slots=domain.slots, + ) + + prediction = default_flow_policy.predict_action_probabilities( + tracker=tracker, domain=domain, flows=flow_with_loop + ) + + assert prediction.max_confidence == 1.0 + + predicted_idx = prediction.max_confidence_index + assert domain.action_names_or_texts[predicted_idx] == "utter_internal_error_rasa" + # check that the stack was updated. + assert len(prediction.optional_events) == 1 + assert isinstance(prediction.optional_events[0], SlotSet) + + assert prediction.optional_events[0].key == "dialogue_stack" + # the user flow should be on the stack as well as the error pattern + assert len(prediction.optional_events[0].value) == 2 + # the user flow should be about to end + assert prediction.optional_events[0].value[0]["step_id"] == "NEXT:END" + # the pattern should be the other frame + assert prediction.optional_events[0].value[1]["flow_id"] == "pattern_internal_error" + + +def test_executor_does_not_get_tripped_if_an_action_is_predicted_in_loop(): + flow_with_loop = flows_from_str( + """ + flows: + foo_flow: + steps: + - id: "1" + set_slot: + foo: bar + next: "2" + - id: "2" + action: action_listen + next: "1" + """ + ) + + domain = Domain.empty() + + stack = DialogueStack( + frames=[UserFlowStackFrame(flow_id="foo_flow", step_id="1", frame_id="some-id")] + ) + + tracker = DialogueStateTracker.from_events( + "test", + evts=[ActionExecuted(action_name="action_listen"), stack.persist_as_event()], + domain=domain, + slots=domain.slots, + ) + + executor = FlowExecutor.from_tracker(tracker, flow_with_loop, domain) + + selection = executor.select_next_action(tracker) + assert selection.action_name == "action_listen" diff --git a/tests/cdu/__init__.py b/tests/dialogue_understanding/__init__.py similarity index 100% rename from tests/cdu/__init__.py rename to tests/dialogue_understanding/__init__.py diff --git a/tests/cdu/commands/__init__.py b/tests/dialogue_understanding/commands/__init__.py similarity index 100% rename from tests/cdu/commands/__init__.py rename to tests/dialogue_understanding/commands/__init__.py diff --git a/tests/cdu/commands/conftest.py b/tests/dialogue_understanding/commands/conftest.py similarity index 100% rename from tests/cdu/commands/conftest.py rename to tests/dialogue_understanding/commands/conftest.py diff --git a/tests/cdu/commands/test_can_not_handle_command.py b/tests/dialogue_understanding/commands/test_can_not_handle_command.py similarity index 100% rename from tests/cdu/commands/test_can_not_handle_command.py rename to tests/dialogue_understanding/commands/test_can_not_handle_command.py diff --git a/tests/cdu/commands/test_cancel_flow_command.py b/tests/dialogue_understanding/commands/test_cancel_flow_command.py similarity index 100% rename from tests/cdu/commands/test_cancel_flow_command.py rename to tests/dialogue_understanding/commands/test_cancel_flow_command.py diff --git a/tests/cdu/commands/test_chit_chat_answer_command.py b/tests/dialogue_understanding/commands/test_chit_chat_answer_command.py similarity index 100% rename from tests/cdu/commands/test_chit_chat_answer_command.py rename to tests/dialogue_understanding/commands/test_chit_chat_answer_command.py diff --git a/tests/cdu/commands/test_clarify_command.py b/tests/dialogue_understanding/commands/test_clarify_command.py similarity index 100% rename from tests/cdu/commands/test_clarify_command.py rename to tests/dialogue_understanding/commands/test_clarify_command.py diff --git a/tests/cdu/commands/test_command.py b/tests/dialogue_understanding/commands/test_command.py similarity index 100% rename from tests/cdu/commands/test_command.py rename to tests/dialogue_understanding/commands/test_command.py diff --git a/tests/cdu/commands/test_command_processor.py b/tests/dialogue_understanding/commands/test_command_processor.py similarity index 97% rename from tests/cdu/commands/test_command_processor.py rename to tests/dialogue_understanding/commands/test_command_processor.py index e1b32d1b6c2b..0a6c32d6e518 100644 --- a/tests/cdu/commands/test_command_processor.py +++ b/tests/dialogue_understanding/commands/test_command_processor.py @@ -13,7 +13,7 @@ from rasa.shared.core.constants import FLOW_HASHES_SLOT from rasa.shared.core.flows.flow import FlowsList from rasa.shared.core.trackers import DialogueStateTracker -from tests.cdu.commands.conftest import start_bar_user_uttered +from tests.dialogue_understanding.commands.conftest import start_bar_user_uttered from tests.utilities import flows_from_str diff --git a/tests/cdu/commands/test_correct_slots_command.py b/tests/dialogue_understanding/commands/test_correct_slots_command.py similarity index 100% rename from tests/cdu/commands/test_correct_slots_command.py rename to tests/dialogue_understanding/commands/test_correct_slots_command.py diff --git a/tests/cdu/commands/test_error_command.py b/tests/dialogue_understanding/commands/test_error_command.py similarity index 100% rename from tests/cdu/commands/test_error_command.py rename to tests/dialogue_understanding/commands/test_error_command.py diff --git a/tests/cdu/commands/test_handle_code_change_command.py b/tests/dialogue_understanding/commands/test_handle_code_change_command.py similarity index 98% rename from tests/cdu/commands/test_handle_code_change_command.py rename to tests/dialogue_understanding/commands/test_handle_code_change_command.py index 6ab56430d9a4..fd7fbb7ad9b2 100644 --- a/tests/cdu/commands/test_handle_code_change_command.py +++ b/tests/dialogue_understanding/commands/test_handle_code_change_command.py @@ -23,7 +23,7 @@ END_STEP, ) from rasa.shared.core.trackers import DialogueStateTracker -from tests.cdu.commands.test_command_processor import ( +from tests.dialogue_understanding.commands.test_command_processor import ( start_bar_user_uttered, change_cases, ) diff --git a/tests/cdu/commands/test_human_handoff_command.py b/tests/dialogue_understanding/commands/test_human_handoff_command.py similarity index 100% rename from tests/cdu/commands/test_human_handoff_command.py rename to tests/dialogue_understanding/commands/test_human_handoff_command.py diff --git a/tests/cdu/commands/test_konwledge_answer_command.py b/tests/dialogue_understanding/commands/test_konwledge_answer_command.py similarity index 100% rename from tests/cdu/commands/test_konwledge_answer_command.py rename to tests/dialogue_understanding/commands/test_konwledge_answer_command.py diff --git a/tests/cdu/commands/test_set_slot_command.py b/tests/dialogue_understanding/commands/test_set_slot_command.py similarity index 100% rename from tests/cdu/commands/test_set_slot_command.py rename to tests/dialogue_understanding/commands/test_set_slot_command.py diff --git a/tests/cdu/commands/test_start_flow_command.py b/tests/dialogue_understanding/commands/test_start_flow_command.py similarity index 100% rename from tests/cdu/commands/test_start_flow_command.py rename to tests/dialogue_understanding/commands/test_start_flow_command.py diff --git a/tests/cdu/generator/__init__.py b/tests/dialogue_understanding/generator/__init__.py similarity index 100% rename from tests/cdu/generator/__init__.py rename to tests/dialogue_understanding/generator/__init__.py diff --git a/tests/cdu/generator/test_command_generator.py b/tests/dialogue_understanding/generator/test_command_generator.py similarity index 100% rename from tests/cdu/generator/test_command_generator.py rename to tests/dialogue_understanding/generator/test_command_generator.py diff --git a/tests/cdu/stack/__init__.py b/tests/dialogue_understanding/stack/__init__.py similarity index 100% rename from tests/cdu/stack/__init__.py rename to tests/dialogue_understanding/stack/__init__.py diff --git a/tests/cdu/stack/frames/__init__.py b/tests/dialogue_understanding/stack/frames/__init__.py similarity index 100% rename from tests/cdu/stack/frames/__init__.py rename to tests/dialogue_understanding/stack/frames/__init__.py diff --git a/tests/cdu/stack/frames/test_chit_chat_frame.py b/tests/dialogue_understanding/stack/frames/test_chit_chat_frame.py similarity index 100% rename from tests/cdu/stack/frames/test_chit_chat_frame.py rename to tests/dialogue_understanding/stack/frames/test_chit_chat_frame.py diff --git a/tests/cdu/stack/frames/test_dialogue_stack_frame.py b/tests/dialogue_understanding/stack/frames/test_dialogue_stack_frame.py similarity index 100% rename from tests/cdu/stack/frames/test_dialogue_stack_frame.py rename to tests/dialogue_understanding/stack/frames/test_dialogue_stack_frame.py diff --git a/tests/cdu/stack/frames/test_flow_frame.py b/tests/dialogue_understanding/stack/frames/test_flow_frame.py similarity index 100% rename from tests/cdu/stack/frames/test_flow_frame.py rename to tests/dialogue_understanding/stack/frames/test_flow_frame.py diff --git a/tests/cdu/stack/frames/test_search_frame.py b/tests/dialogue_understanding/stack/frames/test_search_frame.py similarity index 100% rename from tests/cdu/stack/frames/test_search_frame.py rename to tests/dialogue_understanding/stack/frames/test_search_frame.py diff --git a/tests/cdu/stack/test_dialogue_stack.py b/tests/dialogue_understanding/stack/test_dialogue_stack.py similarity index 90% rename from tests/cdu/stack/test_dialogue_stack.py rename to tests/dialogue_understanding/stack/test_dialogue_stack.py index 25769624eb0e..fa0859f3a314 100644 --- a/tests/cdu/stack/test_dialogue_stack.py +++ b/tests/dialogue_understanding/stack/test_dialogue_stack.py @@ -4,6 +4,8 @@ ) from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack from rasa.dialogue_understanding.stack.frames.flow_stack_frame import UserFlowStackFrame +from rasa.shared.core.constants import DIALOGUE_STACK_SLOT +from rasa.shared.core.events import SlotSet def test_dialogue_stack_from_dict(): @@ -75,6 +77,24 @@ def test_dialogue_stack_as_dict(): ] +def test_dialogue_stack_as_event(): + # check that the stack gets persisted as an event storing the dict + stack = DialogueStack( + frames=[ + UserFlowStackFrame( + flow_id="foo", step_id="first_step", frame_id="some-frame-id" + ), + CollectInformationPatternFlowStackFrame( + collect="foo", + frame_id="some-other-id", + utter="utter_ask_foo", + ), + ] + ) + + assert stack.persist_as_event() == SlotSet(DIALOGUE_STACK_SLOT, stack.as_dict()) + + def test_dialogue_stack_as_dict_handles_empty(): stack = DialogueStack(frames=[]) assert stack.as_dict() == [] diff --git a/tests/cdu/stack/test_utils.py b/tests/dialogue_understanding/stack/test_utils.py similarity index 79% rename from tests/cdu/stack/test_utils.py rename to tests/dialogue_understanding/stack/test_utils.py index 4d7cf485522f..793132f17721 100644 --- a/tests/cdu/stack/test_utils.py +++ b/tests/dialogue_understanding/stack/test_utils.py @@ -5,6 +5,7 @@ from rasa.dialogue_understanding.stack.frames.chit_chat_frame import ChitChatStackFrame from rasa.dialogue_understanding.stack.frames.flow_stack_frame import UserFlowStackFrame from rasa.dialogue_understanding.stack.utils import ( + end_top_user_flow, filled_slots_for_active_flow, top_flow_frame, top_user_flow_frame, @@ -199,3 +200,52 @@ def test_filled_slots_for_active_flow_only_collects_till_top_most_user_flow_fram stack = DialogueStack(frames=[another_user_frame, user_frame]) assert filled_slots_for_active_flow(stack, all_flows) == {"foo", "bar"} + + +def test_end_top_user_flow(): + user_frame = UserFlowStackFrame( + flow_id="my_flow", step_id="collect_bar", frame_id="some-frame-id" + ) + pattern_frame = CollectInformationPatternFlowStackFrame( + collect="foo", frame_id="some-other-id" + ) + stack = DialogueStack(frames=[user_frame, pattern_frame]) + + end_top_user_flow(stack) + + assert len(stack.frames) == 2 + + assert stack.frames[0] == UserFlowStackFrame( + flow_id="my_flow", step_id="NEXT:END", frame_id="some-frame-id" + ) + assert stack.frames[1] == CollectInformationPatternFlowStackFrame( + collect="foo", frame_id="some-other-id", step_id="NEXT:END" + ) + + +def test_end_top_user_flow_only_ends_topmost_user_frame(): + user_frame = UserFlowStackFrame( + flow_id="my_flow", step_id="collect_bar", frame_id="some-frame-id" + ) + other_user_frame = UserFlowStackFrame( + flow_id="my_other_flow", step_id="collect_bar2", frame_id="some-other-id" + ) + stack = DialogueStack(frames=[other_user_frame, user_frame]) + + end_top_user_flow(stack) + + assert len(stack.frames) == 2 + + assert stack.frames[0] == UserFlowStackFrame( + flow_id="my_other_flow", step_id="collect_bar2", frame_id="some-other-id" + ) + assert stack.frames[1] == UserFlowStackFrame( + flow_id="my_flow", step_id="NEXT:END", frame_id="some-frame-id" + ) + + +def test_end_top_user_flow_handles_empty(): + stack = DialogueStack(frames=[]) + end_top_user_flow(stack) + + assert len(stack.frames) == 0 diff --git a/tests/utilities.py b/tests/utilities.py index f803ec1da0ab..c2274c4be227 100644 --- a/tests/utilities.py +++ b/tests/utilities.py @@ -1,7 +1,9 @@ from yarl import URL import textwrap +from rasa.shared.core.domain import Domain from rasa.shared.core.flows.flow import FlowsList from rasa.shared.core.flows.yaml_flows_io import YAMLFlowsReader +from rasa.shared.importers.importer import FlowSyncImporter def latest_request(mocked, request_type, path): @@ -15,3 +17,13 @@ def json_of_latest_request(r): def flows_from_str(yaml_str: str) -> FlowsList: """Reads flows from a YAML string.""" return YAMLFlowsReader.read_from_string(textwrap.dedent(yaml_str)) + + +def flows_from_str_with_defaults(yaml_str: str) -> FlowsList: + """Reads flows from a YAML string and includes buildin flows.""" + return FlowSyncImporter.merge_with_default_flows(flows_from_str(yaml_str)) + + +def flows_default_domain() -> Domain: + """Returns the default domain for the default flows.""" + return FlowSyncImporter.load_default_pattern_flows_domain()