diff --git a/csrc/python_frontend/fusion_definition.cpp b/csrc/python_frontend/fusion_definition.cpp index 7525e525522..fd3cc1f5034 100644 --- a/csrc/python_frontend/fusion_definition.cpp +++ b/csrc/python_frontend/fusion_definition.cpp @@ -699,66 +699,4 @@ void FusionDefinition::finalizeSegmentation() { segmentation_state_.reset(); } -void SegmentationState::prepareGroupOrder() { - NVF_ERROR(segmented_fusion_ != nullptr); - - // Setup group run order - std::unordered_set available_input; - - // setup the order tensor dimensions are bound - std::copy( - segmented_fusion_->inputs().begin(), - segmented_fusion_->inputs().end(), - std::inserter(available_input, available_input.end())); - - // The size of the tensor dimensions can be used as an input of the segments. - // NvFuser does not support returning scalar values. Segmentation must pass - // those sizes as segment arguments manually. - std::vector extents = getExtents(segmented_fusion_->completeFusion()); - std::copy( - extents.begin(), - extents.end(), - std::inserter(available_input, available_input.end())); - - // Keep track of groups that has run - std::vector group_ran(segmented_fusion_->groups().size(), false); - - while (!std::all_of( - group_ran.begin(), group_ran.end(), [](bool b) { return b; })) { - bool ran_any_group = false; - - // Find the first segment with all inputs available to run - for (size_t group_i : c10::irange(segmented_fusion_->groups().size())) { - SegmentedGroup* group = segmented_fusion_->groups().at(group_i); - // short-circuit: already ran group. - if (group_ran.at(group_i)) { - continue; - } - const std::vector& group_inputs = group->inputs(); - bool ready_to_run = std::all_of( - group_inputs.begin(), - group_inputs.end(), - [&available_input](Val* val) { return available_input.count(val); }); - - // short-circuit: group is not ready to run. - if (!ready_to_run) { - continue; - } - - group_run_order_.push_back(group); - - // Insert graph segment output to tensor map - const std::vector& group_outputs = group->outputs(); - for (size_t group_out_i : c10::irange(group_outputs.size())) { - available_input.insert(group_outputs.at(group_out_i)); - } - group_ran[group_i] = true; - ran_any_group = true; - } - NVF_ERROR( - ran_any_group, - "Failed to run all groups; An error must have occured in segmentation."); - } -} - } // namespace nvfuser::python_frontend diff --git a/csrc/python_frontend/segmentation.cpp b/csrc/python_frontend/segmentation.cpp index 26dc55da2af..95ce3c8dd7f 100644 --- a/csrc/python_frontend/segmentation.cpp +++ b/csrc/python_frontend/segmentation.cpp @@ -244,4 +244,68 @@ std::unordered_map SegmentationState::buildSegment( return segment_fid_to_original_fid_map; } +void SegmentationState::prepareGroupOrder() { + NVF_ERROR(segmented_fusion_ != nullptr); + + // Gather initial inputs for SegmentedFusion. + std::unordered_set available_input; + std::copy( + segmented_fusion_->inputs().begin(), + segmented_fusion_->inputs().end(), + std::inserter(available_input, available_input.end())); + + // The size of the tensor dimensions can be used as an input of the segments. + // NvFuser does not support returning scalar values. Segmentation must pass + // those sizes as segment arguments manually. + std::vector extents = getExtents(segmented_fusion_->completeFusion()); + std::copy( + extents.begin(), + extents.end(), + std::inserter(available_input, available_input.end())); + + // Track the run status of all SegmentedGroups in SegmentedFusion + std::vector group_ran(segmented_fusion_->groups().size(), false); + + // While not all the SegmentedGroups are run: + while (!std::all_of( + group_ran.begin(), group_ran.end(), [](bool b) { return b; })) { + bool ran_any_group = false; + + // Find the first segment with all inputs available to run + for (size_t group_i : c10::irange(segmented_fusion_->groups().size())) { + SegmentedGroup* group = segmented_fusion_->groups().at(group_i); + + // short-circuit: Already ran this segmented group. + if (group_ran.at(group_i)) { + continue; + } + + const std::vector& group_inputs = group->inputs(); + bool ready_to_run = std::all_of( + group_inputs.begin(), + group_inputs.end(), + [&available_input](Val* val) { return available_input.count(val); }); + + // short-circuit: This segmented group is not ready to run. + if (!ready_to_run) { + continue; + } + + // Add SegmentedGroup to group_run_order_. + group_run_order_.push_back(group); + + // Mark all outputs of SegmentedGroup as ready. + const std::vector& group_outputs = group->outputs(); + for (size_t group_out_i : c10::irange(group_outputs.size())) { + available_input.insert(group_outputs.at(group_out_i)); + } + group_ran[group_i] = true; + ran_any_group = true; + } + NVF_ERROR( + ran_any_group, + "Failed to run all groups; An error must have occured in segmentation."); + } +} + } // namespace nvfuser::python_frontend