Skip to content

Commit

Permalink
Merge pull request #585 from chenyuxyz/test_ops-int32
Browse files Browse the repository at this point in the history
default tensors to int32 in test_ops
  • Loading branch information
chenyuxyz authored Dec 7, 2024
2 parents e9692de + 25897b8 commit a5183b0
Showing 1 changed file with 22 additions and 35 deletions.
57 changes: 22 additions & 35 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def prepare_test_op(low, high, shps, vals, forward_only=False):
np.random.seed(0)
np_data = [np.random.uniform(low=low, high=high, size=size).astype(_to_np_dtype(dtypes.default_float)) for size in shps]
ts = [torch.tensor(data, requires_grad=(not forward_only)) for data in np_data]
for i in range(len(ts)):
# NOTE: torch default int64 for python ints input
if ts[i].dtype == torch.int64: ts[i] = ts[i].type(torch.int32)
tst = [Tensor(x.detach().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY)) for x in ts]
return ts, tst

Expand Down Expand Up @@ -312,8 +315,7 @@ def test_where_permute(self):
def _test_cmp(self, fxn, reverse=True):
# test different dtypes
helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0.,1,2], [2.,1,0]])
if is_dtype_supported(dtypes.long):
helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0,1,2], [2,1,0]])
helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0,1,2], [2,1,0]])
helper_test_op(None, fxn, fxn, forward_only=True, vals=[[True, True, False], [False,True,False]])
# test broadcasting
for shps in [[(3, 4, 5), (3, 4, 5)], [(3, 4, 5), (5,)], [(5,), (3, 4, 5)]]:
Expand Down Expand Up @@ -563,10 +565,7 @@ def test_pow_const(self):
helper_test_op(None, lambda x: (-2)**x, vals=[[-2.,-1,0,1,2,3]], forward_only=True)

def test_pow_int(self):
# TODO: better infra for these, helper_test_op creates buffer in long first, so WEBGPU fails
def _test(base, exponent):
np.testing.assert_equal((Tensor(base) ** Tensor(exponent)).numpy(),
(torch.tensor(base, dtype=torch.int) ** torch.tensor(exponent, dtype=torch.int)).numpy())
def _test(base, exponent): helper_test_op(None, lambda x,y: x**y, vals=[base, exponent], forward_only=True)

for base in ([1, 2, 3], [-1, -2, -3]):
for exponent in ([2, 3, 4], [-2, -3, -4]):
Expand Down Expand Up @@ -1098,9 +1097,8 @@ def test_min(self):
helper_test_op([(45,3)], lambda x: x.min().mul(0.5))
helper_test_op([()], lambda x: x.min())

if is_dtype_supported(dtypes.long):
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[-2**31, 0]])
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[-2**31, 0]])
helper_test_op(None, lambda x: x.type(torch.bool).min(), lambda x: x.cast(dtypes.bool).min(), forward_only=True, vals=[[False, True]])
helper_test_op(None, lambda x: x.type(torch.bool).min(), lambda x: x.cast(dtypes.bool).min(), forward_only=True, vals=[[True, False]])

Expand All @@ -1111,9 +1109,8 @@ def test_max(self):
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: x.max(axis=1))
helper_test_op([()], lambda x: x.max())

if is_dtype_supported(dtypes.long):
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[-2**31, 0]])
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[-2**31, 0]])
helper_test_op(None, lambda x: x.type(torch.bool).max(), lambda x: x.cast(dtypes.bool).max(), forward_only=True, vals=[[False, True]])
helper_test_op(None, lambda x: x.type(torch.bool).max(), lambda x: x.cast(dtypes.bool).max(), forward_only=True, vals=[[True, False]])

Expand Down Expand Up @@ -1216,12 +1213,10 @@ def test_std_mean(self):
helper_test_op([(3,4,5,6)], lambda x: torch.stack(torch.std_mean(x, axis=(1,2))),
lambda x: Tensor.stack(*x.std_mean(axis=(1,2))))
def test_softmax(self):
# exceed per kernel buffer limit with backward
forward_only = (Device.DEFAULT == "WEBGPU")
helper_test_op([(45,65)], torch.nn.Softmax(dim=1), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
helper_test_op([(45)], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
helper_test_op([()], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
helper_test_op([()], torch.nn.Softmax(dim=-1), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
helper_test_op([(45,65)], torch.nn.Softmax(dim=1), Tensor.softmax, atol=1e-7, grad_atol=1e-7)
helper_test_op([(45)], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7)
helper_test_op([()], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7)
helper_test_op([()], torch.nn.Softmax(dim=-1), Tensor.softmax, atol=1e-7, grad_atol=1e-7)
def test_softmax_other_axis(self):
helper_test_op([(10,10,10)], lambda x: x.softmax(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([(10,10,10)], lambda x: x.softmax(1), atol=1e-7, grad_atol=1e-7)
Expand Down Expand Up @@ -2246,7 +2241,7 @@ def test_matvec(self):
helper_test_op([(1,128), (128,128)], lambda x,y: (x@y).relu())

@unittest.skip("this test is broken #862")
def test_max_inf(self):
def test_max_nan(self):
n = Tensor([1, float("nan")]).max().numpy()
assert math.isnan(n.item()), f"{n.item()} is not nan"

Expand Down Expand Up @@ -2484,53 +2479,47 @@ def test_cross_entropy_smoothing(self):
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, label_smoothing=ls),
lambda x,y: x.cross_entropy(y, label_smoothing=ls))

@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_nll_loss(self):
helper_test_op([(32,10), (32)],
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long)),
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long)), forward_only=True)
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32)), forward_only=True)

@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_nll_loss_3d(self):
helper_test_op([(32,10,3,3,3), (32,3,3,3)],
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long)),
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long)), forward_only=True)
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32)), forward_only=True)

@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_nll_loss_reductions(self):
for r in ("mean", "sum", "none"):
helper_test_op([(32,10), (32)],
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long), reduction=r),
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long), reduction=r), forward_only=True)
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32), reduction=r), forward_only=True)
self.helper_test_exception([(32,10), (32)],
lambda x,y: torch.nn.functional.nll_loss(x, torch.clip(y,0).type(torch.long), reduction="typo"),
lambda x,y: x.nll_loss(y.clip(0).cast(dtypes.long), reduction="typo"), expected=ValueError)
lambda x,y: x.nll_loss(y.clip(0).cast(dtypes.int32), reduction="typo"), expected=ValueError)

@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_nll_loss_weight(self):
for r in ("mean", "sum", "none"):
helper_test_op([(32,10), (32), (10)],
lambda x,y,z: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long),
weight=z, reduction=r),
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long), weight=z, reduction=r), forward_only=True)
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32), weight=z, reduction=r), forward_only=True)

@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_nll_loss_3d_weight(self):
for r in ("mean", "sum", "none"):
helper_test_op([(32,10,3,3,3), (32,3,3,3), (10)],
lambda x,y,z: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long),
weight=z, reduction=r),
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long), weight=z, reduction=r), forward_only=True)
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32), weight=z, reduction=r), forward_only=True)

@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_nll_loss_ignore_index(self):
logits = [[2.0, 0.5, -1.0],
[1.5, 2.5, -0.5],
[0.0, -2.0, 1.0]]
targets = [0, 1, 2]
helper_test_op(None, lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1),
torch.clip(y,0).type(torch.long), ignore_index=1),
lambda x,y: x.log_softmax().nll_loss(y.clip(0).cast(dtypes.long), ignore_index=1),
lambda x,y: x.log_softmax().nll_loss(y.clip(0), ignore_index=1),
forward_only=True, vals=[logits, targets])

def test_one_hot(self):
Expand All @@ -2552,8 +2541,7 @@ def test_masked_fill(self):
@unittest.skipIf(Device.DEFAULT == "QCOM", "OpenCL fails to compile this (both on GPU(qcom)/QCOM backends)")
def test_cast(self):
helper_test_op([(3, 3)], lambda x: x.float())
if is_dtype_supported(dtypes.long):
helper_test_op(None, lambda x: x.float(), vals=[[0, 1, 2, 3]], forward_only=True)
helper_test_op(None, lambda x: x.float(), vals=[[0, 1, 2, 3]], forward_only=True)
helper_test_op(None, lambda x: x.float(), vals=[[True, False]], forward_only=True)
helper_test_op([(3, 3)], lambda x: x.int(), forward_only=True)
helper_test_op([(3, 3)], lambda x: x.bool(), forward_only=True)
Expand Down Expand Up @@ -2587,7 +2575,6 @@ def test_interpolate_nearest_exact(self):
lambda x: torch.nn.functional.interpolate((10*x).relu().type(torch.uint8), size=out_sz, mode="nearest-exact"),
lambda x: Tensor.interpolate((10*x).relu().cast('uint8'), size=out_sz, mode="nearest-exact"), forward_only=True)

@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_min(self):
helper_test_op(None,
lambda x: x.type(torch.uint8).min(),
Expand Down

0 comments on commit a5183b0

Please sign in to comment.