Skip to content

Commit

Permalink
Add safety check to that TransposeBigMLFloat16 test passes (#77)
Browse files Browse the repository at this point in the history
* Added maximum gridDim.y overflow heck before calling transposeNoOverlap kernel so that TransposeBigMLFloat16 test passes

* Fix formatting
  • Loading branch information
sstamenk authored Nov 19, 2024
1 parent d906a82 commit 061c493
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
23 changes: 21 additions & 2 deletions onnxruntime/core/providers/rocm/fpgeneric.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,26 @@ __global__ void CopyVectorBFloat16(const onnxruntime::BFloat16* x, int incx, onn

} // namespace

rocblas_status rocblasTransposeHelper(hipStream_t stream, rocblas_handle, rocblas_operation , rocblas_operation , int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int) {
dim3 rocblasTransposeHelperDimGrid(int m, int n) {
return dim3((n + TRANS_TILE_DIM - 1) / TRANS_TILE_DIM, (m + TRANS_TILE_DIM - 1) / TRANS_TILE_DIM, 1);
}

// rocblasTransposeHelper can only be used if it won't overflow the maxGridSize y dimension size
__host__ bool CanUse_rocblasTransposeHelper_MLFloat16(int m, int n) {
dim3 dimGrid = rocblasTransposeHelperDimGrid(m, n);

int deviceId;
hipError_t hipError = hipGetDevice(&deviceId);
if (hipError != 0) return false;

hipDeviceProp_t deviceProp;
hipError = hipGetDeviceProperties(&deviceProp, deviceId);
if (hipError != 0) return false;

return dimGrid.y < deviceProp.maxGridSize[1];
}

rocblas_status rocblasTransposeHelper(hipStream_t stream, rocblas_handle, rocblas_operation, rocblas_operation, int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int) {
if (C != A) {
dim3 dimGrid((n + TRANS_TILE_DIM - 1) / TRANS_TILE_DIM, (m + TRANS_TILE_DIM - 1) / TRANS_TILE_DIM, 1);
dim3 dimBlock(TRANS_TILE_DIM, BLOCK_ROWS, 1);
Expand All @@ -73,7 +92,7 @@ rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle, int n, cons
}

rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle, int n, const onnxruntime::BFloat16* x, int incx,
onnxruntime::BFloat16* y, int incy) {
onnxruntime::BFloat16* y, int incy) {
dim3 dimGrid((unsigned int)(n + COPY_BLOCK_DIM - 1) / COPY_BLOCK_DIM, 1, 1);
dim3 dimBlock(COPY_BLOCK_DIM, 1, 1);
CopyVectorBFloat16<<<dimGrid, dimBlock, 0, stream>>>(x, incx, y, incy, n);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ inline rocblas_status rocblasTransposeHelper(hipStream_t /*stream*/, rocblas_han
return rocblas_dgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);
}

inline bool CanUse_rocblasTransposeHelper_MLFloat16(int /*m*/, int /*n*/) { return true; } // CUDA has a limited grid size of 65536, ROCm has higher limits.
bool CanUse_rocblasTransposeHelper_MLFloat16(int m, int n);
rocblas_status rocblasTransposeHelper(hipStream_t stream, rocblas_handle, rocblas_operation, rocblas_operation, int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int);

// copy
Expand Down

0 comments on commit 061c493

Please sign in to comment.