From 1b914c11ff6726ec41a8b9f586e2a53efefd6e5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 26 Sep 2024 15:30:43 +0200 Subject: [PATCH] Make funtion templates real functions Prompts remplates are currently contained in the docstring of decorated functions. The main issue with this is that prompt templates cannot be composed. In this commit we instead require users to return the prompt template from the function. The template will then automatically be rendered using the values passed to the function. This is very flexible: some variables can be used inside the functions and not be present in the Jinja2 template that is returned, for instance: ```python import prompts @prompts.template def my_template(a, b): prompt = f'This is a first variable {a}' return prompt + "and a second {{b}}" ``` --- README.md | 2 +- docs/reference/template.md | 20 +++++++++---------- prompts/templates.py | 40 +++++++++++++++++--------------------- tests/test_templates.py | 30 +++++++--------------------- 4 files changed, 36 insertions(+), 56 deletions(-) diff --git a/README.md b/README.md index 50402a0..55301b3 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ from prompts import template @template def few_shots(instructions, examples, question): - """{{ instructions }} + return """{{ instructions }} Examples -------- diff --git a/docs/reference/template.md b/docs/reference/template.md index 14baaf1..0a6fd2c 100644 --- a/docs/reference/template.md +++ b/docs/reference/template.md @@ -38,7 +38,7 @@ will pass to the prompt function. @prompts.template def greetings(name, question): - """Hello, {{ name }}! + return """Hello, {{ name }}! {{ question }} """ @@ -62,7 +62,7 @@ If a variable is missing in the function's arguments, Jinja2 will throw an `Unde @prompts.template def greetings(name): - """Hello, {{ surname }}!""" + return """Hello, {{ surname }}!""" prompt = greetings("user") ``` @@ -94,7 +94,7 @@ Prompt functions are functions, and thus can be imported from other modules: @prompts.template def greetings(name, question): - """Hello, {{ name }}! + return """Hello, {{ name }}! {{ question }} """ ``` @@ -128,7 +128,7 @@ keys `question` and `answer` to the prompt function: @prompts.template def few_shots(instructions, examples, question): - """{{ instructions }} + return """{{ instructions }} Examples -------- @@ -207,12 +207,12 @@ below does not matter for formatting: @prompts.template def prompt1(): - """My prompt + return """My prompt """ @prompts.template def prompt2(): - """ + return """ My prompt """ @@ -236,20 +236,20 @@ Indentation is relative to the second line of the docstring, and leading spaces @prompts.template def example1(): - """First line + return """First line Second line """ @prompts.template def example2(): - """ + return """ Second line Third line """ @prompts.template def example3(): - """ + return """ Second line Third line """ @@ -285,7 +285,7 @@ You can use the backslash `\` to break a long line of text. It will render as a @prompts.template def example(): - """ + return """ Break in \ several lines \ But respect the indentation diff --git a/prompts/templates.py b/prompts/templates.py index a7f5a8c..56188f8 100644 --- a/prompts/templates.py +++ b/prompts/templates.py @@ -3,7 +3,7 @@ import warnings from dataclasses import dataclass, field from functools import lru_cache -from typing import Callable, Dict, Hashable, Optional, cast +from typing import Callable, Dict, Hashable, Optional from jinja2 import Environment, StrictUndefined @@ -15,7 +15,7 @@ class Template: """Represents a prompt template. A prompt template is a callable that, given a Jinja2 template and a set of values, - renders the template using those values. It is recommended to instantiate `Temaplate` + renders the template using those values. It is recommended to instantiate `Template` using the `template` decorator, which extracts the template from the function's docstring and its variables from the function's signature. @@ -40,11 +40,15 @@ class Template: """ - template: str signature: inspect.Signature + fn: Callable model: Optional[str] = None registry: Dict[str, Callable] = field(default_factory=dict) + def __init__(self, fn: Callable): + self.fn = fn + self.signature = inspect.signature(fn) + def __call__(self, *args, **kwargs) -> str: """Render and return the template. @@ -55,7 +59,10 @@ def __call__(self, *args, **kwargs) -> str: """ bound_arguments = self.signature.bind(*args, **kwargs) bound_arguments.apply_defaults() - return render(self.template, self.model, **bound_arguments.arguments) + + template = self.fn(**bound_arguments.arguments) + + return render(template, self.model, **bound_arguments.arguments) def __str__(self): return self.template @@ -104,11 +111,11 @@ def template(fn: Callable) -> Template: manipulation by providing some degree of encapsulation. It uses the `render` function internally to render templates. - >>> import outlines + >>> import prompts >>> - >>> @outlines.prompt + >>> @prompts.template >>> def build_prompt(question): - ... "I have a ${question}" + ... return "I have a {{question}}" ... >>> prompt = build_prompt("How are you?") @@ -116,12 +123,11 @@ def template(fn: Callable) -> Template: are set when the agent is initialized and never modified later. In this situation we can partially apply the prompt function at initialization. - >>> import outlines - >>> import functools as ft + >>> import prompts ... - >>> @outlines.prompt + >>> @prompts.template ... def solve_task(name: str, objective: str, task: str): - ... '''Your name is {{name}}. + ... return '''Your name is {{name}}. .. Your overall objective is to {{objective}}. ... Please solve the following task: {{task}}''' ... @@ -132,17 +138,7 @@ def template(fn: Callable) -> Template: A `Prompt` callable class which will render the template when called. """ - signature = inspect.signature(fn) - - # The docstring contains the template that will be rendered to be used - # as a prompt to the language model. - docstring = fn.__doc__ - if docstring is None: - raise TypeError("Could not find a template in the function's docstring.") - - template = cast(str, docstring) - - return Template(template, signature) + return Template(fn) @lru_cache diff --git a/tests/test_templates.py b/tests/test_templates.py index 8bd295d..4784134 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -129,9 +129,8 @@ def test_render_jinja(): def test_prompt_basic(): @prompts.template def test_tpl(variable): - """{{variable}} test""" + return """{{variable}} test""" - assert test_tpl.template == "{{variable}} test" assert list(test_tpl.signature.parameters.keys()) == ["variable"] with pytest.raises(TypeError): @@ -145,7 +144,7 @@ def test_tpl(variable): @prompts.template def test_single_quote_tpl(variable): - "${variable} test" + return "{{variable}} test" p = test_tpl("test") assert p == "test test" @@ -154,9 +153,8 @@ def test_single_quote_tpl(variable): def test_prompt_kwargs(): @prompts.template def test_kwarg_tpl(var, other_var="other"): - """{{var}} and {{other_var}}""" + return """{{var}} and {{other_var}}""" - assert test_kwarg_tpl.template == "{{var}} and {{other_var}}" assert list(test_kwarg_tpl.signature.parameters.keys()) == ["var", "other_var"] p = test_kwarg_tpl("test") @@ -169,30 +167,16 @@ def test_kwarg_tpl(var, other_var="other"): assert p == "test and test" -def test_no_prompt(): - with pytest.raises(TypeError, match="template"): - - @prompts.template - def test_empty(variable): - pass - - with pytest.raises(TypeError, match="template"): - - @prompts.template - def test_only_code(variable): - return variable - - @pytest.mark.filterwarnings("ignore: The model") def test_dispatch(): @prompts.template def simple_prompt(query: str): - """{{ query }}""" + return """{{ query }}""" @simple_prompt.register("provider/name") def simple_prompt_name(query: str): - """name: {{ query }}""" + return """name: {{ query }}""" assert list(simple_prompt.registry.keys()) == ["provider/name"] assert callable(simple_prompt) @@ -214,7 +198,7 @@ def test_special_tokens(): @prompts.template def simple_prompt(query: str): - """{{ bos + query + eos }}""" + return """{{ bos + query + eos }}""" assert simple_prompt("test") == "test" assert simple_prompt["openai-community/gpt2"]("test") == "test<|endoftext|>" @@ -225,7 +209,7 @@ def test_warn(): @prompts.template def simple_prompt(): - """test""" + return """test""" with pytest.warns(UserWarning, match="not present in the special token"): simple_prompt["non-existent-model"]()