Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creating better Segmented Fusion IR Debug Output #785

Closed
kevinstephano opened this issue Aug 25, 2023 · 1 comment
Closed

Creating better Segmented Fusion IR Debug Output #785

kevinstephano opened this issue Aug 25, 2023 · 1 comment
Assignees
Labels
Segmentation Issues related to nvFuser Segmentation

Comments

@kevinstephano
Copy link
Collaborator

When we debug Fusions generated by the framework, we don't know why a set of instructions exist together in the fusion group and then those instructions get segmented into multiple kernels and we are left needing to figure out why we are segmenting the fusion into multiple kernels from the top down. It would be nice if the segment information was captured in the Fusion IR dump instead of requiring a separate set of information to determine what is being grouped.

Here is a simplified example:

Python Definition:

import torch
from nvfuser import FusionDefinition, DataType
def nvfuser_fusion_id6(fd : FusionDefinition) -> None :
    T4 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False)
    T6 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False)
    S7 = fd.define_scalar(0.500000, dtype=DataType.Double)
    T8 = fd.ops.mul(S7, T6)
    S9 = fd.define_scalar(3.00000, dtype=DataType.Double)
    T10 = fd.ops.pow(T6, S9)
    T19 = fd.ops.sum(T4, axes=[0, 1], keepdim=False, dtype=DataType.Null)
    fd.add_output(T19)
    fd.add_output(T10)

inputs = [
        torch.randn(16, 512, 4096, device="cuda"),
        torch.randn(16, 512, 4096, device="cuda"),
]

with FusionDefinition() as fd:
    nvfuser_fusion_id6(fd)

for _ in range(5):
        out = fd.execute(inputs)

Pre-scheduled Fusion IR:
Command:

NVFUSER_DUMP=fusion_ir_presched python test.py
Inputs:
  T0_g[ iS0{i0}, iS1{i1}, iS2{i2} ], float
  T1_g[ iS3{i4}, iS4{i5}, iS5{i6} ], float
Outputs:
  T4_g[ rS12{i0}, rS13{i1}, iS14{i2} ], float
  T3_g[ iS9{i4}, iS10{i5}, iS11{i6} ], float

%kernel_math {
T4_g[ rS12{i0}, rS13{i1}, iS14{i2} ]
   = reduction( T0_g[ iS0{i0}, iS1{i1}, iS2{i2} ], op = add, initial value = float(0), allreduce = false )
T3_g[ iS9{i4}, iS10{i5}, iS11{i6} ]
   = pow(T1_g[ iS3{i4}, iS4{i5}, iS5{i6} ]
  , double(3));
}

Scheduled Fusion IR (the math only without the transforms of the tensors):
It looks like the printing of the segments is not cleanly separated in stdout.
Command:

NVFUSER_DUMP=fusion_ir_math python test.py
Inputs:
  T1_g[ iS54{( ceilDiv(( ceilDiv(( ceilDiv(( i4 * ( i5 * i6 ) ), 4) ), 1) ), 128) )}, iS53{1}, iS51{4}, iS55{128} ], float
Outputs:
  T3_g[ iblockIdx.x30{( ceilDiv(( ceilDiv(( ceilDiv(( i4 * ( i5 * i6 ) ), 4) ), 1) ), 128) )}, iUS29{1}, iV27{4}, ithreadIdx.x31{128} ] ca_pos( 2 ) produce_pos( 2 ), float

%kernel_math {
Inputs:
  T5_l[ iblockIdx.x46{( ceilDiv(( ceilDiv(( ceilDiv(( i4 * ( i5 * i6 ) ), 4) ), 1) ), 128) )}, iUS45{1}, iV43{4}, ithreadIdx.x47{128} ] ca_pos( 2 )
   = Set( T1_g[ iS54{( ceilDiv(( ceilDiv(( ceilDiv(( i4 * ( i5 * i6 ) ), 4) ), 1) ), 128) )}, iS53{1}, iS51{4}, iS55{128} ] )
T0_g[ iS63{( ceilDiv(i2, blockDim.x) )}, iS64{blockDim.x}, iS73{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(( i0 * i1 ), blockDim.y) ), 4) ), 1) ), gridDim.y) )}, iS72{gridDim.y}, iS67{blockDim.y}, iS71{1}, iS69{4} ], float
Outputs:
  T4_g[ iblockIdx.x74{( ceilDiv(i2, blockDim.x) )}, ithreadIdx.x75{blockDim.x} ] ca_pos( 2 ) produce_pos( 2 ), float

%kernel_math {
T6_l[ iblockIdx.x38{( ceilDiv(( ceilDiv(( ceilDiv(( i4 * ( i5 * i6 ) ), 4) ), 1) ), 128) )}, iUS37{1}, iS35{4}, ithreadIdx.x39{128} ] ca_pos( 2 ) produce_pos( 2 )
   = pow(T5_l[ iblockIdx.x46{( ceilDiv(( ceilDiv(( ceilDiv(( i4 * ( i5 * i6 ) ), 4) ), 1) ), 128) )}, iUS45{1}, iV43{4}, ithreadIdx.x47{128} ] ca_pos( 2 )
  , double(3));
T5_l[ iblockIdx.x52{( ceilDiv(i2, blockDim.x) )}, ithreadIdx.x53{blockDim.x}, iS62{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(( i0 * i1 ), blockDim.y) ), 4) ), 1) ), gridDim.y) )}, iblockIdx.y61{gridDim.y}, ithreadIdx.y56{blockDim.y}, iUS60{1}, iUR58{4} ] ca_pos( 6 )
   = Set( T0_g[ iS63{( ceilDiv(i2, blockDim.x) )}, iS64{blockDim.x}, iS73{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(( i0 * i1 ), blockDim.y) ), 4) ), 1) ), gridDim.y) )}, iS72{gridDim.y}, iS67{blockDim.y}, iS71{1}, iS69{4} ] )
T3_g[ iblockIdx.x30{( ceilDiv(( ceilDiv(( ceilDiv(( i4 * ( i5 * i6 ) ), 4) ), 1) ), 128) )}, iUS29{1}, iV27{4}, ithreadIdx.x31{128} ] ca_pos( 2 ) produce_pos( 2 )
   = Set( T6_l[ iblockIdx.x38{( ceilDiv(( ceilDiv(( ceilDiv(( i4 * ( i5 * i6 ) ), 4) ), 1) ), 128) )}, iUS37{1}, iS35{4}, ithreadIdx.x39{128} ] ca_pos( 2 ) produce_pos( 2 ) )
}

T7_l[ iblockIdx.x36{( ceilDiv(i2, blockDim.x) )}, ithreadIdx.x37{blockDim.x}, rS46{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(( i0 * i1 ), blockDim.y) ), 4) ), 1) ), gridDim.y) )}rf, iblockIdx.y45{gridDim.y}rf, ithreadIdx.y40{blockDim.y}rf, rUS44{1}rf, rS42{4}rf ] ca_pos( 2 ) produce_pos( 6 )
   = reduction( T5_l[ iblockIdx.x52{( ceilDiv(i2, blockDim.x) )}, ithreadIdx.x53{blockDim.x}, iS62{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(( i0 * i1 ), blockDim.y) ), 4) ), 1) ), gridDim.y) )}, iblockIdx.y61{gridDim.y}, ithreadIdx.y56{blockDim.y}, iUS60{1}, iUR58{4} ] ca_pos( 6 ), op = add, initial value = float(0), allreduce = false )
T6_l[ iblockIdx.x50{( ceilDiv(i2, blockDim.x) )}, ithreadIdx.x51{blockDim.x}, rblockIdx.y47{gridDim.y}, rthreadIdx.y48{blockDim.y} ] ca_pos( 2 ) produce_pos( 2 )
   = reduction( T7_l[ iblockIdx.x36{( ceilDiv(i2, blockDim.x) )}, ithreadIdx.x37{blockDim.x}, rS46{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(( i0 * i1 ), blockDim.y) ), 4) ), 1) ), gridDim.y) )}rf, iblockIdx.y45{gridDim.y}rf, ithreadIdx.y40{blockDim.y}rf, rUS44{1}rf, rS42{4}rf ] ca_pos( 2 ) produce_pos( 6 ), op = add, initial value = float(0), allreduce = false )
T4_g[ iblockIdx.x74{( ceilDiv(i2, blockDim.x) )}, ithreadIdx.x75{blockDim.x} ] ca_pos( 2 ) produce_pos( 2 )
   = Set( T6_l[ iblockIdx.x50{( ceilDiv(i2, blockDim.x) )}, ithreadIdx.x51{blockDim.x}, rblockIdx.y47{gridDim.y}, rthreadIdx.y48{blockDim.y} ] ca_pos( 2 ) produce_pos( 2 ) )
}

Segmentation Fusion IR Debug:
It is not clear where the groups are coming from original pre-scheduled Fusion IR. Since this case only has 2 operations, it is easy to infer. It would be nice if segmentation information was naturally just printed with the Fusion IR instead of requiring a separate debug output. As we take in larger fusion groups, it is going to be more likely that we will segment operations.

Command:

NVFUSER_DUMP=segmented_fusion python test.py
Segment the fusion (Original Fusion Un-modified):
Inputs:
  T0_g[ iS0{i0}, iS1{i1}, iS2{i2} ], float
  T1_g[ iS3{i4}, iS4{i5}, iS5{i6} ], float
Outputs:
  T4_g[ rS12{i0}, rS13{i1}, iS14{i2} ], float
  T3_g[ iS9{i4}, iS10{i5}, iS11{i6} ], float

%kernel_math {
T4_g[ rS12{i0}, rS13{i1}, iS14{i2} ]
   = reduction( T0_g[ iS0{i0}, iS1{i1}, iS2{i2} ], op = add, initial value = float(0), allreduce = false )
T3_g[ iS9{i4}, iS10{i5}, iS11{i6} ]
   = pow(T1_g[ iS3{i4}, iS4{i5}, iS5{i6} ]
  , double(3));
}

Segmented_Fusion Dump: -- Re-written complete fusion:{
Inputs:
  T0_g[ iS0{i0}, iS1{i1}, iS2{i2} ], float
  T1_g[ iS3{i4}, iS4{i5}, iS5{i6} ], float
Outputs:
  T4_g[ rS12{i0}, rS13{i1}, iS14{i2} ], float
  T3_g[ iS9{i4}, iS10{i5}, iS11{i6} ], float

%kernel_math {
T4_g[ rS12{i0}, rS13{i1}, iS14{i2} ]
   = reduction( T0_g[ iS0{i0}, iS1{i1}, iS2{i2} ], op = add, initial value = float(0), allreduce = false )
T3_g[ iS9{i4}, iS10{i5}, iS11{i6} ]
   = pow(T1_g[ iS3{i4}, iS4{i5}, iS5{i6} ]
  , double(3));
}

} // {Re-written complete fusion}
Segmented_Fusion Dump: -- fusion segments:
Segmented_Fusion{
groups:
g{2}

g{1}

edges:

group details:
g{(reduction)
inputs:
T0_g[ iS0{i0}, iS1{i1}, iS2{i2} ] float
outputs:
T4_g[ rS12{i0}, rS13{i1}, iS14{i2} ] float


T4_g[ rS12{i0}, rS13{i1}, iS14{i2} ]
   = reduction( T0_g[ iS0{i0}, iS1{i1}, iS2{i2} ], op = add, initial value = float(0), allreduce = false )
(2)
}

g{(pointwise)
inputs:
T1_g[ iS3{i4}, iS4{i5}, iS5{i6} ] float
outputs:
T3_g[ iS9{i4}, iS10{i5}, iS11{i6} ] float


T3_g[ iS9{i4}, iS10{i5}, iS11{i6} ]
   = pow(T1_g[ iS3{i4}, iS4{i5}, iS5{i6} ]
  , double(3));
(1)
}

} //Segmented_Fusion

The real example of the fusion definition from the backward implementation of Gelu + a Bias Gradient.

import torch
from nvfuser import FusionDefinition, DataType
def nvfuser_fusion_id6(fd : FusionDefinition) -> None :
    S0 = fd.define_scalar(None, dtype=DataType.Double)
    S1 = fd.define_scalar(None, dtype=DataType.Double)
    S2 = fd.define_scalar(None, dtype=DataType.Double)
    S3 = fd.define_scalar(None, dtype=DataType.Double)
    T4 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False)
    T5 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False)
    T6 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False)
    S7 = fd.define_scalar(0.500000, dtype=DataType.Double)
    T8 = fd.ops.mul(S7, T6)
    S9 = fd.define_scalar(3.00000, dtype=DataType.Double)
    T10 = fd.ops.pow(T6, S9)
    S11 = fd.define_scalar(0.0447150, dtype=DataType.Double)
    T12 = fd.ops.mul(S11, T10)
    T13 = fd.ops.add(T6, T12)
    S14 = fd.define_scalar(0.797885, dtype=DataType.Double)
    T15 = fd.ops.mul(S14, T13)
    T16 = fd.ops.tanh(T15)
    S17 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T18 = fd.ops.add(S17, T16)
    T19 = fd.ops.sum(T4, axes=[0, 1], keepdim=False, dtype=DataType.Null)
    T20 = fd.ops.mul(T5, T18)
    T21 = fd.ops.mul(T5, T8)
    S22 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T23 = fd.ops.mul(S22, T21)
    T24 = fd.ops.mul(T16, T16)
    S25 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T26 = fd.ops.sub(S25, T24)
    T27 = fd.ops.mul(T23, T26)
    T28 = fd.ops.mul(T27, S3)
    S29 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T30 = fd.ops.mul(S29, T28)
    T31 = fd.ops.mul(T30, S2)
    T32 = fd.ops.mul(T31, S1)
    S33 = fd.define_scalar(2.00000, dtype=DataType.Double)
    T34 = fd.ops.pow(T6, S33)
    T35 = fd.ops.mul(T32, T34)
    T36 = fd.ops.add(T30, T35)
    T37 = fd.ops.mul(T20, S0)
    T38 = fd.ops.add(T36, T37)
    fd.add_output(T19)
    fd.add_output(T38)

inputs = [
    1.0,
        2.0,
    3.0,
    4.0,
        torch.randn(16, 512, 4096, device="cuda"),
        torch.randn(16, 512, 4096, device="cuda"),
        torch.randn(16, 512, 4096, device="cuda"),
]

with FusionDefinition() as fd:
    nvfuser_fusion_id6(fd)

for _ in range(5):
        out = fd.execute(inputs)
@kevinstephano kevinstephano added the Segmentation Issues related to nvFuser Segmentation label Aug 30, 2023
@kevinstephano
Copy link
Collaborator Author

Closing as this is not going to be worked on.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Segmentation Issues related to nvFuser Segmentation
Projects
None yet
Development

No branches or pull requests

2 participants