Skip to content

Commit

Permalink
Move beam search in case of chat scenario to sampler.cpp
Browse files Browse the repository at this point in the history
- fill vlm prompt with real tokens
- spread lenght penalty to multinomial
  • Loading branch information
sbalandi committed Nov 14, 2024
1 parent 96bcffe commit d213acc
Show file tree
Hide file tree
Showing 11 changed files with 114 additions and 563 deletions.
455 changes: 0 additions & 455 deletions src/cpp/src/group_beam_searcher.cpp

This file was deleted.

92 changes: 43 additions & 49 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,12 @@
namespace ov {
namespace genai {

std::pair<EncodedResults, int32_t> beam_search(
ov::InferRequest& lm,
ov::Tensor prompts,
ov::Tensor attention_mask,
GenerationConfig config,
std::optional<ov::Tensor> position_ids,
std::optional<int32_t> selected_beam_idx
);

class StatefulLLMPipeline final : public LLMPipelineImplBase {
public:
ov::InferRequest m_model_runner;

bool is_chat_conversation = false;
bool m_is_cache_empty = true;
std::optional<int32_t> m_selected_beam = std::nullopt;
ChatHistory m_history;
std::string m_templated_chat_history = {};
TokenizedInputs m_tokenized_chat_history;
Expand Down Expand Up @@ -234,7 +224,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
kv_cache_len = atten_mask_history.get_shape()[1];

ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, {batch_size, kv_cache_len + prompt_len}};
auto start_atten_hst = atten_mask_history.data<int64_t>() + kv_cache_len * (*m_selected_beam);
auto start_atten_hst = atten_mask_history.data<int64_t>();
std::copy(start_atten_hst, start_atten_hst + kv_cache_len,
new_atten_mask.data<int64_t>());
std::copy(attention_mask.data<int64_t>(), attention_mask.data<int64_t>() + prompt_len,
Expand All @@ -247,57 +237,63 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
bool position_ids_available = (num_inputs == 4);
std::optional<ov::Tensor> position_ids = std::nullopt;
if (position_ids_available) {
position_ids = ov::Tensor{ov::element::i64, input_ids.get_shape()};
utils::initialize_position_ids(*position_ids, attention_mask, kv_cache_len);
if (is_chat_conversation && config.is_beam_search()) {
position_ids = ov::Tensor{ov::element::i64, m_tokenized_chat_history.input_ids.get_shape()};
size_t start_pos = kv_cache_len - (m_tokenized_chat_history.input_ids.get_shape().at(1) - input_ids.get_shape().at(1));
size_t seq_length = m_tokenized_chat_history.input_ids.get_shape().at(1);

utils::initialize_position_ids(*position_ids, concatenated_attention_mask, seq_length, start_pos);
} else {
position_ids = ov::Tensor{ov::element::i64, input_ids.get_shape()};
utils::initialize_position_ids(*position_ids, attention_mask, attention_mask.get_shape()[1], kv_cache_len);
}
}

if(m_adapter_controller) {
m_adapter_controller->apply(m_model_runner, config.adapters);
}

ov::genai::EncodedResults result;
if (config.is_beam_search() && is_chat_conversation) {
std::tie(result, m_selected_beam) = beam_search(m_model_runner, input_ids, concatenated_attention_mask,
config, position_ids, m_selected_beam);
} else {
std::vector<SequenceGroup::Ptr> requests;
size_t block_size = 1;
bool enable_prefix_caching = false;

config.stop_token_ids.insert(config.eos_token_id);
for (size_t request_id = 0; request_id < batch_size; request_id++) {
SequenceGroup::Ptr sequence_group;
if (is_chat_conversation && !m_is_cache_empty) {
sequence_group = std::make_shared<SequenceGroup>(request_id, m_tokenized_chat_history.input_ids, config, block_size, enable_prefix_caching);
sequence_group->update_processed_tokens_num(m_tokenized_chat_history.input_ids.get_shape().at(1) - 1);
} else {
size_t seq_len = input_ids.get_shape().at(1);
size_t batch_offset = request_id * seq_len;
const int64_t* prompt_start = input_ids.data<const int64_t>() + batch_offset;
std::vector<int64_t> tokenized_prompt(prompt_start, prompt_start + seq_len);
// in case of multi batch scenario, remove eos_token_id at start of prompt
auto real_prompt_start = std::find_if(tokenized_prompt.begin(), tokenized_prompt.end(), [&config](int64_t token) { return token != config.eos_token_id; });
tokenized_prompt.erase(tokenized_prompt.begin(), real_prompt_start);

sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_prompt, config, block_size, enable_prefix_caching);
sequence_group->update_processed_tokens_num(tokenized_prompt.size() - 1);
}
std::vector<SequenceGroup::Ptr> requests;
size_t block_size = 1;
bool enable_prefix_caching = false;

sequence_group->set_sequence_group_ptr(sequence_group);
requests.push_back(sequence_group);
config.stop_token_ids.insert(config.eos_token_id);
for (size_t request_id = 0; request_id < batch_size; request_id++) {
SequenceGroup::Ptr sequence_group;
if (is_chat_conversation) {
sequence_group = std::make_shared<SequenceGroup>(request_id, m_tokenized_chat_history.input_ids, config, block_size, enable_prefix_caching);
sequence_group->update_processed_tokens_num(m_tokenized_chat_history.input_ids.get_shape().at(1) - 1);
} else {
size_t seq_len = input_ids.get_shape().at(1);
size_t batch_offset = request_id * seq_len;
const int64_t* prompt_start = input_ids.data<const int64_t>() + batch_offset;
std::vector<int64_t> tokenized_prompt(prompt_start, prompt_start + seq_len);
// in case of multi batch scenario, remove eos_token_id at start of prompt
auto real_prompt_start = std::find_if(tokenized_prompt.begin(), tokenized_prompt.end(), [&config](int64_t token) { return token != config.eos_token_id; });
tokenized_prompt.erase(tokenized_prompt.begin(), real_prompt_start);

sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_prompt, config, block_size, enable_prefix_caching);
sequence_group->update_processed_tokens_num(tokenized_prompt.size() - 1);
}

Sampler sampler = Sampler(m_tokenizer);
std::tie(result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask, streamer_ptr,
sampler, requests, position_ids, std::nullopt, m_selected_beam);
sequence_group->set_sequence_group_ptr(sequence_group);
requests.push_back(sequence_group);
}

if (!is_chat_conversation) {
Sampler sampler = Sampler(m_tokenizer);
// we can't properly refer to history in case of chat scenario with beam search, so reset_kv_state and use the whole history for each new propmt
auto input_tokens = (is_chat_conversation && config.is_beam_search()) ? m_tokenized_chat_history.input_ids : input_ids;
result = ov::genai::get_lm_encoded_results(m_model_runner, input_tokens, concatenated_attention_mask, streamer_ptr, sampler, requests, position_ids, std::nullopt);

if (!is_chat_conversation || config.is_beam_search()) {
reset_kv_state();
m_selected_beam = std::nullopt;
} else {
}

if (is_chat_conversation) {
m_is_cache_empty = false;
}

auto stop_time = std::chrono::steady_clock::now();

// If is called without tokenization then that stat will not be reported.
Expand All @@ -311,7 +307,6 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

void start_chat(const std::string& system_message) override {
is_chat_conversation = true;
m_selected_beam = std::nullopt;
if (!m_is_cache_empty) {
reset_kv_state();
m_is_cache_empty = true;
Expand All @@ -329,7 +324,6 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

void finish_chat() override {
is_chat_conversation = false;
m_selected_beam = std::nullopt;
if (!m_is_cache_empty) {
reset_kv_state();
m_is_cache_empty = true;
Expand Down
77 changes: 37 additions & 40 deletions src/cpp/src/lm_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,22 @@
#include "lm_encoding.hpp"
#include "openvino/genai/perf_metrics.hpp"

#include "debug_utils.hpp"

#include "utils.hpp"

#include "debug_utils.hpp"

namespace ov {
namespace genai {

void reset_all_inputs_to_empty_tensors(ov::InferRequest& request, bool is_vlm = false) {
if (!is_vlm)
request.set_tensor("input_ids", ov::Tensor(ov::element::i64, {0, 0}));

request.set_tensor("beam_idx", ov::Tensor(ov::element::i32, {0}));
if (request.get_compiled_model().inputs().size() == 4)
request.set_tensor("position_ids", ov::Tensor(ov::element::i64, {0, 0}));
}

void update_position_ids(ov::Tensor&& position_ids, const ov::Tensor&& attention_mask) {
const size_t batch_size = attention_mask.get_shape().at(0);
const size_t sequence_length = attention_mask.get_shape().at(1);
Expand Down Expand Up @@ -51,22 +60,16 @@ void update_attention_mask_with_beams(ov::Tensor&& attention_mask, std::vector<i
}


std::pair<EncodedResults, int32_t> get_lm_encoded_results(
EncodedResults get_lm_encoded_results(
ov::InferRequest& m_llm,
const ov::Tensor& input_ids,
const ov::Tensor& attention_mask,
const std::shared_ptr<StreamerBase>& streamer_ptr,
Sampler& sampler,
std::vector<SequenceGroup::Ptr> sequence_groups,
std::optional<ov::Tensor> position_ids,
std::optional<EmbeddingsModel> m_embedding,
std::optional<int32_t> selected_beam_idx
std::optional<EmbeddingsModel> m_embedding
) {
std::vector<GenerationHandle> generations;
for (SequenceGroup::Ptr sequence_group : sequence_groups) {
generations.push_back(std::make_shared<GenerationHandleImpl>(sequence_group->get_generation_stream(), sequence_group->get_sampling_parameters()));
}

ov::Shape prompts_shape = input_ids.get_shape();
const size_t batch_size = prompts_shape[0];

Expand All @@ -88,10 +91,7 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(

ov::Tensor beam_idx = ov::Tensor(ov::element::i32, {batch_size});
auto beam_data = beam_idx.data<int32_t>();
if (selected_beam_idx.has_value())
beam_data[0] = *selected_beam_idx;
else
std::fill_n(beam_data, batch_size, 0);
std::fill_n(beam_data, batch_size, 0);
m_llm.set_tensor("beam_idx", beam_idx);

const auto infer_start = std::chrono::steady_clock::now();
Expand Down Expand Up @@ -177,18 +177,26 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
m_llm.get_tensor("inputs_embeds").set_shape(embed_prompt_tensor.get_shape());
m_llm.set_tensor("inputs_embeds", embed_prompt_tensor);
} else {
m_llm.get_tensor("input_ids").set_shape(new_input_ids.get_shape());
m_llm.get_tensor("input_ids").set_shape({total_num_tokens, 1});
m_llm.set_tensor("input_ids", new_input_ids);
}

m_llm.get_tensor("beam_idx").set_shape({ total_num_tokens });
m_llm.set_tensor("beam_idx", ov::Tensor{ov::element::i32, {total_num_tokens}, next_beams.data()});

update_attention_mask_with_beams(m_llm.get_tensor("attention_mask"), next_beams);

if (position_ids.has_value()) {
update_position_ids(m_llm.get_tensor("position_ids"), m_llm.get_tensor("attention_mask"));
}

m_llm.get_tensor("beam_idx").set_shape({ total_num_tokens });
m_llm.set_tensor("beam_idx", ov::Tensor{ov::element::i32, {total_num_tokens}, next_beams.data()});
if (streamer_ptr) {
// stream data from first sequence
int64_t out_token = sequence_groups.at(0).get()->operator[](0)->get_generated_ids().back();
if (streamer_ptr->put(out_token)) {
break;
}
}

const auto infer_start = std::chrono::steady_clock::now();
m_llm.infer();
Expand All @@ -199,14 +207,6 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
raw_perf_counters.m_new_token_times.emplace_back(infer_end);
raw_perf_counters.m_batch_sizes.emplace_back(batch_size);

if (streamer_ptr) {
// stream data from first sequence
int64_t out_token = sequence_groups.at(0).get()->operator[](0)->get_generated_ids().back();
if (streamer_ptr->put(out_token)) {
break;
}
}

sampler_output = sampler.sample(active_sequence_groups, m_llm.get_tensor("logits"));

active_sequence_groups.erase(std::remove_if(active_sequence_groups.begin(),
Expand All @@ -220,26 +220,23 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
streamer_ptr->put(out_token);
streamer_ptr->end();
}

size_t next_selected_beam = 0;
for (size_t i = 0; i < sequence_groups.size(); i++) {
auto request = sequence_groups[i];
auto generation_outputs = generations[i]->read_all();

std::sort(generation_outputs.begin(), generation_outputs.end(), [] (const GenerationOutput& r1, const GenerationOutput& r2) {
return r1.score > r2.score;
});
reset_all_inputs_to_empty_tensors(m_llm, m_embedding.has_value());

auto num_outputs = std::min(request->get_sampling_parameters().num_return_sequences, generation_outputs.size());
for (size_t generation_output_idx = 0; generation_output_idx < num_outputs; ++generation_output_idx) {
const auto& generation_output = generation_outputs[generation_output_idx];
results.tokens.push_back(std::move(generation_output.generated_ids));
results.scores.push_back(generation_output.score);
for (size_t i = 0; i < sequence_groups.size(); i++) {
auto request = sequence_groups[i];
// sequences is sorted by cumulative_log_prob with length_penalty
auto outputs = request->get_finished_sequences();

auto num_outputs = std::min(request->get_sampling_parameters().num_return_sequences, outputs.size());
for (size_t output_idx = 0; output_idx < num_outputs; ++output_idx) {
const auto& output = outputs[output_idx];
results.tokens.push_back(std::move(output->get_generated_ids()));
results.scores.push_back(output->get_cumulative_score_with_length_penalty(request->get_sampling_parameters()));
}
// next_selected_beam = sampler.last_selected_beam(request);
}

return {results, next_selected_beam};
return results;
}

} // namespace genai
Expand Down
4 changes: 2 additions & 2 deletions src/cpp/src/lm_encoding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
namespace ov {
namespace genai {

std::pair<EncodedResults, int32_t> get_lm_encoded_results(ov::InferRequest& m_llm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask,
EncodedResults get_lm_encoded_results(ov::InferRequest& m_llm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask,
const std::shared_ptr<StreamerBase>& streamer_ptr, Sampler& sampler, std::vector<SequenceGroup::Ptr> sequence_groups,
std::optional<ov::Tensor> position_ids, std::optional<EmbeddingsModel> m_embedding, std::optional<int32_t> selected_beam_idx);
std::optional<ov::Tensor> position_ids, std::optional<EmbeddingsModel> m_embedding);

void update_attention_mask_with_beams(ov::Tensor&& attention_mask, std::vector<int32_t> next_beams);

Expand Down
1 change: 0 additions & 1 deletion src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ Sampler::GroupBeamSearcher::GroupBeamSearcher(SequenceGroup::Ptr sequence_group,
}
}


std::map<size_t, int32_t> Sampler::GroupBeamSearcher::get_beam_idxs() {
std::map<size_t, int32_t> next_beams;

Expand Down
6 changes: 3 additions & 3 deletions src/cpp/src/sequence_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class Sequence {
m_generated_log_probs[idx] = log_prob;
}

float get_beam_search_score(const ov::genai::GenerationConfig& sampling_params) const {
float get_cumulative_score_with_length_penalty(const ov::genai::GenerationConfig& sampling_params) const {
float cumulative_log_prob = get_cumulative_log_probs(), current_length = get_generated_len();
float score = cumulative_log_prob / std::pow(current_length, sampling_params.length_penalty);
return score;
Expand Down Expand Up @@ -339,7 +339,7 @@ class SequenceGroup {

// do we need to sort sequences here or sampler can handle it for us?
std::sort(finished_seqs.begin(), finished_seqs.end(), [=] (Sequence::CPtr s1, Sequence::CPtr s2) {
return s1->get_beam_search_score(m_sampling_params) > s2->get_beam_search_score(m_sampling_params);
return s1->get_cumulative_score_with_length_penalty(m_sampling_params) > s2->get_cumulative_score_with_length_penalty(m_sampling_params);
});

return finished_seqs;
Expand Down Expand Up @@ -586,7 +586,7 @@ class SequenceGroup {
output.generated_ids.insert(output.generated_ids.begin(), m_prompt_ids.begin(), m_prompt_ids.end());
output.generated_log_probs.insert(output.generated_log_probs.begin(), m_prompt_log_probs.begin(), m_prompt_log_probs.end());
}
output.score = m_sampling_params.is_beam_search() ? sequence->get_beam_search_score(m_sampling_params) : sequence->get_cumulative_log_probs();
output.score = m_sampling_params.is_beam_search() ? sequence->get_cumulative_score_with_length_penalty(m_sampling_params) : sequence->get_cumulative_log_probs();
output.finish_reason = sequence->get_finish_reason();
outputs.emplace(sequence->get_grouped_id(), output);
}
Expand Down
5 changes: 2 additions & 3 deletions src/cpp/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ int64_t argmax(const ov::Tensor& logits, const size_t batch_idx) {
/**
* Initializes position ids based on attention mask and starting position
*/
void initialize_position_ids(ov::Tensor& position_ids, const ov::Tensor& attention_mask, int64_t start_pos) {
void initialize_position_ids(ov::Tensor& position_ids, const ov::Tensor& attention_mask, size_t seq_length, int64_t start_pos) {
OPENVINO_ASSERT(position_ids.get_element_type() == ov::element::i64,
"position_ids tensor element type should be an i64");
OPENVINO_ASSERT(position_ids.get_shape().size() == 2,
Expand All @@ -72,7 +72,6 @@ void initialize_position_ids(ov::Tensor& position_ids, const ov::Tensor& attenti
"attention_mask tensor should of rank 2 with shape [batch_size, seq_len]");

const size_t batch_size = attention_mask.get_shape()[0];
const size_t seq_length = attention_mask.get_shape()[1];

const int64_t* attention_mask_data = attention_mask.data<int64_t>();
int64_t* position_ids_data = position_ids.data<int64_t>();
Expand All @@ -97,7 +96,7 @@ void initialize_beam_inputs(const ov::Tensor& input_ids, const ov::Tensor& atten

ov::Tensor position_ids = request.get_tensor("position_ids");
position_ids.set_shape(input_shape);
initialize_position_ids(position_ids, attention_mask);
initialize_position_ids(position_ids, attention_mask, attention_mask.get_shape()[1]);

ov::Tensor beam_idx = request.get_tensor("beam_idx");
beam_idx.set_shape({input_shape.at(0)});
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void print_tensor(const ov::Tensor& tensor);

int64_t argmax(const ov::Tensor& logits, const size_t batch_idx);

void initialize_position_ids(ov::Tensor& position_ids, const ov::Tensor& attention_mask, int64_t start_pos = 0);
void initialize_position_ids(ov::Tensor& position_ids, const ov::Tensor& attention_mask, size_t seq_length, int64_t start_pos = 0);

ov::Tensor extend_attention(ov::Tensor attention_mask);

Expand Down
Loading

0 comments on commit d213acc

Please sign in to comment.