From af92b882aabbb25b3835c27002e613c219064e77 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Fri, 5 Jan 2024 20:23:06 -0500 Subject: [PATCH] Update matmul_kernel.cu --- pyg_lib/csrc/ops/cuda/matmul_kernel.cu | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/pyg_lib/csrc/ops/cuda/matmul_kernel.cu b/pyg_lib/csrc/ops/cuda/matmul_kernel.cu index 9909425a5..6c98bef6b 100644 --- a/pyg_lib/csrc/ops/cuda/matmul_kernel.cu +++ b/pyg_lib/csrc/ops/cuda/matmul_kernel.cu @@ -45,9 +45,9 @@ void run_grouped_gemm(const at::TensorList input, // Set arguments into gemm_args from input args for (size_t i = 0; i < num_matrices; ++i) { - auto new_in = input[i].contiguous(); - auto new_other = other[i].contiguous(); - auto new_out = out[i].contiguous(); + auto new_in = input[i]; + auto new_other = other[i]; + auto new_out = out[i]; auto m = new_in.size(0), k = new_other.size((int)(segment)), n = new_out.size(1); @@ -288,9 +288,14 @@ void grouped_matmul_out_kernel(const at::TensorList input, std::vector grouped_matmul_kernel(const at::TensorList input, const at::TensorList other) { std::vector out(input.size()); - for (size_t i = 0; i < input.size(); ++i) + std::vector input_contiguous(input.size()); + std::vector other_contiguous(other.size()); + for (size_t i = 0; i < input.size(); ++i) { + input_contiguous[i] = input[i].contiguous(); + other_contiguous[i] = other[i].contiguous(); out[i] = input[i].new_empty({input[i].size(0), other[i].size(-1)}); - grouped_matmul_out_kernel(input, other, out, false); + } + grouped_matmul_out_kernel(input_contiguous, other_contiguous, out, false); return out; }