Skip to content

Commit

Permalink
Add multiple EOS support (#1219)
Browse files Browse the repository at this point in the history
Ticket 157352
  • Loading branch information
Wovchena authored Nov 21, 2024
2 parents bf84989 + 9d41787 commit 7b372a8
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 10 deletions.
8 changes: 4 additions & 4 deletions src/cpp/include/openvino/genai/generation_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ enum class StopCriteria { EARLY, HEURISTIC, NEVER };
* @param eos_token_id token_id of <eos> (end of sentence)
* @param min_new_tokens set 0 probability for eos_token_id for the first eos_token_id generated tokens. Ignored for non continuous batching.
*
* @param stop_strings vector of strings that will cause pipeline to stop generating further tokens. Ignored for non continuous batching.
* @param stop_strings A set of strings that will cause pipeline to stop generating further tokens.
* @param include_stop_str_in_output if set to true stop string that matched generation will be included in generation output (default: false)
* @param stop_token_ids vector of tokens that will cause pipeline to stop generating further tokens. Ignored for non continuous batching.
* @param stop_token_ids A set of tokens that will cause pipeline to stop generating further tokens.
* @param echo if set to true, output will include user prompt (default: false).
* @param logprobs number of top logprobs computed for each position, if set to 0, logprobs are not computed and value 0.0 is returned.
* Currently only single top logprob can be returned, so any logprobs > 1 is treated as logprobs == 1. (default: 0).
Expand Down Expand Up @@ -154,9 +154,9 @@ static constexpr ov::Property<size_t> max_new_tokens{"max_new_tokens"};
static constexpr ov::Property<size_t> max_length{"max_length"};
static constexpr ov::Property<bool> ignore_eos{"ignore_eos"};
static constexpr ov::Property<size_t> min_new_tokens{"min_new_tokens"};
static constexpr ov::Property<std::vector<std::string>> stop_strings{"stop_strings"};
static constexpr ov::Property<std::set<std::string>> stop_strings{"stop_strings"};
static constexpr ov::Property<bool> include_stop_str_in_output{"include_stop_str_in_output"};
static constexpr ov::Property<std::vector<std::vector<int64_t>>> stop_token_ids{"stop_token_ids"};
static constexpr ov::Property<std::set<int64_t>> stop_token_ids{"stop_token_ids"};

static constexpr ov::Property<size_t> num_beam_groups{"num_beam_groups"};
static constexpr ov::Property<size_t> num_beams{"num_beams"};
Expand Down
7 changes: 5 additions & 2 deletions src/python/py_generation_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ char generation_config_docstring[] = R"(
ignore_eos: if set to true, then generation will not stop even if <eos> token is met.
eos_token_id: token_id of <eos> (end of sentence)
min_new_tokens: set 0 probability for eos_token_id for the first eos_token_id generated tokens. Ignored for non continuous batching.
stop_strings: list of strings that will cause pipeline to stop generating further tokens. Ignored for non continuous batching.
stop_strings: a set of strings that will cause pipeline to stop generating further tokens.
include_stop_str_in_output: if set to true stop string that matched generation will be included in generation output (default: false)
stop_token_ids: list of tokens that will cause pipeline to stop generating further tokens. Ignored for non continuous batching.
stop_token_ids: a set of tokens that will cause pipeline to stop generating further tokens.
echo: if set to true, the model will echo the prompt in the output.
logprobs: number of top logprobs computed for each position, if set to 0, logprobs are not computed and value 0.0 is returned.
Currently only single top logprob can be returned, so any logprobs > 1 is treated as logprobs == 1. (default: 0).
Expand Down Expand Up @@ -87,6 +87,9 @@ void init_generation_config(py::module_& m) {
.def_readwrite("max_length", &GenerationConfig::max_length)
.def_readwrite("ignore_eos", &GenerationConfig::ignore_eos)
.def_readwrite("min_new_tokens", &GenerationConfig::min_new_tokens)
.def_readwrite("stop_strings", &GenerationConfig::stop_strings)
.def_readwrite("include_stop_str_in_output", &GenerationConfig::include_stop_str_in_output)
.def_readwrite("stop_token_ids", &GenerationConfig::stop_token_ids)
.def_readwrite("num_beam_groups", &GenerationConfig::num_beam_groups)
.def_readwrite("num_beams", &GenerationConfig::num_beams)
.def_readwrite("diversity_penalty", &GenerationConfig::diversity_penalty)
Expand Down
8 changes: 4 additions & 4 deletions src/python/py_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ ov::genai::OptionalGenerationConfig update_config_from_kwargs(const ov::genai::O
res_config.stop_strings = py::cast<std::set<std::string>>(value);
} else if (key == "include_stop_str_in_output") {
res_config.include_stop_str_in_output = py::cast<bool>(value);
} else if (key == "include_stop_str_in_output") {
} else if (key == "stop_token_ids") {
res_config.stop_token_ids = py::cast<std::set<int64_t>>(value);
} else if (key == "max_length") {
res_config.max_length = py::cast<int>(item.second);
Expand Down Expand Up @@ -311,11 +311,11 @@ bool generation_config_param_to_property(std::string key, py::object value, ov::
} else if (key == "min_new_tokens") {
map.insert(ov::genai::min_new_tokens(py::cast<int>(value)));
} else if (key == "stop_strings") {
map.insert(ov::genai::stop_strings(py::cast<std::vector<std::string>>(value)));
map.insert(ov::genai::stop_strings(py::cast<std::set<std::string>>(value)));
} else if (key == "include_stop_str_in_output") {
map.insert(ov::genai::include_stop_str_in_output(py::cast<bool>(value)));
} else if (key == "include_stop_str_in_output") {
map.insert(ov::genai::stop_token_ids(py::cast<std::vector<std::vector<int64_t>>>(value)));
} else if (key == "stop_token_ids") {
map.insert(ov::genai::stop_token_ids(py::cast<std::set<int64_t>>(value)));
} else if (key == "num_beam_groups") {
map.insert(ov::genai::num_beam_groups(py::cast<int>(value)));
} else if (key == "num_beams") {
Expand Down
26 changes: 26 additions & 0 deletions tests/python_tests/test_generate_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,3 +848,29 @@ def test_batch_switch():
pipe = read_model(('katuni4ka/tiny-random-phi3', Path('tiny-random-phi3')))[4]
pipe.generate(["a"], max_new_tokens=2)
pipe.generate(["1", "2"], max_new_tokens=2)


@pytest.mark.precommit
@pytest.mark.nightly
def test_stop_token_ids():
pipe = read_model(('katuni4ka/tiny-random-phi3', Path('tiny-random-phi3')))[4]
res = pipe.generate(
ov.Tensor([(1,)]),
max_new_tokens=3,
stop_token_ids={-1, 9935},
include_stop_str_in_output=False
)
assert 2 == len(res.tokens[0])
assert 9935 in res.tokens[0]


@pytest.mark.precommit
@pytest.mark.nightly
def test_stop_strings():
pipe = read_model(('katuni4ka/tiny-random-phi3', Path('tiny-random-phi3')))[4]
res = pipe.generate(
"",
max_new_tokens=5,
stop_strings={"ignored", "боль"}
)
assert "боль" not in res

0 comments on commit 7b372a8

Please sign in to comment.