We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
I have two examples.
bias+dropout+add+layer_norm
layer_norm
The most notable thing is the order of the inputs. Could that be an issue?
A repro of the 1-Kernel bias+dropout+add+layer_norm backward fusion
bias+dropout+add+layer_norm backward
import torch from nvfuser import FusionDefinition, DataType def nvfuser_fusion_id9(fd : FusionDefinition) -> None : S0 = fd.define_scalar(None, dtype=DataType.Double) T1 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[None, None, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0]) T2 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0]) T3 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0]) T4 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Bool, is_cpu=False, stride_order=[2, 1, 0]) T5 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, None], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0]) T6 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0]) T7 = fd.define_tensor(shape=[-1, -1, 1], contiguity=[True, True, None], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0]) T8 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0]) T9 = fd.ops.cast(T4, dtype=DataType.Float) S10 = fd.define_scalar(16, dtype=DataType.Int) S11 = fd.define_scalar(512, dtype=DataType.Int) S12 = fd.define_scalar(1600, dtype=DataType.Int) V13 = fd.define_vector([S10, S11, S12], dtype=DataType.Int) T14 = fd.ops.broadcast_in_dim(T7, shape=V13, broadcast_dims=[0, 1, 2]) T15 = fd.ops.mul(T8, T14) T16 = fd.ops.cast(T1, dtype=DataType.Float) T17 = fd.ops.cast(T3, dtype=DataType.Float) T18 = fd.ops.sum(T17, axes=[0, 1], keepdim=False, dtype=DataType.Null) T19 = fd.ops.cast(T18, dtype=DataType.BFloat16) T20 = fd.ops.mul(T17, T16) T21 = fd.ops.mul(T17, T15) T22 = fd.ops.sum(T21, axes=[0, 1], keepdim=False, dtype=DataType.Null) T23 = fd.ops.cast(T22, dtype=DataType.BFloat16) T24 = fd.ops.mul(T20, T14) T25 = fd.ops.mul(T20, T8) T26 = fd.ops.sum(T25, axes=[2], keepdim=False, dtype=DataType.Null) S27 = fd.define_scalar(16, dtype=DataType.Int) S28 = fd.define_scalar(512, dtype=DataType.Int) S29 = fd.define_scalar(1, dtype=DataType.Int) V30 = fd.define_vector([S27, S28, S29], dtype=DataType.Int) T31 = fd.ops.broadcast_in_dim(T26, shape=V30, broadcast_dims=[0, 1]) T32 = fd.ops.neg(T24) T33 = fd.ops.sum(T32, axes=[2], keepdim=False, dtype=DataType.Null) S34 = fd.define_scalar(16, dtype=DataType.Int) S35 = fd.define_scalar(512, dtype=DataType.Int) S36 = fd.define_scalar(1, dtype=DataType.Int) V37 = fd.define_vector([S34, S35, S36], dtype=DataType.Int) T38 = fd.ops.broadcast_in_dim(T33, shape=V37, broadcast_dims=[0, 1]) S39 = fd.define_scalar(-0.500000, dtype=DataType.Double) T40 = fd.ops.mul(S39, T31) S41 = fd.define_scalar(3.00000, dtype=DataType.Double) T42 = fd.ops.pow(T7, S41) T43 = fd.ops.mul(T40, T42) T44 = fd.ops.sum(T38, axes=[2], keepdim=False, dtype=DataType.Null) T45 = fd.ops.sum(T43, axes=[2], keepdim=False, dtype=DataType.Null) S46 = fd.define_scalar(16, dtype=DataType.Int) S47 = fd.define_scalar(512, dtype=DataType.Int) S48 = fd.define_scalar(1, dtype=DataType.Int) V49 = fd.define_vector([S46, S47, S48], dtype=DataType.Int) T50 = fd.ops.broadcast_in_dim(T45, shape=V49, broadcast_dims=[0, 1]) S51 = fd.define_scalar(16, dtype=DataType.Int) S52 = fd.define_scalar(512, dtype=DataType.Int) S53 = fd.define_scalar(1600, dtype=DataType.Int) V54 = fd.define_vector([S51, S52, S53], dtype=DataType.Int) T55 = fd.ops.broadcast_in_dim(T50, shape=V54, broadcast_dims=[0, 1, 2]) S56 = fd.define_scalar(2.00000, dtype=DataType.Double) T57 = fd.ops.mul(S56, T55) T58 = fd.ops.sub(T6, T5) T59 = fd.ops.mul(T57, T58) S60 = fd.define_scalar(1600.00, dtype=DataType.Double) S61 = fd.ops.reciprocal(S60) T62 = fd.ops.mul(T59, S61) S63 = fd.define_scalar(16, dtype=DataType.Int) S64 = fd.define_scalar(512, dtype=DataType.Int) S65 = fd.define_scalar(1, dtype=DataType.Int) V66 = fd.define_vector([S63, S64, S65], dtype=DataType.Int) T67 = fd.ops.broadcast_in_dim(T44, shape=V66, broadcast_dims=[0, 1]) S68 = fd.define_scalar(16, dtype=DataType.Int) S69 = fd.define_scalar(512, dtype=DataType.Int) S70 = fd.define_scalar(1600, dtype=DataType.Int) V71 = fd.define_vector([S68, S69, S70], dtype=DataType.Int) T72 = fd.ops.broadcast_in_dim(T67, shape=V71, broadcast_dims=[0, 1, 2]) S73 = fd.define_scalar(0.000625000, dtype=DataType.Double) T74 = fd.ops.mul(S73, T72) T75 = fd.ops.add(T62, T74) T76 = fd.ops.add(T24, T75) T77 = fd.ops.add(T2, T76) T78 = fd.ops.mul(T77, S0) T79 = fd.ops.mul(T78, T9) T80 = fd.ops.cast(T79, dtype=DataType.BFloat16) T81 = fd.ops.sum(T79, axes=[0, 1], keepdim=False, dtype=DataType.Null) T82 = fd.ops.cast(T81, dtype=DataType.BFloat16) fd.add_output(T19) fd.add_output(T23) fd.add_output(T77) fd.add_output(T80) fd.add_output(T82) with FusionDefinition() as fd: nvfuser_fusion_id9(fd) inputs = [ 1.0, torch.randn(16, 512, 4096, device='cuda', dtype=torch.bfloat16).as_strided((16, 512, 4096), (0, 0, 1)), # T1 torch.randn(16, 512, 4096, device='cuda', dtype=torch.float32), # T2 torch.randn(16, 512, 4096, device='cuda', dtype=torch.bfloat16), # T3 torch.randn(16, 512, 4096, device='cuda', dtype=torch.float32) > 0, # T4 Bool torch.randn(16, 512, device='cuda', dtype=torch.float32).as_strided((16, 512, 4096), (512, 1, 0)), # T5 torch.randn(16, 512, 4096, device='cuda', dtype=torch.float32), # T6 torch.randn(16, 512, 1, device='cuda', dtype=torch.float32), # T7 torch.randn(16, 512, 4096, device='cuda', dtype=torch.float32), # T8 ] out = fd.execute(inputs)
A repro of the 3-Kernel layer_norm backward fusion
layer_norm backward
import torch from nvfuser import FusionDefinition, DataType def nvfuser_fusion_id14(fd : FusionDefinition) -> None : T0 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0]) T1 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[None, None, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0]) T2 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0]) T3 = fd.define_tensor(shape=[-1, -1, 1], contiguity=[True, True, None], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0]) T4 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0]) T5 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, None], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0]) T6 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0]) T7 = fd.ops.cast(T6, dtype=DataType.Float) S8 = fd.define_scalar(16, dtype=DataType.Int) S9 = fd.define_scalar(512, dtype=DataType.Int) S10 = fd.define_scalar(1600, dtype=DataType.Int) V11 = fd.define_vector([S8, S9, S10], dtype=DataType.Int) T12 = fd.ops.broadcast_in_dim(T3, shape=V11, broadcast_dims=[0, 1, 2]) T13 = fd.ops.mul(T0, T12) T14 = fd.ops.cast(T1, dtype=DataType.Float) T15 = fd.ops.cast(T4, dtype=DataType.Float) T16 = fd.ops.sum(T15, axes=[0, 1], keepdim=False, dtype=DataType.Null) T17 = fd.ops.cast(T16, dtype=DataType.BFloat16) T18 = fd.ops.mul(T15, T14) T19 = fd.ops.mul(T15, T13) T20 = fd.ops.sum(T19, axes=[0, 1], keepdim=False, dtype=DataType.Null) T21 = fd.ops.cast(T20, dtype=DataType.BFloat16) T22 = fd.ops.mul(T18, T12) T23 = fd.ops.mul(T18, T0) T24 = fd.ops.sum(T23, axes=[2], keepdim=False, dtype=DataType.Null) S25 = fd.define_scalar(16, dtype=DataType.Int) S26 = fd.define_scalar(512, dtype=DataType.Int) S27 = fd.define_scalar(1, dtype=DataType.Int) V28 = fd.define_vector([S25, S26, S27], dtype=DataType.Int) T29 = fd.ops.broadcast_in_dim(T24, shape=V28, broadcast_dims=[0, 1]) T30 = fd.ops.neg(T22) T31 = fd.ops.sum(T30, axes=[2], keepdim=False, dtype=DataType.Null) S32 = fd.define_scalar(16, dtype=DataType.Int) S33 = fd.define_scalar(512, dtype=DataType.Int) S34 = fd.define_scalar(1, dtype=DataType.Int) V35 = fd.define_vector([S32, S33, S34], dtype=DataType.Int) T36 = fd.ops.broadcast_in_dim(T31, shape=V35, broadcast_dims=[0, 1]) S37 = fd.define_scalar(-0.500000, dtype=DataType.Double) T38 = fd.ops.mul(S37, T29) S39 = fd.define_scalar(3.00000, dtype=DataType.Double) T40 = fd.ops.pow(T3, S39) T41 = fd.ops.mul(T38, T40) T42 = fd.ops.sum(T36, axes=[2], keepdim=False, dtype=DataType.Null) T43 = fd.ops.sum(T41, axes=[2], keepdim=False, dtype=DataType.Null) S44 = fd.define_scalar(16, dtype=DataType.Int) S45 = fd.define_scalar(512, dtype=DataType.Int) S46 = fd.define_scalar(1, dtype=DataType.Int) V47 = fd.define_vector([S44, S45, S46], dtype=DataType.Int) T48 = fd.ops.broadcast_in_dim(T43, shape=V47, broadcast_dims=[0, 1]) S49 = fd.define_scalar(16, dtype=DataType.Int) S50 = fd.define_scalar(512, dtype=DataType.Int) S51 = fd.define_scalar(1600, dtype=DataType.Int) V52 = fd.define_vector([S49, S50, S51], dtype=DataType.Int) T53 = fd.ops.broadcast_in_dim(T48, shape=V52, broadcast_dims=[0, 1, 2]) S54 = fd.define_scalar(2.00000, dtype=DataType.Double) T55 = fd.ops.mul(S54, T53) T56 = fd.ops.sub(T7, T5) T57 = fd.ops.mul(T55, T56) S58 = fd.define_scalar(1600.00, dtype=DataType.Double) S59 = fd.ops.reciprocal(S58) T60 = fd.ops.mul(T57, S59) S61 = fd.define_scalar(16, dtype=DataType.Int) S62 = fd.define_scalar(512, dtype=DataType.Int) S63 = fd.define_scalar(1, dtype=DataType.Int) V64 = fd.define_vector([S61, S62, S63], dtype=DataType.Int) T65 = fd.ops.broadcast_in_dim(T42, shape=V64, broadcast_dims=[0, 1]) S66 = fd.define_scalar(16, dtype=DataType.Int) S67 = fd.define_scalar(512, dtype=DataType.Int) S68 = fd.define_scalar(1600, dtype=DataType.Int) V69 = fd.define_vector([S66, S67, S68], dtype=DataType.Int) T70 = fd.ops.broadcast_in_dim(T65, shape=V69, broadcast_dims=[0, 1, 2]) S71 = fd.define_scalar(0.000625000, dtype=DataType.Double) T72 = fd.ops.mul(S71, T70) T73 = fd.ops.add(T60, T72) T74 = fd.ops.add(T22, T73) T75 = fd.ops.add(T2, T74) T76 = fd.ops.cast(T75, dtype=DataType.BFloat16) fd.add_output(T17) fd.add_output(T21) fd.add_output(T76) with FusionDefinition() as fd: nvfuser_fusion_id14(fd) inputs = [ torch.randn(16, 512, 4096, device='cuda', dtype=torch.float32), # T0 torch.randn(16, 512, 4096, device='cuda', dtype=torch.bfloat16).as_strided((16, 512, 4096), (0, 0, 1)), # T1 torch.randn(16, 512, 4096, device='cuda', dtype=torch.float32), # T2 torch.randn(16, 512, 1, device='cuda', dtype=torch.float32), # T3 torch.randn(16, 512, 4096, device='cuda', dtype=torch.bfloat16), # T4 torch.randn(16, 512, device='cuda', dtype=torch.float32).as_strided((16, 512, 4096), (512, 1, 0)), # T5 torch.randn(16, 512, 4096, device='cuda', dtype=torch.bfloat16), # T6 ] out = fd.execute(inputs)
Perf on A100: 1-Kernel bias+dropout+add+layer_norm backward fusion
Fus# NSegs CuEvtTm(ms) HstTm(ms) CmpTm(ms) KerTm(ms) EffBw(GB/s) %PeakBw S-Seg# S-KerTm(ms) S-EffBw(GB/s) S-%PeakBw S-In(MB) S-Out(MB) S-Smem[Dyn,Stat] S-Regs S-Grid S-Block S-KerName 0 1 1431.770 1431.266 1312.344 0.537 1438.28 74.32 0 0.537 1438.28 74.32 570.491 201.351 [512, 16] 168 [1, 324, 1] [4, 32, 1] nvfuser_inner_outer_persistent_f0_c1_r0_g0
3-Kernel layer_norm backward fusion
Fus# NSegs CuEvtTm(ms) HstTm(ms) CmpTm(ms) KerTm(ms) EffBw(GB/s) %PeakBw S-Seg# S-KerTm(ms) S-EffBw(GB/s) S-%PeakBw S-In(MB) S-Out(MB) S-Smem[Dyn,Stat] S-Regs S-Grid S-Block S-KerName 0 3 1336.294 1336.115 575.697 0.547 980.75 50.68 0 0.178 1131.51 58.46 201.359 0.033 [2048, 16] 56 [64, 15, 1] [64, 8, 1] nvfuser_reduction_f0_c1_r0_g0 - - - - - - - - 1 0.163 2063.60 106.63 268.452 67.142 [2048, 0] 30 [8192, 1, 1] [512, 1, 1] nvfuser_reduction_f0_c1_r0_g1 - - - - - - - - 2 0.207 1622.39 83.83 268.534 67.109 [6912, 16] 54 [32768, 1, 1] [128, 1, 1] nvfuser_transpose_f0_c1_r0_g2
The text was updated successfully, but these errors were encountered:
I am not sure this is still relevant, so closing.
Sorry, something went wrong.
wujingyue
No branches or pull requests
I have two examples.
bias+dropout+add+layer_norm
that fuses into one kernel as expected.layer_norm
backward that segments into 3 kernels.The most notable thing is the order of the inputs. Could that be an issue?
A repro of the 1-Kernel
bias+dropout+add+layer_norm backward
fusionA repro of the 3-Kernel
layer_norm backward
fusionPerf on A100:
1-Kernel
bias+dropout+add+layer_norm backward
fusion3-Kernel
layer_norm backward
fusionThe text was updated successfully, but these errors were encountered: