Skip to content

Commit

Permalink
cleanup&hip
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhaoc committed Oct 14, 2023
1 parent 23f5891 commit e31249a
Show file tree
Hide file tree
Showing 7 changed files with 381 additions and 669 deletions.
3 changes: 2 additions & 1 deletion include/flexflow/ops/inc_multihead_self_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ class IncMultiHeadSelfAttentionMeta : public OpMeta {
size_t weights_params, weightSize, biasSize, reserveSpaceSize,
quantized_weightSize;
int qSize, kSize, vSize, qProjSize, kProjSize, vProjSize, oProjSize;
int global_num_q_heads, global_num_kv_heads, num_q_heads, num_kv_heads, hidden_size;
int global_num_q_heads, global_num_kv_heads, num_q_heads, num_kv_heads,
hidden_size;
bool *has_load_weights;
bool *apply_rotary_embedding;
bool *qkv_bias;
Expand Down
438 changes: 153 additions & 285 deletions src/ops/inc_multihead_self_attention.cpp

Large diffs are not rendered by default.

49 changes: 26 additions & 23 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ __global__ void apply_proj_bias_qkv(DT *input_ptr,
// int qkv_index = i / (num_tokens * qProjSize) % 3;

int token_idx = i / (hidden_size * QKV_WEIGHT_NUM);
size_t in_token_idx = i - token_idx * hidden_size * 3;
size_t in_token_idx = i - token_idx * hidden_size * QKV_WEIGHT_NUM;

int qkv_index = in_token_idx / hidden_size;

Expand All @@ -97,9 +97,10 @@ __global__ void apply_proj_bias_qkv(DT *input_ptr,
int global_head_idx = head_idx + shard_id * num_q_heads;

size_t pre_length =
qkv_index == 0 ? 0
: (qkv_index == 1 ? qProjSize * global_num_q_heads
: qProjSize * global_num_q_heads * 2);
qkv_index == 0
? 0
: (qkv_index == 1 ? qProjSize * global_num_q_heads
: qProjSize * global_num_q_heads * KV_WEIGHT_NUM);

size_t bias_idx = pre_length + global_head_idx * proj_size + i % proj_size;

Expand Down Expand Up @@ -194,7 +195,7 @@ __global__ void
int head_idx = (real_i - (token_idx * (hidden_size / 2))) / (proj_size / 2);

int real_part_index = idx + head_idx * proj_size +
token_idx * hidden_size * 3 +
token_idx * hidden_size * QKV_WEIGHT_NUM +
hidden_size * (q_tensor ? 0 : 1);
int complex_part_index = real_part_index + (proj_size / 2);

Expand Down Expand Up @@ -252,7 +253,7 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m,
assert(m_q == m_k && m_k == m_v); // keep things simple for now
int n = bc->num_active_tokens();
int k = m->qSize;
int m_ = m_q * 3;
int m_ = m_q * QKV_WEIGHT_NUM;
int lda = k, ldb = k, ldc = m_;
checkCUDA(cublasGemmEx(m->handle.blas,
CUBLAS_OP_T,
Expand Down Expand Up @@ -280,21 +281,21 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m,
size_t q_array_size = m->qProjSize * num_tokens * m->num_q_heads;
// apply bias for q, k, v
if (*m->qkv_bias) {
apply_proj_bias_qkv<<<GET_BLOCKS(parallelism),
min(CUDA_NUM_THREADS, parallelism),
0,
stream>>>(output_ptr,
bias_ptr,
shard_id,
num_tokens,
m->qProjSize,
m->kProjSize,
m->vProjSize,
m->global_num_q_heads,
m->num_q_heads,
*m->scaling_query,
m->scaling_factor,
m->hidden_size);
// apply_proj_bias_qkv<<<GET_BLOCKS(parallelism),
// min(CUDA_NUM_THREADS, parallelism),
// 0,
// stream>>>(output_ptr,
// bias_ptr,
// shard_id,
// num_tokens,
// m->qProjSize,
// m->kProjSize,
// m->vProjSize,
// m->global_num_q_heads,
// m->num_q_heads,
// *m->scaling_query,
// m->scaling_factor,
// m->hidden_size);
} else if (m->scaling_query) {
scaling_query_kernel<<<GET_BLOCKS(parallelism),
min(CUDA_NUM_THREADS, parallelism),
Expand Down Expand Up @@ -460,7 +461,8 @@ __global__ void store_kv_cache(DT const *devQKVProjArray,
int token_idx = i / hidden_size;
int offset = i % hidden_size;
size_t val_idx = token_idx * 3 * hidden_size + hidden_size + offset;
size_t val_idx =
token_idx * QKV_WEIGHT_NUM * hidden_size + hidden_size + offset;
DT kVal = devQKVProjArray[val_idx];
DT vVal = devQKVProjArray[val_idx + hidden_size];
Expand Down Expand Up @@ -536,7 +538,8 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta const *m,
int m_ = num_new_tokens;
int n = total_tokens;
int k = m->qProjSize;
int lda = k * m->num_q_heads * 3, ldb = k * m->num_q_heads, ldc = m_;
int lda = k * m->num_q_heads * QKV_WEIGHT_NUM, ldb = k * m->num_q_heads,
ldc = m_;
int strideA = q_block_size;
int strideB = kt_block_size;
int strideC = num_new_tokens * total_tokens;
Expand Down
Loading

0 comments on commit e31249a

Please sign in to comment.