From 07251a3c8292bca72829e45c6646195ba5a9079f Mon Sep 17 00:00:00 2001 From: Anca Lita <27920906+ancalita@users.noreply.github.com> Date: Mon, 18 Sep 2023 18:10:43 +0100 Subject: [PATCH] adapt to list format of checks, add new default action for running predicates --- rasa/core/actions/action.py | 4 + .../evaluate_predicate_rejections_action.py | 109 ++++++++++++++++++ rasa/core/policies/flow_policy.py | 30 ++--- .../patterns/collect_information.py | 15 ++- .../patterns/default_flows_for_patterns.yml | 16 +-- rasa/shared/core/constants.py | 2 + rasa/shared/core/flows/flow.py | 6 +- 7 files changed, 141 insertions(+), 41 deletions(-) create mode 100644 rasa/core/actions/evaluate_predicate_rejections_action.py diff --git a/rasa/core/actions/action.py b/rasa/core/actions/action.py index 65f5e51cd444..e90fe8f7ac67 100644 --- a/rasa/core/actions/action.py +++ b/rasa/core/actions/action.py @@ -100,6 +100,9 @@ def default_actions(action_endpoint: Optional[EndpointConfig] = None) -> List["A from rasa.dialogue_understanding.patterns.correction import ActionCorrectFlowSlot from rasa.dialogue_understanding.patterns.cancel import ActionCancelFlow from rasa.dialogue_understanding.patterns.clarify import ActionClarifyFlows + from rasa.core.actions.evaluate_predicate_rejections_action import ( + ActionEvaluatePredicateRejection, + ) return [ ActionListen(), @@ -118,6 +121,7 @@ def default_actions(action_endpoint: Optional[EndpointConfig] = None) -> List["A ActionCancelFlow(), ActionCorrectFlowSlot(), ActionClarifyFlows(), + ActionEvaluatePredicateRejection(), ] diff --git a/rasa/core/actions/evaluate_predicate_rejections_action.py b/rasa/core/actions/evaluate_predicate_rejections_action.py new file mode 100644 index 000000000000..52098ad42a0a --- /dev/null +++ b/rasa/core/actions/evaluate_predicate_rejections_action.py @@ -0,0 +1,109 @@ +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Text + +import structlog +from jinja2 import Template +from pypred import Predicate + +from rasa.core.actions.action import Action, create_bot_utterance +from rasa.dialogue_understanding.patterns.collect_information import ( + CollectInformationPatternFlowStackFrame, +) +from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack +from rasa.shared.core.constants import ACTION_EVALUATE_PREDICATE_REJECTIONS_NAME +from rasa.shared.core.events import Event, SlotSet + +if TYPE_CHECKING: + from rasa.core.nlg import NaturalLanguageGenerator + from rasa.core.channels.channel import OutputChannel + from rasa.shared.core.domain import Domain + from rasa.shared.core.trackers import DialogueStateTracker + +structlogger = structlog.get_logger() + + +class ActionEvaluatePredicateRejection(Action): + """Action which evaluates the predicate checks under rejections.""" + + def name(self) -> Text: + """Return the name of the action.""" + return ACTION_EVALUATE_PREDICATE_REJECTIONS_NAME + + async def run( + self, + output_channel: "OutputChannel", + nlg: "NaturalLanguageGenerator", + tracker: "DialogueStateTracker", + domain: "Domain", + metadata: Optional[Dict[Text, Any]] = None, + ) -> List[Event]: + """Run the predicate checks.""" + events: List[Event] = [] + + dialogue_stack = DialogueStack.from_tracker(tracker) + top_frame = dialogue_stack.top() + if not isinstance(top_frame, CollectInformationPatternFlowStackFrame): + return [] + + if top_frame.rejections is None: + return [] + + slot_name = top_frame.collect_information + slot_value = tracker.get_slot(slot_name) + + current_context = dialogue_stack.current_context() + current_context[slot_name] = slot_value + + structlogger.debug("collect.predicate.context", context=current_context) + document = current_context.copy() + + for rejection in top_frame.rejections: + check_text = rejection.get("if") + utterance = rejection.get("utter") + rendered_template = Template(check_text).render(current_context) + predicate = Predicate(rendered_template) + try: + result = predicate.evaluate(document) + structlogger.debug( + "collect.predicate.result", + result=result, + ) + except (TypeError, Exception) as e: + structlogger.error( + "collect.predicate.error", + predicate=predicate, + document=document, + error=str(e), + ) + continue + + if result is False: + continue + + if current_context.get("number_of_tries", 0) < 2: + # reset slot value that was initially filled with an invalid value + events.append(SlotSet(top_frame.collect_information, None)) + + if utterance is None: + structlogger.debug( + "collect.rejection.missing.utter", + predicate=predicate, + document=document, + ) + break + + message = await nlg.generate( + utterance, + tracker, + output_channel.name(), + ) + if message is None: + structlogger.debug( + "collect.rejection.failed.finding.utter", + utterance=utterance, + ) + else: + message["utter_action"] = utterance + events.append(create_bot_utterance(message)) + return events + + return events diff --git a/rasa/core/policies/flow_policy.py b/rasa/core/policies/flow_policy.py index 97796c6ace44..e04b8efee7c9 100644 --- a/rasa/core/policies/flow_policy.py +++ b/rasa/core/policies/flow_policy.py @@ -171,6 +171,7 @@ def predict_action_probabilities( domain: The model's domain. rule_only_data: Slots and loops which are specific to rules and hence should be ignored by this policy. + flows: The flows to use. **kwargs: Depending on the specified `needs` section and the resulting graph structure the policy can use different input to make predictions. @@ -208,7 +209,7 @@ def _create_prediction_result( domain: The model's domain. score: The score of the predicted action. - Resturns: + Returns: The prediction result where the score is used for one hot encoding. """ result = self._default_predictions(domain) @@ -372,7 +373,7 @@ def render_template_variables(text: str, context: Dict[Text, Any]) -> str: return Template(text).render(context) def _slot_for_collect_information(self, collect_information: Text) -> Slot: - """Find the slot for a collect information.""" + """Find the slot for the collect information step.""" for slot in self.domain.slots: if slot.name == collect_information: return slot @@ -416,7 +417,6 @@ def advance_flows(self, tracker: DialogueStateTracker) -> ActionPrediction: Args: tracker: The tracker to get the next action for. - domain: The domain to get the next action for. Returns: The predicted action and the events to run. @@ -457,7 +457,6 @@ def _select_next_action( Args: tracker: The tracker to get the next action for. - domain: The domain to get the next action for. Returns: The next action to execute, the events that should be applied to the @@ -554,7 +553,7 @@ def _run_step( if isinstance(step, CollectInformationFlowStep): structlogger.debug("flow.step.run.collect_information") self.trigger_pattern_ask_collect_information( - step.collect_information, step.validation + step.collect_information, step.rejections ) # reset the slot if its already filled and the collect information shouldn't @@ -583,7 +582,7 @@ def _run_step( # the question has been asked top_frame = self.dialogue_stack.top() if isinstance(top_frame, CollectInformationPatternFlowStackFrame): - top_frame.number_of_retries = top_frame.number_of_retries + 1 + top_frame.number_of_tries = top_frame.number_of_tries + 1 return PauseFlowReturnPrediction(ActionPrediction(action_name, 1.0)) else: @@ -693,27 +692,18 @@ def trigger_pattern_completed(self, current_frame: DialogueStackFrame) -> None: def trigger_pattern_ask_collect_information( self, collect_information: str, - validation: Optional[Dict[Text, Any]], + rejections: Optional[List[Dict[Text, Any]]], ) -> None: context = self.dialogue_stack.current_context().copy() - number_of_retries = context.get("number_of_retries", 0) - - if validation and "valid_message" not in validation: - validation["valid_message"] = "null" - - default_validation = { - "condition": "true", - "valid_message": "null", - "invalid_message": "null", - } + number_of_tries = context.get("number_of_tries", 0) - step_validation = validation if validation else default_validation + slot_value_rejections = rejections if rejections else [] self.dialogue_stack.push( CollectInformationPatternFlowStackFrame( collect_information=collect_information, - number_of_retries=number_of_retries, - validation=step_validation, + number_of_tries=number_of_tries, + rejections=slot_value_rejections, ) ) diff --git a/rasa/dialogue_understanding/patterns/collect_information.py b/rasa/dialogue_understanding/patterns/collect_information.py index 73361185e4ea..3c8d7c9629d2 100644 --- a/rasa/dialogue_understanding/patterns/collect_information.py +++ b/rasa/dialogue_understanding/patterns/collect_information.py @@ -20,10 +20,13 @@ class CollectInformationPatternFlowStackFrame(PatternFlowStackFrame): collect_information: str = "" """The information that should be collected from the user. this corresponds to the slot that will be filled.""" - number_of_retries: int = 0 - """The number of times the question is being re-asked to fill the slot.""" - validation: Optional[Dict[str, Any]] = None - """The predicate validation that should be applied to the collected information.""" + number_of_tries: int = 0 + """The number of times the question is being asked to fill the slot.""" + rejections: Optional[List[Dict[str, Any]]] = None + """The predicate check that should be applied to the collected information. + If a predicate check fails, its `utter` action indicated under rejections + will be executed. + """ @classmethod def type(cls) -> str: @@ -44,8 +47,8 @@ def from_dict(data: Dict[str, Any]) -> CollectInformationPatternFlowStackFrame: data["frame_id"], step_id=data["step_id"], collect_information=data["collect_information"], - number_of_retries=data.get("number_of_retries", 0), - validation=data.get("validation"), + number_of_tries=data.get("number_of_tries", 0), + rejections=data.get("rejections"), ) def context_as_dict( diff --git a/rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml b/rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml index 4edd4e540f6a..f98eebf203e5 100644 --- a/rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +++ b/rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml @@ -112,20 +112,12 @@ flows: - id: "start" action: action_extract_slots next: - - if: "{{context.collect_information}} is null and {{context.number_of_retries}} < 2" + - if: "{{context.collect_information}} is null and {{context.number_of_tries}} < 2" then: "ask_collect_information" - - if: "{{context.validation.condition}}" - then: "valid_message" - - else: "invalid_message" - - id: "valid_message" - action: "{{context.validation.valid_message}}" + - else: "evaluate_predicate_rejections" + - id: "evaluate_predicate_rejections" + action: action_evaluate_predicate_rejections next: "validate" - - id: "invalid_message" - action: "{{context.validation.invalid_message}}" - next: - - if: "{{context.number_of_retries}} < 2" - then: "ask_collect_information" - - else: "done" - id: "validate" action: validate_{{context.collect_information}} next: diff --git a/rasa/shared/core/constants.py b/rasa/shared/core/constants.py index f548b3ddd22d..052a1b311ba7 100644 --- a/rasa/shared/core/constants.py +++ b/rasa/shared/core/constants.py @@ -40,6 +40,7 @@ ACTION_CANCEL_FLOW = "action_cancel_flow" ACTION_CLARIFY_FLOWS = "action_clarify_flows" ACTION_CORRECT_FLOW_SLOT = "action_correct_flow_slot" +ACTION_EVALUATE_PREDICATE_REJECTIONS_NAME = "action_evaluate_predicate_rejections" DEFAULT_ACTION_NAMES = [ @@ -60,6 +61,7 @@ ACTION_CANCEL_FLOW, ACTION_CORRECT_FLOW_SLOT, ACTION_CLARIFY_FLOWS, + ACTION_EVALUATE_PREDICATE_REJECTIONS_NAME, ] ACTION_SHOULD_SEND_DOMAIN = "send_domain" diff --git a/rasa/shared/core/flows/flow.py b/rasa/shared/core/flows/flow.py index 72c248a7a050..9ad875ae2cf3 100644 --- a/rasa/shared/core/flows/flow.py +++ b/rasa/shared/core/flows/flow.py @@ -991,7 +991,7 @@ class CollectInformationFlowStep(FlowStep): """Whether to always ask the question even if the slot is already filled.""" scope: CollectInformationScope = CollectInformationScope.FLOW """how the question is scoped, determines when to reset its value.""" - validation: Optional[Dict[Text, Any]] = None + rejections: Optional[List[Dict[Text, Any]]] = None """how the slot value is validated using predicate evaluation.""" @classmethod @@ -1009,7 +1009,7 @@ def from_json(cls, flow_step_config: Dict[Text, Any]) -> CollectInformationFlowS collect_information=flow_step_config.get("collect_information", ""), ask_before_filling=flow_step_config.get("ask_before_filling", False), scope=CollectInformationScope.from_str(flow_step_config.get("scope")), - validation=flow_step_config.get("validation"), + rejections=flow_step_config.get("rejections"), **base.__dict__, ) @@ -1023,7 +1023,7 @@ def as_json(self) -> Dict[Text, Any]: dump["collect_information"] = self.collect_information dump["ask_before_filling"] = self.ask_before_filling dump["scope"] = self.scope.value - dump["validation"] = self.validation + dump["rejections"] = self.rejections return dump