Skip to content

Commit

Permalink
use CUBLAS_COMPUTE_FAST_16F for full-precision gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
jiazhihao committed Oct 30, 2023
1 parent 610bb07 commit 65f497a
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 6 deletions.
14 changes: 12 additions & 2 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,13 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m,
#if defined(CUDA_VERSION) && (CUDA_VERSION < 11000)
cudaDataType_t compute_type = cublas_data_type;
#else
// TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance
// For best performance, set the default cublas compute type to
// CUBLAS_COMPUTE_16F for half precision and to
// CUBLAS_COMPUTE_32F_FAST_16F for full precision
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
if (m->output_type[0] == DT_FLOAT) {
compute_type = CUBLAS_COMPUTE_32F_FAST_16F;
}
#endif
// Compute (W^T)x matmul: einsum(ijkl,im->jmkl)
// Weights: qSize x qProjSize x 3 x num_q_heads
Expand Down Expand Up @@ -511,8 +516,13 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta const *m,
#if defined(CUDA_VERSION) && (CUDA_VERSION < 11000)
cudaDataType_t compute_type = cublas_data_type;
#else
// TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance
// For best performance, set the default cublas compute type to
// CUBLAS_COMPUTE_16F for half precision and to
// CUBLAS_COMPUTE_32F_FAST_16F for full precision
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
if (m->output_type[0] == DT_FLOAT) {
compute_type = CUBLAS_COMPUTE_32F_FAST_16F;
}
#endif
// int num_requests = bc->num_active_requests();
int num_tokens = bc->num_active_tokens();
Expand Down
14 changes: 12 additions & 2 deletions src/ops/kernels/linear_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,13 @@ void forward_kernel(LinearMeta const *m,
#if defined(CUDA_VERSION) && (CUDA_VERSION < 11000)
cudaDataType_t compute_type = cublas_data_type;
#else
// TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance
// For best performance, set the default cublas compute type to
// CUBLAS_COMPUTE_16F for half precision and to
// CUBLAS_COMPUTE_32F_FAST_16F for full precision
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
if (m->output_type[0] == DT_FLOAT) {
compute_type = CUBLAS_COMPUTE_32F_FAST_16F;
}
#endif
checkCUDA(cublasGemmEx(m->handle.blas,
CUBLAS_OP_T,
Expand Down Expand Up @@ -404,8 +409,13 @@ void backward_kernel(LinearMeta const *m,
#if defined(CUDA_VERSION) && (CUDA_VERSION < 11000)
cudaDataType_t compute_type = cublas_data_type;
#else
// TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance
// For best performance, set the default cublas compute type to
// CUBLAS_COMPUTE_16F for half precision and to
// CUBLAS_COMPUTE_32F_FAST_16F for full precision
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
if (m->output_type[0] == DT_FLOAT) {
compute_type = CUBLAS_COMPUTE_32F_FAST_16F;
}
#endif
int output_size = out_dim * batch_size;
if (m->activation == AC_MODE_RELU) {
Expand Down
7 changes: 6 additions & 1 deletion src/ops/spec_inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,13 @@ void compute_attention_kernel(SpecIncMultiHeadSelfAttentionMeta const *m,
#if defined(CUDA_VERSION) && (CUDA_VERSION < 11000)
cudaDataType_t compute_type = cublas_data_type;
#else
// TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance
// For best performance, set the default cublas compute type to
// CUBLAS_COMPUTE_16F for half precision and to
// CUBLAS_COMPUTE_32F_FAST_16F for full precision
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
if (m->output_type[0] == DT_FLOAT) {
compute_type = CUBLAS_COMPUTE_32F_FAST_16F;
}
#endif
// int num_requests = bc->num_active_requests();
int num_tokens = bc->num_active_tokens();
Expand Down
7 changes: 6 additions & 1 deletion src/ops/tree_inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,13 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m,
#if defined(CUDA_VERSION) && (CUDA_VERSION < 11000)
cudaDataType_t compute_type = cublas_data_type;
#else
// TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance
// For best performance, set the default cublas compute type to
// CUBLAS_COMPUTE_16F for half precision and to
// CUBLAS_COMPUTE_32F_FAST_16F for full precision
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
if (m->output_type[0] == DT_FLOAT) {
compute_type = CUBLAS_COMPUTE_32F_FAST_16F;
}
#endif
// int num_requests = bc->num_active_requests();
int processed_tokens_in_batch = 0;
Expand Down

0 comments on commit 65f497a

Please sign in to comment.