Skip to content

Commit

Permalink
fix Tensor.minimum for int (tinygrad#8145)
Browse files Browse the repository at this point in the history
use invert instead of just neg. consolidate min, argmin, and minimum

also update maximum to not apply the mid point for int
  • Loading branch information
chenyuxyz authored Dec 10, 2024
1 parent 71dd222 commit 286fec1
Showing 2 changed files with 28 additions and 5 deletions.
19 changes: 19 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
@@ -438,16 +438,35 @@ def test_maximum(self):
helper_test_op([(), ()], torch.maximum, Tensor.maximum)
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., -4.], 3.])
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., -4.], [-1., -2., 3., 0.]])
helper_test_op(None, torch.maximum, Tensor.maximum,
vals=[[-1234, 0, 1234, dtypes.max(dtypes.int), dtypes.min(dtypes.int)], dtypes.max(dtypes.int)], forward_only=True)
helper_test_op(None, torch.maximum, Tensor.maximum,
vals=[[-1234, 0, 1234, dtypes.max(dtypes.int), dtypes.min(dtypes.int)], dtypes.min(dtypes.int)], forward_only=True)
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[True, False, False], True], forward_only=True)
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[True, False, False], [True, True, False]], forward_only=True)

# test applying to different dtype
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1, 2, 3], 1.2], forward_only=True)
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[True, False, False], 1.2], forward_only=True)
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[True, False, False], 3], forward_only=True)

def test_minimum(self):
helper_test_op([(45,65), (45,65)], torch.minimum, Tensor.minimum)
helper_test_op([(), ()], torch.minimum, Tensor.minimum)
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[1., 0., 3., -4.], 3.])
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[1., 0., 3., -4.], [-1., -2., 3., 0.]])
helper_test_op(None, torch.minimum, Tensor.minimum,
vals=[[-1234, 0, 1234, dtypes.max(dtypes.int), dtypes.min(dtypes.int)], dtypes.max(dtypes.int)], forward_only=True)
helper_test_op(None, torch.minimum, Tensor.minimum,
vals=[[-1234, 0, 1234, dtypes.max(dtypes.int), dtypes.min(dtypes.int)], dtypes.min(dtypes.int)], forward_only=True)
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], True], forward_only=True)
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], [True, True, False]], forward_only=True)

# test applying to different dtype
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[1, 2, 3], 1.2], forward_only=True)
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], 1.2], forward_only=True)
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], 3], forward_only=True)

def test_tiny_add(self):
helper_test_op([(3), (3)], lambda x,y: x+y, Tensor.add, forward_only=True)
def test_tiny_mul(self):
14 changes: 9 additions & 5 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
@@ -1570,6 +1570,8 @@ def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
"""
return self._reduce(F.Max, axis, keepdim)

def _inverse(self): return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not()

def min(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
"""
Returns the minimum value of the tensor along the specified axis or axes.
@@ -1591,8 +1593,7 @@ def min(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
print(t.min(axis=1, keepdim=True).numpy())
```
"""
if dtypes.is_int(self.dtype) or self.dtype == dtypes.bool: return ~((~self).max(axis=axis, keepdim=keepdim))
return -((-self).max(axis=axis, keepdim=keepdim))
return self._inverse().max(axis=axis, keepdim=keepdim)._inverse()

def any(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
"""
@@ -1885,7 +1886,7 @@ def argmin(self, axis=None, keepdim=False):
print(t.argmin(axis=1).numpy()) # Returns the indices of the minimum values along axis 1.
```
"""
return (-self if self.is_floating_point() else ~self).argmax(axis=axis, keepdim=keepdim)
return self._inverse().argmax(axis=axis, keepdim=keepdim)

def rearrange(self, formula: str, **sizes) -> Tensor:
"""
@@ -3229,7 +3230,9 @@ def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy())
```
"""
return (self<x).detach().where(x, (self==x).detach().where(((self * 0.5 + x * 0.5).cast(self.dtype)), self))
# NOTE: the mid-point is for backward, revisit after new gradient API
if self.is_floating_point(): return (self<x).detach().where(x, (self==x).detach().where(((self * 0.5 + x * 0.5).cast(self.dtype)), self))
return (self<x).detach().where(x, self)

def minimum(self, x:Union[Tensor, ConstType]) -> Tensor:
"""
@@ -3242,7 +3245,8 @@ def minimum(self, x:Union[Tensor, ConstType]) -> Tensor:
print(Tensor([-1, 2, 3]).minimum(Tensor([-4, -2, 9])).numpy())
```
"""
return -((-self).maximum(-x))
t, x = self._broadcasted(x)
return t._inverse().maximum(x._inverse())._inverse()

def where(self:Tensor, x:Union[Tensor, ConstType, sint], y:Union[Tensor, ConstType, sint]):
"""

0 comments on commit 286fec1

Please sign in to comment.