Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
varunshankar committed Sep 26, 2023
1 parent c1d1325 commit 192733e
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 1 deletion.
1 change: 1 addition & 0 deletions data/test_prompt_templates/test_prompt.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This is a test prompt.
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
self.config = {**self.get_default_config(), **config}
self.prompt_template = get_prompt_template(
DEFAULT_COMMAND_PROMPT_TEMPLATE,
self.config.get("prompt"),
config.get("prompt"),
)
self._model_storage = model_storage
self._resource = resource
Expand Down
42 changes: 42 additions & 0 deletions tests/cdu/generator/test_llm_command_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import uuid
from typing import Dict

import pytest
from _pytest.tmpdir import TempPathFactory

from rasa.dialogue_understanding.generator.llm_command_generator import (
LLMCommandGenerator,
)
from rasa.engine.storage.local_model_storage import LocalModelStorage
from rasa.engine.storage.resource import Resource
from rasa.engine.storage.storage import ModelStorage


@pytest.fixture(scope="session")
def resource() -> Resource:
return Resource(uuid.uuid4().hex)


@pytest.fixture(scope="session")
def model_storage(tmp_path_factory: TempPathFactory) -> ModelStorage:
return LocalModelStorage(tmp_path_factory.mktemp(uuid.uuid4().hex))


@pytest.mark.parametrize(
"config, expected",
[
(
{"prompt": "data/test_prompt_templates/test_prompt.jinja2"},
"This is a test prompt.",
),
(
{},
"Your task is to analyze the current conversation",
),
],
)
async def test_llm_command_generator_prompt_initialisation(
model_storage: ModelStorage, resource: Resource, config: Dict, expected: str
):
generator = LLMCommandGenerator(config, model_storage, resource)
assert generator.prompt_template.startswith(expected)

0 comments on commit 192733e

Please sign in to comment.