Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Host benchmarking for a fusion with multiple segments #3307

Merged
merged 3 commits into from
Oct 30, 2024

Conversation

Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Oct 29, 2024

This benchmark uses matmul + pointwise op to create a fusion with 12 segments instead of using segment_set to force segmentation.

Screenshot 2024-10-29 at 4 41 47 PM

For host_benchmark_mode='compile', the profile is shown below. The Finding valid segment solutions pass takes 52 ms
Screenshot 2024-10-29 at 4 52 38 PM

@naoyam
Copy link
Collaborator

naoyam commented Oct 30, 2024

Can we validate the fusion results in the intended number of segments?

@Priya2698
Copy link
Collaborator Author

Can we validate the fusion results in the intended number of segments?

Done. I mis-wrote that it is 11 segments, it is 12 segments with the permute operation.

group details:
g{(pointwise)
group id: 6
inputs:
  T0_g_float[ iS0{i0}, iS51{i0} ] float
  T1_g_float[ iS52{i0}, iS53{i0} ] float
outputs:
  T3_g_float[ iS55{i0}, iS6{i0} ] float


T2_l_float[ iS4{i0}, iS54{i0} ]
   = T0_g_float[ iS0{i0}, iS51{i0} ]
   + T1_g_float[ iS52{i0}, iS53{i0} ];
(0)
T3_g_float[ iS55{i0}, iS6{i0} ]
   = Set.Permute( T2_l_float[ iS4{i0}, iS54{i0} ], cache_op=Streaming )
(1)
}

g{(expr_eval)
group id: 0
inputs:
  T1_g_float[ iS52{i0}, iS53{i0} ] float
  T3_g_float[ iS55{i0}, iS6{i0} ] float
outputs:
  T4_g_float[ iS56{i0}, iS57{i0}, rS10{i0} ] float


T4_g_float[ iS56{i0}, iS57{i0}, rS10{i0} ]
   = matmul(T3_g_float[ iS55{i0}, iS6{i0} ],
            T1_g_float[ iS52{i0}, iS53{i0} ])
(2)
}

g{(pointwise)
group id: 8
inputs:
  T1_g_float[ iS52{i0}, iS53{i0} ] float
  T3_g_float[ iS55{i0}, iS6{i0} ] float
  T4_g_float[ iS56{i0}, iS57{i0}, rS10{i0} ] float
outputs:
  T7_g_float[ iS61{i0}, iS62{i0} ] float


T5_g_float[ iS58{i0}, iS12{i0} ]
   = T3_g_float[ iS55{i0}, iS6{i0} ]
   + T1_g_float[ iS52{i0}, iS53{i0} ];
(3)
T6_l_float[ iS59{i0}, iS60{i0} ]
   = T4_g_float[ iS56{i0}, iS57{i0}, rS10{i0} ]
   + T5_g_float[ iS58{i0}, iS12{i0} ];
(4)
T7_g_float[ iS61{i0}, iS62{i0} ]
   = Set.Permute( T6_l_float[ iS59{i0}, iS60{i0} ], cache_op=Streaming )
(5)
}

g{(expr_eval)
group id: 1
inputs:
  T1_g_float[ iS52{i0}, iS53{i0} ] float
  T7_g_float[ iS61{i0}, iS62{i0} ] float
outputs:
  T8_g_float[ iS63{i0}, iS64{i0}, rS65{i0} ] float


T8_g_float[ iS63{i0}, iS64{i0}, rS65{i0} ]
   = matmul(T7_g_float[ iS61{i0}, iS62{i0} ],
            T1_g_float[ iS52{i0}, iS53{i0} ])
(6)
}

g{(pointwise)
group id: 2
inputs:
  T1_g_float[ iS52{i0}, iS53{i0} ] float
  T7_g_float[ iS61{i0}, iS62{i0} ] float
outputs:
  T9_g_float[ iS66{i0}, iS67{i0} ] float


T9_g_float[ iS66{i0}, iS67{i0} ]
   = T7_g_float[ iS61{i0}, iS62{i0} ]
   + T1_g_float[ iS52{i0}, iS53{i0} ];
(7)
}

g{(pointwise)
group id: 9
inputs:
  T1_g_float[ iS52{i0}, iS53{i0} ] float
  T8_g_float[ iS63{i0}, iS64{i0}, rS65{i0} ] float
  T9_g_float[ iS66{i0}, iS67{i0} ] float
outputs:
  T11_g_float[ iS70{i0}, iS71{i0} ] float
  T13_g_float[ iS75{i0}, iS76{i0} ] float


T10_l_float[ iS68{i0}, iS69{i0} ]
   = T8_g_float[ iS63{i0}, iS64{i0}, rS65{i0} ]
   + T9_g_float[ iS66{i0}, iS67{i0} ];
(8)
T11_g_float[ iS70{i0}, iS71{i0} ]
   = Set.Permute( T10_l_float[ iS68{i0}, iS69{i0} ], cache_op=Streaming )
(9)
T13_g_float[ iS75{i0}, iS76{i0} ]
   = T11_g_float[ iS70{i0}, iS71{i0} ]
   + T1_g_float[ iS52{i0}, iS53{i0} ];
(11)
}

g{(expr_eval)
group id: 3
inputs:
  T1_g_float[ iS52{i0}, iS53{i0} ] float
  T11_g_float[ iS70{i0}, iS71{i0} ] float
outputs:
  T12_g_float[ iS72{i0}, iS73{i0}, rS74{i0} ] float


T12_g_float[ iS72{i0}, iS73{i0}, rS74{i0} ]
   = matmul(T11_g_float[ iS70{i0}, iS71{i0} ],
            T1_g_float[ iS52{i0}, iS53{i0} ])
(10)
}

g{(pointwise)
group id: 7
inputs:
  T12_g_float[ iS72{i0}, iS73{i0}, rS74{i0} ] float
  T13_g_float[ iS75{i0}, iS76{i0} ] float
outputs:
  T15_g_float[ iS79{i0}, iS80{i0} ] float


T14_l_float[ iS77{i0}, iS78{i0} ]
   = T12_g_float[ iS72{i0}, iS73{i0}, rS74{i0} ]
   + T13_g_float[ iS75{i0}, iS76{i0} ];
(12)
T15_g_float[ iS79{i0}, iS80{i0} ]
   = Set.Permute( T14_l_float[ iS77{i0}, iS78{i0} ], cache_op=Streaming )
(13)
}

g{(expr_eval)
group id: 4
inputs:
  T1_g_float[ iS52{i0}, iS53{i0} ] float
  T15_g_float[ iS79{i0}, iS80{i0} ] float
outputs:
  T16_g_float[ iS81{i0}, iS82{i0}, rS83{i0} ] float


T16_g_float[ iS81{i0}, iS82{i0}, rS83{i0} ]
   = matmul(T15_g_float[ iS79{i0}, iS80{i0} ],
            T1_g_float[ iS52{i0}, iS53{i0} ])
(14)
}

g{(pointwise)
group id: 10
inputs:
  T1_g_float[ iS52{i0}, iS53{i0} ] float
  T15_g_float[ iS79{i0}, iS80{i0} ] float
  T16_g_float[ iS81{i0}, iS82{i0}, rS83{i0} ] float
outputs:
  T19_g_float[ iS88{i0}, iS89{i0} ] float


T17_g_float[ iS84{i0}, iS85{i0} ]
   = T15_g_float[ iS79{i0}, iS80{i0} ]
   + T1_g_float[ iS52{i0}, iS53{i0} ];
(15)
T18_l_float[ iS86{i0}, iS87{i0} ]
   = T16_g_float[ iS81{i0}, iS82{i0}, rS83{i0} ]
   + T17_g_float[ iS84{i0}, iS85{i0} ];
(16)
T19_g_float[ iS88{i0}, iS89{i0} ]
   = Set.Permute( T18_l_float[ iS86{i0}, iS87{i0} ], cache_op=Streaming )
(17)
}

g{(expr_eval)
group id: 5
inputs:
  T1_g_float[ iS52{i0}, iS53{i0} ] float
  T19_g_float[ iS88{i0}, iS89{i0} ] float
outputs:
  T20_g_float[ iS90{i0}, iS91{i0}, rS92{i0} ] float


T20_g_float[ iS90{i0}, iS91{i0}, rS92{i0} ]
   = matmul(T19_g_float[ iS88{i0}, iS89{i0} ],
            T1_g_float[ iS52{i0}, iS53{i0} ])
(18)
}

g{(pointwise)
group id: 11
inputs:
  T1_g_float[ iS52{i0}, iS53{i0} ] float
  T19_g_float[ iS88{i0}, iS89{i0} ] float
  T20_g_float[ iS90{i0}, iS91{i0}, rS92{i0} ] float
outputs:
  T22_g_float[ iS95{i0}, iS96{i0} ] float


T21_g_float[ iS93{i0}, iS94{i0} ]
   = T19_g_float[ iS88{i0}, iS89{i0} ]
   + T1_g_float[ iS52{i0}, iS53{i0} ];
(19)
T22_g_float[ iS95{i0}, iS96{i0} ]
   = T20_g_float[ iS90{i0}, iS91{i0}, rS92{i0} ]
   + T21_g_float[ iS93{i0}, iS94{i0} ];
(20)
}

} //Segmented_Fusion

fd.validate(input, [eager_output])

# Validate number of segments
_ = fd.execute(input, profile=True)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self:
Create an issue: Allow fd.validate to accept nvfuser outputs in which case, they are directly compared with the reference outputs.

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Priya2698
Copy link
Collaborator Author

!build

@Priya2698 Priya2698 merged commit a4465df into main Oct 30, 2024
40 of 41 checks passed
@Priya2698 Priya2698 deleted the pm/many_matmul_bench branch October 30, 2024 19:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants