Skip to content

Commit

Permalink
feat(pbs): slightly improve f64 pbs perf
Browse files Browse the repository at this point in the history
co-authored-by: sarah el kazdadi <[email protected]>
  • Loading branch information
IceTDrinker and sarah el kazdadi committed Sep 3, 2024
1 parent 10be6f9 commit 15e3474
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 26 deletions.
27 changes: 23 additions & 4 deletions tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,7 @@ pub(crate) fn update_with_fmadd(
fourier_poly_size: usize,
) {
let rhs = S::c64s_as_simd(fourier).0;
let len = rhs.len();

if is_output_uninit {
for (output_fourier, ggsw_poly) in izip!(
Expand All @@ -742,8 +743,17 @@ pub(crate) fn update_with_fmadd(
let out = S::c64s_as_mut_simd(output_fourier).0;
let lhs = S::c64s_as_simd(ggsw_poly).0;

for (out, &lhs, &rhs) in izip!(out, lhs, rhs) {
*out = simd.c64s_mul(lhs, rhs);
// This split is done to make better use of memory prefetchers see
// https://blog.mattstuchlik.com/2024/07/21/fastest-memory-read.html
let (lhs0, lhs1) = lhs.split_at(len / 2);
let (rhs0, rhs1) = rhs.split_at(len / 2);
let (out0, out1) = out.split_at_mut(len / 2);

for ((out0, out1), (lhs0, lhs1), (rhs0, rhs1)) in
izip!(izip!(out0, out1), izip!(lhs0, lhs1), izip!(rhs0, rhs1),)
{
*out0 = simd.c64s_mul(*lhs0, *rhs0);
*out1 = simd.c64s_mul(*lhs1, *rhs1);
}
}
} else {
Expand All @@ -754,8 +764,17 @@ pub(crate) fn update_with_fmadd(
let out = S::c64s_as_mut_simd(output_fourier).0;
let lhs = S::c64s_as_simd(ggsw_poly).0;

for (out, &lhs, &rhs) in izip!(out, lhs, rhs) {
*out = simd.c64s_mul_add_e(lhs, rhs, *out);
// This split is done to make better use of memory prefetchers see
// https://blog.mattstuchlik.com/2024/07/21/fastest-memory-read.html
let (lhs0, lhs1) = lhs.split_at(len / 2);
let (rhs0, rhs1) = rhs.split_at(len / 2);
let (out0, out1) = out.split_at_mut(len / 2);

for ((out0, out1), (lhs0, lhs1), (rhs0, rhs1)) in
izip!(izip!(out0, out1), izip!(lhs0, lhs1), izip!(rhs0, rhs1),)
{
*out0 = simd.c64s_mul_add_e(*lhs0, *rhs0, *out0);
*out1 = simd.c64s_mul_add_e(*lhs1, *rhs1, *out1);
}
}
}
Expand Down
42 changes: 20 additions & 22 deletions tfhe/src/core_crypto/fft_impl/fft64/math/fft/x86.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,16 @@ pub fn mm256_cvtpd_epi64(simd: V3, x: __m256d) -> __m256i {
avx2._mm256_blendv_epi8(value_if_positive, value_if_negative, sign_is_negative_mask)
}

/// Convert a vector of f64 values to a vector of i64 values.
/// This intrinsics is currently not available in rust, so we have our own implementation using
/// inline assembly.
///
/// The name matches Intel's convention (re-used by rust in their intrinsics) without the leading
/// `_`.
///
/// [`Intel's documentation`](`https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm512_cvtt_roundpd_epi64 `)
/// Convert a vector of f64 values to a vector of i64 values with rounding to nearest integer.
/// [`Intel's documentation`](`https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm512_cvt_roundpd_epi64`)
#[cfg(feature = "nightly-avx512")]
#[inline(always)]
pub fn mm512_cvtt_roundpd_epi64(simd: V4, x: __m512d) -> __m512i {
pub fn mm512_cvt_round_nearest_pd_epi64(simd: V4, x: __m512d) -> __m512i {
let _ = simd.avx512dq;

// SAFETY: simd contains an instance of avx512dq, that matches the target feature of
// `implementation`
_ = simd;
unsafe { _mm512_cvttpd_epi64(x) }
unsafe { _mm512_cvt_roundpd_epi64::<{ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC }>(x) }
}

/// Convert a vector of i64 values to a vector of f64 values. Not sure how it works.
Expand Down Expand Up @@ -512,7 +507,7 @@ pub fn convert_forward_integer_u64_avx2_v3(
/// Perform common work for `u32` and `u64`, used by the backward torus transformation.
///
/// This deinterleaves two vectors of c64 values into two vectors of real part and imaginary part,
/// then rounds to the nearest integer.
/// then returns the scaled fractional part.
#[cfg(feature = "nightly-avx512")]
#[inline(always)]
pub fn prologue_convert_torus_v4(
Expand Down Expand Up @@ -555,8 +550,8 @@ pub fn prologue_convert_torus_v4(
let fract_re = avx._mm512_sub_pd(mul_re, avx._mm512_roundscale_pd::<ROUNDING>(mul_re));
let fract_im = avx._mm512_sub_pd(mul_im, avx._mm512_roundscale_pd::<ROUNDING>(mul_im));
// scale fractional part and round
let fract_re = avx._mm512_roundscale_pd::<ROUNDING>(avx._mm512_mul_pd(scaling, fract_re));
let fract_im = avx._mm512_roundscale_pd::<ROUNDING>(avx._mm512_mul_pd(scaling, fract_im));
let fract_re = avx._mm512_mul_pd(scaling, fract_re);
let fract_im = avx._mm512_mul_pd(scaling, fract_im);

(fract_re, fract_im)
}
Expand Down Expand Up @@ -624,10 +619,13 @@ pub fn convert_add_backward_torus_u32_v4(
scaling,
);

// round to nearest integer and suppress exceptions
const ROUNDING: i32 = _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC;

// convert f64 to i32
let fract_re = avx512f._mm512_cvtpd_epi32(fract_re);
let fract_re = avx512f._mm512_cvt_roundpd_epi32::<ROUNDING>(fract_re);
// convert f64 to i32
let fract_im = avx512f._mm512_cvtpd_epi32(fract_im);
let fract_im = avx512f._mm512_cvt_roundpd_epi32::<ROUNDING>(fract_im);

// add to input and store
*out_re = pulp::cast(avx2._mm256_add_epi32(fract_re, pulp::cast(*out_re)));
Expand Down Expand Up @@ -708,9 +706,9 @@ pub fn convert_add_backward_torus_u64_v4(
);

// convert f64 to i64
let fract_re = mm512_cvtt_roundpd_epi64(simd, fract_re);
let fract_re = mm512_cvt_round_nearest_pd_epi64(simd, fract_re);
// convert f64 to i64
let fract_im = mm512_cvtt_roundpd_epi64(simd, fract_im);
let fract_im = mm512_cvt_round_nearest_pd_epi64(simd, fract_im);

// add to input and store
*out_re = pulp::cast(avx512f._mm512_add_epi64(fract_re, pulp::cast(*out_re)));
Expand Down Expand Up @@ -1060,7 +1058,7 @@ mod tests {
if x == 2.0f64.powi(63) {
// This is the proper representation in 2's complement, 2^63 gets folded
// onto -2^63
-(2i64.pow(63))
i64::MIN
} else {
x as i64
}
Expand Down Expand Up @@ -1100,14 +1098,14 @@ mod tests {
if x == 2.0f64.powi(63) {
// This is the proper representation in 2's complement, 2^63 gets folded
// onto -2^63
-(2i64.pow(63))
i64::MIN
} else {
x as i64
x.round() as i64
}
});

let computed: [i64; 4] =
pulp::cast_lossy(mm512_cvtt_roundpd_epi64(simd, pulp::cast([v, v])));
pulp::cast_lossy(mm512_cvt_round_nearest_pd_epi64(simd, pulp::cast([v, v])));
assert_eq!(target, computed);
}
}
Expand Down

0 comments on commit 15e3474

Please sign in to comment.