Skip to content

Commit

Permalink
fix bitwise_not for signed int (tinygrad#8117)
Browse files Browse the repository at this point in the history
-1 is correct because 2**32-1 is not within int32 range, so in some case clang casts the whole thing into uint32
  • Loading branch information
chenyuxyz authored Dec 9, 2024
1 parent e22d7b6 commit c814de2
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 2 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,8 @@ def test_argmax(self):
helper_test_op([(10,20)], lambda x: x.argmax(0, False).type(torch.int32), lambda x: x.argmax(0, False), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmax(1, False).type(torch.int32), lambda x: x.argmax(1, False), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmax(1, True).type(torch.int32), lambda x: x.argmax(1, True), forward_only=True)
# regression test for bitwise_not then argmax
helper_test_op(None, lambda x: (~x).argmax().type(torch.int32), lambda x: (~x).argmax(), forward_only=True, vals=[[2, 2]])

def test_argmin(self):
# check if it returns the first index for multiple occurences
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3152,7 +3152,7 @@ def bitwise_not(self) -> Tensor:
```
"""
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
return self.logical_not() if self.dtype == dtypes.bool else self ^ ((1<<8*self.dtype.itemsize)-1)
return self.logical_not() if self.dtype == dtypes.bool else self ^ -1

def lshift(self, x:int):
"""
Expand Down

0 comments on commit c814de2

Please sign in to comment.