Skip to content

Commit

Permalink
Adjust Global/Subgroup required mem calc
Browse files Browse the repository at this point in the history
  • Loading branch information
hjabird committed Jan 3, 2024
1 parent 6984da5 commit 8ecffbc
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/portfft/descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,19 +385,19 @@ class committed_descriptor {
IdxGlobal factor_wi = factor_size / factor_sg;
if (detail::can_cast_safely<IdxGlobal, Idx>(factor_sg) && detail::can_cast_safely<IdxGlobal, Idx>(factor_wi)) {
if (batch_interleaved_layout) {
return (2 *
num_scalars_in_local_mem<detail::layout::BATCH_INTERLEAVED>(
detail::level::SUBGROUP, static_cast<std::size_t>(factor_size), SubgroupSize,
{static_cast<Idx>(factor_sg), static_cast<Idx>(factor_wi)}, temp_num_sgs_in_wg) *
sizeof(Scalar) +
2 * static_cast<std::size_t>(factor_size) * (LocalRange / 2) * sizeof(Scalar)) <
// (local_memory_for_input + local_mem_for_store_modifier + local_mem_for_twiddles) < local_mem
return (sizeof(Scalar) *
(num_scalars_in_local_mem<detail::layout::BATCH_INTERLEAVED>(
detail::level::SUBGROUP, static_cast<std::size_t>(factor_size), SubgroupSize,
{static_cast<Idx>(factor_sg), static_cast<Idx>(factor_wi)}, temp_num_sgs_in_wg) +
static_cast<std::size_t>(factor_size) * LocalRange + 2 * static_cast<std::size_t>(factor_size))) <
static_cast<std::size_t>(local_memory_size);
}
return (num_scalars_in_local_mem<detail::layout::PACKED>(
detail::level::SUBGROUP, static_cast<std::size_t>(factor_size), SubgroupSize,
{static_cast<Idx>(factor_sg), static_cast<Idx>(factor_wi)}, temp_num_sgs_in_wg) *
sizeof(Scalar) +
2 * static_cast<std::size_t>(factor_size) * (LocalRange / 2) * sizeof(Scalar)) <
return (sizeof(Scalar) *
(num_scalars_in_local_mem<detail::layout::PACKED>(
detail::level::SUBGROUP, static_cast<std::size_t>(factor_size), SubgroupSize,
{static_cast<Idx>(factor_sg), static_cast<Idx>(factor_wi)}, temp_num_sgs_in_wg) +
static_cast<std::size_t>(factor_size) * LocalRange + 2 * static_cast<std::size_t>(factor_size))) <
static_cast<std::size_t>(local_memory_size);
}
return false;
Expand Down

0 comments on commit 8ecffbc

Please sign in to comment.