Skip to content

Commit

Permalink
failed test case for int pow (tinygrad#8128)
Browse files Browse the repository at this point in the history
also updated test_ops so that non-float compares with `assert_equal`. removed `test_multinomial` which is tested better in test_randomness
  • Loading branch information
chenyuxyz authored Dec 9, 2024
1 parent 80de06c commit 12f7d28
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def compare(s, tinygrad_output, torch_output, atol, rtol):
try:
assert tinygrad_output.shape == torch_output.shape, f"shape mismatch: tinygrad={tinygrad_output.shape} | torch={torch_output.shape}"
assert tinygrad_output.dtype == torch_output.dtype, f"dtype mismatch: tinygrad={tinygrad_output.dtype} | torch={torch_output.dtype}"
np.testing.assert_allclose(tinygrad_output, torch_output, atol=atol, rtol=rtol)
if np.issubdtype(tinygrad_output.dtype, np.floating):
np.testing.assert_allclose(tinygrad_output, torch_output, atol=atol, rtol=rtol)
else:
np.testing.assert_equal(tinygrad_output, torch_output)
except Exception as e:
raise Exception(f"{s} failed shape {tinygrad_output.shape}: {e}")

Expand Down Expand Up @@ -573,6 +576,12 @@ def _test(base, exponent): helper_test_op(None, lambda x,y: x**y, vals=[base, ex
# NOTE: torch 0 ** -1 is 0
_test([0, 0, 0], [0, 1, 2])

np.testing.assert_equal((Tensor(11) ** Tensor(7)).item(), 11 ** 7)
np.testing.assert_equal((Tensor([11]) ** Tensor(7)).item(), 11 ** 7)
# TODO: fix non-precise int pow
with self.assertRaises(AssertionError): np.testing.assert_equal((Tensor(11) ** Tensor([7])).item(), 11 ** 7)
with self.assertRaises(AssertionError): np.testing.assert_equal((Tensor([11]) ** Tensor([7])).item(), 11 ** 7)

def test_sqrt(self):
helper_test_op([(45,65)], lambda x: x.sqrt())
helper_test_op([()], lambda x: x.sqrt())
Expand Down Expand Up @@ -796,11 +805,6 @@ def test_mish(self):
helper_test_op([(45,65)], torch.nn.functional.mish, Tensor.mish)
helper_test_op([()], torch.nn.functional.mish, Tensor.mish)

def test_multinomial(self):
# NOTE: this is random, so it has a very large atol
helper_test_op([(1000,)], lambda x: torch.multinomial(x.clip(0,1), num_samples=1).type(torch.int32),
lambda x: Tensor.multinomial(x.clip(0,1)), forward_only=True, atol=1000.)

def test_small_cumsum(self):
helper_test_op([(10)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0))
def test_simple_cumsum(self):
Expand Down

0 comments on commit 12f7d28

Please sign in to comment.