Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Account for additional local memory requirements in Global dispatch #127

Merged
merged 9 commits into from
Jan 5, 2024
2 changes: 1 addition & 1 deletion src/portfft/common/global.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ std::vector<sycl::event> compute_level(
return 1;
}
if (kd_struct.level == detail::level::SUBGROUP) {
return detail::pad_local(2 * kd_struct.length * static_cast<std::size_t>(local_range / 2), std::size_t(1));
return kd_struct.local_mem_required;
hjabird marked this conversation as resolved.
Show resolved Hide resolved
}
}
return std::size_t(1);
Expand Down
2 changes: 2 additions & 0 deletions src/portfft/descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,9 @@ class committed_descriptor {
}
std::vector<std::tuple<detail::level, std::vector<sycl::kernel_id>, std::vector<Idx>>> param_vec;
auto check_and_select_target_level = [&](IdxGlobal factor_size, bool batch_interleaved_layout = true) -> bool {
constexpr std::size_t LocalRange = PORTFFT_SGS_IN_WG * SubgroupSize;
if (detail::fits_in_wi<Scalar>(factor_size)) {
// Throughout we have assumed there would always be enough local memory for the WI implementation.
param_vec.emplace_back(detail::level::WORKITEM,
detail::get_ids<detail::global_kernel, Scalar, Domain, SubgroupSize>(),
std::vector<Idx>{static_cast<Idx>(factor_size)});
Expand Down
38 changes: 21 additions & 17 deletions src/portfft/dispatcher/global_dispatcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,21 @@ namespace detail {
* @param level The implementation for the factor
* @param n_compute_units compute_units available
* @param subgroup_size Subgroup size chosen
* @param n_sgs_in_wg Number of subgroups in a workgroup.
* @return std::pair containing global and local range
*/
inline std::pair<IdxGlobal, IdxGlobal> get_launch_params(IdxGlobal fft_size, IdxGlobal num_batches, detail::level level,
Idx n_compute_units, Idx subgroup_size) {
Idx n_compute_units, Idx subgroup_size, Idx n_sgs_in_wg) {
IdxGlobal n_available_sgs = 8 * n_compute_units * 64;
IdxGlobal wg_size = static_cast<IdxGlobal>(PORTFFT_SGS_IN_WG * subgroup_size);
IdxGlobal wg_size = static_cast<IdxGlobal>(n_sgs_in_wg * subgroup_size);
if (level == detail::level::WORKITEM) {
IdxGlobal n_ffts_per_wg = static_cast<IdxGlobal>(PORTFFT_SGS_IN_WG * subgroup_size);
IdxGlobal n_ffts_per_wg = wg_size;
IdxGlobal n_wgs_required = divide_ceil(num_batches, n_ffts_per_wg);
return std::make_pair(std::min(n_wgs_required * wg_size, n_available_sgs), wg_size);
}
if (level == detail::level::SUBGROUP) {
IdxGlobal n_ffts_per_sg = static_cast<IdxGlobal>(subgroup_size) / detail::factorize_sg(fft_size, subgroup_size);
IdxGlobal n_ffts_per_wg = n_ffts_per_sg * PORTFFT_SGS_IN_WG * subgroup_size;
IdxGlobal n_ffts_per_wg = n_ffts_per_sg * n_sgs_in_wg;
hjabird marked this conversation as resolved.
Show resolved Hide resolved
IdxGlobal n_wgs_required = divide_ceil(num_batches, n_ffts_per_wg);
return std::make_pair(std::min(n_wgs_required * wg_size, n_available_sgs), wg_size);
}
Expand Down Expand Up @@ -207,35 +208,38 @@ struct committed_descriptor<Scalar, Domain>::calculate_twiddles_struct::inner<de
kernel_data.length = static_cast<std::size_t>(factors_idx_global.at(counter));
if (kernel_data.level == detail::level::WORKITEM) {
// See comments in workitem_dispatcher for layout requirments.
auto [global_range, local_range] =
detail::get_launch_params(factors_idx_global.at(counter), sub_batches.at(counter), detail::level::WORKITEM,
desc.n_compute_units, kernel_data.used_sg_size);
kernel_data.global_range = global_range;
kernel_data.local_range = local_range;
Idx num_sgs_in_wg = PORTFFT_SGS_IN_WG;
if (counter < kernels.size() - 1) {
kernel_data.local_mem_required = static_cast<std::size_t>(1);
} else {
kernel_data.local_mem_required = 2 * static_cast<std::size_t>(local_range * factors_idx_global.at(counter));
kernel_data.local_mem_required = desc.num_scalars_in_local_mem<detail::layout::PACKED>(
detail::level::WORKITEM, static_cast<std::size_t>(factors_idx_global.at(counter)),
kernel_data.used_sg_size, {static_cast<Idx>(factors_idx_global.at(counter))}, num_sgs_in_wg);
}
} else if (kernel_data.level == detail::level::SUBGROUP) {
// See comments in subgroup_dispatcher for layout requirements.
auto [global_range, local_range] =
detail::get_launch_params(factors_idx_global.at(counter), sub_batches.at(counter), detail::level::SUBGROUP,
desc.n_compute_units, kernel_data.used_sg_size);
detail::get_launch_params(factors_idx_global.at(counter), sub_batches.at(counter), detail::level::WORKITEM,
desc.n_compute_units, kernel_data.used_sg_size, num_sgs_in_wg);
kernel_data.global_range = global_range;
kernel_data.local_range = local_range;
} else if (kernel_data.level == detail::level::SUBGROUP) {
Idx num_sgs_in_wg = PORTFFT_SGS_IN_WG;
// See comments in subgroup_dispatcher for layout requirements.
IdxGlobal factor_sg = detail::factorize_sg(factors_idx_global.at(counter), kernel_data.used_sg_size);
IdxGlobal factor_wi = factors_idx_global.at(counter) / factor_sg;
Idx tmp;
if (counter < kernels.size() - 1) {
kernel_data.local_mem_required = desc.num_scalars_in_local_mem<detail::layout::BATCH_INTERLEAVED>(
detail::level::SUBGROUP, static_cast<std::size_t>(factors_idx_global.at(counter)),
kernel_data.used_sg_size, {static_cast<Idx>(factor_sg), static_cast<Idx>(factor_wi)}, tmp);
kernel_data.used_sg_size, {static_cast<Idx>(factor_sg), static_cast<Idx>(factor_wi)}, num_sgs_in_wg);
} else {
kernel_data.local_mem_required = desc.num_scalars_in_local_mem<detail::layout::PACKED>(
detail::level::SUBGROUP, static_cast<std::size_t>(factors_idx_global.at(counter)),
kernel_data.used_sg_size, {static_cast<Idx>(factor_sg), static_cast<Idx>(factor_wi)}, tmp);
kernel_data.used_sg_size, {static_cast<Idx>(factor_sg), static_cast<Idx>(factor_wi)}, num_sgs_in_wg);
}
auto [global_range, local_range] =
detail::get_launch_params(factors_idx_global.at(counter), sub_batches.at(counter), detail::level::SUBGROUP,
desc.n_compute_units, kernel_data.used_sg_size, num_sgs_in_wg);
kernel_data.global_range = global_range;
kernel_data.local_range = local_range;
}
counter++;
}
Expand Down
2 changes: 1 addition & 1 deletion test/unit_test/instantiate_fft_tests.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ INSTANTIATE_TEST_SUITE_P(GlobalTest, FFTTest,
INSTANTIATE_TEST_SUITE_P(WorkgroupOrGlobalRegressionTest, FFTTest,
::testing::ConvertGenerator<basic_param_tuple>(
::testing::Combine(ip_packed_layout, fwd_only, interleaved_storage, ::testing::Values(3),
::testing::Values(sizes_t{9800}))),
::testing::Values(sizes_t{9800}, sizes_t{15360}))),
test_params_print());

// Backward FFT test suite
Expand Down