Skip to content

Commit

Permalink
fix Tensor._pool edge case (tinygrad#7581)
Browse files Browse the repository at this point in the history
* split into another branch

* polish

* try this

* Revert "try this"

This reverts commit 84f711b.

* try

* Revert "try"

This reverts commit 89c7a76.

* idk anymore

* it is what it is

---------

Co-authored-by: chenyu <[email protected]>
  • Loading branch information
geohotstan and chenyuxyz authored Nov 29, 2024
1 parent 70f052d commit 765096f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 1 addition & 3 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2038,9 +2038,7 @@ def test_avg_pool2d(self):
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz),
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), rtol=1e-5)

# TODO fix edge case
@unittest.expectedFailure
def test_avg_pool2d_failure(self):
# regression test for https://github.com/tinygrad/tinygrad/pull/7581
helper_test_op([(1,1,8,8)],
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(1,2), padding=(0,1), stride=(5,1)),
lambda x: Tensor.avg_pool2d(x, kernel_size=(1,2), padding=(0,1), stride=(5,1)), rtol=1e-5)
Expand Down
8 changes: 5 additions & 3 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2008,10 +2008,12 @@ def _pool(self, k_:Tuple[sint, ...], stride:Union[Tuple[int, ...], int]=1, dilat
assert all(resolve(d*(k-1)+1 <= i) for k,d,i in zip(k_,d_,i_)), "kernel size cannot be greater than actual input size"
o_ = [ceildiv(i-d*(k-1), s) for i,d,k,s in zip(i_,d_,k_,s_)]
if any(resolve(k > s) for k,s in zip(k_,s_)) or any(d != 1 for d in d_):
# repeats such that we don't need padding
x = self.repeat([1]*len(noop) + [ceildiv(k*(i+d), i) for k,i,d in zip(k_,i_,d_)])
# input size scaling factor to make sure shrink for stride is possible
f_ = [1 + int(resolve(o*s > i+d)) for o,s,i,d in zip(o_,s_,i_,d_)]
# # repeats such that we don't need padding
x = self.repeat([1]*len(noop) + [ceildiv(k*(i*f+d),i) for k,i,d,f in zip(k_,i_,d_,f_)])
# handle dilation
x = x.shrink(tuple(noop + [(0,k*(i+d)) for k,i,d in zip(k_,i_,d_)])).reshape(noop + flatten((k,i+d) for k,i,d in zip(k_,i_,d_)))
x = x.shrink(tuple(noop + [(0,k*(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)])).reshape(noop + flatten((k,(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)))
# handle stride
x = x.shrink(tuple(noop + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_,o_,s_)))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_)))
x = x.shrink(tuple(noop + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_,o_)))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_)))
Expand Down

0 comments on commit 765096f

Please sign in to comment.