Skip to content

Commit

Permalink
Remove special token management
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Sep 26, 2024
1 parent 82e8d17 commit fb5e355
Show file tree
Hide file tree
Showing 4 changed files with 0 additions and 79 deletions.
23 changes: 0 additions & 23 deletions prompts/templates.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import inspect
import re
import warnings
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Callable, Dict, Hashable, Optional

from jinja2 import Environment, StrictUndefined

from prompts.tokens import SPECIAL_TOKENS, Special


@dataclass
class Template:
Expand Down Expand Up @@ -151,10 +148,6 @@ def render(
allow users to enter prompts more naturally than if they used Python's
constructs directly. See the examples for a detailed explanation.
We also define the `bos` and `eos` special variables which, when used, will
be replaced by the model's BOS and EOS tokens respectively. This allows you
to write prompts that are model-agnostic.
Examples
--------
Expand Down Expand Up @@ -255,28 +248,12 @@ def render(
# used to continue to the next line without linebreak.
cleaned_template = re.sub(r"(?![\r\n])(\b\s+)", " ", cleaned_template)

# Warn the user when the model is not present in the special token registry
if model_name not in SPECIAL_TOKENS:
warnings.warn(
UserWarning(
f"The model {model_name} is not present in the special token registry."
"As a result, EOS and BOS tokens will be rendered as the empty string."
"Please open an issue: https://github.com/outlines-dev/prompts/issues"
"And ask for the model to be added to the registry."
)
)

env = Environment(
trim_blocks=True,
lstrip_blocks=True,
keep_trailing_newline=True,
undefined=StrictUndefined,
)
env.globals["bos"] = SPECIAL_TOKENS.get(model_name, Special()).sequence.begin
env.globals["eos"] = SPECIAL_TOKENS.get(model_name, Special()).sequence.end
env.globals["user"] = SPECIAL_TOKENS.get(model_name, Special()).user
env.globals["assistant"] = SPECIAL_TOKENS.get(model_name, Special()).assistant
env.globals["system"] = SPECIAL_TOKENS.get(model_name, Special()).system
jinja_template = env.from_string(cleaned_template)

return jinja_template.render(**values)
29 changes: 0 additions & 29 deletions prompts/tokens.py

This file was deleted.

21 changes: 0 additions & 21 deletions tests/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,24 +192,3 @@ def simple_prompt_name(query: str):
assert simple_prompt("test") == "test"
assert simple_prompt["gpt2"]("test") == "test"
assert simple_prompt["provider/name"]("test") == "name: test"


def test_special_tokens():

@prompts.template
def simple_prompt(query: str):
return """{{ bos + query + eos }}"""

assert simple_prompt("test") == "test"
assert simple_prompt["openai-community/gpt2"]("test") == "test<|endoftext|>"
assert simple_prompt["mistralai/Mistral-7B-v0.1"]("test") == "<s>test</s>"


def test_warn():

@prompts.template
def simple_prompt():
return """test"""

with pytest.warns(UserWarning, match="not present in the special token"):
simple_prompt["non-existent-model"]()
6 changes: 0 additions & 6 deletions tests/test_tokens.py

This file was deleted.

0 comments on commit fb5e355

Please sign in to comment.