Skip to content

Commit

Permalink
pre-load all prompt file contents instead of just storing a path and …
Browse files Browse the repository at this point in the history
…load on demand
  • Loading branch information
Benjoyo committed Apr 27, 2024
1 parent 99180e6 commit 0f4c511
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 52 deletions.
35 changes: 22 additions & 13 deletions bpm-ai-core/bpm_ai_core/prompt/prompt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import glob
import inspect
import os
import re
Expand All @@ -17,10 +18,11 @@

class Prompt:

def __init__(self, kwargs: Dict[str, Any], path: str | None = None, template_str: str | None = None) -> None:
self.path = path
def __init__(self, kwargs: dict[str, Any], template_str: str = None, path: str = None, prompt_templates: dict = None) -> None:
self.template_str = template_str
self.template_vars = kwargs
self.path = path
self.prompt_templates = prompt_templates

@classmethod
def from_file(cls, path: str, **kwargs):
Expand All @@ -32,14 +34,23 @@ def from_file(cls, path: str, **kwargs):
current_dir = os.path.dirname(os.path.abspath(caller_filename))
file_path = os.path.join(current_dir, path)

return cls(kwargs, path=file_path)
prompt_templates = {}

default_file_path = f"{file_path}.prompt"
specific_file_path = f"{file_path}.*.prompt"
prompt_files = glob.glob(specific_file_path) + [default_file_path]
for prompt_file in prompt_files:
with open(prompt_file, 'r') as p:
prompt_templates[os.path.basename(prompt_file)] = p.read()

return cls(kwargs, path=path, prompt_templates=prompt_templates)

@classmethod
def from_string(cls, template: str, **kwargs):
return cls(kwargs, template_str=template)

def format(self, llm_name: str = "") -> List[ChatMessage]:
template = self.load_template(self.path, llm_name) if self.path else Template(self.template_str)
template = self.load_template(self.path, llm_name)
full_prompt = template.render(self.template_vars)

regex = r'\[#\s*(user|assistant|system|tool_result:.*|)\s*#\]'
Expand Down Expand Up @@ -130,15 +141,13 @@ def format(self, llm_name: str = "") -> List[ChatMessage]:

return [m for m in messages if m]

@staticmethod
def load_template(path: str, llm_name: str) -> Template:
default_path = f"{path}.prompt"
llm_specific_path = f"{path}.{llm_name}.prompt"
filename = llm_specific_path if os.path.exists(llm_specific_path) else default_path
if not os.path.exists(filename):
raise FileNotFoundError(f"No prompt file found at {filename}")
with open(filename, 'r') as f:
return Template(f.read())
def load_template(self, path: str, llm_name: str) -> Template:
default_prompt = f"{path}.prompt"
llm_specific_prompt = f"{path}.{llm_name}.prompt"
prompt = self.prompt_templates.get(llm_specific_prompt, self.prompt_templates.get(default_prompt))
if not prompt:
raise FileNotFoundError(f"No prompt file {path} found for llm {llm_name}")
return Template(prompt)

def __repr__(self):
return f"{self.__class__.__qualname__}(template_vars={self.template_vars}, path={self.path}, template_str={self.template_str})"
39 changes: 1 addition & 38 deletions bpm-ai-core/tests/test.prompt
Original file line number Diff line number Diff line change
@@ -1,38 +1 @@
[# system #]
You are a smart assistant.
[# blob {{image_url}} #]
Go!

[# user #]
What is one plus one?

[# assistant #]
I will call some tools.
[# tool_call: foo (foo_id) #]
x
[# tool_call: bar #]
y

[# tool_result: foo_id #]
the result

[# assistant #]
Looks good, now another one:

[# tool_call: other (other_id) #]
z

[# tool_result: other_id #]
the result 2

[# assistant #]
[# tool_call: another #]
123

[# assistant #]
That's that.

[# user #]
Here is an image:
[# blob {{image_url}} #]
{{task}}
foo
9 changes: 8 additions & 1 deletion bpm-ai-core/tests/test_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_prompt_format():
"task": "Do the task"
}
prompt = Prompt.from_file("test", **template_vars)
messages = prompt.format()
messages = prompt.format("openai")

assert len(messages) == 9

Expand Down Expand Up @@ -103,6 +103,13 @@ def test_prompt_format():
assert messages[8].content[2] == "Do the task"


def test_prompt_format_default_prompt():
prompt = Prompt.from_file("test")
messages = prompt.format()

assert len(messages) == 1
assert messages[0].content == "foo"

def test_prompt_filter():
input_dict = {
"text": "What is your name?",
Expand Down

0 comments on commit 0f4c511

Please sign in to comment.