diff --git a/README.md b/README.md index 50402a0..55301b3 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ from prompts import template @template def few_shots(instructions, examples, question): - """{{ instructions }} + return """{{ instructions }} Examples -------- diff --git a/docs/reference/template.md b/docs/reference/template.md index 14baaf1..0651280 100644 --- a/docs/reference/template.md +++ b/docs/reference/template.md @@ -25,6 +25,11 @@ allows to easily compose complex prompts. Prompt functions are opinionated when it comes to prompt rendering. These opinions are meant to avoid common prompting errors, but can have unintended consequences if you are doing something unusual. We advise to always print the prompt before using it. You can also [read the reference](#formatting-conventions) section if you want to know more. + +!!! note "Performance" + + Prompt templates introduce some overhead compared to standard Python functions, although the rendering time is still very reasonable. In the unlikely scenario where rendering templates are a bottleneck you can replace them with functions that use standard string manipulation. + ## Your first prompt The following snippet showcases a very simple prompt. The variables between @@ -38,7 +43,7 @@ will pass to the prompt function. @prompts.template def greetings(name, question): - """Hello, {{ name }}! + return """Hello, {{ name }}! {{ question }} """ @@ -62,7 +67,7 @@ If a variable is missing in the function's arguments, Jinja2 will throw an `Unde @prompts.template def greetings(name): - """Hello, {{ surname }}!""" + return """Hello, {{ surname }}!""" prompt = greetings("user") ``` @@ -94,7 +99,7 @@ Prompt functions are functions, and thus can be imported from other modules: @prompts.template def greetings(name, question): - """Hello, {{ name }}! + return """Hello, {{ name }}! {{ question }} """ ``` @@ -128,7 +133,7 @@ keys `question` and `answer` to the prompt function: @prompts.template def few_shots(instructions, examples, question): - """{{ instructions }} + return """{{ instructions }} Examples -------- @@ -207,12 +212,12 @@ below does not matter for formatting: @prompts.template def prompt1(): - """My prompt + return """My prompt """ @prompts.template def prompt2(): - """ + return """ My prompt """ @@ -236,20 +241,20 @@ Indentation is relative to the second line of the docstring, and leading spaces @prompts.template def example1(): - """First line + return """First line Second line """ @prompts.template def example2(): - """ + return """ Second line Third line """ @prompts.template def example3(): - """ + return """ Second line Third line """ @@ -285,7 +290,7 @@ You can use the backslash `\` to break a long line of text. It will render as a @prompts.template def example(): - """ + return """ Break in \ several lines \ But respect the indentation diff --git a/prompts/templates.py b/prompts/templates.py index a7f5a8c..8ad7154 100644 --- a/prompts/templates.py +++ b/prompts/templates.py @@ -1,23 +1,25 @@ import inspect import re -import warnings from dataclasses import dataclass, field from functools import lru_cache -from typing import Callable, Dict, Hashable, Optional, cast +from typing import Callable, Dict, Hashable, Optional from jinja2 import Environment, StrictUndefined -from prompts.tokens import SPECIAL_TOKENS, Special - @dataclass class Template: """Represents a prompt template. - A prompt template is a callable that, given a Jinja2 template and a set of values, - renders the template using those values. It is recommended to instantiate `Temaplate` - using the `template` decorator, which extracts the template from the function's - docstring and its variables from the function's signature. + A prompt template is a callable with renders the template returned by the + function using the values that are passed to it. It is recommended to + instantiate `Template` using the `template` decorator. + + >>> import prompts + ... + ... @prompts.template + ... def prompt(name: str) -> str: + ... return "My name is {{name}}" It is not uncommon that, for the same taks, different models will perform better with different prompt. Here we thus allow to dispatch to associate a @@ -25,13 +27,24 @@ class Template: `Template` instance is thus also a registry that associates model names to other templates. + >>> @prompt.register("gpt2") + ... def prompt_gpt2(name: str) -> str: + ... return "Hi GPT2! My name is {{name}}" + + The name of the model can then be passed to the render function along with + the model name and the values of the arguments: + + >>> from prompts import render + ... + ... render(prompt, "gpt2", name="Dan") + >>> "Hi GPT2! My name is Dan" Attributes ---------- - template - The template to render. + fn + The function that returns a template. signature - The prompt function's signature. + The function's signature. model The model the `Template` is associated with. Defaults to `None`. registry @@ -40,7 +53,7 @@ class Template: """ - template: str + fn: Callable signature: inspect.Signature model: Optional[str] = None registry: Dict[str, Callable] = field(default_factory=dict) @@ -55,10 +68,10 @@ def __call__(self, *args, **kwargs) -> str: """ bound_arguments = self.signature.bind(*args, **kwargs) bound_arguments.apply_defaults() - return render(self.template, self.model, **bound_arguments.arguments) - def __str__(self): - return self.template + template = self.fn(**bound_arguments.arguments) + + return render(template, self.model, **bound_arguments.arguments) def __getitem__(self, model_name: str): """Get the prompt template corresponding to a model name. @@ -84,7 +97,7 @@ def __getitem__(self, model_name: str): def register(self, model_name: str): """Register the prompt template, as represented by a prompt function, - for the model name. + for a given model `model_name`. """ @@ -98,17 +111,17 @@ def wrapper(fn: Callable): def template(fn: Callable) -> Template: - """Decorate a function that contains a prompt template. + """Decorate a function that returns a prompt template. - This allows to define prompts in the docstring of a function and simplify their - manipulation by providing some degree of encapsulation. It uses the `render` - function internally to render templates. + This allows to define prompts as the return value of a function and simplify + their manipulation by providing some degree of encapsulation. It uses the + `render` function internally to render templates. - >>> import outlines + >>> import prompts >>> - >>> @outlines.prompt + >>> @prompts.template >>> def build_prompt(question): - ... "I have a ${question}" + ... return "I have a {{question}}" ... >>> prompt = build_prompt("How are you?") @@ -116,12 +129,11 @@ def template(fn: Callable) -> Template: are set when the agent is initialized and never modified later. In this situation we can partially apply the prompt function at initialization. - >>> import outlines - >>> import functools as ft + >>> import prompts ... - >>> @outlines.prompt + >>> @prompts.template ... def solve_task(name: str, objective: str, task: str): - ... '''Your name is {{name}}. + ... return '''Your name is {{name}}. .. Your overall objective is to {{objective}}. ... Please solve the following task: {{task}}''' ... @@ -129,20 +141,12 @@ def template(fn: Callable) -> Template: Returns ------- - A `Prompt` callable class which will render the template when called. + A `Template` callable class which will render the template when called. """ signature = inspect.signature(fn) - # The docstring contains the template that will be rendered to be used - # as a prompt to the language model. - docstring = fn.__doc__ - if docstring is None: - raise TypeError("Could not find a template in the function's docstring.") - - template = cast(str, docstring) - - return Template(template, signature) + return Template(fn, signature) @lru_cache @@ -151,28 +155,24 @@ def render( model_name: Optional[str] = None, **values: Optional[Dict[str, Hashable]], ) -> str: - r"""Parse a Jinaj2 template and translate it into an Outlines graph. + r"""Parse a Jinaj2 template and renders it using the passed values. This function removes extra whitespaces and linebreaks from templates to allow users to enter prompts more naturally than if they used Python's constructs directly. See the examples for a detailed explanation. - We also define the `bos` and `eos` special variables which, when used, will - be replaced by the model's BOS and EOS tokens respectively. This allows you - to write prompts that are model-agnostic. - Examples -------- Outlines follow Jinja2's syntax - >>> import outlines - >>> outline = outlines.render("I like {{food}} and {{sport}}", food="tomatoes", sport="tennis") + >>> from prompts import render + >>> render("I like {{food}} and {{sport}}", food="tomatoes", sport="tennis") I like tomatoes and tennis If the first line of the template is empty, `render` removes it - >>> from outlines import render + >>> from prompts import render >>> >>> tpl = ''' ... A new string''' @@ -261,28 +261,12 @@ def render( # used to continue to the next line without linebreak. cleaned_template = re.sub(r"(?![\r\n])(\b\s+)", " ", cleaned_template) - # Warn the user when the model is not present in the special token registry - if model_name not in SPECIAL_TOKENS: - warnings.warn( - UserWarning( - f"The model {model_name} is not present in the special token registry." - "As a result, EOS and BOS tokens will be rendered as the empty string." - "Please open an issue: https://github.com/outlines-dev/prompts/issues" - "And ask for the model to be added to the registry." - ) - ) - env = Environment( trim_blocks=True, lstrip_blocks=True, keep_trailing_newline=True, undefined=StrictUndefined, ) - env.globals["bos"] = SPECIAL_TOKENS.get(model_name, Special()).sequence.begin - env.globals["eos"] = SPECIAL_TOKENS.get(model_name, Special()).sequence.end - env.globals["user"] = SPECIAL_TOKENS.get(model_name, Special()).user - env.globals["assistant"] = SPECIAL_TOKENS.get(model_name, Special()).assistant - env.globals["system"] = SPECIAL_TOKENS.get(model_name, Special()).system jinja_template = env.from_string(cleaned_template) return jinja_template.render(**values) diff --git a/prompts/tokens.py b/prompts/tokens.py deleted file mode 100644 index c9f3a98..0000000 --- a/prompts/tokens.py +++ /dev/null @@ -1,29 +0,0 @@ -from dataclasses import dataclass, field -from typing import Dict, Optional - - -@dataclass -class Limits: - begin: str = "" - end: str = "" - - -@dataclass -class Special: - sequence: Limits = field(default_factory=lambda: Limits()) - user: Limits = field(default_factory=lambda: Limits()) - assistant: Limits = field(default_factory=lambda: Limits()) - system: Limits = field(default_factory=lambda: Limits()) - - -SPECIAL_TOKENS: Dict[Optional[str], Special] = { - None: Special(), - "google/gemma-2-9b": Special(Limits("", "")), - "openai-community/gpt2": Special(Limits("", "<|endoftext|>")), - "mistralai/Mistral-7B-v0.1": Special(Limits("", "")), - "mistralai/Mistral-7B-Instruct-v0.1": Special( - Limits("", ""), - Limits("[INST]", "[/INST]"), - Limits("", ""), - ), -} diff --git a/pyproject.toml b/pyproject.toml index b7a049f..326a183 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ packages = ["prompts"] write_to = "prompts/_version.py" [project.optional-dependencies] -test = ["pre-commit", "pytest"] +test = ["pre-commit", "pytest", "pytest-benchmark"] docs = [ "mkdocs", "mkdocs-material", diff --git a/tests/test_templates.py b/tests/test_templates.py index 8bd295d..9e4e8f9 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -1,3 +1,6 @@ +import random +import string + import pytest import prompts @@ -129,9 +132,8 @@ def test_render_jinja(): def test_prompt_basic(): @prompts.template def test_tpl(variable): - """{{variable}} test""" + return """{{variable}} test""" - assert test_tpl.template == "{{variable}} test" assert list(test_tpl.signature.parameters.keys()) == ["variable"] with pytest.raises(TypeError): @@ -145,7 +147,7 @@ def test_tpl(variable): @prompts.template def test_single_quote_tpl(variable): - "${variable} test" + return "{{variable}} test" p = test_tpl("test") assert p == "test test" @@ -154,9 +156,8 @@ def test_single_quote_tpl(variable): def test_prompt_kwargs(): @prompts.template def test_kwarg_tpl(var, other_var="other"): - """{{var}} and {{other_var}}""" + return """{{var}} and {{other_var}}""" - assert test_kwarg_tpl.template == "{{var}} and {{other_var}}" assert list(test_kwarg_tpl.signature.parameters.keys()) == ["var", "other_var"] p = test_kwarg_tpl("test") @@ -169,30 +170,16 @@ def test_kwarg_tpl(var, other_var="other"): assert p == "test and test" -def test_no_prompt(): - with pytest.raises(TypeError, match="template"): - - @prompts.template - def test_empty(variable): - pass - - with pytest.raises(TypeError, match="template"): - - @prompts.template - def test_only_code(variable): - return variable - - @pytest.mark.filterwarnings("ignore: The model") def test_dispatch(): @prompts.template def simple_prompt(query: str): - """{{ query }}""" + return """{{ query }}""" @simple_prompt.register("provider/name") def simple_prompt_name(query: str): - """name: {{ query }}""" + return """name: {{ query }}""" assert list(simple_prompt.registry.keys()) == ["provider/name"] assert callable(simple_prompt) @@ -210,22 +197,40 @@ def simple_prompt_name(query: str): assert simple_prompt["provider/name"]("test") == "name: test" -def test_special_tokens(): - - @prompts.template - def simple_prompt(query: str): - """{{ bos + query + eos }}""" - - assert simple_prompt("test") == "test" - assert simple_prompt["openai-community/gpt2"]("test") == "test<|endoftext|>" - assert simple_prompt["mistralai/Mistral-7B-v0.1"]("test") == "test" - - -def test_warn(): +def test_benchmark_template_render(benchmark): @prompts.template - def simple_prompt(): - """test""" - - with pytest.warns(UserWarning, match="not present in the special token"): - simple_prompt["non-existent-model"]() + def test_tpl(var0, var1): + prompt = var0 + return prompt + """{{var1}} test""" + + def setup(): + """We generate random strings to make sure we don't hit any potential cache.""" + length = 10 + var0 = "".join( + random.choice(string.ascii_uppercase + string.digits) for _ in range(length) + ) + var1 = "".join( + random.choice(string.ascii_uppercase + string.digits) for _ in range(length) + ) + return (var0, var1), {} + + benchmark.pedantic(test_tpl, setup=setup, rounds=500) + + +def test_benchmark_template_function(benchmark): + + def test_tpl(var0, var1): + return var0 + f"{var1} test" + + def setup(): + length = 10 + var0 = "".join( + random.choice(string.ascii_uppercase + string.digits) for _ in range(length) + ) + var1 = "".join( + random.choice(string.ascii_uppercase + string.digits) for _ in range(length) + ) + return (var0, var1), {} + + benchmark.pedantic(test_tpl, setup=setup, rounds=500) diff --git a/tests/test_tokens.py b/tests/test_tokens.py deleted file mode 100644 index c9391e9..0000000 --- a/tests/test_tokens.py +++ /dev/null @@ -1,6 +0,0 @@ -from prompts.tokens import Special - - -def test_simple(): - special = Special() - assert special.assistant.begin == ""