Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why is LayerNorm Backward Segmenting? #1473

Closed
kevinstephano opened this issue Dec 7, 2023 · 1 comment
Closed

Why is LayerNorm Backward Segmenting? #1473

kevinstephano opened this issue Dec 7, 2023 · 1 comment
Assignees
Labels
perf Segmentation Issues related to nvFuser Segmentation

Comments

@kevinstephano
Copy link
Collaborator

kevinstephano commented Dec 7, 2023

I have two examples.

  1. The backward of bias+dropout+add+layer_norm that fuses into one kernel as expected.
  2. layer_norm backward that segments into 3 kernels.

The most notable thing is the order of the inputs. Could that be an issue?

A repro of the 1-Kernel bias+dropout+add+layer_norm backward fusion

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id9(fd : FusionDefinition) -> None :
    S0 = fd.define_scalar(None, dtype=DataType.Double)
    T1 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[None, None, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T2 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T3 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T4 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Bool, is_cpu=False, stride_order=[2, 1, 0])
    T5 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, None], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T6 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T7 = fd.define_tensor(shape=[-1, -1, 1], contiguity=[True, True, None], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T8 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T9 = fd.ops.cast(T4, dtype=DataType.Float)
    S10 = fd.define_scalar(16, dtype=DataType.Int)
    S11 = fd.define_scalar(512, dtype=DataType.Int)
    S12 = fd.define_scalar(1600, dtype=DataType.Int)
    V13 = fd.define_vector([S10, S11, S12], dtype=DataType.Int)
    T14 = fd.ops.broadcast_in_dim(T7, shape=V13, broadcast_dims=[0, 1, 2])
    T15 = fd.ops.mul(T8, T14)
    T16 = fd.ops.cast(T1, dtype=DataType.Float)
    T17 = fd.ops.cast(T3, dtype=DataType.Float)
    T18 = fd.ops.sum(T17, axes=[0, 1], keepdim=False, dtype=DataType.Null)
    T19 = fd.ops.cast(T18, dtype=DataType.BFloat16)
    T20 = fd.ops.mul(T17, T16)
    T21 = fd.ops.mul(T17, T15)
    T22 = fd.ops.sum(T21, axes=[0, 1], keepdim=False, dtype=DataType.Null)
    T23 = fd.ops.cast(T22, dtype=DataType.BFloat16)
    T24 = fd.ops.mul(T20, T14)
    T25 = fd.ops.mul(T20, T8)
    T26 = fd.ops.sum(T25, axes=[2], keepdim=False, dtype=DataType.Null)
    S27 = fd.define_scalar(16, dtype=DataType.Int)
    S28 = fd.define_scalar(512, dtype=DataType.Int)
    S29 = fd.define_scalar(1, dtype=DataType.Int)
    V30 = fd.define_vector([S27, S28, S29], dtype=DataType.Int)
    T31 = fd.ops.broadcast_in_dim(T26, shape=V30, broadcast_dims=[0, 1])
    T32 = fd.ops.neg(T24)
    T33 = fd.ops.sum(T32, axes=[2], keepdim=False, dtype=DataType.Null)
    S34 = fd.define_scalar(16, dtype=DataType.Int)
    S35 = fd.define_scalar(512, dtype=DataType.Int)
    S36 = fd.define_scalar(1, dtype=DataType.Int)
    V37 = fd.define_vector([S34, S35, S36], dtype=DataType.Int)
    T38 = fd.ops.broadcast_in_dim(T33, shape=V37, broadcast_dims=[0, 1])
    S39 = fd.define_scalar(-0.500000, dtype=DataType.Double)
    T40 = fd.ops.mul(S39, T31)
    S41 = fd.define_scalar(3.00000, dtype=DataType.Double)
    T42 = fd.ops.pow(T7, S41)
    T43 = fd.ops.mul(T40, T42)
    T44 = fd.ops.sum(T38, axes=[2], keepdim=False, dtype=DataType.Null)
    T45 = fd.ops.sum(T43, axes=[2], keepdim=False, dtype=DataType.Null)
    S46 = fd.define_scalar(16, dtype=DataType.Int)
    S47 = fd.define_scalar(512, dtype=DataType.Int)
    S48 = fd.define_scalar(1, dtype=DataType.Int)
    V49 = fd.define_vector([S46, S47, S48], dtype=DataType.Int)
    T50 = fd.ops.broadcast_in_dim(T45, shape=V49, broadcast_dims=[0, 1])
    S51 = fd.define_scalar(16, dtype=DataType.Int)
    S52 = fd.define_scalar(512, dtype=DataType.Int)
    S53 = fd.define_scalar(1600, dtype=DataType.Int)
    V54 = fd.define_vector([S51, S52, S53], dtype=DataType.Int)
    T55 = fd.ops.broadcast_in_dim(T50, shape=V54, broadcast_dims=[0, 1, 2])
    S56 = fd.define_scalar(2.00000, dtype=DataType.Double)
    T57 = fd.ops.mul(S56, T55)
    T58 = fd.ops.sub(T6, T5)
    T59 = fd.ops.mul(T57, T58)
    S60 = fd.define_scalar(1600.00, dtype=DataType.Double)
    S61 = fd.ops.reciprocal(S60)
    T62 = fd.ops.mul(T59, S61)
    S63 = fd.define_scalar(16, dtype=DataType.Int)
    S64 = fd.define_scalar(512, dtype=DataType.Int)
    S65 = fd.define_scalar(1, dtype=DataType.Int)
    V66 = fd.define_vector([S63, S64, S65], dtype=DataType.Int)
    T67 = fd.ops.broadcast_in_dim(T44, shape=V66, broadcast_dims=[0, 1])
    S68 = fd.define_scalar(16, dtype=DataType.Int)
    S69 = fd.define_scalar(512, dtype=DataType.Int)
    S70 = fd.define_scalar(1600, dtype=DataType.Int)
    V71 = fd.define_vector([S68, S69, S70], dtype=DataType.Int)
    T72 = fd.ops.broadcast_in_dim(T67, shape=V71, broadcast_dims=[0, 1, 2])
    S73 = fd.define_scalar(0.000625000, dtype=DataType.Double)
    T74 = fd.ops.mul(S73, T72)
    T75 = fd.ops.add(T62, T74)
    T76 = fd.ops.add(T24, T75)
    T77 = fd.ops.add(T2, T76)
    T78 = fd.ops.mul(T77, S0)
    T79 = fd.ops.mul(T78, T9)
    T80 = fd.ops.cast(T79, dtype=DataType.BFloat16)
    T81 = fd.ops.sum(T79, axes=[0, 1], keepdim=False, dtype=DataType.Null)
    T82 = fd.ops.cast(T81, dtype=DataType.BFloat16)
    fd.add_output(T19)
    fd.add_output(T23)
    fd.add_output(T77)
    fd.add_output(T80)
    fd.add_output(T82)

with FusionDefinition() as fd:
    nvfuser_fusion_id9(fd)

inputs = [
        1.0,
        torch.randn(16, 512, 4096, device='cuda', dtype=torch.bfloat16).as_strided((16, 512, 4096), (0, 0, 1)), # T1
        torch.randn(16, 512, 4096, device='cuda', dtype=torch.float32), # T2
        torch.randn(16, 512, 4096, device='cuda', dtype=torch.bfloat16), # T3
        torch.randn(16, 512, 4096, device='cuda', dtype=torch.float32) > 0, # T4 Bool
        torch.randn(16, 512, device='cuda', dtype=torch.float32).as_strided((16, 512, 4096), (512, 1, 0)), # T5
        torch.randn(16, 512, 4096, device='cuda', dtype=torch.float32), # T6
        torch.randn(16, 512, 1, device='cuda', dtype=torch.float32), # T7
        torch.randn(16, 512, 4096, device='cuda', dtype=torch.float32), # T8
        ]

out = fd.execute(inputs)

A repro of the 3-Kernel layer_norm backward fusion

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id14(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T1 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[None, None, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T2 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T3 = fd.define_tensor(shape=[-1, -1, 1], contiguity=[True, True, None], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T4 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T5 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, None], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T6 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T7 = fd.ops.cast(T6, dtype=DataType.Float)
    S8 = fd.define_scalar(16, dtype=DataType.Int)
    S9 = fd.define_scalar(512, dtype=DataType.Int)
    S10 = fd.define_scalar(1600, dtype=DataType.Int)
    V11 = fd.define_vector([S8, S9, S10], dtype=DataType.Int)
    T12 = fd.ops.broadcast_in_dim(T3, shape=V11, broadcast_dims=[0, 1, 2])
    T13 = fd.ops.mul(T0, T12)
    T14 = fd.ops.cast(T1, dtype=DataType.Float)
    T15 = fd.ops.cast(T4, dtype=DataType.Float)
    T16 = fd.ops.sum(T15, axes=[0, 1], keepdim=False, dtype=DataType.Null)
    T17 = fd.ops.cast(T16, dtype=DataType.BFloat16)
    T18 = fd.ops.mul(T15, T14)
    T19 = fd.ops.mul(T15, T13)
    T20 = fd.ops.sum(T19, axes=[0, 1], keepdim=False, dtype=DataType.Null)
    T21 = fd.ops.cast(T20, dtype=DataType.BFloat16)
    T22 = fd.ops.mul(T18, T12)
    T23 = fd.ops.mul(T18, T0)
    T24 = fd.ops.sum(T23, axes=[2], keepdim=False, dtype=DataType.Null)
    S25 = fd.define_scalar(16, dtype=DataType.Int)
    S26 = fd.define_scalar(512, dtype=DataType.Int)
    S27 = fd.define_scalar(1, dtype=DataType.Int)
    V28 = fd.define_vector([S25, S26, S27], dtype=DataType.Int)
    T29 = fd.ops.broadcast_in_dim(T24, shape=V28, broadcast_dims=[0, 1])
    T30 = fd.ops.neg(T22)
    T31 = fd.ops.sum(T30, axes=[2], keepdim=False, dtype=DataType.Null)
    S32 = fd.define_scalar(16, dtype=DataType.Int)
    S33 = fd.define_scalar(512, dtype=DataType.Int)
    S34 = fd.define_scalar(1, dtype=DataType.Int)
    V35 = fd.define_vector([S32, S33, S34], dtype=DataType.Int)
    T36 = fd.ops.broadcast_in_dim(T31, shape=V35, broadcast_dims=[0, 1])
    S37 = fd.define_scalar(-0.500000, dtype=DataType.Double)
    T38 = fd.ops.mul(S37, T29)
    S39 = fd.define_scalar(3.00000, dtype=DataType.Double)
    T40 = fd.ops.pow(T3, S39)
    T41 = fd.ops.mul(T38, T40)
    T42 = fd.ops.sum(T36, axes=[2], keepdim=False, dtype=DataType.Null)
    T43 = fd.ops.sum(T41, axes=[2], keepdim=False, dtype=DataType.Null)
    S44 = fd.define_scalar(16, dtype=DataType.Int)
    S45 = fd.define_scalar(512, dtype=DataType.Int)
    S46 = fd.define_scalar(1, dtype=DataType.Int)
    V47 = fd.define_vector([S44, S45, S46], dtype=DataType.Int)
    T48 = fd.ops.broadcast_in_dim(T43, shape=V47, broadcast_dims=[0, 1])
    S49 = fd.define_scalar(16, dtype=DataType.Int)
    S50 = fd.define_scalar(512, dtype=DataType.Int)
    S51 = fd.define_scalar(1600, dtype=DataType.Int)
    V52 = fd.define_vector([S49, S50, S51], dtype=DataType.Int)
    T53 = fd.ops.broadcast_in_dim(T48, shape=V52, broadcast_dims=[0, 1, 2])
    S54 = fd.define_scalar(2.00000, dtype=DataType.Double)
    T55 = fd.ops.mul(S54, T53)
    T56 = fd.ops.sub(T7, T5)
    T57 = fd.ops.mul(T55, T56)
    S58 = fd.define_scalar(1600.00, dtype=DataType.Double)
    S59 = fd.ops.reciprocal(S58)
    T60 = fd.ops.mul(T57, S59)
    S61 = fd.define_scalar(16, dtype=DataType.Int)
    S62 = fd.define_scalar(512, dtype=DataType.Int)
    S63 = fd.define_scalar(1, dtype=DataType.Int)
    V64 = fd.define_vector([S61, S62, S63], dtype=DataType.Int)
    T65 = fd.ops.broadcast_in_dim(T42, shape=V64, broadcast_dims=[0, 1])
    S66 = fd.define_scalar(16, dtype=DataType.Int)
    S67 = fd.define_scalar(512, dtype=DataType.Int)
    S68 = fd.define_scalar(1600, dtype=DataType.Int)
    V69 = fd.define_vector([S66, S67, S68], dtype=DataType.Int)
    T70 = fd.ops.broadcast_in_dim(T65, shape=V69, broadcast_dims=[0, 1, 2])
    S71 = fd.define_scalar(0.000625000, dtype=DataType.Double)
    T72 = fd.ops.mul(S71, T70)
    T73 = fd.ops.add(T60, T72)
    T74 = fd.ops.add(T22, T73)
    T75 = fd.ops.add(T2, T74)
    T76 = fd.ops.cast(T75, dtype=DataType.BFloat16)
    fd.add_output(T17)
    fd.add_output(T21)
    fd.add_output(T76)

with FusionDefinition() as fd:
    nvfuser_fusion_id14(fd)

inputs = [
        torch.randn(16, 512, 4096, device='cuda', dtype=torch.float32), # T0
        torch.randn(16, 512, 4096, device='cuda', dtype=torch.bfloat16).as_strided((16, 512, 4096), (0, 0, 1)), # T1
        torch.randn(16, 512, 4096, device='cuda', dtype=torch.float32), # T2
        torch.randn(16, 512, 1, device='cuda', dtype=torch.float32), # T3
        torch.randn(16, 512, 4096, device='cuda', dtype=torch.bfloat16), # T4
        torch.randn(16, 512, device='cuda', dtype=torch.float32).as_strided((16, 512, 4096), (512, 1, 0)), # T5
        torch.randn(16, 512, 4096, device='cuda', dtype=torch.bfloat16), # T6
        ]

out = fd.execute(inputs)

Perf on A100:
1-Kernel bias+dropout+add+layer_norm backward fusion

Fus#  NSegs CuEvtTm(ms) HstTm(ms) CmpTm(ms) KerTm(ms) EffBw(GB/s) %PeakBw   S-Seg# S-KerTm(ms) S-EffBw(GB/s) S-%PeakBw S-In(MB)  S-Out(MB) S-Smem[Dyn,Stat] S-Regs S-Grid           S-Block          S-KerName
    0     1    1431.770  1431.266  1312.344     0.537     1438.28     74.32      0       0.537       1438.28     74.32   570.491   201.351        [512, 16]    168      [1, 324, 1]       [4, 32, 1] nvfuser_inner_outer_persistent_f0_c1_r0_g0

3-Kernel layer_norm backward fusion

Fus#  NSegs CuEvtTm(ms) HstTm(ms) CmpTm(ms) KerTm(ms) EffBw(GB/s) %PeakBw   S-Seg# S-KerTm(ms) S-EffBw(GB/s) S-%PeakBw S-In(MB)  S-Out(MB) S-Smem[Dyn,Stat] S-Regs S-Grid           S-Block          S-KerName
    0     3    1336.294  1336.115   575.697     0.547      980.75     50.68      0       0.178       1131.51     58.46   201.359     0.033       [2048, 16]     56      [64, 15, 1]       [64, 8, 1] nvfuser_reduction_f0_c1_r0_g0
    -     -           -         -         -         -           -         -      1       0.163       2063.60    106.63   268.452    67.142        [2048, 0]     30     [8192, 1, 1]      [512, 1, 1] nvfuser_reduction_f0_c1_r0_g1
    -     -           -         -         -         -           -         -      2       0.207       1622.39     83.83   268.534    67.109       [6912, 16]     54    [32768, 1, 1]      [128, 1, 1] nvfuser_transpose_f0_c1_r0_g2
@kevinstephano kevinstephano added perf Segmentation Issues related to nvFuser Segmentation labels Dec 7, 2023
@kevinstephano
Copy link
Collaborator Author

I am not sure this is still relevant, so closing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
perf Segmentation Issues related to nvFuser Segmentation
Projects
None yet
Development

No branches or pull requests

2 participants