Skip to content

Commit

Permalink
update llm command generator
Browse files Browse the repository at this point in the history
  • Loading branch information
varunshankar committed Sep 24, 2023
1 parent ed71677 commit c1d1325
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from rasa.shared.nlu.training_data.training_data import TrainingData
from rasa.shared.utils.llm import (
DEFAULT_OPENAI_CHAT_MODEL_NAME_ADVANCED,
get_prompt_template,
llm_factory,
tracker_as_readable_transcript,
sanitize_message_for_prompt,
Expand Down Expand Up @@ -82,7 +83,10 @@ def __init__(
resource: Resource,
) -> None:
self.config = {**self.get_default_config(), **config}
self.prompt_template = self.config["prompt"]
self.prompt_template = get_prompt_template(
DEFAULT_COMMAND_PROMPT_TEMPLATE,
self.config.get("prompt"),
)
self._model_storage = model_storage
self._resource = resource

Expand Down
20 changes: 20 additions & 0 deletions rasa/shared/utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.core.events import BotUttered, UserUttered
from rasa.shared.engine.caching import get_local_cache_location
import rasa.shared.utils.io


structlogger = structlog.get_logger()
Expand Down Expand Up @@ -208,3 +209,22 @@ def embedder_factory(
return embeddings_cls(**parameters)
else:
raise ValueError(f"Unsupported embeddings type '{typ}'")


def get_prompt_template(
default_prompt_template: Text, jinja_file_path: Optional[Text]
) -> Text:
"""Returns the prompt template.
Args:
default_prompt_template: the default prompt template
jinja_file_path: the path to the jinja file
Returns:
The prompt template.
"""
return (
rasa.shared.utils.io.read_file(jinja_file_path)
if jinja_file_path is not None
else default_prompt_template
)

0 comments on commit c1d1325

Please sign in to comment.