Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attn debug #1527

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ if(NOT BUILD_LEGION_ONLY)
list(APPEND FF_CC_FLAGS
-DFF_DEBUG)
list(APPEND FF_NVCC_FLAGS
-DFF_DEBUG)
-DFF_DEBUG -lineinfo)
endif()

message(STATUS "FlexFlow MAX_DIM: ${FF_MAX_DIM}")
Expand Down
2 changes: 1 addition & 1 deletion src/c/flexflow_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1730,7 +1730,7 @@ void flexflow_model_generate(flexflow_model_t handle_,
handle,
dataset_fp.c_str(),
max_lengths[i],
max_new_tokens[i],
max_new_tokens_[i],
training_steps[i]);
} else {
assert(false && "Unknown request type");
Expand Down
39 changes: 26 additions & 13 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,8 @@ __global__ void compute_attention_kernel_generation_kernel(
int max_seq_length,
int per_head_size,
int hidden_size,
BatchConfig::PerRequestInfo *request_infos) {
BatchConfig::PerRequestInfo *request_infos,
BatchConfig::PerTokenInfo *tokens_infos) {
// q, k
using Q_vec = typename VEC_K<DT, THREADS_PER_KEY>::Type;
Expand All @@ -409,11 +410,14 @@ __global__ void compute_attention_kernel_generation_kernel(
int const tidx = threadIdx.x;
// head id
int const head_idx = blockIdx.x;
// request idx
int const request_idx = blockIdx.y;
// generation token idx.
int const batch_config_token_idx = blockIdx.y;
int const batch_config_request_id =
request_infos[request_idx].batch_config_request_id;
tokens_infos[batch_config_token_idx].request_index;
if (request_infos[batch_config_request_id].prompt_phase) {
return;
}
assert(request_infos[batch_config_request_id].num_tokens_in_batch == 1);
int const first_step = 0;
Expand All @@ -432,11 +436,13 @@ __global__ void compute_attention_kernel_generation_kernel(
// first WARPS_PER_BLOCK for store qk_max, second WARPS_PER_BLOCK for sum
__shared__ float red_smem[WARPS_PER_BLOCK * 2];
const DT *q_ptr = query + request_idx * hidden_size * QKV_WEIGHT_NUM +
const DT *q_ptr = query +
batch_config_token_idx * hidden_size * QKV_WEIGHT_NUM +
head_idx * per_head_size;
__shared__ Q_vec q_vecs[THREADS_PER_KEY][K_VECS_PER_THREAD];
// DT const *q_ptr =
// query + request_idx * Dh * QKV_WEIGHT_NUM + head_idx * per_head_size;
// query + batch_config_token_idx * Dh * QKV_WEIGHT_NUM + head_idx *
// per_head_size;
// q tensor in this thread
// if THREADS_PER_KEY is 4, first thread load 0, 4, 8, 12..., total
Expand Down Expand Up @@ -615,10 +621,10 @@ __global__ void compute_attention_kernel_generation_kernel(
// Output the final values.
if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) {
convert_from_float(
*reinterpret_cast<V_vec *>(output_ptr + request_idx * hidden_size +
head_idx * per_head_size + vi),
out);
convert_from_float(*reinterpret_cast<V_vec *>(
output_ptr + batch_config_token_idx * hidden_size +
head_idx * per_head_size + vi),
out);
}
}
Expand Down Expand Up @@ -897,14 +903,15 @@ void update_kv_cache_kernel(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig::max_sequence_length(), \
m->qProjSize, \
m->hidden_size, \
m->request_infos)
m->request_infos, \
m->token_infos)
template <typename DT>
void compute_attention_kernel_generation(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
DT *output_ptr,
cudaStream_t stream) {
dim3 grid(m->num_q_heads, bc->num_generation_tokens);
dim3 grid(m->num_q_heads, bc->num_active_tokens());
int const per_head_size = m->qProjSize;
float scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->kProjSize) : 1.0f;
size_t smem_sz;
Expand Down Expand Up @@ -958,8 +965,14 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m,
// phase 1: Implement kernel to apply rotary embedding and scaling
compute_qkv_kernel(
m, bc, shard_id, static_cast<DT *>(m->devQKVProjArray), stream);
// key/value cache shape (keyed by request idx in batch config):
// [proj_size * num_heads, max_seq_length, max_requests]
update_kv_cache_kernel<DT>(m, bc, stream);
// QK prods shape (single request only): [num_new_tokens, total_tokens,
// num_heads] Attn heads shape (single request only): [vProjSize, num_heads,
// num_new_tokens]
if (bc->num_generation_tokens > 0) {
// phase 3: Compute attention score for generation tokens
compute_attention_kernel_generation<DT>(
Expand Down
Loading