From 96c483ea80a0490aff552aa1895fef0d9c7255f6 Mon Sep 17 00:00:00 2001 From: Chen-Yu Yang Date: Mon, 9 Dec 2024 15:27:42 -0800 Subject: [PATCH] make //0 return 0 in python_alu on master it raises because it cannot truncate inf to int, which crashes valid expression like `(t > 0).where(1//t, t)`. --- test/test_ops.py | 11 +++++++++-- tinygrad/ops.py | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 19b012c6f84c6..7f57c957e2999 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -505,11 +505,18 @@ def test_div_int(self): helper_test_op(None, lambda x,y: x//y, forward_only=True, vals=np.array([[5, 6, 7],[1, 2, 3]], dtype=np.int32)) helper_test_op(None, lambda x: x/2, forward_only=True, vals=np.array([[3, 4, 5]], dtype=np.int32)) helper_test_op(None, lambda x: x//2, forward_only=True, vals=np.array([[3, 4, 5]], dtype=np.int32)) - torch_idiv, tiny_idiv = functools.partial(torch.div, rounding_mode="trunc"), Tensor.idiv - helper_test_op(None, torch_idiv, tiny_idiv, forward_only=True, vals=np.array([[5, -6, 7],[1, 2, 3]], dtype=np.int32)) + helper_test_op(None, functools.partial(torch.div, rounding_mode="trunc"), Tensor.idiv, forward_only=True, + vals=np.array([[5, -6, 7],[1, 2, 3]], dtype=np.int32)) if is_dtype_supported(dtypes.uint64): x = Tensor(2**64 - 1, dtype=dtypes.uint64).idiv(1) np.testing.assert_equal(x.numpy(), 2**64 - 1) + # 1 // 0 is device dependent, but it should not raise + Tensor([1]).idiv(1).realize() + if not (CI and (Device.DEFAULT=="LLVM" or getenv("PTX"))): # TODO: crashed in CI + # ... because if might be in a where branch that the output is well defined + t = Tensor([-1, 0, 1, 2]) + np.testing.assert_equal((t > 0).where(1//t, t).numpy(), [-1, 0, 1, 0]) + def test_scalar_div(self): helper_test_op([(45,65)], lambda x: x/255) helper_test_op([(45,65)], lambda x: x/1) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 8900ff7d50ca6..8977dd0cf4f95 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -493,7 +493,7 @@ def safe_exp2(x): Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt, Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max, - Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf, + Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0, Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z} def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):