Skip to content

Commit

Permalink
Automatically render special EOS and BOS tokens
Browse files Browse the repository at this point in the history
We define Jinja2 `eos` and `bos` global variables that are rendered
as EOS and BOS tokens.
  • Loading branch information
rlouf committed Jul 30, 2024
1 parent 32c4883 commit ec36ba8
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 3 deletions.
31 changes: 31 additions & 0 deletions docs/reference/special_tokens.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Handle special tokens

Tokens that indicate the beginnning of a sequence, an end of sequence, that
delineate user and assistant turns in a conversation, etc. are model-specific.
This means that one needs to write a new prompt each time they use a new model,
only replacing these special tokens. This is error-prone and leads to duplicated
work.

`prompts` provides special variables in its templates that allows user to use special tokens in their prompts in a model-agnotic way:

```python
import prompts


@prompts.template
def a_simple_prompt(query: str):
"""{{ bos + query + eos }}"""


print(a_simple_prompt["mistralai/Mistral-7B-v0.1"]("question"))
# <s>question</s>

print(a_simple_prompt["google/gemma-2-9b"]("question"))
# <bos>question<eos>
```


!!! note "Registry"

The registry is currently limited to a few models. Please [open an issue](https://github.com/outlines-dev/prompts/issues) if you
want to use `prompts` with a model that is not currently in the registry.
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,4 @@ nav:
- reference/index.md
- Prompt template: reference/template.md
- Dispatch: reference/dispatch.md
- Special tokens: reference/special_tokens.md
42 changes: 39 additions & 3 deletions prompts/templates.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import inspect
import re
import warnings
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Callable, Dict, Hashable, Optional, cast
from typing import Callable, Dict, Hashable, Optional, Tuple, cast

from jinja2 import Environment, StrictUndefined

Expand All @@ -29,6 +30,8 @@ class Template:
The template to render.
signature
The prompt function's signature.
model
The model the `Template` is associated with. Defaults to `None`.
registry
Registry that maps function names to their respective `Template`
instances.
Expand All @@ -50,7 +53,7 @@ def __call__(self, *args, **kwargs) -> str:
"""
bound_arguments = self.signature.bind(*args, **kwargs)
bound_arguments.apply_defaults()
return render(self.template, **bound_arguments.arguments)
return render(self.template, self.model, **bound_arguments.arguments)

def __str__(self):
return self.template
Expand All @@ -74,6 +77,7 @@ def __getitem__(self, model_name: str):
try:
return self.registry[model_name]
except KeyError:
self.model = model_name
return self

def register(self, model_name: str):
Expand Down Expand Up @@ -140,13 +144,21 @@ def template(fn: Callable) -> Template:


@lru_cache
def render(template: str, **values: Optional[Dict[str, Hashable]]) -> str:
def render(
template: str,
model_name: Optional[str] = None,
**values: Optional[Dict[str, Hashable]],
) -> str:
r"""Parse a Jinaj2 template and translate it into an Outlines graph.
This function removes extra whitespaces and linebreaks from templates to
allow users to enter prompts more naturally than if they used Python's
constructs directly. See the examples for a detailed explanation.
We also define the `bos` and `eos` special variables which, when used, will
be replaced by the model's BOS and EOS tokens respectively. This allows you
to write prompts that are model-agnostic.
Examples
--------
Expand Down Expand Up @@ -223,6 +235,8 @@ def render(template: str, **values: Optional[Dict[str, Hashable]]) -> str:
----------
template
A string that contains a template written with the Jinja2 syntax.
model_name
The name of the model to which the rendered string will be passed.
**values
Map from the variables in the template to their value.
Expand All @@ -245,12 +259,34 @@ def render(template: str, **values: Optional[Dict[str, Hashable]]) -> str:
# used to continue to the next line without linebreak.
cleaned_template = re.sub(r"(?![\r\n])(\b\s+)", " ", cleaned_template)

# Warn the user when the model is not present in the special token registry
if model_name not in SPECIAL_TOKENS:
warnings.warn(
UserWarning(
f"The model {model_name} is not present in the special token registry."
"As a result, EOS and BOS tokens will be rendered as the empty string."
"Please open an issue: https://github.com/outlines-dev/prompts/issues"
"And ask for the model to be added to the registry."
)
)

env = Environment(
trim_blocks=True,
lstrip_blocks=True,
keep_trailing_newline=True,
undefined=StrictUndefined,
)
env.globals["bos"] = SPECIAL_TOKENS.get(model_name, ("", ""))[0]
env.globals["eos"] = SPECIAL_TOKENS.get(model_name, ("", ""))[1]
jinja_template = env.from_string(cleaned_template)

return jinja_template.render(**values)


# (BOS, EOS)
SPECIAL_TOKENS: Dict[Optional[str], Tuple[str, str]] = {
None: ("", ""),
"google/gemma-2-9b": ("<bos>", "<eos>"),
"openai-community/gpt2": ("", "<|endoftext|>"),
"mistralai/Mistral-7B-v0.1": ("<s>", "</s>"),
}
22 changes: 22 additions & 0 deletions tests/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def test_only_code(variable):
return variable


@pytest.mark.filterwarnings("ignore: The model")
def test_dispatch():

@prompts.template
Expand All @@ -207,3 +208,24 @@ def simple_prompt_name(query: str):
assert simple_prompt("test") == "test"
assert simple_prompt["gpt2"]("test") == "test"
assert simple_prompt["provider/name"]("test") == "name: test"


def test_special_tokens():

@prompts.template
def simple_prompt(query: str):
"""{{ bos + query + eos }}"""

assert simple_prompt("test") == "test"
assert simple_prompt["openai-community/gpt2"]("test") == "test<|endoftext|>"
assert simple_prompt["mistralai/Mistral-7B-v0.1"]("test") == "<s>test</s>"


def test_warn():

@prompts.template
def simple_prompt():
"""test"""

with pytest.warns(UserWarning, match="not present in the special token"):
simple_prompt["non-existent-model"]()

0 comments on commit ec36ba8

Please sign in to comment.