Skip to content

Commit

Permalink
fix abs of diff of uint
Browse files Browse the repository at this point in the history
redo
  • Loading branch information
chenyuxyz committed May 15, 2024
1 parent 2119e04 commit b8091a0
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion test/test_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,8 +660,15 @@ def test_functions(self):
]:
t = func(Tensor([4.0, 3.0])).max() == func(Tensor([4.0, 3.0]))
result.append(t.numpy().sum())

assert all(result)

class TestTensorMethod(unittest.TestCase):
@given(strat.sampled_from(core_dtypes))
def test_abs_diff(self, dt):
if dt == dtypes.bool or not is_dtype_supported(dt): return
a, b = Tensor([2], dtype=dt), Tensor([1], dtype=dt)
ret = (a - b).abs()
np.testing.assert_allclose(ret.numpy(), np.abs(a.numpy()-b.numpy()))

if __name__ == '__main__':
unittest.main()

0 comments on commit b8091a0

Please sign in to comment.