diff --git a/test/test_dtype.py b/test/test_dtype.py index 216952cb3338f..acaaf1ccc560e 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -650,8 +650,15 @@ def test_functions(self): ]: t = func(Tensor([4.0, 3.0])).max() == func(Tensor([4.0, 3.0])) result.append(t.numpy().sum()) - assert all(result) +class TestTensorMethod(unittest.TestCase): + @given(strat.sampled_from(core_dtypes)) + def test_abs_diff(self, dt): + if dt == dtypes.bool or not is_dtype_supported(dt): return + a, b = Tensor([2], dtype=dt), Tensor([1], dtype=dt) + ret = (a - b).abs() + np.testing.assert_allclose(ret.numpy(), np.abs(a.numpy()-b.numpy())) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 1cb4b2657e5b0..00ebf5dccdf87 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1140,7 +1140,10 @@ def round(self: Tensor) -> Tensor: def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor: return self + (end - self) * weight def square(self): return self*self def clip(self, min_, max_): return self.maximum(min_).minimum(max_) - def abs(self): return self.relu() + (-self).relu() + def abs(self): + if dtypes.is_unsigned(self.dtype) or self.dtype == dtypes.bool: return self + if dtypes.is_int(self.dtype): return (self < 0).where(0, self) + return self.relu() + (-self).relu() def sign(self): return ((self.float()) / (self.float().abs() + 1e-12)).cast(self.dtype) def reciprocal(self): return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype)))