diff --git a/test/test_ops.py b/test/test_ops.py index 7f57c957e299..287b539fffe4 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -438,16 +438,35 @@ def test_maximum(self): helper_test_op([(), ()], torch.maximum, Tensor.maximum) helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., -4.], 3.]) helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., -4.], [-1., -2., 3., 0.]]) + helper_test_op(None, torch.maximum, Tensor.maximum, + vals=[[-1234, 0, 1234, dtypes.max(dtypes.int), dtypes.min(dtypes.int)], dtypes.max(dtypes.int)], forward_only=True) + helper_test_op(None, torch.maximum, Tensor.maximum, + vals=[[-1234, 0, 1234, dtypes.max(dtypes.int), dtypes.min(dtypes.int)], dtypes.min(dtypes.int)], forward_only=True) helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[True, False, False], True], forward_only=True) helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[True, False, False], [True, True, False]], forward_only=True) + + # test applying to different dtype + helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1, 2, 3], 1.2], forward_only=True) + helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[True, False, False], 1.2], forward_only=True) + helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[True, False, False], 3], forward_only=True) + def test_minimum(self): helper_test_op([(45,65), (45,65)], torch.minimum, Tensor.minimum) helper_test_op([(), ()], torch.minimum, Tensor.minimum) helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[1., 0., 3., -4.], 3.]) helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[1., 0., 3., -4.], [-1., -2., 3., 0.]]) + helper_test_op(None, torch.minimum, Tensor.minimum, + vals=[[-1234, 0, 1234, dtypes.max(dtypes.int), dtypes.min(dtypes.int)], dtypes.max(dtypes.int)], forward_only=True) + helper_test_op(None, torch.minimum, Tensor.minimum, + vals=[[-1234, 0, 1234, dtypes.max(dtypes.int), dtypes.min(dtypes.int)], dtypes.min(dtypes.int)], forward_only=True) helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], True], forward_only=True) helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], [True, True, False]], forward_only=True) + # test applying to different dtype + helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[1, 2, 3], 1.2], forward_only=True) + helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], 1.2], forward_only=True) + helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], 3], forward_only=True) + def test_tiny_add(self): helper_test_op([(3), (3)], lambda x,y: x+y, Tensor.add, forward_only=True) def test_tiny_mul(self): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 12cc7faa8ab2..a3740930c636 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1570,6 +1570,8 @@ def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False): """ return self._reduce(F.Max, axis, keepdim) + def _inverse(self): return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not() + def min(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False): """ Returns the minimum value of the tensor along the specified axis or axes. @@ -1591,8 +1593,7 @@ def min(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False): print(t.min(axis=1, keepdim=True).numpy()) ``` """ - if dtypes.is_int(self.dtype) or self.dtype == dtypes.bool: return ~((~self).max(axis=axis, keepdim=keepdim)) - return -((-self).max(axis=axis, keepdim=keepdim)) + return self._inverse().max(axis=axis, keepdim=keepdim)._inverse() def any(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False): """ @@ -1885,7 +1886,7 @@ def argmin(self, axis=None, keepdim=False): print(t.argmin(axis=1).numpy()) # Returns the indices of the minimum values along axis 1. ``` """ - return (-self if self.is_floating_point() else ~self).argmax(axis=axis, keepdim=keepdim) + return self._inverse().argmax(axis=axis, keepdim=keepdim) def rearrange(self, formula: str, **sizes) -> Tensor: """ @@ -3229,7 +3230,9 @@ def maximum(self, x:Union[Tensor, ConstType]) -> Tensor: print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy()) ``` """ - return (self Tensor: """ @@ -3242,7 +3245,8 @@ def minimum(self, x:Union[Tensor, ConstType]) -> Tensor: print(Tensor([-1, 2, 3]).minimum(Tensor([-4, -2, 9])).numpy()) ``` """ - return -((-self).maximum(-x)) + t, x = self._broadcasted(x) + return t._inverse().maximum(x._inverse())._inverse() def where(self:Tensor, x:Union[Tensor, ConstType, sint], y:Union[Tensor, ConstType, sint]): """