From 1521bd6f61e182eae461e9e02333155e89dc6905 Mon Sep 17 00:00:00 2001 From: Chen-Yu Yang Date: Mon, 9 Dec 2024 13:51:53 -0800 Subject: [PATCH] pow int stuff --- test/test_ops.py | 5 ++--- tinygrad/tensor.py | 8 ++++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 3e6d9bd53dc40..c79e41ac565e7 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -604,9 +604,8 @@ def _test(base, exponent): helper_test_op(None, lambda x,y: x**y, vals=[base, ex np.testing.assert_equal((Tensor(11) ** Tensor(7)).item(), 11 ** 7) np.testing.assert_equal((Tensor([11]) ** Tensor(7)).item(), 11 ** 7) - # TODO: fix non-precise int pow - with self.assertRaises(AssertionError): np.testing.assert_equal((Tensor(11) ** Tensor([7])).item(), 11 ** 7) - with self.assertRaises(AssertionError): np.testing.assert_equal((Tensor([11]) ** Tensor([7])).item(), 11 ** 7) + np.testing.assert_equal((Tensor(11) ** Tensor([7])).item(), 11 ** 7) + np.testing.assert_equal((Tensor([11]) ** Tensor([7])).item(), 11 ** 7) # pow to a const int helper_test_op([], lambda: torch.tensor([2], dtype=torch.int) ** torch.tensor(-2, dtype=torch.int), diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 5f16c9da8c06c..3a35b95439428 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3209,6 +3209,14 @@ def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp() base, exponent = self._broadcasted(x, reverse=reverse) + + # pow to int + if dtypes.is_int(exponent.dtype): + # TODO: something about _min_max takes very long time with range(31) + def step(r,b,e): return (r * (e & 1).where(b, 1), b * b, e // 2) + ret = functools.reduce(lambda rbe, _: step(*rbe), range(10), (Tensor.ones_like(self), base, exponent))[0] + return (exponent < 0).where(1//ret, ret) + # start with b ** e = exp(e * log(b)) ret = base.abs().log().mul(exponent).exp() # correct sign of negative base with odd exponent (cos has a period of 2pi so we use it here to get the oddness of the exponent)