Skip to content

Commit

Permalink
Some improvements to nonzero post guard_size_oblivious (pytorch#122156)
Browse files Browse the repository at this point in the history
Prompted by pytorch#121571

Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: pytorch#122156
Approved by: https://github.com/jansel
  • Loading branch information
ezyang authored and pytorchmergebot committed Mar 28, 2024
1 parent caa57e4 commit 8c8e4e3
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 31 deletions.
8 changes: 0 additions & 8 deletions test/onnx/test_fx_op_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -1587,14 +1587,6 @@ def skip_torchlib_forward_compatibility(
reason="Output 'shape' do not match: torch.Size([0, 1]) != torch.Size([0, 0]).",
model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
),
xfail(
"nonzero",
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
reason=onnx_test_common.reason_onnx_script_does_not_support(
"aten::_assert_async.msg",
"https://github.com/pytorch/pytorch/issues/112443",
),
),
xfail(
"scatter_add",
matcher=lambda sample: len(sample.input.shape) == 0,
Expand Down
45 changes: 22 additions & 23 deletions torch/_subclasses/fake_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,38 +304,37 @@ def nonzero(fake_mode, func, arg):
if arg.nonzero_memo is not None:
nnz = arg.nonzero_memo
else:
nnz = fake_mode.shape_env.create_unbacked_symint()

# This is unsound, but it works well in practice
# See https://docs.google.com/document/d/1lFRYAJo5nrfxRhwIzGnfi2pbLpU6T4ytSRSuLJ5qebI/edit#
# TODO: Add a config knob to turn off this unsound behavior
#
# NB: If numel < 2, the bounds here might be COMPLETELY
# disjoint with what can actually occur. But this is fine:
# remember, the hypothesis is that if your later code works
# with N >= 2, it will work with N = 1 and N = 0.
maxval = sys.maxsize - 1

# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
has_free_symbols,
)

if not has_free_symbols(arg.numel()):
# Don't upgrade the range if numel is less than two, since we then
# have an empty range which makes things go explodey. We also
# don't allow for 2 because that would specialize the unbacked
# SymInt to 2, which is also likely to be buggy.
if arg.numel() > 2:
if not has_free_symbols(arg.numel()) and arg.numel() == 0:
# If numel is zero, then the output size must be zero.
# In this case, we must not allocate an unbacked SymInt,
# because if we do, it will immediately get refined to
# zero, but this will be inconsistent with size oblivious
# tests (which will continue to claim that the unbacked
# symint cannot equal zero). We could also unconditionally
# allocate an unbacked SymInt and not refine its range,
# but this seems more precise.
nnz = arg._nonzero_memo = 0
arg._nonzero_memo_vc = arg._version
else:
nnz = fake_mode.shape_env.create_unbacked_symint()

maxval = sys.maxsize - 1

if not has_free_symbols(arg.numel()):
maxval = int(arg.numel())

_constrain_range_for_size(nnz, max=maxval)
_constrain_range_for_size(nnz, max=maxval)

if not torch.is_inference_mode_enabled():
# arg._version N/A in inference mode
arg._nonzero_memo = nnz
arg._nonzero_memo_vc = arg._version
if not torch.is_inference_mode_enabled():
# arg._version N/A in inference mode
arg._nonzero_memo = nnz
arg._nonzero_memo_vc = arg._version

return arg.new_empty((nnz, arg.dim()), dtype=torch.int64)

Expand Down

0 comments on commit 8c8e4e3

Please sign in to comment.