Skip to content

Commit

Permalink
Translate segments to python definition (#3335)
Browse files Browse the repository at this point in the history
## Overview:

- `buildSegment` creates the CPP Fusion for a given segment id,
translates it to a python FusionDefinition, then returns a mapping from
the segment fusion state indices to the original fusion state indices.
- `FusionDefinition.segment` calls `setupSegmentation`, `buildSegment`,
and `finalizeSegmentation` to create python definitions for the
sub-fusions and their index mappings.

## Changes in this PR

This PR implements `buildSegment` function for user-scheduler
segmentation. It is the second PR in a stack, preceded by
#3334 and followed by
#3025.

1. Implement `buildSegment` function in
`csrc/python_frontend/segmentation.cpp`.
2. Complete `segment` function in `nvfuser/__init__.py`

## Example:
### Original Fusion: A reduction + broadcast + pointwise fusion.
```python
def nvfuser_fusion_id1(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1],
                          contiguity=[True, True],
                          dtype=DataType.Float,
                          is_cpu=False)
    T1 = fd.define_tensor(shape=[-1, -1],
                          contiguity=[True, True],
                          dtype=DataType.Float,
                          is_cpu=False)
    T2 = fd.ops.sum(T0, dims=[1], keepdim=False, dtype=DataType.Float)
    T3 = fd.ops.broadcast(T2, is_broadcast_dim=[False, True])
    T4 = fd.ops.add(T1, T3)
    fd.add_output(T4)
```

**After Segmentation:** The reduction scheduler does not support fusing
any operations with an inner reduction, so the original fusion is
divided into two segments.

## First Segment:
The first segment contains the reduction and broadcast operations, which
corresponds with [T0, T2, T3] in the original fusion. Therefore, the
segment index to original index map has two entries.

| Segment Index | Original Index | Description |
| -----------------| ---------------  | ------------- |
| T0 | T0 | The first tensor argument for the original fusion. |
| T2 | T3 | The broadcasted, reduction tensor is this segment's output.
|

```python
def nvfuser_fusion_id2(fd : FusionDefinition) -> None :
   T0 = fd.define_tensor(shape=[-1, -1],
                         contiguity=[True, True],
                         dtype=DataType.Float,
                         is_cpu=False)
   T1 = fd.ops.sum(T0, dims=[1], keepdim=False, dtype=DataType.Float)
   T2 = fd.ops.broadcast(T1, is_broadcast_dim=[False, True])
   fd.add_output(T2)
```
## Second Segment:
The second segment is the pointwise addition with the broadcasted
reduction. It corresponds with [T1, T3, T4] in the original fusion.

| Segment Index | Original Index | Description |
| -----------------| ---------------  | ------------- |
| T0 | T1 | The second tensor argument for the original fusion. |
| T1 | T3 | The broadcasted, reduction tensor, which is the output from
the first segment. |
| T2 | T4 | The pointwise addition, which is the output for the original
fusion. |

```python
def nvfuser_fusion_id3(fd : FusionDefinition) -> None :
   T0 = fd.define_tensor(shape=[-1, -1],
                         contiguity=[True, True],
                         dtype=DataType.Float,
                         is_cpu=False)
   T1 = fd.define_tensor(shape=[-1, 1],
                         contiguity=[True, None],
                         dtype=DataType.Float,
                         is_cpu=False)
   T2 = fd.ops.add(T0, T1)
   fd.add_output(T2)
```
  • Loading branch information
rdspring1 authored Nov 15, 2024
1 parent 8fa7555 commit 3229ed8
Show file tree
Hide file tree
Showing 7 changed files with 357 additions and 32 deletions.
10 changes: 10 additions & 0 deletions csrc/python_frontend/fusion_definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,16 @@ int64_t FusionDefinition::setupSegmentation(
preschedFusion(), map_value_to_fid_, inputs);
}

std::unordered_map<int64_t, int64_t> FusionDefinition::buildSegment(
FusionDefinition& segment_fd,
int64_t segment_id) {
NVF_CHECK(id().has_value(), "FusionDefinition does not exist!");
NVF_CHECK(
segmentation_state_ != nullptr,
"Run setupSegmentation first before trying to build segments!");
return segmentation_state_->buildSegment(segment_fd, segment_id);
}

void FusionDefinition::finalizeSegmentation() {
// Destroy SegmentedState
segmentation_state_.reset();
Expand Down
7 changes: 7 additions & 0 deletions csrc/python_frontend/fusion_definition.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,13 @@ class NVF_API FusionDefinition : public FusionState {
//! Run segmentation algorithm on FusionDefinition. Returns the number of
//! segments.
NVF_API int64_t setupSegmentation(const at::ArrayRef<c10::IValue>& inputs);
//! Given an empty FusionDefinition and a segment id, buildSegment creates the
//! CPP Fusion, translates it to the python FusionDefinition, then return a
//! mapping from segment fusion state indices to the original fusion state
//! indices.
NVF_API std::unordered_map<int64_t, int64_t> buildSegment(
FusionDefinition& segment_fd,
int64_t segment_id);
//! After creating segments, destroy SegmentationState.
NVF_API void finalizeSegmentation();

Expand Down
7 changes: 7 additions & 0 deletions csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,13 @@ void initNvFuserPythonBindings(PyObject* module) {
}
return self.setupSegmentation(inputs);
})
.def(
"_build_segment",
[](FusionDefinition& self,
FusionDefinition& other,
int64_t segment_id) {
return self.buildSegment(other, segment_id);
})
.def(
"_finalize_segmentation",
[](FusionDefinition& self) {
Expand Down
248 changes: 230 additions & 18 deletions csrc/python_frontend/segmentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,65 +12,69 @@ namespace nvfuser::python_frontend {

int64_t SegmentationState::setupSegmentation(
Fusion* fusion,
const std::unordered_map<const Val*, int64_t>& map_value_to_original_fid,
const std::unordered_map<const Val*, int64_t>&
map_presched_value_to_original_python_index,
const at::ArrayRef<c10::IValue>& inputs) {
// Check state
NVF_ERROR(fusion != nullptr);
NVF_ERROR(cloned_fusion_ == nullptr);
NVF_ERROR(cloned_original_fusion_ == nullptr);
NVF_ERROR(segmented_fusion_ == nullptr);
NVF_ERROR(group_run_order_.empty());
NVF_ERROR(map_cloned_value_to_fid_.empty());
NVF_ERROR(cloned_extents_.empty());
NVF_ERROR(map_cloned_concretized_value_to_original_python_index_.empty());
NVF_ERROR(cloned_original_extents_.empty());

int8_t device = getCommonDeviceCUDA(inputs);
NVF_CHECK(
inputs.empty() || device > -1, "Inputs are not all on the same device!");

// Step 1) Clone preschedFusion CPP Fusion.
cloned_fusion_ = std::make_unique<Fusion>();
cloned_original_fusion_ = std::make_unique<Fusion>();

// The IRCloner returned by Fusion::copy acts as map from the original fusion
// to the cloned fusion.
IrCloner original_to_cloned_map = Fusion::copy(fusion, cloned_fusion_.get());
IrCloner original_to_cloned_map =
Fusion::copy(fusion, cloned_original_fusion_.get());

KernelArgumentHolder args =
KernelArgumentHolder::createKernelArgumentHolder(inputs, device);

// Step 2) Concretize fusion with input arguments.
std::unordered_map<Val*, Val*> symbolic_to_concrete_map =
DynamicTransform::concretizeFusion(cloned_fusion_.get(), args);
DynamicTransform::concretizeFusion(cloned_original_fusion_.get(), args);

// Step 3) Given the map_value_to_original_fid, the IRCloner returned by
// Fusion::copy, AND the symbolic_to_concrete map returned by
// Step 3) Given the map_presched_value_to_original_python_index, the IRCloner
// returned by Fusion::copy, AND the symbolic_to_concrete map returned by
// concretization pass, create a mapping from cloned Vals to original fusion
// state indices.
std::transform(
map_value_to_original_fid.begin(),
map_value_to_original_fid.end(),
std::inserter(map_cloned_value_to_fid_, map_cloned_value_to_fid_.end()),
map_presched_value_to_original_python_index.begin(),
map_presched_value_to_original_python_index.end(),
std::inserter(
map_cloned_concretized_value_to_original_python_index_,
map_cloned_concretized_value_to_original_python_index_.end()),
[&](const auto& item) {
const Val* original_value = item.first;
int64_t fid = item.second;
int64_t python_index = item.second;
Val* cloned_val = original_to_cloned_map.clone(original_value);
if (symbolic_to_concrete_map.count(cloned_val)) {
cloned_val = symbolic_to_concrete_map.at(cloned_val);
}
return std::make_pair(cloned_val, fid);
return std::make_pair(cloned_val, python_index);
});

// Track the extents for input TensorViews in cloned CPP Fusion.
cloned_extents_ = getExtents(cloned_fusion_.get());
cloned_original_extents_ = getExtents(cloned_original_fusion_.get());

// Create runtime infomation
SchedulerRuntimeInfo runtime_info(
cloned_fusion_.get(),
cloned_original_fusion_.get(),
args,
/*precomputed_values=*/nullptr,
cloned_fusion_->allTvs());
cloned_original_fusion_->allTvs());

// Run segmentation algorithm
segmented_fusion_ = SegmentCandidateFinder::segment(
std::move(cloned_fusion_), &args, runtime_info);
std::move(cloned_original_fusion_), &args, runtime_info);

// Get the order for fusion segments
prepareGroupOrder();
Expand All @@ -79,6 +83,214 @@ int64_t SegmentationState::setupSegmentation(
return (int64_t)segmented_fusion_->groups().size();
}

// setupSegmentation transforms the Prescheduled, Symbolic Fusion to Cloned,
// Concretized Fusion. Both CPP fusions corresponds with Original
// FusionDefinition.
//
// The segmentation pass runs on cloned, concretized fusion to create
// SegmentedFusion. Each SegmentedGroup in the SegmentedFusion creates a segment
// CPP fusion that is translated to a python definition.
//
//
// NOTE: Steps 4a through 4d are run for every fusion segment. However,
// sometimes the python definition needs the extents of the original fusion's
// input tensors as extra arguments. Steps 4f to 4l creates mappings for these
// missing extents.
//
// Details:
// 1) Use segment id to get SegmentedGroup from group_run_order_.
// 2) Create CPP Fusion for SegmentedGroup.
// * IrCloner acts as a map from fusion segment to the original fusion.
// 3) Translate CPP Fusion to Python FusionDefinition
// 4) Create map from segment fusion indices to original fusion indices.
// a) Get cloned Vals for SegmentedGroup's inputs and outputs. Map them
// to their original fusion indices.
// b) Map cloned Vals to their segment Vals
// c) Map segment Vals to their fusion indices.
// d) Map original indices to segment indices.
// e) Return map if the number of input arguments for python definition
// matches the number of input arguments for CPP fusion.
// f) Create a map from segment to cloned extents.
// g) Create a map from segment fusion indices to cloned extents.
// h) Find segment inputs that are missing from segment to original
// indices map.
// i) Get segment Vals for the missing segment fusion indices.
// j) Map segment Vals to cloned Vals.
// k) Map cloned Vals to their corresponding fusion indices.
// l) Add missing mappings to segment to original indices map.
// 5) Return the mapping from the segmented FusionDefinition index space to
// original FusionDefinition index space.
std::unordered_map<int64_t, int64_t> SegmentationState::buildSegment(
FusionDefinition& segment_fd,
int64_t segment_id) {
NVF_ERROR(
!segment_fd.completed(),
"Expected an incomplete definition before translation.");
NVF_ERROR(
segmented_fusion_ != nullptr,
"SegmentedFusion is not initialized. Run setupSegmentation first.");
NVF_ERROR(
segment_id >= 0 &&
segment_id < (int64_t)segmented_fusion_->groups().size(),
"The segment id is not valid");

// Step 1) Use segment id to get SegmentedGroup from group_run_order_.
SegmentedGroup* sg = group_run_order_.at(segment_id);
NVF_ERROR(sg != nullptr);

// Step 2) Create CPP Fusion for SegmentedGroup. The IrCloner acts as a map
// from fusion segment to the original fusion.
std::pair<IrCloner, std::unique_ptr<Fusion>> cloner_segment_pair =
segmented_fusion_->makeFusion(sg);
IrCloner cloned_to_segment_map = cloner_segment_pair.first;
std::unique_ptr<Fusion> fusion_segment =
std::move(cloner_segment_pair.second);

// Step 3) Translate CPP Fusion to Python FusionDefinition
std::unordered_map<const nvfuser::Val*, size_t>
map_segment_cpp_value_to_python_index =
translate(fusion_segment.get(), &segment_fd);

// Step 4) Create map from segment fusion indices to original fusion indices.
// Step 4a) Get FusionDefinition index for cloned inputs and outputs. Map them
// to their original fusion indices.
const std::vector<Val*>& cloned_inputs = sg->inputs();
const std::vector<Val*>& cloned_outputs = sg->outputs();

std::vector<int64_t> original_python_index;
original_python_index.reserve(cloned_inputs.size() + cloned_outputs.size());

std::transform(
cloned_inputs.begin(),
cloned_inputs.end(),
std::back_inserter(original_python_index),
[&](Val* v) {
return map_cloned_concretized_value_to_original_python_index_.at(v);
});

std::transform(
cloned_outputs.begin(),
cloned_outputs.end(),
std::back_inserter(original_python_index),
[&](Val* v) {
return map_cloned_concretized_value_to_original_python_index_.at(v);
});

// Step 4b) ir_cloner maps cloned fusion Vals to segment Vals.
std::vector<Val*> segment_inputs_outputs;
segment_inputs_outputs.reserve(cloned_inputs.size() + cloned_outputs.size());

std::transform(
cloned_inputs.begin(),
cloned_inputs.end(),
std::back_inserter(segment_inputs_outputs),
[&](Val* v) { return cloned_to_segment_map.clone(v); });

std::transform(
cloned_outputs.begin(),
cloned_outputs.end(),
std::back_inserter(segment_inputs_outputs),
[&](Val* v) { return cloned_to_segment_map.clone(v); });

// Step 4c) Map segment Vals to their FusionDefinition index.
std::vector<int64_t> segment_python_index;
segment_python_index.reserve(segment_inputs_outputs.size());
std::transform(
segment_inputs_outputs.begin(),
segment_inputs_outputs.end(),
std::back_inserter(segment_python_index),
[&](Val* v) { return map_segment_cpp_value_to_python_index.at(v); });

// Step 4d) Map original indices to segment indices.
NVF_ERROR(original_python_index.size() == segment_python_index.size());
std::unordered_map<int64_t, int64_t> segment_to_original_python_index_map;
for (size_t idx : c10::irange(original_python_index.size())) {
segment_to_original_python_index_map.emplace(
segment_python_index.at(idx), original_python_index.at(idx));
}

// Step 4e) short-circuit: No extra extents required for python definition.
if (fusion_segment->inputs().size() == segment_fd.inputs().size()) {
return segment_to_original_python_index_map;
}

// The python segment can require the size of tensor dimensions from original
// fusion's input arguments, which the CPP segment does not.

// Step 4f) Create a map from segment to cloned extents.
// Step 4g) Create a map from segment indices to segment extents.
std::unordered_map<Val*, Val*> segment_to_cloned_extents;
std::unordered_map<size_t, Val*> segment_python_index_to_cpp_val;
for (Val* cloned_extent : cloned_original_extents_) {
Val* segment_extent = cloned_to_segment_map.clone(cloned_extent);

// short-circuit: some extents are not used in segment
if (map_segment_cpp_value_to_python_index.count(segment_extent) == 0) {
continue;
}

size_t segment_python_index =
map_segment_cpp_value_to_python_index.at(segment_extent);
segment_to_cloned_extents.emplace(segment_extent, cloned_extent);
segment_python_index_to_cpp_val.emplace(
segment_python_index, segment_extent);
}

// Step 4h) Find the set difference between all segment input indices and
// known input segment indices.
std::vector<int64_t> missing_segment_python_index;
for (int64_t input_python_index : segment_fd.inputs()) {
if (segment_to_original_python_index_map.count(input_python_index) == 0) {
missing_segment_python_index.push_back(input_python_index);
}
}

// Step 4i) Get segment Val for missing segment input indices.
std::vector<Val*> missing_segment_val;
missing_segment_val.reserve(missing_segment_python_index.size());
std::transform(
missing_segment_python_index.begin(),
missing_segment_python_index.end(),
std::back_inserter(missing_segment_val),
[&](int64_t segment_python_index) {
return segment_python_index_to_cpp_val.at(segment_python_index);
});

// Step 4j) Map segment Vals to cloned Vals
std::vector<Val*> missing_cloned_val;
missing_cloned_val.reserve(missing_segment_val.size());
std::transform(
missing_segment_val.begin(),
missing_segment_val.end(),
std::back_inserter(missing_cloned_val),
[&](Val* segment_val) {
return segment_to_cloned_extents.at(segment_val);
});

// Step 4k) Transform cloned Vals to their original fusion indices.
std::vector<int64_t> missing_cloned_python_index;
missing_cloned_python_index.reserve(missing_cloned_val.size());
std::transform(
missing_cloned_val.begin(),
missing_cloned_val.end(),
std::back_inserter(missing_cloned_python_index),
[&](Val* cloned_val) {
return map_cloned_concretized_value_to_original_python_index_.at(
cloned_val);
});

// Step 4l) Add missing mappings from segment to original indices.
for (size_t idx : c10::irange(missing_segment_python_index.size())) {
segment_to_original_python_index_map.emplace(
missing_segment_python_index.at(idx),
missing_cloned_python_index.at(idx));
}

// Return the mapping from the index space of segment FusionDefinition to the
// index space of the original FusionDefinition.
return segment_to_original_python_index_map;
}

void SegmentationState::prepareGroupOrder() {
NVF_ERROR(segmented_fusion_ != nullptr);

Expand Down
Loading

0 comments on commit 3229ed8

Please sign in to comment.