Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

default tensors to int32 in test_ops #585

Merged
merged 1 commit into from
Dec 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 22 additions & 35 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def prepare_test_op(low, high, shps, vals, forward_only=False):
np.random.seed(0)
np_data = [np.random.uniform(low=low, high=high, size=size).astype(_to_np_dtype(dtypes.default_float)) for size in shps]
ts = [torch.tensor(data, requires_grad=(not forward_only)) for data in np_data]
for i in range(len(ts)):
# NOTE: torch default int64 for python ints input
if ts[i].dtype == torch.int64: ts[i] = ts[i].type(torch.int32)
tst = [Tensor(x.detach().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY)) for x in ts]
return ts, tst

Expand Down Expand Up @@ -312,8 +315,7 @@ def test_where_permute(self):
def _test_cmp(self, fxn, reverse=True):
# test different dtypes
helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0.,1,2], [2.,1,0]])
if is_dtype_supported(dtypes.long):
helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0,1,2], [2,1,0]])
helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0,1,2], [2,1,0]])
helper_test_op(None, fxn, fxn, forward_only=True, vals=[[True, True, False], [False,True,False]])
# test broadcasting
for shps in [[(3, 4, 5), (3, 4, 5)], [(3, 4, 5), (5,)], [(5,), (3, 4, 5)]]:
Expand Down Expand Up @@ -563,10 +565,7 @@ def test_pow_const(self):
helper_test_op(None, lambda x: (-2)**x, vals=[[-2.,-1,0,1,2,3]], forward_only=True)

def test_pow_int(self):
# TODO: better infra for these, helper_test_op creates buffer in long first, so WEBGPU fails
def _test(base, exponent):
np.testing.assert_equal((Tensor(base) ** Tensor(exponent)).numpy(),
(torch.tensor(base, dtype=torch.int) ** torch.tensor(exponent, dtype=torch.int)).numpy())
def _test(base, exponent): helper_test_op(None, lambda x,y: x**y, vals=[base, exponent], forward_only=True)

for base in ([1, 2, 3], [-1, -2, -3]):
for exponent in ([2, 3, 4], [-2, -3, -4]):
Expand Down Expand Up @@ -1098,9 +1097,8 @@ def test_min(self):
helper_test_op([(45,3)], lambda x: x.min().mul(0.5))
helper_test_op([()], lambda x: x.min())

if is_dtype_supported(dtypes.long):
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[-2**31, 0]])
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[-2**31, 0]])
helper_test_op(None, lambda x: x.type(torch.bool).min(), lambda x: x.cast(dtypes.bool).min(), forward_only=True, vals=[[False, True]])
helper_test_op(None, lambda x: x.type(torch.bool).min(), lambda x: x.cast(dtypes.bool).min(), forward_only=True, vals=[[True, False]])

Expand All @@ -1111,9 +1109,8 @@ def test_max(self):
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: x.max(axis=1))
helper_test_op([()], lambda x: x.max())

if is_dtype_supported(dtypes.long):
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[-2**31, 0]])
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[-2**31, 0]])
helper_test_op(None, lambda x: x.type(torch.bool).max(), lambda x: x.cast(dtypes.bool).max(), forward_only=True, vals=[[False, True]])
helper_test_op(None, lambda x: x.type(torch.bool).max(), lambda x: x.cast(dtypes.bool).max(), forward_only=True, vals=[[True, False]])

Expand Down Expand Up @@ -1216,12 +1213,10 @@ def test_std_mean(self):
helper_test_op([(3,4,5,6)], lambda x: torch.stack(torch.std_mean(x, axis=(1,2))),
lambda x: Tensor.stack(*x.std_mean(axis=(1,2))))
def test_softmax(self):
# exceed per kernel buffer limit with backward
forward_only = (Device.DEFAULT == "WEBGPU")
helper_test_op([(45,65)], torch.nn.Softmax(dim=1), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
helper_test_op([(45)], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
helper_test_op([()], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
helper_test_op([()], torch.nn.Softmax(dim=-1), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
helper_test_op([(45,65)], torch.nn.Softmax(dim=1), Tensor.softmax, atol=1e-7, grad_atol=1e-7)
helper_test_op([(45)], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7)
helper_test_op([()], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7)
helper_test_op([()], torch.nn.Softmax(dim=-1), Tensor.softmax, atol=1e-7, grad_atol=1e-7)
def test_softmax_other_axis(self):
helper_test_op([(10,10,10)], lambda x: x.softmax(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([(10,10,10)], lambda x: x.softmax(1), atol=1e-7, grad_atol=1e-7)
Expand Down Expand Up @@ -2246,7 +2241,7 @@ def test_matvec(self):
helper_test_op([(1,128), (128,128)], lambda x,y: (x@y).relu())

@unittest.skip("this test is broken #862")
def test_max_inf(self):
def test_max_nan(self):
n = Tensor([1, float("nan")]).max().numpy()
assert math.isnan(n.item()), f"{n.item()} is not nan"

Expand Down Expand Up @@ -2484,53 +2479,47 @@ def test_cross_entropy_smoothing(self):
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, label_smoothing=ls),
lambda x,y: x.cross_entropy(y, label_smoothing=ls))

@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_nll_loss(self):
helper_test_op([(32,10), (32)],
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long)),
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long)), forward_only=True)
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32)), forward_only=True)

@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_nll_loss_3d(self):
helper_test_op([(32,10,3,3,3), (32,3,3,3)],
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long)),
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long)), forward_only=True)
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32)), forward_only=True)

@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_nll_loss_reductions(self):
for r in ("mean", "sum", "none"):
helper_test_op([(32,10), (32)],
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long), reduction=r),
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long), reduction=r), forward_only=True)
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32), reduction=r), forward_only=True)
self.helper_test_exception([(32,10), (32)],
lambda x,y: torch.nn.functional.nll_loss(x, torch.clip(y,0).type(torch.long), reduction="typo"),
lambda x,y: x.nll_loss(y.clip(0).cast(dtypes.long), reduction="typo"), expected=ValueError)
lambda x,y: x.nll_loss(y.clip(0).cast(dtypes.int32), reduction="typo"), expected=ValueError)

@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_nll_loss_weight(self):
for r in ("mean", "sum", "none"):
helper_test_op([(32,10), (32), (10)],
lambda x,y,z: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long),
weight=z, reduction=r),
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long), weight=z, reduction=r), forward_only=True)
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32), weight=z, reduction=r), forward_only=True)

@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_nll_loss_3d_weight(self):
for r in ("mean", "sum", "none"):
helper_test_op([(32,10,3,3,3), (32,3,3,3), (10)],
lambda x,y,z: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long),
weight=z, reduction=r),
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long), weight=z, reduction=r), forward_only=True)
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32), weight=z, reduction=r), forward_only=True)

@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_nll_loss_ignore_index(self):
logits = [[2.0, 0.5, -1.0],
[1.5, 2.5, -0.5],
[0.0, -2.0, 1.0]]
targets = [0, 1, 2]
helper_test_op(None, lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1),
torch.clip(y,0).type(torch.long), ignore_index=1),
lambda x,y: x.log_softmax().nll_loss(y.clip(0).cast(dtypes.long), ignore_index=1),
lambda x,y: x.log_softmax().nll_loss(y.clip(0), ignore_index=1),
forward_only=True, vals=[logits, targets])

def test_one_hot(self):
Expand All @@ -2552,8 +2541,7 @@ def test_masked_fill(self):
@unittest.skipIf(Device.DEFAULT == "QCOM", "OpenCL fails to compile this (both on GPU(qcom)/QCOM backends)")
def test_cast(self):
helper_test_op([(3, 3)], lambda x: x.float())
if is_dtype_supported(dtypes.long):
helper_test_op(None, lambda x: x.float(), vals=[[0, 1, 2, 3]], forward_only=True)
helper_test_op(None, lambda x: x.float(), vals=[[0, 1, 2, 3]], forward_only=True)
helper_test_op(None, lambda x: x.float(), vals=[[True, False]], forward_only=True)
helper_test_op([(3, 3)], lambda x: x.int(), forward_only=True)
helper_test_op([(3, 3)], lambda x: x.bool(), forward_only=True)
Expand Down Expand Up @@ -2587,7 +2575,6 @@ def test_interpolate_nearest_exact(self):
lambda x: torch.nn.functional.interpolate((10*x).relu().type(torch.uint8), size=out_sz, mode="nearest-exact"),
lambda x: Tensor.interpolate((10*x).relu().cast('uint8'), size=out_sz, mode="nearest-exact"), forward_only=True)

@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_min(self):
helper_test_op(None,
lambda x: x.type(torch.uint8).min(),
Expand Down
Loading