Skip to content

Commit

Permalink
fix abs of diff of uint
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyuxyz committed May 3, 2024
1 parent c736851 commit f4d51ad
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
9 changes: 8 additions & 1 deletion test/test_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,8 +650,15 @@ def test_functions(self):
]:
t = func(Tensor([4.0, 3.0])).max() == func(Tensor([4.0, 3.0]))
result.append(t.numpy().sum())

assert all(result)

class TestTensorMethod(unittest.TestCase):
@given(strat.sampled_from(core_dtypes))
def test_abs_diff(self, dt):
if dt == dtypes.bool or not is_dtype_supported(dt): return
a, b = Tensor([2], dtype=dt), Tensor([1], dtype=dt)
ret = (a - b).abs()
np.testing.assert_allclose(ret.numpy(), np.abs(a.numpy()-b.numpy()))

if __name__ == '__main__':
unittest.main()
5 changes: 4 additions & 1 deletion tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,7 +1140,10 @@ def round(self: Tensor) -> Tensor:
def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor: return self + (end - self) * weight
def square(self): return self*self
def clip(self, min_, max_): return self.maximum(min_).minimum(max_)
def abs(self): return self.relu() + (-self).relu()
def abs(self):
if dtypes.is_unsigned(self.dtype) or self.dtype == dtypes.bool: return self
if dtypes.is_int(self.dtype): return (self < 0).where(0, self)
return self.relu() + (-self).relu()
def sign(self): return ((self.float()) / (self.float().abs() + 1e-12)).cast(self.dtype)
def reciprocal(self): return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype)))

Expand Down

0 comments on commit f4d51ad

Please sign in to comment.