Skip to content

Commit

Permalink
Move softmax_csr autograd module to C++
Browse files Browse the repository at this point in the history
  • Loading branch information
DamianSzwichtenberg committed Nov 20, 2023
1 parent d23da87 commit 651045e
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 25 deletions.
60 changes: 60 additions & 0 deletions pyg_lib/csrc/ops/autograd/softmax_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include "../softmax.h"

#include <torch/autograd.h>

namespace pyg {
namespace ops {

namespace {

using torch::autograd::Variable;
using torch::autograd::variable_list;

class SoftmaxCSR : public torch::autograd::Function<SoftmaxCSR> {
public:
static variable_list forward(torch::autograd::AutogradContext* ctx,
const Variable& src,
const at::Tensor& ptr,
const int64_t dim) {
at::AutoDispatchBelowADInplaceOrView g;

Variable out = softmax_csr_forward(src, ptr, dim);
ctx->saved_data["dim"] = dim;
ctx->save_for_backward({src, out, ptr});

return {out};
}

static variable_list backward(torch::autograd::AutogradContext* ctx,
variable_list out_grads) {
const auto out_grad = out_grads[0];
const auto saved = ctx->get_saved_variables();
const auto src = saved[0];
const auto out = saved[1];
const auto ptr = saved[2];
const auto dim = ctx->saved_data["dim"].toInt();

auto src_grad = Variable();
if (torch::autograd::any_variable_requires_grad({src})) {
src_grad = softmax_csr_backward(out, out_grad, ptr, dim);
}

return {src_grad, Variable(), Variable()};
}
};

at::Tensor softmax_csr_autograd(const at::Tensor& src,
const at::Tensor& ptr,
const int64_t dim) {
return SoftmaxCSR::apply(src, ptr, dim)[0];
}

} // namespace

TORCH_LIBRARY_IMPL(pyg, Autograd, m) {
m.impl(TORCH_SELECTIVE_NAME("pyg::softmax_csr_forward"),
TORCH_FN(softmax_csr_autograd));
}

} // namespace ops
} // namespace pyg
25 changes: 1 addition & 24 deletions pyg_lib/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,29 +331,6 @@ def index_sort(
return torch.ops.pyg.index_sort(inputs, max_value)


class Softmax(torch.autograd.Function):
@staticmethod
def forward(
ctx,
src: Tensor,
ptr: Tensor,
dim: int = 0,
) -> Tensor:
out = torch.ops.pyg.softmax_csr_forward(src, ptr, dim)
ctx.save_for_backward(out, ptr)
ctx.dim = dim

return out

@staticmethod
def backward(ctx, out_grad: Tensor) -> Tuple[Union[Tensor, int]]:
out, ptr = ctx.saved_tensors
in_grad = torch.ops.pyg.softmax_csr_backward(out, out_grad, ptr,
ctx.dim)

return in_grad, None, None


def softmax_csr(
src: Tensor,
ptr: Tensor,
Expand Down Expand Up @@ -384,7 +361,7 @@ def softmax_csr(
[0.7792, 0.3502, 0.1638, 0.2145]])
"""
dim = dim + src.dim() if dim < 0 else dim
return Softmax.apply(src, ptr, dim)
return torch.ops.pyg.softmax_csr_forward(src, ptr, dim)


__all__ = [
Expand Down
12 changes: 11 additions & 1 deletion test/csrc/ops/test_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,29 @@ TEST_P(CPUTest, SoftmaxCSRForward) {
EXPECT_TRUE(at::allclose(expected_out, out, 1e-04, 1e-04));
}

TEST_P(CPUTest, SoftmaxCSRBackward) {
TEST_P(CPUTest, SoftmaxCSRAutogradBackward) {
const auto dim = ::testing::TestWithParam<int64_t>::GetParam();
const auto src = at::rand({8, 8});
src.set_requires_grad(true);
const auto ptr = at::tensor({0, 3, 4, 7, 8}, at::kLong);
const auto out = softmax2D_ref_impl(src, ptr, dim);
const auto out_grad = at::rand({8, 8});

// use softmax_csr_backward directly
const auto in_grad = pyg::ops::softmax_csr_backward(out, out_grad, ptr, dim);
out.backward(out_grad);
EXPECT_EQ(src.grad().size(0), in_grad.size(0));
EXPECT_EQ(src.grad().size(1), in_grad.size(1));
EXPECT_TRUE(at::allclose(src.grad(), in_grad, 1e-04, 1e-04));

// use softmax backward via autograd module
const auto src2 = src.detach().clone();
src2.set_requires_grad(true);
const auto out2 = pyg::ops::softmax_csr_forward(src2, ptr, dim);
out2.backward(out_grad);
EXPECT_EQ(src.grad().size(0), src2.grad().size(0));
EXPECT_EQ(src.grad().size(1), src2.grad().size(1));
EXPECT_TRUE(at::allclose(src.grad(), src2.grad(), 1e-04, 1e-04));
}

INSTANTIATE_TEST_SUITE_P(OpsTest,
Expand Down

0 comments on commit 651045e

Please sign in to comment.