From 1724fc75b200e199f51e48180da5463c2b7850ef Mon Sep 17 00:00:00 2001 From: Anca Lita <27920906+ancalita@users.noreply.github.com> Date: Tue, 17 Oct 2023 11:57:02 +0100 Subject: [PATCH] handle edge case --- rasa/core/policies/flow_policy.py | 33 +++++++++++++------ tests/core/policies/test_flow_policy.py | 42 +++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 9 deletions(-) diff --git a/rasa/core/policies/flow_policy.py b/rasa/core/policies/flow_policy.py index 2ca36139f86b..31c6ca93521c 100644 --- a/rasa/core/policies/flow_policy.py +++ b/rasa/core/policies/flow_policy.py @@ -568,16 +568,31 @@ def _reset_slot( events.append(SlotSet(slot_name, initial_value)) events: List[Event] = [] + + not_resettable_slot_names = set() + for step in current_flow.steps: - # reset all slots scoped to the flow - if ( - isinstance(step, CollectInformationFlowStep) - and step.reset_after_flow_ends - ): - _reset_slot(step.collect, tracker) - elif isinstance(step, SetSlotsFlowStep): - for slot in step.slots: - _reset_slot(slot["key"], tracker) + if isinstance(step, CollectInformationFlowStep): + # reset all slots scoped to the flow + if step.reset_after_flow_ends: + _reset_slot(step.collect, tracker) + else: + not_resettable_slot_names.add(step.collect) + + # slots set by the set slots step should be reset after the flow ends + # unless they are also used in a collect step where `reset_after_flow_ends` + # is set to `False` + resettable_set_slots = [ + slot["key"] + for step in current_flow.steps + if isinstance(step, SetSlotsFlowStep) + for slot in step.slots + if slot["key"] not in not_resettable_slot_names + ] + + for name in resettable_set_slots: + _reset_slot(name, tracker) + return events def run_step( diff --git a/tests/core/policies/test_flow_policy.py b/tests/core/policies/test_flow_policy.py index a495e1430a5d..3d78c7d081f8 100644 --- a/tests/core/policies/test_flow_policy.py +++ b/tests/core/policies/test_flow_policy.py @@ -371,3 +371,45 @@ def test_flow_policy_resets_all_slots_after_flow_ends() -> None: SlotSet("foo", None), SlotSet("other_slot", None), ] + + +def test_flow_policy_set_slots_inherit_reset_from_collect_step() -> None: + """Test that `reset_after_flow_ends` is inherited from the collect step.""" + slot_name = "my_slot" + flows = flows_from_str( + f""" + flows: + foo_flow: + steps: + - id: "1" + collect: {slot_name} + reset_after_flow_ends: false + - id: "2" + set_slots: + - foo: bar + - {slot_name}: my_value + - id: "3" + action: action_listen + """ + ) + tracker = DialogueStateTracker.from_events( + "test123", + [ + SlotSet("my_slot", "my_value"), + SlotSet("foo", "bar"), + ActionExecuted("action_listen"), + ], + slots=[ + TextSlot("my_slot", mappings=[], initial_value="initial_value"), + TextSlot("foo", mappings=[]), + ], + ) + + domain = Domain.empty() + executor = FlowExecutor.from_tracker(tracker, flows, domain) + + current_flow = flows.flow_by_id("foo_flow") + events = executor._reset_scoped_slots(current_flow, tracker) + assert events == [ + SlotSet("foo", None), + ]