From 97696041181d32679e1d1d0a8d7cf3cc2e1b8a97 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Fri, 10 Nov 2023 03:43:02 +0000 Subject: [PATCH] fix bwd bugs --- src/ops/inc_multihead_self_attention.cu | 7 ++++++- src/ops/kernels/linear_kernels.cu | 12 +++++++++--- src/runtime/cuda_helper.cu | 22 ++++++++++++++++++++++ 3 files changed, 37 insertions(+), 4 deletions(-) diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index c406435327..1a93251db4 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -469,8 +469,13 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) cudaDataType_t compute_type = cublas_data_type; #else - // TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance + // For best performance, set the default cublas compute type to + // CUBLAS_COMPUTE_16F for half precision and to + // CUBLAS_COMPUTE_32F_FAST_16F for full precision cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; + if (m->output_type[0] == DT_FLOAT) { + compute_type = CUBLAS_COMPUTE_32F_FAST_16F; + } #endif for (int i = 0; i < bc->max_requests_per_batch(); i++) { if (bc->request_completed[i]) { diff --git a/src/ops/kernels/linear_kernels.cu b/src/ops/kernels/linear_kernels.cu index dad6dc4e00..6f4016f2c2 100644 --- a/src/ops/kernels/linear_kernels.cu +++ b/src/ops/kernels/linear_kernels.cu @@ -457,14 +457,20 @@ void peft_bwd_kernel(LinearMeta const *m, cudaDataType_t weight_type = ff_to_cuda_datatype(m->weight_type[0]); cudaDataType_t output_type = ff_to_cuda_datatype(m->output_type[0]); // update input_grad_ptr and output_grad_ptr offset - input_grad_ptr = static_cast
(input_grad_ptr) + num_infr_tokens * in_dim; + int num_infr_only_tokens = num_infr_tokens - num_peft_tokens; + input_grad_ptr = static_cast
(input_grad_ptr) + num_infr_only_tokens * in_dim; output_grad_ptr = - static_cast
(output_grad_ptr) + num_infr_tokens * out_dim; + static_cast
(output_grad_ptr) + num_infr_only_tokens * out_dim; #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) cudaDataType_t compute_type = output_type; #else - // TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance + // For best performance, set the default cublas compute type to + // CUBLAS_COMPUTE_16F for half precision and to + // CUBLAS_COMPUTE_32F_FAST_16F for full precision cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; + if (m->output_type[0] == DT_FLOAT) { + compute_type = CUBLAS_COMPUTE_32F_FAST_16F; + } #endif int output_size = out_dim * num_peft_tokens; if (m->activation == AC_MODE_RELU) { diff --git a/src/runtime/cuda_helper.cu b/src/runtime/cuda_helper.cu index c2b2affc40..e2078fa663 100644 --- a/src/runtime/cuda_helper.cu +++ b/src/runtime/cuda_helper.cu @@ -608,6 +608,28 @@ cudnnDataType_t cuda_to_cudnn_datatype(cudaDataType_t type) { return CUDNN_DATA_FLOAT; } +void check_device_vs_host_ptr(void const *maybe_devicePtr) { + cudaPointerAttributes attributes; + cudaError_t cudaStatus = cudaPointerGetAttributes(&attributes, maybe_devicePtr); + + if (cudaStatus == cudaSuccess) { + // Check attributes and perform actions accordingly + if (attributes.type == cudaMemoryTypeDevice) { + printf("Pointer is allocated in device memory.\n"); + } else if (attributes.type == cudaMemoryTypeHost) { + printf("Pointer is allocated in host memory.\n"); + } else if (attributes.type == cudaMemoryTypeUnregistered) { + printf("Pointer is unregistered.\n"); + } else if (attributes.type == cudaMemoryTypeManaged) { + printf("Pointer is managed.\n"); + } else { + printf("Pointer is not allocated in recognized memory type.\n"); + } + } else { + fprintf(stderr, "cudaPointerGetAttributes failed: %s\n", cudaGetErrorString(cudaStatus)); + } +} + template __global__ void assign_kernel(half *ptr, coord_t size, half value); template __global__ void