From 94f92fbd883605e6f8109e8202a7e9614bcf55a0 Mon Sep 17 00:00:00 2001 From: Bob Ren Date: Thu, 22 Aug 2024 17:12:12 -0700 Subject: [PATCH] Use integer divison in arange length calculation when start/end/step are integral (#134296) Fixes #133338 Test Plan: ``` TORCH_LOGS=dynamic python import torch torch._dynamo.config.capture_scalar_outputs = True @torch.compile() def f(x): y = x.item() torch._check_is_size(y) r = torch.arange(y, dtype=torch.float32) torch._check(r.size(0) == y) return r f(torch.tensor([300])) ``` Before and after diff. Verify the following line ``` I0813 11:05:44.890000 652898 torch/fx/experimental/symbolic_shapes.py:5198] [0/0] runtime_assert Eq(CeilToInt(IntTrueDiv(u0, 1)), u0) [guard added] at aa.py:10 in f (_dynamo/utils.py:2092 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(CeilToInt(IntTrueDiv(u0, 1)), u0)" ``` no longer shows in the logs. Also verify CI passes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134296 Approved by: https://github.com/aorenste --- test/dynamo/test_repros.py | 2 +- torch/_refs/__init__.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 4099fb365e767a..f9347ad978798d 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -964,7 +964,7 @@ def test_do_paste_mask(self): ) # (dynamic shapes, static shapes) self.assertIn(cnt.frame_count, (5, 7)) - self.assertIn(cnt.op_count, (94, 106, 121)) + self.assertIn(cnt.op_count, (92, 106, 119)) def test_convert_boxes_to_pooler_format(self): boxes1 = [ diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index d57d205f8a61c0..160a3da6a48909 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -4989,14 +4989,14 @@ def is_finite(x): dtype = torch.int64 if integer_args else torch.get_default_dtype() is_integer = utils.is_integer_dtype(dtype) - if is_integer: + if is_integer or integer_args: xstart = sym_int(start) xend = sym_int(end) xstep = sym_int(step) # For int64 we truncate arguments to int before calculating length, but # other integral dtypes we don't. Weird... but needed to match ATen shapes. - if dtype == torch.int64: + if dtype == torch.int64 or integer_args: # Uses floordiv to avoid ceil in inductor. sgn = bool(xstep > 0) - bool(xstep < 0) # type: ignore[possibly-undefined] length = (xend - xstart + xstep - sgn) // xstep # type: ignore[possibly-undefined]