Skip to content

Commit

Permalink
fix stuck floordiv (pytorch#134150)
Browse files Browse the repository at this point in the history
Summary: Fixes pytorch#134133

Test Plan:
Tested on the small repro in the linked issue with different lengths N (replacing 100), recording N vs. time taken in nanoseconds:
10 127268319
20 220839662
30 325463125
40 429259441
50 553136055
60 670799769
70 999170514
80 899014103
90 997168902
100 1168202035
110 1388556619
120 1457488235
130 1609816470
140 2177889877
150 1917560313
160 2121096113
170 2428502334
180 4117450755
190 4003068224

So N ~ 200 takes ~5s. Previously even smaller N would go for >1 min.

Didn't add a perf test because ezyang is planning to build a benchmark.

Also tested on https://www.internalfb.com/diff/D61560171, which now gets past the stuck point.

Differential Revision: D61619660

Pull Request resolved: pytorch#134150
Approved by: https://github.com/ezyang
  • Loading branch information
avikchaudhuri authored and pytorchmergebot committed Aug 26, 2024
1 parent c5f6b72 commit 92c4771
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import sys

from benchmark_base import BenchmarkBase

import torch


class Benchmark(BenchmarkBase):
N = 100

def name(self):
return "sum_floordiv_regression"

def description(self):
return "information at https://github.com/pytorch/pytorch/issues/134133"

def prepare_once(self):
class M(torch.nn.Module):
def forward(self, x):
total = sum(t.item() for t in x)
return total // 2

self.m = M()
self.input = [torch.tensor(i + 2) for i in range(self.N)]

def prepare(self):
torch._dynamo.reset()

def work(self):
torch.export.export(self.m, (self.input,))


def main():
result_path = sys.argv[1]
Benchmark().enable_instruction_count().collect_all().append_results(result_path)


if __name__ == "__main__":
main()
10 changes: 7 additions & 3 deletions torch/utils/_sympy/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ def eval(cls, base, divisor):

# Expands (x + y) // b into x // b + y // b.
# This only works if floor is an identity, i.e. x / b is an integer.
for term in sympy.Add.make_args(base):
base_args = sympy.Add.make_args(base)
for term in base_args:
quotient = term / divisor
if quotient.is_integer and isinstance(divisor, sympy.Integer):
# NB: this is correct even if the divisor is not an integer, but it
Expand All @@ -195,8 +196,11 @@ def eval(cls, base, divisor):
return FloorDiv(base - term, divisor) + quotient

try:
gcd = sympy.gcd(base, divisor)
if not equal_valued(gcd, 1):
# sympy.gcd tends to blow up on large sums, so use it on each summand instead
gcd, *gcds_ = (sympy.gcd(term, divisor) for term in base_args)
if not equal_valued(gcd, 1) and all(
equal_valued(gcd, gcd_) for gcd_ in gcds_
):
return FloorDiv(
sympy.simplify(base / gcd), sympy.simplify(divisor / gcd)
)
Expand Down

0 comments on commit 92c4771

Please sign in to comment.