Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically render user, assistant, system, special tokens #8

Merged
merged 1 commit into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions docs/reference/special_tokens.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ This means that one needs to write a new prompt each time they use a new model,
only replacing these special tokens. This is error-prone and leads to duplicated
work.


## Beginning and end of sequences

`prompts` provides special variables in its templates that allows user to use special tokens in their prompts in a model-agnotic way:

```python
Expand All @@ -29,3 +32,21 @@ print(a_simple_prompt["google/gemma-2-9b"]("question"))

The registry is currently limited to a few models. Please [open an issue](https://github.com/outlines-dev/prompts/issues) if you
want to use `prompts` with a model that is not currently in the registry.


## Chat and Instruct models

`prompts` also provides special variables `user`, `assistant` and `system` that are related to chat workflows, so you can design prompts with a chat format in a model-agnostic way:

```python
import prompts


@prompts.template
def simple_prompt(favorite: str):
"""{{ bos + user.begin}} What is your favorite {{favorite + '? ' + user.end}}
{{ assistant.begin }}
"""
```

Chat templates are so idiosyncractic, however, that we recommend using the `Chat` class to format according to chat templates.
20 changes: 8 additions & 12 deletions prompts/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import warnings
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Callable, Dict, Hashable, Optional, Tuple, cast
from typing import Callable, Dict, Hashable, Optional, cast

from jinja2 import Environment, StrictUndefined

from prompts.tokens import SPECIAL_TOKENS, Special


@dataclass
class Template:
Expand Down Expand Up @@ -276,17 +278,11 @@ def render(
keep_trailing_newline=True,
undefined=StrictUndefined,
)
env.globals["bos"] = SPECIAL_TOKENS.get(model_name, ("", ""))[0]
env.globals["eos"] = SPECIAL_TOKENS.get(model_name, ("", ""))[1]
env.globals["bos"] = SPECIAL_TOKENS.get(model_name, Special()).sequence.begin
env.globals["eos"] = SPECIAL_TOKENS.get(model_name, Special()).sequence.end
env.globals["user"] = SPECIAL_TOKENS.get(model_name, Special()).user
env.globals["assistant"] = SPECIAL_TOKENS.get(model_name, Special()).assistant
env.globals["system"] = SPECIAL_TOKENS.get(model_name, Special()).system
jinja_template = env.from_string(cleaned_template)

return jinja_template.render(**values)


# (BOS, EOS)
SPECIAL_TOKENS: Dict[Optional[str], Tuple[str, str]] = {
None: ("", ""),
"google/gemma-2-9b": ("<bos>", "<eos>"),
"openai-community/gpt2": ("", "<|endoftext|>"),
"mistralai/Mistral-7B-v0.1": ("<s>", "</s>"),
}
29 changes: 29 additions & 0 deletions prompts/tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from dataclasses import dataclass
from typing import Dict, Optional


@dataclass
class Limits:
begin: str = ""
end: str = ""


@dataclass
class Special:
sequence: Limits = Limits("", "")
user: Limits = Limits("", "")
assistant: Limits = Limits("", "")
system: Limits = Limits("", "")


SPECIAL_TOKENS: Dict[Optional[str], Special] = {
None: Special(),
"google/gemma-2-9b": Special(Limits("<bos>", "<eos>")),
"openai-community/gpt2": Special(Limits("", "<|endoftext|>")),
"mistralai/Mistral-7B-v0.1": Special(Limits("<s>", "</s>")),
"mistralai/Mistral-7B-Instruct-v0.1": Special(
Limits("<s>", "</s>"),
Limits("[INST]", "[/INST]"),
Limits("", "</s>"),
),
}
6 changes: 6 additions & 0 deletions tests/test_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from prompts.tokens import Special


def test_simple():
special = Special()
assert special.assistant.begin == ""
Loading