Skip to content

Commit

Permalink
Merge branch 'dm2' into internal
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbo committed Sep 22, 2023
2 parents ab7ecf0 + ed71677 commit bb3aea8
Show file tree
Hide file tree
Showing 17 changed files with 843 additions and 38 deletions.
4 changes: 2 additions & 2 deletions rasa/cli/project_templates/tutorial/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from rasa_sdk.events import SlotSet


class ActionSufficientFunds(Action):
class ActionCheckSufficientFunds(Action):
def name(self) -> Text:
return "action_sufficient_funds"
return "action_check_sufficient_funds"

def run(
self,
Expand Down
2 changes: 1 addition & 1 deletion rasa/cli/project_templates/tutorial/domain.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ responses:
- text: "How much money would you like to send?"

utter_transfer_complete:
- text: "All done. ${amount} has been sent to {recipient}."
- text: "All done. {amount} has been sent to {recipient}."
4 changes: 4 additions & 0 deletions rasa/core/actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def default_actions(action_endpoint: Optional[EndpointConfig] = None) -> List["A
from rasa.dialogue_understanding.patterns.correction import ActionCorrectFlowSlot
from rasa.dialogue_understanding.patterns.cancel import ActionCancelFlow
from rasa.dialogue_understanding.patterns.clarify import ActionClarifyFlows
from rasa.core.actions.action_run_slot_rejections import (
ActionRunSlotRejections,
)

return [
ActionListen(),
Expand All @@ -118,6 +121,7 @@ def default_actions(action_endpoint: Optional[EndpointConfig] = None) -> List["A
ActionCancelFlow(),
ActionCorrectFlowSlot(),
ActionClarifyFlows(),
ActionRunSlotRejections(),
]


Expand Down
131 changes: 131 additions & 0 deletions rasa/core/actions/action_run_slot_rejections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Text

import structlog
from jinja2 import Template
from pypred import Predicate

from rasa.core.actions.action import Action, create_bot_utterance
from rasa.dialogue_understanding.patterns.collect_information import (
CollectInformationPatternFlowStackFrame,
)
from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack
from rasa.shared.core.constants import ACTION_RUN_SLOT_REJECTIONS_NAME
from rasa.shared.core.events import Event, SlotSet

if TYPE_CHECKING:
from rasa.core.nlg import NaturalLanguageGenerator
from rasa.core.channels.channel import OutputChannel
from rasa.shared.core.domain import Domain
from rasa.shared.core.trackers import DialogueStateTracker

structlogger = structlog.get_logger()


class ActionRunSlotRejections(Action):
"""Action which evaluates the predicate checks under rejections."""

def name(self) -> Text:
"""Return the name of the action."""
return ACTION_RUN_SLOT_REJECTIONS_NAME

async def run(
self,
output_channel: "OutputChannel",
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
metadata: Optional[Dict[Text, Any]] = None,
) -> List[Event]:
"""Run the predicate checks."""
events: List[Event] = []
violation = False
utterance = None
internal_error = False

dialogue_stack = DialogueStack.from_tracker(tracker)
top_frame = dialogue_stack.top()
if not isinstance(top_frame, CollectInformationPatternFlowStackFrame):
return []

if not top_frame.rejections:
return []

slot_name = top_frame.collect_information
slot_instance = tracker.slots.get(slot_name)
if slot_instance and not slot_instance.has_been_set:
# this is the first time the assistant asks for the slot value,
# therefore we skip the predicate validation because the slot
# value has not been provided
structlogger.debug(
"first.collect.slot.not.set",
slot_name=slot_name,
slot_value=slot_instance.value,
)
return []

slot_value = tracker.get_slot(slot_name)

current_context = dialogue_stack.current_context()
current_context[slot_name] = slot_value

structlogger.debug("run.predicate.context", context=current_context)
document = current_context.copy()

for rejection in top_frame.rejections:
condition = rejection.if_
utterance = rejection.utter

try:
rendered_template = Template(condition).render(current_context)
predicate = Predicate(rendered_template)
violation = predicate.evaluate(document)
structlogger.debug(
"run.predicate.result",
predicate=predicate.description(),
violation=violation,
)
except (TypeError, Exception) as e:
structlogger.error(
"run.predicate.error",
predicate=condition,
document=document,
error=str(e),
)
violation = True
internal_error = True

if violation:
break

if not violation:
return []

# reset slot value that was initially filled with an invalid value
events.append(SlotSet(top_frame.collect_information, None))

if internal_error:
utterance = "utter_internal_error_rasa"

if not isinstance(utterance, str):
structlogger.error(
"run.rejection.missing.utter",
utterance=utterance,
)
return events

message = await nlg.generate(
utterance,
tracker,
output_channel.name(),
)

if message is None:
structlogger.error(
"run.rejection.failed.finding.utter",
utterance=utterance,
)
else:
message["utter_action"] = utterance
events.append(create_bot_utterance(message))

return events
36 changes: 25 additions & 11 deletions rasa/core/policies/flow_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
IfFlowLink,
EntryPromptFlowStep,
CollectInformationScope,
SlotRejection,
StepThatCanStartAFlow,
UserMessageStep,
LinkFlowStep,
Expand Down Expand Up @@ -171,6 +172,7 @@ def predict_action_probabilities(
domain: The model's domain.
rule_only_data: Slots and loops which are specific to rules and hence
should be ignored by this policy.
flows: The flows to use.
**kwargs: Depending on the specified `needs` section and the resulting
graph structure the policy can use different input to make predictions.
Expand Down Expand Up @@ -208,7 +210,7 @@ def _create_prediction_result(
domain: The model's domain.
score: The score of the predicted action.
Resturns:
Returns:
The prediction result where the score is used for one hot encoding.
"""
result = self._default_predictions(domain)
Expand Down Expand Up @@ -242,8 +244,9 @@ def __init__(
"""Initializes the `FlowExecutor`.
Args:
dialogue_stack_frame: State of the flow.
dialogue_stack: State of the flow.
all_flows: All flows.
domain: The domain.
"""
self.dialogue_stack = dialogue_stack
self.all_flows = all_flows
Expand All @@ -258,6 +261,7 @@ def from_tracker(
Args:
tracker: The tracker to create the `FlowExecutor` from.
flows: The flows to use.
domain: The domain to use.
Returns:
The created `FlowExecutor`.
Expand All @@ -270,7 +274,6 @@ def find_startable_flow(self, tracker: DialogueStateTracker) -> Optional[Flow]:
Args:
tracker: The tracker containing the conversation history up to now.
flows: The flows to use.
Returns:
The predicted action and the events to run.
Expand All @@ -296,7 +299,7 @@ def is_condition_satisfied(
) -> bool:
"""Evaluate a predicate condition."""

# attach context to the predicate evaluation to allow coditions using it
# attach context to the predicate evaluation to allow conditions using it
context = {"context": DialogueStack.from_tracker(tracker).current_context()}
document: Dict[str, Any] = context.copy()
for slot in self.domain.slots:
Expand Down Expand Up @@ -371,7 +374,7 @@ def render_template_variables(text: str, context: Dict[Text, Any]) -> str:
return Template(text).render(context)

def _slot_for_collect_information(self, collect_information: Text) -> Slot:
"""Find the slot for a collect information."""
"""Find the slot for the collect information step."""
for slot in self.domain.slots:
if slot.name == collect_information:
return slot
Expand Down Expand Up @@ -415,7 +418,6 @@ def advance_flows(self, tracker: DialogueStateTracker) -> ActionPrediction:
Args:
tracker: The tracker to get the next action for.
domain: The domain to get the next action for.
Returns:
The predicted action and the events to run.
Expand Down Expand Up @@ -456,7 +458,6 @@ def _select_next_action(
Args:
tracker: The tracker to get the next action for.
domain: The domain to get the next action for.
Returns:
The next action to execute, the events that should be applied to the
Expand Down Expand Up @@ -552,11 +553,14 @@ def _run_step(
"""
if isinstance(step, CollectInformationFlowStep):
structlogger.debug("flow.step.run.collect_information")
self.trigger_pattern_ask_collect_information(step.collect_information)
self.trigger_pattern_ask_collect_information(
step.collect_information, step.rejections, step.utter
)

# reset the slot if its already filled and the collect infomation shouldn't
# reset the slot if its already filled and the collect information shouldn't
# be skipped
slot = tracker.slots.get(step.collect_information, None)

if slot and slot.has_been_set and step.ask_before_filling:
events = [SlotSet(step.collect_information, slot.initial_value)]
else:
Expand All @@ -567,8 +571,10 @@ def _run_step(
elif isinstance(step, ActionFlowStep):
if not step.action:
raise FlowException(f"Action not specified for step {step}")

context = {"context": self.dialogue_stack.current_context()}
action_name = self.render_template_variables(step.action, context)

if action_name in self.domain.action_names_or_texts:
structlogger.debug("flow.step.run.action", context=context)
return PauseFlowReturnPrediction(ActionPrediction(action_name, 1.0))
Expand Down Expand Up @@ -676,10 +682,18 @@ def trigger_pattern_completed(self, current_frame: DialogueStackFrame) -> None:
)
)

def trigger_pattern_ask_collect_information(self, collect_information: str) -> None:
def trigger_pattern_ask_collect_information(
self,
collect_information: str,
rejections: List[SlotRejection],
utter: str,
) -> None:
"""Trigger the pattern to ask for a slot value."""
self.dialogue_stack.push(
CollectInformationPatternFlowStackFrame(
collect_information=collect_information
collect_information=collect_information,
utter=utter,
rejections=rejections,
)
)

Expand Down
19 changes: 18 additions & 1 deletion rasa/dialogue_understanding/patterns/collect_information.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStackFrame
from rasa.shared.constants import RASA_DEFAULT_FLOW_PATTERN_PREFIX
from rasa.dialogue_understanding.stack.frames import PatternFlowStackFrame
from rasa.shared.core.flows.flow import SlotRejection

FLOW_PATTERN_COLLECT_INFORMATION = (
RASA_DEFAULT_FLOW_PATTERN_PREFIX + "ask_collect_information"
Expand All @@ -20,6 +21,14 @@ class CollectInformationPatternFlowStackFrame(PatternFlowStackFrame):
collect_information: str = ""
"""The information that should be collected from the user.
this corresponds to the slot that will be filled."""
utter: str = ""
"""The utter action that should be executed to ask the user for the
information."""
rejections: Optional[List[SlotRejection]] = None
"""The predicate check that should be applied to the collected information.
If a predicate check fails, its `utter` action indicated under rejections
will be executed.
"""

@classmethod
def type(cls) -> str:
Expand All @@ -36,10 +45,18 @@ def from_dict(data: Dict[str, Any]) -> CollectInformationPatternFlowStackFrame:
Returns:
The created `DialogueStackFrame`.
"""
rejections = data.get("rejections")
if rejections is not None:
rejections = [
SlotRejection.from_dict(rejection) for rejection in rejections
]

return CollectInformationPatternFlowStackFrame(
data["frame_id"],
step_id=data["step_id"],
collect_information=data["collect_information"],
utter=data["utter"],
rejections=rejections,
)

def context_as_dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ flows:
description: flow used to fill a slot
steps:
- id: "start"
action: action_extract_slots
action: action_run_slot_rejections
next: "validate"
- id: "validate"
action: validate_{{context.collect_information}}
Expand All @@ -119,7 +119,7 @@ flows:
then: "done"
- else: "ask_collect_information"
- id: "ask_collect_information"
action: utter_ask_{{context.collect_information}}
action: "{{context.utter}}"
next: "listen"
- id: "listen"
action: action_listen
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,14 @@ def as_dict(self) -> Dict[str, Any]:

def custom_asdict_factory(fields: List[Tuple[str, Any]]) -> Dict[str, Any]:
"""Converts enum values to their value."""

def rename_internal(field_name: str) -> str:
return field_name[:-1] if field_name.endswith("_") else field_name

return {
field: value.value if isinstance(value, Enum) else value
rename_internal(field): value.value
if isinstance(value, Enum)
else value
for field, value in fields
}

Expand Down
2 changes: 2 additions & 0 deletions rasa/shared/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
ACTION_CANCEL_FLOW = "action_cancel_flow"
ACTION_CLARIFY_FLOWS = "action_clarify_flows"
ACTION_CORRECT_FLOW_SLOT = "action_correct_flow_slot"
ACTION_RUN_SLOT_REJECTIONS_NAME = "action_run_slot_rejections"


DEFAULT_ACTION_NAMES = [
Expand All @@ -60,6 +61,7 @@
ACTION_CANCEL_FLOW,
ACTION_CORRECT_FLOW_SLOT,
ACTION_CLARIFY_FLOWS,
ACTION_RUN_SLOT_REJECTIONS_NAME,
]

ACTION_SHOULD_SEND_DOMAIN = "send_domain"
Expand Down
Loading

0 comments on commit bb3aea8

Please sign in to comment.