diff --git a/test/test_ops.py b/test/test_ops.py index 743e53f302f0..8ea760bc1b6f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -71,6 +71,9 @@ def prepare_test_op(low, high, shps, vals, forward_only=False): np.random.seed(0) np_data = [np.random.uniform(low=low, high=high, size=size).astype(_to_np_dtype(dtypes.default_float)) for size in shps] ts = [torch.tensor(data, requires_grad=(not forward_only)) for data in np_data] + for i in range(len(ts)): + # NOTE: torch default int64 for python ints input + if ts[i].dtype == torch.int64: ts[i] = ts[i].type(torch.int32) tst = [Tensor(x.detach().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY)) for x in ts] return ts, tst @@ -312,8 +315,7 @@ def test_where_permute(self): def _test_cmp(self, fxn, reverse=True): # test different dtypes helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0.,1,2], [2.,1,0]]) - if is_dtype_supported(dtypes.long): - helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0,1,2], [2,1,0]]) + helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0,1,2], [2,1,0]]) helper_test_op(None, fxn, fxn, forward_only=True, vals=[[True, True, False], [False,True,False]]) # test broadcasting for shps in [[(3, 4, 5), (3, 4, 5)], [(3, 4, 5), (5,)], [(5,), (3, 4, 5)]]: @@ -563,10 +565,7 @@ def test_pow_const(self): helper_test_op(None, lambda x: (-2)**x, vals=[[-2.,-1,0,1,2,3]], forward_only=True) def test_pow_int(self): - # TODO: better infra for these, helper_test_op creates buffer in long first, so WEBGPU fails - def _test(base, exponent): - np.testing.assert_equal((Tensor(base) ** Tensor(exponent)).numpy(), - (torch.tensor(base, dtype=torch.int) ** torch.tensor(exponent, dtype=torch.int)).numpy()) + def _test(base, exponent): helper_test_op(None, lambda x,y: x**y, vals=[base, exponent], forward_only=True) for base in ([1, 2, 3], [-1, -2, -3]): for exponent in ([2, 3, 4], [-2, -3, -4]): @@ -1098,9 +1097,8 @@ def test_min(self): helper_test_op([(45,3)], lambda x: x.min().mul(0.5)) helper_test_op([()], lambda x: x.min()) - if is_dtype_supported(dtypes.long): - helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[0, -2**31]]) - helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[-2**31, 0]]) + helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[0, -2**31]]) + helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[-2**31, 0]]) helper_test_op(None, lambda x: x.type(torch.bool).min(), lambda x: x.cast(dtypes.bool).min(), forward_only=True, vals=[[False, True]]) helper_test_op(None, lambda x: x.type(torch.bool).min(), lambda x: x.cast(dtypes.bool).min(), forward_only=True, vals=[[True, False]]) @@ -1111,9 +1109,8 @@ def test_max(self): helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: x.max(axis=1)) helper_test_op([()], lambda x: x.max()) - if is_dtype_supported(dtypes.long): - helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[0, -2**31]]) - helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[-2**31, 0]]) + helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[0, -2**31]]) + helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[-2**31, 0]]) helper_test_op(None, lambda x: x.type(torch.bool).max(), lambda x: x.cast(dtypes.bool).max(), forward_only=True, vals=[[False, True]]) helper_test_op(None, lambda x: x.type(torch.bool).max(), lambda x: x.cast(dtypes.bool).max(), forward_only=True, vals=[[True, False]]) @@ -1216,12 +1213,10 @@ def test_std_mean(self): helper_test_op([(3,4,5,6)], lambda x: torch.stack(torch.std_mean(x, axis=(1,2))), lambda x: Tensor.stack(*x.std_mean(axis=(1,2)))) def test_softmax(self): - # exceed per kernel buffer limit with backward - forward_only = (Device.DEFAULT == "WEBGPU") - helper_test_op([(45,65)], torch.nn.Softmax(dim=1), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only) - helper_test_op([(45)], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only) - helper_test_op([()], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only) - helper_test_op([()], torch.nn.Softmax(dim=-1), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only) + helper_test_op([(45,65)], torch.nn.Softmax(dim=1), Tensor.softmax, atol=1e-7, grad_atol=1e-7) + helper_test_op([(45)], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7) + helper_test_op([()], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7) + helper_test_op([()], torch.nn.Softmax(dim=-1), Tensor.softmax, atol=1e-7, grad_atol=1e-7) def test_softmax_other_axis(self): helper_test_op([(10,10,10)], lambda x: x.softmax(0), atol=1e-7, grad_atol=1e-7) helper_test_op([(10,10,10)], lambda x: x.softmax(1), atol=1e-7, grad_atol=1e-7) @@ -2246,7 +2241,7 @@ def test_matvec(self): helper_test_op([(1,128), (128,128)], lambda x,y: (x@y).relu()) @unittest.skip("this test is broken #862") - def test_max_inf(self): + def test_max_nan(self): n = Tensor([1, float("nan")]).max().numpy() assert math.isnan(n.item()), f"{n.item()} is not nan" @@ -2484,45 +2479,39 @@ def test_cross_entropy_smoothing(self): helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, label_smoothing=ls), lambda x,y: x.cross_entropy(y, label_smoothing=ls)) - @unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}") def test_nll_loss(self): helper_test_op([(32,10), (32)], lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long)), - lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long)), forward_only=True) + lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32)), forward_only=True) - @unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}") def test_nll_loss_3d(self): helper_test_op([(32,10,3,3,3), (32,3,3,3)], lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long)), - lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long)), forward_only=True) + lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32)), forward_only=True) - @unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}") def test_nll_loss_reductions(self): for r in ("mean", "sum", "none"): helper_test_op([(32,10), (32)], lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long), reduction=r), - lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long), reduction=r), forward_only=True) + lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32), reduction=r), forward_only=True) self.helper_test_exception([(32,10), (32)], lambda x,y: torch.nn.functional.nll_loss(x, torch.clip(y,0).type(torch.long), reduction="typo"), - lambda x,y: x.nll_loss(y.clip(0).cast(dtypes.long), reduction="typo"), expected=ValueError) + lambda x,y: x.nll_loss(y.clip(0).cast(dtypes.int32), reduction="typo"), expected=ValueError) - @unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}") def test_nll_loss_weight(self): for r in ("mean", "sum", "none"): helper_test_op([(32,10), (32), (10)], lambda x,y,z: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long), weight=z, reduction=r), - lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long), weight=z, reduction=r), forward_only=True) + lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32), weight=z, reduction=r), forward_only=True) - @unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}") def test_nll_loss_3d_weight(self): for r in ("mean", "sum", "none"): helper_test_op([(32,10,3,3,3), (32,3,3,3), (10)], lambda x,y,z: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long), weight=z, reduction=r), - lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long), weight=z, reduction=r), forward_only=True) + lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32), weight=z, reduction=r), forward_only=True) - @unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}") def test_nll_loss_ignore_index(self): logits = [[2.0, 0.5, -1.0], [1.5, 2.5, -0.5], @@ -2530,7 +2519,7 @@ def test_nll_loss_ignore_index(self): targets = [0, 1, 2] helper_test_op(None, lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long), ignore_index=1), - lambda x,y: x.log_softmax().nll_loss(y.clip(0).cast(dtypes.long), ignore_index=1), + lambda x,y: x.log_softmax().nll_loss(y.clip(0), ignore_index=1), forward_only=True, vals=[logits, targets]) def test_one_hot(self): @@ -2552,8 +2541,7 @@ def test_masked_fill(self): @unittest.skipIf(Device.DEFAULT == "QCOM", "OpenCL fails to compile this (both on GPU(qcom)/QCOM backends)") def test_cast(self): helper_test_op([(3, 3)], lambda x: x.float()) - if is_dtype_supported(dtypes.long): - helper_test_op(None, lambda x: x.float(), vals=[[0, 1, 2, 3]], forward_only=True) + helper_test_op(None, lambda x: x.float(), vals=[[0, 1, 2, 3]], forward_only=True) helper_test_op(None, lambda x: x.float(), vals=[[True, False]], forward_only=True) helper_test_op([(3, 3)], lambda x: x.int(), forward_only=True) helper_test_op([(3, 3)], lambda x: x.bool(), forward_only=True) @@ -2587,7 +2575,6 @@ def test_interpolate_nearest_exact(self): lambda x: torch.nn.functional.interpolate((10*x).relu().type(torch.uint8), size=out_sz, mode="nearest-exact"), lambda x: Tensor.interpolate((10*x).relu().cast('uint8'), size=out_sz, mode="nearest-exact"), forward_only=True) - @unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}") def test_min(self): helper_test_op(None, lambda x: x.type(torch.uint8).min(),