From c35f37baebaddb70751f3475ba9c052443d28266 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 8 Apr 2024 06:52:32 +0000 Subject: [PATCH] update --- pyg_lib/csrc/ops/cpu/matmul_kernel.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyg_lib/csrc/ops/cpu/matmul_kernel.cpp b/pyg_lib/csrc/ops/cpu/matmul_kernel.cpp index 20ee14f4..21b41d84 100644 --- a/pyg_lib/csrc/ops/cpu/matmul_kernel.cpp +++ b/pyg_lib/csrc/ops/cpu/matmul_kernel.cpp @@ -203,6 +203,7 @@ void grouped_matmul_out_kernel_at_impl(const std::vector input, void grouped_matmul_out_kernel_mkl_impl(const std::vector input, const std::vector other, std::vector out) { +#if WITH_MKL_BLAS() // matrix_params using matrix_params = std::tuple; phmap::flat_hash_map> groups; @@ -261,7 +262,6 @@ void grouped_matmul_out_kernel_mkl_impl(const std::vector input, auto src1_ptrs = const_cast(src1.data()); auto dst_ptrs = dst.data(); -#if WITH_MKL_BLAS() #if AT_MKL_SEQUENTIAL() // unlikely to happen - requires Torch to be built from source with // explicit flag denoting MKL sequential version @@ -273,9 +273,9 @@ void grouped_matmul_out_kernel_mkl_impl(const std::vector input, src0_ptrs, ld_src0.data(), src1_ptrs, ld_src1.data(), beta.data(), dst_ptrs, ld_dst.data(), group_count, group_sizes.data()); -#endif #endif }); +#endif } std::vector grouped_matmul_kernel(const at::TensorList input, @@ -328,6 +328,7 @@ void segment_matmul_out_kernel_mkl_impl(const at::Tensor& input, const at::Tensor& other, at::Tensor& out, const at::IntArrayRef& sizes) { +#if WITH_MKL_BLAS() const int n = other.size(-1); const int k = input.size(-1); const int nk = n * k; @@ -390,7 +391,6 @@ void segment_matmul_out_kernel_mkl_impl(const at::Tensor& input, auto src1_ptrs = const_cast(src1.data()); auto dst_ptrs = dst.data(); -#if WITH_MKL_BLAS() #if AT_MKL_SEQUENTIAL() // unlikely to happen - requires Torch to be built from source with // explicit flag denoting MKL sequential version @@ -402,9 +402,9 @@ void segment_matmul_out_kernel_mkl_impl(const at::Tensor& input, src0_ptrs, ld_src0.data(), src1_ptrs, ld_src1.data(), beta.data(), dst_ptrs, ld_dst.data(), group_count, group_sizes.data()); -#endif #endif }); +#endif } at::Tensor segment_matmul_kernel(const at::Tensor& input,