diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv_benchmark.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv_benchmark.py new file mode 100644 index 00000000000000..227c0bd911ada3 --- /dev/null +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv_benchmark.py @@ -0,0 +1,39 @@ +import sys + +from benchmark_base import BenchmarkBase + +import torch + + +class Benchmark(BenchmarkBase): + N = 100 + + def name(self): + return "sum_floordiv_regression" + + def description(self): + return "information at https://github.com/pytorch/pytorch/issues/134133" + + def prepare_once(self): + class M(torch.nn.Module): + def forward(self, x): + total = sum(t.item() for t in x) + return total // 2 + + self.m = M() + self.input = [torch.tensor(i + 2) for i in range(self.N)] + + def prepare(self): + torch._dynamo.reset() + + def work(self): + torch.export.export(self.m, (self.input,)) + + +def main(): + result_path = sys.argv[1] + Benchmark().enable_instruction_count().collect_all().append_results(result_path) + + +if __name__ == "__main__": + main() diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 0998b57678610a..d54495047e2713 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -186,7 +186,8 @@ def eval(cls, base, divisor): # Expands (x + y) // b into x // b + y // b. # This only works if floor is an identity, i.e. x / b is an integer. - for term in sympy.Add.make_args(base): + base_args = sympy.Add.make_args(base) + for term in base_args: quotient = term / divisor if quotient.is_integer and isinstance(divisor, sympy.Integer): # NB: this is correct even if the divisor is not an integer, but it @@ -195,8 +196,11 @@ def eval(cls, base, divisor): return FloorDiv(base - term, divisor) + quotient try: - gcd = sympy.gcd(base, divisor) - if not equal_valued(gcd, 1): + # sympy.gcd tends to blow up on large sums, so use it on each summand instead + gcd, *gcds_ = (sympy.gcd(term, divisor) for term in base_args) + if not equal_valued(gcd, 1) and all( + equal_valued(gcd, gcd_) for gcd_ in gcds_ + ): return FloorDiv( sympy.simplify(base / gcd), sympy.simplify(divisor / gcd) )