Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Nov 21, 2024
1 parent 55469e8 commit a2dbeba
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
ov::InferRequest m_model_runner;

bool is_chat_conversation = false;
bool m_is_cache_empty = true;
bool m_history_available = false;
std::optional<int32_t> m_selected_beam = std::nullopt;
ChatHistory m_history;
std::string m_templated_chat_history = {};
Expand Down Expand Up @@ -112,11 +112,11 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
// Do not add special tokens in chat scenario to be aligned with HF.
auto new_chat_tokens = m_tokenizer.encode(new_templated_chat_history, ov::genai::add_special_tokens(false));
if (m_is_cache_empty) {
encoded_input = new_chat_tokens;
} else {
if (m_history_available) {
auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(false));
encoded_input = utils::subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens);
} else {
encoded_input = new_chat_tokens;
}
m_templated_chat_history = new_templated_chat_history;
// TODO: Forbid LoRA config change if we are in the chat mode, because it requires regenerating the history with LoRA applied
Expand Down Expand Up @@ -218,12 +218,15 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
std::copy(input_ids.data<int64_t>(), input_ids.data<int64_t>() + input_ids.get_size(), std::back_inserter(m_tokenized_chat_history));
}
ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data());
bool kv_history_available = m_selected_beam.has_value();
bool is_cache_empty = !m_selected_beam.has_value();

size_t kv_cache_len = 0;
ov::Tensor concatenated_attention_mask;
if (is_chat_conversation && !m_is_cache_empty) {
if (kv_history_available) {
if (is_chat_conversation && m_history_available) {
if (is_cache_empty) {
attention_mask = ov::genai::utils::init_attention_mask(tokenized_chat_history);
concatenated_attention_mask = attention_mask;
} else {
OPENVINO_ASSERT(batch_size == 1, "continuation of generation is possible only for batch 1");
// If history is saved in KV cache, concatenate new attention_mask with the already existing.
// Between subsequent runs attention_mask should not be modified.
Expand All @@ -238,9 +241,6 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
std::copy(attention_mask.data<int64_t>(), attention_mask.data<int64_t>() + prompt_len,
new_atten_mask.data<int64_t>() + kv_cache_len);
concatenated_attention_mask = new_atten_mask;
} else {
attention_mask = ov::genai::utils::init_attention_mask(tokenized_chat_history);
concatenated_attention_mask = attention_mask;
}
} else {
concatenated_attention_mask = attention_mask;
Expand All @@ -249,7 +249,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
bool position_ids_available = (num_inputs == 4);
std::optional<ov::Tensor> position_ids = std::nullopt;
if (position_ids_available) {
if (is_chat_conversation && !kv_history_available) {
if (is_chat_conversation && is_cache_empty) {
position_ids = ov::Tensor{ov::element::i64, tokenized_chat_history.get_shape()};
} else {
position_ids = ov::Tensor{ov::element::i64, input_ids.get_shape()};
Expand Down Expand Up @@ -285,7 +285,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
Sampler sampler = Sampler(m_tokenizer);
// we can't properly refer to history in case of chat scenario with beam search, so reset_kv_state and use the whole history for each new propmt
auto input_tokens = input_ids;
if (is_chat_conversation && !kv_history_available) {
if (is_chat_conversation && is_cache_empty) {
input_tokens = tokenized_chat_history;
}
result = ov::genai::get_lm_encoded_results(m_model_runner, input_tokens, concatenated_attention_mask, streamer_ptr, sampler, requests, position_ids, std::nullopt);
Expand All @@ -297,7 +297,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
}

if (is_chat_conversation) {
m_is_cache_empty = false;
m_history_available = true;
}

if (is_chat_conversation) {
Expand All @@ -320,9 +320,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
void start_chat(const std::string& system_message) override {
is_chat_conversation = true;
m_selected_beam = std::nullopt;
if (!m_is_cache_empty) {
if (m_history_available) {
reset_kv_state();
m_is_cache_empty = true;
m_history_available = false;
m_history = {};
m_templated_chat_history = "";
m_tokenized_chat_history = {};
Expand All @@ -339,9 +339,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
void finish_chat() override {
is_chat_conversation = false;
m_selected_beam = std::nullopt;
if (!m_is_cache_empty) {
if (m_history_available) {
reset_kv_state();
m_is_cache_empty = true;
m_history_available = false;
m_history.clear();
m_templated_chat_history.clear();
m_tokenized_chat_history = {};
Expand Down

0 comments on commit a2dbeba

Please sign in to comment.