Skip to content

Commit

Permalink
Merge branch 'master' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
Jokeren authored Jan 6, 2024
2 parents 84b4324 + bef1074 commit 4ae38bf
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Dropped the MKL code path when sampling neighbors with `replace=False` since it does not correctly prevent duplicates ([#275](https://github.com/pyg-team/pyg-lib/pull/275))
- Added `--biased` parameter to run benchmarks for biased sampling ([#267](https://github.com/pyg-team/pyg-lib/pull/267))
- Improved speed of biased sampling ([#270](https://github.com/pyg-team/pyg-lib/pull/270))
- Fixed `grouped_matmul` when tensors are not contiguous (#290)
- Fixed `grouped_matmul` when tensors are not contiguous ([#290](https://github.com/pyg-team/pyg-lib/pull/290))
### Removed

## [0.3.0] - 2023-10-11
Expand Down
25 changes: 18 additions & 7 deletions pyg_lib/csrc/ops/cuda/matmul_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ void run_grouped_gemm(const at::TensorList input,
int64_t* ptr_C_data = ld_A_data + 5 * num_matrices;
cutlass::gemm::GemmCoord* problem_sizes_data =
reinterpret_cast<cutlass::gemm::GemmCoord*>(ld_A_data + 6 * num_matrices);
std::vector<cutlass::gemm::GemmCoord> host_problem_sizes;

// Set arguments into gemm_args from input args
for (size_t i = 0; i < num_matrices; ++i) {
Expand All @@ -52,6 +53,7 @@ void run_grouped_gemm(const at::TensorList input,
n = new_out.size(1);

problem_sizes_data[i] = cutlass::gemm::GemmCoord(m, n, k);
host_problem_sizes.push_back(cutlass::gemm::GemmCoord(m, n, k));

ld_A_data[i] = GemmKernel::LayoutA::packed({m, k}).stride(0);
ld_B_data[i] = GemmKernel::LayoutB::packed({k, n}).stride(0);
Expand Down Expand Up @@ -86,11 +88,15 @@ void run_grouped_gemm(const at::TensorList input,
reinterpret_cast<float**>(ptr_B_data),
reinterpret_cast<float**>(ptr_C_data),
reinterpret_cast<float**>(ptr_C_data), ld_A_data, ld_B_data, ld_C_data,
ld_C_data);
ld_C_data, host_problem_sizes.data());

GemmGrouped gemm;
int64_t workspace_bytes = GemmGrouped::get_workspace_size(args);
at::Tensor workspace =
at::empty({workspace_bytes},
at::TensorOptions().dtype(at::kByte).device(out[0].device()));
auto status =
gemm.initialize(args, nullptr, at::cuda::getCurrentCUDAStream());
gemm.initialize(args, workspace.data_ptr(), at::cuda::getCurrentCUDAStream());
TORCH_CHECK(status == cutlass::Status::kSuccess, "GroupedGEMM init failed");
status = gemm.run(at::cuda::getCurrentCUDAStream());
TORCH_CHECK(status == cutlass::Status::kSuccess, "GroupedGEMM run failed");
Expand Down Expand Up @@ -149,7 +155,8 @@ void grouped_matmul_out_kernel(const at::TensorList input,
float, 1, float, float>, //
cutlass::gemm::threadblock:: // Swizzling Operator
GemmIdentityThreadblockSwizzle<8>, //
2 // Stages
2,
cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute // Stages
>::GemmKernel;
run_grouped_gemm<GemmKernel_Volta>(input, other, out, segment);
} else {
Expand Down Expand Up @@ -185,7 +192,8 @@ void grouped_matmul_out_kernel(const at::TensorList input,
float, 1, float, float>, //
cutlass::gemm::threadblock:: // Swizzling Operator
GemmIdentityThreadblockSwizzle<8>, //
3 // Stages
3, // Stages
cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute
>::GemmKernel;
int grouped_shared_mem =
shared_memory_for_kernel<DefaultGemmKernel_TF32>();
Expand Down Expand Up @@ -217,7 +225,8 @@ void grouped_matmul_out_kernel(const at::TensorList input,
float, 1, float, float>, //
cutlass::gemm::threadblock:: // Swizzling Operator
GemmIdentityThreadblockSwizzle<8>, //
3 // Stages
3, // Stages
cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute
>::GemmKernel;
run_grouped_gemm<SmallGemmKernel_TF32>(input, other, out, segment);
}
Expand Down Expand Up @@ -245,7 +254,8 @@ void grouped_matmul_out_kernel(const at::TensorList input,
float, 1, float, float>, //
cutlass::gemm::threadblock:: // Swizzling Operator
GemmIdentityThreadblockSwizzle<8>, //
3 // Stages
3, // Stages
cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute
>::GemmKernel;
int grouped_shared_mem =
shared_memory_for_kernel<DefaultGemmKernel_FP32>();
Expand Down Expand Up @@ -277,7 +287,8 @@ void grouped_matmul_out_kernel(const at::TensorList input,
float, 1, float, float>, //
cutlass::gemm::threadblock:: // Swizzling Operator
GemmIdentityThreadblockSwizzle<8>, //
3 // Stages
3, // Stages
cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute
>::GemmKernel;
run_grouped_gemm<SmallGemmKernel_FP32>(input, other, out, segment);
}
Expand Down

0 comments on commit 4ae38bf

Please sign in to comment.