diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index 108bc8d172..25bc206bf9 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -61,7 +61,7 @@ class BatchConfig { int num_tokens; struct PerRequestInfo { - int token_start_offset; + int first_token_depth_in_request; int num_tokens_in_batch; int max_sequence_length; RequestGuid request_guid; diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index 3081aaa1c2..baf6844801 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -154,7 +154,7 @@ class RequestManager { std::vector> traverse_beam_tree(BeamSearchBatchConfig const &old_bc, int request_index, - int token_start_offset); + int first_token_depth_in_request); // remove guid after put the cached tree in request std::vector> merge_dfs_trees( diff --git a/src/ops/inc_multihead_self_attention.cpp b/src/ops/inc_multihead_self_attention.cpp index 562898a220..37cc986f5e 100644 --- a/src/ops/inc_multihead_self_attention.cpp +++ b/src/ops/inc_multihead_self_attention.cpp @@ -532,7 +532,7 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta const *m, continue; } int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch; - int total_tokens = bc->requestsInfo[i].token_start_offset + + int total_tokens = bc->requestsInfo[i].first_token_depth_in_request + bc->requestsInfo[i].num_tokens_in_batch; // bc->token_last_available_idx[i] + 1; // Compute (QK^T/sqrt(d_k)) diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 00d45a9cfa..6ec077c328 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -531,7 +531,7 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta const *m, continue; } int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch; - int total_tokens = bc->requestsInfo[i].token_start_offset + + int total_tokens = bc->requestsInfo[i].first_token_depth_in_request + bc->requestsInfo[i].num_tokens_in_batch; // bc->token_last_available_idx[i] + 1; // Compute (QK^T/sqrt(d_k)) diff --git a/src/ops/spec_inc_multihead_self_attention.cpp b/src/ops/spec_inc_multihead_self_attention.cpp index 173d4a5b1d..1d81ae0c11 100644 --- a/src/ops/spec_inc_multihead_self_attention.cpp +++ b/src/ops/spec_inc_multihead_self_attention.cpp @@ -231,7 +231,7 @@ void compute_attention_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, // int total_tokens = bc->token_last_available_idx[i] + 1; int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch; - int total_tokens = bc->requestsInfo[i].token_start_offset + + int total_tokens = bc->requestsInfo[i].first_token_depth_in_request + bc->requestsInfo[i].num_tokens_in_batch; // Compute (QK^T/sqrt(d_k)) int m_ = num_new_tokens; diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index 00eec96824..8b89acf3b7 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -248,7 +248,7 @@ void compute_attention_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, // int total_tokens = bc->token_last_available_idx[i] + 1; int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch; - int total_tokens = bc->requestsInfo[i].token_start_offset + + int total_tokens = bc->requestsInfo[i].first_token_depth_in_request + bc->requestsInfo[i].num_tokens_in_batch; if (num_new_tokens <= 0) { diff --git a/src/runtime/batch_config.cc b/src/runtime/batch_config.cc index 72572c4e06..4781f09cab 100644 --- a/src/runtime/batch_config.cc +++ b/src/runtime/batch_config.cc @@ -27,7 +27,7 @@ using Legion::Memory; BatchConfig::BatchConfig() : num_tokens(0) { for (int i = 0; i < MAX_NUM_REQUESTS; i++) { - requestsInfo[i].token_start_offset = 0; + requestsInfo[i].first_token_depth_in_request = 0; requestsInfo[i].num_tokens_in_batch = 0; request_completed[i] = true; } @@ -104,8 +104,8 @@ std::ostream &operator<<(std::ostream &os, BatchConfig const &bc) { for (int i = 0; i < bc.max_requests_per_batch(); i++) { if (!bc.request_completed[i]) { os << " Request " << i << ":\n"; - os << " Token start offset: " << bc.requestsInfo[i].token_start_offset - << std::endl; + os << " Token start offset: " + << bc.requestsInfo[i].first_token_depth_in_request << std::endl; os << " Number of tokens in batch: " << bc.requestsInfo[i].num_tokens_in_batch << std::endl; os << " GUID: " << bc.requestsInfo[i].request_guid << std::endl; diff --git a/src/runtime/beam_search_batch_config.cc b/src/runtime/beam_search_batch_config.cc index 811ef00ba2..f785dc5b74 100644 --- a/src/runtime/beam_search_batch_config.cc +++ b/src/runtime/beam_search_batch_config.cc @@ -126,8 +126,8 @@ std::ostream &operator<<(std::ostream &os, BeamSearchBatchConfig const &bc) { for (int i = 0; i < bc.max_requests_per_batch(); i++) { if (!bc.request_completed[i]) { os << " Request " << i << ":\n"; - os << " Token start offset: " << bc.requestsInfo[i].token_start_offset - << std::endl; + os << " Token start offset: " + << bc.requestsInfo[i].first_token_depth_in_request << std::endl; os << " Number of tokens in batch: " << bc.requestsInfo[i].num_tokens_in_batch << std::endl; os << " GUID: " << bc.requestsInfo[i].request_guid << std::endl; diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index b5688c07e6..1c5a6ae5da 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -367,7 +367,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, Request new_request = pending_request_queue.front(); pending_request_queue.pop(); // all_requests[new_request.guid] = new_request; - new_bc.requestsInfo[i].token_start_offset = 0; + new_bc.requestsInfo[i].first_token_depth_in_request = 0; new_bc.requestsInfo[i].request_guid = new_request.guid; new_bc.requestsInfo[i].num_tokens_in_batch = std::min(get_max_tokens_per_batch() - new_bc.num_tokens - @@ -382,7 +382,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, profile_info.start_time = Realm::Clock::current_time_in_microseconds(); profiling_requests[new_request.guid] = profile_info; for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { - int depth = new_bc.requestsInfo[i].token_start_offset + j; + int depth = new_bc.requestsInfo[i].first_token_depth_in_request + j; new_bc.tokensInfo[new_bc.num_tokens].request_index = i; new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = depth; assert(depth < new_request.tokens.size()); @@ -397,8 +397,9 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, } else { assert(old_bc.requestsInfo[i].num_tokens_in_batch > 0); Request &request = all_requests[old_bc.requestsInfo[i].request_guid]; - int processed_tokens = old_bc.requestsInfo[i].token_start_offset + - old_bc.requestsInfo[i].num_tokens_in_batch; + int processed_tokens = + old_bc.requestsInfo[i].first_token_depth_in_request + + old_bc.requestsInfo[i].num_tokens_in_batch; assert(processed_tokens < request.tokens.size()); bool request_completed = false; // printf("model_type = %d\n", this->model_type); @@ -464,12 +465,12 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, } else { new_bc.request_completed[i] = false; - new_bc.requestsInfo[i].token_start_offset = processed_tokens; + new_bc.requestsInfo[i].first_token_depth_in_request = processed_tokens; new_bc.requestsInfo[i].request_guid = old_bc.requestsInfo[i].request_guid; new_bc.requestsInfo[i].max_sequence_length = old_bc.requestsInfo[i].max_sequence_length; - if (new_bc.requestsInfo[i].token_start_offset + 1 == + if (new_bc.requestsInfo[i].first_token_depth_in_request + 1 == request.tokens.size()) { // Incremental phase new_bc.requestsInfo[i].num_tokens_in_batch = 1; @@ -478,10 +479,10 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, new_bc.requestsInfo[i].num_tokens_in_batch = std::min(get_max_tokens_per_batch() - new_bc.num_tokens, (int)request.tokens.size() - - new_bc.requestsInfo[i].token_start_offset); + new_bc.requestsInfo[i].first_token_depth_in_request); } for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { - int depth = new_bc.requestsInfo[i].token_start_offset + j; + int depth = new_bc.requestsInfo[i].first_token_depth_in_request + j; new_bc.tokensInfo[new_bc.num_tokens].request_index = i; new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = depth; assert(depth < request.tokens.size()); @@ -685,7 +686,7 @@ BeamSearchBatchConfig new_bc.request_running[i] = true; // Normal Request Info - new_bc.requestsInfo[i].token_start_offset = + new_bc.requestsInfo[i].first_token_depth_in_request = verified_tokens.front().second; new_bc.requestsInfo[i].request_guid = old_bc.requestsInfo[i].request_guid; @@ -694,9 +695,10 @@ BeamSearchBatchConfig new_bc.requestsInfo[i].num_tokens_in_batch = verified_tokens.size(); // TODO: Beam Request Info, missing from VerifyTreeBatchConfig - int new_max_depth = new_bc.requestsInfo[i].max_sequence_length - - new_bc.requestsInfo[i].token_start_offset - - verified_tokens.size(); + int new_max_depth = + new_bc.requestsInfo[i].max_sequence_length - + new_bc.requestsInfo[i].first_token_depth_in_request - + verified_tokens.size(); new_bc.beamRequestsInfo[i].current_depth = 1; new_bc.beamRequestsInfo[i].beam_size = BeamSearchBatchConfig::MAX_BEAM_WIDTH; @@ -742,7 +744,8 @@ BeamSearchBatchConfig assert(request.ssm_cache_size == request.initial_len); // Normal Request Info - new_bc.requestsInfo[i].token_start_offset = request.ssm_cache_size; + new_bc.requestsInfo[i].first_token_depth_in_request = + request.ssm_cache_size; new_bc.requestsInfo[i].request_guid = old_bc.requestsInfo[i].request_guid; new_bc.requestsInfo[i].max_sequence_length = old_bc.requestsInfo[i].max_sequence_length; @@ -776,7 +779,7 @@ BeamSearchBatchConfig Request new_request = pending_request_queue.front(); pending_request_queue.pop(); // all_requests[new_request.guid] = new_request; - new_bc.requestsInfo[i].token_start_offset = 0; + new_bc.requestsInfo[i].first_token_depth_in_request = 0; new_bc.requestsInfo[i].request_guid = new_request.guid; new_bc.requestsInfo[i].num_tokens_in_batch = std::min(get_max_tokens_per_batch() - new_bc.num_tokens, @@ -806,7 +809,7 @@ BeamSearchBatchConfig new_bc.sub_requests[i] = 1; for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { - int depth = new_bc.requestsInfo[i].token_start_offset + j; + int depth = new_bc.requestsInfo[i].first_token_depth_in_request + j; new_bc.tokensInfo[new_bc.num_tokens].request_index = i; new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = depth; assert(depth < new_request.tokens.size()); @@ -922,7 +925,7 @@ BeamSearchBatchConfig // zero when beam search has reached required sequence length // assert(old_bc.requestsInfo[i].num_tokens_in_batch > 0); Request &request = all_requests[old_bc.requestsInfo[i].request_guid]; - int processed_tokens = old_bc.requestsInfo[i].token_start_offset + + int processed_tokens = old_bc.requestsInfo[i].first_token_depth_in_request + old_bc.requestsInfo[i].num_tokens_in_batch; // assert(processed_tokens < request.tokens.size()); @@ -937,7 +940,8 @@ BeamSearchBatchConfig // // old_bc.beamRequestsInfo[i].max_depth); // // // new_bc.request_completed[i] = true; // // new_bc.request_completed[i] = false; - // // new_bc.requestsInfo[i].token_start_offset = processed_tokens; + // // new_bc.requestsInfo[i].first_token_depth_in_request = + // processed_tokens; // // new_bc.requestsInfo[i].request_guid = // // old_bc.requestsInfo[i].request_guid; // // new_bc.requestsInfo[i].max_sequence_length = @@ -953,7 +957,7 @@ BeamSearchBatchConfig log_req_mgr.debug() << "num tokens: " << old_bc.num_tokens << ", " << new_bc.num_tokens; new_bc.request_completed[i] = false; - new_bc.requestsInfo[i].token_start_offset = processed_tokens; + new_bc.requestsInfo[i].first_token_depth_in_request = processed_tokens; new_bc.requestsInfo[i].request_guid = old_bc.requestsInfo[i].request_guid; new_bc.requestsInfo[i].max_sequence_length = old_bc.requestsInfo[i].max_sequence_length; @@ -986,7 +990,8 @@ BeamSearchBatchConfig // do the slot exchange to minimize the cache exchange in kernel. // update_beam_metadata(new_bc, request.beam_trees.at(old_bc.model_id), // i); - if (new_bc.requestsInfo[i].token_start_offset >= request.tokens.size()) { + if (new_bc.requestsInfo[i].first_token_depth_in_request >= + request.tokens.size()) { // Incremental phase if (request.status == Request::RUNNING) { new_bc.requestsInfo[i].num_tokens_in_batch = 1; @@ -1006,7 +1011,7 @@ BeamSearchBatchConfig std::min(get_max_tokens_per_batch() - new_bc.num_tokens - BatchConfig::max_requests_per_batch() + i, (int)request.tokens.size() - - new_bc.requestsInfo[i].token_start_offset); + new_bc.requestsInfo[i].first_token_depth_in_request); request.ssm_cache_size += new_bc.requestsInfo[i].num_tokens_in_batch; if (verbose) { std::cout << "[ Beam Spec] " << request.guid << std::endl; @@ -1027,7 +1032,7 @@ BeamSearchBatchConfig // register more tokens due to the beam width for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { - int depth = new_bc.requestsInfo[i].token_start_offset + j; + int depth = new_bc.requestsInfo[i].first_token_depth_in_request + j; for (int k = 0; k < new_bc.sub_requests[i]; k++) { new_bc.tokensInfo[new_bc.num_tokens].request_index = i; new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = depth; @@ -1151,7 +1156,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( } // Normal Request Info - new_bc.requestsInfo[i].token_start_offset = + new_bc.requestsInfo[i].first_token_depth_in_request = dfs_tree_inputs.front().second; new_bc.requestsInfo[i].request_guid = old_batches.at(0).requestsInfo[i].request_guid; @@ -1204,7 +1209,8 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( break; } - new_bc.requestsInfo[i].token_start_offset = request.tokens.size() - 1; + new_bc.requestsInfo[i].first_token_depth_in_request = + request.tokens.size() - 1; // Add Tokens from the DFS Tree to the next batch for (int j = 1; j < dfs_tree_inputs.size(); j++) { @@ -1257,7 +1263,8 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( } // Normal Request Info - new_bc.requestsInfo[i].token_start_offset = request.llm_cache_size; + new_bc.requestsInfo[i].first_token_depth_in_request = + request.llm_cache_size; new_bc.requestsInfo[i].request_guid = old_batches.at(0).requestsInfo[i].request_guid; new_bc.requestsInfo[i].max_sequence_length = @@ -1265,9 +1272,10 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( new_bc.request_completed[i] = false; - new_bc.requestsInfo[i].num_tokens_in_batch = std::min( - max_prompt_load_size, - (int)request.initial_len - new_bc.requestsInfo[i].token_start_offset); + new_bc.requestsInfo[i].num_tokens_in_batch = + std::min(max_prompt_load_size, + (int)request.initial_len - + new_bc.requestsInfo[i].first_token_depth_in_request); max_prompt_load_size -= new_bc.requestsInfo[i].num_tokens_in_batch; std::cout << "max_prompt_load_size: " << max_prompt_load_size @@ -1673,7 +1681,7 @@ std::vector> std::vector> RequestManager::traverse_beam_tree(BeamSearchBatchConfig const &old_bc, int request_index, - int token_start_offset) { + int first_token_depth_in_request) { if (verbose) { std::cout << "[Traverse Beam Tree] request_index: " << request_index << "\n"; @@ -1709,7 +1717,7 @@ std::vector> << serializedTree.size() << "\n"; } for (int k = 0; k < serializedTree.size(); k++) { - serializedTree.at(k).second += token_start_offset; + serializedTree.at(k).second += first_token_depth_in_request; if (verbose) { std::cout << "token id: " << serializedTree.at(k).first << ", depth: " << serializedTree.at(k).second << "\n"; diff --git a/src/runtime/tree_verify_batch_config.cc b/src/runtime/tree_verify_batch_config.cc index cb68ecc5f1..6dbcaceaa4 100644 --- a/src/runtime/tree_verify_batch_config.cc +++ b/src/runtime/tree_verify_batch_config.cc @@ -47,8 +47,8 @@ std::ostream &operator<<(std::ostream &os, TreeVerifyBatchConfig const &bc) { for (int i = 0; i < bc.max_requests_per_batch(); i++) { if (!bc.request_completed[i]) { os << " Request " << i << ":\n"; - os << " Token start offset: " << bc.requestsInfo[i].token_start_offset - << std::endl; + os << " Token start offset: " + << bc.requestsInfo[i].first_token_depth_in_request << std::endl; os << " Number of tokens in batch: " << bc.requestsInfo[i].num_tokens_in_batch << std::endl; os << " GUID: " << bc.requestsInfo[i].request_guid << std::endl;