Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Translate segments to python definition #3335

Merged
merged 1 commit into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions csrc/python_frontend/fusion_definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,16 @@ int64_t FusionDefinition::setupSegmentation(
preschedFusion(), map_value_to_fid_, inputs);
}

std::unordered_map<int64_t, int64_t> FusionDefinition::buildSegment(
FusionDefinition& segment_fd,
int64_t segment_id) {
NVF_CHECK(id().has_value(), "FusionDefinition does not exist!");
NVF_CHECK(
segmentation_state_ != nullptr,
"Run setupSegmentation first before trying to build segments!");
return segmentation_state_->buildSegment(segment_fd, segment_id);
}

void FusionDefinition::finalizeSegmentation() {
// Destroy SegmentedState
segmentation_state_.reset();
Expand Down
7 changes: 7 additions & 0 deletions csrc/python_frontend/fusion_definition.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,13 @@ class NVF_API FusionDefinition : public FusionState {
//! Run segmentation algorithm on FusionDefinition. Returns the number of
//! segments.
NVF_API int64_t setupSegmentation(const at::ArrayRef<c10::IValue>& inputs);
//! Given an empty FusionDefinition and a segment id, buildSegment creates the
//! CPP Fusion, translates it to the python FusionDefinition, then return a
//! mapping from segment fusion state indices to the original fusion state
//! indices.
NVF_API std::unordered_map<int64_t, int64_t> buildSegment(
FusionDefinition& segment_fd,
int64_t segment_id);
//! After creating segments, destroy SegmentationState.
NVF_API void finalizeSegmentation();

Expand Down
7 changes: 7 additions & 0 deletions csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,13 @@ void initNvFuserPythonBindings(PyObject* module) {
}
return self.setupSegmentation(inputs);
})
.def(
"_build_segment",
[](FusionDefinition& self,
FusionDefinition& other,
int64_t segment_id) {
return self.buildSegment(other, segment_id);
})
.def(
"_finalize_segmentation",
[](FusionDefinition& self) {
Expand Down
248 changes: 230 additions & 18 deletions csrc/python_frontend/segmentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,65 +12,69 @@ namespace nvfuser::python_frontend {

int64_t SegmentationState::setupSegmentation(
Fusion* fusion,
const std::unordered_map<const Val*, int64_t>& map_value_to_original_fid,
const std::unordered_map<const Val*, int64_t>&
map_presched_value_to_original_python_index,
const at::ArrayRef<c10::IValue>& inputs) {
// Check state
NVF_ERROR(fusion != nullptr);
NVF_ERROR(cloned_fusion_ == nullptr);
NVF_ERROR(cloned_original_fusion_ == nullptr);
NVF_ERROR(segmented_fusion_ == nullptr);
NVF_ERROR(group_run_order_.empty());
NVF_ERROR(map_cloned_value_to_fid_.empty());
NVF_ERROR(cloned_extents_.empty());
NVF_ERROR(map_cloned_concretized_value_to_original_python_index_.empty());
NVF_ERROR(cloned_original_extents_.empty());

int8_t device = getCommonDeviceCUDA(inputs);
NVF_CHECK(
inputs.empty() || device > -1, "Inputs are not all on the same device!");

// Step 1) Clone preschedFusion CPP Fusion.
cloned_fusion_ = std::make_unique<Fusion>();
cloned_original_fusion_ = std::make_unique<Fusion>();

// The IRCloner returned by Fusion::copy acts as map from the original fusion
// to the cloned fusion.
IrCloner original_to_cloned_map = Fusion::copy(fusion, cloned_fusion_.get());
IrCloner original_to_cloned_map =
Fusion::copy(fusion, cloned_original_fusion_.get());

KernelArgumentHolder args =
KernelArgumentHolder::createKernelArgumentHolder(inputs, device);

// Step 2) Concretize fusion with input arguments.
std::unordered_map<Val*, Val*> symbolic_to_concrete_map =
DynamicTransform::concretizeFusion(cloned_fusion_.get(), args);
DynamicTransform::concretizeFusion(cloned_original_fusion_.get(), args);

// Step 3) Given the map_value_to_original_fid, the IRCloner returned by
// Fusion::copy, AND the symbolic_to_concrete map returned by
// Step 3) Given the map_presched_value_to_original_python_index, the IRCloner
// returned by Fusion::copy, AND the symbolic_to_concrete map returned by
// concretization pass, create a mapping from cloned Vals to original fusion
// state indices.
std::transform(
map_value_to_original_fid.begin(),
map_value_to_original_fid.end(),
std::inserter(map_cloned_value_to_fid_, map_cloned_value_to_fid_.end()),
map_presched_value_to_original_python_index.begin(),
map_presched_value_to_original_python_index.end(),
std::inserter(
map_cloned_concretized_value_to_original_python_index_,
map_cloned_concretized_value_to_original_python_index_.end()),
[&](const auto& item) {
const Val* original_value = item.first;
int64_t fid = item.second;
int64_t python_index = item.second;
Val* cloned_val = original_to_cloned_map.clone(original_value);
if (symbolic_to_concrete_map.count(cloned_val)) {
cloned_val = symbolic_to_concrete_map.at(cloned_val);
}
return std::make_pair(cloned_val, fid);
return std::make_pair(cloned_val, python_index);
});

// Track the extents for input TensorViews in cloned CPP Fusion.
cloned_extents_ = getExtents(cloned_fusion_.get());
cloned_original_extents_ = getExtents(cloned_original_fusion_.get());

// Create runtime infomation
SchedulerRuntimeInfo runtime_info(
cloned_fusion_.get(),
cloned_original_fusion_.get(),
args,
/*precomputed_values=*/nullptr,
cloned_fusion_->allTvs());
cloned_original_fusion_->allTvs());

// Run segmentation algorithm
segmented_fusion_ = SegmentCandidateFinder::segment(
std::move(cloned_fusion_), &args, runtime_info);
std::move(cloned_original_fusion_), &args, runtime_info);

// Get the order for fusion segments
prepareGroupOrder();
Expand All @@ -79,6 +83,214 @@ int64_t SegmentationState::setupSegmentation(
return (int64_t)segmented_fusion_->groups().size();
}

// setupSegmentation transforms the Prescheduled, Symbolic Fusion to Cloned,
// Concretized Fusion. Both CPP fusions corresponds with Original
// FusionDefinition.
//
// The segmentation pass runs on cloned, concretized fusion to create
// SegmentedFusion. Each SegmentedGroup in the SegmentedFusion creates a segment
// CPP fusion that is translated to a python definition.
//
//
// NOTE: Steps 4a through 4d are run for every fusion segment. However,
// sometimes the python definition needs the extents of the original fusion's
// input tensors as extra arguments. Steps 4f to 4l creates mappings for these
// missing extents.
//
// Details:
// 1) Use segment id to get SegmentedGroup from group_run_order_.
// 2) Create CPP Fusion for SegmentedGroup.
// * IrCloner acts as a map from fusion segment to the original fusion.
// 3) Translate CPP Fusion to Python FusionDefinition
// 4) Create map from segment fusion indices to original fusion indices.
// a) Get cloned Vals for SegmentedGroup's inputs and outputs. Map them
// to their original fusion indices.
// b) Map cloned Vals to their segment Vals
// c) Map segment Vals to their fusion indices.
// d) Map original indices to segment indices.
// e) Return map if the number of input arguments for python definition
// matches the number of input arguments for CPP fusion.
// f) Create a map from segment to cloned extents.
// g) Create a map from segment fusion indices to cloned extents.
// h) Find segment inputs that are missing from segment to original
// indices map.
// i) Get segment Vals for the missing segment fusion indices.
// j) Map segment Vals to cloned Vals.
// k) Map cloned Vals to their corresponding fusion indices.
// l) Add missing mappings to segment to original indices map.
// 5) Return the mapping from the segmented FusionDefinition index space to
// original FusionDefinition index space.
std::unordered_map<int64_t, int64_t> SegmentationState::buildSegment(
FusionDefinition& segment_fd,
int64_t segment_id) {
NVF_ERROR(
!segment_fd.completed(),
"Expected an incomplete definition before translation.");
NVF_ERROR(
segmented_fusion_ != nullptr,
"SegmentedFusion is not initialized. Run setupSegmentation first.");
NVF_ERROR(
segment_id >= 0 &&
segment_id < (int64_t)segmented_fusion_->groups().size(),
"The segment id is not valid");

// Step 1) Use segment id to get SegmentedGroup from group_run_order_.
SegmentedGroup* sg = group_run_order_.at(segment_id);
NVF_ERROR(sg != nullptr);

// Step 2) Create CPP Fusion for SegmentedGroup. The IrCloner acts as a map
// from fusion segment to the original fusion.
std::pair<IrCloner, std::unique_ptr<Fusion>> cloner_segment_pair =
segmented_fusion_->makeFusion(sg);
IrCloner cloned_to_segment_map = cloner_segment_pair.first;
std::unique_ptr<Fusion> fusion_segment =
std::move(cloner_segment_pair.second);

// Step 3) Translate CPP Fusion to Python FusionDefinition
std::unordered_map<const nvfuser::Val*, size_t>
map_segment_cpp_value_to_python_index =
translate(fusion_segment.get(), &segment_fd);

// Step 4) Create map from segment fusion indices to original fusion indices.
// Step 4a) Get FusionDefinition index for cloned inputs and outputs. Map them
// to their original fusion indices.
const std::vector<Val*>& cloned_inputs = sg->inputs();
const std::vector<Val*>& cloned_outputs = sg->outputs();

std::vector<int64_t> original_python_index;
original_python_index.reserve(cloned_inputs.size() + cloned_outputs.size());

std::transform(
cloned_inputs.begin(),
cloned_inputs.end(),
std::back_inserter(original_python_index),
[&](Val* v) {
return map_cloned_concretized_value_to_original_python_index_.at(v);
});

std::transform(
cloned_outputs.begin(),
cloned_outputs.end(),
std::back_inserter(original_python_index),
[&](Val* v) {
return map_cloned_concretized_value_to_original_python_index_.at(v);
});

// Step 4b) ir_cloner maps cloned fusion Vals to segment Vals.
std::vector<Val*> segment_inputs_outputs;
segment_inputs_outputs.reserve(cloned_inputs.size() + cloned_outputs.size());

std::transform(
cloned_inputs.begin(),
cloned_inputs.end(),
std::back_inserter(segment_inputs_outputs),
[&](Val* v) { return cloned_to_segment_map.clone(v); });

std::transform(
cloned_outputs.begin(),
cloned_outputs.end(),
std::back_inserter(segment_inputs_outputs),
[&](Val* v) { return cloned_to_segment_map.clone(v); });

// Step 4c) Map segment Vals to their FusionDefinition index.
std::vector<int64_t> segment_python_index;
segment_python_index.reserve(segment_inputs_outputs.size());
std::transform(
segment_inputs_outputs.begin(),
segment_inputs_outputs.end(),
std::back_inserter(segment_python_index),
[&](Val* v) { return map_segment_cpp_value_to_python_index.at(v); });

// Step 4d) Map original indices to segment indices.
NVF_ERROR(original_python_index.size() == segment_python_index.size());
std::unordered_map<int64_t, int64_t> segment_to_original_python_index_map;
for (size_t idx : c10::irange(original_python_index.size())) {
segment_to_original_python_index_map.emplace(
segment_python_index.at(idx), original_python_index.at(idx));
}

// Step 4e) short-circuit: No extra extents required for python definition.
if (fusion_segment->inputs().size() == segment_fd.inputs().size()) {
return segment_to_original_python_index_map;
}

// The python segment can require the size of tensor dimensions from original
jjsjann123 marked this conversation as resolved.
Show resolved Hide resolved
// fusion's input arguments, which the CPP segment does not.

// Step 4f) Create a map from segment to cloned extents.
// Step 4g) Create a map from segment indices to segment extents.
std::unordered_map<Val*, Val*> segment_to_cloned_extents;
std::unordered_map<size_t, Val*> segment_python_index_to_cpp_val;
for (Val* cloned_extent : cloned_original_extents_) {
Val* segment_extent = cloned_to_segment_map.clone(cloned_extent);

// short-circuit: some extents are not used in segment
if (map_segment_cpp_value_to_python_index.count(segment_extent) == 0) {
continue;
}

size_t segment_python_index =
map_segment_cpp_value_to_python_index.at(segment_extent);
segment_to_cloned_extents.emplace(segment_extent, cloned_extent);
segment_python_index_to_cpp_val.emplace(
segment_python_index, segment_extent);
}

// Step 4h) Find the set difference between all segment input indices and
jjsjann123 marked this conversation as resolved.
Show resolved Hide resolved
// known input segment indices.
std::vector<int64_t> missing_segment_python_index;
for (int64_t input_python_index : segment_fd.inputs()) {
if (segment_to_original_python_index_map.count(input_python_index) == 0) {
missing_segment_python_index.push_back(input_python_index);
}
}

// Step 4i) Get segment Val for missing segment input indices.
std::vector<Val*> missing_segment_val;
missing_segment_val.reserve(missing_segment_python_index.size());
std::transform(
missing_segment_python_index.begin(),
missing_segment_python_index.end(),
std::back_inserter(missing_segment_val),
[&](int64_t segment_python_index) {
return segment_python_index_to_cpp_val.at(segment_python_index);
});

// Step 4j) Map segment Vals to cloned Vals
std::vector<Val*> missing_cloned_val;
missing_cloned_val.reserve(missing_segment_val.size());
std::transform(
missing_segment_val.begin(),
missing_segment_val.end(),
std::back_inserter(missing_cloned_val),
[&](Val* segment_val) {
return segment_to_cloned_extents.at(segment_val);
});

// Step 4k) Transform cloned Vals to their original fusion indices.
std::vector<int64_t> missing_cloned_python_index;
missing_cloned_python_index.reserve(missing_cloned_val.size());
std::transform(
missing_cloned_val.begin(),
missing_cloned_val.end(),
std::back_inserter(missing_cloned_python_index),
[&](Val* cloned_val) {
return map_cloned_concretized_value_to_original_python_index_.at(
cloned_val);
});

// Step 4l) Add missing mappings from segment to original indices.
for (size_t idx : c10::irange(missing_segment_python_index.size())) {
segment_to_original_python_index_map.emplace(
missing_segment_python_index.at(idx),
missing_cloned_python_index.at(idx));
}

// Return the mapping from the index space of segment FusionDefinition to the
// index space of the original FusionDefinition.
return segment_to_original_python_index_map;
}

void SegmentationState::prepareGroupOrder() {
NVF_ERROR(segmented_fusion_ != nullptr);

Expand Down
Loading
Loading