Skip to content

Commit

Permalink
add SELU to Tensor (tinygrad#7993)
Browse files Browse the repository at this point in the history
* add selu

* more clean ups
  • Loading branch information
geohotstan authored Dec 2, 2024
1 parent 146e1ca commit 0a2e10b
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/tensor/elementwise.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Elementwise ops operate on a per element basis. They don't change the shape of t
::: tinygrad.Tensor.hardsigmoid
::: tinygrad.Tensor.elu
::: tinygrad.Tensor.celu
::: tinygrad.Tensor.selu
::: tinygrad.Tensor.swish
::: tinygrad.Tensor.silu
::: tinygrad.Tensor.relu6
Expand Down
3 changes: 1 addition & 2 deletions extra/onnx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

tensor_methods = {"Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Asin", "Acos", "Atan","Relu",
"Sigmoid", "MatMul", "Floor", "Ceil", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh", "Softsign",
"Asinh", "Acosh", "Atanh", "Elu", "Celu", "Xor", "Round", "Erf"}
"Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf"}

# **************** Free Ops ****************

Expand Down Expand Up @@ -44,7 +44,6 @@ def Constant(value:Optional[Tensor]=None, value_float=None, value_floats=None, v

def HardSigmoid(x: Tensor, alpha=0.2, beta=0.5): return (alpha*x + beta).clip(0, 1)
def Gelu(x:Tensor, approximate=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf())
def Selu(X: Tensor, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu())
def PRelu(X:Tensor, slope:Tensor):
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE
return (X > 0).where(X, X * slope)
Expand Down
3 changes: 3 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,9 @@ def test_celu(self):
for val in range(1, 5):
helper_test_op([(45,65)], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val))
helper_test_op([()], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val))
def test_selu(self):
helper_test_op([(45,65)], torch.nn.functional.selu, Tensor.selu)
helper_test_op([()], torch.nn.functional.selu, Tensor.selu)

def test_abs(self):
helper_test_op([(45,65)], torch.abs, Tensor.abs)
Expand Down
13 changes: 13 additions & 0 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2688,6 +2688,19 @@ def celu(self, alpha=1.0):
"""
return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)

def selu(self, alpha=1.67326, gamma=1.0507):
"""
Applies the Scaled Exponential Linear Unit (SELU) function element-wise.
- Described: https://paperswithcode.com/method/selu
- Paper: https://arxiv.org/abs/1706.02515v5
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).selu().numpy())
```
"""
return gamma * (self >= 0).detach().where(self, alpha * (self.exp() - 1))

def swish(self):
"""
See `.silu()`
Expand Down

0 comments on commit 0a2e10b

Please sign in to comment.