Skip to content

Commit

Permalink
fix Tensor.minimum for int
Browse files Browse the repository at this point in the history
use invert instead of just neg. consolidate min, argmin, and minimum

also update maximum to not apply the mid point for int
  • Loading branch information
chenyuxyz committed Dec 10, 2024
1 parent 56c84ce commit 73c3a50
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
19 changes: 19 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,16 +438,35 @@ def test_maximum(self):
helper_test_op([(), ()], torch.maximum, Tensor.maximum)
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., -4.], 3.])
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., -4.], [-1., -2., 3., 0.]])
helper_test_op(None, torch.maximum, Tensor.maximum,
vals=[[-1234, 0, 1234, dtypes.max(dtypes.int), dtypes.min(dtypes.int)], dtypes.max(dtypes.int)], forward_only=True)
helper_test_op(None, torch.maximum, Tensor.maximum,
vals=[[-1234, 0, 1234, dtypes.max(dtypes.int), dtypes.min(dtypes.int)], dtypes.min(dtypes.int)], forward_only=True)
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[True, False, False], True], forward_only=True)
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[True, False, False], [True, True, False]], forward_only=True)

# test applying to different dtype
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1, 2, 3], 1.2], forward_only=True)
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[True, False, False], 1.2], forward_only=True)
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[True, False, False], 3], forward_only=True)

def test_minimum(self):
helper_test_op([(45,65), (45,65)], torch.minimum, Tensor.minimum)
helper_test_op([(), ()], torch.minimum, Tensor.minimum)
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[1., 0., 3., -4.], 3.])
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[1., 0., 3., -4.], [-1., -2., 3., 0.]])
helper_test_op(None, torch.minimum, Tensor.minimum,
vals=[[-1234, 0, 1234, dtypes.max(dtypes.int), dtypes.min(dtypes.int)], dtypes.max(dtypes.int)], forward_only=True)
helper_test_op(None, torch.minimum, Tensor.minimum,
vals=[[-1234, 0, 1234, dtypes.max(dtypes.int), dtypes.min(dtypes.int)], dtypes.min(dtypes.int)], forward_only=True)
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], True], forward_only=True)
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], [True, True, False]], forward_only=True)

# test applying to different dtype
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[1, 2, 3], 1.2], forward_only=True)
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], 1.2], forward_only=True)
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], 3], forward_only=True)

def test_tiny_add(self):
helper_test_op([(3), (3)], lambda x,y: x+y, Tensor.add, forward_only=True)
def test_tiny_mul(self):
Expand Down
15 changes: 10 additions & 5 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1570,6 +1570,8 @@ def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
"""
return self._reduce(F.Max, axis, keepdim)

def _inverse(self): return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not()

def min(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
"""
Returns the minimum value of the tensor along the specified axis or axes.
Expand All @@ -1591,8 +1593,7 @@ def min(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
print(t.min(axis=1, keepdim=True).numpy())
```
"""
if dtypes.is_int(self.dtype) or self.dtype == dtypes.bool: return ~((~self).max(axis=axis, keepdim=keepdim))
return -((-self).max(axis=axis, keepdim=keepdim))
return self._inverse().max(axis=axis, keepdim=keepdim)._inverse()

def any(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
"""
Expand Down Expand Up @@ -1885,7 +1886,7 @@ def argmin(self, axis=None, keepdim=False):
print(t.argmin(axis=1).numpy()) # Returns the indices of the minimum values along axis 1.
```
"""
return (-self if self.is_floating_point() else ~self).argmax(axis=axis, keepdim=keepdim)
return self._inverse().argmax(axis=axis, keepdim=keepdim)

def rearrange(self, formula: str, **sizes) -> Tensor:
"""
Expand Down Expand Up @@ -3229,7 +3230,9 @@ def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy())
```
"""
return (self<x).detach().where(x, (self==x).detach().where(((self * 0.5 + x * 0.5).cast(self.dtype)), self))
# NOTE: the mid-point is for backward, update after new gradient API
if self.is_floating_point(): return (self<x).detach().where(x, (self==x).detach().where(((self * 0.5 + x * 0.5).cast(self.dtype)), self))
return (self<x).detach().where(x, self)

def minimum(self, x:Union[Tensor, ConstType]) -> Tensor:
"""
Expand All @@ -3242,7 +3245,9 @@ def minimum(self, x:Union[Tensor, ConstType]) -> Tensor:
print(Tensor([-1, 2, 3]).minimum(Tensor([-4, -2, 9])).numpy())
```
"""
return -((-self).maximum(-x))
t, x = self._broadcasted(x)
print(f"{t.dtype=}, {x.dtype=}")
return t._inverse().maximum(x._inverse())._inverse()

def where(self:Tensor, x:Union[Tensor, ConstType, sint], y:Union[Tensor, ConstType, sint]):
"""
Expand Down

0 comments on commit 73c3a50

Please sign in to comment.