diff --git a/CMakeLists.txt b/CMakeLists.txt
index 4e24e1e54b..fa9ab75be2 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -240,7 +240,7 @@ if(NOT BUILD_LEGION_ONLY)
list(APPEND FF_CC_FLAGS
-DFF_DEBUG)
list(APPEND FF_NVCC_FLAGS
- -DFF_DEBUG)
+ -DFF_DEBUG -lineinfo)
endif()
message(STATUS "FlexFlow MAX_DIM: ${FF_MAX_DIM}")
diff --git a/src/c/flexflow_c.cc b/src/c/flexflow_c.cc
index bfa60a6d54..980ecd356b 100644
--- a/src/c/flexflow_c.cc
+++ b/src/c/flexflow_c.cc
@@ -1730,7 +1730,7 @@ void flexflow_model_generate(flexflow_model_t handle_,
handle,
dataset_fp.c_str(),
max_lengths[i],
- max_new_tokens[i],
+ max_new_tokens_[i],
training_steps[i]);
} else {
assert(false && "Unknown request type");
diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu
index 454926bcdb..0081625801 100644
--- a/src/ops/inc_multihead_self_attention.cu
+++ b/src/ops/inc_multihead_self_attention.cu
@@ -384,7 +384,8 @@ __global__ void compute_attention_kernel_generation_kernel(
int max_seq_length,
int per_head_size,
int hidden_size,
- BatchConfig::PerRequestInfo *request_infos) {
+ BatchConfig::PerRequestInfo *request_infos,
+ BatchConfig::PerTokenInfo *tokens_infos) {
// q, k
using Q_vec = typename VEC_K
::Type;
@@ -409,11 +410,14 @@ __global__ void compute_attention_kernel_generation_kernel(
int const tidx = threadIdx.x;
// head id
int const head_idx = blockIdx.x;
- // request idx
- int const request_idx = blockIdx.y;
-
+ // generation token idx.
+ int const batch_config_token_idx = blockIdx.y;
int const batch_config_request_id =
- request_infos[request_idx].batch_config_request_id;
+ tokens_infos[batch_config_token_idx].request_index;
+ if (request_infos[batch_config_request_id].prompt_phase) {
+ return;
+ }
+ assert(request_infos[batch_config_request_id].num_tokens_in_batch == 1);
int const first_step = 0;
@@ -432,11 +436,13 @@ __global__ void compute_attention_kernel_generation_kernel(
// first WARPS_PER_BLOCK for store qk_max, second WARPS_PER_BLOCK for sum
__shared__ float red_smem[WARPS_PER_BLOCK * 2];
- const DT *q_ptr = query + request_idx * hidden_size * QKV_WEIGHT_NUM +
+ const DT *q_ptr = query +
+ batch_config_token_idx * hidden_size * QKV_WEIGHT_NUM +
head_idx * per_head_size;
__shared__ Q_vec q_vecs[THREADS_PER_KEY][K_VECS_PER_THREAD];
// DT const *q_ptr =
- // query + request_idx * Dh * QKV_WEIGHT_NUM + head_idx * per_head_size;
+ // query + batch_config_token_idx * Dh * QKV_WEIGHT_NUM + head_idx *
+ // per_head_size;
// q tensor in this thread
// if THREADS_PER_KEY is 4, first thread load 0, 4, 8, 12..., total
@@ -615,10 +621,10 @@ __global__ void compute_attention_kernel_generation_kernel(
// Output the final values.
if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) {
- convert_from_float(
- *reinterpret_cast(output_ptr + request_idx * hidden_size +
- head_idx * per_head_size + vi),
- out);
+ convert_from_float(*reinterpret_cast(
+ output_ptr + batch_config_token_idx * hidden_size +
+ head_idx * per_head_size + vi),
+ out);
}
}
@@ -897,14 +903,15 @@ void update_kv_cache_kernel(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig::max_sequence_length(), \
m->qProjSize, \
m->hidden_size, \
- m->request_infos)
+ m->request_infos, \
+ m->token_infos)
template
void compute_attention_kernel_generation(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
DT *output_ptr,
cudaStream_t stream) {
- dim3 grid(m->num_q_heads, bc->num_generation_tokens);
+ dim3 grid(m->num_q_heads, bc->num_active_tokens());
int const per_head_size = m->qProjSize;
float scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->kProjSize) : 1.0f;
size_t smem_sz;
@@ -958,8 +965,14 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m,
// phase 1: Implement kernel to apply rotary embedding and scaling
compute_qkv_kernel(
m, bc, shard_id, static_cast(m->devQKVProjArray), stream);
+ // key/value cache shape (keyed by request idx in batch config):
+ // [proj_size * num_heads, max_seq_length, max_requests]
update_kv_cache_kernel(m, bc, stream);
+ // QK prods shape (single request only): [num_new_tokens, total_tokens,
+ // num_heads] Attn heads shape (single request only): [vProjSize, num_heads,
+ // num_new_tokens]
+
if (bc->num_generation_tokens > 0) {
// phase 3: Compute attention score for generation tokens
compute_attention_kernel_generation(