Skip to content

Commit

Permalink
make //0 return 0 in python_alu
Browse files Browse the repository at this point in the history
on master it raises because it cannot truncate inf to int, which crashes valid expression like `(t > 0).where(1//t, t)`.
  • Loading branch information
chenyuxyz committed Dec 9, 2024
1 parent f83d715 commit 96c483e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
11 changes: 9 additions & 2 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,11 +505,18 @@ def test_div_int(self):
helper_test_op(None, lambda x,y: x//y, forward_only=True, vals=np.array([[5, 6, 7],[1, 2, 3]], dtype=np.int32))
helper_test_op(None, lambda x: x/2, forward_only=True, vals=np.array([[3, 4, 5]], dtype=np.int32))
helper_test_op(None, lambda x: x//2, forward_only=True, vals=np.array([[3, 4, 5]], dtype=np.int32))
torch_idiv, tiny_idiv = functools.partial(torch.div, rounding_mode="trunc"), Tensor.idiv
helper_test_op(None, torch_idiv, tiny_idiv, forward_only=True, vals=np.array([[5, -6, 7],[1, 2, 3]], dtype=np.int32))
helper_test_op(None, functools.partial(torch.div, rounding_mode="trunc"), Tensor.idiv, forward_only=True,
vals=np.array([[5, -6, 7],[1, 2, 3]], dtype=np.int32))
if is_dtype_supported(dtypes.uint64):
x = Tensor(2**64 - 1, dtype=dtypes.uint64).idiv(1)
np.testing.assert_equal(x.numpy(), 2**64 - 1)
# 1 // 0 is device dependent, but it should not raise
Tensor([1]).idiv(1).realize()
if not (CI and (Device.DEFAULT=="LLVM" or getenv("PTX"))): # TODO: crashed in CI
# ... because if might be in a where branch that the output is well defined
t = Tensor([-1, 0, 1, 2])
np.testing.assert_equal((t > 0).where(1//t, t).numpy(), [-1, 0, 1, 0])

def test_scalar_div(self):
helper_test_op([(45,65)], lambda x: x/255)
helper_test_op([(45,65)], lambda x: x/1)
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def safe_exp2(x):
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt,
Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,
Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0,
Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z}

def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
Expand Down

0 comments on commit 96c483e

Please sign in to comment.