Skip to content

Commit

Permalink
Support segmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 committed Nov 1, 2024
1 parent 4511aa3 commit 594b329
Show file tree
Hide file tree
Showing 5 changed files with 446 additions and 0 deletions.
302 changes: 302 additions & 0 deletions csrc/python_frontend/fusion_definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <fusion_profiler.h>
#include <instrumentation.h>
#include <options.h>
#include <preseg_passes/pre_segmenter.h>
#include <python_frontend/fusion_cache.h>
#include <python_frontend/fusion_definition.h>
#include <python_frontend/translation.h>
Expand Down Expand Up @@ -673,4 +674,305 @@ std::vector<std::pair<double, double>> FusionDefinition::getValTolerances(
return get_val_constants(preschedFusion(), inputs);
}

void FusionDefinition::prepareGroupOrder() {
NVF_ERROR(segmented_fusion_ != nullptr);

// Setup group run order
std::unordered_set<Val*> available_input;

// setup the order tensor dimensions are bound
std::copy(
segmented_fusion_->inputs().begin(),
segmented_fusion_->inputs().end(),
std::inserter(available_input, available_input.end()));

// The size of the tensor dimensions can be used as an input of the segments.
// NvFuser does not support returning scalar values. Segmentation must pass
// those sizes as segment arguments manually.
std::vector<Val*> extents = getExtents(segmented_fusion_->completeFusion());
std::copy(
extents.begin(),
extents.end(),
std::inserter(available_input, available_input.end()));

// Keep track of groups that has run
std::vector<bool> group_ran(segmented_fusion_->groups().size(), false);

while (!std::all_of(
group_ran.begin(), group_ran.end(), [](bool b) { return b; })) {
bool ran_any_group = false;

// Find the first segment with all inputs available to run
for (size_t group_i : c10::irange(segmented_fusion_->groups().size())) {
SegmentedGroup* group = segmented_fusion_->groups().at(group_i);
// short-circuit: already ran group.
if (group_ran.at(group_i)) {
continue;
}
const std::vector<Val*>& group_inputs = group->inputs();
bool ready_to_run = std::all_of(
group_inputs.begin(),
group_inputs.end(),
[&available_input](Val* val) { return available_input.count(val); });

// short-circuit: group is not ready to run.
if (!ready_to_run) {
continue;
}

group_run_order_.push_back(group);

// Insert graph segment output to tensor map
const std::vector<Val*>& group_outputs = group->outputs();
for (size_t group_out_i : c10::irange(group_outputs.size())) {
available_input.insert(group_outputs.at(group_out_i));
}
group_ran[group_i] = true;
ran_any_group = true;
}
NVF_ERROR(
ran_any_group,
"Failed to run all groups; An error must have occured in segmentation.");
}
}

int64_t FusionDefinition::setupSegmentation(
const at::ArrayRef<c10::IValue>& inputs) {
NVF_CHECK(id().has_value(), "FusionDefinition definition does not exist!");
int8_t device = getCommonDeviceCUDA(inputs);
NVF_CHECK(
inputs.empty() || device > -1, "Inputs are not all on the same device!");

// Check segmentation state
NVF_ERROR(segment_fusion_ == nullptr);
NVF_ERROR(segmented_fusion_ == nullptr);
NVF_ERROR(group_run_order_.empty());
NVF_ERROR(map_cloned_value_to_fid_.empty());
NVF_ERROR(cloned_extents_.empty());

// Clone CPP Fusion
segment_fusion_ = std::make_unique<Fusion>();
IrCloner original_to_cloned_map =
Fusion::copy(preschedFusion(), segment_fusion_.get());

// Get arguments
KernelArgumentHolder args =
KernelArgumentHolder::createKernelArgumentHolder(inputs, device);

// Concretize fusion with input arguments. Then, map original symbolic values
// to new concrete values when building map_cloned_value_to_fid_
std::unordered_map<Val*, Val*> symbolic_to_concrete_map =
DynamicTransform::concretizeFusion(segment_fusion_.get(), args);

// NOTE: The following tests require using the MarkAliasesPreparePass before
// segmentation, but not running AllocationDomainPass when running each
// segment. See test_issue1953 and test_unpadded_catop_issue2275_repro1.

// Track mapping from cloned CPP fusion and FusionDefinition indices.
std::transform(
map_value_to_fid_.begin(),
map_value_to_fid_.end(),
std::inserter(map_cloned_value_to_fid_, map_cloned_value_to_fid_.end()),
[&](const auto& item) {
const Val* original_value = item.first;
int64_t fid = item.second;
Val* cloned_val = original_to_cloned_map.clone(original_value);
if (symbolic_to_concrete_map.count(cloned_val)) {
cloned_val = symbolic_to_concrete_map.at(cloned_val);
}
return std::make_pair(cloned_val, fid);
});

// Track the extents for input TensorViews in cloned CPP Fusion.
cloned_extents_ = getExtents(segment_fusion_.get());

// Create runtime infomation
SchedulerRuntimeInfo runtime_info(
segment_fusion_.get(),
args,
/*precomputed_values=*/nullptr,
segment_fusion_->allTvs());

// Run segmentation algorithm
segmented_fusion_ = SegmentCandidateFinder::segment(
std::move(segment_fusion_), &args, runtime_info);

// Get the order for fusion segments
prepareGroupOrder();

// Return the number of segments
return (int64_t)segmented_fusion_->groups().size();
}

std::unordered_map<int64_t, int64_t> FusionDefinition::buildSegment(
FusionDefinition& other,
int64_t segment_id) {
NVF_CHECK(id().has_value(), "FusionDefinition definition does not exist!");
NVF_ERROR(
!other.completed(),
"Expected an incomplete definition before translation.");
NVF_ERROR(
segmented_fusion_ != nullptr,
"SegmentedFusion is not initialized. Run setupSegmentation first.");
NVF_ERROR(
segment_id >= 0 &&
segment_id < (int64_t)segmented_fusion_->groups().size(),
"The segment id is not valid");

// Create new fusion segment
SegmentedGroup* sg = group_run_order_.at(segment_id);
NVF_ERROR(sg != nullptr);
std::pair<IrCloner, std::unique_ptr<Fusion>> cloner_segment_pair =
segmented_fusion_->makeFusion(sg);
IrCloner original_to_segment_map = cloner_segment_pair.first;
std::unique_ptr<Fusion> fusion_segment =
std::move(cloner_segment_pair.second);

std::unordered_map<const nvfuser::Val*, size_t>
map_translated_val_to_segment_fid =
translate(fusion_segment.get(), &other);

// Step 1: Get FusionDefinition index for original inputs and outputs.
// Use std::transform on inputs and outputs
const std::vector<Val*>& original_inputs = sg->inputs();
const std::vector<Val*>& original_outputs = sg->outputs();

std::vector<int64_t> original_fid;
original_fid.reserve(original_inputs.size() + original_outputs.size());

std::transform(
original_inputs.begin(),
original_inputs.end(),
std::back_inserter(original_fid),
[&](Val* v) { return map_cloned_value_to_fid_.at(v); });

std::transform(
original_outputs.begin(),
original_outputs.end(),
std::back_inserter(original_fid),
[&](Val* v) { return map_cloned_value_to_fid_.at(v); });

// Step 2: ir_cloner maps original fusion statements to translated statements.
// Use std::transform
std::vector<Val*> segment_inputs_outputs;
segment_inputs_outputs.reserve(
original_inputs.size() + original_outputs.size());

std::transform(
original_inputs.begin(),
original_inputs.end(),
std::back_inserter(segment_inputs_outputs),
[&](Val* v) { return original_to_segment_map.clone(v); });

std::transform(
original_outputs.begin(),
original_outputs.end(),
std::back_inserter(segment_inputs_outputs),
[&](Val* v) { return original_to_segment_map.clone(v); });

// Step 3: Map translated statements to its FusionDefinition index.
std::vector<int64_t> segment_fid;
segment_fid.reserve(segment_inputs_outputs.size());
std::transform(
segment_inputs_outputs.begin(),
segment_inputs_outputs.end(),
std::back_inserter(segment_fid),
[&](Val* v) { return map_translated_val_to_segment_fid.at(v); });

// Step 4: Map original FusionDefinition index to translated Fusion Definition
// index for inputs and outputs.
NVF_ERROR(original_fid.size() == segment_fid.size());

// Create map from original fid to segment fid.
std::unordered_map<int64_t, int64_t> segment_fid_to_original_fid_map;
for (size_t idx : c10::irange(original_fid.size())) {
segment_fid_to_original_fid_map.emplace(
segment_fid.at(idx), original_fid.at(idx));
}

// short-circuit: No extra extents required for python definition
if (fusion_segment->inputs().size() == other.inputs().size()) {
return segment_fid_to_original_fid_map;
}

// The python definition can require the size of tensor dimensions from
// original input arguments, which the original segment does not.

// Step 1a: Create a map from segment to original extents.
// Step 1a: Create a map from segment fid to segment extents.
std::unordered_map<Val*, Val*> segment_to_original_extents;
std::unordered_map<size_t, Val*> segment_fid_to_translated_val;
for (Val* original_extent : cloned_extents_) {
Val* segment_extent = original_to_segment_map.clone(original_extent);

// short-circuit: some extents are not used in segment
if (map_translated_val_to_segment_fid.count(segment_extent) == 0) {
continue;
}

size_t segment_fid = map_translated_val_to_segment_fid.at(segment_extent);
segment_to_original_extents.emplace(segment_extent, original_extent);
segment_fid_to_translated_val.emplace(segment_fid, segment_extent);
}

// Step 2: Find the set difference between all segment input fid and known
// segment fids.
std::vector<int64_t> missing_segment_fid;
for (int64_t input_fid : other.inputs()) {
if (segment_fid_to_original_fid_map.count(input_fid) == 0) {
missing_segment_fid.push_back(input_fid);
}
}

// Step 3: Get segment Val for missing segment input fids.
std::vector<Val*> missing_segment_val;
missing_segment_val.reserve(missing_segment_fid.size());
std::transform(
missing_segment_fid.begin(),
missing_segment_fid.end(),
std::back_inserter(missing_segment_val),
[&](int64_t segment_fid) {
return segment_fid_to_translated_val.at(segment_fid);
});

// Step 4: Map segment Val to cloned Val
std::vector<Val*> missing_cloned_val;
missing_cloned_val.reserve(missing_segment_val.size());
std::transform(
missing_segment_val.begin(),
missing_segment_val.end(),
std::back_inserter(missing_cloned_val),
[&](Val* segment_val) {
return segment_to_original_extents.at(segment_val);
});

// Step 5: Transform cloned Val to original fid.
std::vector<int64_t> missing_cloned_fid;
missing_cloned_fid.reserve(missing_cloned_val.size());
std::transform(
missing_cloned_val.begin(),
missing_cloned_val.end(),
std::back_inserter(missing_cloned_fid),
[&](Val* original_val) {
return map_cloned_value_to_fid_.at(original_val);
});

// Step 6: Add mapping from segment to original fid.
for (size_t idx : c10::irange(missing_segment_fid.size())) {
segment_fid_to_original_fid_map.emplace(
missing_segment_fid.at(idx), missing_cloned_fid.at(idx));
}

return segment_fid_to_original_fid_map;
}

void FusionDefinition::finalizeSegmentation() {
// Destroy SegmentedFusion
segmented_fusion_.reset(nullptr);
segment_fusion_.reset(nullptr);
group_run_order_.clear();
map_cloned_value_to_fid_.clear();
cloned_extents_.clear();
}

} // namespace nvfuser::python_frontend
30 changes: 30 additions & 0 deletions csrc/python_frontend/fusion_definition.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
// clang-format on
#pragma once
#include <exceptions.h>
#include <functional>
#include <iostream>

#include <python_frontend/fusion_state.h>
#include <runtime/fusion_executor_cache.h>
#include <visibility.h>
#include <unordered_map>

namespace nvfuser::python_frontend {

Expand Down Expand Up @@ -254,6 +256,18 @@ class NVF_API FusionDefinition : public FusionState {
//! Get all Tensors in FusionState.
NVF_API std::vector<Tensor> tensors();

//! Run segmentation algorithm on FusionDefinition. Returns the number of
//! segments.
NVF_API int64_t setupSegmentation(const at::ArrayRef<c10::IValue>& inputs);
//! Given SegmentedFusion and vector of FusionDefinition objects for the
//! fusion segments, create the fusion segments and clone their state to the
//! FusionDefinitions.
NVF_API std::unordered_map<int64_t, int64_t> buildSegment(
FusionDefinition& other,
int64_t segment_id);
//! After creating segments, destroy SegmentedFusion and RuntimeWorkspace.
NVF_API void finalizeSegmentation();

private:
//! Returns the FusionCache Ptr that holds the cache of Fusions
FusionCache* fusionCache() const;
Expand All @@ -267,6 +281,8 @@ class NVF_API FusionDefinition : public FusionState {
// Check that the NvFuser TensorView and the Python Tensor dimensions match.
// Apply after buildFusionIr
void verifyTensorDimensions();
//! Perform a topological sort on SegmentedFusion to segment order.
void prepareGroupOrder();

//! Holds the defined maximum length of a FusionDefinition in order to
//! prevent a run away error. The user should feel free to increase this
Expand All @@ -289,6 +305,20 @@ class NVF_API FusionDefinition : public FusionState {
//! Number of recording_states_ before applying user schedule
int64_t num_recording_states_presched_ = 0;

//! Clone of original fusion for segmentation
std::unique_ptr<Fusion> segment_fusion_ = nullptr;
//! This FusionDefinition may require multiple kernels if it cannot be handled
//! by a single heuristic scheduler. SegmentedFusion takes a fusion and runs
//! the segmentation algorithm.
std::unique_ptr<SegmentedFusion> segmented_fusion_ = nullptr;
//! Pre-determined order to run the segmented groups
std::vector<SegmentedGroup*> group_run_order_;
//! Create copy of fusion for segmentation algorithm. IrCloner is a map
//! between values in original and cloned fusions.
std::unordered_map<const Val*, int64_t> map_cloned_value_to_fid_;
//! Extents for TensorView input arguments for cloned Fusion
std::vector<Val*> cloned_extents_;

public:
//! The Operators are not directly defined in this header. They are defined
//! in the python bindings through lambda functions so the user only needs to
Expand Down
Loading

0 comments on commit 594b329

Please sign in to comment.