Skip to content

Commit

Permalink
Allow token arrangement align with request index in batch (#1176)
Browse files Browse the repository at this point in the history
* arrange tokens by request index in incr decoding.

* fix logic.

* fix issues.

* format.

* undo output format change.

* format.

* remove empty line in end of the file.
  • Loading branch information
zwang86 authored Oct 16, 2023
1 parent 7b57463 commit f243b40
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 119 deletions.
10 changes: 8 additions & 2 deletions src/ops/fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1104,14 +1104,20 @@ __host__ void
}
for (int i = 0; i < fused->op_num_weights[op]; i++) {
assert(fused->op_weight_source[i + woff] == SOURCE_WEIGHT);
weight_accessors_to_save.push_back(weight_accessor[fused->op_weight_idx[i + woff]]);
weight_accessors_to_save.push_back(
weight_accessor[fused->op_weight_idx[i + woff]]);
}
for (int i = 0; i < fused->op_num_outputs[op]; i++) {
output_accessors_to_save.push_back(output_accessor[i + ooff]);
}
assert(task->index_point.get_dim() == 1);
int shard_id = task->index_point.point_data[0];
FusedOp::save_inference_tensors_to_file(metas->meta[op], shard_id, bc, input_accessors_to_save, weight_accessors_to_save, output_accessors_to_save);
FusedOp::save_inference_tensors_to_file(metas->meta[op],
shard_id,
bc,
input_accessors_to_save,
weight_accessors_to_save,
output_accessors_to_save);
}
ioff += fused->op_num_inputs[op];
woff += fused->op_num_weights[op];
Expand Down
229 changes: 112 additions & 117 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ BatchConfig RequestManager::prepare_next_batch_task(
BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
InferenceResult const &result) {
const std::lock_guard<std::mutex> lock(request_queue_mutex);

// Step 1: append result from previous iteration to request's tokens
for (int i = 0; i < old_bc.num_tokens; i++) {
size_t guid =
Expand All @@ -356,115 +357,11 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
// log_req_mgr.print("Output: %s", output.c_str());
}
}

// Step 2: prepare the next batch for existing requests
BatchConfig new_bc;
for (int i = 0; i < BatchConfig::max_requests_per_batch(); i++) {
if (old_bc.request_completed[i]) {
continue;
}
assert(old_bc.requestsInfo[i].num_tokens_in_batch > 0);
Request &request = all_requests[old_bc.requestsInfo[i].request_guid];
int processed_tokens = old_bc.requestsInfo[i].token_start_offset +
old_bc.requestsInfo[i].num_tokens_in_batch;
assert(processed_tokens < request.tokens.size());
bool request_completed = false;
// printf("model_type = %d\n", this->model_type);
if (request.tokens.size() >= old_bc.requestsInfo[i].max_sequence_length) {
request_completed = true;
} else if (request.tokens.back() == eos_token_id) {
// Encounter EOS token id
request_completed = true;
}
if (request_completed) {
request.status = Request::COMPLETED;
log_req_mgr.print("[Done] guid(%zu) final_length(%zu)",
old_bc.requestsInfo[i].request_guid,
request.tokens.size());
std::string output = this->tokenizer_->Decode(request.tokens);

{
// update generation result and trigger future
GenerationResult &gr = request_generation_results[request.guid];
assert(gr.guid == request.guid);
gr.output_tokens = request.tokens;
gr.output_text = output;
}
log_req_mgr.print("Final output: %s", output.c_str());
num_processed_requests++;
ProfileInfo profile_info = profiling_requests[request.guid];
profile_info.finish_time = Realm::Clock::current_time_in_microseconds();
total_request_run_time +=
profile_info.finish_time - profile_info.start_time;
profiling_requests[request.guid] = profile_info;
log_req_mgr.print("[Profile] guid(%zu) decoding_steps(%d) start(%.1lf) "
"finish(%.1lf) latency(%.1lf)",
request.guid,
profile_info.decoding_steps,
profile_info.start_time,
profile_info.finish_time,
profile_info.finish_time - profile_info.start_time);
// Write output to file if needed:
if (!output_filepath.empty()) {
std::ofstream outputFile(output_filepath);
if (outputFile.is_open()) {
outputFile << "end-to-end latency: " << std::fixed
<< std::setprecision(3) << total_request_run_time
<< std::endl;
outputFile << "num decoding steps: " << profile_info.decoding_steps
<< std::endl;
outputFile << "token IDs: ";
for (int i = 0; i < request.tokens.size(); i++) {
outputFile << request.tokens[i];
if (i < request.tokens.size() - 1) {
outputFile << ",";
}
}
outputFile << std::endl;
outputFile << output;
outputFile.close();
} else {
std::cout << "Unable to open the output file: " << output_filepath
<< std::endl;
assert(false);
}
}

// std::cout << "print results: " << std::endl;
// for (int i = 0; i < request.tokens.size(); i++) {
// std::cout << request.tokens.at(i) << ", ";
// }
} else {
new_bc.request_completed[i] = false;
new_bc.requestsInfo[i].token_start_offset = processed_tokens;
new_bc.requestsInfo[i].request_guid = old_bc.requestsInfo[i].request_guid;
new_bc.requestsInfo[i].max_sequence_length =
old_bc.requestsInfo[i].max_sequence_length;
if (new_bc.requestsInfo[i].token_start_offset + 1 ==
request.tokens.size()) {
// Incremental phase
new_bc.requestsInfo[i].num_tokens_in_batch = 1;
} else {
// Prompt phase
new_bc.requestsInfo[i].num_tokens_in_batch =
std::min(get_max_tokens_per_batch() - new_bc.num_tokens,
(int)request.tokens.size() -
new_bc.requestsInfo[i].token_start_offset);
}
for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) {
int depth = new_bc.requestsInfo[i].token_start_offset + j;
new_bc.tokensInfo[new_bc.num_tokens].request_index = i;
new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = depth;
assert(depth < request.tokens.size());
new_bc.tokensInfo[new_bc.num_tokens].token_id = request.tokens[depth];
new_bc.num_tokens++;
}
// Update profiling
profiling_requests[new_bc.requestsInfo[i].request_guid].decoding_steps++;
}
}
// Step 3: add new requests to the next batch
for (int i = 0; i < BatchConfig::max_requests_per_batch(); i++) {
if (new_bc.request_completed[i]) {
if (old_bc.request_completed[i]) { // add new requests to the next batch
if (!pending_request_queue.empty() &&
new_bc.num_tokens < get_max_tokens_per_batch()) {
Request new_request = pending_request_queue.front();
Expand All @@ -473,7 +370,8 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
new_bc.requestsInfo[i].token_start_offset = 0;
new_bc.requestsInfo[i].request_guid = new_request.guid;
new_bc.requestsInfo[i].num_tokens_in_batch =
std::min(get_max_tokens_per_batch() - new_bc.num_tokens,
std::min(get_max_tokens_per_batch() - new_bc.num_tokens -
BatchConfig::max_requests_per_batch() + (i + 1),
(int)new_request.tokens.size());
new_bc.requestsInfo[i].max_sequence_length =
new_request.max_sequence_length;
Expand All @@ -496,8 +394,107 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
break;
}
}
} else {
assert(old_bc.requestsInfo[i].num_tokens_in_batch > 0);
Request &request = all_requests[old_bc.requestsInfo[i].request_guid];
int processed_tokens = old_bc.requestsInfo[i].token_start_offset +
old_bc.requestsInfo[i].num_tokens_in_batch;
assert(processed_tokens < request.tokens.size());
bool request_completed = false;
// printf("model_type = %d\n", this->model_type);
if (request.tokens.size() >= old_bc.requestsInfo[i].max_sequence_length) {
request_completed = true;
} else if (request.tokens.back() == eos_token_id) {
// Encounter EOS token id
request_completed = true;
}
if (request_completed) {
request.status = Request::COMPLETED;
log_req_mgr.print("[Done] guid(%zu) final_length(%zu)",
old_bc.requestsInfo[i].request_guid,
request.tokens.size());
std::string output = this->tokenizer_->Decode(request.tokens);

{
// update generation result and trigger future
GenerationResult &gr = request_generation_results[request.guid];
assert(gr.guid == request.guid);
gr.output_tokens = request.tokens;
gr.output_text = output;
}
log_req_mgr.print("Final output: %s", output.c_str());
num_processed_requests++;
ProfileInfo profile_info = profiling_requests[request.guid];
profile_info.finish_time = Realm::Clock::current_time_in_microseconds();
total_request_run_time +=
profile_info.finish_time - profile_info.start_time;
profiling_requests[request.guid] = profile_info;
log_req_mgr.print("[Profile] guid(%zu) decoding_steps(%d) start(%.1lf) "
"finish(%.1lf) latency(%.1lf)",
request.guid,
profile_info.decoding_steps,
profile_info.start_time,
profile_info.finish_time,
profile_info.finish_time - profile_info.start_time);
// Write output to file if needed:
if (!output_filepath.empty()) {
std::ofstream outputFile(output_filepath, std::ios::app);
if (outputFile.is_open()) {
outputFile << "end-to-end latency: " << std::fixed
<< std::setprecision(3) << total_request_run_time
<< std::endl;
outputFile << "num decoding steps: " << profile_info.decoding_steps
<< std::endl;
outputFile << "token IDs: ";
for (int i = 0; i < request.tokens.size(); i++) {
outputFile << request.tokens[i];
if (i < request.tokens.size() - 1) {
outputFile << ",";
}
}
outputFile << std::endl;
outputFile << output;
outputFile.close();
} else {
std::cout << "Unable to open the output file: " << output_filepath
<< std::endl;
assert(false);
}
}

} else {
new_bc.request_completed[i] = false;
new_bc.requestsInfo[i].token_start_offset = processed_tokens;
new_bc.requestsInfo[i].request_guid =
old_bc.requestsInfo[i].request_guid;
new_bc.requestsInfo[i].max_sequence_length =
old_bc.requestsInfo[i].max_sequence_length;
if (new_bc.requestsInfo[i].token_start_offset + 1 ==
request.tokens.size()) {
// Incremental phase
new_bc.requestsInfo[i].num_tokens_in_batch = 1;
} else {
// Prompt phase
new_bc.requestsInfo[i].num_tokens_in_batch =
std::min(get_max_tokens_per_batch() - new_bc.num_tokens,
(int)request.tokens.size() -
new_bc.requestsInfo[i].token_start_offset);
}
for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) {
int depth = new_bc.requestsInfo[i].token_start_offset + j;
new_bc.tokensInfo[new_bc.num_tokens].request_index = i;
new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = depth;
assert(depth < request.tokens.size());
new_bc.tokensInfo[new_bc.num_tokens].token_id = request.tokens[depth];
new_bc.num_tokens++;
}
// Update profiling
profiling_requests[new_bc.requestsInfo[i].request_guid]
.decoding_steps++;
}
}
}

return new_bc;
}

Expand Down Expand Up @@ -654,11 +651,10 @@ BeamSearchBatchConfig

// Write output to file if needed:
if (!output_filepath.empty()) {
std::ofstream outputFile(output_filepath);
std::ofstream outputFile(output_filepath, std::ios::app);
if (outputFile.is_open()) {
outputFile << "end-to-end latency: " << std::fixed
<< std::setprecision(3)
<< profile_info.finish_time - profile_info.start_time
<< std::setprecision(3) << total_request_run_time
<< std::endl;
outputFile << "num decoding steps: " << profile_info.decoding_steps
<< std::endl;
Expand All @@ -671,6 +667,7 @@ BeamSearchBatchConfig
}
outputFile << std::endl;
outputFile << output;

outputFile.close();
} else {
std::cout << "Unable to open the output file: " << output_filepath
Expand Down Expand Up @@ -1098,10 +1095,8 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify(
std::vector<BeamSearchBatchConfig> const &old_batches) {
const std::lock_guard<std::mutex> lock(request_queue_mutex);

if (verbose) {
std::cout
<< "\n############### prepare_next_batch_verify ###############\n";
}
std::cout << "\n############### prepare_next_batch_verify ###############\n";

assert(old_batches.size() > 0);

TreeVerifyBatchConfig new_bc;
Expand Down Expand Up @@ -1277,8 +1272,6 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify(

std::cout << "max_prompt_load_size: " << max_prompt_load_size
<< std::endl;
std::cout << "new_bc.requestsInfo[i].num_tokens_in_batch: " << i << ", "
<< new_bc.requestsInfo[i].num_tokens_in_batch << std::endl;

if (request.llm_cache_size < request.initial_len) {
// Initialization (prompt) phase
Expand All @@ -1298,7 +1291,9 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify(
break;
}

if (new_bc.num_tokens + request.llm_cache_size >= request.initial_len) {
if (new_bc.requestsInfo[i].num_tokens_in_batch +
request.llm_cache_size >=
request.initial_len) {
// launch the request into running phase after loading all prompt
request.status = Request::RUNNING;
new_bc.request_running[i] = true;
Expand Down

0 comments on commit f243b40

Please sign in to comment.