Skip to content

Commit

Permalink
Account for additional local memory requirements in Global dispatch
Browse files Browse the repository at this point in the history
* Not accounting for all local memory requirements in dispatch causes
errors on A100.
* This PR accounts for the additional required memory.
* And also adds a regression test.
* This is not an ideal fix for the issue - the memory requirements
depend on the ordering of the sub-impls. This information is not
available at dispatch time. A more significant refactor is required.
  • Loading branch information
hjabird committed Jan 3, 2024
1 parent c28df72 commit aec8049
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/portfft/descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,12 @@ class committed_descriptor {
}
std::vector<std::tuple<detail::level, std::vector<sycl::kernel_id>, std::vector<Idx>>> param_vec;
auto check_and_select_target_level = [&](IdxGlobal factor_size, bool batch_interleaved_layout = true) -> bool {
if (detail::fits_in_wi<Scalar>(factor_size)) {
constexpr std::size_t LocalRange = PORTFFT_SGS_IN_WG * SubgroupSize;
if (detail::fits_in_wi<Scalar>(factor_size) &&
static_cast<std::size_t>(local_memory_size) >
sizeof(Scalar) * 2 * LocalRange * static_cast<std::size_t>(factor_size)) {
// The local memory requirement (LocalRange * factor_size) is only required for the final factor.
// There is no way to know if this is the last factor, so it is always used.
param_vec.emplace_back(detail::level::WORKITEM,
detail::get_ids<detail::global_kernel, Scalar, Domain, SubgroupSize>(),
std::vector<Idx>{static_cast<Idx>(factor_size)});
Expand All @@ -385,14 +390,14 @@ class committed_descriptor {
detail::level::SUBGROUP, static_cast<std::size_t>(factor_size), SubgroupSize,
{static_cast<Idx>(factor_sg), static_cast<Idx>(factor_wi)}, temp_num_sgs_in_wg) *
sizeof(Scalar) +
2 * static_cast<std::size_t>(factor_size) * sizeof(Scalar)) <
2 * static_cast<std::size_t>(factor_size) * (LocalRange / 2) * sizeof(Scalar)) <
static_cast<std::size_t>(local_memory_size);
}
return (num_scalars_in_local_mem<detail::layout::PACKED>(
detail::level::SUBGROUP, static_cast<std::size_t>(factor_size), SubgroupSize,
{static_cast<Idx>(factor_sg), static_cast<Idx>(factor_wi)}, temp_num_sgs_in_wg) *
sizeof(Scalar) +
2 * static_cast<std::size_t>(factor_size) * sizeof(Scalar)) <
2 * static_cast<std::size_t>(factor_size) * (LocalRange / 2) * sizeof(Scalar)) <
static_cast<std::size_t>(local_memory_size);
}
return false;
Expand Down
7 changes: 7 additions & 0 deletions test/unit_test/instantiate_fft_tests.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ INSTANTIATE_TEST_SUITE_P(WorkgroupOrGlobal, FFTTest,
::testing::Values(1, 128), ::testing::Values(sizes_t{8192}, sizes_t{16384}))),
test_params_print());

// Selected individual tests in workgroup or global size range
INSTANTIATE_TEST_SUITE_P(WorkgroupOrGlobalRegressionTest, FFTTest,
::testing::ConvertGenerator<basic_param_tuple>(
::testing::Combine(ip_packed_layout, fwd_only, interleaved_storage, ::testing::Values(3),
::testing::Values(sizes_t{15360}))),
test_params_print());

// Sizes that use the global implementations
INSTANTIATE_TEST_SUITE_P(GlobalTest, FFTTest,
::testing::ConvertGenerator<basic_param_tuple>(::testing::Combine(
Expand Down

0 comments on commit aec8049

Please sign in to comment.