Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Configure prompts through files - [ENG 510] #12859

Merged
merged 8 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions data/test_prompt_templates/test_prompt.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This is a test prompt.
11 changes: 6 additions & 5 deletions rasa/dialogue_understanding/generator/llm_command_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from rasa.shared.nlu.training_data.training_data import TrainingData
from rasa.shared.utils.llm import (
DEFAULT_OPENAI_CHAT_MODEL_NAME_ADVANCED,
get_prompt_template,
llm_factory,
tracker_as_readable_transcript,
sanitize_message_for_prompt,
Expand Down Expand Up @@ -70,10 +71,7 @@ class LLMCommandGenerator(GraphComponent, CommandGenerator):
@staticmethod
def get_default_config() -> Dict[str, Any]:
"""The component's default config (see parent class for full docstring)."""
return {
"prompt": DEFAULT_COMMAND_PROMPT_TEMPLATE,
LLM_CONFIG_KEY: None,
}
return {"prompt": None, LLM_CONFIG_KEY: None}

def __init__(
self,
Expand All @@ -82,7 +80,10 @@ def __init__(
resource: Resource,
) -> None:
self.config = {**self.get_default_config(), **config}
self.prompt_template = self.config["prompt"]
self.prompt_template = get_prompt_template(
config.get("prompt"),
DEFAULT_COMMAND_PROMPT_TEMPLATE,
)
self._model_storage = model_storage
self._resource = resource

Expand Down
20 changes: 20 additions & 0 deletions rasa/shared/utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.core.events import BotUttered, UserUttered
from rasa.shared.engine.caching import get_local_cache_location
import rasa.shared.utils.io


structlogger = structlog.get_logger()
Expand Down Expand Up @@ -208,3 +209,22 @@ def embedder_factory(
return embeddings_cls(**parameters)
else:
raise ValueError(f"Unsupported embeddings type '{typ}'")


def get_prompt_template(
jinja_file_path: Optional[Text], default_prompt_template: Text
) -> Text:
"""Returns the prompt template.

Args:
jinja_file_path: the path to the jinja file
default_prompt_template: the default prompt template

Returns:
The prompt template.
"""
return (
rasa.shared.utils.io.read_file(jinja_file_path)
if jinja_file_path is not None
else default_prompt_template
)
41 changes: 41 additions & 0 deletions tests/cdu/generator/test_llm_command_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import uuid

import pytest
from _pytest.tmpdir import TempPathFactory

from rasa.dialogue_understanding.generator.llm_command_generator import (
LLMCommandGenerator,
)
from rasa.engine.storage.local_model_storage import LocalModelStorage
from rasa.engine.storage.resource import Resource
from rasa.engine.storage.storage import ModelStorage


@pytest.fixture(scope="session")
def resource() -> Resource:
return Resource(uuid.uuid4().hex)


@pytest.fixture(scope="session")
def model_storage(tmp_path_factory: TempPathFactory) -> ModelStorage:
return LocalModelStorage(tmp_path_factory.mktemp(uuid.uuid4().hex))


async def test_llm_command_generator_prompt_init_custom(
model_storage: ModelStorage, resource: Resource
) -> None:
generator = LLMCommandGenerator(
{"prompt": "data/test_prompt_templates/test_prompt.jinja2"},
model_storage,
resource,
)
assert generator.prompt_template.startswith("This is a test prompt.")


async def test_llm_command_generator_prompt_init_default(
model_storage: ModelStorage, resource: Resource
) -> None:
generator = LLMCommandGenerator({}, model_storage, resource)
assert generator.prompt_template.startswith(
"Your task is to analyze the current conversation"
)