diff --git a/csrc/python_frontend/fusion_definition.h b/csrc/python_frontend/fusion_definition.h index 37f4def84c5..d1d6aa8a9e9 100644 --- a/csrc/python_frontend/fusion_definition.h +++ b/csrc/python_frontend/fusion_definition.h @@ -22,8 +22,9 @@ class FusionDefinition; class FusionInterface; class FusionState; struct RecordFunctor; -struct UserSchedule; +class SegmentationState; struct TrieNode; +struct UserSchedule; //! This is helper function used to print a python formated //! Fusion IR DataType when printing a fusion definition. diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index 0f2a9f9314d..3a465516dd5 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -1950,7 +1950,7 @@ def fusion_func(fd: FusionDefinition): # ...but skip outputting scalars, which we don't support fd.add_output(t) - nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False) ab = [inputs[0], inputs[1], 1.0, -1.0] i = 0 @@ -3183,7 +3183,7 @@ def fusion_func(fd: FusionDefinition) -> None: fd.add_output(T54) fd.add_output(T30) - nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False) # self.assertEqual(nvf_out[0], t24) # Test that symbolic IterDomains can be concatenated @@ -3881,7 +3881,7 @@ def fusion_func(fd: FusionDefinition) -> None: fd.add_output(T88) - nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, is_clonable=True) + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False) # See https://github.com/NVIDIA/Fuser/issues/2275 @pytest.mark.skipif( @@ -3927,7 +3927,7 @@ def fusion_func(fd: FusionDefinition) -> None: T101 = fd.ops.cat([T7, T100], dim=-1) fd.add_output(T101) - nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, is_clonable=True) + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False) # See https://github.com/NVIDIA/Fuser/issues/2317 @pytest.mark.skipif( @@ -4708,4 +4708,4 @@ def fusion_func(fd: FusionDefinition) -> None: T223 = fd.ops.cat([T169, T222], dim=-1, manual_padding=0) fd.add_output(T223) - nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False) diff --git a/tests/python/utils.py b/tests/python/utils.py index aa5e37e1ebb..8808a45ecae 100644 --- a/tests/python/utils.py +++ b/tests/python/utils.py @@ -242,14 +242,24 @@ def check_captured_python_definition(reference_outputs, fd, inputs, device=None) # Run original FusionDefinition # Clone FusionDefinition +# Apply segmentation if it supported for this FusionDefinition # Run cloned python definition # Check that the result of cloned python definition matches original results -def check_cpp_translation(reference_outputs, fd, inputs, device=None): +def check_cpp_translation( + reference_outputs, fd, inputs, supports_segmentation, device=None +): try: torch.manual_seed(0) + + # Clone cloned_fd = FusionDefinition() clone(fd, cloned_fd) - cloned_fd.segment(inputs) + + # Segment + if supports_segmentation: + cloned_fd.segment(inputs) + + # Run cloned_outputs = cloned_fd.execute(inputs, device=device) # Make sure the results of original and cloned definitions match. @@ -269,6 +279,7 @@ def check_cpp_translation(reference_outputs, fd, inputs, device=None): print( "(A failure here suggests a mismatch in functionality between the original and cloned definitions.)" ) + print("Does FusionDefinition supports segmentation?\t", supports_segmentation) print(fd.getReproErrorString("executing", inputs)) raise err @@ -420,6 +431,7 @@ def exec_nvfuser( new_fusion_expected=True, device=None, is_clonable=True, + supports_segmentation=True, ): fc = FusionCache.get() before_fusions = fc.num_fusions() @@ -442,5 +454,7 @@ def exec_nvfuser( self.assertEqual(fc.num_fusions() - before_fusions, int(new_fusion_expected)) if is_clonable: - self.assertTrue(check_cpp_translation(out, fd, inputs_cloned)) + self.assertTrue( + check_cpp_translation(out, fd, inputs_cloned, supports_segmentation) + ) return out, fd