forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SoftShrink.cu
44 lines (37 loc) · 937 Bytes
/
SoftShrink.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include <THCUNN/THCUNN.h>
#include <TH/THHalf.h>
#include <THC/THCNumerics.cuh>
#include <THC/THCApply.cuh>
template <typename T>
struct SoftShrinkUpdateOutput
{
const T lambda_;
SoftShrinkUpdateOutput(T lambda)
: lambda_(lambda)
{}
__device__ __forceinline__ void operator()(T *out, T *in)
{
T x = *in;
if (x > lambda_) *out = x - lambda_;
else if (x < -lambda_) *out = x + lambda_;
else *out = ScalarConvert<int, T>::to(0);
}
};
template <typename T>
struct SoftShrinkUpdateGradInput
{
const T lambda_;
SoftShrinkUpdateGradInput(T lambda)
: lambda_(lambda)
{}
__device__ __forceinline__ void operator()(T *gradInput, T *input, T *gradOutput) const
{
T x = *input;
if (x > lambda_ || x < -lambda_)
*gradInput = *gradOutput;
else
*gradInput = ScalarConvert<int, T>::to(0);
}
};
#include <THCUNN/generic/SoftShrink.cu>
#include <THC/THCGenerateFloatTypes.h>