diff --git a/test/test_dtype.py b/test/test_dtype.py index aeba03adb48c..571c8ab4afcd 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -660,8 +660,15 @@ def test_functions(self): ]: t = func(Tensor([4.0, 3.0])).max() == func(Tensor([4.0, 3.0])) result.append(t.numpy().sum()) - assert all(result) +class TestTensorMethod(unittest.TestCase): + @given(strat.sampled_from(core_dtypes)) + def test_abs_diff(self, dt): + if dt == dtypes.bool or not is_dtype_supported(dt): return + a, b = Tensor([2], dtype=dt), Tensor([1], dtype=dt) + ret = (a - b).abs() + np.testing.assert_allclose(ret.numpy(), np.abs(a.numpy()-b.numpy())) + if __name__ == '__main__': unittest.main()