-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move
softmax_csr
autograd module to C++
- Loading branch information
1 parent
d23da87
commit 651045e
Showing
3 changed files
with
72 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
#include "../softmax.h" | ||
|
||
#include <torch/autograd.h> | ||
|
||
namespace pyg { | ||
namespace ops { | ||
|
||
namespace { | ||
|
||
using torch::autograd::Variable; | ||
using torch::autograd::variable_list; | ||
|
||
class SoftmaxCSR : public torch::autograd::Function<SoftmaxCSR> { | ||
public: | ||
static variable_list forward(torch::autograd::AutogradContext* ctx, | ||
const Variable& src, | ||
const at::Tensor& ptr, | ||
const int64_t dim) { | ||
at::AutoDispatchBelowADInplaceOrView g; | ||
|
||
Variable out = softmax_csr_forward(src, ptr, dim); | ||
ctx->saved_data["dim"] = dim; | ||
ctx->save_for_backward({src, out, ptr}); | ||
|
||
return {out}; | ||
} | ||
|
||
static variable_list backward(torch::autograd::AutogradContext* ctx, | ||
variable_list out_grads) { | ||
const auto out_grad = out_grads[0]; | ||
const auto saved = ctx->get_saved_variables(); | ||
const auto src = saved[0]; | ||
const auto out = saved[1]; | ||
const auto ptr = saved[2]; | ||
const auto dim = ctx->saved_data["dim"].toInt(); | ||
|
||
auto src_grad = Variable(); | ||
if (torch::autograd::any_variable_requires_grad({src})) { | ||
src_grad = softmax_csr_backward(out, out_grad, ptr, dim); | ||
} | ||
|
||
return {src_grad, Variable(), Variable()}; | ||
} | ||
}; | ||
|
||
at::Tensor softmax_csr_autograd(const at::Tensor& src, | ||
const at::Tensor& ptr, | ||
const int64_t dim) { | ||
return SoftmaxCSR::apply(src, ptr, dim)[0]; | ||
} | ||
|
||
} // namespace | ||
|
||
TORCH_LIBRARY_IMPL(pyg, Autograd, m) { | ||
m.impl(TORCH_SELECTIVE_NAME("pyg::softmax_csr_forward"), | ||
TORCH_FN(softmax_csr_autograd)); | ||
} | ||
|
||
} // namespace ops | ||
} // namespace pyg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters