Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make funtion templates return templates to be rendered #22

Merged
merged 4 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ from prompts import template

@template
def few_shots(instructions, examples, question):
"""{{ instructions }}
return """{{ instructions }}

Examples
--------
Expand Down
25 changes: 15 additions & 10 deletions docs/reference/template.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ allows to easily compose complex prompts.
Prompt functions are opinionated when it comes to prompt rendering. These opinions are meant to avoid common prompting errors, but can have unintended consequences if you are doing something unusual. We advise to always print the prompt before using it. You can also [read the
reference](#formatting-conventions) section if you want to know more.


!!! note "Performance"

Prompt templates introduce some overhead compared to standard Python functions, although the rendering time is still very reasonable. In the unlikely scenario where rendering templates are a bottleneck you can replace them with functions that use standard string manipulation.

## Your first prompt

The following snippet showcases a very simple prompt. The variables between
Expand All @@ -38,7 +43,7 @@ will pass to the prompt function.

@prompts.template
def greetings(name, question):
"""Hello, {{ name }}!
return """Hello, {{ name }}!
{{ question }}
"""

Expand All @@ -62,7 +67,7 @@ If a variable is missing in the function's arguments, Jinja2 will throw an `Unde

@prompts.template
def greetings(name):
"""Hello, {{ surname }}!"""
return """Hello, {{ surname }}!"""

prompt = greetings("user")
```
Expand Down Expand Up @@ -94,7 +99,7 @@ Prompt functions are functions, and thus can be imported from other modules:

@prompts.template
def greetings(name, question):
"""Hello, {{ name }}!
return """Hello, {{ name }}!
{{ question }}
"""
```
Expand Down Expand Up @@ -128,7 +133,7 @@ keys `question` and `answer` to the prompt function:

@prompts.template
def few_shots(instructions, examples, question):
"""{{ instructions }}
return """{{ instructions }}

Examples
--------
Expand Down Expand Up @@ -207,12 +212,12 @@ below does not matter for formatting:

@prompts.template
def prompt1():
"""My prompt
return """My prompt
"""

@prompts.template
def prompt2():
"""
return """
My prompt
"""

Expand All @@ -236,20 +241,20 @@ Indentation is relative to the second line of the docstring, and leading spaces

@prompts.template
def example1():
"""First line
return """First line
Second line
"""

@prompts.template
def example2():
"""
return """
Second line
Third line
"""

@prompts.template
def example3():
"""
return """
Second line
Third line
"""
Expand Down Expand Up @@ -285,7 +290,7 @@ You can use the backslash `\` to break a long line of text. It will render as a

@prompts.template
def example():
"""
return """
Break in \
several lines \
But respect the indentation
Expand Down
106 changes: 45 additions & 61 deletions prompts/templates.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,50 @@
import inspect
import re
import warnings
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Callable, Dict, Hashable, Optional, cast
from typing import Callable, Dict, Hashable, Optional

from jinja2 import Environment, StrictUndefined

from prompts.tokens import SPECIAL_TOKENS, Special


@dataclass
class Template:
"""Represents a prompt template.

A prompt template is a callable that, given a Jinja2 template and a set of values,
renders the template using those values. It is recommended to instantiate `Temaplate`
using the `template` decorator, which extracts the template from the function's
docstring and its variables from the function's signature.
A prompt template is a callable with renders the template returned by the
function using the values that are passed to it. It is recommended to
instantiate `Template` using the `template` decorator.

>>> import prompts
...
... @prompts.template
... def prompt(name: str) -> str:
... return "My name is {{name}}"

It is not uncommon that, for the same taks, different models will perform
better with different prompt. Here we thus allow to dispatch to associate a
prompt with a task and dispatch the prompt based on the model being used; a
`Template` instance is thus also a registry that associates model names to
other templates.

>>> @prompt.register("gpt2")
... def prompt_gpt2(name: str) -> str:
... return "Hi GPT2! My name is {{name}}"

The name of the model can then be passed to the render function along with
the model name and the values of the arguments:

>>> from prompts import render
...
... render(prompt, "gpt2", name="Dan")
>>> "Hi GPT2! My name is Dan"

Attributes
----------
template
The template to render.
fn
The function that returns a template.
signature
The prompt function's signature.
The function's signature.
model
The model the `Template` is associated with. Defaults to `None`.
registry
Expand All @@ -40,7 +53,7 @@ class Template:

"""

template: str
fn: Callable
signature: inspect.Signature
model: Optional[str] = None
registry: Dict[str, Callable] = field(default_factory=dict)
Expand All @@ -55,10 +68,10 @@ def __call__(self, *args, **kwargs) -> str:
"""
bound_arguments = self.signature.bind(*args, **kwargs)
bound_arguments.apply_defaults()
return render(self.template, self.model, **bound_arguments.arguments)

def __str__(self):
return self.template
template = self.fn(**bound_arguments.arguments)

return render(template, self.model, **bound_arguments.arguments)

def __getitem__(self, model_name: str):
"""Get the prompt template corresponding to a model name.
Expand All @@ -84,7 +97,7 @@ def __getitem__(self, model_name: str):

def register(self, model_name: str):
"""Register the prompt template, as represented by a prompt function,
for the model name.
for a given model `model_name`.

"""

Expand All @@ -98,51 +111,42 @@ def wrapper(fn: Callable):


def template(fn: Callable) -> Template:
"""Decorate a function that contains a prompt template.
"""Decorate a function that returns a prompt template.

This allows to define prompts in the docstring of a function and simplify their
manipulation by providing some degree of encapsulation. It uses the `render`
function internally to render templates.
This allows to define prompts as the return value of a function and simplify
their manipulation by providing some degree of encapsulation. It uses the
`render` function internally to render templates.

>>> import outlines
>>> import prompts
>>>
>>> @outlines.prompt
>>> @prompts.template
>>> def build_prompt(question):
... "I have a ${question}"
... return "I have a {{question}}"
...
>>> prompt = build_prompt("How are you?")

This API can also be helpful in an "agent" context where parts of the prompt
are set when the agent is initialized and never modified later. In this situation
we can partially apply the prompt function at initialization.

>>> import outlines
>>> import functools as ft
>>> import prompts
...
>>> @outlines.prompt
>>> @prompts.template
... def solve_task(name: str, objective: str, task: str):
... '''Your name is {{name}}.
... return '''Your name is {{name}}.
.. Your overall objective is to {{objective}}.
... Please solve the following task: {{task}}'''
...
>>> hal = ft.partial(solve_task, "HAL", "Travel to Jupiter")

Returns
-------
A `Prompt` callable class which will render the template when called.
A `Template` callable class which will render the template when called.

"""
signature = inspect.signature(fn)

# The docstring contains the template that will be rendered to be used
# as a prompt to the language model.
docstring = fn.__doc__
if docstring is None:
raise TypeError("Could not find a template in the function's docstring.")

template = cast(str, docstring)

return Template(template, signature)
return Template(fn, signature)


@lru_cache
Expand All @@ -151,28 +155,24 @@ def render(
model_name: Optional[str] = None,
**values: Optional[Dict[str, Hashable]],
) -> str:
r"""Parse a Jinaj2 template and translate it into an Outlines graph.
r"""Parse a Jinaj2 template and renders it using the passed values.

This function removes extra whitespaces and linebreaks from templates to
allow users to enter prompts more naturally than if they used Python's
constructs directly. See the examples for a detailed explanation.

We also define the `bos` and `eos` special variables which, when used, will
be replaced by the model's BOS and EOS tokens respectively. This allows you
to write prompts that are model-agnostic.

Examples
--------

Outlines follow Jinja2's syntax

>>> import outlines
>>> outline = outlines.render("I like {{food}} and {{sport}}", food="tomatoes", sport="tennis")
>>> from prompts import render
>>> render("I like {{food}} and {{sport}}", food="tomatoes", sport="tennis")
I like tomatoes and tennis

If the first line of the template is empty, `render` removes it

>>> from outlines import render
>>> from prompts import render
>>>
>>> tpl = '''
... A new string'''
Expand Down Expand Up @@ -261,28 +261,12 @@ def render(
# used to continue to the next line without linebreak.
cleaned_template = re.sub(r"(?![\r\n])(\b\s+)", " ", cleaned_template)

# Warn the user when the model is not present in the special token registry
if model_name not in SPECIAL_TOKENS:
warnings.warn(
UserWarning(
f"The model {model_name} is not present in the special token registry."
"As a result, EOS and BOS tokens will be rendered as the empty string."
"Please open an issue: https://github.com/outlines-dev/prompts/issues"
"And ask for the model to be added to the registry."
)
)

env = Environment(
trim_blocks=True,
lstrip_blocks=True,
keep_trailing_newline=True,
undefined=StrictUndefined,
)
env.globals["bos"] = SPECIAL_TOKENS.get(model_name, Special()).sequence.begin
env.globals["eos"] = SPECIAL_TOKENS.get(model_name, Special()).sequence.end
env.globals["user"] = SPECIAL_TOKENS.get(model_name, Special()).user
env.globals["assistant"] = SPECIAL_TOKENS.get(model_name, Special()).assistant
env.globals["system"] = SPECIAL_TOKENS.get(model_name, Special()).system
jinja_template = env.from_string(cleaned_template)

return jinja_template.render(**values)
29 changes: 0 additions & 29 deletions prompts/tokens.py

This file was deleted.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ packages = ["prompts"]
write_to = "prompts/_version.py"

[project.optional-dependencies]
test = ["pre-commit", "pytest"]
test = ["pre-commit", "pytest", "pytest-benchmark"]
docs = [
"mkdocs",
"mkdocs-material",
Expand Down
Loading
Loading