Skip to content

Commit

Permalink
Move softmax_csr autograd module to C++ (#282)
Browse files Browse the repository at this point in the history
Co-authored-by: Matthias Fey <[email protected]>
  • Loading branch information
DamianSzwichtenberg and rusty1s authored Nov 20, 2023
1 parent d23da87 commit 28f84eb
Showing 7 changed files with 87 additions and 40 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.4.0] - 2023-MM-DD
### Added
- Added `softmax_csr` implementation ([#264](https://github.com/pyg-team/pyg-lib/pull/264))
- Added `softmax_csr` implementation ([#264](https://github.com/pyg-team/pyg-lib/pull/264), [#282](https://github.com/pyg-team/pyg-lib/pull/282))
- Added support for edge-level sampling ([#280](https://github.com/pyg-team/pyg-lib/pull/280))
- Added support for `bfloat16` data type in `segment_matmul` and `grouped_matmul` (CPU only) ([#272](https://github.com/pyg-team/pyg-lib/pull/272))
### Changed
60 changes: 60 additions & 0 deletions pyg_lib/csrc/ops/autograd/softmax_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include "../softmax.h"

#include <torch/autograd.h>

namespace pyg {
namespace ops {

namespace {

using torch::autograd::Variable;
using torch::autograd::variable_list;

class SoftmaxCSR : public torch::autograd::Function<SoftmaxCSR> {
public:
static variable_list forward(torch::autograd::AutogradContext* ctx,
const Variable& src,
const at::Tensor& ptr,
const int64_t dim) {
at::AutoDispatchBelowADInplaceOrView g;

Variable out = softmax_csr(src, ptr, dim);
ctx->saved_data["dim"] = dim;
ctx->save_for_backward({src, out, ptr});

return {out};
}

static variable_list backward(torch::autograd::AutogradContext* ctx,
variable_list out_grads) {
const auto out_grad = out_grads[0];
const auto saved = ctx->get_saved_variables();
const auto src = saved[0];
const auto out = saved[1];
const auto ptr = saved[2];
const auto dim = ctx->saved_data["dim"].toInt();

auto src_grad = Variable();
if (torch::autograd::any_variable_requires_grad({src})) {
src_grad = softmax_csr_backward(out, out_grad, ptr, dim);
}

return {src_grad, Variable(), Variable()};
}
};

at::Tensor softmax_csr_autograd(const at::Tensor& src,
const at::Tensor& ptr,
const int64_t dim) {
return SoftmaxCSR::apply(src, ptr, dim)[0];
}

} // namespace

TORCH_LIBRARY_IMPL(pyg, Autograd, m) {
m.impl(TORCH_SELECTIVE_NAME("pyg::softmax_csr"),
TORCH_FN(softmax_csr_autograd));
}

} // namespace ops
} // namespace pyg
2 changes: 1 addition & 1 deletion pyg_lib/csrc/ops/cpu/softmax_kernel.cpp
Original file line number Diff line number Diff line change
@@ -248,7 +248,7 @@ at::Tensor softmax_csr_backward_kernel(const at::Tensor& out,
} // namespace

TORCH_LIBRARY_IMPL(pyg, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("pyg::softmax_csr_forward"),
m.impl(TORCH_SELECTIVE_NAME("pyg::softmax_csr"),
TORCH_FN(softmax_csr_forward_kernel));
m.impl(TORCH_SELECTIVE_NAME("pyg::softmax_csr_backward"),
TORCH_FN(softmax_csr_backward_kernel));
16 changes: 8 additions & 8 deletions pyg_lib/csrc/ops/softmax.cpp
Original file line number Diff line number Diff line change
@@ -7,20 +7,20 @@ namespace pyg {
namespace ops {

// Performs softmax operations for each group.
PYG_API at::Tensor softmax_csr_forward(const at::Tensor& src,
const at::Tensor& ptr,
const int64_t dim) {
PYG_API at::Tensor softmax_csr(const at::Tensor& src,
const at::Tensor& ptr,
const int64_t dim) {
at::TensorArg src_arg{src, "src", 0};
at::TensorArg ptr_arg{ptr, "ptr", 1};
at::CheckedFrom c{"softmax_forward"};
at::CheckedFrom c{"softmax_csr"};

at::checkAllDefined(c, {src_arg, ptr_arg});
at::checkContiguous(c, src_arg);
at::checkContiguous(c, ptr_arg);

static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::softmax_csr_forward", "")
.typed<decltype(softmax_csr_forward)>();
.findSchemaOrThrow("pyg::softmax_csr", "")
.typed<decltype(softmax_csr)>();
return op.call(src, ptr, dim);
}

@@ -32,7 +32,7 @@ PYG_API at::Tensor softmax_csr_backward(const at::Tensor& out,
at::TensorArg out_arg{out, "out", 0};
at::TensorArg out_grad_arg{out_grad, "out_grad", 1};
at::TensorArg ptr_arg{ptr, "ptr", 2};
at::CheckedFrom c{"softmax_backward"};
at::CheckedFrom c{"softmax_csr_backward"};

at::checkAllDefined(c, {out_arg, out_grad_arg, ptr_arg});
at::checkContiguous(c, out_arg);
@@ -47,7 +47,7 @@ PYG_API at::Tensor softmax_csr_backward(const at::Tensor& out,

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.def(
TORCH_SELECTIVE_SCHEMA("pyg::softmax_csr_forward(Tensor src, Tensor ptr, "
TORCH_SELECTIVE_SCHEMA("pyg::softmax_csr(Tensor src, Tensor ptr, "
"int dim=0) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::softmax_csr_backward(Tensor out, Tensor out_grad, "
6 changes: 3 additions & 3 deletions pyg_lib/csrc/ops/softmax.h
Original file line number Diff line number Diff line change
@@ -7,9 +7,9 @@ namespace pyg {
namespace ops {

// Performs softmax operations for each group.
PYG_API at::Tensor softmax_csr_forward(const at::Tensor& src,
const at::Tensor& ptr,
const int64_t dim = 0);
PYG_API at::Tensor softmax_csr(const at::Tensor& src,
const at::Tensor& ptr,
const int64_t dim = 0);

// Computes gradient for grouped softmax operations.
PYG_API at::Tensor softmax_csr_backward(const at::Tensor& out,
27 changes: 2 additions & 25 deletions pyg_lib/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple

import torch
import torch.utils._pytree as pytree
@@ -331,29 +331,6 @@ def index_sort(
return torch.ops.pyg.index_sort(inputs, max_value)


class Softmax(torch.autograd.Function):
@staticmethod
def forward(
ctx,
src: Tensor,
ptr: Tensor,
dim: int = 0,
) -> Tensor:
out = torch.ops.pyg.softmax_csr_forward(src, ptr, dim)
ctx.save_for_backward(out, ptr)
ctx.dim = dim

return out

@staticmethod
def backward(ctx, out_grad: Tensor) -> Tuple[Union[Tensor, int]]:
out, ptr = ctx.saved_tensors
in_grad = torch.ops.pyg.softmax_csr_backward(out, out_grad, ptr,
ctx.dim)

return in_grad, None, None


def softmax_csr(
src: Tensor,
ptr: Tensor,
@@ -384,7 +361,7 @@ def softmax_csr(
[0.7792, 0.3502, 0.1638, 0.2145]])
"""
dim = dim + src.dim() if dim < 0 else dim
return Softmax.apply(src, ptr, dim)
return torch.ops.pyg.softmax_csr(src, ptr, dim)


__all__ = [
14 changes: 12 additions & 2 deletions test/csrc/ops/test_softmax.cpp
Original file line number Diff line number Diff line change
@@ -34,25 +34,35 @@ TEST_P(CPUTest, SoftmaxCSRForward) {
const auto ptr = at::tensor({0, 3, 4, 7, 8}, at::kLong);
const auto expected_out = softmax2D_ref_impl(src, ptr, dim);

const auto out = pyg::ops::softmax_csr_forward(src, ptr, dim);
const auto out = pyg::ops::softmax_csr(src, ptr, dim);
EXPECT_EQ(expected_out.size(0), out.size(0));
EXPECT_EQ(expected_out.size(1), out.size(1));
EXPECT_TRUE(at::allclose(expected_out, out, 1e-04, 1e-04));
}

TEST_P(CPUTest, SoftmaxCSRBackward) {
TEST_P(CPUTest, SoftmaxCSRAutogradBackward) {
const auto dim = ::testing::TestWithParam<int64_t>::GetParam();
const auto src = at::rand({8, 8});
src.set_requires_grad(true);
const auto ptr = at::tensor({0, 3, 4, 7, 8}, at::kLong);
const auto out = softmax2D_ref_impl(src, ptr, dim);
const auto out_grad = at::rand({8, 8});

// use softmax_csr_backward directly
const auto in_grad = pyg::ops::softmax_csr_backward(out, out_grad, ptr, dim);
out.backward(out_grad);
EXPECT_EQ(src.grad().size(0), in_grad.size(0));
EXPECT_EQ(src.grad().size(1), in_grad.size(1));
EXPECT_TRUE(at::allclose(src.grad(), in_grad, 1e-04, 1e-04));

// use softmax backward via autograd module
const auto src2 = src.detach().clone();
src2.set_requires_grad(true);
const auto out2 = pyg::ops::softmax_csr(src2, ptr, dim);
out2.backward(out_grad);
EXPECT_EQ(src.grad().size(0), src2.grad().size(0));
EXPECT_EQ(src.grad().size(1), src2.grad().size(1));
EXPECT_TRUE(at::allclose(src.grad(), src2.grad(), 1e-04, 1e-04));
}

INSTANTIATE_TEST_SUITE_P(OpsTest,

0 comments on commit 28f84eb

Please sign in to comment.