Skip to content

Commit

Permalink
Add first_token_offset_in_batch to indicate the offset of the reque…
Browse files Browse the repository at this point in the history
…st's first token in a `BatchConfig` (flexflow#1197)

* Add first_token_offset_in_batch to indicate the offset of the request's first token in a BatchConfig

* format
  • Loading branch information
jiazhihao authored Oct 18, 2023
1 parent 4c06a09 commit fb0b21c
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 4 deletions.
1 change: 1 addition & 0 deletions include/flexflow/batch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class BatchConfig {

struct PerRequestInfo {
int first_token_depth_in_request;
int first_token_offset_in_batch;
int num_tokens_in_batch;
int max_sequence_length;
RequestGuid request_guid;
Expand Down
2 changes: 2 additions & 0 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,8 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta const *m,
if (bc->request_completed[i]) {
continue;
}
assert(tokens_previous_requests ==
bc->requestsInfo[i].first_token_offset_in_batch);
int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch;
int total_tokens = bc->requestsInfo[i].first_token_depth_in_request +
bc->requestsInfo[i].num_tokens_in_batch;
Expand Down
3 changes: 2 additions & 1 deletion src/ops/spec_inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ void compute_attention_kernel(SpecIncMultiHeadSelfAttentionMeta const *m,
if (bc->request_completed[i]) {
continue;
}
assert(tokens_previous_requests ==
bc->requestsInfo[i].first_token_offset_in_batch);
for (int sub_req_id = 0; sub_req_id < bc->sub_requests[i]; sub_req_id++) {
// int num_new_tokens = bc->num_processing_tokens[i];
Expand Down
2 changes: 2 additions & 0 deletions src/ops/tree_inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m,
if (bc->request_completed[i]) {
continue;
}
assert(processed_tokens_in_batch ==
bc->requestsInfo[i].first_token_offset_in_batch);
int last_token_idx_of_the_request =
processed_tokens_in_batch + bc->requestsInfo[i].num_tokens_in_batch - 1;
while (processed_tokens_in_batch <= last_token_idx_of_the_request) {
Expand Down
5 changes: 4 additions & 1 deletion src/runtime/batch_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ using Legion::Memory;
BatchConfig::BatchConfig() : num_tokens(0) {
for (int i = 0; i < MAX_NUM_REQUESTS; i++) {
requestsInfo[i].first_token_depth_in_request = 0;
requestsInfo[i].first_token_offset_in_batch = 0;
requestsInfo[i].num_tokens_in_batch = 0;
request_completed[i] = true;
}
Expand Down Expand Up @@ -104,8 +105,10 @@ std::ostream &operator<<(std::ostream &os, BatchConfig const &bc) {
for (int i = 0; i < bc.max_requests_per_batch(); i++) {
if (!bc.request_completed[i]) {
os << " Request " << i << ":\n";
os << " Token start offset: "
os << " First token depth in request: "
<< bc.requestsInfo[i].first_token_depth_in_request << std::endl;
os << " First token offset in batch: "
<< bc.requestsInfo[i].first_token_offset_in_batch << std::endl;
os << " Number of tokens in batch: "
<< bc.requestsInfo[i].num_tokens_in_batch << std::endl;
os << " GUID: " << bc.requestsInfo[i].request_guid << std::endl;
Expand Down
4 changes: 3 additions & 1 deletion src/runtime/beam_search_batch_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,10 @@ std::ostream &operator<<(std::ostream &os, BeamSearchBatchConfig const &bc) {
for (int i = 0; i < bc.max_requests_per_batch(); i++) {
if (!bc.request_completed[i]) {
os << " Request " << i << ":\n";
os << " Token start offset: "
os << " First token depth in request: "
<< bc.requestsInfo[i].first_token_depth_in_request << std::endl;
os << " First token offset in batch: "
<< bc.requestsInfo[i].first_token_offset_in_batch << std::endl;
os << " Number of tokens in batch: "
<< bc.requestsInfo[i].num_tokens_in_batch << std::endl;
os << " GUID: " << bc.requestsInfo[i].request_guid << std::endl;
Expand Down
8 changes: 8 additions & 0 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
pending_request_queue.pop();
// all_requests[new_request.guid] = new_request;
new_bc.requestsInfo[i].first_token_depth_in_request = 0;
new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens;
new_bc.requestsInfo[i].request_guid = new_request.guid;
new_bc.requestsInfo[i].num_tokens_in_batch =
std::min(get_max_tokens_per_batch() - new_bc.num_tokens -
Expand Down Expand Up @@ -466,6 +467,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
} else {
new_bc.request_completed[i] = false;
new_bc.requestsInfo[i].first_token_depth_in_request = processed_tokens;
new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens;
new_bc.requestsInfo[i].request_guid =
old_bc.requestsInfo[i].request_guid;
new_bc.requestsInfo[i].max_sequence_length =
Expand Down Expand Up @@ -688,6 +690,7 @@ BeamSearchBatchConfig
// Normal Request Info
new_bc.requestsInfo[i].first_token_depth_in_request =
verified_tokens.front().second;
new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens;
new_bc.requestsInfo[i].request_guid =
old_bc.requestsInfo[i].request_guid;
new_bc.requestsInfo[i].max_sequence_length =
Expand Down Expand Up @@ -746,6 +749,7 @@ BeamSearchBatchConfig
// Normal Request Info
new_bc.requestsInfo[i].first_token_depth_in_request =
request.ssm_cache_size;
new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens;
new_bc.requestsInfo[i].request_guid = old_bc.requestsInfo[i].request_guid;
new_bc.requestsInfo[i].max_sequence_length =
old_bc.requestsInfo[i].max_sequence_length;
Expand Down Expand Up @@ -780,6 +784,7 @@ BeamSearchBatchConfig
pending_request_queue.pop();
// all_requests[new_request.guid] = new_request;
new_bc.requestsInfo[i].first_token_depth_in_request = 0;
new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens;
new_bc.requestsInfo[i].request_guid = new_request.guid;
new_bc.requestsInfo[i].num_tokens_in_batch =
std::min(get_max_tokens_per_batch() - new_bc.num_tokens,
Expand Down Expand Up @@ -958,6 +963,7 @@ BeamSearchBatchConfig
<< new_bc.num_tokens;
new_bc.request_completed[i] = false;
new_bc.requestsInfo[i].first_token_depth_in_request = processed_tokens;
new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens;
new_bc.requestsInfo[i].request_guid = old_bc.requestsInfo[i].request_guid;
new_bc.requestsInfo[i].max_sequence_length =
old_bc.requestsInfo[i].max_sequence_length;
Expand Down Expand Up @@ -1158,6 +1164,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify(
// Normal Request Info
new_bc.requestsInfo[i].first_token_depth_in_request =
dfs_tree_inputs.front().second;
new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens;
new_bc.requestsInfo[i].request_guid =
old_batches.at(0).requestsInfo[i].request_guid;
new_bc.requestsInfo[i].max_sequence_length =
Expand Down Expand Up @@ -1265,6 +1272,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify(
// Normal Request Info
new_bc.requestsInfo[i].first_token_depth_in_request =
request.llm_cache_size;
new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens;
new_bc.requestsInfo[i].request_guid =
old_batches.at(0).requestsInfo[i].request_guid;
new_bc.requestsInfo[i].max_sequence_length =
Expand Down
4 changes: 3 additions & 1 deletion src/runtime/tree_verify_batch_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ std::ostream &operator<<(std::ostream &os, TreeVerifyBatchConfig const &bc) {
for (int i = 0; i < bc.max_requests_per_batch(); i++) {
if (!bc.request_completed[i]) {
os << " Request " << i << ":\n";
os << " Token start offset: "
os << " First token depth in request: "
<< bc.requestsInfo[i].first_token_depth_in_request << std::endl;
os << " First token offset in batch: "
<< bc.requestsInfo[i].first_token_offset_in_batch << std::endl;
os << " Number of tokens in batch: "
<< bc.requestsInfo[i].num_tokens_in_batch << std::endl;
os << " GUID: " << bc.requestsInfo[i].request_guid << std::endl;
Expand Down

0 comments on commit fb0b21c

Please sign in to comment.