From deaf077de82789656c707d4b4b2c2e0d1ecee684 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Wed, 1 Mar 2023 10:51:12 -0800 Subject: [PATCH] Don't use guardless contiguity/stride-like implementations (#95733) These prevent us from simplifying tests involving unbacked SymInts, and then you end up with unbacked SymInt in guards, which is bad. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/95733 Approved by: https://github.com/tugsbayasgalan --- c10/core/TensorImpl.cpp | 29 ++++++++++++------------ test/test_proxy_tensor.py | 3 +-- torch/_subclasses/fake_tensor.py | 15 +++++------- torch/fx/experimental/symbolic_shapes.py | 11 ++++++++- 4 files changed, 32 insertions(+), 26 deletions(-) diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 2c1324036e59c..54dcd3e4e06b7 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -447,19 +447,20 @@ SymBool TensorImpl::compute_contiguous(identity) const { } SymIntArrayRef sizes = extra_meta_->sizes_; SymIntArrayRef strides = extra_meta_->strides_; - auto n = normalize_sym_sizes_strides(sizes, strides); - if (n.has_value()) { - SymNode base; - std::vector size_nodes; - std::vector stride_nodes; - std::tie(base, size_nodes, stride_nodes) = *n; - return SymBool(base->is_contiguous(size_nodes, stride_nodes)); - } else { - return _compute_contiguous(sizes, strides, extra_meta_->numel_); - } + return _compute_contiguous(sizes, strides, extra_meta_->numel_); } // The rest of them +#define DEFINE_EAGER_SYMBOOL_COMPUTE(name, nodeimpl, fallback) \ + SymBool TensorImpl::name(identity) const { \ + if (is_sparse()) { \ + return false; \ + } \ + SymIntArrayRef sizes = extra_meta_->sizes_; \ + SymIntArrayRef strides = extra_meta_->strides_; \ + return fallback(sizes, strides); \ + } + #define DEFINE_SYMBOOL_COMPUTE(name, nodeimpl, fallback) \ SymBool TensorImpl::name(identity) const { \ if (is_sparse()) { \ @@ -480,10 +481,10 @@ SymBool TensorImpl::compute_contiguous(identity) const { } // clang-format off -DEFINE_SYMBOOL_COMPUTE(compute_channels_last_contiguous_2d, is_channels_last_contiguous_2d, _compute_channels_last_contiguous_2d) -DEFINE_SYMBOOL_COMPUTE(compute_channels_last_contiguous_3d, is_channels_last_contiguous_3d, _compute_channels_last_contiguous_3d) -DEFINE_SYMBOOL_COMPUTE(compute_strides_like_channels_last_2d, is_channels_last_strides_2d, is_channels_last_strides_2d) -DEFINE_SYMBOOL_COMPUTE(compute_strides_like_channels_last_3d, is_channels_last_strides_3d, is_channels_last_strides_3d) +DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_2d, is_channels_last_contiguous_2d, _compute_channels_last_contiguous_2d) +DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_3d, is_channels_last_contiguous_3d, _compute_channels_last_contiguous_3d) +DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_2d, is_channels_last_strides_2d, is_channels_last_strides_2d) +DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_3d, is_channels_last_strides_3d, is_channels_last_strides_3d) DEFINE_SYMBOOL_COMPUTE(compute_non_overlapping_and_dense, is_non_overlapping_and_dense, _compute_non_overlapping_and_dense) // clang-format on diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 88f4aa6d782fe..52c53d3f244c4 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -899,7 +899,7 @@ def forward(self, a_1): def test_item_to_constructor(self): def f(a): r = a.item() - constrain_range(r, min=0) + constrain_range(r, min=2) return torch.empty(r) r = str(make_fx(f, tracing_mode="symbolic")(torch.randint(5, (1,))).code).strip() @@ -971,7 +971,6 @@ def forward(self, crop_camera_1, mask_1): return None""") @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") - @unittest.expectedFailure def test_unbacked_batch_resnet(self): mod = torchvision.models.resnet18() diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index c3d29185d677a..359885a9f2f06 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -426,23 +426,20 @@ def nonzero(fake_mode, func, arg): raise DynamicOutputShapeException(func) if arg.nonzero_memo is None: - from torch.fx.experimental.symbolic_shapes import ( - constrain_range, - definitely_true, - guard_int, - ) + from torch.fx.experimental.symbolic_shapes import constrain_range nnz = fake_mode.shape_env.create_unbacked_symint() # This is unsound, but it works well in practice # See https://docs.google.com/document/d/1lFRYAJo5nrfxRhwIzGnfi2pbLpU6T4ytSRSuLJ5qebI/edit# # TODO: Add a config knob to turn off this unsound behavior + # + # NB: If numel < 2, the bounds here might be COMPLETELY + # disjoint with what can actually occur. But this is fine: + # remember, the hypothesis is that if your later code works + # with N >= 2, it will work with N = 1 and N = 0. lower = 2 upper = None - # But don't give totally unsatisfiable bounds if we know it's too small! - if definitely_true(arg.numel() < 2): - lower = 0 - upper = guard_int(arg.numel()) constrain_range(nnz, min=lower, max=upper) arg._nonzero_memo = nnz diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index f9ad531dbfbf6..95e7015fd2600 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1825,7 +1825,7 @@ def _maybe_guard_eq(self, expr: Union["sympy.Eq", "sympy.Ne"], concrete_bool: bo return free = list(expr.free_symbols) - assert len(free) > 0, "The expression should not be static by this point" + assert len(free) > 0, f"The expression should not be static by this point: {expr}" # In case of really gnarly expression, we don't blow up if len(free) > 5: return @@ -1904,6 +1904,15 @@ def evaluate_expr(self, expr: "sympy.Expr", hint=None): # is not actually necessary to save a guard for the equality, # as we will implicitly generate a guard when we match that # input against the symbol + elif isinstance(concrete_val, sympy.Integer): + # WARNING: we cannot actually do simplifications on guards + # on floating point values, because Sympy generally does not + # think expressions on integers can ever be equal to floating + # point (e.g., sympy.Eq(s0/6, 0.5) evaluates to False). Without + # very clear algebraic laws that hold for floating point, such + # simplifications are error prone anyway, so be sure not to + # maybe_guard_eq in those cases. + self._maybe_guard_eq(sympy.Eq(expr, concrete_val), True) # TODO: optimize this; avoid formatting traces until we need them # NB: drop two frames; evaluate_expr and the Sym* function that