Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle categorical slots in command generator - [ENG 569] #12876

Closed
10 changes: 10 additions & 0 deletions rasa/dialogue_understanding/commands/set_slot_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from rasa.dialogue_understanding.stack.utils import filled_slots_for_active_flow
from rasa.shared.core.events import Event, SlotSet
from rasa.shared.core.flows.flow import FlowsList
from rasa.shared.core.slots import CategoricalSlot
from rasa.shared.core.trackers import DialogueStateTracker

structlogger = structlog.get_logger()
Expand Down Expand Up @@ -75,6 +76,15 @@ def run_command_on_tracker(
"command_executor.skip_command.slot_not_asked_for", command=self
)
return []
if self.name in tracker.slots:
slot = tracker.slots[self.name]
if isinstance(slot, CategoricalSlot) and self.value not in slot.values:
varunshankar marked this conversation as resolved.
Show resolved Hide resolved
# only fill categorical slots with values that are present in the domain
structlogger.debug(
"command_executor.skip_command.categorical_slot_value_not_in_domain",
command=self,
)
return []

structlogger.debug("command_executor.set_slot", command=self)
return [SlotSet(self.name, self.value)]
55 changes: 55 additions & 0 deletions tests/cdu/commands/test_set_slot_command.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from rasa.dialogue_understanding.commands.set_slot_command import SetSlotCommand
from rasa.shared.core.constants import DIALOGUE_STACK_SLOT
from rasa.shared.core.domain import Domain
from rasa.shared.core.events import SlotSet
from rasa.shared.core.flows.flow import FlowsList
from rasa.shared.core.trackers import DialogueStateTracker
Expand Down Expand Up @@ -226,3 +227,57 @@ def test_run_command_skips_setting_unknown_slot():
command = SetSlotCommand(name="unknown", value="unknown")

assert command.run_command_on_tracker(tracker, all_flows, tracker) == []


def test_run_command_slot_set_categorical_slot_values():
all_flows = flows_from_str(
"""
flows:
my_flow:
steps:
- id: collect_foo
collect: foo
next: collect_bar
- id: collect_bar
collect: bar
"""
)
domain = Domain.from_yaml(
"""
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
slots:
bar:
type: categorical
values:
- low
- high
mappings:
- type: custom
"""
)
tracker = DialogueStateTracker.from_events(
"test",
evts=[
SlotSet(
DIALOGUE_STACK_SLOT,
[
{
"type": "flow",
"flow_id": "my_flow",
"step_id": "collect_bar",
"frame_id": "some-frame-id",
},
],
),
],
slots=domain.slots,
)
# set the slot to a value that is not in the domain
command = SetSlotCommand(name="bar", value="medium")
assert command.run_command_on_tracker(tracker, all_flows, tracker) == []

# set the slot to a value present in the domain
command = SetSlotCommand(name="bar", value="low")
assert command.run_command_on_tracker(tracker, all_flows, tracker) == [
SlotSet("bar", "low")
]