From fcf055af438e39686bd294e6ae9aaa82b5fd10f2 Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Sat, 10 Aug 2024 13:24:13 -0400 Subject: [PATCH] [Grammar] Add SetStopTokenIds for tool calling (#2767) This PR adds SetStopTokenIds to GrammarStateMatcher, registering mlc.grammar.GrammarStateMatcherSetStopTokenIds that overrides existing default stop token ids of the matcher with user-provided ones. --- cpp/grammar/grammar_state_matcher.cc | 10 ++++++ cpp/grammar/grammar_state_matcher.h | 3 ++ python/mlc_llm/grammar/grammar.py | 6 +++- .../test_grammar_state_matcher_json.py | 33 +++++++++++++++++++ 4 files changed, 51 insertions(+), 1 deletion(-) diff --git a/cpp/grammar/grammar_state_matcher.cc b/cpp/grammar/grammar_state_matcher.cc index 3c37b8806b..660b8a5e3d 100644 --- a/cpp/grammar/grammar_state_matcher.cc +++ b/cpp/grammar/grammar_state_matcher.cc @@ -153,6 +153,10 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm PushInitialState(kInvalidRulePosition, true); } + void SetStopTokenIds(const std::vector& stop_token_ids) final { + init_ctx_->stop_token_ids = stop_token_ids; + } + private: /*! * \brief If is_uncertain_saved is true, find the next token in uncertain_indices. Otherwise, @@ -609,6 +613,12 @@ TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherIsTerminated") TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherResetState") .set_body_typed([](GrammarStateMatcher matcher) { matcher->ResetState(); }); +TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherSetStopTokenIds") + .set_body_typed([](GrammarStateMatcher matcher, IntTuple stop_token_ids) { + std::vector stop_token_ids_vector{stop_token_ids.begin(), stop_token_ids.end()}; + matcher->SetStopTokenIds(stop_token_ids_vector); + }); + /*! \brief Check if a matcher can accept the complete string, and then reach the end of the * grammar. Does not change the state of the GrammarStateMatcher. For test purpose. */ bool MatchCompleteString(GrammarStateMatcher matcher, String str, bool verbose) { diff --git a/cpp/grammar/grammar_state_matcher.h b/cpp/grammar/grammar_state_matcher.h index f961a59fcd..98fda522d0 100644 --- a/cpp/grammar/grammar_state_matcher.h +++ b/cpp/grammar/grammar_state_matcher.h @@ -101,6 +101,9 @@ class GrammarStateMatcherNode : public Object { /*! \brief Reset the matcher to the initial state. */ virtual void ResetState() = 0; + /*! \brief Set the stop token ids, overriding the existing defaults ones. */ + virtual void SetStopTokenIds(const std::vector& stop_token_ids) = 0; + static constexpr const char* _type_key = "mlc.grammar.GrammarStateMatcher"; static constexpr const bool _type_has_method_sequal_reduce = false; static constexpr const bool _type_has_method_shash_reduce = false; diff --git a/python/mlc_llm/grammar/grammar.py b/python/mlc_llm/grammar/grammar.py index e1365e2b78..97cb30c719 100644 --- a/python/mlc_llm/grammar/grammar.py +++ b/python/mlc_llm/grammar/grammar.py @@ -291,7 +291,7 @@ def accept_token(self, token_id: int) -> bool: find_next_rejected_tokens operations can be performed. The termination state can be canceled using Rollback(). """ - return _ffi_api.GrammarStateMatcherAcceptToken(self, token_id) # type: ignore # pylint: disable=no-member + return _ffi_api.GrammarStateMatcherAcceptToken(self, token_id, False) # type: ignore # pylint: disable=no-member def find_next_rejected_tokens(self, verbose: bool = False) -> List[int]: """Find the ids of the rejected tokens for the next step. @@ -400,3 +400,7 @@ def debug_match_complete_string(self, string: str, verbose: bool = False) -> boo The string to be matched. """ return _ffi_api.GrammarStateMatcherDebugMatchCompleteString(self, string, verbose) # type: ignore # pylint: disable=no-member + + def set_stop_token_ids(self, stop_token_ids: List[int]) -> None: + """Set the stop token ids, overriding the default ones.""" + _ffi_api.GrammarStateMatcherSetStopTokenIds(self, tvm.runtime.ShapeTuple(stop_token_ids)) # type: ignore # pylint: disable=no-member diff --git a/tests/python/grammar/test_grammar_state_matcher_json.py b/tests/python/grammar/test_grammar_state_matcher_json.py index bf912243e0..333e36a283 100644 --- a/tests/python/grammar/test_grammar_state_matcher_json.py +++ b/tests/python/grammar/test_grammar_state_matcher_json.py @@ -395,6 +395,39 @@ def test_reset(json_grammar: BNFGrammar): assert orig_result == result_after_reset +def test_set_stop_token_ids(json_grammar: BNFGrammar): + token_table = [ + # fmt: off + "", "", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true', + # fmt: on + ] + input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', "}", ""] + input_ids = [token_table.index(t) for t in input_splitted] + + # 1. Will accept as last token for stop token + grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table) + for i in input_ids: + assert grammar_state_matcher.accept_token(i) + + # 2. Will reject as last token for stop token + grammar_state_matcher.reset_state() + grammar_state_matcher.set_stop_token_ids([2]) + for i in input_ids: + if i == 1: + # 1 is , will be rejected + assert not grammar_state_matcher.accept_token(i) + else: + assert grammar_state_matcher.accept_token(i) + + # 3. Will accept "a" as stop token + grammar_state_matcher.reset_state() + grammar_state_matcher.set_stop_token_ids([2]) + input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', "}", "a"] + input_ids = [token_table.index(t) for t in input_splitted] + for i in input_ids: + assert grammar_state_matcher.accept_token(i) + + def test_termination(json_grammar: BNFGrammar): token_table = [ # fmt: off