Skip to content

Commit

Permalink
[Grammar] Add SetStopTokenIds for tool calling (#2767)
Browse files Browse the repository at this point in the history
This PR adds SetStopTokenIds to GrammarStateMatcher, registering mlc.grammar.GrammarStateMatcherSetStopTokenIds that overrides existing default stop token ids of the matcher with user-provided ones.
  • Loading branch information
CharlieFRuan authored Aug 10, 2024
1 parent 37860b0 commit fcf055a
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 1 deletion.
10 changes: 10 additions & 0 deletions cpp/grammar/grammar_state_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm
PushInitialState(kInvalidRulePosition, true);
}

void SetStopTokenIds(const std::vector<int32_t>& stop_token_ids) final {
init_ctx_->stop_token_ids = stop_token_ids;
}

private:
/*!
* \brief If is_uncertain_saved is true, find the next token in uncertain_indices. Otherwise,
Expand Down Expand Up @@ -609,6 +613,12 @@ TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherIsTerminated")
TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherResetState")
.set_body_typed([](GrammarStateMatcher matcher) { matcher->ResetState(); });

TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherSetStopTokenIds")
.set_body_typed([](GrammarStateMatcher matcher, IntTuple stop_token_ids) {
std::vector<int32_t> stop_token_ids_vector{stop_token_ids.begin(), stop_token_ids.end()};
matcher->SetStopTokenIds(stop_token_ids_vector);
});

/*! \brief Check if a matcher can accept the complete string, and then reach the end of the
* grammar. Does not change the state of the GrammarStateMatcher. For test purpose. */
bool MatchCompleteString(GrammarStateMatcher matcher, String str, bool verbose) {
Expand Down
3 changes: 3 additions & 0 deletions cpp/grammar/grammar_state_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ class GrammarStateMatcherNode : public Object {
/*! \brief Reset the matcher to the initial state. */
virtual void ResetState() = 0;

/*! \brief Set the stop token ids, overriding the existing defaults ones. */
virtual void SetStopTokenIds(const std::vector<int32_t>& stop_token_ids) = 0;

static constexpr const char* _type_key = "mlc.grammar.GrammarStateMatcher";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
Expand Down
6 changes: 5 additions & 1 deletion python/mlc_llm/grammar/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def accept_token(self, token_id: int) -> bool:
find_next_rejected_tokens operations can be performed. The termination state can be canceled
using Rollback().
"""
return _ffi_api.GrammarStateMatcherAcceptToken(self, token_id) # type: ignore # pylint: disable=no-member
return _ffi_api.GrammarStateMatcherAcceptToken(self, token_id, False) # type: ignore # pylint: disable=no-member

def find_next_rejected_tokens(self, verbose: bool = False) -> List[int]:
"""Find the ids of the rejected tokens for the next step.
Expand Down Expand Up @@ -400,3 +400,7 @@ def debug_match_complete_string(self, string: str, verbose: bool = False) -> boo
The string to be matched.
"""
return _ffi_api.GrammarStateMatcherDebugMatchCompleteString(self, string, verbose) # type: ignore # pylint: disable=no-member

def set_stop_token_ids(self, stop_token_ids: List[int]) -> None:
"""Set the stop token ids, overriding the default ones."""
_ffi_api.GrammarStateMatcherSetStopTokenIds(self, tvm.runtime.ShapeTuple(stop_token_ids)) # type: ignore # pylint: disable=no-member
33 changes: 33 additions & 0 deletions tests/python/grammar/test_grammar_state_matcher_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,39 @@ def test_reset(json_grammar: BNFGrammar):
assert orig_result == result_after_reset


def test_set_stop_token_ids(json_grammar: BNFGrammar):
token_table = [
# fmt: off
"<s>", "</s>", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true',
# fmt: on
]
input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', "}", "</s>"]
input_ids = [token_table.index(t) for t in input_splitted]

# 1. Will accept </s> as last token for stop token
grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table)
for i in input_ids:
assert grammar_state_matcher.accept_token(i)

# 2. Will reject </s> as last token for stop token
grammar_state_matcher.reset_state()
grammar_state_matcher.set_stop_token_ids([2])
for i in input_ids:
if i == 1:
# 1 is </s>, will be rejected
assert not grammar_state_matcher.accept_token(i)
else:
assert grammar_state_matcher.accept_token(i)

# 3. Will accept "a" as stop token
grammar_state_matcher.reset_state()
grammar_state_matcher.set_stop_token_ids([2])
input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', "}", "a"]
input_ids = [token_table.index(t) for t in input_splitted]
for i in input_ids:
assert grammar_state_matcher.accept_token(i)


def test_termination(json_grammar: BNFGrammar):
token_table = [
# fmt: off
Expand Down

0 comments on commit fcf055a

Please sign in to comment.