Skip to content

Commit

Permalink
Make funtion templates real functions
Browse files Browse the repository at this point in the history
Prompts remplates are currently contained in the docstring of decorated
functions. The main issue with this is that prompt templates cannot be
composed.

In this commit we instead require users to return the prompt
template from the function. The template will then automatically be
rendered using the values passed to the function. This is very flexible:
some variables can be used inside the functions and not be present in
the Jinja2 template that is returned, for instance:

```python
import prompts

@prompts.template
def my_template(a, b):
    prompt = f'This is a first variable {a}'
    return prompt + "and a second {{b}}"
```
  • Loading branch information
rlouf committed Sep 26, 2024
1 parent 05c9d5e commit 1b914c1
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 56 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ from prompts import template

@template
def few_shots(instructions, examples, question):
"""{{ instructions }}
return """{{ instructions }}
Examples
--------
Expand Down
20 changes: 10 additions & 10 deletions docs/reference/template.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ will pass to the prompt function.

@prompts.template
def greetings(name, question):
"""Hello, {{ name }}!
return """Hello, {{ name }}!
{{ question }}
"""

Expand All @@ -62,7 +62,7 @@ If a variable is missing in the function's arguments, Jinja2 will throw an `Unde

@prompts.template
def greetings(name):
"""Hello, {{ surname }}!"""
return """Hello, {{ surname }}!"""

prompt = greetings("user")
```
Expand Down Expand Up @@ -94,7 +94,7 @@ Prompt functions are functions, and thus can be imported from other modules:

@prompts.template
def greetings(name, question):
"""Hello, {{ name }}!
return """Hello, {{ name }}!
{{ question }}
"""
```
Expand Down Expand Up @@ -128,7 +128,7 @@ keys `question` and `answer` to the prompt function:

@prompts.template
def few_shots(instructions, examples, question):
"""{{ instructions }}
return """{{ instructions }}

Examples
--------
Expand Down Expand Up @@ -207,12 +207,12 @@ below does not matter for formatting:

@prompts.template
def prompt1():
"""My prompt
return """My prompt
"""

@prompts.template
def prompt2():
"""
return """
My prompt
"""

Expand All @@ -236,20 +236,20 @@ Indentation is relative to the second line of the docstring, and leading spaces

@prompts.template
def example1():
"""First line
return """First line
Second line
"""

@prompts.template
def example2():
"""
return """
Second line
Third line
"""

@prompts.template
def example3():
"""
return """
Second line
Third line
"""
Expand Down Expand Up @@ -285,7 +285,7 @@ You can use the backslash `\` to break a long line of text. It will render as a

@prompts.template
def example():
"""
return """
Break in \
several lines \
But respect the indentation
Expand Down
40 changes: 18 additions & 22 deletions prompts/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import warnings
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Callable, Dict, Hashable, Optional, cast
from typing import Callable, Dict, Hashable, Optional

from jinja2 import Environment, StrictUndefined

Expand All @@ -15,7 +15,7 @@ class Template:
"""Represents a prompt template.
A prompt template is a callable that, given a Jinja2 template and a set of values,
renders the template using those values. It is recommended to instantiate `Temaplate`
renders the template using those values. It is recommended to instantiate `Template`
using the `template` decorator, which extracts the template from the function's
docstring and its variables from the function's signature.
Expand All @@ -40,11 +40,15 @@ class Template:
"""

template: str
signature: inspect.Signature
fn: Callable
model: Optional[str] = None
registry: Dict[str, Callable] = field(default_factory=dict)

def __init__(self, fn: Callable):
self.fn = fn
self.signature = inspect.signature(fn)

def __call__(self, *args, **kwargs) -> str:
"""Render and return the template.
Expand All @@ -55,7 +59,10 @@ def __call__(self, *args, **kwargs) -> str:
"""
bound_arguments = self.signature.bind(*args, **kwargs)
bound_arguments.apply_defaults()
return render(self.template, self.model, **bound_arguments.arguments)

template = self.fn(**bound_arguments.arguments)

return render(template, self.model, **bound_arguments.arguments)

def __str__(self):
return self.template
Expand Down Expand Up @@ -104,24 +111,23 @@ def template(fn: Callable) -> Template:
manipulation by providing some degree of encapsulation. It uses the `render`
function internally to render templates.
>>> import outlines
>>> import prompts
>>>
>>> @outlines.prompt
>>> @prompts.template
>>> def build_prompt(question):
... "I have a ${question}"
... return "I have a {{question}}"
...
>>> prompt = build_prompt("How are you?")
This API can also be helpful in an "agent" context where parts of the prompt
are set when the agent is initialized and never modified later. In this situation
we can partially apply the prompt function at initialization.
>>> import outlines
>>> import functools as ft
>>> import prompts
...
>>> @outlines.prompt
>>> @prompts.template
... def solve_task(name: str, objective: str, task: str):
... '''Your name is {{name}}.
... return '''Your name is {{name}}.
.. Your overall objective is to {{objective}}.
... Please solve the following task: {{task}}'''
...
Expand All @@ -132,17 +138,7 @@ def template(fn: Callable) -> Template:
A `Prompt` callable class which will render the template when called.
"""
signature = inspect.signature(fn)

# The docstring contains the template that will be rendered to be used
# as a prompt to the language model.
docstring = fn.__doc__
if docstring is None:
raise TypeError("Could not find a template in the function's docstring.")

template = cast(str, docstring)

return Template(template, signature)
return Template(fn)


@lru_cache
Expand Down
30 changes: 7 additions & 23 deletions tests/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,8 @@ def test_render_jinja():
def test_prompt_basic():
@prompts.template
def test_tpl(variable):
"""{{variable}} test"""
return """{{variable}} test"""

assert test_tpl.template == "{{variable}} test"
assert list(test_tpl.signature.parameters.keys()) == ["variable"]

with pytest.raises(TypeError):
Expand All @@ -145,7 +144,7 @@ def test_tpl(variable):

@prompts.template
def test_single_quote_tpl(variable):
"${variable} test"
return "{{variable}} test"

p = test_tpl("test")
assert p == "test test"
Expand All @@ -154,9 +153,8 @@ def test_single_quote_tpl(variable):
def test_prompt_kwargs():
@prompts.template
def test_kwarg_tpl(var, other_var="other"):
"""{{var}} and {{other_var}}"""
return """{{var}} and {{other_var}}"""

assert test_kwarg_tpl.template == "{{var}} and {{other_var}}"
assert list(test_kwarg_tpl.signature.parameters.keys()) == ["var", "other_var"]

p = test_kwarg_tpl("test")
Expand All @@ -169,30 +167,16 @@ def test_kwarg_tpl(var, other_var="other"):
assert p == "test and test"


def test_no_prompt():
with pytest.raises(TypeError, match="template"):

@prompts.template
def test_empty(variable):
pass

with pytest.raises(TypeError, match="template"):

@prompts.template
def test_only_code(variable):
return variable


@pytest.mark.filterwarnings("ignore: The model")
def test_dispatch():

@prompts.template
def simple_prompt(query: str):
"""{{ query }}"""
return """{{ query }}"""

@simple_prompt.register("provider/name")
def simple_prompt_name(query: str):
"""name: {{ query }}"""
return """name: {{ query }}"""

assert list(simple_prompt.registry.keys()) == ["provider/name"]
assert callable(simple_prompt)
Expand All @@ -214,7 +198,7 @@ def test_special_tokens():

@prompts.template
def simple_prompt(query: str):
"""{{ bos + query + eos }}"""
return """{{ bos + query + eos }}"""

assert simple_prompt("test") == "test"
assert simple_prompt["openai-community/gpt2"]("test") == "test<|endoftext|>"
Expand All @@ -225,7 +209,7 @@ def test_warn():

@prompts.template
def simple_prompt():
"""test"""
return """test"""

with pytest.warns(UserWarning, match="not present in the special token"):
simple_prompt["non-existent-model"]()

0 comments on commit 1b914c1

Please sign in to comment.