-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Translate segments to python definition #3335
Conversation
fadad52
to
a53d1dd
Compare
I am seeing changes from PR #3334, can you rebase to only include changes from this PR for easier review? |
e0a4538
to
98fd9b2
Compare
## General Overview of Segmentation: Segmentation decomposes a fusion into a directed acyclic graph (DAG) of sub-fusions. After applying the segmentation algorithm, we can translate the sub-fusions into their corresponding python definitions. Then, given the fusion's input arguments, the segments are run in the correct order to produce the output results. The original FusionDefinition stores the sequence of sub-fusions and acts as an argument manager. It gathers the input arguments before running the sub-fusion and stores its results. To perform this function, it requires a map from the segment index space to the original index space. This mapping is generated while creating the python definition for each sub-fusion. ### CPP functions: Step 1: `setupSegmentation` runs the segmentation algorithm on the CPP Fusion to create the `SegmentedFusion`. Then, sub-fusions are ordered according to their dependencies by the `prepareGroupOrder` function. It returns the number of segments in `SegmentedFusion`. Step 2: `buildSegment` creates the CPP `Fusion` for a given segment id, translates it to a python `FusionDefinition`, then returns a mapping from the segment fusion state indices to the original fusion state indices. Step 3: `finalizeSegmentation` destroys any state stored in `FusionDefinition`. ### Python functions: 1. `setupSegmentation`, `buildSegment`, and `finalizeSegmentation` are called together in `FusionDefinition.segment`. 6. If a python `FusionDefinition` has segments, call `_execute_segments` in the `FusionDefinition.execute`. The original `FusionDefinition` acts as argument manager, running the sub-fusions in topological order. ## Example: ### Original Fusion: A reduction + broadcast + pointwise fusion. ```python def nvfuser_fusion_id1(fd : FusionDefinition) -> None : T0 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False) T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False) T2 = fd.ops.sum(T0, dims=[1], keepdim=False, dtype=DataType.Float) T3 = fd.ops.broadcast(T2, is_broadcast_dim=[False, True]) T4 = fd.ops.add(T1, T3) fd.add_output(T4) ``` ### After Segmentation: - The reduction scheduler does not support fusing any operations with an inner reduction, so the original fusion is divided into two segments. Segment 2 depends on Segment 1, so there is a strict ordering of the segments. - The first segment contains the reduction and broadcast operations, which corresponds with [T0, T2, T3] in the original fusion. - The second segment is the pointwise addition with the broadcasted reduction. It corresponds with [T1, T3, T4] in the original fusion. **First Segment:** ```python def nvfuser_fusion_id2(fd : FusionDefinition) -> None : T0 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False) T1 = fd.ops.sum(T0, dims=[1], keepdim=False, dtype=DataType.Float) T2 = fd.ops.broadcast(T1, is_broadcast_dim=[False, True]) fd.add_output(T2) ``` **Second Segment:** ```python def nvfuser_fusion_id3(fd : FusionDefinition) -> None : T0 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False) T1 = fd.define_tensor(shape=[-1, 1], contiguity=[True, None], dtype=DataType.Float, is_cpu=False) T2 = fd.ops.add(T0, T1) fd.add_output(T2) ``` ## Changes in this PR This PR implements `setupSegmentation` function for user-scheduler segmentation. It is the first PR in a stack, followed by #3335 and #3025. 1. Create `SegmentationState` class that contains all segmentation logic for python-frontend. 2. All segmentation logic is contained in a separate file - `csrc/python_frontend/segmentation.h` 3. `FusionDefinition` contains an instantiation of `SegmentationState` and exposes its logic in a public interface. This interface is added to the python bindings. 4. Created `test_segmentation_reduction_pointwise_epilogue` to test functionality.
Oops. I think the merge of #3334 messed up the git history. You might have to resolve the conflicts by hand now. |
98fd9b2
to
4e22cba
Compare
I used |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a strong opinion against the implementation. But I'm rather nitpicking on the implementation/comments.
So I'm not pushing for fixing everything in this PR, given that it's just an implementation detail that isn't necessarily exposed to the user.
csrc/python_frontend/segmentation.h
Outdated
// k) Map cloned Vals to their corresponding fusion indices. | ||
// l) Add missing mappings to segment to original indices map. | ||
// 5) Return the mapping from the segmented FusionDefinition index space to | ||
// original FusionDefinition index space. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my nitpick on the comment along with the implementation.
It felt quite a bit verbose on how the mapping is done. At least I struggle to follow that. I think a higher level description would be useful.
i.e. how the map goes from segmented_fusion_definition -> fusion_segment -> original_fusion -> original_fusion_definition
. I think it would also help if we try to rename variables in the implementation to make it easy to see how each map it projecting in the map above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall.
I mainly have requests around naming, comments and test.
Thanks for adding detailed comments, very helpful in understanding your implementation. Great work!
Add buildSegment function
4e22cba
to
f6dec35
Compare
!test |
I renamed some variables to make things clearer. I hope it helps!!! |
## Overview: The original `FusionDefinition` stores the sequence of sub-fusions and acts as an argument manager. It gathers the input arguments before running the sub-fusion and stores its results. To perform this function, it uses a map from the segment index space to the original index space. This mapping was generated while creating the python definition for each sub-fusion. ## Changes in this PR This PR implements `_execute_segments ` function for user-scheduler segmentation. It is the third PR in a stack, preceded by #3334 and #3335. 1. Implement `_execute_segments ` function in `nvfuser/__init__.py` to orchestrate segments in original fusion. 2. Add `supports_segmentation flag` to `exec_nvfuser`, so segmentation testing is enabled by default for all python tests. ## Example: ### Original Fusion: A reduction + broadcast + pointwise fusion. ```python def nvfuser_fusion_id1(fd : FusionDefinition) -> None : T0 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False) T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False) T2 = fd.ops.sum(T0, dims=[1], keepdim=False, dtype=DataType.Float) T3 = fd.ops.broadcast(T2, is_broadcast_dim=[False, True]) T4 = fd.ops.add(T1, T3) fd.add_output(T4) ``` ## Step-by-Step execution of `_execute_segments` ### Step 1 before running any segments. #### `map_original_fid_to_value` state: 6 entries | Original Index | Description | | -----------------| --------------- | | 0 | The first tensor argument for the original fusion. | | 1 | The second tensor argument for the original fusion. | | -1 | Extent of axis 0 for first tensor argument. | | -2 | Extent of axis 1 for first tensor argument. | | -3 | Extent of axis 0 for second tensor argument. | | -4 | Extent of axis 1 for second tensor argument. | * Omit extents [-1, -4] from table in future steps because they are not necessary for these segments. ### Step 2 after running the first segment. #### `map_original_fid_to_value` state: 6 entries | Original Index | Description | | -----------------| --------------- | | 1 | The second tensor argument for the original fusion. | | 3 | The broadcasted, reduction tensor, which is the output from the first segment. | * Removed the entry for `T0` because the first tensor argument is not required for second segment. * Added the entry for `T3`, which is the output from the first segment. ### Step 3 after running the second segment. #### `map_original_fid_to_value` state: 5 entries | Original Index | Description | | -----------------| --------------- | | 4 | The pointwise addition, which is the output for the original fusion. | * Removed the entries for `T1` and `T3` because they are not necessary anymore. * Added the entry for `T4`, which is the output from the second segment. ### Step 4 after running all segments. * Return `T4` from `map_original_fid_to_value` as the result for the original fusion.
Overview:
buildSegment
creates the CPP Fusion for a given segment id, translates it to a python FusionDefinition, then returns a mapping from the segment fusion state indices to the original fusion state indices.FusionDefinition.segment
callssetupSegmentation
,buildSegment
, andfinalizeSegmentation
to create python definitions for the sub-fusions and their index mappings.Changes in this PR
This PR implements
buildSegment
function for user-scheduler segmentation. It is the second PR in a stack, preceded by #3334 and followed by #3025.buildSegment
function incsrc/python_frontend/segmentation.cpp
.segment
function innvfuser/__init__.py
Example:
Original Fusion: A reduction + broadcast + pointwise fusion.
After Segmentation: The reduction scheduler does not support fusing any operations with an inner reduction, so the original fusion is divided into two segments.
First Segment:
The first segment contains the reduction and broadcast operations, which corresponds with [T0, T2, T3] in the original fusion. Therefore, the segment index to original index map has two entries.
Second Segment:
The second segment is the pointwise addition with the broadcasted reduction. It corresponds with [T1, T3, T4] in the original fusion.