Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Account for additional local memory requirements in Global dispatch #127

Merged
merged 9 commits into from
Jan 5, 2024
29 changes: 17 additions & 12 deletions src/portfft/descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,12 @@ class committed_descriptor {
}
std::vector<std::tuple<detail::level, std::vector<sycl::kernel_id>, std::vector<Idx>>> param_vec;
auto check_and_select_target_level = [&](IdxGlobal factor_size, bool batch_interleaved_layout = true) -> bool {
if (detail::fits_in_wi<Scalar>(factor_size)) {
constexpr std::size_t LocalRange = PORTFFT_SGS_IN_WG * SubgroupSize;
if (detail::fits_in_wi<Scalar>(factor_size) &&
static_cast<std::size_t>(local_memory_size) >
sizeof(Scalar) * 2 * LocalRange * static_cast<std::size_t>(factor_size)) {
// The local memory requirement (LocalRange * factor_size) is only required for the final factor.
// There is no way to know if this is the last factor, so it is always used.
Rbiessy marked this conversation as resolved.
Show resolved Hide resolved
param_vec.emplace_back(detail::level::WORKITEM,
detail::get_ids<detail::global_kernel, Scalar, Domain, SubgroupSize>(),
std::vector<Idx>{static_cast<Idx>(factor_size)});
Expand All @@ -380,19 +385,19 @@ class committed_descriptor {
IdxGlobal factor_wi = factor_size / factor_sg;
if (detail::can_cast_safely<IdxGlobal, Idx>(factor_sg) && detail::can_cast_safely<IdxGlobal, Idx>(factor_wi)) {
if (batch_interleaved_layout) {
return (2 *
num_scalars_in_local_mem<detail::layout::BATCH_INTERLEAVED>(
detail::level::SUBGROUP, static_cast<std::size_t>(factor_size), SubgroupSize,
{static_cast<Idx>(factor_sg), static_cast<Idx>(factor_wi)}, temp_num_sgs_in_wg) *
sizeof(Scalar) +
2 * static_cast<std::size_t>(factor_size) * sizeof(Scalar)) <
// (local_memory_for_input + local_mem_for_store_modifier + local_mem_for_twiddles) < local_mem
return (sizeof(Scalar) *
(num_scalars_in_local_mem<detail::layout::BATCH_INTERLEAVED>(
detail::level::SUBGROUP, static_cast<std::size_t>(factor_size), SubgroupSize,
{static_cast<Idx>(factor_sg), static_cast<Idx>(factor_wi)}, temp_num_sgs_in_wg) +
static_cast<std::size_t>(factor_size) * LocalRange + 2 * static_cast<std::size_t>(factor_size))) <
AD2605 marked this conversation as resolved.
Show resolved Hide resolved
static_cast<std::size_t>(local_memory_size);
}
return (num_scalars_in_local_mem<detail::layout::PACKED>(
detail::level::SUBGROUP, static_cast<std::size_t>(factor_size), SubgroupSize,
{static_cast<Idx>(factor_sg), static_cast<Idx>(factor_wi)}, temp_num_sgs_in_wg) *
sizeof(Scalar) +
2 * static_cast<std::size_t>(factor_size) * sizeof(Scalar)) <
return (sizeof(Scalar) *
(num_scalars_in_local_mem<detail::layout::PACKED>(
detail::level::SUBGROUP, static_cast<std::size_t>(factor_size), SubgroupSize,
{static_cast<Idx>(factor_sg), static_cast<Idx>(factor_wi)}, temp_num_sgs_in_wg) +
static_cast<std::size_t>(factor_size) * LocalRange + 2 * static_cast<std::size_t>(factor_size))) <
static_cast<std::size_t>(local_memory_size);
}
return false;
Expand Down
7 changes: 7 additions & 0 deletions test/unit_test/instantiate_fft_tests.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ INSTANTIATE_TEST_SUITE_P(WorkgroupOrGlobal, FFTTest,
::testing::Values(1, 128), ::testing::Values(sizes_t{8192}, sizes_t{16384}))),
test_params_print());

// Selected individual tests in workgroup or global size range
AD2605 marked this conversation as resolved.
Show resolved Hide resolved
INSTANTIATE_TEST_SUITE_P(WorkgroupOrGlobalRegressionTest, FFTTest,
::testing::ConvertGenerator<basic_param_tuple>(
::testing::Combine(ip_packed_layout, fwd_only, interleaved_storage, ::testing::Values(3),
::testing::Values(sizes_t{15360}))),
test_params_print());

// Sizes that use the global implementations
INSTANTIATE_TEST_SUITE_P(GlobalTest, FFTTest,
::testing::ConvertGenerator<basic_param_tuple>(::testing::Combine(
Expand Down