From 8ecffbce8282740f35d39be8d6a4d1c092c1533b Mon Sep 17 00:00:00 2001 From: Hugh Bird Date: Wed, 3 Jan 2024 16:41:29 +0000 Subject: [PATCH] Adjust Global/Subgroup required mem calc --- src/portfft/descriptor.hpp | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/portfft/descriptor.hpp b/src/portfft/descriptor.hpp index 164bb331..cea94f56 100644 --- a/src/portfft/descriptor.hpp +++ b/src/portfft/descriptor.hpp @@ -385,19 +385,19 @@ class committed_descriptor { IdxGlobal factor_wi = factor_size / factor_sg; if (detail::can_cast_safely(factor_sg) && detail::can_cast_safely(factor_wi)) { if (batch_interleaved_layout) { - return (2 * - num_scalars_in_local_mem( - detail::level::SUBGROUP, static_cast(factor_size), SubgroupSize, - {static_cast(factor_sg), static_cast(factor_wi)}, temp_num_sgs_in_wg) * - sizeof(Scalar) + - 2 * static_cast(factor_size) * (LocalRange / 2) * sizeof(Scalar)) < + // (local_memory_for_input + local_mem_for_store_modifier + local_mem_for_twiddles) < local_mem + return (sizeof(Scalar) * + (num_scalars_in_local_mem( + detail::level::SUBGROUP, static_cast(factor_size), SubgroupSize, + {static_cast(factor_sg), static_cast(factor_wi)}, temp_num_sgs_in_wg) + + static_cast(factor_size) * LocalRange + 2 * static_cast(factor_size))) < static_cast(local_memory_size); } - return (num_scalars_in_local_mem( - detail::level::SUBGROUP, static_cast(factor_size), SubgroupSize, - {static_cast(factor_sg), static_cast(factor_wi)}, temp_num_sgs_in_wg) * - sizeof(Scalar) + - 2 * static_cast(factor_size) * (LocalRange / 2) * sizeof(Scalar)) < + return (sizeof(Scalar) * + (num_scalars_in_local_mem( + detail::level::SUBGROUP, static_cast(factor_size), SubgroupSize, + {static_cast(factor_sg), static_cast(factor_wi)}, temp_num_sgs_in_wg) + + static_cast(factor_size) * LocalRange + 2 * static_cast(factor_size))) < static_cast(local_memory_size); } return false;