Skip to content

Commit

Permalink
Automatically render user, assistant, system variables
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Jul 31, 2024
1 parent d2eeaf9 commit b47c9ba
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 12 deletions.
21 changes: 21 additions & 0 deletions docs/reference/special_tokens.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ This means that one needs to write a new prompt each time they use a new model,
only replacing these special tokens. This is error-prone and leads to duplicated
work.


## Beginning and end of sequences

`prompts` provides special variables in its templates that allows user to use special tokens in their prompts in a model-agnotic way:

```python
Expand All @@ -29,3 +32,21 @@ print(a_simple_prompt["google/gemma-2-9b"]("question"))

The registry is currently limited to a few models. Please [open an issue](https://github.com/outlines-dev/prompts/issues) if you
want to use `prompts` with a model that is not currently in the registry.


## Chat and Instruct models

`prompts` also provides special variables `user`, `assistant` and `system` that are related to chat workflows, so you can design prompts with a chat format in a model-agnostic way:

```python
import prompts


@prompts.template
def simple_prompt(favorite: str):
"""{{ bos + user.begin}} What is your favorite {{favorite + '? ' + user.end}}
{{ assistant.begin }}
"""
```

Chat templates are so idiosyncractic, however, that we recommend using the `Chat` class to format according to chat templates.
20 changes: 8 additions & 12 deletions prompts/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import warnings
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Callable, Dict, Hashable, Optional, Tuple, cast
from typing import Callable, Dict, Hashable, Optional, cast

from jinja2 import Environment, StrictUndefined

from prompts.tokens import SPECIAL_TOKENS, Special


@dataclass
class Template:
Expand Down Expand Up @@ -276,17 +278,11 @@ def render(
keep_trailing_newline=True,
undefined=StrictUndefined,
)
env.globals["bos"] = SPECIAL_TOKENS.get(model_name, ("", ""))[0]
env.globals["eos"] = SPECIAL_TOKENS.get(model_name, ("", ""))[1]
env.globals["bos"] = SPECIAL_TOKENS.get(model_name, Special()).sequence.begin
env.globals["eos"] = SPECIAL_TOKENS.get(model_name, Special()).sequence.end
env.globals["user"] = SPECIAL_TOKENS.get(model_name, Special()).user
env.globals["assistant"] = SPECIAL_TOKENS.get(model_name, Special()).assistant
env.globals["system"] = SPECIAL_TOKENS.get(model_name, Special()).system
jinja_template = env.from_string(cleaned_template)

return jinja_template.render(**values)


# (BOS, EOS)
SPECIAL_TOKENS: Dict[Optional[str], Tuple[str, str]] = {
None: ("", ""),
"google/gemma-2-9b": ("<bos>", "<eos>"),
"openai-community/gpt2": ("", "<|endoftext|>"),
"mistralai/Mistral-7B-v0.1": ("<s>", "</s>"),
}
29 changes: 29 additions & 0 deletions prompts/tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from dataclasses import dataclass
from typing import Dict, Optional


@dataclass
class Limits:
begin: str = ""
end: str = ""


@dataclass
class Special:
sequence: Limits = Limits("", "")
user: Limits = Limits("", "")
assistant: Limits = Limits("", "")
system: Limits = Limits("", "")


SPECIAL_TOKENS: Dict[Optional[str], Special] = {
None: Special(),
"google/gemma-2-9b": Special(Limits("<bos>", "<eos>")),
"openai-community/gpt2": Special(Limits("", "<|endoftext|>")),
"mistralai/Mistral-7B-v0.1": Special(Limits("<s>", "</s>")),
"mistralai/Mistral-7B-Instruct-v0.1": Special(
Limits("<s>", "</s>"),
Limits("[INST]", "[/INST]"),
Limits("", "</s>"),
),
}
6 changes: 6 additions & 0 deletions tests/test_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from prompts.tokens import Special


def test_simple():
special = Special()
assert special.assistant.begin == ""

0 comments on commit b47c9ba

Please sign in to comment.