Skip to content

Commit

Permalink
improve handling categorical slot rejection
Browse files Browse the repository at this point in the history
  • Loading branch information
varunshankar committed Oct 16, 2023
1 parent edfbe68 commit e36c1f5
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 125 deletions.
97 changes: 56 additions & 41 deletions rasa/core/actions/action_run_slot_rejections.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack
from rasa.shared.core.constants import ACTION_RUN_SLOT_REJECTIONS_NAME
from rasa.shared.core.events import Event, SlotSet
from rasa.shared.core.slots import CategoricalSlot

if TYPE_CHECKING:
from rasa.core.nlg import NaturalLanguageGenerator
Expand Down Expand Up @@ -47,9 +48,6 @@ async def run(
if not isinstance(top_frame, CollectInformationPatternFlowStackFrame):
return []

if not top_frame.rejections:
return []

slot_name = top_frame.collect
slot_instance = tracker.slots.get(slot_name)
if slot_instance and not slot_instance.has_been_set:
Expand All @@ -65,46 +63,63 @@ async def run(

slot_value = tracker.get_slot(slot_name)

current_context = dialogue_stack.current_context()
current_context[slot_name] = slot_value

structlogger.debug("run.predicate.context", context=current_context)
document = current_context.copy()

for rejection in top_frame.rejections:
condition = rejection.if_
utterance = rejection.utter

try:
rendered_template = Template(condition).render(current_context)
predicate = Predicate(rendered_template)
violation = predicate.evaluate(document)
if top_frame.rejections:
# run validation when rejections are defined
current_context = dialogue_stack.current_context()
current_context[slot_name] = slot_value

structlogger.debug("run.predicate.context", context=current_context)
document = current_context.copy()

for rejection in top_frame.rejections:
condition = rejection.if_
utterance = rejection.utter

try:
rendered_template = Template(condition).render(current_context)
predicate = Predicate(rendered_template)
violation = predicate.evaluate(document)
structlogger.debug(
"run.predicate.result",
predicate=predicate.description(),
violation=violation,
)
except (TypeError, Exception) as e:
structlogger.error(
"run.predicate.error",
predicate=condition,
document=document,
error=str(e),
)
violation = True
internal_error = True

if violation:
break

if not violation:
return []

# reset slot value that was initially filled with an invalid value
events.append(SlotSet(top_frame.collect, None))

if internal_error:
utterance = "utter_internal_error_rasa"
else:
# run implicit validations.
if (
isinstance(slot_instance, CategoricalSlot)
and slot_value not in slot_instance.values
):
# only fill categorical slots with values that are present in the domain
structlogger.debug(
"run.predicate.result",
predicate=predicate.description(),
violation=violation,
"run.rejection.categorical_slot_value_not_in_domain",
command=self,
)
except (TypeError, Exception) as e:
structlogger.error(
"run.predicate.error",
predicate=condition,
document=document,
error=str(e),
)
violation = True
internal_error = True

if violation:
break

if not violation:
return []

# reset slot value that was initially filled with an invalid value
events.append(SlotSet(top_frame.collect, None))

if internal_error:
utterance = "utter_internal_error_rasa"
events.append(SlotSet(slot_name, None))
utterance = "utter_default_slot_rejection"
else:
return []

if not isinstance(utterance, str):
structlogger.error(
Expand Down
10 changes: 0 additions & 10 deletions rasa/dialogue_understanding/commands/set_slot_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from rasa.dialogue_understanding.stack.utils import filled_slots_for_active_flow
from rasa.shared.core.events import Event, SlotSet
from rasa.shared.core.flows.flow import FlowsList
from rasa.shared.core.slots import CategoricalSlot
from rasa.shared.core.trackers import DialogueStateTracker

structlogger = structlog.get_logger()
Expand Down Expand Up @@ -76,15 +75,6 @@ def run_command_on_tracker(
"command_executor.skip_command.slot_not_asked_for", command=self
)
return []
if self.name in tracker.slots:
slot = tracker.slots[self.name]
if isinstance(slot, CategoricalSlot) and self.value not in slot.values:
# only fill categorical slots with values that are present in the domain
structlogger.debug(
"command_executor.skip_command.categorical_slot_value_not_in_domain",
command=self,
)
return []

structlogger.debug("command_executor.set_slot", command=self)
return [SlotSet(self.name, self.value)]
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ responses:
metadata:
rephrase: True
template: jinja


utter_default_slot_rejection:
- text: Sorry, you requested an option that is not valid. Please select one of the available options.
metadata:
rephrase: True

slots:
confirm_correction:
Expand Down
55 changes: 0 additions & 55 deletions tests/cdu/commands/test_set_slot_command.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
from rasa.dialogue_understanding.commands.set_slot_command import SetSlotCommand
from rasa.shared.core.constants import DIALOGUE_STACK_SLOT
from rasa.shared.core.domain import Domain
from rasa.shared.core.events import SlotSet
from rasa.shared.core.flows.flow import FlowsList
from rasa.shared.core.trackers import DialogueStateTracker
Expand Down Expand Up @@ -227,57 +226,3 @@ def test_run_command_skips_setting_unknown_slot():
command = SetSlotCommand(name="unknown", value="unknown")

assert command.run_command_on_tracker(tracker, all_flows, tracker) == []


def test_run_command_slot_set_categorical_slot_values():
all_flows = flows_from_str(
"""
flows:
my_flow:
steps:
- id: collect_foo
collect: foo
next: collect_bar
- id: collect_bar
collect: bar
"""
)
domain = Domain.from_yaml(
"""
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
slots:
bar:
type: categorical
values:
- low
- high
mappings:
- type: custom
"""
)
tracker = DialogueStateTracker.from_events(
"test",
evts=[
SlotSet(
DIALOGUE_STACK_SLOT,
[
{
"type": "flow",
"flow_id": "my_flow",
"step_id": "collect_bar",
"frame_id": "some-frame-id",
},
],
),
],
slots=domain.slots,
)
# set the slot to a value that is not in the domain
command = SetSlotCommand(name="bar", value="medium")
assert command.run_command_on_tracker(tracker, all_flows, tracker) == []

# set the slot to a value present in the domain
command = SetSlotCommand(name="bar", value="low")
assert command.run_command_on_tracker(tracker, all_flows, tracker) == [
SlotSet("bar", "low")
]
Loading

0 comments on commit e36c1f5

Please sign in to comment.