diff --git a/csrc/python_frontend/fusion_definition.cpp b/csrc/python_frontend/fusion_definition.cpp index 036e7abd125..9efdebc17ad 100644 --- a/csrc/python_frontend/fusion_definition.cpp +++ b/csrc/python_frontend/fusion_definition.cpp @@ -696,6 +696,16 @@ int64_t FusionDefinition::setupSegmentation( preschedFusion(), map_value_to_fid_, inputs); } +std::unordered_map FusionDefinition::buildSegment( + FusionDefinition& segment_fd, + int64_t segment_id) { + NVF_CHECK(id().has_value(), "FusionDefinition does not exist!"); + NVF_CHECK( + segmentation_state_ != nullptr, + "Run setupSegmentation first before trying to build segments!"); + return segmentation_state_->buildSegment(segment_fd, segment_id); +} + void FusionDefinition::finalizeSegmentation() { // Destroy SegmentedState segmentation_state_.reset(); diff --git a/csrc/python_frontend/fusion_definition.h b/csrc/python_frontend/fusion_definition.h index 4415ff599e9..c359352c565 100644 --- a/csrc/python_frontend/fusion_definition.h +++ b/csrc/python_frontend/fusion_definition.h @@ -262,6 +262,13 @@ class NVF_API FusionDefinition : public FusionState { //! Run segmentation algorithm on FusionDefinition. Returns the number of //! segments. NVF_API int64_t setupSegmentation(const at::ArrayRef& inputs); + //! Given an empty FusionDefinition and a segment id, buildSegment creates the + //! CPP Fusion, translates it to the python FusionDefinition, then return a + //! mapping from segment fusion state indices to the original fusion state + //! indices. + NVF_API std::unordered_map buildSegment( + FusionDefinition& segment_fd, + int64_t segment_id); //! After creating segments, destroy SegmentationState. NVF_API void finalizeSegmentation(); diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index 004ee939a87..fb80756c6ac 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -1024,6 +1024,13 @@ void initNvFuserPythonBindings(PyObject* module) { } return self.setupSegmentation(inputs); }) + .def( + "_build_segment", + [](FusionDefinition& self, + FusionDefinition& other, + int64_t segment_id) { + return self.buildSegment(other, segment_id); + }) .def( "_finalize_segmentation", [](FusionDefinition& self) { diff --git a/csrc/python_frontend/segmentation.cpp b/csrc/python_frontend/segmentation.cpp index 7bdcbab3c29..42b085af926 100644 --- a/csrc/python_frontend/segmentation.cpp +++ b/csrc/python_frontend/segmentation.cpp @@ -12,65 +12,69 @@ namespace nvfuser::python_frontend { int64_t SegmentationState::setupSegmentation( Fusion* fusion, - const std::unordered_map& map_value_to_original_fid, + const std::unordered_map& + map_presched_value_to_original_python_index, const at::ArrayRef& inputs) { // Check state NVF_ERROR(fusion != nullptr); - NVF_ERROR(cloned_fusion_ == nullptr); + NVF_ERROR(cloned_original_fusion_ == nullptr); NVF_ERROR(segmented_fusion_ == nullptr); NVF_ERROR(group_run_order_.empty()); - NVF_ERROR(map_cloned_value_to_fid_.empty()); - NVF_ERROR(cloned_extents_.empty()); + NVF_ERROR(map_cloned_concretized_value_to_original_python_index_.empty()); + NVF_ERROR(cloned_original_extents_.empty()); int8_t device = getCommonDeviceCUDA(inputs); NVF_CHECK( inputs.empty() || device > -1, "Inputs are not all on the same device!"); // Step 1) Clone preschedFusion CPP Fusion. - cloned_fusion_ = std::make_unique(); + cloned_original_fusion_ = std::make_unique(); // The IRCloner returned by Fusion::copy acts as map from the original fusion // to the cloned fusion. - IrCloner original_to_cloned_map = Fusion::copy(fusion, cloned_fusion_.get()); + IrCloner original_to_cloned_map = + Fusion::copy(fusion, cloned_original_fusion_.get()); KernelArgumentHolder args = KernelArgumentHolder::createKernelArgumentHolder(inputs, device); // Step 2) Concretize fusion with input arguments. std::unordered_map symbolic_to_concrete_map = - DynamicTransform::concretizeFusion(cloned_fusion_.get(), args); + DynamicTransform::concretizeFusion(cloned_original_fusion_.get(), args); - // Step 3) Given the map_value_to_original_fid, the IRCloner returned by - // Fusion::copy, AND the symbolic_to_concrete map returned by + // Step 3) Given the map_presched_value_to_original_python_index, the IRCloner + // returned by Fusion::copy, AND the symbolic_to_concrete map returned by // concretization pass, create a mapping from cloned Vals to original fusion // state indices. std::transform( - map_value_to_original_fid.begin(), - map_value_to_original_fid.end(), - std::inserter(map_cloned_value_to_fid_, map_cloned_value_to_fid_.end()), + map_presched_value_to_original_python_index.begin(), + map_presched_value_to_original_python_index.end(), + std::inserter( + map_cloned_concretized_value_to_original_python_index_, + map_cloned_concretized_value_to_original_python_index_.end()), [&](const auto& item) { const Val* original_value = item.first; - int64_t fid = item.second; + int64_t python_index = item.second; Val* cloned_val = original_to_cloned_map.clone(original_value); if (symbolic_to_concrete_map.count(cloned_val)) { cloned_val = symbolic_to_concrete_map.at(cloned_val); } - return std::make_pair(cloned_val, fid); + return std::make_pair(cloned_val, python_index); }); // Track the extents for input TensorViews in cloned CPP Fusion. - cloned_extents_ = getExtents(cloned_fusion_.get()); + cloned_original_extents_ = getExtents(cloned_original_fusion_.get()); // Create runtime infomation SchedulerRuntimeInfo runtime_info( - cloned_fusion_.get(), + cloned_original_fusion_.get(), args, /*precomputed_values=*/nullptr, - cloned_fusion_->allTvs()); + cloned_original_fusion_->allTvs()); // Run segmentation algorithm segmented_fusion_ = SegmentCandidateFinder::segment( - std::move(cloned_fusion_), &args, runtime_info); + std::move(cloned_original_fusion_), &args, runtime_info); // Get the order for fusion segments prepareGroupOrder(); @@ -79,6 +83,214 @@ int64_t SegmentationState::setupSegmentation( return (int64_t)segmented_fusion_->groups().size(); } +// setupSegmentation transforms the Prescheduled, Symbolic Fusion to Cloned, +// Concretized Fusion. Both CPP fusions corresponds with Original +// FusionDefinition. +// +// The segmentation pass runs on cloned, concretized fusion to create +// SegmentedFusion. Each SegmentedGroup in the SegmentedFusion creates a segment +// CPP fusion that is translated to a python definition. +// +// +// NOTE: Steps 4a through 4d are run for every fusion segment. However, +// sometimes the python definition needs the extents of the original fusion's +// input tensors as extra arguments. Steps 4f to 4l creates mappings for these +// missing extents. +// +// Details: +// 1) Use segment id to get SegmentedGroup from group_run_order_. +// 2) Create CPP Fusion for SegmentedGroup. +// * IrCloner acts as a map from fusion segment to the original fusion. +// 3) Translate CPP Fusion to Python FusionDefinition +// 4) Create map from segment fusion indices to original fusion indices. +// a) Get cloned Vals for SegmentedGroup's inputs and outputs. Map them +// to their original fusion indices. +// b) Map cloned Vals to their segment Vals +// c) Map segment Vals to their fusion indices. +// d) Map original indices to segment indices. +// e) Return map if the number of input arguments for python definition +// matches the number of input arguments for CPP fusion. +// f) Create a map from segment to cloned extents. +// g) Create a map from segment fusion indices to cloned extents. +// h) Find segment inputs that are missing from segment to original +// indices map. +// i) Get segment Vals for the missing segment fusion indices. +// j) Map segment Vals to cloned Vals. +// k) Map cloned Vals to their corresponding fusion indices. +// l) Add missing mappings to segment to original indices map. +// 5) Return the mapping from the segmented FusionDefinition index space to +// original FusionDefinition index space. +std::unordered_map SegmentationState::buildSegment( + FusionDefinition& segment_fd, + int64_t segment_id) { + NVF_ERROR( + !segment_fd.completed(), + "Expected an incomplete definition before translation."); + NVF_ERROR( + segmented_fusion_ != nullptr, + "SegmentedFusion is not initialized. Run setupSegmentation first."); + NVF_ERROR( + segment_id >= 0 && + segment_id < (int64_t)segmented_fusion_->groups().size(), + "The segment id is not valid"); + + // Step 1) Use segment id to get SegmentedGroup from group_run_order_. + SegmentedGroup* sg = group_run_order_.at(segment_id); + NVF_ERROR(sg != nullptr); + + // Step 2) Create CPP Fusion for SegmentedGroup. The IrCloner acts as a map + // from fusion segment to the original fusion. + std::pair> cloner_segment_pair = + segmented_fusion_->makeFusion(sg); + IrCloner cloned_to_segment_map = cloner_segment_pair.first; + std::unique_ptr fusion_segment = + std::move(cloner_segment_pair.second); + + // Step 3) Translate CPP Fusion to Python FusionDefinition + std::unordered_map + map_segment_cpp_value_to_python_index = + translate(fusion_segment.get(), &segment_fd); + + // Step 4) Create map from segment fusion indices to original fusion indices. + // Step 4a) Get FusionDefinition index for cloned inputs and outputs. Map them + // to their original fusion indices. + const std::vector& cloned_inputs = sg->inputs(); + const std::vector& cloned_outputs = sg->outputs(); + + std::vector original_python_index; + original_python_index.reserve(cloned_inputs.size() + cloned_outputs.size()); + + std::transform( + cloned_inputs.begin(), + cloned_inputs.end(), + std::back_inserter(original_python_index), + [&](Val* v) { + return map_cloned_concretized_value_to_original_python_index_.at(v); + }); + + std::transform( + cloned_outputs.begin(), + cloned_outputs.end(), + std::back_inserter(original_python_index), + [&](Val* v) { + return map_cloned_concretized_value_to_original_python_index_.at(v); + }); + + // Step 4b) ir_cloner maps cloned fusion Vals to segment Vals. + std::vector segment_inputs_outputs; + segment_inputs_outputs.reserve(cloned_inputs.size() + cloned_outputs.size()); + + std::transform( + cloned_inputs.begin(), + cloned_inputs.end(), + std::back_inserter(segment_inputs_outputs), + [&](Val* v) { return cloned_to_segment_map.clone(v); }); + + std::transform( + cloned_outputs.begin(), + cloned_outputs.end(), + std::back_inserter(segment_inputs_outputs), + [&](Val* v) { return cloned_to_segment_map.clone(v); }); + + // Step 4c) Map segment Vals to their FusionDefinition index. + std::vector segment_python_index; + segment_python_index.reserve(segment_inputs_outputs.size()); + std::transform( + segment_inputs_outputs.begin(), + segment_inputs_outputs.end(), + std::back_inserter(segment_python_index), + [&](Val* v) { return map_segment_cpp_value_to_python_index.at(v); }); + + // Step 4d) Map original indices to segment indices. + NVF_ERROR(original_python_index.size() == segment_python_index.size()); + std::unordered_map segment_to_original_python_index_map; + for (size_t idx : c10::irange(original_python_index.size())) { + segment_to_original_python_index_map.emplace( + segment_python_index.at(idx), original_python_index.at(idx)); + } + + // Step 4e) short-circuit: No extra extents required for python definition. + if (fusion_segment->inputs().size() == segment_fd.inputs().size()) { + return segment_to_original_python_index_map; + } + + // The python segment can require the size of tensor dimensions from original + // fusion's input arguments, which the CPP segment does not. + + // Step 4f) Create a map from segment to cloned extents. + // Step 4g) Create a map from segment indices to segment extents. + std::unordered_map segment_to_cloned_extents; + std::unordered_map segment_python_index_to_cpp_val; + for (Val* cloned_extent : cloned_original_extents_) { + Val* segment_extent = cloned_to_segment_map.clone(cloned_extent); + + // short-circuit: some extents are not used in segment + if (map_segment_cpp_value_to_python_index.count(segment_extent) == 0) { + continue; + } + + size_t segment_python_index = + map_segment_cpp_value_to_python_index.at(segment_extent); + segment_to_cloned_extents.emplace(segment_extent, cloned_extent); + segment_python_index_to_cpp_val.emplace( + segment_python_index, segment_extent); + } + + // Step 4h) Find the set difference between all segment input indices and + // known input segment indices. + std::vector missing_segment_python_index; + for (int64_t input_python_index : segment_fd.inputs()) { + if (segment_to_original_python_index_map.count(input_python_index) == 0) { + missing_segment_python_index.push_back(input_python_index); + } + } + + // Step 4i) Get segment Val for missing segment input indices. + std::vector missing_segment_val; + missing_segment_val.reserve(missing_segment_python_index.size()); + std::transform( + missing_segment_python_index.begin(), + missing_segment_python_index.end(), + std::back_inserter(missing_segment_val), + [&](int64_t segment_python_index) { + return segment_python_index_to_cpp_val.at(segment_python_index); + }); + + // Step 4j) Map segment Vals to cloned Vals + std::vector missing_cloned_val; + missing_cloned_val.reserve(missing_segment_val.size()); + std::transform( + missing_segment_val.begin(), + missing_segment_val.end(), + std::back_inserter(missing_cloned_val), + [&](Val* segment_val) { + return segment_to_cloned_extents.at(segment_val); + }); + + // Step 4k) Transform cloned Vals to their original fusion indices. + std::vector missing_cloned_python_index; + missing_cloned_python_index.reserve(missing_cloned_val.size()); + std::transform( + missing_cloned_val.begin(), + missing_cloned_val.end(), + std::back_inserter(missing_cloned_python_index), + [&](Val* cloned_val) { + return map_cloned_concretized_value_to_original_python_index_.at( + cloned_val); + }); + + // Step 4l) Add missing mappings from segment to original indices. + for (size_t idx : c10::irange(missing_segment_python_index.size())) { + segment_to_original_python_index_map.emplace( + missing_segment_python_index.at(idx), + missing_cloned_python_index.at(idx)); + } + + // Return the mapping from the index space of segment FusionDefinition to the + // index space of the original FusionDefinition. + return segment_to_original_python_index_map; +} + void SegmentationState::prepareGroupOrder() { NVF_ERROR(segmented_fusion_ != nullptr); diff --git a/csrc/python_frontend/segmentation.h b/csrc/python_frontend/segmentation.h index b2cb6582fdf..e4ac019dff3 100644 --- a/csrc/python_frontend/segmentation.h +++ b/csrc/python_frontend/segmentation.h @@ -170,21 +170,39 @@ class SegmentationState { // Details: // 1) Clone preschedFusion CPP Fusion. // 2) Concretize fusion using input arguments. - // 3) Given the map_value_to_original_fid, the IRCloner returned by - // Fusion::copy, AND symbolic_to_concrete map returned by concretization - // pass, create a mapping from cloned Vals to original fusion state - // indices - // 4) Get extents for cloned fusion - // 5) Create SchedulerRuntimeInfo + // 3) Given the map_presched_value_to_original_python_index, the IRCloner + // returned by Fusion::copy, AND symbolic_to_concrete map returned by + // concretization pass, create a mapping from cloned Vals to original fusion + // state indices. + // 4) Get extents for cloned fusion. + // 5) Create SchedulerRuntimeInfo. // 6) Run segmentation algorithm using cloned fusion, input arguments, and // scheduler runtime information. // 7) Get sequential order of fusion segments using prepareGroupOrder. // 8) Return the number of segments created by segmentation algorithm. int64_t setupSegmentation( Fusion* fusion, - const std::unordered_map& map_value_to_original_fid, + const std::unordered_map& + map_presched_value_to_original_python_index, const at::ArrayRef& inputs); + // Given an empty FusionDefinition and a segment id, buildSegment creates the + // CPP Fusion, translates it to the python FusionDefinition, then returns a + // mapping from segment fusion state indices to the original fusion state + // indices. + // + // The mapping is constructed from the segment's python definition -> + // segment's CPP Fusion -> original's CPP Fusion -> original's python + // definition. + // + // NOTE: Sometimes the python definition requires the extents from the + // original fusion's input tensors as extra arguments. Therefore, the input + // arguments for the python definition and the CPP Fusion may not exactly + // match. + NVF_API std::unordered_map buildSegment( + FusionDefinition& segment_fd, + int64_t segment_id); + private: // prepareGroupOrder is similar to prepareRuntimeOrder. It generates the // topological order of SegmentedGroups in SegmentedFusion. @@ -206,7 +224,7 @@ class SegmentationState { private: // Clone of original fusion for segmentation - std::unique_ptr cloned_fusion_ = nullptr; + std::unique_ptr cloned_original_fusion_ = nullptr; // This FusionDefinition may require multiple kernels if it cannot be handled // by a single heuristic scheduler. SegmentedFusion takes a fusion and runs @@ -216,12 +234,13 @@ class SegmentationState { // Pre-determined order to run the segmented groups std::vector group_run_order_; - // Create copy of fusion for segmentation algorithm. IrCloner is a map - // between values in original and cloned fusions. - std::unordered_map map_cloned_value_to_fid_; + // Map values from cloned, concretized fusion to the indices of the original + // python definition. + std::unordered_map + map_cloned_concretized_value_to_original_python_index_; // Extents for TensorView input arguments for cloned Fusion - std::vector cloned_extents_; + std::vector cloned_original_extents_; }; } // namespace nvfuser::python_frontend diff --git a/nvfuser/__init__.py b/nvfuser/__init__.py index 46e95dc9d3c..5581050f4de 100644 --- a/nvfuser/__init__.py +++ b/nvfuser/__init__.py @@ -56,9 +56,46 @@ def __init__(self, id=None, max_length=1024): self.profiled = False def segment(self, inputs): + """ + Decompose this FusionDefinition into a sequence of segment + FusionDefinitions. + + This function runs the nvfuser segmentation algorithm and translates the + segments into their corresponding FusionDefinitions. + + Args: + inputs (List[Union[Tensor, Scalar]]): A list of inputs to fusion. + + Returns: + List[FusionDefinition]: The FusionDefinitions corresponding to the + sub-fusion segments of this FusionDefinition. + """ num_segments = self._setup_segmentation(inputs) + if num_segments == 1: + self._finalize_segmentation() + return [] + + # Track all segments for this FusionDefinition + self.segments = [] + + # Track map_segment_fid_to_original_fid for each segment + self.segment_index_space_maps = {} + + # Track the last segment a value is used as an input + self.map_value_to_last_used_segment = {} + + for idx in range(num_segments): + new_fd = FusionDefinition() + map_segment_fid_to_original_fid = self._build_segment(new_fd, idx) + + for segment_input in new_fd.inputs(): + original_input = map_segment_fid_to_original_fid[segment_input] + self.map_value_to_last_used_segment[original_input] = idx + + self.segment_index_space_maps[new_fd] = map_segment_fid_to_original_fid + self.segments.append(new_fd) self._finalize_segmentation() - return num_segments + return self.segments def __enter__(self): return self._setup_definition() diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index f68e83b1fc7..a181ef83fa5 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -1263,7 +1263,40 @@ def fusion_func(fd: FusionDefinition): with FusionDefinition() as fd: fusion_func(fd) - assert fd.segment(inputs) == 2 + + # create segments + segments = fd.segment(inputs) + + # Each segment has a map from its segment index space to the original + # fusion's index space. The original fusion creates two segments. + # Check that segment maps match expected behavior. + assert len(fd.segment_index_space_maps) == 2 + + # First Segment: + # def nvfuser_fusion_id2(fd : FusionDefinition) -> None : + # T0 = fd.define_tensor(shape=[-1, -1], + # contiguity=[True, True], + # dtype=DataType.Float, + # is_cpu=False) + # T1 = fd.ops.sum(T0, dims=[1], keepdim=False, dtype=DataType.Float) + # T2 = fd.ops.broadcast(T1, is_broadcast_dim=[False, True]) + # fd.add_output(T2) + # + assert fd.segment_index_space_maps[segments[0]] == {2: 2, 0: 0} + + # Second Segment: + # def nvfuser_fusion_id3(fd : FusionDefinition) -> None : + # T0 = fd.define_tensor(shape=[-1, -1], + # contiguity=[True, True], + # dtype=DataType.Float, + # is_cpu=False) + # T1 = fd.define_tensor(shape=[-1, 1], + # contiguity=[True, None], + # dtype=DataType.Float, + # is_cpu=False) + # T2 = fd.ops.add(T0, T1) + # fd.add_output(T2) + assert fd.segment_index_space_maps[segments[1]] == {2: 3, 1: 2, 0: 1} def test_arithmetic_ops(self): inputs = [