-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Translate segments to python definition (#3335)
## Overview: - `buildSegment` creates the CPP Fusion for a given segment id, translates it to a python FusionDefinition, then returns a mapping from the segment fusion state indices to the original fusion state indices. - `FusionDefinition.segment` calls `setupSegmentation`, `buildSegment`, and `finalizeSegmentation` to create python definitions for the sub-fusions and their index mappings. ## Changes in this PR This PR implements `buildSegment` function for user-scheduler segmentation. It is the second PR in a stack, preceded by #3334 and followed by #3025. 1. Implement `buildSegment` function in `csrc/python_frontend/segmentation.cpp`. 2. Complete `segment` function in `nvfuser/__init__.py` ## Example: ### Original Fusion: A reduction + broadcast + pointwise fusion. ```python def nvfuser_fusion_id1(fd : FusionDefinition) -> None : T0 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False) T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False) T2 = fd.ops.sum(T0, dims=[1], keepdim=False, dtype=DataType.Float) T3 = fd.ops.broadcast(T2, is_broadcast_dim=[False, True]) T4 = fd.ops.add(T1, T3) fd.add_output(T4) ``` **After Segmentation:** The reduction scheduler does not support fusing any operations with an inner reduction, so the original fusion is divided into two segments. ## First Segment: The first segment contains the reduction and broadcast operations, which corresponds with [T0, T2, T3] in the original fusion. Therefore, the segment index to original index map has two entries. | Segment Index | Original Index | Description | | -----------------| --------------- | ------------- | | T0 | T0 | The first tensor argument for the original fusion. | | T2 | T3 | The broadcasted, reduction tensor is this segment's output. | ```python def nvfuser_fusion_id2(fd : FusionDefinition) -> None : T0 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False) T1 = fd.ops.sum(T0, dims=[1], keepdim=False, dtype=DataType.Float) T2 = fd.ops.broadcast(T1, is_broadcast_dim=[False, True]) fd.add_output(T2) ``` ## Second Segment: The second segment is the pointwise addition with the broadcasted reduction. It corresponds with [T1, T3, T4] in the original fusion. | Segment Index | Original Index | Description | | -----------------| --------------- | ------------- | | T0 | T1 | The second tensor argument for the original fusion. | | T1 | T3 | The broadcasted, reduction tensor, which is the output from the first segment. | | T2 | T4 | The pointwise addition, which is the output for the original fusion. | ```python def nvfuser_fusion_id3(fd : FusionDefinition) -> None : T0 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False) T1 = fd.define_tensor(shape=[-1, 1], contiguity=[True, None], dtype=DataType.Float, is_cpu=False) T2 = fd.ops.add(T0, T1) fd.add_output(T2) ```
- Loading branch information
Showing
7 changed files
with
357 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.