Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Translate segments to python definition #3335

Merged
merged 1 commit into from
Nov 15, 2024

Conversation

rdspring1
Copy link
Collaborator

@rdspring1 rdspring1 commented Nov 3, 2024

Overview:

  • buildSegment creates the CPP Fusion for a given segment id, translates it to a python FusionDefinition, then returns a mapping from the segment fusion state indices to the original fusion state indices.
  • FusionDefinition.segment calls setupSegmentation, buildSegment, and finalizeSegmentation to create python definitions for the sub-fusions and their index mappings.

Changes in this PR

This PR implements buildSegment function for user-scheduler segmentation. It is the second PR in a stack, preceded by #3334 and followed by #3025.

  1. Implement buildSegment function in csrc/python_frontend/segmentation.cpp.
  2. Complete segment function in nvfuser/__init__.py

Example:

Original Fusion: A reduction + broadcast + pointwise fusion.

def nvfuser_fusion_id1(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1],
                          contiguity=[True, True],
                          dtype=DataType.Float,
                          is_cpu=False)
    T1 = fd.define_tensor(shape=[-1, -1],
                          contiguity=[True, True],
                          dtype=DataType.Float,
                          is_cpu=False)
    T2 = fd.ops.sum(T0, dims=[1], keepdim=False, dtype=DataType.Float)
    T3 = fd.ops.broadcast(T2, is_broadcast_dim=[False, True])
    T4 = fd.ops.add(T1, T3)
    fd.add_output(T4)

After Segmentation: The reduction scheduler does not support fusing any operations with an inner reduction, so the original fusion is divided into two segments.

First Segment:

The first segment contains the reduction and broadcast operations, which corresponds with [T0, T2, T3] in the original fusion. Therefore, the segment index to original index map has two entries.

Segment Index Original Index Description
T0 T0 The first tensor argument for the original fusion.
T2 T3 The broadcasted, reduction tensor is this segment's output.
def nvfuser_fusion_id2(fd : FusionDefinition) -> None :
   T0 = fd.define_tensor(shape=[-1, -1],
                         contiguity=[True, True],
                         dtype=DataType.Float,
                         is_cpu=False)
   T1 = fd.ops.sum(T0, dims=[1], keepdim=False, dtype=DataType.Float)
   T2 = fd.ops.broadcast(T1, is_broadcast_dim=[False, True])
   fd.add_output(T2)

Second Segment:

The second segment is the pointwise addition with the broadcasted reduction. It corresponds with [T1, T3, T4] in the original fusion.

Segment Index Original Index Description
T0 T1 The second tensor argument for the original fusion.
T1 T3 The broadcasted, reduction tensor, which is the output from the first segment.
T2 T4 The pointwise addition, which is the output for the original fusion.
def nvfuser_fusion_id3(fd : FusionDefinition) -> None :
   T0 = fd.define_tensor(shape=[-1, -1],
                         contiguity=[True, True],
                         dtype=DataType.Float,
                         is_cpu=False)
   T1 = fd.define_tensor(shape=[-1, 1],
                         contiguity=[True, None],
                         dtype=DataType.Float,
                         is_cpu=False)
   T2 = fd.ops.add(T0, T1)
   fd.add_output(T2)

@rdspring1 rdspring1 added the Python API Issues related to the Python API label Nov 3, 2024
@rdspring1 rdspring1 marked this pull request as ready for review November 3, 2024 18:04
@rdspring1 rdspring1 force-pushed the user_sched_segmentation_build branch 2 times, most recently from fadad52 to a53d1dd Compare November 12, 2024 18:02
@Priya2698
Copy link
Collaborator

I am seeing changes from PR #3334, can you rebase to only include changes from this PR for easier review?

@rdspring1 rdspring1 force-pushed the user_sched_segmentation_translate branch from e0a4538 to 98fd9b2 Compare November 13, 2024 16:43
rdspring1 added a commit that referenced this pull request Nov 13, 2024
## General Overview of Segmentation:

Segmentation decomposes a fusion into a directed acyclic graph (DAG) of
sub-fusions. After applying the segmentation algorithm, we can translate
the sub-fusions into their corresponding python definitions. Then, given
the fusion's input arguments, the segments are run in the correct order
to produce the output results.

The original FusionDefinition stores the sequence of sub-fusions and
acts as an argument manager. It gathers the input arguments before
running the sub-fusion and stores its results. To perform this function,
it requires a map from the segment index space to the original index
space. This mapping is generated while creating the python definition
for each sub-fusion.

### CPP functions:
Step 1: `setupSegmentation` runs the segmentation algorithm on the CPP
Fusion to create the `SegmentedFusion`. Then, sub-fusions are ordered
according to their dependencies by the `prepareGroupOrder` function. It
returns the number of segments in `SegmentedFusion`.

Step 2: `buildSegment` creates the CPP `Fusion` for a given segment id,
translates it to a python `FusionDefinition`, then returns a mapping
from the segment fusion state indices to the original fusion state
indices.

Step 3: `finalizeSegmentation` destroys any state stored in
`FusionDefinition`.

### Python functions:

1. `setupSegmentation`, `buildSegment`, and `finalizeSegmentation` are
called together in `FusionDefinition.segment`.
6. If a python `FusionDefinition` has segments, call `_execute_segments`
in the `FusionDefinition.execute`. The original `FusionDefinition` acts
as argument manager, running the sub-fusions in topological order.

## Example:
### Original Fusion: A reduction + broadcast + pointwise fusion.
```python
def nvfuser_fusion_id1(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1],
                          contiguity=[True, True],
                          dtype=DataType.Float,
                          is_cpu=False)
    T1 = fd.define_tensor(shape=[-1, -1],
                          contiguity=[True, True],
                          dtype=DataType.Float,
                          is_cpu=False)
    T2 = fd.ops.sum(T0, dims=[1], keepdim=False, dtype=DataType.Float)
    T3 = fd.ops.broadcast(T2, is_broadcast_dim=[False, True])
    T4 = fd.ops.add(T1, T3)
    fd.add_output(T4)
```

### After Segmentation:
- The reduction scheduler does not support fusing any operations with an
inner reduction, so the original fusion is divided into two segments.
Segment 2 depends on Segment 1, so there is a strict ordering of the
segments.
- The first segment contains the reduction and broadcast operations,
which corresponds with [T0, T2, T3] in the original fusion.
- The second segment is the pointwise addition with the broadcasted
reduction. It corresponds with [T1, T3, T4] in the original fusion.

**First Segment:**
```python
def nvfuser_fusion_id2(fd : FusionDefinition) -> None :
   T0 = fd.define_tensor(shape=[-1, -1],
                         contiguity=[True, True],
                         dtype=DataType.Float,
                         is_cpu=False)
   T1 = fd.ops.sum(T0, dims=[1], keepdim=False, dtype=DataType.Float)
   T2 = fd.ops.broadcast(T1, is_broadcast_dim=[False, True])
   fd.add_output(T2)
```

**Second Segment:**
```python
def nvfuser_fusion_id3(fd : FusionDefinition) -> None :
   T0 = fd.define_tensor(shape=[-1, -1],
                         contiguity=[True, True],
                         dtype=DataType.Float,
                         is_cpu=False)
   T1 = fd.define_tensor(shape=[-1, 1],
                         contiguity=[True, None],
                         dtype=DataType.Float,
                         is_cpu=False)
   T2 = fd.ops.add(T0, T1)
   fd.add_output(T2)
```

## Changes in this PR

This PR implements `setupSegmentation` function for user-scheduler
segmentation. It is the first PR in a stack, followed by
#3335 and
#3025.

1. Create `SegmentationState` class that contains all segmentation logic
for python-frontend.
2. All segmentation logic is contained in a separate file -
`csrc/python_frontend/segmentation.h`
3. `FusionDefinition` contains an instantiation of `SegmentationState`
and exposes its logic in a public interface. This interface is added to
the python bindings.
4. Created `test_segmentation_reduction_pointwise_epilogue` to test
functionality.
Base automatically changed from user_sched_segmentation_build to main November 13, 2024 18:58
@jjsjann123
Copy link
Collaborator

Oops. I think the merge of #3334 messed up the git history. You might have to resolve the conflicts by hand now.

@rdspring1 rdspring1 force-pushed the user_sched_segmentation_translate branch from 98fd9b2 to 4e22cba Compare November 13, 2024 19:27
@rdspring1
Copy link
Collaborator Author

I used git rebase to fixed the conflicts.

Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion against the implementation. But I'm rather nitpicking on the implementation/comments.
So I'm not pushing for fixing everything in this PR, given that it's just an implementation detail that isn't necessarily exposed to the user.

csrc/python_frontend/segmentation.cpp Show resolved Hide resolved
csrc/python_frontend/segmentation.h Outdated Show resolved Hide resolved
// k) Map cloned Vals to their corresponding fusion indices.
// l) Add missing mappings to segment to original indices map.
// 5) Return the mapping from the segmented FusionDefinition index space to
// original FusionDefinition index space.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my nitpick on the comment along with the implementation.

It felt quite a bit verbose on how the mapping is done. At least I struggle to follow that. I think a higher level description would be useful.

i.e. how the map goes from segmented_fusion_definition -> fusion_segment -> original_fusion -> original_fusion_definition. I think it would also help if we try to rename variables in the implementation to make it easy to see how each map it projecting in the map above.

nvfuser/__init__.py Outdated Show resolved Hide resolved
nvfuser/__init__.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@Priya2698 Priya2698 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall.

I mainly have requests around naming, comments and test.
Thanks for adding detailed comments, very helpful in understanding your implementation. Great work!

Add buildSegment function
@rdspring1 rdspring1 force-pushed the user_sched_segmentation_translate branch from 4e22cba to f6dec35 Compare November 15, 2024 02:46
@rdspring1
Copy link
Collaborator Author

!test

@rdspring1
Copy link
Collaborator Author

I renamed some variables to make things clearer. I hope it helps!!!

@rdspring1 rdspring1 merged commit 3229ed8 into main Nov 15, 2024
47 of 48 checks passed
@rdspring1 rdspring1 deleted the user_sched_segmentation_translate branch November 15, 2024 18:27
rdspring1 added a commit that referenced this pull request Nov 24, 2024
## Overview:
The original `FusionDefinition` stores the sequence of sub-fusions and
acts as an argument manager. It gathers the input arguments before
running the sub-fusion and stores its results. To perform this function,
it uses a map from the segment index space to the original index space.
This mapping was generated while creating the python definition for each
sub-fusion.

## Changes in this PR

This PR implements `_execute_segments ` function for user-scheduler
segmentation. It is the third PR in a stack, preceded by
#3334 and
#3335.

1. Implement `_execute_segments ` function in `nvfuser/__init__.py` to
orchestrate segments in original fusion.
2. Add `supports_segmentation flag` to `exec_nvfuser`, so segmentation
testing is enabled by default for all python tests.

## Example:
### Original Fusion: A reduction + broadcast + pointwise fusion.
```python
def nvfuser_fusion_id1(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1],
                          contiguity=[True, True],
                          dtype=DataType.Float,
                          is_cpu=False)
    T1 = fd.define_tensor(shape=[-1, -1],
                          contiguity=[True, True],
                          dtype=DataType.Float,
                          is_cpu=False)
    T2 = fd.ops.sum(T0, dims=[1], keepdim=False, dtype=DataType.Float)
    T3 = fd.ops.broadcast(T2, is_broadcast_dim=[False, True])
    T4 = fd.ops.add(T1, T3)
    fd.add_output(T4)
```

## Step-by-Step execution of `_execute_segments`

### Step 1 before running any segments.
#### `map_original_fid_to_value` state: 6 entries
|   Original Index   |    Description    |
| -----------------| --------------- |
| 0 | The first tensor argument for the original fusion. |
| 1 | The second tensor argument for the original fusion. |
| -1 | Extent of axis 0 for first tensor argument. |
| -2 | Extent of axis 1 for first tensor argument. |
| -3 | Extent of axis 0 for second tensor argument. |
| -4 | Extent of axis 1 for second tensor argument. |

* Omit extents [-1, -4] from table in future steps because they are not
necessary for these segments.

### Step 2 after running the first segment.
#### `map_original_fid_to_value` state: 6 entries
|   Original Index   |    Description    |
| -----------------| --------------- |
| 1 | The second tensor argument for the original fusion. |
| 3 | The broadcasted, reduction tensor, which is the output from the
first segment. |

* Removed the entry for `T0` because the first tensor argument is not
required for second segment.
*  Added the entry for `T3`, which is the output from the first segment.

### Step 3 after running the second segment.
#### `map_original_fid_to_value` state: 5 entries
|   Original Index   |    Description    |
| -----------------| --------------- |
| 4 | The pointwise addition, which is the output for the original
fusion. |

* Removed the entries for `T1` and `T3` because they are not necessary
anymore.
* Added the entry for `T4`, which is the output from the second segment.

### Step 4 after running all segments.
* Return `T4` from `map_original_fid_to_value` as the result for the
original fusion.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Python API Issues related to the Python API
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants