diff --git a/src/portfft/common/global.hpp b/src/portfft/common/global.hpp index bcc7aef5..fdd0c36b 100644 --- a/src/portfft/common/global.hpp +++ b/src/portfft/common/global.hpp @@ -474,7 +474,7 @@ std::vector compute_level( return 1; } if (kd_struct.level == detail::level::SUBGROUP) { - return detail::pad_local(2 * kd_struct.length * static_cast(local_range / 2), std::size_t(1)); + return kd_struct.local_mem_required; } } return std::size_t(1); diff --git a/src/portfft/descriptor.hpp b/src/portfft/descriptor.hpp index 61b9b62d..949965bb 100644 --- a/src/portfft/descriptor.hpp +++ b/src/portfft/descriptor.hpp @@ -368,6 +368,7 @@ class committed_descriptor { std::vector, std::vector>> param_vec; auto check_and_select_target_level = [&](IdxGlobal factor_size, bool batch_interleaved_layout = true) -> bool { if (detail::fits_in_wi(factor_size)) { + // Throughout we have assumed there would always be enough local memory for the WI implementation. param_vec.emplace_back(detail::level::WORKITEM, detail::get_ids(), std::vector{static_cast(factor_size)}); diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index fa43ff71..8c3afc8c 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -41,20 +41,21 @@ namespace detail { * @param level The implementation for the factor * @param n_compute_units compute_units available * @param subgroup_size Subgroup size chosen + * @param n_sgs_in_wg Number of subgroups in a workgroup. * @return std::pair containing global and local range */ inline std::pair get_launch_params(IdxGlobal fft_size, IdxGlobal num_batches, detail::level level, - Idx n_compute_units, Idx subgroup_size) { + Idx n_compute_units, Idx subgroup_size, Idx n_sgs_in_wg) { IdxGlobal n_available_sgs = 8 * n_compute_units * 64; - IdxGlobal wg_size = static_cast(PORTFFT_SGS_IN_WG * subgroup_size); + IdxGlobal wg_size = n_sgs_in_wg * subgroup_size; if (level == detail::level::WORKITEM) { - IdxGlobal n_ffts_per_wg = static_cast(PORTFFT_SGS_IN_WG * subgroup_size); + IdxGlobal n_ffts_per_wg = wg_size; IdxGlobal n_wgs_required = divide_ceil(num_batches, n_ffts_per_wg); return std::make_pair(std::min(n_wgs_required * wg_size, n_available_sgs), wg_size); } if (level == detail::level::SUBGROUP) { IdxGlobal n_ffts_per_sg = static_cast(subgroup_size) / detail::factorize_sg(fft_size, subgroup_size); - IdxGlobal n_ffts_per_wg = n_ffts_per_sg * PORTFFT_SGS_IN_WG * subgroup_size; + IdxGlobal n_ffts_per_wg = n_ffts_per_sg * n_sgs_in_wg; IdxGlobal n_wgs_required = divide_ceil(num_batches, n_ffts_per_wg); return std::make_pair(std::min(n_wgs_required * wg_size, n_available_sgs), wg_size); } @@ -207,35 +208,38 @@ struct committed_descriptor::calculate_twiddles_struct::inner(factors_idx_global.at(counter)); if (kernel_data.level == detail::level::WORKITEM) { // See comments in workitem_dispatcher for layout requirments. - auto [global_range, local_range] = - detail::get_launch_params(factors_idx_global.at(counter), sub_batches.at(counter), detail::level::WORKITEM, - desc.n_compute_units, kernel_data.used_sg_size); - kernel_data.global_range = global_range; - kernel_data.local_range = local_range; + Idx num_sgs_in_wg = PORTFFT_SGS_IN_WG; if (counter < kernels.size() - 1) { kernel_data.local_mem_required = static_cast(1); } else { - kernel_data.local_mem_required = 2 * static_cast(local_range * factors_idx_global.at(counter)); + kernel_data.local_mem_required = desc.num_scalars_in_local_mem( + detail::level::WORKITEM, static_cast(factors_idx_global.at(counter)), + kernel_data.used_sg_size, {static_cast(factors_idx_global.at(counter))}, num_sgs_in_wg); } - } else if (kernel_data.level == detail::level::SUBGROUP) { - // See comments in subgroup_dispatcher for layout requirements. auto [global_range, local_range] = - detail::get_launch_params(factors_idx_global.at(counter), sub_batches.at(counter), detail::level::SUBGROUP, - desc.n_compute_units, kernel_data.used_sg_size); + detail::get_launch_params(factors_idx_global.at(counter), sub_batches.at(counter), detail::level::WORKITEM, + desc.n_compute_units, kernel_data.used_sg_size, num_sgs_in_wg); kernel_data.global_range = global_range; kernel_data.local_range = local_range; + } else if (kernel_data.level == detail::level::SUBGROUP) { + Idx num_sgs_in_wg = PORTFFT_SGS_IN_WG; + // See comments in subgroup_dispatcher for layout requirements. IdxGlobal factor_sg = detail::factorize_sg(factors_idx_global.at(counter), kernel_data.used_sg_size); IdxGlobal factor_wi = factors_idx_global.at(counter) / factor_sg; - Idx tmp; if (counter < kernels.size() - 1) { kernel_data.local_mem_required = desc.num_scalars_in_local_mem( detail::level::SUBGROUP, static_cast(factors_idx_global.at(counter)), - kernel_data.used_sg_size, {static_cast(factor_sg), static_cast(factor_wi)}, tmp); + kernel_data.used_sg_size, {static_cast(factor_sg), static_cast(factor_wi)}, num_sgs_in_wg); } else { kernel_data.local_mem_required = desc.num_scalars_in_local_mem( detail::level::SUBGROUP, static_cast(factors_idx_global.at(counter)), - kernel_data.used_sg_size, {static_cast(factor_sg), static_cast(factor_wi)}, tmp); + kernel_data.used_sg_size, {static_cast(factor_sg), static_cast(factor_wi)}, num_sgs_in_wg); } + auto [global_range, local_range] = + detail::get_launch_params(factors_idx_global.at(counter), sub_batches.at(counter), detail::level::SUBGROUP, + desc.n_compute_units, kernel_data.used_sg_size, num_sgs_in_wg); + kernel_data.global_range = global_range; + kernel_data.local_range = local_range; } counter++; } diff --git a/test/unit_test/instantiate_fft_tests.hpp b/test/unit_test/instantiate_fft_tests.hpp index eab1b72a..7fdf25d5 100644 --- a/test/unit_test/instantiate_fft_tests.hpp +++ b/test/unit_test/instantiate_fft_tests.hpp @@ -128,7 +128,7 @@ INSTANTIATE_TEST_SUITE_P(GlobalTest, FFTTest, INSTANTIATE_TEST_SUITE_P(WorkgroupOrGlobalRegressionTest, FFTTest, ::testing::ConvertGenerator( ::testing::Combine(ip_packed_layout, fwd_only, interleaved_storage, ::testing::Values(3), - ::testing::Values(sizes_t{9800}))), + ::testing::Values(sizes_t{9800}, sizes_t{15360}))), test_params_print()); // Backward FFT test suite