Skip to content

Commit

Permalink
Enable segmentation support to python frontend user-scheduling (#3025)
Browse files Browse the repository at this point in the history
## Overview:
The original `FusionDefinition` stores the sequence of sub-fusions and
acts as an argument manager. It gathers the input arguments before
running the sub-fusion and stores its results. To perform this function,
it uses a map from the segment index space to the original index space.
This mapping was generated while creating the python definition for each
sub-fusion.

## Changes in this PR

This PR implements `_execute_segments ` function for user-scheduler
segmentation. It is the third PR in a stack, preceded by
#3334 and
#3335.

1. Implement `_execute_segments ` function in `nvfuser/__init__.py` to
orchestrate segments in original fusion.
2. Add `supports_segmentation flag` to `exec_nvfuser`, so segmentation
testing is enabled by default for all python tests.

## Example:
### Original Fusion: A reduction + broadcast + pointwise fusion.
```python
def nvfuser_fusion_id1(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1],
                          contiguity=[True, True],
                          dtype=DataType.Float,
                          is_cpu=False)
    T1 = fd.define_tensor(shape=[-1, -1],
                          contiguity=[True, True],
                          dtype=DataType.Float,
                          is_cpu=False)
    T2 = fd.ops.sum(T0, dims=[1], keepdim=False, dtype=DataType.Float)
    T3 = fd.ops.broadcast(T2, is_broadcast_dim=[False, True])
    T4 = fd.ops.add(T1, T3)
    fd.add_output(T4)
```

## Step-by-Step execution of `_execute_segments`

### Step 1 before running any segments.
#### `map_original_fid_to_value` state: 6 entries
|   Original Index   |    Description    |
| -----------------| --------------- |
| 0 | The first tensor argument for the original fusion. |
| 1 | The second tensor argument for the original fusion. |
| -1 | Extent of axis 0 for first tensor argument. |
| -2 | Extent of axis 1 for first tensor argument. |
| -3 | Extent of axis 0 for second tensor argument. |
| -4 | Extent of axis 1 for second tensor argument. |

* Omit extents [-1, -4] from table in future steps because they are not
necessary for these segments.

### Step 2 after running the first segment.
#### `map_original_fid_to_value` state: 6 entries
|   Original Index   |    Description    |
| -----------------| --------------- |
| 1 | The second tensor argument for the original fusion. |
| 3 | The broadcasted, reduction tensor, which is the output from the
first segment. |

* Removed the entry for `T0` because the first tensor argument is not
required for second segment.
*  Added the entry for `T3`, which is the output from the first segment.

### Step 3 after running the second segment.
#### `map_original_fid_to_value` state: 5 entries
|   Original Index   |    Description    |
| -----------------| --------------- |
| 4 | The pointwise addition, which is the output for the original
fusion. |

* Removed the entries for `T1` and `T3` because they are not necessary
anymore.
* Added the entry for `T4`, which is the output from the second segment.

### Step 4 after running all segments.
* Return `T4` from `map_original_fid_to_value` as the result for the
original fusion.
  • Loading branch information
rdspring1 authored Nov 24, 2024
1 parent ad3aa3f commit bb05859
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 22 deletions.
3 changes: 3 additions & 0 deletions csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,9 @@ void initNvFuserPythonBindings(PyObject* module) {
// Mark the end of segmentation
inst::Trace::instance()->endEvent(nullptr);
})
.def("inputs", [](FusionDefinition& self) { return self.inputs(); })
.def("outputs", [](FusionDefinition& self) { return self.outputs(); })
.def("extents", [](FusionDefinition& self) { return self.extents(); })
.def(
"__repr__",
[](FusionDefinition& self) {
Expand Down
45 changes: 33 additions & 12 deletions csrc/python_frontend/segmentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,52 @@ int64_t SegmentationState::setupSegmentation(
IrCloner original_to_cloned_map =
Fusion::copy(fusion, cloned_original_fusion_.get());

// Step 2) Given the map_presched_value_to_original_python_index AND the
// IRCloner returned by Fusion::copy, create a mapping from cloned CPP values
// to original fusion state indices.
std::unordered_map<Val*, int64_t> map_cloned_value_to_original_python_index;
map_cloned_value_to_original_python_index.reserve(
map_presched_value_to_original_python_index.size());
std::transform(
map_presched_value_to_original_python_index.begin(),
map_presched_value_to_original_python_index.end(),
std::inserter(
map_cloned_value_to_original_python_index,
map_cloned_value_to_original_python_index.end()),
[&](const auto& item) {
const Val* original_value = item.first;
int64_t python_index = item.second;
Val* cloned_value = original_to_cloned_map.clone(original_value);
return std::make_pair(cloned_value, python_index);
});

// Step 3) Concretize fusion with input arguments.
KernelArgumentHolder args =
KernelArgumentHolder::createKernelArgumentHolder(inputs, device);

// Step 2) Concretize fusion with input arguments.
std::unordered_map<Val*, Val*> symbolic_to_concrete_map =
DynamicTransform::concretizeFusion(cloned_original_fusion_.get(), args);

// Step 3) Given the map_presched_value_to_original_python_index, the IRCloner
// returned by Fusion::copy, AND the symbolic_to_concrete map returned by
// concretization pass, create a mapping from cloned Vals to original fusion
// state indices.
// Given the map_cloned_value_to_original_python_index AND the
// symbolic_to_concrete map returned by the concretization pass, create a
// mapping from cloned, concretized CPP values to original fusion state
// indices.
map_cloned_concretized_value_to_original_python_index_.reserve(
map_cloned_value_to_original_python_index.size());
std::transform(
map_presched_value_to_original_python_index.begin(),
map_presched_value_to_original_python_index.end(),
map_cloned_value_to_original_python_index.begin(),
map_cloned_value_to_original_python_index.end(),
std::inserter(
map_cloned_concretized_value_to_original_python_index_,
map_cloned_concretized_value_to_original_python_index_.end()),
[&](const auto& item) {
const Val* original_value = item.first;
Val* maybe_concretized_value = item.first;
int64_t python_index = item.second;
Val* cloned_val = original_to_cloned_map.clone(original_value);
if (symbolic_to_concrete_map.count(cloned_val)) {
cloned_val = symbolic_to_concrete_map.at(cloned_val);
if (symbolic_to_concrete_map.count(maybe_concretized_value) > 0) {
maybe_concretized_value =
symbolic_to_concrete_map.at(maybe_concretized_value);
}
return std::make_pair(cloned_val, python_index);
return std::make_pair(maybe_concretized_value, python_index);
});

// Track the extents for input TensorViews in cloned CPP Fusion.
Expand Down
79 changes: 79 additions & 0 deletions nvfuser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,82 @@ def __exit__(self, type, value, traceback):
def definition(self):
raise NotImplementedError("definition() should be implemented by child class!")

def _execute_segments(self, input_arguments, *, device=None, profile=False):
"""
Run the sequence of FusionDefinition segments to generate the results
of this FusionDefinition.
This FusionDefinition acts an argument manager. It gathers input
arguments for the segments and stores their output results. After
running a segment, any redundant intermediate values, which are
unnecessary for any other segments, are deleted to save memory.
Args:
inputs (List[Union[Tensor, Scalar]]): A list of inputs to fusion.
Kwargs:
device (Optional[Union[int, str, torch.device]]): This is a hint to run
the Fusion on the given CUDA device. This is not typically
necessary, as the device is usually inferred from the locations
of input tensors. However, for some fusion definitions, no
tensors will be input (for example when all tensors are
generated with `full` or `uniform` ops). In these cases, we
must either tell NVFuser where to run the resulting kernel, or
let it default to 0. Note that passing this option providing
and input tensors that lie on another device is an error.
profile (bool): Captures a CUPTI based profile of a fusion.
Returns:
List[Tensor]: The output results for this FusionDefinition.
"""
assert len(self.segments) > 0
assert len(self.segments) == len(self.segment_index_space_maps)

input_arguments_with_extents = [*input_arguments]
for a in input_arguments:
if type(a) is torch.Tensor:
input_arguments_with_extents.extend(a.size())

# Map inputs arguments to original fid
map_original_fid_to_value = {
fd_state: argument
for fd_state, argument in zip(
self.inputs() + self.extents(), input_arguments_with_extents
)
}

# Run all segments in correct order
for idx, segment in enumerate(self.segments):
segment_to_original_map = self.segment_index_space_maps[segment]

# Gather segment input arguments
segment_arguments = [
map_original_fid_to_value[segment_to_original_map[fd_state]]
for fd_state in segment.inputs()
]

# Run segment
segment_outputs = segment.execute(
segment_arguments, device=device, profile=profile
)

# Update original fusion definition indices to outputs
for fd_state, output in zip(segment.outputs(), segment_outputs):
map_original_fid_to_value[segment_to_original_map[fd_state]] = output

# Destroy any arguments that are not used by future segments
for segment_input in segment.inputs():
original_input = segment_to_original_map[segment_input]
if (
original_input not in self.outputs()
and self.map_value_to_last_used_segment[original_input] == idx
):
del map_original_fid_to_value[original_input]

# Map output fid to actual results
return [map_original_fid_to_value[fd_state] for fd_state in self.outputs()]

def execute(
self,
inputs,
Expand Down Expand Up @@ -225,6 +301,9 @@ def execute(
fake_mode = FakeTensorMode()
self.fake_inputs = [fake_mode.from_tensor(inp) for inp in inputs]

if hasattr(self, "segments") and len(self.segments) > 0:
return self._execute_segments(inputs, device=device, profile=profile)

results = None

try:
Expand Down
4 changes: 3 additions & 1 deletion tests/python/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def torch_correctness_test_fn(fd_fn: Callable, nvf_op: OpInfo, sample: SampleInp
assert check_captured_python_definition(nvfuser_result, fd, inputs_cap)

if nvf_op.is_clonable:
assert check_cpp_translation(nvfuser_result, fd, inputs_cap)
assert check_cpp_translation(
nvfuser_result, fd, inputs_cap, supports_segmentation=True
)

torch_result = nvf_op.reference(*sample.args, **sample.kwargs)

Expand Down
16 changes: 9 additions & 7 deletions tests/python/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2831,7 +2831,8 @@ def fusion_func(fd: FusionDefinition) -> None:
T89 = fd.ops.sum(T98, dims=[4], keepdim=False, dtype=DataType.Null)
fd.add_output(T89)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
# TODO Segmentation fails validateAllocationSizesAndStrides
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False)

# This tests no dead code at definition does not cause a problem due to
# removal of empty tensors
Expand Down Expand Up @@ -3237,7 +3238,7 @@ def fusion_func(fd: FusionDefinition) -> None:
fd.add_output(T54)
fd.add_output(T30)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False)
# self.assertEqual(nvf_out[0], t24)

# Test that symbolic IterDomains can be concatenated
Expand Down Expand Up @@ -3769,7 +3770,7 @@ def fusion_func(fd: FusionDefinition) -> None:
fd.add_output(T57)
fd.add_output(T101)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False)

# A simple pointwise fusion, but passed misaligned input
def test_misaligned_add(self):
Expand Down Expand Up @@ -3935,7 +3936,7 @@ def fusion_func(fd: FusionDefinition) -> None:

fd.add_output(T88)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False)

# See https://github.com/NVIDIA/Fuser/issues/2275
@pytest.mark.skipif(
Expand Down Expand Up @@ -3981,7 +3982,7 @@ def fusion_func(fd: FusionDefinition) -> None:
T101 = fd.ops.cat([T7, T100], dim=-1)
fd.add_output(T101)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False)

# See https://github.com/NVIDIA/Fuser/issues/2317
@pytest.mark.skipif(
Expand Down Expand Up @@ -4138,7 +4139,8 @@ def fusion_func(fd: FusionDefinition) -> None:
# T7 = fd.ops.reshape(T1, new_shape=V5)
fd.add_output(T7)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
# TODO Segmentation fails validateAllocationSizesAndStrides
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False)

# Test empty symbolic tensors can be reshaped
# See https://github.com/NVIDIA/Fuser/issues/2362
Expand Down Expand Up @@ -4762,7 +4764,7 @@ def fusion_func(fd: FusionDefinition) -> None:
T223 = fd.ops.cat([T169, T222], dim=-1, manual_padding=0)
fd.add_output(T223)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False)

def test_enable_disable_options(self):
m = 24
Expand Down
19 changes: 17 additions & 2 deletions tests/python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,13 +242,24 @@ def check_captured_python_definition(reference_outputs, fd, inputs, device=None)

# Run original FusionDefinition
# Clone FusionDefinition
# Apply segmentation if it supported for this FusionDefinition
# Run cloned python definition
# Check that the result of cloned python definition matches original results
def check_cpp_translation(reference_outputs, fd, inputs, device=None):
def check_cpp_translation(
reference_outputs, fd, inputs, supports_segmentation, device=None
):
try:
torch.manual_seed(0)

# Clone
cloned_fd = FusionDefinition()
clone(fd, cloned_fd)

# Segment
if supports_segmentation:
cloned_fd.segment(inputs)

# Run
cloned_outputs = cloned_fd.execute(inputs, device=device)

# Make sure the results of original and cloned definitions match.
Expand All @@ -268,6 +279,7 @@ def check_cpp_translation(reference_outputs, fd, inputs, device=None):
print(
"(A failure here suggests a mismatch in functionality between the original and cloned definitions.)"
)
print("Does FusionDefinition supports segmentation?\t", supports_segmentation)
print(fd.getReproErrorString("executing", inputs))
raise err

Expand Down Expand Up @@ -421,6 +433,7 @@ def exec_nvfuser(
new_fusion_expected=True,
device=None,
is_clonable=True,
supports_segmentation=True,
):
fc = FusionCache.get()
before_fusions = fc.num_fusions()
Expand Down Expand Up @@ -450,5 +463,7 @@ def exec_nvfuser(
)

if is_clonable:
self.assertTrue(check_cpp_translation(out, fd, inputs_cloned))
self.assertTrue(
check_cpp_translation(out, fd, inputs_cloned, supports_segmentation)
)
return out, fd

0 comments on commit bb05859

Please sign in to comment.