Skip to content

Commit

Permalink
Change softmax_csr_forward to softmax_csr
Browse files Browse the repository at this point in the history
  • Loading branch information
DamianSzwichtenberg committed Nov 20, 2023
1 parent 052fd8e commit 41a545c
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 16 deletions.
4 changes: 2 additions & 2 deletions pyg_lib/csrc/ops/autograd/softmax_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class SoftmaxCSR : public torch::autograd::Function<SoftmaxCSR> {
const int64_t dim) {
at::AutoDispatchBelowADInplaceOrView g;

Variable out = softmax_csr_forward(src, ptr, dim);
Variable out = softmax_csr(src, ptr, dim);
ctx->saved_data["dim"] = dim;
ctx->save_for_backward({src, out, ptr});

Expand Down Expand Up @@ -52,7 +52,7 @@ at::Tensor softmax_csr_autograd(const at::Tensor& src,
} // namespace

TORCH_LIBRARY_IMPL(pyg, Autograd, m) {
m.impl(TORCH_SELECTIVE_NAME("pyg::softmax_csr_forward"),
m.impl(TORCH_SELECTIVE_NAME("pyg::softmax_csr"),
TORCH_FN(softmax_csr_autograd));
}

Expand Down
2 changes: 1 addition & 1 deletion pyg_lib/csrc/ops/cpu/softmax_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ at::Tensor softmax_csr_backward_kernel(const at::Tensor& out,
} // namespace

TORCH_LIBRARY_IMPL(pyg, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("pyg::softmax_csr_forward"),
m.impl(TORCH_SELECTIVE_NAME("pyg::softmax_csr"),
TORCH_FN(softmax_csr_forward_kernel));
m.impl(TORCH_SELECTIVE_NAME("pyg::softmax_csr_backward"),
TORCH_FN(softmax_csr_backward_kernel));
Expand Down
16 changes: 8 additions & 8 deletions pyg_lib/csrc/ops/softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@ namespace pyg {
namespace ops {

// Performs softmax operations for each group.
PYG_API at::Tensor softmax_csr_forward(const at::Tensor& src,
const at::Tensor& ptr,
const int64_t dim) {
PYG_API at::Tensor softmax_csr(const at::Tensor& src,
const at::Tensor& ptr,
const int64_t dim) {
at::TensorArg src_arg{src, "src", 0};
at::TensorArg ptr_arg{ptr, "ptr", 1};
at::CheckedFrom c{"softmax_forward"};
at::CheckedFrom c{"softmax_csr"};

at::checkAllDefined(c, {src_arg, ptr_arg});
at::checkContiguous(c, src_arg);
at::checkContiguous(c, ptr_arg);

static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::softmax_csr_forward", "")
.typed<decltype(softmax_csr_forward)>();
.findSchemaOrThrow("pyg::softmax_csr", "")
.typed<decltype(softmax_csr)>();
return op.call(src, ptr, dim);
}

Expand All @@ -32,7 +32,7 @@ PYG_API at::Tensor softmax_csr_backward(const at::Tensor& out,
at::TensorArg out_arg{out, "out", 0};
at::TensorArg out_grad_arg{out_grad, "out_grad", 1};
at::TensorArg ptr_arg{ptr, "ptr", 2};
at::CheckedFrom c{"softmax_backward"};
at::CheckedFrom c{"softmax_csr_backward"};

at::checkAllDefined(c, {out_arg, out_grad_arg, ptr_arg});
at::checkContiguous(c, out_arg);
Expand All @@ -47,7 +47,7 @@ PYG_API at::Tensor softmax_csr_backward(const at::Tensor& out,

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.def(
TORCH_SELECTIVE_SCHEMA("pyg::softmax_csr_forward(Tensor src, Tensor ptr, "
TORCH_SELECTIVE_SCHEMA("pyg::softmax_csr(Tensor src, Tensor ptr, "
"int dim=0) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::softmax_csr_backward(Tensor out, Tensor out_grad, "
Expand Down
6 changes: 3 additions & 3 deletions pyg_lib/csrc/ops/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ namespace pyg {
namespace ops {

// Performs softmax operations for each group.
PYG_API at::Tensor softmax_csr_forward(const at::Tensor& src,
const at::Tensor& ptr,
const int64_t dim = 0);
PYG_API at::Tensor softmax_csr(const at::Tensor& src,
const at::Tensor& ptr,
const int64_t dim = 0);

// Computes gradient for grouped softmax operations.
PYG_API at::Tensor softmax_csr_backward(const at::Tensor& out,
Expand Down
4 changes: 2 additions & 2 deletions test/csrc/ops/test_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ TEST_P(CPUTest, SoftmaxCSRForward) {
const auto ptr = at::tensor({0, 3, 4, 7, 8}, at::kLong);
const auto expected_out = softmax2D_ref_impl(src, ptr, dim);

const auto out = pyg::ops::softmax_csr_forward(src, ptr, dim);
const auto out = pyg::ops::softmax_csr(src, ptr, dim);
EXPECT_EQ(expected_out.size(0), out.size(0));
EXPECT_EQ(expected_out.size(1), out.size(1));
EXPECT_TRUE(at::allclose(expected_out, out, 1e-04, 1e-04));
Expand All @@ -58,7 +58,7 @@ TEST_P(CPUTest, SoftmaxCSRAutogradBackward) {
// use softmax backward via autograd module
const auto src2 = src.detach().clone();
src2.set_requires_grad(true);
const auto out2 = pyg::ops::softmax_csr_forward(src2, ptr, dim);
const auto out2 = pyg::ops::softmax_csr(src2, ptr, dim);
out2.backward(out_grad);
EXPECT_EQ(src.grad().size(0), src2.grad().size(0));
EXPECT_EQ(src.grad().size(1), src2.grad().size(1));
Expand Down

0 comments on commit 41a545c

Please sign in to comment.