diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index 25bc206bf9..d625985552 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -62,6 +62,7 @@ class BatchConfig { struct PerRequestInfo { int first_token_depth_in_request; + int first_token_offset_in_batch; int num_tokens_in_batch; int max_sequence_length; RequestGuid request_guid; diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 6ec077c328..ced1459b59 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -530,6 +530,8 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta const *m, if (bc->request_completed[i]) { continue; } + assert(tokens_previous_requests == + bc->requestsInfo[i].first_token_offset_in_batch); int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch; int total_tokens = bc->requestsInfo[i].first_token_depth_in_request + bc->requestsInfo[i].num_tokens_in_batch; diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index 8b89acf3b7..fddbd252b6 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -241,7 +241,8 @@ void compute_attention_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, if (bc->request_completed[i]) { continue; } - + assert(tokens_previous_requests == + bc->requestsInfo[i].first_token_offset_in_batch); for (int sub_req_id = 0; sub_req_id < bc->sub_requests[i]; sub_req_id++) { // int num_new_tokens = bc->num_processing_tokens[i]; diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index 0da432b732..98a9c6557a 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -181,6 +181,8 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m, if (bc->request_completed[i]) { continue; } + assert(processed_tokens_in_batch == + bc->requestsInfo[i].first_token_offset_in_batch); int last_token_idx_of_the_request = processed_tokens_in_batch + bc->requestsInfo[i].num_tokens_in_batch - 1; while (processed_tokens_in_batch <= last_token_idx_of_the_request) { diff --git a/src/runtime/batch_config.cc b/src/runtime/batch_config.cc index 4781f09cab..d2fbc0883f 100644 --- a/src/runtime/batch_config.cc +++ b/src/runtime/batch_config.cc @@ -28,6 +28,7 @@ using Legion::Memory; BatchConfig::BatchConfig() : num_tokens(0) { for (int i = 0; i < MAX_NUM_REQUESTS; i++) { requestsInfo[i].first_token_depth_in_request = 0; + requestsInfo[i].first_token_offset_in_batch = 0; requestsInfo[i].num_tokens_in_batch = 0; request_completed[i] = true; } @@ -104,8 +105,10 @@ std::ostream &operator<<(std::ostream &os, BatchConfig const &bc) { for (int i = 0; i < bc.max_requests_per_batch(); i++) { if (!bc.request_completed[i]) { os << " Request " << i << ":\n"; - os << " Token start offset: " + os << " First token depth in request: " << bc.requestsInfo[i].first_token_depth_in_request << std::endl; + os << " First token offset in batch: " + << bc.requestsInfo[i].first_token_offset_in_batch << std::endl; os << " Number of tokens in batch: " << bc.requestsInfo[i].num_tokens_in_batch << std::endl; os << " GUID: " << bc.requestsInfo[i].request_guid << std::endl; diff --git a/src/runtime/beam_search_batch_config.cc b/src/runtime/beam_search_batch_config.cc index f785dc5b74..74843e9460 100644 --- a/src/runtime/beam_search_batch_config.cc +++ b/src/runtime/beam_search_batch_config.cc @@ -126,8 +126,10 @@ std::ostream &operator<<(std::ostream &os, BeamSearchBatchConfig const &bc) { for (int i = 0; i < bc.max_requests_per_batch(); i++) { if (!bc.request_completed[i]) { os << " Request " << i << ":\n"; - os << " Token start offset: " + os << " First token depth in request: " << bc.requestsInfo[i].first_token_depth_in_request << std::endl; + os << " First token offset in batch: " + << bc.requestsInfo[i].first_token_offset_in_batch << std::endl; os << " Number of tokens in batch: " << bc.requestsInfo[i].num_tokens_in_batch << std::endl; os << " GUID: " << bc.requestsInfo[i].request_guid << std::endl; diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 1c5a6ae5da..4d232b6d44 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -368,6 +368,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, pending_request_queue.pop(); // all_requests[new_request.guid] = new_request; new_bc.requestsInfo[i].first_token_depth_in_request = 0; + new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens; new_bc.requestsInfo[i].request_guid = new_request.guid; new_bc.requestsInfo[i].num_tokens_in_batch = std::min(get_max_tokens_per_batch() - new_bc.num_tokens - @@ -466,6 +467,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, } else { new_bc.request_completed[i] = false; new_bc.requestsInfo[i].first_token_depth_in_request = processed_tokens; + new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens; new_bc.requestsInfo[i].request_guid = old_bc.requestsInfo[i].request_guid; new_bc.requestsInfo[i].max_sequence_length = @@ -688,6 +690,7 @@ BeamSearchBatchConfig // Normal Request Info new_bc.requestsInfo[i].first_token_depth_in_request = verified_tokens.front().second; + new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens; new_bc.requestsInfo[i].request_guid = old_bc.requestsInfo[i].request_guid; new_bc.requestsInfo[i].max_sequence_length = @@ -746,6 +749,7 @@ BeamSearchBatchConfig // Normal Request Info new_bc.requestsInfo[i].first_token_depth_in_request = request.ssm_cache_size; + new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens; new_bc.requestsInfo[i].request_guid = old_bc.requestsInfo[i].request_guid; new_bc.requestsInfo[i].max_sequence_length = old_bc.requestsInfo[i].max_sequence_length; @@ -780,6 +784,7 @@ BeamSearchBatchConfig pending_request_queue.pop(); // all_requests[new_request.guid] = new_request; new_bc.requestsInfo[i].first_token_depth_in_request = 0; + new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens; new_bc.requestsInfo[i].request_guid = new_request.guid; new_bc.requestsInfo[i].num_tokens_in_batch = std::min(get_max_tokens_per_batch() - new_bc.num_tokens, @@ -958,6 +963,7 @@ BeamSearchBatchConfig << new_bc.num_tokens; new_bc.request_completed[i] = false; new_bc.requestsInfo[i].first_token_depth_in_request = processed_tokens; + new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens; new_bc.requestsInfo[i].request_guid = old_bc.requestsInfo[i].request_guid; new_bc.requestsInfo[i].max_sequence_length = old_bc.requestsInfo[i].max_sequence_length; @@ -1158,6 +1164,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( // Normal Request Info new_bc.requestsInfo[i].first_token_depth_in_request = dfs_tree_inputs.front().second; + new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens; new_bc.requestsInfo[i].request_guid = old_batches.at(0).requestsInfo[i].request_guid; new_bc.requestsInfo[i].max_sequence_length = @@ -1265,6 +1272,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( // Normal Request Info new_bc.requestsInfo[i].first_token_depth_in_request = request.llm_cache_size; + new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens; new_bc.requestsInfo[i].request_guid = old_batches.at(0).requestsInfo[i].request_guid; new_bc.requestsInfo[i].max_sequence_length = diff --git a/src/runtime/tree_verify_batch_config.cc b/src/runtime/tree_verify_batch_config.cc index 6dbcaceaa4..841c735f59 100644 --- a/src/runtime/tree_verify_batch_config.cc +++ b/src/runtime/tree_verify_batch_config.cc @@ -47,8 +47,10 @@ std::ostream &operator<<(std::ostream &os, TreeVerifyBatchConfig const &bc) { for (int i = 0; i < bc.max_requests_per_batch(); i++) { if (!bc.request_completed[i]) { os << " Request " << i << ":\n"; - os << " Token start offset: " + os << " First token depth in request: " << bc.requestsInfo[i].first_token_depth_in_request << std::endl; + os << " First token offset in batch: " + << bc.requestsInfo[i].first_token_offset_in_batch << std::endl; os << " Number of tokens in batch: " << bc.requestsInfo[i].num_tokens_in_batch << std::endl; os << " GUID: " << bc.requestsInfo[i].request_guid << std::endl;