Skip to content

Commit

Permalink
Use integer divison in arange length calculation when start/end/step …
Browse files Browse the repository at this point in the history
…are integral (pytorch#134296)

Fixes pytorch#133338

Test Plan:

```
TORCH_LOGS=dynamic python
import torch

torch._dynamo.config.capture_scalar_outputs = True

@torch.compile()
def f(x):
    y = x.item()
    torch._check_is_size(y)
    r = torch.arange(y, dtype=torch.float32)
    torch._check(r.size(0) == y)
    return r

f(torch.tensor([300]))
```

Before and after diff. Verify the following line

```
I0813 11:05:44.890000 652898 torch/fx/experimental/symbolic_shapes.py:5198] [0/0] runtime_assert Eq(CeilToInt(IntTrueDiv(u0, 1)), u0) [guard added] at aa.py:10 in f (_dynamo/utils.py:2092 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(CeilToInt(IntTrueDiv(u0, 1)), u0)"
```

no longer shows in the logs. Also verify CI passes.

Pull Request resolved: pytorch#134296
Approved by: https://github.com/aorenste
  • Loading branch information
bobrenjc93 authored and pytorchmergebot committed Aug 24, 2024
1 parent 1a0d00f commit 94f92fb
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,7 @@ def test_do_paste_mask(self):
)
# (dynamic shapes, static shapes)
self.assertIn(cnt.frame_count, (5, 7))
self.assertIn(cnt.op_count, (94, 106, 121))
self.assertIn(cnt.op_count, (92, 106, 119))

def test_convert_boxes_to_pooler_format(self):
boxes1 = [
Expand Down
4 changes: 2 additions & 2 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4989,14 +4989,14 @@ def is_finite(x):
dtype = torch.int64 if integer_args else torch.get_default_dtype()

is_integer = utils.is_integer_dtype(dtype)
if is_integer:
if is_integer or integer_args:
xstart = sym_int(start)
xend = sym_int(end)
xstep = sym_int(step)

# For int64 we truncate arguments to int before calculating length, but
# other integral dtypes we don't. Weird... but needed to match ATen shapes.
if dtype == torch.int64:
if dtype == torch.int64 or integer_args:
# Uses floordiv to avoid ceil in inductor.
sgn = bool(xstep > 0) - bool(xstep < 0) # type: ignore[possibly-undefined]
length = (xend - xstart + xstep - sgn) // xstep # type: ignore[possibly-undefined]
Expand Down

0 comments on commit 94f92fb

Please sign in to comment.