From 0a2e10be1d944a33c8bfa2e46d146e90c0116c4f Mon Sep 17 00:00:00 2001 From: geohotstan <135171913+geohotstan@users.noreply.github.com> Date: Mon, 2 Dec 2024 23:04:01 +0800 Subject: [PATCH] add SELU to Tensor (#7993) * add selu * more clean ups --- docs/tensor/elementwise.md | 1 + extra/onnx_ops.py | 3 +-- test/test_ops.py | 3 +++ tinygrad/tensor.py | 13 +++++++++++++ 4 files changed, 18 insertions(+), 2 deletions(-) diff --git a/docs/tensor/elementwise.md b/docs/tensor/elementwise.md index fc5bd70d4e735..3d3858ad79281 100644 --- a/docs/tensor/elementwise.md +++ b/docs/tensor/elementwise.md @@ -37,6 +37,7 @@ Elementwise ops operate on a per element basis. They don't change the shape of t ::: tinygrad.Tensor.hardsigmoid ::: tinygrad.Tensor.elu ::: tinygrad.Tensor.celu +::: tinygrad.Tensor.selu ::: tinygrad.Tensor.swish ::: tinygrad.Tensor.silu ::: tinygrad.Tensor.relu6 diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 10d58b42f48fe..ae5afedca2600 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -8,7 +8,7 @@ tensor_methods = {"Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Asin", "Acos", "Atan","Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh", "Softsign", - "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Xor", "Round", "Erf"} + "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf"} # **************** Free Ops **************** @@ -44,7 +44,6 @@ def Constant(value:Optional[Tensor]=None, value_float=None, value_floats=None, v def HardSigmoid(x: Tensor, alpha=0.2, beta=0.5): return (alpha*x + beta).clip(0, 1) def Gelu(x:Tensor, approximate=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf()) -def Selu(X: Tensor, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu()) def PRelu(X:Tensor, slope:Tensor): slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE return (X > 0).where(X, X * slope) diff --git a/test/test_ops.py b/test/test_ops.py index 2dd89ca7e3a13..a376eea19af2a 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -697,6 +697,9 @@ def test_celu(self): for val in range(1, 5): helper_test_op([(45,65)], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val)) helper_test_op([()], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val)) + def test_selu(self): + helper_test_op([(45,65)], torch.nn.functional.selu, Tensor.selu) + helper_test_op([()], torch.nn.functional.selu, Tensor.selu) def test_abs(self): helper_test_op([(45,65)], torch.abs, Tensor.abs) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 30a1f5f60c036..24f076cc41f39 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2688,6 +2688,19 @@ def celu(self, alpha=1.0): """ return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0) + def selu(self, alpha=1.67326, gamma=1.0507): + """ + Applies the Scaled Exponential Linear Unit (SELU) function element-wise. + + - Described: https://paperswithcode.com/method/selu + - Paper: https://arxiv.org/abs/1706.02515v5 + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).selu().numpy()) + ``` + """ + return gamma * (self >= 0).detach().where(self, alpha * (self.exp() - 1)) + def swish(self): """ See `.silu()`