diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index f4c5f1184a0..820c4ff6107 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -1123,6 +1123,9 @@ void initNvFuserPythonBindings(PyObject* module) { // Mark the end of segmentation inst::Trace::instance()->endEvent(nullptr); }) + .def("inputs", [](FusionDefinition& self) { return self.inputs(); }) + .def("outputs", [](FusionDefinition& self) { return self.outputs(); }) + .def("extents", [](FusionDefinition& self) { return self.extents(); }) .def( "__repr__", [](FusionDefinition& self) { diff --git a/csrc/python_frontend/segmentation.cpp b/csrc/python_frontend/segmentation.cpp index 42b085af926..3478b4dbe21 100644 --- a/csrc/python_frontend/segmentation.cpp +++ b/csrc/python_frontend/segmentation.cpp @@ -35,31 +35,52 @@ int64_t SegmentationState::setupSegmentation( IrCloner original_to_cloned_map = Fusion::copy(fusion, cloned_original_fusion_.get()); + // Step 2) Given the map_presched_value_to_original_python_index AND the + // IRCloner returned by Fusion::copy, create a mapping from cloned CPP values + // to original fusion state indices. + std::unordered_map map_cloned_value_to_original_python_index; + map_cloned_value_to_original_python_index.reserve( + map_presched_value_to_original_python_index.size()); + std::transform( + map_presched_value_to_original_python_index.begin(), + map_presched_value_to_original_python_index.end(), + std::inserter( + map_cloned_value_to_original_python_index, + map_cloned_value_to_original_python_index.end()), + [&](const auto& item) { + const Val* original_value = item.first; + int64_t python_index = item.second; + Val* cloned_value = original_to_cloned_map.clone(original_value); + return std::make_pair(cloned_value, python_index); + }); + + // Step 3) Concretize fusion with input arguments. KernelArgumentHolder args = KernelArgumentHolder::createKernelArgumentHolder(inputs, device); - // Step 2) Concretize fusion with input arguments. std::unordered_map symbolic_to_concrete_map = DynamicTransform::concretizeFusion(cloned_original_fusion_.get(), args); - // Step 3) Given the map_presched_value_to_original_python_index, the IRCloner - // returned by Fusion::copy, AND the symbolic_to_concrete map returned by - // concretization pass, create a mapping from cloned Vals to original fusion - // state indices. + // Given the map_cloned_value_to_original_python_index AND the + // symbolic_to_concrete map returned by the concretization pass, create a + // mapping from cloned, concretized CPP values to original fusion state + // indices. + map_cloned_concretized_value_to_original_python_index_.reserve( + map_cloned_value_to_original_python_index.size()); std::transform( - map_presched_value_to_original_python_index.begin(), - map_presched_value_to_original_python_index.end(), + map_cloned_value_to_original_python_index.begin(), + map_cloned_value_to_original_python_index.end(), std::inserter( map_cloned_concretized_value_to_original_python_index_, map_cloned_concretized_value_to_original_python_index_.end()), [&](const auto& item) { - const Val* original_value = item.first; + Val* maybe_concretized_value = item.first; int64_t python_index = item.second; - Val* cloned_val = original_to_cloned_map.clone(original_value); - if (symbolic_to_concrete_map.count(cloned_val)) { - cloned_val = symbolic_to_concrete_map.at(cloned_val); + if (symbolic_to_concrete_map.count(maybe_concretized_value) > 0) { + maybe_concretized_value = + symbolic_to_concrete_map.at(maybe_concretized_value); } - return std::make_pair(cloned_val, python_index); + return std::make_pair(maybe_concretized_value, python_index); }); // Track the extents for input TensorViews in cloned CPP Fusion. diff --git a/nvfuser/__init__.py b/nvfuser/__init__.py index 5581050f4de..967be681b7f 100644 --- a/nvfuser/__init__.py +++ b/nvfuser/__init__.py @@ -110,6 +110,82 @@ def __exit__(self, type, value, traceback): def definition(self): raise NotImplementedError("definition() should be implemented by child class!") + def _execute_segments(self, input_arguments, *, device=None, profile=False): + """ + Run the sequence of FusionDefinition segments to generate the results + of this FusionDefinition. + + This FusionDefinition acts an argument manager. It gathers input + arguments for the segments and stores their output results. After + running a segment, any redundant intermediate values, which are + unnecessary for any other segments, are deleted to save memory. + + Args: + inputs (List[Union[Tensor, Scalar]]): A list of inputs to fusion. + + Kwargs: + device (Optional[Union[int, str, torch.device]]): This is a hint to run + the Fusion on the given CUDA device. This is not typically + necessary, as the device is usually inferred from the locations + of input tensors. However, for some fusion definitions, no + tensors will be input (for example when all tensors are + generated with `full` or `uniform` ops). In these cases, we + must either tell NVFuser where to run the resulting kernel, or + let it default to 0. Note that passing this option providing + and input tensors that lie on another device is an error. + profile (bool): Captures a CUPTI based profile of a fusion. + + + Returns: + List[Tensor]: The output results for this FusionDefinition. + """ + assert len(self.segments) > 0 + assert len(self.segments) == len(self.segment_index_space_maps) + + input_arguments_with_extents = [*input_arguments] + for a in input_arguments: + if type(a) is torch.Tensor: + input_arguments_with_extents.extend(a.size()) + + # Map inputs arguments to original fid + map_original_fid_to_value = { + fd_state: argument + for fd_state, argument in zip( + self.inputs() + self.extents(), input_arguments_with_extents + ) + } + + # Run all segments in correct order + for idx, segment in enumerate(self.segments): + segment_to_original_map = self.segment_index_space_maps[segment] + + # Gather segment input arguments + segment_arguments = [ + map_original_fid_to_value[segment_to_original_map[fd_state]] + for fd_state in segment.inputs() + ] + + # Run segment + segment_outputs = segment.execute( + segment_arguments, device=device, profile=profile + ) + + # Update original fusion definition indices to outputs + for fd_state, output in zip(segment.outputs(), segment_outputs): + map_original_fid_to_value[segment_to_original_map[fd_state]] = output + + # Destroy any arguments that are not used by future segments + for segment_input in segment.inputs(): + original_input = segment_to_original_map[segment_input] + if ( + original_input not in self.outputs() + and self.map_value_to_last_used_segment[original_input] == idx + ): + del map_original_fid_to_value[original_input] + + # Map output fid to actual results + return [map_original_fid_to_value[fd_state] for fd_state in self.outputs()] + def execute( self, inputs, @@ -225,6 +301,9 @@ def execute( fake_mode = FakeTensorMode() self.fake_inputs = [fake_mode.from_tensor(inp) for inp in inputs] + if hasattr(self, "segments") and len(self.segments) > 0: + return self._execute_segments(inputs, device=device, profile=profile) + results = None try: diff --git a/tests/python/test_ops.py b/tests/python/test_ops.py index f5c8a57dcab..d653e005736 100644 --- a/tests/python/test_ops.py +++ b/tests/python/test_ops.py @@ -87,7 +87,9 @@ def torch_correctness_test_fn(fd_fn: Callable, nvf_op: OpInfo, sample: SampleInp assert check_captured_python_definition(nvfuser_result, fd, inputs_cap) if nvf_op.is_clonable: - assert check_cpp_translation(nvfuser_result, fd, inputs_cap) + assert check_cpp_translation( + nvfuser_result, fd, inputs_cap, supports_segmentation=True + ) torch_result = nvf_op.reference(*sample.args, **sample.kwargs) diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index a181ef83fa5..bd7b3416eba 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -2831,7 +2831,8 @@ def fusion_func(fd: FusionDefinition) -> None: T89 = fd.ops.sum(T98, dims=[4], keepdim=False, dtype=DataType.Null) fd.add_output(T89) - nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + # TODO Segmentation fails validateAllocationSizesAndStrides + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False) # This tests no dead code at definition does not cause a problem due to # removal of empty tensors @@ -3237,7 +3238,7 @@ def fusion_func(fd: FusionDefinition) -> None: fd.add_output(T54) fd.add_output(T30) - nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False) # self.assertEqual(nvf_out[0], t24) # Test that symbolic IterDomains can be concatenated @@ -3769,7 +3770,7 @@ def fusion_func(fd: FusionDefinition) -> None: fd.add_output(T57) fd.add_output(T101) - nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False) # A simple pointwise fusion, but passed misaligned input def test_misaligned_add(self): @@ -3935,7 +3936,7 @@ def fusion_func(fd: FusionDefinition) -> None: fd.add_output(T88) - nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False) # See https://github.com/NVIDIA/Fuser/issues/2275 @pytest.mark.skipif( @@ -3981,7 +3982,7 @@ def fusion_func(fd: FusionDefinition) -> None: T101 = fd.ops.cat([T7, T100], dim=-1) fd.add_output(T101) - nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False) # See https://github.com/NVIDIA/Fuser/issues/2317 @pytest.mark.skipif( @@ -4138,7 +4139,8 @@ def fusion_func(fd: FusionDefinition) -> None: # T7 = fd.ops.reshape(T1, new_shape=V5) fd.add_output(T7) - nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + # TODO Segmentation fails validateAllocationSizesAndStrides + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False) # Test empty symbolic tensors can be reshaped # See https://github.com/NVIDIA/Fuser/issues/2362 @@ -4762,7 +4764,7 @@ def fusion_func(fd: FusionDefinition) -> None: T223 = fd.ops.cat([T169, T222], dim=-1, manual_padding=0) fd.add_output(T223) - nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False) def test_enable_disable_options(self): m = 24 diff --git a/tests/python/utils.py b/tests/python/utils.py index 4b0f3e2c06c..3df42bba942 100644 --- a/tests/python/utils.py +++ b/tests/python/utils.py @@ -242,13 +242,24 @@ def check_captured_python_definition(reference_outputs, fd, inputs, device=None) # Run original FusionDefinition # Clone FusionDefinition +# Apply segmentation if it supported for this FusionDefinition # Run cloned python definition # Check that the result of cloned python definition matches original results -def check_cpp_translation(reference_outputs, fd, inputs, device=None): +def check_cpp_translation( + reference_outputs, fd, inputs, supports_segmentation, device=None +): try: torch.manual_seed(0) + + # Clone cloned_fd = FusionDefinition() clone(fd, cloned_fd) + + # Segment + if supports_segmentation: + cloned_fd.segment(inputs) + + # Run cloned_outputs = cloned_fd.execute(inputs, device=device) # Make sure the results of original and cloned definitions match. @@ -268,6 +279,7 @@ def check_cpp_translation(reference_outputs, fd, inputs, device=None): print( "(A failure here suggests a mismatch in functionality between the original and cloned definitions.)" ) + print("Does FusionDefinition supports segmentation?\t", supports_segmentation) print(fd.getReproErrorString("executing", inputs)) raise err @@ -421,6 +433,7 @@ def exec_nvfuser( new_fusion_expected=True, device=None, is_clonable=True, + supports_segmentation=True, ): fc = FusionCache.get() before_fusions = fc.num_fusions() @@ -450,5 +463,7 @@ def exec_nvfuser( ) if is_clonable: - self.assertTrue(check_cpp_translation(out, fd, inputs_cloned)) + self.assertTrue( + check_cpp_translation(out, fd, inputs_cloned, supports_segmentation) + ) return out, fd