diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index c5b89f08639..e2d291fb92b 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -739,9 +739,10 @@ void DynamicTransformConcretizer::concretizeReshape() { auto concrete_reshape_out_tv = reshape(inp_tv, view_analysis); - // We do the replacement directly here, but we must still check that the - // replacement is valid - checkConcretizedUses(incomplete_out_tv, concrete_reshape_out_tv); + // NOTE: The replacement might not yet actually be valid. For example, if + // inp_tv contains Symbolic domains that need to be squeezed, this check + // would fail at this point. So we skip checkConcretizedUses here and + // perform it later in mutate(TensorView*). // Extent expressions often change when concretizing a reshape. Here we // replace these in all downstream expressions so that the Fusion looks just @@ -1031,6 +1032,9 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { // TensorDomain and then TensorView mutate(tv->domain()); OptOutMutator::mutate(tv); + // Check concretization is valid after we've done the replacement. See note + // about squeeze inside concretizeReshape above. + checkConcretizedUses(tv, tv); } // Almost an exact copy of OptOutMutator::mutate(TensorDomain*), but diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 3b7024d4145..b610815f867 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -1409,11 +1409,9 @@ void SqueezeOp::checkConcretization(Val* old_val, Val* new_val) const { new_id->toString(), " must concretize to IterType::Broadcast but found ", new_id->toString()); - NVF_CHECK( - !new_id->hasExpandedExtent(), "Can not squeeze expanded dimension(s)."); - NVF_CHECK( - new_id->extent()->isOneInt(), - "Can not squeeze dimension(s) with size != 1."); + // NOTE: we do not check the extent here. Even if the extent is not a const + // scalar we know that it would simplify to 1 for these inputs, since this + // IterDomain is concretized to Broadcast. } } diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index f64ca30f474..e8cc0d186e4 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -4055,6 +4055,37 @@ def fusion_func(fd: FusionDefinition) -> None: "FusionDefinition's execute() did not run correctly with profile enabled!" ) + # Small repro from https://github.com/NVIDIA/Fuser/issues/2359 + def test_reshape_squeeze_concretization(self): + inputs = [ + torch.randn((100,), dtype=torch.float32, device="cuda:0").as_strided( + (2, 5, 10), (50, 10, 1) + ), + ] + + def fusion_func(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[-1, -1, -1], + contiguity=[True, True, True], + dtype=DataType.Float, + is_cpu=False, + stride_order=[2, 1, 0], + ) + T1 = fd.ops.slice( + T0, start_indices=[0, 0, 0], end_indices=[1, 2, 4], strides=[1, 1, 1] + ) + S2 = fd.define_scalar(1, dtype=DataType.Int) + S3 = fd.define_scalar(8, dtype=DataType.Int) + V4 = fd.define_vector([S2, S3], dtype=DataType.Int) + V5 = fd.define_vector([S3], dtype=DataType.Int) + T6 = fd.ops.reshape(T1, new_shape=V4) + T7 = fd.ops.reshape(T6, new_shape=V5) + # this works fine + # T7 = fd.ops.reshape(T1, new_shape=V5) + fd.add_output(T7) + + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + if __name__ == "__main__": run_tests()