From 192733ed48a1530cdb848e95baae25f3a82d0408 Mon Sep 17 00:00:00 2001 From: Varun Shankar S Date: Tue, 26 Sep 2023 16:50:19 +0200 Subject: [PATCH] add tests --- data/test_prompt_templates/test_prompt.jinja2 | 1 + .../generator/llm_command_generator.py | 2 +- .../generator/test_llm_command_generator.py | 42 +++++++++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 data/test_prompt_templates/test_prompt.jinja2 create mode 100644 tests/cdu/generator/test_llm_command_generator.py diff --git a/data/test_prompt_templates/test_prompt.jinja2 b/data/test_prompt_templates/test_prompt.jinja2 new file mode 100644 index 000000000000..503ea5757402 --- /dev/null +++ b/data/test_prompt_templates/test_prompt.jinja2 @@ -0,0 +1 @@ +This is a test prompt. \ No newline at end of file diff --git a/rasa/dialogue_understanding/generator/llm_command_generator.py b/rasa/dialogue_understanding/generator/llm_command_generator.py index 6a36cfbb507c..b2f092729c77 100644 --- a/rasa/dialogue_understanding/generator/llm_command_generator.py +++ b/rasa/dialogue_understanding/generator/llm_command_generator.py @@ -85,7 +85,7 @@ def __init__( self.config = {**self.get_default_config(), **config} self.prompt_template = get_prompt_template( DEFAULT_COMMAND_PROMPT_TEMPLATE, - self.config.get("prompt"), + config.get("prompt"), ) self._model_storage = model_storage self._resource = resource diff --git a/tests/cdu/generator/test_llm_command_generator.py b/tests/cdu/generator/test_llm_command_generator.py new file mode 100644 index 000000000000..67ad8618f937 --- /dev/null +++ b/tests/cdu/generator/test_llm_command_generator.py @@ -0,0 +1,42 @@ +import uuid +from typing import Dict + +import pytest +from _pytest.tmpdir import TempPathFactory + +from rasa.dialogue_understanding.generator.llm_command_generator import ( + LLMCommandGenerator, +) +from rasa.engine.storage.local_model_storage import LocalModelStorage +from rasa.engine.storage.resource import Resource +from rasa.engine.storage.storage import ModelStorage + + +@pytest.fixture(scope="session") +def resource() -> Resource: + return Resource(uuid.uuid4().hex) + + +@pytest.fixture(scope="session") +def model_storage(tmp_path_factory: TempPathFactory) -> ModelStorage: + return LocalModelStorage(tmp_path_factory.mktemp(uuid.uuid4().hex)) + + +@pytest.mark.parametrize( + "config, expected", + [ + ( + {"prompt": "data/test_prompt_templates/test_prompt.jinja2"}, + "This is a test prompt.", + ), + ( + {}, + "Your task is to analyze the current conversation", + ), + ], +) +async def test_llm_command_generator_prompt_initialisation( + model_storage: ModelStorage, resource: Resource, config: Dict, expected: str +): + generator = LLMCommandGenerator(config, model_storage, resource) + assert generator.prompt_template.startswith(expected)