Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move beam search in case of chat scenario to sampler.cpp #1215

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
455 changes: 0 additions & 455 deletions src/cpp/src/group_beam_searcher.cpp

This file was deleted.

197 changes: 126 additions & 71 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,30 @@
#include "sampler.hpp"
#include "lm_encoding.hpp"

namespace {

enum class GenerationChatInputsType {
UNDEF = 0, // Default value, type of inputs is not defined
STRING = 1, // Type of inputs is StringInputs
ENCODED_INPUTS = 2, // Type of inputs is EncodedInputs
};

} // namespace

namespace ov {
namespace genai {

std::pair<EncodedResults, int32_t> beam_search(
ov::InferRequest& lm,
ov::Tensor prompts,
ov::Tensor attention_mask,
GenerationConfig config,
std::optional<ov::Tensor> position_ids,
std::optional<int32_t> selected_beam_idx
);

class StatefulLLMPipeline final : public LLMPipelineImplBase {
public:
ov::InferRequest m_model_runner;

bool is_chat_conversation = false;
bool m_history_available = false;
bool m_is_cache_empty = true;
std::optional<int32_t> m_selected_beam = std::nullopt;
ChatHistory m_history;
std::string m_templated_chat_history = {};
TokenizedInputs m_tokenized_chat_history;
std::vector<int64_t> m_tokenized_chat_history;
GenerationChatInputsType m_chat_input_type = GenerationChatInputsType::UNDEF;
std::optional<int64_t> m_last_disappeared_token = std::nullopt;

StatefulLLMPipeline(
const ov::InferRequest& request,
Expand Down Expand Up @@ -94,6 +96,13 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
OptionalGenerationConfig generation_config,
StreamerVariant streamer
) override {
if (is_chat_conversation && m_chat_input_type == GenerationChatInputsType::UNDEF)
m_chat_input_type = GenerationChatInputsType::STRING;

if (is_chat_conversation)
OPENVINO_ASSERT(m_chat_input_type != GenerationChatInputsType::ENCODED_INPUTS,
"Chat doesn't support switching between input types. Please, continue using EncodedInputs or restart the chat.");

auto start_time = std::chrono::steady_clock::now();
GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config;
TokenizedInputs encoded_input;
Expand All @@ -119,14 +128,23 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
// Do not add special tokens in chat scenario to be aligned with HF.
auto new_chat_tokens = m_tokenizer.encode(new_templated_chat_history, ov::genai::add_special_tokens(false));
if (m_is_cache_empty) {
encoded_input = new_chat_tokens;
} else {
if (m_history_available) {
auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(false));
encoded_input = utils::subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens);
ilya-lavrenov marked this conversation as resolved.
Show resolved Hide resolved
} else {
encoded_input = new_chat_tokens;
}
m_templated_chat_history = new_templated_chat_history;
m_tokenized_chat_history = new_chat_tokens;

m_tokenized_chat_history.clear();
std::copy(new_chat_tokens.input_ids.data<int64_t>(), new_chat_tokens.input_ids.data<int64_t>() + new_chat_tokens.input_ids.get_size(),
std::back_inserter(m_tokenized_chat_history));

// no need to add m_last_disappeared_token to encoded_input, it was kept by subtract_chat_tokenized_inputs
if (m_last_disappeared_token.has_value() && *m_last_disappeared_token == encoded_input.input_ids.data<int64_t>()[0]) {
m_last_disappeared_token = std::nullopt;
}

// TODO: Forbid LoRA config change if we are in the chat mode, because it requires regenerating the history with LoRA applied
} else {
encoded_input = m_tokenizer.encode(prompt);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that m_templated_chat_history.append(answer); is valid for all chat templates, because between current history and assistant answer we can have some tokens / words (e.g. ' ' in example below):

{% set content = message['content'] %}
{% endif %}
{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}
{% elif message['role'] == 'assistant' %}{{ ' '  + content.strip() + ' ' + eos_token }}

Let's discuss it on GenAI meeting.

Expand Down Expand Up @@ -180,6 +198,14 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
OptionalGenerationConfig generation_config,
StreamerVariant streamer
) override {
if (is_chat_conversation && m_chat_input_type == GenerationChatInputsType::UNDEF)
m_chat_input_type = GenerationChatInputsType::ENCODED_INPUTS;

if (is_chat_conversation)
// if chat was run in StringInputs mode, but it was called EncodedInputs generate, last m_history entry will be with assistant role
OPENVINO_ASSERT(m_chat_input_type == GenerationChatInputsType::ENCODED_INPUTS || m_history.back()["role"] == "user",
"Chat doesn't support switching between input types. Please, continue using StringInputs or restart the chat.");

auto start_time = std::chrono::steady_clock::now();
ov::Tensor input_ids;
ov::Tensor attention_mask;
Expand All @@ -191,6 +217,16 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
attention_mask = data->attention_mask;
}

if (is_chat_conversation && m_chat_input_type == GenerationChatInputsType::ENCODED_INPUTS) {
std::copy(input_ids.data<int64_t>(), input_ids.data<int64_t>() + input_ids.get_size(), std::back_inserter(m_tokenized_chat_history));
}

// Tail of previous output in chat mode is missing in KV cache.
if (is_chat_conversation && m_last_disappeared_token.has_value()) {
attention_mask = ov::genai::utils::push_front_inputs(attention_mask, std::vector<int64_t>{1});
input_ids = ov::genai::utils::push_front_inputs(input_ids, std::vector<int64_t>{*m_last_disappeared_token});
}

GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config;

// If eos_token_id was not provided, take value from default m_generation_config
Expand Down Expand Up @@ -222,77 +258,92 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
"(input_ids, attention_mask, position_ids, beam_idx) "
"but you have '" + std::to_string(num_inputs) + "' inputs");


ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data());
size_t kv_cache_len = 0;
ov::Tensor concatenated_attention_mask;
if (is_chat_conversation && !m_is_cache_empty) {
OPENVINO_ASSERT(batch_size == 1, "continuation of generation is possible only for batch 1");
// If history is saved in KV cache, concatenate new attention_mask with the already existing.
// Between subsequent runs attention_mask should not be modified.
auto atten_mask_history = m_model_runner.get_tensor("attention_mask");
auto prompt_len = attention_mask.get_shape()[1];
kv_cache_len = atten_mask_history.get_shape()[1];

ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, {batch_size, kv_cache_len + prompt_len}};
auto start_atten_hst = atten_mask_history.data<int64_t>() + kv_cache_len * (*m_selected_beam);
std::copy(start_atten_hst, start_atten_hst + kv_cache_len,
new_atten_mask.data<int64_t>());
std::copy(attention_mask.data<int64_t>(), attention_mask.data<int64_t>() + prompt_len,
new_atten_mask.data<int64_t>() + kv_cache_len);
concatenated_attention_mask = new_atten_mask;
if (is_chat_conversation && m_history_available) {
if (m_is_cache_empty) {
attention_mask = ov::genai::utils::init_attention_mask(tokenized_chat_history);
concatenated_attention_mask = attention_mask;
} else {
OPENVINO_ASSERT(batch_size == 1, "continuation of generation is possible only for batch 1");
// If history is saved in KV cache, concatenate new attention_mask with the already existing.
// Between subsequent runs attention_mask should not be modified.
auto atten_mask_history = m_model_runner.get_tensor("attention_mask");
auto prompt_len = attention_mask.get_shape()[1];
kv_cache_len = atten_mask_history.get_shape()[1];

ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, {batch_size, kv_cache_len + prompt_len}};
auto start_atten_hst = atten_mask_history.data<int64_t>();
std::copy(start_atten_hst, start_atten_hst + kv_cache_len,
new_atten_mask.data<int64_t>());
std::copy(attention_mask.data<int64_t>(), attention_mask.data<int64_t>() + prompt_len,
new_atten_mask.data<int64_t>() + kv_cache_len);
concatenated_attention_mask = new_atten_mask;
}
} else {
concatenated_attention_mask = attention_mask;
}

bool position_ids_available = (num_inputs == 4);
std::optional<ov::Tensor> position_ids = std::nullopt;
if (position_ids_available) {
position_ids = ov::Tensor{ov::element::i64, input_ids.get_shape()};
utils::initialize_position_ids(*position_ids, attention_mask, kv_cache_len);
if (is_chat_conversation && m_is_cache_empty) {
position_ids = ov::Tensor{ov::element::i64, tokenized_chat_history.get_shape()};
} else {
position_ids = ov::Tensor{ov::element::i64, input_ids.get_shape()};
}
utils::initialize_position_ids(*position_ids, attention_mask, kv_cache_len);
}

if(m_adapter_controller) {
m_adapter_controller->apply(m_model_runner, config.adapters);
}

ov::genai::EncodedResults result;
if (config.is_beam_search() && is_chat_conversation) {
std::tie(result, m_selected_beam) = beam_search(m_model_runner, input_ids, concatenated_attention_mask,
config, position_ids, m_selected_beam);
} else {
std::vector<SequenceGroup::Ptr> requests;
size_t block_size = 1;
bool enable_prefix_caching = false;

config.stop_token_ids.insert(config.eos_token_id);
for (size_t request_id = 0; request_id < batch_size; request_id++) {
SequenceGroup::Ptr sequence_group;
if (is_chat_conversation && !m_is_cache_empty) {
sequence_group = std::make_shared<SequenceGroup>(request_id, m_tokenized_chat_history.input_ids, config, block_size, enable_prefix_caching);
} else {
size_t seq_len = input_ids.get_shape().at(1);
size_t batch_offset = request_id * seq_len;
const int64_t* prompt_start = input_ids.data<const int64_t>() + batch_offset;
std::vector<int64_t> tokenized_prompt(prompt_start, prompt_start + seq_len);

sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_prompt, config, block_size, enable_prefix_caching);
}
std::vector<SequenceGroup::Ptr> requests;
size_t block_size = 1;
bool enable_prefix_caching = false;
for (size_t request_id = 0; request_id < batch_size; request_id++) {
SequenceGroup::Ptr sequence_group;
if (is_chat_conversation) {
sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_chat_history, config, block_size, enable_prefix_caching);
} else {
size_t seq_len = input_ids.get_shape().at(1);
size_t batch_offset = request_id * seq_len;
const int64_t* prompt_start = input_ids.data<const int64_t>() + batch_offset;
std::vector<int64_t> tokenized_prompt(prompt_start, prompt_start + seq_len);

sequence_group->set_sequence_group_ptr(sequence_group);
requests.push_back(sequence_group);
sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_prompt, config, block_size, enable_prefix_caching);
}

Sampler sampler = Sampler(m_tokenizer);
std::tie(result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask, streamer_ptr,
sampler, requests, position_ids, std::nullopt, m_selected_beam);
sequence_group->set_sequence_group_ptr(sequence_group);
requests.push_back(sequence_group);
}

if (!is_chat_conversation) {
Sampler sampler = Sampler(m_tokenizer);
// we can't properly refer to history in case of chat scenario with beam search, so reset_kv_state and use the whole history for each new propmt
auto input_tokens = input_ids;
if (is_chat_conversation && m_is_cache_empty) {
input_tokens = tokenized_chat_history;
}
ov::genai::EncodedResults result = ov::genai::get_lm_encoded_results(m_model_runner, input_tokens, concatenated_attention_mask,
streamer_ptr, sampler, requests, position_ids, std::nullopt);

m_is_cache_empty = false;
if (!is_chat_conversation || config.is_beam_search()) {
reset_kv_state();
m_selected_beam = std::nullopt;
} else {
m_is_cache_empty = false;
m_is_cache_empty = true;
}

if (is_chat_conversation) {
m_history_available = true;
m_last_disappeared_token = result.tokens[0].back();
Copy link
Contributor

@ilya-lavrenov ilya-lavrenov Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose EOS can also be added here, but it's OK for such token to disappear

can we extract from sampler the reason of sequence finishing and add last token only iff it's ended by not EOS condition ?

}

if (is_chat_conversation && m_chat_input_type == GenerationChatInputsType::ENCODED_INPUTS) {
std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));
}

auto stop_time = std::chrono::steady_clock::now();

// If is called without tokenization then that stat will not be reported.
Expand All @@ -306,10 +357,12 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

void start_chat(const std::string& system_message) override {
is_chat_conversation = true;
m_selected_beam = std::nullopt;
if (!m_is_cache_empty) {
m_is_cache_empty = true;
m_chat_input_type = GenerationChatInputsType::UNDEF;
m_tokenized_chat_history = {};
if (m_history_available) {
reset_kv_state();
m_is_cache_empty = true;
m_history_available = false;
m_history = {};
m_templated_chat_history = "";
}
Expand All @@ -324,10 +377,12 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

void finish_chat() override {
is_chat_conversation = false;
m_selected_beam = std::nullopt;
if (!m_is_cache_empty) {
m_is_cache_empty = true;
m_chat_input_type = GenerationChatInputsType::UNDEF;
m_tokenized_chat_history = {};
if (m_history_available) {
reset_kv_state();
m_is_cache_empty = true;
m_history_available = false;
m_history.clear();
m_templated_chat_history.clear();
}
Expand Down
Loading
Loading