Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Apr 8, 2024
1 parent b777858 commit c35f37b
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions pyg_lib/csrc/ops/cpu/matmul_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ void grouped_matmul_out_kernel_at_impl(const std::vector<at::Tensor> input,
void grouped_matmul_out_kernel_mkl_impl(const std::vector<at::Tensor> input,
const std::vector<at::Tensor> other,
std::vector<at::Tensor> out) {
#if WITH_MKL_BLAS()
// matrix_params<M, N, K>
using matrix_params = std::tuple<int, int, int>;
phmap::flat_hash_map<matrix_params, std::vector<size_t>> groups;
Expand Down Expand Up @@ -261,7 +262,6 @@ void grouped_matmul_out_kernel_mkl_impl(const std::vector<at::Tensor> input,
auto src1_ptrs = const_cast<const scalar_t**>(src1.data());
auto dst_ptrs = dst.data();

#if WITH_MKL_BLAS()
#if AT_MKL_SEQUENTIAL()
// unlikely to happen - requires Torch to be built from source with
// explicit flag denoting MKL sequential version
Expand All @@ -273,9 +273,9 @@ void grouped_matmul_out_kernel_mkl_impl(const std::vector<at::Tensor> input,
src0_ptrs, ld_src0.data(), src1_ptrs, ld_src1.data(),
beta.data(), dst_ptrs, ld_dst.data(), group_count,
group_sizes.data());
#endif
#endif
});
#endif
}

std::vector<at::Tensor> grouped_matmul_kernel(const at::TensorList input,
Expand Down Expand Up @@ -328,6 +328,7 @@ void segment_matmul_out_kernel_mkl_impl(const at::Tensor& input,
const at::Tensor& other,
at::Tensor& out,
const at::IntArrayRef& sizes) {
#if WITH_MKL_BLAS()
const int n = other.size(-1);
const int k = input.size(-1);
const int nk = n * k;
Expand Down Expand Up @@ -390,7 +391,6 @@ void segment_matmul_out_kernel_mkl_impl(const at::Tensor& input,
auto src1_ptrs = const_cast<const scalar_t**>(src1.data());
auto dst_ptrs = dst.data();

#if WITH_MKL_BLAS()
#if AT_MKL_SEQUENTIAL()
// unlikely to happen - requires Torch to be built from source with
// explicit flag denoting MKL sequential version
Expand All @@ -402,9 +402,9 @@ void segment_matmul_out_kernel_mkl_impl(const at::Tensor& input,
src0_ptrs, ld_src0.data(), src1_ptrs, ld_src1.data(),
beta.data(), dst_ptrs, ld_dst.data(), group_count,
group_sizes.data());
#endif
#endif
});
#endif
}

at::Tensor segment_matmul_kernel(const at::Tensor& input,
Expand Down

0 comments on commit c35f37b

Please sign in to comment.