From 668d6af066f445a4856f442940c39446a70427b2 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 22 Jul 2024 00:40:27 +0000 Subject: [PATCH] adsf --- pyg_lib/ops/scatter_reduce.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pyg_lib/ops/scatter_reduce.py b/pyg_lib/ops/scatter_reduce.py index 725ab784..0bbb69ec 100644 --- a/pyg_lib/ops/scatter_reduce.py +++ b/pyg_lib/ops/scatter_reduce.py @@ -13,9 +13,10 @@ @triton.jit -def _fused_scatter_reduce_kernel(inputs_ptr, index_ptr, out_ptr, num_feats, - num_reductions, numel, REDUCE0, REDUCE1, - REDUCE2, REDUCE3, BLOCK_SIZE: tl.constexpr): +def _fused_scatter_reduce_forward_kernel(inputs_ptr, index_ptr, out_ptr, + num_feats, num_reductions, numel, + REDUCE0, REDUCE1, REDUCE2, REDUCE3, + BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE @@ -136,7 +137,7 @@ def fused_scatter_reduce( grid = lambda meta: ( # noqa: E731 triton.cdiv(inputs.numel(), meta['BLOCK_SIZE']), ) - _fused_scatter_reduce_kernel[grid]( + _fused_scatter_reduce_forward_kernel[grid]( inputs, index, out,