Skip to content

Commit

Permalink
adapt to list format of checks, add new default action for running pr…
Browse files Browse the repository at this point in the history
…edicates
  • Loading branch information
ancalita committed Sep 18, 2023
1 parent ddc08d0 commit 07251a3
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 41 deletions.
4 changes: 4 additions & 0 deletions rasa/core/actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def default_actions(action_endpoint: Optional[EndpointConfig] = None) -> List["A
from rasa.dialogue_understanding.patterns.correction import ActionCorrectFlowSlot
from rasa.dialogue_understanding.patterns.cancel import ActionCancelFlow
from rasa.dialogue_understanding.patterns.clarify import ActionClarifyFlows
from rasa.core.actions.evaluate_predicate_rejections_action import (
ActionEvaluatePredicateRejection,
)

return [
ActionListen(),
Expand All @@ -118,6 +121,7 @@ def default_actions(action_endpoint: Optional[EndpointConfig] = None) -> List["A
ActionCancelFlow(),
ActionCorrectFlowSlot(),
ActionClarifyFlows(),
ActionEvaluatePredicateRejection(),
]


Expand Down
109 changes: 109 additions & 0 deletions rasa/core/actions/evaluate_predicate_rejections_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Text

import structlog
from jinja2 import Template
from pypred import Predicate

from rasa.core.actions.action import Action, create_bot_utterance
from rasa.dialogue_understanding.patterns.collect_information import (
CollectInformationPatternFlowStackFrame,
)
from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack
from rasa.shared.core.constants import ACTION_EVALUATE_PREDICATE_REJECTIONS_NAME
from rasa.shared.core.events import Event, SlotSet

if TYPE_CHECKING:
from rasa.core.nlg import NaturalLanguageGenerator
from rasa.core.channels.channel import OutputChannel
from rasa.shared.core.domain import Domain
from rasa.shared.core.trackers import DialogueStateTracker

structlogger = structlog.get_logger()


class ActionEvaluatePredicateRejection(Action):
"""Action which evaluates the predicate checks under rejections."""

def name(self) -> Text:
"""Return the name of the action."""
return ACTION_EVALUATE_PREDICATE_REJECTIONS_NAME

async def run(
self,
output_channel: "OutputChannel",
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
metadata: Optional[Dict[Text, Any]] = None,
) -> List[Event]:
"""Run the predicate checks."""
events: List[Event] = []

dialogue_stack = DialogueStack.from_tracker(tracker)
top_frame = dialogue_stack.top()
if not isinstance(top_frame, CollectInformationPatternFlowStackFrame):
return []

if top_frame.rejections is None:
return []

slot_name = top_frame.collect_information
slot_value = tracker.get_slot(slot_name)

current_context = dialogue_stack.current_context()
current_context[slot_name] = slot_value

structlogger.debug("collect.predicate.context", context=current_context)
document = current_context.copy()

for rejection in top_frame.rejections:
check_text = rejection.get("if")
utterance = rejection.get("utter")
rendered_template = Template(check_text).render(current_context)
predicate = Predicate(rendered_template)
try:
result = predicate.evaluate(document)
structlogger.debug(
"collect.predicate.result",
result=result,
)
except (TypeError, Exception) as e:
structlogger.error(
"collect.predicate.error",
predicate=predicate,
document=document,
error=str(e),
)
continue

if result is False:
continue

if current_context.get("number_of_tries", 0) < 2:
# reset slot value that was initially filled with an invalid value
events.append(SlotSet(top_frame.collect_information, None))

if utterance is None:
structlogger.debug(
"collect.rejection.missing.utter",
predicate=predicate,
document=document,
)
break

message = await nlg.generate(
utterance,
tracker,
output_channel.name(),
)
if message is None:
structlogger.debug(
"collect.rejection.failed.finding.utter",
utterance=utterance,
)
else:
message["utter_action"] = utterance
events.append(create_bot_utterance(message))
return events

return events
30 changes: 10 additions & 20 deletions rasa/core/policies/flow_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def predict_action_probabilities(
domain: The model's domain.
rule_only_data: Slots and loops which are specific to rules and hence
should be ignored by this policy.
flows: The flows to use.
**kwargs: Depending on the specified `needs` section and the resulting
graph structure the policy can use different input to make predictions.
Expand Down Expand Up @@ -208,7 +209,7 @@ def _create_prediction_result(
domain: The model's domain.
score: The score of the predicted action.
Resturns:
Returns:
The prediction result where the score is used for one hot encoding.
"""
result = self._default_predictions(domain)
Expand Down Expand Up @@ -372,7 +373,7 @@ def render_template_variables(text: str, context: Dict[Text, Any]) -> str:
return Template(text).render(context)

def _slot_for_collect_information(self, collect_information: Text) -> Slot:
"""Find the slot for a collect information."""
"""Find the slot for the collect information step."""
for slot in self.domain.slots:
if slot.name == collect_information:
return slot
Expand Down Expand Up @@ -416,7 +417,6 @@ def advance_flows(self, tracker: DialogueStateTracker) -> ActionPrediction:
Args:
tracker: The tracker to get the next action for.
domain: The domain to get the next action for.
Returns:
The predicted action and the events to run.
Expand Down Expand Up @@ -457,7 +457,6 @@ def _select_next_action(
Args:
tracker: The tracker to get the next action for.
domain: The domain to get the next action for.
Returns:
The next action to execute, the events that should be applied to the
Expand Down Expand Up @@ -554,7 +553,7 @@ def _run_step(
if isinstance(step, CollectInformationFlowStep):
structlogger.debug("flow.step.run.collect_information")
self.trigger_pattern_ask_collect_information(
step.collect_information, step.validation
step.collect_information, step.rejections
)

# reset the slot if its already filled and the collect information shouldn't
Expand Down Expand Up @@ -583,7 +582,7 @@ def _run_step(
# the question has been asked
top_frame = self.dialogue_stack.top()
if isinstance(top_frame, CollectInformationPatternFlowStackFrame):
top_frame.number_of_retries = top_frame.number_of_retries + 1
top_frame.number_of_tries = top_frame.number_of_tries + 1

return PauseFlowReturnPrediction(ActionPrediction(action_name, 1.0))
else:
Expand Down Expand Up @@ -693,27 +692,18 @@ def trigger_pattern_completed(self, current_frame: DialogueStackFrame) -> None:
def trigger_pattern_ask_collect_information(
self,
collect_information: str,
validation: Optional[Dict[Text, Any]],
rejections: Optional[List[Dict[Text, Any]]],
) -> None:
context = self.dialogue_stack.current_context().copy()
number_of_retries = context.get("number_of_retries", 0)

if validation and "valid_message" not in validation:
validation["valid_message"] = "null"

default_validation = {
"condition": "true",
"valid_message": "null",
"invalid_message": "null",
}
number_of_tries = context.get("number_of_tries", 0)

step_validation = validation if validation else default_validation
slot_value_rejections = rejections if rejections else []

self.dialogue_stack.push(
CollectInformationPatternFlowStackFrame(
collect_information=collect_information,
number_of_retries=number_of_retries,
validation=step_validation,
number_of_tries=number_of_tries,
rejections=slot_value_rejections,
)
)

Expand Down
15 changes: 9 additions & 6 deletions rasa/dialogue_understanding/patterns/collect_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ class CollectInformationPatternFlowStackFrame(PatternFlowStackFrame):
collect_information: str = ""
"""The information that should be collected from the user.
this corresponds to the slot that will be filled."""
number_of_retries: int = 0
"""The number of times the question is being re-asked to fill the slot."""
validation: Optional[Dict[str, Any]] = None
"""The predicate validation that should be applied to the collected information."""
number_of_tries: int = 0
"""The number of times the question is being asked to fill the slot."""
rejections: Optional[List[Dict[str, Any]]] = None
"""The predicate check that should be applied to the collected information.
If a predicate check fails, its `utter` action indicated under rejections
will be executed.
"""

@classmethod
def type(cls) -> str:
Expand All @@ -44,8 +47,8 @@ def from_dict(data: Dict[str, Any]) -> CollectInformationPatternFlowStackFrame:
data["frame_id"],
step_id=data["step_id"],
collect_information=data["collect_information"],
number_of_retries=data.get("number_of_retries", 0),
validation=data.get("validation"),
number_of_tries=data.get("number_of_tries", 0),
rejections=data.get("rejections"),
)

def context_as_dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,20 +112,12 @@ flows:
- id: "start"
action: action_extract_slots
next:
- if: "{{context.collect_information}} is null and {{context.number_of_retries}} < 2"
- if: "{{context.collect_information}} is null and {{context.number_of_tries}} < 2"
then: "ask_collect_information"
- if: "{{context.validation.condition}}"
then: "valid_message"
- else: "invalid_message"
- id: "valid_message"
action: "{{context.validation.valid_message}}"
- else: "evaluate_predicate_rejections"
- id: "evaluate_predicate_rejections"
action: action_evaluate_predicate_rejections
next: "validate"
- id: "invalid_message"
action: "{{context.validation.invalid_message}}"
next:
- if: "{{context.number_of_retries}} < 2"
then: "ask_collect_information"
- else: "done"
- id: "validate"
action: validate_{{context.collect_information}}
next:
Expand Down
2 changes: 2 additions & 0 deletions rasa/shared/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
ACTION_CANCEL_FLOW = "action_cancel_flow"
ACTION_CLARIFY_FLOWS = "action_clarify_flows"
ACTION_CORRECT_FLOW_SLOT = "action_correct_flow_slot"
ACTION_EVALUATE_PREDICATE_REJECTIONS_NAME = "action_evaluate_predicate_rejections"


DEFAULT_ACTION_NAMES = [
Expand All @@ -60,6 +61,7 @@
ACTION_CANCEL_FLOW,
ACTION_CORRECT_FLOW_SLOT,
ACTION_CLARIFY_FLOWS,
ACTION_EVALUATE_PREDICATE_REJECTIONS_NAME,
]

ACTION_SHOULD_SEND_DOMAIN = "send_domain"
Expand Down
6 changes: 3 additions & 3 deletions rasa/shared/core/flows/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ class CollectInformationFlowStep(FlowStep):
"""Whether to always ask the question even if the slot is already filled."""
scope: CollectInformationScope = CollectInformationScope.FLOW
"""how the question is scoped, determines when to reset its value."""
validation: Optional[Dict[Text, Any]] = None
rejections: Optional[List[Dict[Text, Any]]] = None
"""how the slot value is validated using predicate evaluation."""

@classmethod
Expand All @@ -1009,7 +1009,7 @@ def from_json(cls, flow_step_config: Dict[Text, Any]) -> CollectInformationFlowS
collect_information=flow_step_config.get("collect_information", ""),
ask_before_filling=flow_step_config.get("ask_before_filling", False),
scope=CollectInformationScope.from_str(flow_step_config.get("scope")),
validation=flow_step_config.get("validation"),
rejections=flow_step_config.get("rejections"),
**base.__dict__,
)

Expand All @@ -1023,7 +1023,7 @@ def as_json(self) -> Dict[Text, Any]:
dump["collect_information"] = self.collect_information
dump["ask_before_filling"] = self.ask_before_filling
dump["scope"] = self.scope.value
dump["validation"] = self.validation
dump["rejections"] = self.rejections

return dump

Expand Down

0 comments on commit 07251a3

Please sign in to comment.