Skip to content

Commit

Permalink
Account for additional local memory requirements in Global dispatch (#…
Browse files Browse the repository at this point in the history
…127)

* fix mismatch between local memory estimation and allocation.

---------

Co-authored-by: atharva.dubey <[email protected]>
  • Loading branch information
hjabird and AD2605 authored Jan 5, 2024
1 parent c71814f commit e429002
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/portfft/common/global.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ std::vector<sycl::event> compute_level(
return 1;
}
if (kd_struct.level == detail::level::SUBGROUP) {
return detail::pad_local(2 * kd_struct.length * static_cast<std::size_t>(local_range / 2), std::size_t(1));
return kd_struct.local_mem_required;
}
}
return std::size_t(1);
Expand Down
1 change: 1 addition & 0 deletions src/portfft/descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ class committed_descriptor {
std::vector<std::tuple<detail::level, std::vector<sycl::kernel_id>, std::vector<Idx>>> param_vec;
auto check_and_select_target_level = [&](IdxGlobal factor_size, bool batch_interleaved_layout = true) -> bool {
if (detail::fits_in_wi<Scalar>(factor_size)) {
// Throughout we have assumed there would always be enough local memory for the WI implementation.
param_vec.emplace_back(detail::level::WORKITEM,
detail::get_ids<detail::global_kernel, Scalar, Domain, SubgroupSize>(),
std::vector<Idx>{static_cast<Idx>(factor_size)});
Expand Down
38 changes: 21 additions & 17 deletions src/portfft/dispatcher/global_dispatcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,21 @@ namespace detail {
* @param level The implementation for the factor
* @param n_compute_units compute_units available
* @param subgroup_size Subgroup size chosen
* @param n_sgs_in_wg Number of subgroups in a workgroup.
* @return std::pair containing global and local range
*/
inline std::pair<IdxGlobal, IdxGlobal> get_launch_params(IdxGlobal fft_size, IdxGlobal num_batches, detail::level level,
Idx n_compute_units, Idx subgroup_size) {
Idx n_compute_units, Idx subgroup_size, Idx n_sgs_in_wg) {
IdxGlobal n_available_sgs = 8 * n_compute_units * 64;
IdxGlobal wg_size = static_cast<IdxGlobal>(PORTFFT_SGS_IN_WG * subgroup_size);
IdxGlobal wg_size = n_sgs_in_wg * subgroup_size;
if (level == detail::level::WORKITEM) {
IdxGlobal n_ffts_per_wg = static_cast<IdxGlobal>(PORTFFT_SGS_IN_WG * subgroup_size);
IdxGlobal n_ffts_per_wg = wg_size;
IdxGlobal n_wgs_required = divide_ceil(num_batches, n_ffts_per_wg);
return std::make_pair(std::min(n_wgs_required * wg_size, n_available_sgs), wg_size);
}
if (level == detail::level::SUBGROUP) {
IdxGlobal n_ffts_per_sg = static_cast<IdxGlobal>(subgroup_size) / detail::factorize_sg(fft_size, subgroup_size);
IdxGlobal n_ffts_per_wg = n_ffts_per_sg * PORTFFT_SGS_IN_WG * subgroup_size;
IdxGlobal n_ffts_per_wg = n_ffts_per_sg * n_sgs_in_wg;
IdxGlobal n_wgs_required = divide_ceil(num_batches, n_ffts_per_wg);
return std::make_pair(std::min(n_wgs_required * wg_size, n_available_sgs), wg_size);
}
Expand Down Expand Up @@ -207,35 +208,38 @@ struct committed_descriptor<Scalar, Domain>::calculate_twiddles_struct::inner<de
kernel_data.length = static_cast<std::size_t>(factors_idx_global.at(counter));
if (kernel_data.level == detail::level::WORKITEM) {
// See comments in workitem_dispatcher for layout requirments.
auto [global_range, local_range] =
detail::get_launch_params(factors_idx_global.at(counter), sub_batches.at(counter), detail::level::WORKITEM,
desc.n_compute_units, kernel_data.used_sg_size);
kernel_data.global_range = global_range;
kernel_data.local_range = local_range;
Idx num_sgs_in_wg = PORTFFT_SGS_IN_WG;
if (counter < kernels.size() - 1) {
kernel_data.local_mem_required = static_cast<std::size_t>(1);
} else {
kernel_data.local_mem_required = 2 * static_cast<std::size_t>(local_range * factors_idx_global.at(counter));
kernel_data.local_mem_required = desc.num_scalars_in_local_mem<detail::layout::PACKED>(
detail::level::WORKITEM, static_cast<std::size_t>(factors_idx_global.at(counter)),
kernel_data.used_sg_size, {static_cast<Idx>(factors_idx_global.at(counter))}, num_sgs_in_wg);
}
} else if (kernel_data.level == detail::level::SUBGROUP) {
// See comments in subgroup_dispatcher for layout requirements.
auto [global_range, local_range] =
detail::get_launch_params(factors_idx_global.at(counter), sub_batches.at(counter), detail::level::SUBGROUP,
desc.n_compute_units, kernel_data.used_sg_size);
detail::get_launch_params(factors_idx_global.at(counter), sub_batches.at(counter), detail::level::WORKITEM,
desc.n_compute_units, kernel_data.used_sg_size, num_sgs_in_wg);
kernel_data.global_range = global_range;
kernel_data.local_range = local_range;
} else if (kernel_data.level == detail::level::SUBGROUP) {
Idx num_sgs_in_wg = PORTFFT_SGS_IN_WG;
// See comments in subgroup_dispatcher for layout requirements.
IdxGlobal factor_sg = detail::factorize_sg(factors_idx_global.at(counter), kernel_data.used_sg_size);
IdxGlobal factor_wi = factors_idx_global.at(counter) / factor_sg;
Idx tmp;
if (counter < kernels.size() - 1) {
kernel_data.local_mem_required = desc.num_scalars_in_local_mem<detail::layout::BATCH_INTERLEAVED>(
detail::level::SUBGROUP, static_cast<std::size_t>(factors_idx_global.at(counter)),
kernel_data.used_sg_size, {static_cast<Idx>(factor_sg), static_cast<Idx>(factor_wi)}, tmp);
kernel_data.used_sg_size, {static_cast<Idx>(factor_sg), static_cast<Idx>(factor_wi)}, num_sgs_in_wg);
} else {
kernel_data.local_mem_required = desc.num_scalars_in_local_mem<detail::layout::PACKED>(
detail::level::SUBGROUP, static_cast<std::size_t>(factors_idx_global.at(counter)),
kernel_data.used_sg_size, {static_cast<Idx>(factor_sg), static_cast<Idx>(factor_wi)}, tmp);
kernel_data.used_sg_size, {static_cast<Idx>(factor_sg), static_cast<Idx>(factor_wi)}, num_sgs_in_wg);
}
auto [global_range, local_range] =
detail::get_launch_params(factors_idx_global.at(counter), sub_batches.at(counter), detail::level::SUBGROUP,
desc.n_compute_units, kernel_data.used_sg_size, num_sgs_in_wg);
kernel_data.global_range = global_range;
kernel_data.local_range = local_range;
}
counter++;
}
Expand Down
2 changes: 1 addition & 1 deletion test/unit_test/instantiate_fft_tests.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ INSTANTIATE_TEST_SUITE_P(GlobalTest, FFTTest,
INSTANTIATE_TEST_SUITE_P(WorkgroupOrGlobalRegressionTest, FFTTest,
::testing::ConvertGenerator<basic_param_tuple>(
::testing::Combine(ip_packed_layout, fwd_only, interleaved_storage, ::testing::Values(3),
::testing::Values(sizes_t{9800}))),
::testing::Values(sizes_t{9800}, sizes_t{15360}))),
test_params_print());

// Backward FFT test suite
Expand Down

0 comments on commit e429002

Please sign in to comment.