Skip to content

Commit

Permalink
more fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhaoc committed Oct 14, 2023
1 parent e31249a commit c9b4ed3
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 37 deletions.
41 changes: 21 additions & 20 deletions src/ops/inc_multihead_self_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ __global__ void scaling_query_kernel(DT *input_ptr,
int hidden_size) {
CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) {
int token_idx = i / hidden_size;
input_ptr[i + token_idx * hidden_size * QKV_WEIGHT_NUM] *= scaling_factor;
input_ptr[i % hidden_size + token_idx * hidden_size * QKV_WEIGHT_NUM] *=
scaling_factor;
}
}

Expand Down Expand Up @@ -274,25 +275,25 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m,
int k = m->qSize;
int m_ = m_q * QKV_WEIGHT_NUM;
int lda = k, ldb = k, ldc = m_;
checkCUDA(cublasGemmEx(m->handle.blas,
CUBLAS_OP_T,
CUBLAS_OP_N,
m_,
n,
k,
&alpha,
weight_ptr,
cublas_data_type,
lda,
input_ptr,
cublas_data_type,
ldb,
&beta,
output_ptr,
cublas_data_type,
ldc,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
checkCUDA(hipblasGemmEx(m->handle.blas,
HIPBLAS_OP_T,
HIPBLAS_OP_N,
m_,
n,
k,
&alpha,
weight_ptr,
hipblas_data_type,
lda,
input_ptr,
hipblas_data_type,
ldb,
&beta,
output_ptr,
hipblas_data_type,
ldc,
compute_type,
HIPBLAS_GEMM_DEFAULT));

// apply rotary emmmbedding for q and k
// step1 change the k, v to complex tensor
Expand Down
33 changes: 17 additions & 16 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ __global__ void scaling_query_kernel(DT *input_ptr,
int hidden_size) {
CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) {
int token_idx = i / hidden_size;
input_ptr[i + token_idx * hidden_size * QKV_WEIGHT_NUM] *= scaling_factor;
input_ptr[i % hidden_size + token_idx * hidden_size * QKV_WEIGHT_NUM] *=
scaling_factor;
}
}

Expand Down Expand Up @@ -281,21 +282,21 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m,
size_t q_array_size = m->qProjSize * num_tokens * m->num_q_heads;
// apply bias for q, k, v
if (*m->qkv_bias) {
// apply_proj_bias_qkv<<<GET_BLOCKS(parallelism),
// min(CUDA_NUM_THREADS, parallelism),
// 0,
// stream>>>(output_ptr,
// bias_ptr,
// shard_id,
// num_tokens,
// m->qProjSize,
// m->kProjSize,
// m->vProjSize,
// m->global_num_q_heads,
// m->num_q_heads,
// *m->scaling_query,
// m->scaling_factor,
// m->hidden_size);
apply_proj_bias_qkv<<<GET_BLOCKS(parallelism),
min(CUDA_NUM_THREADS, parallelism),
0,
stream>>>(output_ptr,
bias_ptr,
shard_id,
num_tokens,
m->qProjSize,
m->kProjSize,
m->vProjSize,
m->global_num_q_heads,
m->num_q_heads,
*m->scaling_query,
m->scaling_factor,
m->hidden_size);
} else if (m->scaling_query) {
scaling_query_kernel<<<GET_BLOCKS(parallelism),
min(CUDA_NUM_THREADS, parallelism),
Expand Down
2 changes: 1 addition & 1 deletion src/ops/spec_inc_multihead_self_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ __global__ void spec_store_kv_cache(
int qProjSize,
int kProjSize,
int vProjSize,
int num_kv_heads,
int num_tokens,
int max_seq_len,
int max_beam_width,
bool is_root,
Expand Down

0 comments on commit c9b4ed3

Please sign in to comment.