Skip to content

Commit

Permalink
adsf
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta committed Jul 22, 2024
1 parent d620cda commit 668d6af
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions pyg_lib/ops/scatter_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@


@triton.jit
def _fused_scatter_reduce_kernel(inputs_ptr, index_ptr, out_ptr, num_feats,
num_reductions, numel, REDUCE0, REDUCE1,
REDUCE2, REDUCE3, BLOCK_SIZE: tl.constexpr):
def _fused_scatter_reduce_forward_kernel(inputs_ptr, index_ptr, out_ptr,
num_feats, num_reductions, numel,
REDUCE0, REDUCE1, REDUCE2, REDUCE3,
BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE

Expand Down Expand Up @@ -136,7 +137,7 @@ def fused_scatter_reduce(
grid = lambda meta: ( # noqa: E731
triton.cdiv(inputs.numel(), meta['BLOCK_SIZE']), )

_fused_scatter_reduce_kernel[grid](
_fused_scatter_reduce_forward_kernel[grid](
inputs,
index,
out,
Expand Down

0 comments on commit 668d6af

Please sign in to comment.