Skip to content

Commit

Permalink
prevent OOB read from global memory
Browse files Browse the repository at this point in the history
  • Loading branch information
AD2605 committed Mar 19, 2024
1 parent d0e705d commit 823b84f
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 21 deletions.
8 changes: 4 additions & 4 deletions src/portfft/common/subgroup_bluestein.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ PORTFFT_INLINE void sg_bluestein_batch_interleaved(
priv, priv_scratch, detail::elementwise_multiply::APPLIED, detail::elementwise_multiply::APPLIED,
conjugate_on_load, detail::complex_conjugate::NOT_APPLIED, detail::apply_scale_factor::APPLIED, load_modifier,
store_modifier, twiddles_loc, static_cast<T>(1. / (static_cast<T>(factor_sg * factor_wi))), 0, id_of_wi_in_fft,
factor_sg, factor_wi, global_data);
factor_sg, factor_wi, wi_working, global_data);

// TODO: Currently local memory is being used to load the data back in natural order for the backward phase, as the
// result of sg_dft is transposed. However, the ideal way to this is using shuffles. Implement a batched matrix
Expand Down Expand Up @@ -190,7 +190,7 @@ PORTFFT_INLINE void sg_bluestein_batch_interleaved(
detail::elementwise_multiply::APPLIED, detail::complex_conjugate::APPLIED,
detail::complex_conjugate::APPLIED, scale_applied, static_cast<const T*>(nullptr),
load_modifier, twiddles_loc, scale_factor, 0, id_of_wi_in_fft, factor_sg, factor_wi,
global_data);
wi_working, global_data);

if (conjugate_on_store == detail::complex_conjugate::APPLIED) {
global_data.log_message(__func__, "Applying complex conjugate on the output");
Expand Down Expand Up @@ -244,7 +244,7 @@ void sg_bluestein_packed(T* priv, T* priv_scratch, LocView& loc_view, LocTwiddle
priv, priv_scratch, detail::elementwise_multiply::APPLIED, detail::elementwise_multiply::APPLIED,
conjugate_on_load, detail::complex_conjugate::NOT_APPLIED, detail::apply_scale_factor::APPLIED, load_modifier,
store_modifier, loc_twiddles, static_cast<T>(1. / static_cast<T>(factor_sg * factor_wi)), 0, id_of_wi_in_fft,
factor_sg, factor_wi, global_data);
factor_sg, factor_wi, wi_working, global_data);

if (wi_working) {
global_data.log_message(__func__, "storing result of the forward phase back to local memory");
Expand Down Expand Up @@ -276,7 +276,7 @@ void sg_bluestein_packed(T* priv, T* priv_scratch, LocView& loc_view, LocTwiddle
detail::elementwise_multiply::APPLIED, detail::complex_conjugate::APPLIED,
detail::complex_conjugate::APPLIED, scale_applied, static_cast<const T*>(nullptr),
load_modifier, loc_twiddles, scale_factor, 0, id_of_wi_in_fft, factor_sg, factor_wi,
global_data);
wi_working, global_data);
if (conjugate_on_store == detail::complex_conjugate::APPLIED) {
global_data.log_message(__func__, "Applying complex conjugate on the output");
detail::conjugate_inplace(priv, factor_wi);
Expand Down
35 changes: 20 additions & 15 deletions src/portfft/common/subgroup_ct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ void sg_calc_twiddles(Idx factor_sg, Idx factor_wi, Idx n, Idx k, T* sg_twiddles
* @param id_of_wi_in_fft workitem id withing the fft
* @param factor_sg Number of workitems participating for one transform
* @param factor_wi Number of complex elements per workitem for each transform
* @param wi_working Whether or not the workitem participates in the data transfers
* @param global_data global_data_struct associated with the kernel launch
*/
template <Idx SubgroupSize, typename T, typename LocView>
Expand All @@ -348,21 +349,23 @@ PORTFFT_INLINE void sg_cooley_tukey(T* priv, T* private_scratch, detail::element
detail::apply_scale_factor scale_factor_applied, const T* load_modifier_data,
const T* store_modifier_data, LocView& twiddles_loc_view, T scale_factor,
IdxGlobal modifier_start_offset, Idx id_of_wi_in_fft, Idx factor_sg, Idx factor_wi,
detail::global_data_struct<1>& global_data) {
bool wi_working, detail::global_data_struct<1>& global_data) {
using vec2_t = sycl::vec<T, 2>;
vec2_t modifier_vec;
if (conjugate_on_load == detail::complex_conjugate::APPLIED) {
global_data.log_message(__func__, "Applying complex conjugate before computation of the FFT");
detail::conjugate_inplace(priv, factor_wi);
}
if (apply_load_modifier == detail::elementwise_multiply::APPLIED) {
global_data.log_message(__func__, "Applying load modifiers");
PORTFFT_UNROLL
for (Idx j = 0; j < factor_wi; j++) {
modifier_vec = *reinterpret_cast<const vec2_t*>(
&load_modifier_data[modifier_start_offset + 2 * factor_wi * id_of_wi_in_fft + 2 * j]);
detail::multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_vec[0], modifier_vec[1], priv[2 * j],
priv[2 * j + 1]);
if (wi_working) {
global_data.log_message(__func__, "Applying load modifiers");
PORTFFT_UNROLL
for (Idx j = 0; j < factor_wi; j++) {
modifier_vec = *reinterpret_cast<const vec2_t*>(
&load_modifier_data[modifier_start_offset + 2 * factor_wi * id_of_wi_in_fft + 2 * j]);
detail::multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_vec[0], modifier_vec[1], priv[2 * j],
priv[2 * j + 1]);
}
}
}
sg_dft<SubgroupSize>(priv, global_data.sg, factor_wi, factor_sg, twiddles_loc_view, private_scratch);
Expand All @@ -373,13 +376,15 @@ PORTFFT_INLINE void sg_cooley_tukey(T* priv, T* private_scratch, detail::element
}

if (apply_store_modifier == detail::elementwise_multiply::APPLIED) {
global_data.log_message(__func__, "Applying store modifiers");
PORTFFT_UNROLL
for (Idx j = 0; j < factor_wi; j++) {
modifier_vec = *reinterpret_cast<const vec2_t*>(
&store_modifier_data[modifier_start_offset + 2 * j * factor_sg + 2 * id_of_wi_in_fft]);
detail::multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_vec[0], modifier_vec[1], priv[2 * j],
priv[2 * j + 1]);
if (wi_working) {
global_data.log_message(__func__, "Applying store modifiers");
PORTFFT_UNROLL
for (Idx j = 0; j < factor_wi; j++) {
modifier_vec = *reinterpret_cast<const vec2_t*>(
&store_modifier_data[modifier_start_offset + 2 * j * factor_sg + 2 * id_of_wi_in_fft]);
detail::multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_vec[0], modifier_vec[1], priv[2 * j],
priv[2 * j + 1]);
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/portfft/dispatcher/subgroup_dispatcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag
sg_cooley_tukey<SubgroupSize>(priv, wi_private_scratch, multiply_on_load, multiply_on_store,
conjugate_on_load, conjugate_on_store, apply_scale_factor, load_modifier_data,
store_modifier_data, loc_twiddles, scaling_factor, modifier_offset,
id_of_wi_in_fft, factor_sg, factor_wi, global_data);
id_of_wi_in_fft, factor_sg, factor_wi, working_inner, global_data);
} else {
sg_bluestein_batch_interleaved<SubgroupSize>(
priv, wi_private_scratch, loc_view, load_modifier_data, store_modifier_data, loc_twiddles,
Expand Down Expand Up @@ -409,7 +409,7 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag
conjugate_on_store, apply_scale_factor, load_modifier_data, store_modifier_data,
loc_twiddles, scaling_factor,
static_cast<IdxGlobal>(fft_size) * (i - static_cast<IdxGlobal>(id_of_fft_in_sg)),
id_of_wi_in_fft, factor_sg, factor_wi, global_data);
id_of_wi_in_fft, factor_sg, factor_wi, working, global_data);
} else {
Idx loc_offset_store_view;
Idx loc_offset_load_view;
Expand Down

0 comments on commit 823b84f

Please sign in to comment.