Skip to content

Commit

Permalink
pow int stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyuxyz committed Dec 10, 2024
1 parent 56c84ce commit 3da5796
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
5 changes: 2 additions & 3 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,9 +585,8 @@ def _test(base, exponent): helper_test_op(None, lambda x,y: x**y, vals=[base, ex

np.testing.assert_equal((Tensor(11) ** Tensor(7)).item(), 11 ** 7)
np.testing.assert_equal((Tensor([11]) ** Tensor(7)).item(), 11 ** 7)
# TODO: fix non-precise int pow
with self.assertRaises(AssertionError): np.testing.assert_equal((Tensor(11) ** Tensor([7])).item(), 11 ** 7)
with self.assertRaises(AssertionError): np.testing.assert_equal((Tensor([11]) ** Tensor([7])).item(), 11 ** 7)
np.testing.assert_equal((Tensor(11) ** Tensor([7])).item(), 11 ** 7)
np.testing.assert_equal((Tensor([11]) ** Tensor([7])).item(), 11 ** 7)

# pow to a const int
helper_test_op([], lambda: torch.tensor([2], dtype=torch.int) ** torch.tensor(-2, dtype=torch.int),
Expand Down
8 changes: 8 additions & 0 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3206,6 +3206,14 @@ def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()

base, exponent = self._broadcasted(x, reverse=reverse)

# pow to int
if dtypes.is_int(exponent.dtype):
# TODO: something about _min_max takes very long time with range(31)
def step(r,b,e): return (r * (e & 1).where(b, 1), b * b, e // 2)
ret = functools.reduce(lambda rbe, _: step(*rbe), range(10), (Tensor.ones_like(self), base, exponent))[0]
return (exponent < 0).where(1//ret, ret)

# start with b ** e = exp(e * log(b))
ret = base.abs().log().mul(exponent).exp()
# correct sign of negative base with odd exponent (cos has a period of 2pi so we use it here to get the oddness of the exponent)
Expand Down

0 comments on commit 3da5796

Please sign in to comment.