Skip to content

Commit

Permalink
handle edge case
Browse files Browse the repository at this point in the history
  • Loading branch information
ancalita committed Oct 17, 2023
1 parent a90c952 commit 1724fc7
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 9 deletions.
33 changes: 24 additions & 9 deletions rasa/core/policies/flow_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,16 +568,31 @@ def _reset_slot(
events.append(SlotSet(slot_name, initial_value))

events: List[Event] = []

not_resettable_slot_names = set()

for step in current_flow.steps:
# reset all slots scoped to the flow
if (
isinstance(step, CollectInformationFlowStep)
and step.reset_after_flow_ends
):
_reset_slot(step.collect, tracker)
elif isinstance(step, SetSlotsFlowStep):
for slot in step.slots:
_reset_slot(slot["key"], tracker)
if isinstance(step, CollectInformationFlowStep):
# reset all slots scoped to the flow
if step.reset_after_flow_ends:
_reset_slot(step.collect, tracker)
else:
not_resettable_slot_names.add(step.collect)

# slots set by the set slots step should be reset after the flow ends
# unless they are also used in a collect step where `reset_after_flow_ends`
# is set to `False`
resettable_set_slots = [
slot["key"]
for step in current_flow.steps
if isinstance(step, SetSlotsFlowStep)
for slot in step.slots
if slot["key"] not in not_resettable_slot_names
]

for name in resettable_set_slots:
_reset_slot(name, tracker)

return events

def run_step(
Expand Down
42 changes: 42 additions & 0 deletions tests/core/policies/test_flow_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,45 @@ def test_flow_policy_resets_all_slots_after_flow_ends() -> None:
SlotSet("foo", None),
SlotSet("other_slot", None),
]


def test_flow_policy_set_slots_inherit_reset_from_collect_step() -> None:
"""Test that `reset_after_flow_ends` is inherited from the collect step."""
slot_name = "my_slot"
flows = flows_from_str(
f"""
flows:
foo_flow:
steps:
- id: "1"
collect: {slot_name}
reset_after_flow_ends: false
- id: "2"
set_slots:
- foo: bar
- {slot_name}: my_value
- id: "3"
action: action_listen
"""
)
tracker = DialogueStateTracker.from_events(
"test123",
[
SlotSet("my_slot", "my_value"),
SlotSet("foo", "bar"),
ActionExecuted("action_listen"),
],
slots=[
TextSlot("my_slot", mappings=[], initial_value="initial_value"),
TextSlot("foo", mappings=[]),
],
)

domain = Domain.empty()
executor = FlowExecutor.from_tracker(tracker, flows, domain)

current_flow = flows.flow_by_id("foo_flow")
events = executor._reset_scoped_slots(current_flow, tracker)
assert events == [
SlotSet("foo", None),
]

0 comments on commit 1724fc7

Please sign in to comment.