From d1f738a4952713418d7c8c5caad3e3b239431584 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 15 May 2024 11:03:34 -0700 Subject: [PATCH] simpler abs and sign --- tinygrad/tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index ab0b784007db..3acd10be9f21 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1151,8 +1151,8 @@ def round(self: Tensor) -> Tensor: def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor: return self + (end - self) * weight def square(self): return self*self def clip(self, min_, max_): return self.maximum(min_).minimum(max_) - def abs(self): return self.relu() + (-self).relu() - def sign(self): return ((self.float()) / (self.float().abs() + 1e-12)).cast(self.dtype) + def abs(self): return (self < 0).detach().where(-self, self) + def sign(self): return (self == 0).detach().where(0, self / self.abs()).cast(self.dtype) def reciprocal(self): return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype))) # ***** activation functions (unary) *****