Skip to content

Commit

Permalink
feat: use some handwritten dummy prompts for padding inference
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed Sep 20, 2024
1 parent aa4dbc2 commit 8e4048b
Showing 1 changed file with 61 additions and 16 deletions.
77 changes: 61 additions & 16 deletions kani/engines/huggingface/chat_template_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import warnings
from collections import defaultdict
from functools import cached_property
from typing import Iterable

from kani import AIFunction, ChatMessage, ChatRole, PromptPipeline
from kani.exceptions import MissingModelDependencies
Expand Down Expand Up @@ -61,22 +62,46 @@ def _chat_template_dummy_len(self) -> int:

def _infer_padding_len(self, role: ChatRole):
"""Set up _padding_len_by_role for each message type"""
try:
log.debug(f"Estimating padding len for {role}...")
conversation = super().execute([ChatMessage(role=role, content="dummy")])
conversation_len = len(self.tokenizer.apply_chat_template(conversation, add_generation_prompt=False))
text_len = len(self.tokenizer.encode("dummy", add_special_tokens=False))
self._padding_len_by_role[role] = max(conversation_len - text_len, 0)
log.debug(f"{conversation_len=}, {text_len=}, padding estimate={conversation_len - text_len}")
except (TemplateError, IndexError) as e:
# if the template doesn't allow a bare message of this type,
log.warning(
"Estimating message token padding with chat template application raised an error, assuming messages"
" have a padding equal to length of role name plus 4 pad tokens. If this is incorrect, please implement"
" a PromptPipeline.",
exc_info=e,
)
self._padding_len_by_role[role] = len(self.tokenizer.encode(role.value, add_special_tokens=False)) + 4
# try appending a dummy message to various bases, based on role
for base_msgs in _padding_length_inference_base(role):
try:
log.debug(f"Estimating padding len for {role}...")
msgs = base_msgs + [ChatMessage(role=role, content="dummy")]

# get token len of base messages + dummy message
conversation = super().execute(msgs)
conversation_len = len(self.tokenizer.apply_chat_template(conversation, add_generation_prompt=False))

# get token len of just base messages
if base_msgs:
base_conversation = super().execute(base_msgs)
base_conversation_len = len(
self.tokenizer.apply_chat_template(base_conversation, add_generation_prompt=False)
)
else:
base_conversation_len = 0

# get token len of "dummy"
text_len = len(self.tokenizer.encode("dummy", add_special_tokens=False))

# padding = total - (base + dummy)
self._padding_len_by_role[role] = max(conversation_len - (text_len + base_conversation_len), 0)
log.debug(
f"{conversation_len=}, {base_conversation_len=}, {text_len=}, padding"
f" estimate={self._padding_len_by_role[role]}"
)
except (TemplateError, IndexError) as e:
log.debug("Failed to estimate message padding length", exc_info=e)
continue
else:
return
# if we never found a base conversation that works, do a best guess
log.warning(
"Estimating message token padding with chat template application raised an error, assuming messages"
" have a padding equal to length of role name plus 4 pad tokens. If this is incorrect, please implement"
" a PromptPipeline.",
)
self._padding_len_by_role[role] = len(self.tokenizer.encode(role.value, add_special_tokens=False)) + 4

def _chat_template_infer_token_reserve(self):
"""If token_reserve is not set and we have a pipeline, infer it."""
Expand Down Expand Up @@ -209,3 +234,23 @@ def hf_tool_use_keys(message: ChatMessage) -> dict:
data["tool_call_id"] = message.tool_call_id
return data
return {}


def _padding_length_inference_base(role: ChatRole) -> Iterable[list[ChatMessage]]:
"""Given a role, yield possible base messages used to infer padding."""
base_with_system = [ChatMessage.system("dummy"), ChatMessage.user("dummy"), ChatMessage.assistant("dummy")]
base_without_system = [ChatMessage.user("dummy"), ChatMessage.assistant("dummy")]
if role == ChatRole.USER:
yield base_with_system
yield base_without_system
elif role == ChatRole.ASSISTANT:
yield base_with_system + [ChatMessage.user("dummy")]
yield base_without_system + [ChatMessage.user("dummy")]
yield [ChatMessage.system("dummy"), ChatMessage.user("dummy")]
yield [ChatMessage.user("dummy")]
elif role == ChatRole.FUNCTION:
yield base_with_system
yield base_without_system
# SYSTEM: only empty
# default, no conversation
yield []

0 comments on commit 8e4048b

Please sign in to comment.