Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hugging Face Chat Templates #45

Merged
merged 29 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
dc8a69f
feat(huggingface): use tokenizer.apply_chat_template
zhudotexe Feb 2, 2024
c2b0bfd
chore: bump deps
zhudotexe Feb 2, 2024
6e4522d
chore: bump version
zhudotexe Feb 2, 2024
82709d4
chore: format
zhudotexe Feb 2, 2024
d63fe2e
docs: fix docstring error with pydantic update
zhudotexe Feb 2, 2024
34d57c6
feat: ensure abstract attrs are set in concrete HuggingEngine
zhudotexe Feb 2, 2024
649e2f2
hack(huggingface): use a dummy user message to figure out the len of …
zhudotexe Feb 3, 2024
1957501
hack(huggingface): hack to count tokens of prompt starting with assis…
zhudotexe Feb 3, 2024
4024f50
fix: hack was broken because I can't read
zhudotexe Feb 3, 2024
e64cc90
Merge branch 'main' into hf-concrete-base
zhudotexe Feb 3, 2024
e0e3164
Merge branch 'main' into hf-concrete-base
zhudotexe Sep 5, 2024
6a167fe
chore: fix merge
zhudotexe Sep 5, 2024
1d536f3
chore: black
zhudotexe Sep 5, 2024
b751cb9
chore: isort
zhudotexe Sep 5, 2024
318e3c0
feat: chat template impl 1
zhudotexe Sep 10, 2024
ee1cea1
refactor: move chat template stuff to pipeline subclass
zhudotexe Sep 11, 2024
dd56c91
build: add tests for chattemplatepromptpipeline
zhudotexe Sep 12, 2024
59e1c3d
build: test promptpipeline equivalence
zhudotexe Sep 15, 2024
88490df
build: include HF_TOKEN
zhudotexe Sep 15, 2024
f218de1
chore(deps): llama requires protobuf
zhudotexe Sep 15, 2024
d537485
chore: more visible error message on padding inference fail
zhudotexe Sep 15, 2024
a0612dc
docs: concrete HF info
zhudotexe Sep 15, 2024
523022a
docs: more notes on chat templating
zhudotexe Sep 15, 2024
e1fe72d
chore: add basic module for chatting with HuggingEngine
zhudotexe Sep 16, 2024
4652156
feat: actually load chattemplatepromptpipeline if not passed in hfengine
zhudotexe Sep 16, 2024
aa4dbc2
feat: default tool use keys for hf chat templates
zhudotexe Sep 18, 2024
8e4048b
feat: use some handwritten dummy prompts for padding inference
zhudotexe Sep 20, 2024
a2d5680
chore(huggingface): better auto gpu mapping
zhudotexe Sep 20, 2024
80e262c
chore(huggingface): only auto set device_map if accelerate is installed
zhudotexe Sep 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ on: [ push, pull_request ]
# group: ${{ github.workflow }}-${{ github.ref }}
# cancel-in-progress: true

env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}

jobs:
build:
runs-on: ubuntu-latest
Expand Down
9 changes: 7 additions & 2 deletions docs/engines/huggingface.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@ HuggingFace
If your language model backend is available on HuggingFace or is compatible with ``transformers``'
``AutoModelForCausalLM`` interface, kani includes a base engine that implements a prediction pipeline.

.. versionadded:: 1.2.0
For most models that use a chat format, you won't even need to create a new engine class - kani will automatically
use a `Chat Template <https://huggingface.co/docs/transformers/main/en/chat_templating>`_ if a model has one
included.

.. versionadded:: 1.0.0
For most models that use a chat format, you won't even need to create a new engine class - instead, you can pass
a :class:`.PromptPipeline` to the :class:`.HuggingEngine`.
For more control over the prompting of a chat model, you can pass a :class:`.PromptPipeline` to
the :class:`.HuggingEngine`.

If you do create a new engine, instead of having to implement the prediction logic, all you have to do is subclass
:class:`.HuggingEngine` and implement :meth:`~.HuggingEngine.build_prompt` and :meth:`~.BaseEngine.message_len`.
Expand Down
14 changes: 8 additions & 6 deletions docs/shared/engine_table.rst
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
+----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+
| Model Name | Extra | Capabilities | Engine |
+========================================+====================================+==============================+======================================================================+
| GPT-3.5-turbo, GPT-4 | ``openai`` | |function| |api| | :class:`kani.engines.openai.OpenAIEngine` |
| GPT-* | ``openai`` | |function| |api| | :class:`kani.engines.openai.OpenAIEngine` |
+----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+
| Claude, Claude Instant | ``anthropic`` | |function| |api| | :class:`kani.engines.anthropic.AnthropicEngine` |
| Claude-* | ``anthropic`` | |function| |api| | :class:`kani.engines.anthropic.AnthropicEngine` |
+----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+
| |:hugging:| transformers\ [#runtime]_ | ``huggingface``\ [#torch]_ | (runtime) | :class:`kani.engines.huggingface.HuggingEngine` |
| |:hugging:| transformers\ [#hf]_ | ``huggingface``\ [#torch]_ | (model-specific) | :class:`kani.engines.huggingface.HuggingEngine` |
+----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+
| |:hugging:| |:llama:| LLaMA 3 | ``huggingface, llama``\ [#torch]_ | |oss| |cpu| |gpu| | :class:`kani.engines.huggingface.HuggingEngine`\ [#zoo]_ |
+----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+
Expand All @@ -17,7 +17,7 @@
+----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+
| |:hugging:| |:llama:| Vicuna v1.3 | ``huggingface, llama``\ [#torch]_ | |oss| |cpu| |gpu| | :class:`kani.engines.huggingface.vicuna.VicunaEngine` |
+----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+
| llama.cpp\ [#runtime]_ | ``cpp`` | (runtime) | :class:`kani.engines.llamacpp.LlamaCppEngine` |
| llama.cpp\ [#runtime]_ | ``cpp`` | (model-specific) | :class:`kani.engines.llamacpp.LlamaCppEngine` |
+----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+
| |:llama:| LLaMA v2 (GGUF) | ``cpp`` | |oss| |cpu| |gpu| | :class:`kani.engines.llamacpp.LlamaCppEngine` |
+----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+
Expand All @@ -41,8 +41,10 @@ models!
.. |api| replace:: :abbr:`📡 (hosted API)`

.. [#zoo] See the `model zoo <https://github.com/zhudotexe/kani/blob/main/examples/4_engines_zoo.py>`_ for a code sample
to initialize this model with the given engine.
to initialize this model with the given engine.
.. [#torch] You will also need to install `PyTorch <https://pytorch.org/get-started/locally/>`_ manually.
.. [#abstract] This is an abstract class of models; kani includes a couple concrete implementations for
reference.
.. [#runtime] This is a model runtime that can support multiple models using a :class:`.PromptPipeline`.
.. [#runtime] This is a model runtime that can support multiple models using a :class:`.PromptPipeline`.
.. [#hf] The HuggingEngine can run most models directly from HuggingFace using Chat Templates. For more fine-grained
control over prompting, see :class:`.PromptPipeline`.
4 changes: 4 additions & 0 deletions examples/4_engines_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
engine = AnthropicEngine(api_key=os.getenv("ANTHROPIC_API_KEY"), model="claude-3-opus-20240229")

# ========== Hugging Face ==========
# ---- Any Model (Chat Templates) ----
from kani.engines.huggingface import HuggingEngine
engine = HuggingEngine(model_id="org-id/model-id")

# ---- LLaMA v3 (Hugging Face) ----
import torch
from kani.engines.huggingface import HuggingEngine
Expand Down
22 changes: 22 additions & 0 deletions kani/engines/huggingface/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
For internal testing:
python -m kani.engines.huggingface org-id/model-id
Equivalent to initializing a HuggingEngine and calling chat_in_terminal.
"""

import sys

from kani import Kani, chat_in_terminal
from kani.engines.huggingface import HuggingEngine


def basic_chat_with_model_id(model_id: str):
engine = HuggingEngine(model_id)
ai = Kani(engine)
chat_in_terminal(ai)


if __name__ == "__main__":
if len(sys.argv) < 2:
sys.exit("Usage: python -m kani.engines.huggingface <org-id/model-id>")
basic_chat_with_model_id(sys.argv[1])
52 changes: 43 additions & 9 deletions kani/engines/huggingface/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from kani.exceptions import MissingModelDependencies
from kani.models import ChatMessage
from kani.prompts.pipeline import PromptPipeline
from .chat_template_pipeline import ChatTemplatePromptPipeline
from ..base import BaseCompletion, BaseEngine, Completion

try:
Expand All @@ -18,7 +19,15 @@
"You will also need to install PyTorch manually."
) from None

try:
import accelerate

has_accelerate = True
except ImportError:
has_accelerate = False

log = logging.getLogger(__name__)
has_cuda = torch.backends.cuda.is_built()


class HuggingEngine(BaseEngine):
Expand All @@ -28,6 +37,11 @@ class HuggingEngine(BaseEngine):
``AutoModelForCausalLM``. As most models use model-specific chat templates, this base class accepts a
:class:`.PromptPipeline` to translate kani ChatMessages into a model-specific string.

.. versionadded:: 1.2.0
By default, the ``HuggingEngine`` uses models' bundled chat template to build the prompt
for chat-based models available on Hugging Face. See
https://huggingface.co/docs/transformers/main/en/chat_templating for more information.

**GPU Support**

By default, the HuggingEngine loads the model on GPU if CUDA is detected on your system. To override the device
Expand All @@ -42,22 +56,28 @@ def __init__(
max_context_size: int = None,
prompt_pipeline: PromptPipeline[str | torch.Tensor] = None,
*,
# hf args
token=None,
device: str | None = None,
tokenizer_kwargs: dict = None,
model_load_kwargs: dict = None,
# kani args
token_reserve: int = 0,
**hyperparams,
):
"""
:param model_id: The ID of the model to load from HuggingFace.
:param max_context_size: The context size of the model. If not given, will be set from the model's config.
:param prompt_pipeline: The pipeline to translate a list of kani ChatMessages into the model-specific chat
format (see :class:`.PromptPipeline`).
format (see :class:`.PromptPipeline`). If not passed, uses the Hugging Face chat template if available.
:param token: The Hugging Face access token (for gated models). Pass True to load from huggingface-cli.
:param device: The hardware device to use. If not specified, uses CUDA if available; otherwise uses CPU.
:param tokenizer_kwargs: Additional arguments to pass to ``AutoTokenizer.from_pretrained()``.
:param model_load_kwargs: Additional arguments to pass to ``AutoModelForCausalLM.from_pretrained()``.
:param hyperparams: Additional arguments to supply the model during generation.
:param token_reserve: The number of tokens to reserve for internal engine mechanisms (e.g. if there is a
generation template after the last user message). If not passed, kani will attempt to infer this from a
prompt pipeline.
"""
if tokenizer_kwargs is None:
tokenizer_kwargs = {}
Expand All @@ -66,18 +86,26 @@ def __init__(

tokenizer_kwargs.setdefault("token", hyperparams.get("use_auth_token", token))
model_load_kwargs.setdefault("token", hyperparams.pop("use_auth_token", token))
model_load_kwargs.setdefault("torch_dtype", "auto")
if has_cuda and has_accelerate:
model_load_kwargs.setdefault("device_map", "auto")

self.model_id = model_id
self.max_context_size = max_context_size
self.pipeline = prompt_pipeline

self.tokenizer = AutoTokenizer.from_pretrained(model_id, **tokenizer_kwargs)
self.model = AutoModelForCausalLM.from_pretrained(model_id, **model_load_kwargs)
self.hyperparams = hyperparams
self.token_reserve = token_reserve

# load the pipeline
if prompt_pipeline is None:
prompt_pipeline = ChatTemplatePromptPipeline(self.tokenizer)
self.pipeline = prompt_pipeline

# ensure model is on correct device
if device is None:
device = "cuda" if torch.backends.cuda.is_built() else "cpu"
device = "cuda" if has_cuda else "cpu"
self.device = device
if self.model.device.type != self.device:
self.model.to(device)
Expand All @@ -100,8 +128,8 @@ def __init__(
elif self.max_context_size > 1e20:
warnings.warn(
f"The inferred max context size of this model is extremely large ({self.max_context_size}). This"
" may mean that the model has not configured their model_max_len correctly (or you are still using"
" my code in 2050). Please pass the `max_context_size` arg to use the correct model size."
" may mean that the model has not configured their model_max_len correctly. Please pass the"
" `max_context_size` arg to use the correct model size."
)

# infer the token reserve from the pipeline
Expand All @@ -118,6 +146,11 @@ def _infer_token_reserve(self):
return len(tokenized)

def message_len(self, message: ChatMessage) -> int:
"""Return the length, in tokens, of the given chat message.

The HuggingEngine's default implementation renders the message with ``apply_chat_template`` if no
``prompt_pipeline`` is supplied.
"""
# default concrete base behaviour:
if self.pipeline is None:
raise NotImplementedError(
Expand All @@ -131,6 +164,8 @@ def message_len(self, message: ChatMessage) -> int:
return len(tokenized)

def function_token_reserve(self, functions: list[AIFunction]) -> int:
if not functions:
return 0
# default concrete base behaviour:
if self.pipeline is None:
raise NotImplementedError(
Expand All @@ -145,7 +180,7 @@ def function_token_reserve(self, functions: list[AIFunction]) -> int:
toklen = len(tokenized)

# warn if there are functions but no tokens
if functions and toklen == 0:
if toklen == 0:
warnings.warn(
"Functions were given to the model, but the function prompt returned 0 tokens! This model may not"
" support function calling, or you may need to implement"
Expand Down Expand Up @@ -221,9 +256,8 @@ async def predict(
# decode to tokens
# the completion shouldn't include the prompt or stop token
content = self.tokenizer.decode(output[0][input_len:], **decode_kwargs).strip()
return Completion(
ChatMessage.assistant(content), prompt_tokens=input_len, completion_tokens=len(output[0]) - (input_len + 1)
)
output_len = len(output[0]) - (input_len + 1)
return Completion(ChatMessage.assistant(content), prompt_tokens=input_len, completion_tokens=output_len)

async def stream(
self,
Expand Down
Loading
Loading