Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-goliaro committed Oct 20, 2024
1 parent 89f10f4 commit 8443667
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 15 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ if(NOT BUILD_LEGION_ONLY)
list(APPEND FF_CC_FLAGS
-DFF_DEBUG)
list(APPEND FF_NVCC_FLAGS
-DFF_DEBUG)
-DFF_DEBUG -lineinfo)
endif()

message(STATUS "FlexFlow MAX_DIM: ${FF_MAX_DIM}")
Expand Down
2 changes: 1 addition & 1 deletion src/c/flexflow_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1730,7 +1730,7 @@ void flexflow_model_generate(flexflow_model_t handle_,
handle,
dataset_fp.c_str(),
max_lengths[i],
max_new_tokens[i],
max_new_tokens_[i],
training_steps[i]);
} else {
assert(false && "Unknown request type");
Expand Down
39 changes: 26 additions & 13 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,8 @@ __global__ void compute_attention_kernel_generation_kernel(
int max_seq_length,
int per_head_size,
int hidden_size,
BatchConfig::PerRequestInfo *request_infos) {
BatchConfig::PerRequestInfo *request_infos,
BatchConfig::PerTokenInfo *tokens_infos) {
// q, k
using Q_vec = typename VEC_K<DT, THREADS_PER_KEY>::Type;
Expand All @@ -409,11 +410,14 @@ __global__ void compute_attention_kernel_generation_kernel(
int const tidx = threadIdx.x;
// head id
int const head_idx = blockIdx.x;
// request idx
int const request_idx = blockIdx.y;
// generation token idx.
int const batch_config_token_idx = blockIdx.y;
int const batch_config_request_id =
request_infos[request_idx].batch_config_request_id;
tokens_infos[batch_config_token_idx].request_index;
if (request_infos[batch_config_request_id].prompt_phase) {
return;
}
assert(request_infos[batch_config_request_id].num_tokens_in_batch == 1);
int const first_step = 0;
Expand All @@ -432,11 +436,13 @@ __global__ void compute_attention_kernel_generation_kernel(
// first WARPS_PER_BLOCK for store qk_max, second WARPS_PER_BLOCK for sum
__shared__ float red_smem[WARPS_PER_BLOCK * 2];
const DT *q_ptr = query + request_idx * hidden_size * QKV_WEIGHT_NUM +
const DT *q_ptr = query +
batch_config_token_idx * hidden_size * QKV_WEIGHT_NUM +
head_idx * per_head_size;
__shared__ Q_vec q_vecs[THREADS_PER_KEY][K_VECS_PER_THREAD];
// DT const *q_ptr =
// query + request_idx * Dh * QKV_WEIGHT_NUM + head_idx * per_head_size;
// query + batch_config_token_idx * Dh * QKV_WEIGHT_NUM + head_idx *
// per_head_size;
// q tensor in this thread
// if THREADS_PER_KEY is 4, first thread load 0, 4, 8, 12..., total
Expand Down Expand Up @@ -615,10 +621,10 @@ __global__ void compute_attention_kernel_generation_kernel(
// Output the final values.
if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) {
convert_from_float(
*reinterpret_cast<V_vec *>(output_ptr + request_idx * hidden_size +
head_idx * per_head_size + vi),
out);
convert_from_float(*reinterpret_cast<V_vec *>(
output_ptr + batch_config_token_idx * hidden_size +
head_idx * per_head_size + vi),
out);
}
}
Expand Down Expand Up @@ -897,14 +903,15 @@ void update_kv_cache_kernel(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig::max_sequence_length(), \
m->qProjSize, \
m->hidden_size, \
m->request_infos)
m->request_infos, \
m->token_infos)
template <typename DT>
void compute_attention_kernel_generation(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
DT *output_ptr,
cudaStream_t stream) {
dim3 grid(m->num_q_heads, bc->num_generation_tokens);
dim3 grid(m->num_q_heads, bc->num_active_tokens());
int const per_head_size = m->qProjSize;
float scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->kProjSize) : 1.0f;
size_t smem_sz;
Expand Down Expand Up @@ -958,8 +965,14 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m,
// phase 1: Implement kernel to apply rotary embedding and scaling
compute_qkv_kernel(
m, bc, shard_id, static_cast<DT *>(m->devQKVProjArray), stream);
// key/value cache shape (keyed by request idx in batch config):
// [proj_size * num_heads, max_seq_length, max_requests]
update_kv_cache_kernel<DT>(m, bc, stream);
// QK prods shape (single request only): [num_new_tokens, total_tokens,
// num_heads] Attn heads shape (single request only): [vProjSize, num_heads,
// num_new_tokens]
if (bc->num_generation_tokens > 0) {
// phase 3: Compute attention score for generation tokens
compute_attention_kernel_generation<DT>(
Expand Down

0 comments on commit 8443667

Please sign in to comment.