Skip to content

Commit

Permalink
Update matmul_kernel.cu
Browse files Browse the repository at this point in the history
  • Loading branch information
Jokeren authored Jan 6, 2024
1 parent 160d5c3 commit af92b88
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions pyg_lib/csrc/ops/cuda/matmul_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ void run_grouped_gemm(const at::TensorList input,

// Set arguments into gemm_args from input args
for (size_t i = 0; i < num_matrices; ++i) {
auto new_in = input[i].contiguous();
auto new_other = other[i].contiguous();
auto new_out = out[i].contiguous();
auto new_in = input[i];
auto new_other = other[i];
auto new_out = out[i];
auto m = new_in.size(0), k = new_other.size((int)(segment)),
n = new_out.size(1);

Expand Down Expand Up @@ -288,9 +288,14 @@ void grouped_matmul_out_kernel(const at::TensorList input,
std::vector<at::Tensor> grouped_matmul_kernel(const at::TensorList input,
const at::TensorList other) {
std::vector<at::Tensor> out(input.size());
for (size_t i = 0; i < input.size(); ++i)
std::vector<at::Tensor> input_contiguous(input.size());
std::vector<at::Tensor> other_contiguous(other.size());
for (size_t i = 0; i < input.size(); ++i) {
input_contiguous[i] = input[i].contiguous();
other_contiguous[i] = other[i].contiguous();
out[i] = input[i].new_empty({input[i].size(0), other[i].size(-1)});
grouped_matmul_out_kernel(input, other, out, false);
}
grouped_matmul_out_kernel(input_contiguous, other_contiguous, out, false);

return out;
}
Expand Down

0 comments on commit af92b88

Please sign in to comment.