Skip to content

Commit

Permalink
Merge branch 'inference' into optimize_attn_v2
Browse files Browse the repository at this point in the history
  • Loading branch information
jiazhihao authored Nov 10, 2023
2 parents 0aab1e3 + c6ad6e2 commit e6a6b0e
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 44 deletions.
14 changes: 8 additions & 6 deletions src/ops/inc_multihead_self_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,11 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m,
DT alpha = 1.0f, beta = 0.0f;
assert(m->qSize == m->vSize && m->qSize == m->kSize);
hipblasDatatype_t hipblas_data_type = ff_to_cuda_datatype(m->output_type[0]);
#if CUDA_VERSION >= 11000
// TODO: currently set the default to HIPBLAS_COMPUTE_16F for best performance
cublasComputeType_t compute_type = HIPBLAS_COMPUTE_16F;
#if defined(CUDA_VERSION) && (CUDA_VERSION < 11000)
hipblasDatatype_t compute_type = hipblas_data_type;
#else
// TODO: currently use the hipblas_data_type
// cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
hipblasDatatype_t compute_type = hipblas_data_type;
#endif
// Compute (W^T)x matmul: einsum(ijkl,im->jmkl)
Expand Down Expand Up @@ -509,10 +510,11 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta const *m,
hipblasDatatype_t hipblas_data_type = ff_to_cuda_datatype(m->output_type[0]);
miopenDataType_t miopen_data_type = ff_to_cudnn_datatype(m->output_type[0]);
assert(data_type_size(m->output_type[0]) == sizeof(DT));
#if CUDA_VERSION >= 11000
// TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
#if defined(CUDA_VERSION) && (CUDA_VERSION < 11000)
hipblasDatatype_t compute_type = hipblas_data_type;
#else
// TODO: currently use the hipblas_data_type
// cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
hipblasDatatype_t compute_type = hipblas_data_type;
#endif
// int num_requests = bc->num_active_requests();
Expand Down
26 changes: 18 additions & 8 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -507,11 +507,16 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m,
DT alpha = 1.0f, beta = 0.0f;
assert(m->qSize == m->vSize && m->qSize == m->kSize);
cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]);
#if CUDA_VERSION >= 11000
// TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
#else
#if defined(CUDA_VERSION) && (CUDA_VERSION < 11000)
cudaDataType_t compute_type = cublas_data_type;
#else
// For best performance, set the default cublas compute type to
// CUBLAS_COMPUTE_16F for half precision and to
// CUBLAS_COMPUTE_32F_FAST_16F for full precision
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
if (m->output_type[0] == DT_FLOAT) {
compute_type = CUBLAS_COMPUTE_32F_FAST_16F;
}
#endif
// Compute (W^T)x matmul: einsum(ijkl,im->jmkl)
// Weights: qSize x qProjSize x 3 x num_q_heads
Expand Down Expand Up @@ -905,11 +910,16 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta const *m,
cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]);
cudnnDataType_t cudnn_data_type = ff_to_cudnn_datatype(m->output_type[0]);
assert(data_type_size(m->output_type[0]) == sizeof(DT));
#if CUDA_VERSION >= 11000
// TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
#else
#if defined(CUDA_VERSION) && (CUDA_VERSION < 11000)
cudaDataType_t compute_type = cublas_data_type;
#else
// For best performance, set the default cublas compute type to
// CUBLAS_COMPUTE_16F for half precision and to
// CUBLAS_COMPUTE_32F_FAST_16F for full precision
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
if (m->output_type[0] == DT_FLOAT) {
compute_type = CUBLAS_COMPUTE_32F_FAST_16F;
}
#endif
// int num_requests = bc->num_active_requests();
int num_tokens = bc->num_active_tokens();
Expand Down
18 changes: 10 additions & 8 deletions src/ops/kernels/linear_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,12 @@ void forward_kernel(LinearMeta const *m,
hipblasDatatype_t input_type = ff_to_cuda_datatype(m->input_type[0]);
hipblasDatatype_t weight_type = ff_to_cuda_datatype(m->weight_type[0]);
hipblasDatatype_t output_type = ff_to_cuda_datatype(m->output_type[0]);
#if CUDA_VERSION >= 11000
// TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
#if defined(CUDA_VERSION) && (CUDA_VERSION < 11000)
hipblasDatatype_t compute_type = output_type;
#else
hipblasDatatype_t compute_type = input_type;
// TODO: currently use the output_type
// cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
hipblasDatatype_t compute_type = output_type;
#endif
checkCUDA(hipblasGemmEx(m->handle.blas,
HIPBLAS_OP_T,
Expand Down Expand Up @@ -337,11 +338,12 @@ void backward_kernel(LinearMeta const *m,
hipblasDatatype_t input_type = ff_to_cuda_datatype(m->input_type[0]);
hipblasDatatype_t weight_type = ff_to_cuda_datatype(m->weight_type[0]);
hipblasDatatype_t output_type = ff_to_cuda_datatype(m->output_type[0]);
#if CUDA_VERSION >= 11000
// TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
#if defined(CUDA_VERSION) && (CUDA_VERSION < 11000)
hipblasDatatype_t compute_type = output_type;
#else
hipblasDatatype_t compute_type = HIPBLAS_R_32F;
// TODO: currently use output_type
// cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
hipblasDatatype_t compute_type = output_type;
#endif
int output_size = out_dim * batch_size;
if (m->activation == AC_MODE_RELU) {
Expand Down
26 changes: 18 additions & 8 deletions src/ops/kernels/linear_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,16 @@ void forward_kernel(LinearMeta const *m,
: ff_to_cuda_datatype(m->weight_type[0]);
cudaDataType_t output_type = ff_to_cuda_datatype(m->output_type[0]);
assert(input_type == weight_type && weight_type == output_type);
#if CUDA_VERSION >= 11000
// TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
#if defined(CUDA_VERSION) && (CUDA_VERSION < 11000)
cudaDataType_t compute_type = cublas_data_type;
#else
cudaDataType_t compute_type = input_type;
// For best performance, set the default cublas compute type to
// CUBLAS_COMPUTE_16F for half precision and to
// CUBLAS_COMPUTE_32F_FAST_16F for full precision
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
if (m->output_type[0] == DT_FLOAT) {
compute_type = CUBLAS_COMPUTE_32F_FAST_16F;
}
#endif
checkCUDA(cublasGemmEx(m->handle.blas,
CUBLAS_OP_T,
Expand Down Expand Up @@ -401,11 +406,16 @@ void backward_kernel(LinearMeta const *m,
cudaDataType_t input_type = ff_to_cuda_datatype(m->input_type[0]);
cudaDataType_t weight_type = ff_to_cuda_datatype(m->weight_type[0]);
cudaDataType_t output_type = ff_to_cuda_datatype(m->output_type[0]);
#if CUDA_VERSION >= 11000
// TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
#if defined(CUDA_VERSION) && (CUDA_VERSION < 11000)
cudaDataType_t compute_type = cublas_data_type;
#else
cudaDataType_t compute_type = CUDA_R_32F;
// For best performance, set the default cublas compute type to
// CUBLAS_COMPUTE_16F for half precision and to
// CUBLAS_COMPUTE_32F_FAST_16F for full precision
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
if (m->output_type[0] == DT_FLOAT) {
compute_type = CUBLAS_COMPUTE_32F_FAST_16F;
}
#endif
int output_size = out_dim * batch_size;
if (m->activation == AC_MODE_RELU) {
Expand Down
7 changes: 4 additions & 3 deletions src/ops/spec_inc_multihead_self_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,11 @@ void compute_attention_kernel(SpecIncMultiHeadSelfAttentionMeta const *m,
hipblasDatatype_t hipblas_data_type = ff_to_cuda_datatype(m->output_type[0]);
miopenDataType_t miopen_data_type = ff_to_cudnn_datatype(m->output_type[0]);
assert(data_type_size(m->output_type[0]) == sizeof(DT));
#if CUDA_VERSION >= 11000
// TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
#if defined(CUDA_VERSION) && (CUDA_VERSION < 11000)
hipblasDatatype_t compute_type = hipblas_data_type;
#else
// TODO: currently use the hipblas_data_type
// cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
hipblasDatatype_t compute_type = hipblas_data_type;
#endif
// int num_requests = bc->num_active_requests();
Expand Down
13 changes: 9 additions & 4 deletions src/ops/spec_inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,16 @@ void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m,
cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]);
cudnnDataType_t cudnn_data_type = ff_to_cudnn_datatype(m->output_type[0]);
assert(data_type_size(m->output_type[0]) == sizeof(DT));
#if CUDA_VERSION >= 11000
// TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
#else
#if defined(CUDA_VERSION) && (CUDA_VERSION < 11000)
cudaDataType_t compute_type = cublas_data_type;
#else
// For best performance, set the default cublas compute type to
// CUBLAS_COMPUTE_16F for half precision and to
// CUBLAS_COMPUTE_32F_FAST_16F for full precision
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
if (m->output_type[0] == DT_FLOAT) {
compute_type = CUBLAS_COMPUTE_32F_FAST_16F;
}
#endif
// int num_requests = bc->num_active_requests();
int num_tokens = bc->num_active_tokens();
Expand Down
7 changes: 4 additions & 3 deletions src/ops/tree_inc_multihead_self_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,11 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m,
hipblasDatatype_t hipblas_data_type = ff_to_cuda_datatype(m->output_type[0]);
miopenDataType_t miopen_data_type = ff_to_cudnn_datatype(m->output_type[0]);
assert(data_type_size(m->output_type[0]) == sizeof(DT));
#if CUDA_VERSION >= 11000
// TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
#if defined(CUDA_VERSION) && (CUDA_VERSION < 11000)
hipblasDatatype_t compute_type = hipblas_data_type;
#else
// TODO: currently use the hipblas_data_type
// cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
hipblasDatatype_t compute_type = hipblas_data_type;
#endif
// int num_requests = bc->num_active_requests();
Expand Down
13 changes: 9 additions & 4 deletions src/ops/tree_inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -430,11 +430,16 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m,
cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]);
cudnnDataType_t cudnn_data_type = ff_to_cudnn_datatype(m->output_type[0]);
assert(data_type_size(m->output_type[0]) == sizeof(DT));
#if CUDA_VERSION >= 11000
// TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
#else
#if defined(CUDA_VERSION) && (CUDA_VERSION < 11000)
cudaDataType_t compute_type = cublas_data_type;
#else
// For best performance, set the default cublas compute type to
// CUBLAS_COMPUTE_16F for half precision and to
// CUBLAS_COMPUTE_32F_FAST_16F for full precision
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
if (m->output_type[0] == DT_FLOAT) {
compute_type = CUBLAS_COMPUTE_32F_FAST_16F;
}
#endif
// int num_requests = bc->num_active_requests();
int processed_tokens_in_batch = 0;
Expand Down

0 comments on commit e6a6b0e

Please sign in to comment.