From a498f88c997a8f39cd373d6909fb1777ef4e9ac6 Mon Sep 17 00:00:00 2001 From: Varun Shankar S Date: Wed, 18 Oct 2023 11:41:31 +0200 Subject: [PATCH] add tests --- tests/shared/utils/test_llm.py | 43 ++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/shared/utils/test_llm.py b/tests/shared/utils/test_llm.py index 44e00ca917c1..175408a16dd9 100644 --- a/tests/shared/utils/test_llm.py +++ b/tests/shared/utils/test_llm.py @@ -59,6 +59,49 @@ def test_tracker_as_readable_transcript_handles_tracker_with_events_and_max_turn assert tracker_as_readable_transcript(tracker, max_turns=1) == ("""AI: hi""") +def test_tracker_as_readable_transcript_and_discard_excess_turns_with_default_max_turns( + domain: Domain, +): + tracker = DialogueStateTracker(sender_id="test", slots=domain.slots) + tracker.update_with_events( + [ + UserUttered("A0"), + BotUttered("B1"), + UserUttered("C2"), + BotUttered("D3"), + UserUttered("E4"), + BotUttered("F5"), + UserUttered("G6"), + BotUttered("H7"), + UserUttered("I8"), + BotUttered("J9"), + UserUttered("K10"), + BotUttered("L11"), + UserUttered("M12"), + BotUttered("N13"), + UserUttered("O14"), + BotUttered("P15"), + UserUttered("Q16"), + BotUttered("R17"), + UserUttered("S18"), + BotUttered("T19"), + UserUttered("U20"), + BotUttered("V21"), + UserUttered("W22"), + BotUttered("X23"), + UserUttered("Y24"), + ], + domain, + ) + response = tracker_as_readable_transcript(tracker) + assert response == ( + """AI: F5\nUSER: G6\nAI: H7\nUSER: I8\nAI: J9\nUSER: K10\nAI: L11\n""" + """USER: M12\nAI: N13\nUSER: O14\nAI: P15\nUSER: Q16\nAI: R17\nUSER: S18\n""" + """AI: T19\nUSER: U20\nAI: V21\nUSER: W22\nAI: X23\nUSER: Y24""" + ) + assert response.count("\n") == 19 + + def test_sanitize_message_for_prompt_handles_none(): assert sanitize_message_for_prompt(None) == ""