Skip to content

Commit

Permalink
Disable int8, gemm_batch for cublas
Browse files Browse the repository at this point in the history
  • Loading branch information
Aidan committed May 31, 2024
1 parent 7f72a3d commit 6e70f68
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/blas/backends/cublas/cublas_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose tran

GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, sycl::half, sycl::half)
GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, float, float)
GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, float, float)
GEMM_STRIDED_BATCH_LAUNCHER(float, float, float, float)
GEMM_STRIDED_BATCH_LAUNCHER(double, double, double, double)
GEMM_STRIDED_BATCH_LAUNCHER(std::complex<float>, std::complex<float>, std::complex<float>,
Expand All @@ -212,6 +211,7 @@ GEMM_STRIDED_BATCH_LAUNCHER(std::complex<double>, std::complex<double>, std::com
dtype_string<TYPE_C>() + "," + dtype_string<TYPE_S>() + ">"); \
}

GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, float, float)
GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, std::int32_t, float)

#undef GEMM_STRIDED_BATCH_LAUNCHER
Expand Down Expand Up @@ -632,7 +632,6 @@ inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose tra

GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half)
GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float)
GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float)
GEMM_STRIDED_BATCH_LAUNCHER_USM(float, float, float, float)
GEMM_STRIDED_BATCH_LAUNCHER_USM(double, double, double, double)
GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex<float>, std::complex<float>, std::complex<float>,
Expand All @@ -654,6 +653,7 @@ GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex<double>, std::complex<double>, std:
dtype_string<TYPE_C>() + "," + dtype_string<TYPE_S>() + ">"); \
}

GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float)
GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float)

#undef GEMM_STRIDED_BATCH_LAUNCHER_USM
Expand Down Expand Up @@ -714,7 +714,6 @@ inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, tr

GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half)
GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float)
GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float)
GEMM_BATCH_LAUNCHER_USM(float, float, float, float)
GEMM_BATCH_LAUNCHER_USM(double, double, double, double)
GEMM_BATCH_LAUNCHER_USM(std::complex<float>, std::complex<float>, std::complex<float>,
Expand All @@ -736,6 +735,7 @@ GEMM_BATCH_LAUNCHER_USM(std::complex<double>, std::complex<double>, std::complex
dtype_string<TYPE_C>() + "," + dtype_string<TYPE_S>() + ">"); \
}

GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float)
GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float)

#undef GEMM_BATCH_LAUNCHER_USM
Expand Down

0 comments on commit 6e70f68

Please sign in to comment.