-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend] Automatic detection of chat content format from AST #9919
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
Great idea with the PR @DarkLight1337 ! |
Right now I am thinking of using Jinja's AST parser and working off that. The basic idea is to detect whether |
vllm/entrypoints/chat_utils.py
Outdated
def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool: | ||
if isinstance(node, jinja2.nodes.Name): | ||
return node.ctx == "load" and node.name == varname | ||
|
||
return False | ||
|
||
|
||
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: | ||
if isinstance(node, jinja2.nodes.Getitem): | ||
return (node.ctx == "load" and _is_var_access(node.node, varname) | ||
and isinstance(node.arg, jinja2.nodes.Const) | ||
and node.arg.value == key) | ||
|
||
if isinstance(node, jinja2.nodes.Getattr): | ||
return (node.ctx == "load" and _is_var_access(node.node, varname) | ||
and node.attr == key) | ||
|
||
return False | ||
|
||
|
||
def _iter_nodes_define_message(chat_template_ast: jinja2.nodes.Template): | ||
# Search for {%- for message in messages -%} loops | ||
for loop_ast in chat_template_ast.find_all(jinja2.nodes.For): | ||
loop_iter = loop_ast.iter | ||
loop_target = loop_ast.target | ||
|
||
if _is_var_access(loop_iter, "messages"): | ||
assert isinstance(loop_target, jinja2.nodes.Name) | ||
yield loop_ast, loop_target.name | ||
|
||
|
||
def _iter_nodes_define_content_item(chat_template_ast: jinja2.nodes.Template): | ||
for node, message_varname in _iter_nodes_define_message(chat_template_ast): | ||
# Search for {%- for content in message['content'] -%} loops | ||
for loop_ast in node.find_all(jinja2.nodes.For): | ||
loop_iter = loop_ast.iter | ||
loop_target = loop_ast.target | ||
|
||
if _is_attr_access(loop_iter, message_varname, "content"): | ||
assert isinstance(loop_target, jinja2.nodes.Name) | ||
yield loop_iter, loop_target.name | ||
|
||
|
||
def _detect_content_format( | ||
chat_template: str, | ||
*, | ||
default: _ChatTemplateContentFormat, | ||
) -> _ChatTemplateContentFormat: | ||
try: | ||
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template) | ||
jinja_ast = jinja_compiled.environment.parse(chat_template) | ||
except Exception: | ||
logger.exception("Error when compiling Jinja template") | ||
return default | ||
|
||
try: | ||
next(_iter_nodes_define_content_item(jinja_ast)) | ||
except StopIteration: | ||
return "string" | ||
else: | ||
return "openai" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This handles the most common case of iterating through OpenAI-formatted message['content']
as a list, assuming that no relevant variable reassignments are made other than those in the for loops.
Please tell me if you are aware of any chat templates that don't work with this code.
@@ -380,10 +521,7 @@ def load_chat_template( | |||
|
|||
# If opening a file fails, set chat template to be args to | |||
# ensure we decode so our escape are interpreted correctly | |||
resolved_chat_template = codecs.decode(chat_template, "unicode_escape") | |||
|
|||
logger.info("Using supplied chat template:\n%s", resolved_chat_template) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thie logging line has been moved to vllm/entrypoints/openai/api_server.py
.
chat_template: Optional[str] = Field( | ||
default=None, | ||
description=( | ||
"A Jinja template to use for this conversion. " | ||
"As of transformers v4.44, default chat template is no longer " | ||
"allowed, so you must provide a chat template if the tokenizer " | ||
"does not define one."), | ||
) | ||
chat_template_kwargs: Optional[Dict[str, Any]] = Field( | ||
default=None, | ||
description=("Additional kwargs to pass to the template renderer. " | ||
"Will be accessible by the chat template."), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These arguments are present in other chat-based APIs so I added them here as well.
--chat-template-text-format
8ce013b
to
e262745
Compare
This pull request has merge conflicts that must be resolved before it can be |
e262745
to
c37af03
Compare
loop_target = loop_ast.target | ||
|
||
for varname in message_varnames: | ||
if _is_var_or_elems_access(loop_iter, varname, "content"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this also handle cases where content is reassingned?
Pseudo code example:
for message in messages:
content = message["content"]
for c in content:
do_stuff(c)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, currently it doesn't do that. Let me think a bit about how to handle this...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wrote some code to enable this, but found that this causes false positives. In particular, tool_chat_template_mistral.jinja is detected as having OpenAI format because of L54 and L57 in the chat template.
It would be quite complicated to condition the detected content format based on message["role"]
... we might as well build a CFG, otherwise our code would be quite unmaintainable 😅
Let's keep this simple for now. I am by no means an expert in program analysis.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For future reference, here's the code I changed to handle reassignment of message["content"]
:
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index d6ab3c04e..c0edb7c24 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -204,21 +204,47 @@ def _is_var_or_elems_access(
) # yapf: enable
-def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
- # Global variable that is implicitly defined at the root
- yield root, varname
+def _iter_nodes_assign_var_or_elems(
+ root: jinja2.nodes.Node,
+ varname: str,
+ key: Optional[str] = None,
+):
+ if key is None:
+ # Global variable that is implicitly defined at the root
+ yield root, varname
related_varnames: List[str] = [varname]
for assign_ast in root.find_all(jinja2.nodes.Assign):
lhs = assign_ast.target
rhs = assign_ast.node
- if any(_is_var_or_elems_access(rhs, name) for name in related_varnames):
+ if any(_is_var_or_elems_access(rhs, related_varname, key)
+ for related_varname in related_varnames):
assert isinstance(lhs, jinja2.nodes.Name)
yield assign_ast, lhs.name
related_varnames.append(lhs.name)
+def _iter_nodes_assign_elem(
+ root: jinja2.nodes.Node,
+ varname: str,
+ key: Optional[str] = None,
+):
+ for loop_ast in root.find_all(jinja2.nodes.For):
+ loop_iter = loop_ast.iter
+ loop_target = loop_ast.target
+
+ if _is_var_or_elems_access(loop_iter, varname, key):
+ assert isinstance(loop_target, jinja2.nodes.Name)
+ yield loop_ast, loop_target.name
+ break
+
+ if key is not None:
+ for _, related_varname in _iter_nodes_assign_var_or_elems(
+ root, varname, key):
+ yield from _iter_nodes_assign_elem(root, related_varname)
+
+
# NOTE: The proper way to handle this is to build a CFG so that we can handle
# the scope in which each variable is defined, but that is too complicated
def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
@@ -227,16 +253,8 @@ def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
]
- # Search for {%- for message in messages -%} loops
- for loop_ast in root.find_all(jinja2.nodes.For):
- loop_iter = loop_ast.iter
- loop_target = loop_ast.target
-
- for varname in messages_varnames:
- if _is_var_or_elems_access(loop_iter, varname):
- assert isinstance(loop_target, jinja2.nodes.Name)
- yield loop_ast, loop_target.name
- break
+ for messages_varname in messages_varnames:
+ yield from _iter_nodes_assign_elem(root, messages_varname)
def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
@@ -244,16 +262,8 @@ def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
varname for _, varname in _iter_nodes_assign_messages_item(root)
]
- # Search for {%- for content in message['content'] -%} loops
- for loop_ast in root.find_all(jinja2.nodes.For):
- loop_iter = loop_ast.iter
- loop_target = loop_ast.target
-
- for varname in message_varnames:
- if _is_var_or_elems_access(loop_iter, varname, "content"):
- assert isinstance(loop_target, jinja2.nodes.Name)
- yield loop_ast, loop_target.name
- break
+ for message_varname in message_varnames:
+ yield from _iter_nodes_assign_elem(root, message_varname, "content")
def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]:
Signed-off-by: DarkLight1337 <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: DarkLight1337 <[email protected]>
@maxdebayser does this look good to you now? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DarkLight1337 , I've left a few comments, I think the one about the assignment search is worth of your consideration but other than that it looks good to me.
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: DarkLight1337 <[email protected]>
@DarkLight1337 looks like there's one test failure remaining |
The network is quite slow right now (HF keeps timing out for a lot of other PRs). This error comes from not being able to download the video before timeout occurs. (It passes when I run it locally.) Can you approve this PR? Then I'll retry the CI once the network returns to normal. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @DarkLight1337 @maxdebayser!
…project#9919) Signed-off-by: DarkLight1337 <[email protected]>
…project#9919) Signed-off-by: DarkLight1337 <[email protected]>
…project#9919) Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: Maxime Fournioux <[email protected]>
…project#9919) Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: rickyx <[email protected]>
…project#9919) Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: Tyler Michael Smith <[email protected]>
…project#9919) Signed-off-by: DarkLight1337 <[email protected]>
…project#9919) Signed-off-by: DarkLight1337 <[email protected]>
This PR renames
--chat-template-text-format
(introduced by #9358) to--chat-template-content-format
and moves it to the CLI parser specific to OpenAI-compatible server. Also, it removes the redundant hardcoded logic for Llama-3.2-Vision (last updated by #9393) since we can now run online inference with--chat-template-content-format openai
.To avoid causing incompatibilities with how users are currently serving Llama-3.2-Vision, I have added code to automatically detect the format to use based on the AST of the provided chat template.
cc @vrdn-23 @ywang96 @heheda12345 @alex-jw-brooks
FIX #10286