Skip to content

Commit

Permalink
move prepareGroupOrder
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 committed Nov 2, 2024
1 parent 8923dc7 commit a56a31b
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 62 deletions.
62 changes: 0 additions & 62 deletions csrc/python_frontend/fusion_definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -699,66 +699,4 @@ void FusionDefinition::finalizeSegmentation() {
segmentation_state_.reset();
}

void SegmentationState::prepareGroupOrder() {
NVF_ERROR(segmented_fusion_ != nullptr);

// Setup group run order
std::unordered_set<Val*> available_input;

// setup the order tensor dimensions are bound
std::copy(
segmented_fusion_->inputs().begin(),
segmented_fusion_->inputs().end(),
std::inserter(available_input, available_input.end()));

// The size of the tensor dimensions can be used as an input of the segments.
// NvFuser does not support returning scalar values. Segmentation must pass
// those sizes as segment arguments manually.
std::vector<Val*> extents = getExtents(segmented_fusion_->completeFusion());
std::copy(
extents.begin(),
extents.end(),
std::inserter(available_input, available_input.end()));

// Keep track of groups that has run
std::vector<bool> group_ran(segmented_fusion_->groups().size(), false);

while (!std::all_of(
group_ran.begin(), group_ran.end(), [](bool b) { return b; })) {
bool ran_any_group = false;

// Find the first segment with all inputs available to run
for (size_t group_i : c10::irange(segmented_fusion_->groups().size())) {
SegmentedGroup* group = segmented_fusion_->groups().at(group_i);
// short-circuit: already ran group.
if (group_ran.at(group_i)) {
continue;
}
const std::vector<Val*>& group_inputs = group->inputs();
bool ready_to_run = std::all_of(
group_inputs.begin(),
group_inputs.end(),
[&available_input](Val* val) { return available_input.count(val); });

// short-circuit: group is not ready to run.
if (!ready_to_run) {
continue;
}

group_run_order_.push_back(group);

// Insert graph segment output to tensor map
const std::vector<Val*>& group_outputs = group->outputs();
for (size_t group_out_i : c10::irange(group_outputs.size())) {
available_input.insert(group_outputs.at(group_out_i));
}
group_ran[group_i] = true;
ran_any_group = true;
}
NVF_ERROR(
ran_any_group,
"Failed to run all groups; An error must have occured in segmentation.");
}
}

} // namespace nvfuser::python_frontend
64 changes: 64 additions & 0 deletions csrc/python_frontend/segmentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,4 +244,68 @@ std::unordered_map<int64_t, int64_t> SegmentationState::buildSegment(
return segment_fid_to_original_fid_map;
}

void SegmentationState::prepareGroupOrder() {
NVF_ERROR(segmented_fusion_ != nullptr);

// Gather initial inputs for SegmentedFusion.
std::unordered_set<Val*> available_input;
std::copy(
segmented_fusion_->inputs().begin(),
segmented_fusion_->inputs().end(),
std::inserter(available_input, available_input.end()));

// The size of the tensor dimensions can be used as an input of the segments.
// NvFuser does not support returning scalar values. Segmentation must pass
// those sizes as segment arguments manually.
std::vector<Val*> extents = getExtents(segmented_fusion_->completeFusion());
std::copy(
extents.begin(),
extents.end(),
std::inserter(available_input, available_input.end()));

// Track the run status of all SegmentedGroups in SegmentedFusion
std::vector<bool> group_ran(segmented_fusion_->groups().size(), false);

// While not all the SegmentedGroups are run:
while (!std::all_of(
group_ran.begin(), group_ran.end(), [](bool b) { return b; })) {
bool ran_any_group = false;

// Find the first segment with all inputs available to run
for (size_t group_i : c10::irange(segmented_fusion_->groups().size())) {
SegmentedGroup* group = segmented_fusion_->groups().at(group_i);

// short-circuit: Already ran this segmented group.
if (group_ran.at(group_i)) {
continue;
}

const std::vector<Val*>& group_inputs = group->inputs();
bool ready_to_run = std::all_of(
group_inputs.begin(),
group_inputs.end(),
[&available_input](Val* val) { return available_input.count(val); });

// short-circuit: This segmented group is not ready to run.
if (!ready_to_run) {
continue;
}

// Add SegmentedGroup to group_run_order_.
group_run_order_.push_back(group);

// Mark all outputs of SegmentedGroup as ready.
const std::vector<Val*>& group_outputs = group->outputs();
for (size_t group_out_i : c10::irange(group_outputs.size())) {
available_input.insert(group_outputs.at(group_out_i));
}
group_ran[group_i] = true;
ran_any_group = true;
}
NVF_ERROR(
ran_any_group,
"Failed to run all groups; An error must have occured in segmentation.");
}
}

} // namespace nvfuser::python_frontend

0 comments on commit a56a31b

Please sign in to comment.