diff --git a/test/test_ops.py b/test/test_ops.py index c56a4dbf5482..87f66e8b7fe8 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -487,6 +487,8 @@ def test_celu(self): def test_abs(self): helper_test_op([(45,65)], torch.abs, Tensor.abs) helper_test_op([()], torch.abs, Tensor.abs) + def test_abs_exact(self): + helper_test_op(None, torch.abs, Tensor.abs, vals=[[-1.,0,1]]) def test_log(self): helper_test_op([(45,65)], torch.log, Tensor.log) @@ -505,9 +507,14 @@ def test_exp2(self): def test_sign(self): helper_test_op([(45,65)], torch.sign, Tensor.sign) helper_test_op([()], torch.sign, Tensor.sign) + def test_sign_exact(self): + helper_test_op(None, torch.sign, Tensor.sign, vals=[[-1.,0,1]]) + def test_softsign(self): helper_test_op([(45,65)], torch.nn.functional.softsign, Tensor.softsign) helper_test_op([()], torch.nn.functional.softsign, Tensor.softsign) + def test_softsign_exact(self): + helper_test_op(None, torch.nn.functional.softsign, Tensor.softsign, vals=[[-1.,0,1]]) def test_sigmoid(self): helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid) diff --git a/tinygrad/function.py b/tinygrad/function.py index 462139ecd38e..189f82bc2643 100644 --- a/tinygrad/function.py +++ b/tinygrad/function.py @@ -86,6 +86,13 @@ def forward(self, x:LazyBuffer) -> LazyBuffer: def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.e(BinaryOps.MUL, self.ret.const(1).e(BinaryOps.SUB, self.ret)).e(BinaryOps.MUL, grad_output) +class Sign(Function): + def forward(self, x:LazyBuffer) -> LazyBuffer: + return x.e(BinaryOps.CMPEQ, x.const(0)).e(TernaryOps.WHERE, x.const(0), + x.e(BinaryOps.CMPLT, x.const(0)).e(TernaryOps.WHERE, x.const(-1), x.const(1))) + # backward always return 0 to match torch + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const(0) + # ************* binary ops ************* class Less(Function): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index ab0b784007db..c768b57879cf 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1151,8 +1151,8 @@ def round(self: Tensor) -> Tensor: def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor: return self + (end - self) * weight def square(self): return self*self def clip(self, min_, max_): return self.maximum(min_).minimum(max_) - def abs(self): return self.relu() + (-self).relu() - def sign(self): return ((self.float()) / (self.float().abs() + 1e-12)).cast(self.dtype) + def sign(self): return F.Sign.apply(self) + def abs(self): return self * self.sign() def reciprocal(self): return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype))) # ***** activation functions (unary) *****