Skip to content

Commit

Permalink
Translate segments to python definition
Browse files Browse the repository at this point in the history
Add buildSegment function
  • Loading branch information
rdspring1 committed Nov 13, 2024
1 parent a53d1dd commit 98fd9b2
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 2 deletions.
10 changes: 10 additions & 0 deletions csrc/python_frontend/fusion_definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,16 @@ int64_t FusionDefinition::setupSegmentation(
preschedFusion(), map_value_to_fid_, inputs);
}

std::unordered_map<int64_t, int64_t> FusionDefinition::buildSegment(
FusionDefinition& other,
int64_t segment_id) {
NVF_CHECK(id().has_value(), "FusionDefinition definition does not exist!");
NVF_CHECK(
segmentation_state_ != nullptr,
"Run setupSegmenation first before trying to build segments!");
return segmentation_state_->buildSegment(other, segment_id);
}

void FusionDefinition::finalizeSegmentation() {
// Destroy SegmentedState
segmentation_state_.reset();
Expand Down
7 changes: 7 additions & 0 deletions csrc/python_frontend/fusion_definition.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,13 @@ class NVF_API FusionDefinition : public FusionState {
//! Run segmentation algorithm on FusionDefinition. Returns the number of
//! segments.
NVF_API int64_t setupSegmentation(const at::ArrayRef<c10::IValue>& inputs);
//! Given an empty FusionDefinition and a segment id, buildSegment creates the
//! CPP Fusion, translates it to the python FusionDefinition, then return a
//! mapping from segment fusion state indices to the original fusion state
//! indices.
NVF_API std::unordered_map<int64_t, int64_t> buildSegment(
FusionDefinition& other,
int64_t segment_id);
//! After creating segments, destroy SegmentationState.
NVF_API void finalizeSegmentation();

Expand Down
7 changes: 7 additions & 0 deletions csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,13 @@ void initNvFuserPythonBindings(PyObject* module) {
}
return self.setupSegmentation(inputs);
})
.def(
"_build_segment",
[](FusionDefinition& self,
FusionDefinition& other,
int64_t segment_id) {
return self.buildSegment(other, segment_id);
})
.def(
"_finalize_segmentation",
[](FusionDefinition& self) {
Expand Down
161 changes: 161 additions & 0 deletions csrc/python_frontend/segmentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,167 @@ int64_t SegmentationState::setupSegmentation(
return (int64_t)segmented_fusion_->groups().size();
}

std::unordered_map<int64_t, int64_t> SegmentationState::buildSegment(
FusionDefinition& other,
int64_t segment_id) {
NVF_ERROR(
!other.completed(),
"Expected an incomplete definition before translation.");
NVF_ERROR(
segmented_fusion_ != nullptr,
"SegmentedFusion is not initialized. Run setupSegmentation first.");
NVF_ERROR(
segment_id >= 0 &&
segment_id < (int64_t)segmented_fusion_->groups().size(),
"The segment id is not valid");

// Step 1) Use segment id to get SegmentedGroup from group_run_order_.
SegmentedGroup* sg = group_run_order_.at(segment_id);
NVF_ERROR(sg != nullptr);

// Step 2) Create CPP Fusion for SegmentedGroup. The IrCloner acts as a map
// from fusion segment to the original fusion.
std::pair<IrCloner, std::unique_ptr<Fusion>> cloner_segment_pair =
segmented_fusion_->makeFusion(sg);
IrCloner cloned_to_segment_map = cloner_segment_pair.first;
std::unique_ptr<Fusion> fusion_segment =
std::move(cloner_segment_pair.second);

// Step 3) Translate CPP Fusion to Python FusionDefinition
std::unordered_map<const nvfuser::Val*, size_t>
map_translated_val_to_segment_fid =
translate(fusion_segment.get(), &other);

// Step 4) Create map from segment fusion indices to original fusion indices.
// Step 4a) Get FusionDefinition index for cloned inputs and outputs. Map them
// to their original fusion indices.
const std::vector<Val*>& cloned_inputs = sg->inputs();
const std::vector<Val*>& cloned_outputs = sg->outputs();

std::vector<int64_t> original_fid;
original_fid.reserve(cloned_inputs.size() + cloned_outputs.size());

std::transform(
cloned_inputs.begin(),
cloned_inputs.end(),
std::back_inserter(original_fid),
[&](Val* v) { return map_cloned_value_to_fid_.at(v); });

std::transform(
cloned_outputs.begin(),
cloned_outputs.end(),
std::back_inserter(original_fid),
[&](Val* v) { return map_cloned_value_to_fid_.at(v); });

// Step 4b) ir_cloner maps cloned fusion Vals to segment Vals.
std::vector<Val*> segment_inputs_outputs;
segment_inputs_outputs.reserve(cloned_inputs.size() + cloned_outputs.size());

std::transform(
cloned_inputs.begin(),
cloned_inputs.end(),
std::back_inserter(segment_inputs_outputs),
[&](Val* v) { return cloned_to_segment_map.clone(v); });

std::transform(
cloned_outputs.begin(),
cloned_outputs.end(),
std::back_inserter(segment_inputs_outputs),
[&](Val* v) { return cloned_to_segment_map.clone(v); });

// Step 4c) Map segment Vals to their FusionDefinition index.
std::vector<int64_t> segment_fid;
segment_fid.reserve(segment_inputs_outputs.size());
std::transform(
segment_inputs_outputs.begin(),
segment_inputs_outputs.end(),
std::back_inserter(segment_fid),
[&](Val* v) { return map_translated_val_to_segment_fid.at(v); });

// Step 4d) Map original indices to segment indices.
NVF_ERROR(original_fid.size() == segment_fid.size());
std::unordered_map<int64_t, int64_t> segment_fid_to_original_fid_map;
for (size_t idx : c10::irange(original_fid.size())) {
segment_fid_to_original_fid_map.emplace(
segment_fid.at(idx), original_fid.at(idx));
}

// Step 4e) short-circuit: No extra extents required for python definition.
if (fusion_segment->inputs().size() == other.inputs().size()) {
return segment_fid_to_original_fid_map;
}

// The python segment can require the size of tensor dimensions from original
// fusion's input arguments, which the CPP segment does not.

// Step 4f) Create a map from segment to cloned extents.
// Step 4g) Create a map from segment indices to segment extents.
std::unordered_map<Val*, Val*> segment_to_cloned_extents;
std::unordered_map<size_t, Val*> segment_fid_to_translated_val;
for (Val* cloned_extent : cloned_extents_) {
Val* segment_extent = cloned_to_segment_map.clone(cloned_extent);

// short-circuit: some extents are not used in segment
if (map_translated_val_to_segment_fid.count(segment_extent) == 0) {
continue;
}

size_t segment_fid = map_translated_val_to_segment_fid.at(segment_extent);
segment_to_cloned_extents.emplace(segment_extent, cloned_extent);
segment_fid_to_translated_val.emplace(segment_fid, segment_extent);
}

// Step 4h) Find the set difference between all segment input indices and
// known input segment indices.
std::vector<int64_t> missing_segment_fid;
for (int64_t input_fid : other.inputs()) {
if (segment_fid_to_original_fid_map.count(input_fid) == 0) {
missing_segment_fid.push_back(input_fid);
}
}

// Step 4i) Get segment Val for missing segment input indices.
std::vector<Val*> missing_segment_val;
missing_segment_val.reserve(missing_segment_fid.size());
std::transform(
missing_segment_fid.begin(),
missing_segment_fid.end(),
std::back_inserter(missing_segment_val),
[&](int64_t segment_fid) {
return segment_fid_to_translated_val.at(segment_fid);
});

// Step 4j) Map segment Vals to cloned Vals
std::vector<Val*> missing_cloned_val;
missing_cloned_val.reserve(missing_segment_val.size());
std::transform(
missing_segment_val.begin(),
missing_segment_val.end(),
std::back_inserter(missing_cloned_val),
[&](Val* segment_val) {
return segment_to_cloned_extents.at(segment_val);
});

// Step 4k) Transform cloned Vals to their original fusion indices.
std::vector<int64_t> missing_cloned_fid;
missing_cloned_fid.reserve(missing_cloned_val.size());
std::transform(
missing_cloned_val.begin(),
missing_cloned_val.end(),
std::back_inserter(missing_cloned_fid),
[&](Val* cloned_val) { return map_cloned_value_to_fid_.at(cloned_val); });

// Step 4l) Add missing mappings from segment to original indices.
for (size_t idx : c10::irange(missing_segment_fid.size())) {
segment_fid_to_original_fid_map.emplace(
missing_segment_fid.at(idx), missing_cloned_fid.at(idx));
}

// Return the mapping from the index space of segment FusionDefinition to the
// index space of the original FusionDefinition.
return segment_fid_to_original_fid_map;
}

void SegmentationState::prepareGroupOrder() {
NVF_ERROR(segmented_fusion_ != nullptr);

Expand Down
37 changes: 37 additions & 0 deletions csrc/python_frontend/segmentation.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,43 @@ class SegmentationState {
const std::unordered_map<const Val*, int64_t>& map_value_to_original_fid,
const at::ArrayRef<c10::IValue>& inputs);

// Given an empty FusionDefinition and a segment id, buildSegment creates the
// CPP Fusion, translates it to the python FusionDefinition, then return a
// mapping from segment fusion state indices to the original fusion state
// indices.
//
// NOTE: Steps 4a through 4d are run for every fusion segment. However,
// sometimes the python definition needs the extents of the original fusion's
// input tensors as extra arguments. Steps 4f to 4l creates mappings for these
// missing extents.
//
// Details:
// 1) Use segment id to get SegmentedGroup from group_run_order_.
// 2) Create CPP Fusion for SegmentedGroup.
// * IrCloner acts as a map from fusion segment to the original fusion.
// 3) Translate CPP Fusion to Python FusionDefinition
// 4) Create map from segment fusion indices to original fusion indices.
// a) Get cloned Vals for SegmentedGroup's inputs and outputs. Map them
// to their original fusion indices.
// b) Map cloned Vals to their segment Vals
// c) Map segment Vals to their fusion indices.
// d) Map original indices to segment indices.
// e) Return map if the number of input arguments for python definition
// matches the number of input arguments for CPP fusion.
// f) Create a map from segment to cloned extents.
// g) Create a map from segment fusion indices to cloned extents.
// h) Find segment inputs that are missing from segment to original
// indices map.
// i) Get segment Vals for the missing segment fusion indices.
// j) Map segment Vals to cloned Vals.
// k) Map cloned Vals to their corresponding fusion indices.
// l) Add missing mappings to segment to original indices map.
// 5) Return the mapping from the segmented FusionDefinition index space to
// original FusionDefinition index space.
NVF_API std::unordered_map<int64_t, int64_t> buildSegment(
FusionDefinition& other,
int64_t segment_id);

private:
// prepareGroupOrder is similar to prepareRuntimeOrder. It generates the
// topological order of SegmentedGroups in SegmentedFusion.
Expand Down
34 changes: 33 additions & 1 deletion nvfuser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,41 @@ def __init__(self, id=None, max_length=1024):
self.profiled = False

def segment(self, inputs):
"""
Decompose this FusionDefinition into a sequence of segment
FusionDefinitions.
This function runs the nvfuser segmentation algorithm and translates the
segments into their corresponding FusionDefinitions.
Args:
inputs (List[Union[Tensor, Scalar]]): A list of inputs to fusion.
Returns:
List[FusionDefinition]: The FusionDefinitions corresponding to the
sub-fusion segments of this FusionDefinition.
"""
num_segments = self._setup_segmentation(inputs)
if num_segments == 1:
self._finalize_segmentation()
return []

self.segments = []
self.segment_maps = []
self.last_used_segment = {}
for idx in range(num_segments):
new_fd = FusionDefinition()
segment_to_original_fid = self._build_segment(new_fd, idx)

# Track the last segment a value is used as an input
for segment_input in new_fd.inputs():
original_input = segment_to_original_fid[segment_input]
self.last_used_segment[original_input] = idx

self.segment_maps.append(segment_to_original_fid)
self.segments.append(new_fd)
self._finalize_segmentation()
return num_segments
return self.segments

def __enter__(self):
return self._setup_definition()
Expand Down
11 changes: 10 additions & 1 deletion tests/python/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,7 +1263,16 @@ def fusion_func(fd: FusionDefinition):

with FusionDefinition() as fd:
fusion_func(fd)
assert fd.segment(inputs) == 2

# create segments
fd.segment(inputs)

# Each segment has a map from its segment index space to the original
# fusion's index space. The original fusion creates two segments.
# Check that segment maps match expected behavior.
assert len(fd.segment_maps) == 2
assert fd.segment_maps[0] == {2: 2, 0: 0}
assert fd.segment_maps[1] == {2: 3, 1: 2, 0: 1}

def test_arithmetic_ops(self):
inputs = [
Expand Down

0 comments on commit 98fd9b2

Please sign in to comment.