From 9093024ee7f5c3036328e4db96619e767112c5f0 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Sun, 20 Oct 2024 21:32:02 +0000 Subject: [PATCH] fix --- CMakeLists.txt | 2 +- src/c/flexflow_c.cc | 2 +- src/ops/inc_multihead_self_attention.cu | 39 ++++++++++++++++--------- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4e24e1e54b..fa9ab75be2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -240,7 +240,7 @@ if(NOT BUILD_LEGION_ONLY) list(APPEND FF_CC_FLAGS -DFF_DEBUG) list(APPEND FF_NVCC_FLAGS - -DFF_DEBUG) + -DFF_DEBUG -lineinfo) endif() message(STATUS "FlexFlow MAX_DIM: ${FF_MAX_DIM}") diff --git a/src/c/flexflow_c.cc b/src/c/flexflow_c.cc index bfa60a6d54..980ecd356b 100644 --- a/src/c/flexflow_c.cc +++ b/src/c/flexflow_c.cc @@ -1730,7 +1730,7 @@ void flexflow_model_generate(flexflow_model_t handle_, handle, dataset_fp.c_str(), max_lengths[i], - max_new_tokens[i], + max_new_tokens_[i], training_steps[i]); } else { assert(false && "Unknown request type"); diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 454926bcdb..0081625801 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -384,7 +384,8 @@ __global__ void compute_attention_kernel_generation_kernel( int max_seq_length, int per_head_size, int hidden_size, - BatchConfig::PerRequestInfo *request_infos) { + BatchConfig::PerRequestInfo *request_infos, + BatchConfig::PerTokenInfo *tokens_infos) { // q, k using Q_vec = typename VEC_K::Type; @@ -409,11 +410,14 @@ __global__ void compute_attention_kernel_generation_kernel( int const tidx = threadIdx.x; // head id int const head_idx = blockIdx.x; - // request idx - int const request_idx = blockIdx.y; - + // generation token idx. + int const batch_config_token_idx = blockIdx.y; int const batch_config_request_id = - request_infos[request_idx].batch_config_request_id; + tokens_infos[batch_config_token_idx].request_index; + if (request_infos[batch_config_request_id].prompt_phase) { + return; + } + assert(request_infos[batch_config_request_id].num_tokens_in_batch == 1); int const first_step = 0; @@ -432,11 +436,13 @@ __global__ void compute_attention_kernel_generation_kernel( // first WARPS_PER_BLOCK for store qk_max, second WARPS_PER_BLOCK for sum __shared__ float red_smem[WARPS_PER_BLOCK * 2]; - const DT *q_ptr = query + request_idx * hidden_size * QKV_WEIGHT_NUM + + const DT *q_ptr = query + + batch_config_token_idx * hidden_size * QKV_WEIGHT_NUM + head_idx * per_head_size; __shared__ Q_vec q_vecs[THREADS_PER_KEY][K_VECS_PER_THREAD]; // DT const *q_ptr = - // query + request_idx * Dh * QKV_WEIGHT_NUM + head_idx * per_head_size; + // query + batch_config_token_idx * Dh * QKV_WEIGHT_NUM + head_idx * + // per_head_size; // q tensor in this thread // if THREADS_PER_KEY is 4, first thread load 0, 4, 8, 12..., total @@ -615,10 +621,10 @@ __global__ void compute_attention_kernel_generation_kernel( // Output the final values. if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { - convert_from_float( - *reinterpret_cast(output_ptr + request_idx * hidden_size + - head_idx * per_head_size + vi), - out); + convert_from_float(*reinterpret_cast( + output_ptr + batch_config_token_idx * hidden_size + + head_idx * per_head_size + vi), + out); } } @@ -897,14 +903,15 @@ void update_kv_cache_kernel(IncMultiHeadSelfAttentionMeta const *m, BatchConfig::max_sequence_length(), \ m->qProjSize, \ m->hidden_size, \ - m->request_infos) + m->request_infos, \ + m->token_infos) template void compute_attention_kernel_generation(IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, DT *output_ptr, cudaStream_t stream) { - dim3 grid(m->num_q_heads, bc->num_generation_tokens); + dim3 grid(m->num_q_heads, bc->num_active_tokens()); int const per_head_size = m->qProjSize; float scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->kProjSize) : 1.0f; size_t smem_sz; @@ -958,8 +965,14 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m, // phase 1: Implement kernel to apply rotary embedding and scaling compute_qkv_kernel( m, bc, shard_id, static_cast
(m->devQKVProjArray), stream); + // key/value cache shape (keyed by request idx in batch config): + // [proj_size * num_heads, max_seq_length, max_requests] update_kv_cache_kernel
(m, bc, stream); + // QK prods shape (single request only): [num_new_tokens, total_tokens, + // num_heads] Attn heads shape (single request only): [vProjSize, num_heads, + // num_new_tokens] + if (bc->num_generation_tokens > 0) { // phase 3: Compute attention score for generation tokens compute_attention_kernel_generation
(