Skip to content

Commit

Permalink
fix logic.
Browse files Browse the repository at this point in the history
  • Loading branch information
zwang86 committed Oct 6, 2023
1 parent 3dc8bee commit 2bbb987
Showing 1 changed file with 96 additions and 93 deletions.
189 changes: 96 additions & 93 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
// Step 2: prepare the next batch for existing requests
BatchConfig new_bc;
for (int i = 0; i < BatchConfig::max_requests_per_batch(); i++) {
if (old_bc.request_completed[i]) {
if (old_bc.request_completed[i]) { // add new requests to the next batch
if (!pending_request_queue.empty() &&
new_bc.num_tokens < get_max_tokens_per_batch()) {
Request new_request = pending_request_queue.front();
Expand Down Expand Up @@ -394,105 +394,108 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
break;
}
}
}
assert(old_bc.requestsInfo[i].num_tokens_in_batch > 0);
Request &request = all_requests[old_bc.requestsInfo[i].request_guid];
int processed_tokens = old_bc.requestsInfo[i].token_start_offset +
old_bc.requestsInfo[i].num_tokens_in_batch;
assert(processed_tokens < request.tokens.size());
bool request_completed = false;
// printf("model_type = %d\n", this->model_type);
if (request.tokens.size() >= old_bc.requestsInfo[i].max_sequence_length) {
request_completed = true;
} else if (request.tokens.back() == eos_token_id) {
// Encounter EOS token id
request_completed = true;
}
if (request_completed) {
request.status = Request::COMPLETED;
log_req_mgr.print("[Done] guid(%zu) final_length(%zu)",
old_bc.requestsInfo[i].request_guid,
request.tokens.size());
std::string output = this->tokenizer_->Decode(request.tokens);

{
// update generation result and trigger future
GenerationResult &gr = request_generation_results[request.guid];
assert(gr.guid == request.guid);
gr.output_tokens = request.tokens;
gr.output_text = output;
} else {
assert(old_bc.requestsInfo[i].num_tokens_in_batch > 0);
Request &request = all_requests[old_bc.requestsInfo[i].request_guid];
int processed_tokens = old_bc.requestsInfo[i].token_start_offset +
old_bc.requestsInfo[i].num_tokens_in_batch;
assert(processed_tokens < request.tokens.size());
bool request_completed = false;
// printf("model_type = %d\n", this->model_type);
if (request.tokens.size() >= old_bc.requestsInfo[i].max_sequence_length) {
request_completed = true;
} else if (request.tokens.back() == eos_token_id) {
// Encounter EOS token id
request_completed = true;
}
log_req_mgr.print("Final output: %s", output.c_str());
num_processed_requests++;
ProfileInfo profile_info = profiling_requests[request.guid];
profile_info.finish_time = Realm::Clock::current_time_in_microseconds();
total_request_run_time +=
profile_info.finish_time - profile_info.start_time;
profiling_requests[request.guid] = profile_info;
log_req_mgr.print("[Profile] guid(%zu) decoding_steps(%d) start(%.1lf) "
"finish(%.1lf) latency(%.1lf)",
request.guid,
profile_info.decoding_steps,
profile_info.start_time,
profile_info.finish_time,
profile_info.finish_time - profile_info.start_time);
// Write output to file if needed:
if (!output_filepath.empty()) {
std::ofstream outputFile(output_filepath);
if (outputFile.is_open()) {
outputFile << "end-to-end latency: " << std::fixed
<< std::setprecision(3) << total_request_run_time
<< std::endl;
outputFile << "num decoding steps: " << profile_info.decoding_steps
<< std::endl;
outputFile << "token IDs: ";
for (int i = 0; i < request.tokens.size(); i++) {
outputFile << request.tokens[i];
if (i < request.tokens.size() - 1) {
outputFile << ",";
if (request_completed) {
request.status = Request::COMPLETED;
log_req_mgr.print("[Done] guid(%zu) final_length(%zu)",
old_bc.requestsInfo[i].request_guid,
request.tokens.size());
std::string output = this->tokenizer_->Decode(request.tokens);

{
// update generation result and trigger future
GenerationResult &gr = request_generation_results[request.guid];
assert(gr.guid == request.guid);
gr.output_tokens = request.tokens;
gr.output_text = output;
}
log_req_mgr.print("Final output: %s", output.c_str());
num_processed_requests++;
ProfileInfo profile_info = profiling_requests[request.guid];
profile_info.finish_time = Realm::Clock::current_time_in_microseconds();
total_request_run_time +=
profile_info.finish_time - profile_info.start_time;
profiling_requests[request.guid] = profile_info;
log_req_mgr.print("[Profile] guid(%zu) decoding_steps(%d) start(%.1lf) "
"finish(%.1lf) latency(%.1lf)",
request.guid,
profile_info.decoding_steps,
profile_info.start_time,
profile_info.finish_time,
profile_info.finish_time - profile_info.start_time);
// Write output to file if needed:
if (!output_filepath.empty()) {
std::ofstream outputFile(output_filepath);
if (outputFile.is_open()) {
outputFile << "end-to-end latency: " << std::fixed
<< std::setprecision(3) << total_request_run_time
<< std::endl;
outputFile << "num decoding steps: " << profile_info.decoding_steps
<< std::endl;
outputFile << "token IDs: ";
for (int i = 0; i < request.tokens.size(); i++) {
outputFile << request.tokens[i];
if (i < request.tokens.size() - 1) {
outputFile << ",";
}
}
outputFile << std::endl;
outputFile << output;
outputFile.close();
} else {
std::cout << "Unable to open the output file: " << output_filepath
<< std::endl;
assert(false);
}
outputFile << std::endl;
outputFile << output;
outputFile.close();
} else {
std::cout << "Unable to open the output file: " << output_filepath
<< std::endl;
assert(false);
}
}

// std::cout << "print results: " << std::endl;
// for (int i = 0; i < request.tokens.size(); i++) {
// std::cout << request.tokens.at(i) << ", ";
// }
} else {
new_bc.request_completed[i] = false;
new_bc.requestsInfo[i].token_start_offset = processed_tokens;
new_bc.requestsInfo[i].request_guid = old_bc.requestsInfo[i].request_guid;
new_bc.requestsInfo[i].max_sequence_length =
old_bc.requestsInfo[i].max_sequence_length;
if (new_bc.requestsInfo[i].token_start_offset + 1 ==
request.tokens.size()) {
// Incremental phase
new_bc.requestsInfo[i].num_tokens_in_batch = 1;
// std::cout << "print results: " << std::endl;
// for (int i = 0; i < request.tokens.size(); i++) {
// std::cout << request.tokens.at(i) << ", ";
// }
} else {
// Prompt phase
new_bc.requestsInfo[i].num_tokens_in_batch =
std::min(get_max_tokens_per_batch() - new_bc.num_tokens,
(int)request.tokens.size() -
new_bc.requestsInfo[i].token_start_offset);
}
for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) {
int depth = new_bc.requestsInfo[i].token_start_offset + j;
new_bc.tokensInfo[new_bc.num_tokens].request_index = i;
new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = depth;
assert(depth < request.tokens.size());
new_bc.tokensInfo[new_bc.num_tokens].token_id = request.tokens[depth];
new_bc.num_tokens++;
new_bc.request_completed[i] = false;
new_bc.requestsInfo[i].token_start_offset = processed_tokens;
new_bc.requestsInfo[i].request_guid =
old_bc.requestsInfo[i].request_guid;
new_bc.requestsInfo[i].max_sequence_length =
old_bc.requestsInfo[i].max_sequence_length;
if (new_bc.requestsInfo[i].token_start_offset + 1 ==
request.tokens.size()) {
// Incremental phase
new_bc.requestsInfo[i].num_tokens_in_batch = 1;
} else {
// Prompt phase
new_bc.requestsInfo[i].num_tokens_in_batch =
std::min(get_max_tokens_per_batch() - new_bc.num_tokens,
(int)request.tokens.size() -
new_bc.requestsInfo[i].token_start_offset);
}
for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) {
int depth = new_bc.requestsInfo[i].token_start_offset + j;
new_bc.tokensInfo[new_bc.num_tokens].request_index = i;
new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = depth;
assert(depth < request.tokens.size());
new_bc.tokensInfo[new_bc.num_tokens].token_id = request.tokens[depth];
new_bc.num_tokens++;
}
// Update profiling
profiling_requests[new_bc.requestsInfo[i].request_guid]
.decoding_steps++;
}
// Update profiling
profiling_requests[new_bc.requestsInfo[i].request_guid].decoding_steps++;
}
}

Expand Down

0 comments on commit 2bbb987

Please sign in to comment.