diff --git a/CHANGELOG.md b/CHANGELOG.md index 4876414fd..dc01689ea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [0.4.0] - 2023-MM-DD ### Added -- Added `softmax_csr` implementation ([#264](https://github.com/pyg-team/pyg-lib/pull/264)) +- Added `softmax_csr` implementation ([#264](https://github.com/pyg-team/pyg-lib/pull/264), [#282](https://github.com/pyg-team/pyg-lib/pull/282)) - Added support for edge-level sampling ([#280](https://github.com/pyg-team/pyg-lib/pull/280)) - Added support for `bfloat16` data type in `segment_matmul` and `grouped_matmul` (CPU only) ([#272](https://github.com/pyg-team/pyg-lib/pull/272)) ### Changed diff --git a/pyg_lib/csrc/ops/autograd/softmax_kernel.cpp b/pyg_lib/csrc/ops/autograd/softmax_kernel.cpp new file mode 100644 index 000000000..52f4696b9 --- /dev/null +++ b/pyg_lib/csrc/ops/autograd/softmax_kernel.cpp @@ -0,0 +1,60 @@ +#include "../softmax.h" + +#include + +namespace pyg { +namespace ops { + +namespace { + +using torch::autograd::Variable; +using torch::autograd::variable_list; + +class SoftmaxCSR : public torch::autograd::Function { + public: + static variable_list forward(torch::autograd::AutogradContext* ctx, + const Variable& src, + const at::Tensor& ptr, + const int64_t dim) { + at::AutoDispatchBelowADInplaceOrView g; + + Variable out = softmax_csr(src, ptr, dim); + ctx->saved_data["dim"] = dim; + ctx->save_for_backward({src, out, ptr}); + + return {out}; + } + + static variable_list backward(torch::autograd::AutogradContext* ctx, + variable_list out_grads) { + const auto out_grad = out_grads[0]; + const auto saved = ctx->get_saved_variables(); + const auto src = saved[0]; + const auto out = saved[1]; + const auto ptr = saved[2]; + const auto dim = ctx->saved_data["dim"].toInt(); + + auto src_grad = Variable(); + if (torch::autograd::any_variable_requires_grad({src})) { + src_grad = softmax_csr_backward(out, out_grad, ptr, dim); + } + + return {src_grad, Variable(), Variable()}; + } +}; + +at::Tensor softmax_csr_autograd(const at::Tensor& src, + const at::Tensor& ptr, + const int64_t dim) { + return SoftmaxCSR::apply(src, ptr, dim)[0]; +} + +} // namespace + +TORCH_LIBRARY_IMPL(pyg, Autograd, m) { + m.impl(TORCH_SELECTIVE_NAME("pyg::softmax_csr"), + TORCH_FN(softmax_csr_autograd)); +} + +} // namespace ops +} // namespace pyg diff --git a/pyg_lib/csrc/ops/cpu/softmax_kernel.cpp b/pyg_lib/csrc/ops/cpu/softmax_kernel.cpp index 88575af13..812d0abcf 100644 --- a/pyg_lib/csrc/ops/cpu/softmax_kernel.cpp +++ b/pyg_lib/csrc/ops/cpu/softmax_kernel.cpp @@ -248,7 +248,7 @@ at::Tensor softmax_csr_backward_kernel(const at::Tensor& out, } // namespace TORCH_LIBRARY_IMPL(pyg, CPU, m) { - m.impl(TORCH_SELECTIVE_NAME("pyg::softmax_csr_forward"), + m.impl(TORCH_SELECTIVE_NAME("pyg::softmax_csr"), TORCH_FN(softmax_csr_forward_kernel)); m.impl(TORCH_SELECTIVE_NAME("pyg::softmax_csr_backward"), TORCH_FN(softmax_csr_backward_kernel)); diff --git a/pyg_lib/csrc/ops/softmax.cpp b/pyg_lib/csrc/ops/softmax.cpp index 92512cfae..a5c6e14f9 100644 --- a/pyg_lib/csrc/ops/softmax.cpp +++ b/pyg_lib/csrc/ops/softmax.cpp @@ -7,20 +7,20 @@ namespace pyg { namespace ops { // Performs softmax operations for each group. -PYG_API at::Tensor softmax_csr_forward(const at::Tensor& src, - const at::Tensor& ptr, - const int64_t dim) { +PYG_API at::Tensor softmax_csr(const at::Tensor& src, + const at::Tensor& ptr, + const int64_t dim) { at::TensorArg src_arg{src, "src", 0}; at::TensorArg ptr_arg{ptr, "ptr", 1}; - at::CheckedFrom c{"softmax_forward"}; + at::CheckedFrom c{"softmax_csr"}; at::checkAllDefined(c, {src_arg, ptr_arg}); at::checkContiguous(c, src_arg); at::checkContiguous(c, ptr_arg); static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("pyg::softmax_csr_forward", "") - .typed(); + .findSchemaOrThrow("pyg::softmax_csr", "") + .typed(); return op.call(src, ptr, dim); } @@ -32,7 +32,7 @@ PYG_API at::Tensor softmax_csr_backward(const at::Tensor& out, at::TensorArg out_arg{out, "out", 0}; at::TensorArg out_grad_arg{out_grad, "out_grad", 1}; at::TensorArg ptr_arg{ptr, "ptr", 2}; - at::CheckedFrom c{"softmax_backward"}; + at::CheckedFrom c{"softmax_csr_backward"}; at::checkAllDefined(c, {out_arg, out_grad_arg, ptr_arg}); at::checkContiguous(c, out_arg); @@ -47,7 +47,7 @@ PYG_API at::Tensor softmax_csr_backward(const at::Tensor& out, TORCH_LIBRARY_FRAGMENT(pyg, m) { m.def( - TORCH_SELECTIVE_SCHEMA("pyg::softmax_csr_forward(Tensor src, Tensor ptr, " + TORCH_SELECTIVE_SCHEMA("pyg::softmax_csr(Tensor src, Tensor ptr, " "int dim=0) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA( "pyg::softmax_csr_backward(Tensor out, Tensor out_grad, " diff --git a/pyg_lib/csrc/ops/softmax.h b/pyg_lib/csrc/ops/softmax.h index f381ae825..6e0d480ab 100644 --- a/pyg_lib/csrc/ops/softmax.h +++ b/pyg_lib/csrc/ops/softmax.h @@ -7,9 +7,9 @@ namespace pyg { namespace ops { // Performs softmax operations for each group. -PYG_API at::Tensor softmax_csr_forward(const at::Tensor& src, - const at::Tensor& ptr, - const int64_t dim = 0); +PYG_API at::Tensor softmax_csr(const at::Tensor& src, + const at::Tensor& ptr, + const int64_t dim = 0); // Computes gradient for grouped softmax operations. PYG_API at::Tensor softmax_csr_backward(const at::Tensor& out, diff --git a/pyg_lib/ops/__init__.py b/pyg_lib/ops/__init__.py index 81b66c3f1..113e9bc80 100644 --- a/pyg_lib/ops/__init__.py +++ b/pyg_lib/ops/__init__.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple import torch import torch.utils._pytree as pytree @@ -331,29 +331,6 @@ def index_sort( return torch.ops.pyg.index_sort(inputs, max_value) -class Softmax(torch.autograd.Function): - @staticmethod - def forward( - ctx, - src: Tensor, - ptr: Tensor, - dim: int = 0, - ) -> Tensor: - out = torch.ops.pyg.softmax_csr_forward(src, ptr, dim) - ctx.save_for_backward(out, ptr) - ctx.dim = dim - - return out - - @staticmethod - def backward(ctx, out_grad: Tensor) -> Tuple[Union[Tensor, int]]: - out, ptr = ctx.saved_tensors - in_grad = torch.ops.pyg.softmax_csr_backward(out, out_grad, ptr, - ctx.dim) - - return in_grad, None, None - - def softmax_csr( src: Tensor, ptr: Tensor, @@ -384,7 +361,7 @@ def softmax_csr( [0.7792, 0.3502, 0.1638, 0.2145]]) """ dim = dim + src.dim() if dim < 0 else dim - return Softmax.apply(src, ptr, dim) + return torch.ops.pyg.softmax_csr(src, ptr, dim) __all__ = [ diff --git a/test/csrc/ops/test_softmax.cpp b/test/csrc/ops/test_softmax.cpp index 476cd51ee..c798f13a5 100644 --- a/test/csrc/ops/test_softmax.cpp +++ b/test/csrc/ops/test_softmax.cpp @@ -34,13 +34,13 @@ TEST_P(CPUTest, SoftmaxCSRForward) { const auto ptr = at::tensor({0, 3, 4, 7, 8}, at::kLong); const auto expected_out = softmax2D_ref_impl(src, ptr, dim); - const auto out = pyg::ops::softmax_csr_forward(src, ptr, dim); + const auto out = pyg::ops::softmax_csr(src, ptr, dim); EXPECT_EQ(expected_out.size(0), out.size(0)); EXPECT_EQ(expected_out.size(1), out.size(1)); EXPECT_TRUE(at::allclose(expected_out, out, 1e-04, 1e-04)); } -TEST_P(CPUTest, SoftmaxCSRBackward) { +TEST_P(CPUTest, SoftmaxCSRAutogradBackward) { const auto dim = ::testing::TestWithParam::GetParam(); const auto src = at::rand({8, 8}); src.set_requires_grad(true); @@ -48,11 +48,21 @@ TEST_P(CPUTest, SoftmaxCSRBackward) { const auto out = softmax2D_ref_impl(src, ptr, dim); const auto out_grad = at::rand({8, 8}); + // use softmax_csr_backward directly const auto in_grad = pyg::ops::softmax_csr_backward(out, out_grad, ptr, dim); out.backward(out_grad); EXPECT_EQ(src.grad().size(0), in_grad.size(0)); EXPECT_EQ(src.grad().size(1), in_grad.size(1)); EXPECT_TRUE(at::allclose(src.grad(), in_grad, 1e-04, 1e-04)); + + // use softmax backward via autograd module + const auto src2 = src.detach().clone(); + src2.set_requires_grad(true); + const auto out2 = pyg::ops::softmax_csr(src2, ptr, dim); + out2.backward(out_grad); + EXPECT_EQ(src.grad().size(0), src2.grad().size(0)); + EXPECT_EQ(src.grad().size(1), src2.grad().size(1)); + EXPECT_TRUE(at::allclose(src.grad(), src2.grad(), 1e-04, 1e-04)); } INSTANTIATE_TEST_SUITE_P(OpsTest,