diff --git a/docs/reference/special_tokens.md b/docs/reference/special_tokens.md index b98c736..9137ead 100644 --- a/docs/reference/special_tokens.md +++ b/docs/reference/special_tokens.md @@ -6,6 +6,9 @@ This means that one needs to write a new prompt each time they use a new model, only replacing these special tokens. This is error-prone and leads to duplicated work. + +## Beginning and end of sequences + `prompts` provides special variables in its templates that allows user to use special tokens in their prompts in a model-agnotic way: ```python @@ -29,3 +32,21 @@ print(a_simple_prompt["google/gemma-2-9b"]("question")) The registry is currently limited to a few models. Please [open an issue](https://github.com/outlines-dev/prompts/issues) if you want to use `prompts` with a model that is not currently in the registry. + + +## Chat and Instruct models + +`prompts` also provides special variables `user`, `assistant` and `system` that are related to chat workflows, so you can design prompts with a chat format in a model-agnostic way: + +```python +import prompts + + +@prompts.template +def simple_prompt(favorite: str): + """{{ bos + user.begin}} What is your favorite {{favorite + '? ' + user.end}} + {{ assistant.begin }} + """ +``` + +Chat templates are so idiosyncractic, however, that we recommend using the `Chat` class to format according to chat templates. diff --git a/prompts/templates.py b/prompts/templates.py index 393db19..a7f5a8c 100644 --- a/prompts/templates.py +++ b/prompts/templates.py @@ -3,10 +3,12 @@ import warnings from dataclasses import dataclass, field from functools import lru_cache -from typing import Callable, Dict, Hashable, Optional, Tuple, cast +from typing import Callable, Dict, Hashable, Optional, cast from jinja2 import Environment, StrictUndefined +from prompts.tokens import SPECIAL_TOKENS, Special + @dataclass class Template: @@ -276,17 +278,11 @@ def render( keep_trailing_newline=True, undefined=StrictUndefined, ) - env.globals["bos"] = SPECIAL_TOKENS.get(model_name, ("", ""))[0] - env.globals["eos"] = SPECIAL_TOKENS.get(model_name, ("", ""))[1] + env.globals["bos"] = SPECIAL_TOKENS.get(model_name, Special()).sequence.begin + env.globals["eos"] = SPECIAL_TOKENS.get(model_name, Special()).sequence.end + env.globals["user"] = SPECIAL_TOKENS.get(model_name, Special()).user + env.globals["assistant"] = SPECIAL_TOKENS.get(model_name, Special()).assistant + env.globals["system"] = SPECIAL_TOKENS.get(model_name, Special()).system jinja_template = env.from_string(cleaned_template) return jinja_template.render(**values) - - -# (BOS, EOS) -SPECIAL_TOKENS: Dict[Optional[str], Tuple[str, str]] = { - None: ("", ""), - "google/gemma-2-9b": ("", ""), - "openai-community/gpt2": ("", "<|endoftext|>"), - "mistralai/Mistral-7B-v0.1": ("", ""), -} diff --git a/prompts/tokens.py b/prompts/tokens.py new file mode 100644 index 0000000..953ad9f --- /dev/null +++ b/prompts/tokens.py @@ -0,0 +1,29 @@ +from dataclasses import dataclass +from typing import Dict, Optional + + +@dataclass +class Limits: + begin: str = "" + end: str = "" + + +@dataclass +class Special: + sequence: Limits = Limits("", "") + user: Limits = Limits("", "") + assistant: Limits = Limits("", "") + system: Limits = Limits("", "") + + +SPECIAL_TOKENS: Dict[Optional[str], Special] = { + None: Special(), + "google/gemma-2-9b": Special(Limits("", "")), + "openai-community/gpt2": Special(Limits("", "<|endoftext|>")), + "mistralai/Mistral-7B-v0.1": Special(Limits("", "")), + "mistralai/Mistral-7B-Instruct-v0.1": Special( + Limits("", ""), + Limits("[INST]", "[/INST]"), + Limits("", ""), + ), +} diff --git a/tests/test_tokens.py b/tests/test_tokens.py new file mode 100644 index 0000000..c9391e9 --- /dev/null +++ b/tests/test_tokens.py @@ -0,0 +1,6 @@ +from prompts.tokens import Special + + +def test_simple(): + special = Special() + assert special.assistant.begin == ""