From ba8e4e48dc2e515a64984844fdc3f2157b92c57d Mon Sep 17 00:00:00 2001 From: sarah el kazdadi Date: Thu, 1 Aug 2024 13:33:02 +0200 Subject: [PATCH] feat(cuda): new fft impl --- .../cuda/src/crypto/torus.cuh | 11 +- .../cuda/src/fft/bnsmfft.cuh | 846 +++++++----------- .../cuda/src/pbs/bootstrapping_key.cuh | 10 +- .../cuda/src/polynomial/functions.cuh | 67 +- .../cuda/src/polynomial/parameters.cuh | 65 +- .../cuda/src/polynomial/polynomial_math.cuh | 10 +- .../cuda/src/types/complex/operations.cuh | 17 +- 7 files changed, 388 insertions(+), 638 deletions(-) diff --git a/backends/tfhe-cuda-backend/cuda/src/crypto/torus.cuh b/backends/tfhe-cuda-backend/cuda/src/crypto/torus.cuh index 4a5cd8fb66..5e34dbdc41 100644 --- a/backends/tfhe-cuda-backend/cuda/src/crypto/torus.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/crypto/torus.cuh @@ -19,14 +19,19 @@ __device__ inline void typecast_double_to_torus(double x, template <> __device__ inline void typecast_double_to_torus(double x, uint64_t &r) { - // The ull intrinsic does not behave in the same way on all architectures and - // on some platforms this causes the cmux tree test to fail - // Hence the intrinsic is not used here uint128 nnnn = make_uint128_from_float(x); uint64_t lll = nnnn.lo_; r = lll; } +template +__device__ inline void typecast_double_round_to_torus(double x, T &r) { + double mx = (sizeof(T) == 4) ? 4294967296.0 : 18446744073709551616.0; + double frac = x - floor(x); + frac *= mx; + typecast_double_to_torus(frac, r); +} + template __device__ inline T round_to_closest_multiple(T x, uint32_t base_log, uint32_t level_count) { diff --git a/backends/tfhe-cuda-backend/cuda/src/fft/bnsmfft.cuh b/backends/tfhe-cuda-backend/cuda/src/fft/bnsmfft.cuh index a57dd0d0c9..7cfa12186c 100644 --- a/backends/tfhe-cuda-backend/cuda/src/fft/bnsmfft.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/fft/bnsmfft.cuh @@ -6,609 +6,389 @@ #include "twiddles.cuh" #include "types/complex/operations.cuh" -/* - * Direct negacyclic FFT: - * - before the FFT the N real coefficients are stored into a - * N/2 sized complex with the even coefficients in the real part - * and the odd coefficients in the imaginary part. This is referred to - * as the half-size FFT - * - when calling BNSMFFT_direct for the forward negacyclic FFT of PBS, - * opt is divided by 2 because the butterfly pattern is always applied - * between pairs of coefficients - * - instead of twisting each coefficient A_j before the FFT by - * multiplying by the w^j roots of unity (aka twiddles, w=exp(-i pi /N)), - * the FFT is modified, and for each level k of the FFT the twiddle: - * w_j,k = exp(-i pi j/2^k) - * is replaced with: - * \zeta_j,k = exp(-i pi (2j-1)/2^k) - */ -template __device__ void NSMFFT_direct(double2 *A) { - - /* We don't make bit reverse here, since twiddles are already reversed - * Each thread is always in charge of "opt/2" pairs of coefficients, - * which is why we always loop through N/2 by N/opt strides - * The pragma unroll instruction tells the compiler to unroll the - * full loop, which should increase performance - */ - - size_t tid = threadIdx.x; - size_t twid_id; - size_t i1, i2; - double2 u, v, w; - // level 1 - // we don't make actual complex multiplication on level1 since we have only - // one twiddle, it's real and image parts are equal, so we can multiply - // it with simpler operations -#pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - i1 = tid; - i2 = tid + params::degree / 2; - - u = A[i1]; - v = A[i2] * (double2){0.707106781186547461715008466854, - 0.707106781186547461715008466854}; - - A[i1] += v; - A[i2] = u - v; - - tid += params::degree / params::opt; - } - __syncthreads(); - - // level 2 - // from this level there are more than one twiddles and none of them has equal - // real and imag parts, so complete complex multiplication is needed - // for each level params::degree / 2^level represents number of coefficients - // inside divided chunk of specific level - // - tid = threadIdx.x; -#pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 4); - i1 = 2 * (params::degree / 4) * twid_id + (tid & (params::degree / 4 - 1)); - i2 = i1 + params::degree / 4; - - w = negtwiddles[twid_id + 2]; - u = A[i1]; - v = A[i2] * w; - - A[i1] += v; - A[i2] = u - v; - - tid += params::degree / params::opt; - } - __syncthreads(); - - // level 3 - tid = threadIdx.x; -#pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 8); - i1 = 2 * (params::degree / 8) * twid_id + (tid & (params::degree / 8 - 1)); - i2 = i1 + params::degree / 8; - - w = negtwiddles[twid_id + 4]; - u = A[i1]; - v = A[i2] * w; - - A[i1] += v; - A[i2] = u - v; - - tid += params::degree / params::opt; - } - __syncthreads(); - - // level 4 - tid = threadIdx.x; -#pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 16); - i1 = - 2 * (params::degree / 16) * twid_id + (tid & (params::degree / 16 - 1)); - i2 = i1 + params::degree / 16; - - w = negtwiddles[twid_id + 8]; - u = A[i1]; - v = A[i2] * w; - - A[i1] += v; - A[i2] = u - v; - - tid += params::degree / params::opt; - } - __syncthreads(); - - // level 5 - tid = threadIdx.x; -#pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 32); - i1 = - 2 * (params::degree / 32) * twid_id + (tid & (params::degree / 32 - 1)); - i2 = i1 + params::degree / 32; +using Index = unsigned; - w = negtwiddles[twid_id + 16]; - u = A[i1]; - v = A[i2] * w; +constexpr inline Index warpsize() { return 32; } - A[i1] += v; - A[i2] = u - v; - - tid += params::degree / params::opt; - } - __syncthreads(); - - // level 6 - tid = threadIdx.x; -#pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 64); - i1 = - 2 * (params::degree / 64) * twid_id + (tid & (params::degree / 64 - 1)); - i2 = i1 + params::degree / 64; - - w = negtwiddles[twid_id + 32]; - u = A[i1]; - v = A[i2] * w; - - A[i1] += v; - A[i2] = u - v; - - tid += params::degree / params::opt; +template __device__ inline void sync_fft_step() { + if constexpr (SIZE < warpsize()) { + __syncwarp(); + } else { + __syncthreads(); } - __syncthreads(); - - // level 7 - tid = threadIdx.x; -#pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 128); - i1 = 2 * (params::degree / 128) * twid_id + - (tid & (params::degree / 128 - 1)); - i2 = i1 + params::degree / 128; - - w = negtwiddles[twid_id + 64]; - u = A[i1]; - v = A[i2] * w; +} - A[i1] += v; - A[i2] = u - v; +__device__ constexpr bool coalesce_mem_access() { return false; } - tid += params::degree / params::opt; - } - __syncthreads(); +__device__ inline void direct_butterfly(double2 *a0, double2 *a1, double2 w) { + double2 b0 = *a0; + double2 b1 = *a1 * w; - // from level 8, we need to check size of params degree, because we support - // minimum actual polynomial size = 256, when compressed size is halfed and - // minimum supported compressed size is 128, so we always need first 7 - // levels of butterfly operation, since butterfly levels are hardcoded - // we need to check if polynomial size is big enough to require specific level - // of butterfly. - if constexpr (params::degree >= 256) { - // level 8 - tid = threadIdx.x; -#pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 256); - i1 = 2 * (params::degree / 256) * twid_id + - (tid & (params::degree / 256 - 1)); - i2 = i1 + params::degree / 256; + *a0 = b0 + b1; + *a1 = b0 - b1; +} - w = negtwiddles[twid_id + 128]; - u = A[i1]; - v = A[i2] * w; +__device__ inline void inverse_butterfly(double2 *a0, double2 *a1, double2 w) { + double2 b0 = *a0 + *a1; + double2 b1 = *a0 - *a1; - A[i1] += v; - A[i2] = u - v; + *a0 = b0; + *a1 = mul_conj(b1, w); +} - tid += params::degree / params::opt; +template +__device__ inline void NSMFFT_direct_2warpsize(double2 *a0, double2 *a1, + Index const tid) { + if constexpr (T > 0) { + if constexpr (T < warpsize()) { + // thread 0: + // a0: 0 * T + // a1: 2 * T + // thread 1: + // a0: 1 * T + // a1: 3 * T + + bool is_lower = ((tid >> tfhe_log2(T)) & 1) == 0; + double2 tmp = is_lower ? *a1 : *a0; + // thread 0: + // tmp: 2 * T + // thread 1: + // tmp: 1 * T + + tmp.x = __shfl_xor_sync(0xFFFFFFFF, tmp.x, T, 2 * T); + tmp.y = __shfl_xor_sync(0xFFFFFFFF, tmp.y, T, 2 * T); + // thread 0: + // tmp: 1 * T + // thread 1: + // tmp: 2 * T + + *a0 = is_lower ? *a0 : tmp; + *a1 = is_lower ? tmp : *a1; + + // thread 0: + // a0: 0 * T + // a1: 1 * T + // thread 1: + // a0: 2 * T + // a1: 3 * T } - __syncthreads(); + double2 w1 = negtwiddles[M + (tid / T)]; + direct_butterfly(a0, a1, w1); + NSMFFT_direct_2warpsize(a0, a1, tid); } +} - if constexpr (params::degree >= 512) { - // level 9 - tid = threadIdx.x; -#pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 512); - i1 = 2 * (params::degree / 512) * twid_id + - (tid & (params::degree / 512 - 1)); - i2 = i1 + params::degree / 512; - - w = negtwiddles[twid_id + 256]; - u = A[i1]; - v = A[i2] * w; - - A[i1] += v; - A[i2] = u - v; - - tid += params::degree / params::opt; +template +__device__ inline void NSMFFT_inverse_2warpsize(double2 *a0, double2 *a1, + Index const tid) { + if constexpr (T > 0) { + NSMFFT_inverse_2warpsize(a0, a1, tid); + double2 w1 = negtwiddles[M + (tid / T)]; + inverse_butterfly(a0, a1, w1); + + if constexpr (T < warpsize()) { + // thread 0: + // a0: 0 * T + // a1: 1 * T + // thread 1: + // a0: 2 * T + // a1: 3 * T + + bool is_lower = ((tid >> tfhe_log2(T)) & 1) == 0; + double2 tmp = is_lower ? *a1 : *a0; + // thread 0: + // tmp: 1 * T + // thread 1: + // tmp: 2 * T + + tmp.x = __shfl_xor_sync(0xFFFFFFFF, tmp.x, T, 2 * T); + tmp.y = __shfl_xor_sync(0xFFFFFFFF, tmp.y, T, 2 * T); + // thread 0: + // tmp: 2 * T + // thread 1: + // tmp: 1 * T + + *a0 = is_lower ? *a0 : tmp; + *a1 = is_lower ? tmp : *a1; + + // thread 0: + // a0: 0 * T + // a1: 2 * T + // thread 1: + // a0: 1 * T + // a1: 3 * T } - __syncthreads(); } +} - if constexpr (params::degree >= 1024) { - // level 10 - tid = threadIdx.x; -#pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 1024); - i1 = 2 * (params::degree / 1024) * twid_id + - (tid & (params::degree / 1024 - 1)); - i2 = i1 + params::degree / 1024; - - w = negtwiddles[twid_id + 512]; - u = A[i1]; - v = A[i2] * w; +template +__device__ inline void NSMFFT_direct_step(double2 *A, Index const tid) { + constexpr Index SIZE = DEGREE / M; + constexpr Index T = SIZE / RADIX; - A[i1] += v; - A[i2] = u - v; + Index i; + double2 w1; - tid += params::degree / params::opt; - } - __syncthreads(); + if constexpr (M == 1) { + i = 0; + w1 = (double2){ + 0.707106781186547461715008466854, + 0.707106781186547461715008466854, + }; + } else { + i = tid / T; + w1 = negtwiddles[M + i]; } - if constexpr (params::degree >= 2048) { - // level 11 - tid = threadIdx.x; -#pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 2048); - i1 = 2 * (params::degree / 2048) * twid_id + - (tid & (params::degree / 2048 - 1)); - i2 = i1 + params::degree / 2048; - - w = negtwiddles[twid_id + 1024]; - u = A[i1]; - v = A[i2] * w; - - A[i1] += v; - A[i2] = u - v; - - tid += params::degree / params::opt; - } - __syncthreads(); - } + Index j = i * SIZE + (tid % T); - if constexpr (params::degree >= 4096) { - // level 12 - tid = threadIdx.x; -#pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 4096); - i1 = 2 * (params::degree / 4096) * twid_id + - (tid & (params::degree / 4096 - 1)); - i2 = i1 + params::degree / 4096; + if constexpr (RADIX == 4) { + double2 w2 = negtwiddles[2 * (M + i) + 0]; + double2 w3 = negtwiddles[2 * (M + i) + 1]; - w = negtwiddles[twid_id + 2048]; - u = A[i1]; - v = A[i2] * w; + Index j0 = j + 0 * T; + Index j1 = j + 1 * T; + Index j2 = j + 2 * T; + Index j3 = j + 3 * T; - A[i1] += v; - A[i2] = u - v; + double2 a0 = A[j0]; + double2 a1 = A[j1]; + double2 a2 = A[j2]; + double2 a3 = A[j3]; - tid += params::degree / params::opt; - } - __syncthreads(); - } + direct_butterfly(&a0, &a2, w1); + direct_butterfly(&a1, &a3, w1); + direct_butterfly(&a0, &a1, w2); + direct_butterfly(&a2, &a3, w3); - if constexpr (params::degree >= 8192) { - // level 13 - tid = threadIdx.x; -#pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 8192); - i1 = 2 * (params::degree / 8192) * twid_id + - (tid & (params::degree / 8192 - 1)); - i2 = i1 + params::degree / 8192; + A[j0] = a0; + A[j1] = a1; + A[j2] = a2; + A[j3] = a3; + } else if constexpr (RADIX == 2) { + Index j0 = j + 0 * T; + Index j1 = j + 1 * T; - w = negtwiddles[twid_id + 4096]; - u = A[i1]; - v = A[i2] * w; + double2 a0 = A[j0]; + double2 a1 = A[j1]; - A[i1] += v; - A[i2] = u - v; + direct_butterfly(&a0, &a1, w1); - tid += params::degree / params::opt; - } - __syncthreads(); + A[j0] = a0; + A[j1] = a1; } } -/* - * negacyclic inverse fft - */ -template __device__ void NSMFFT_inverse(double2 *A) { - - /* We don't make bit reverse here, since twiddles are already reversed - * Each thread is always in charge of "opt/2" pairs of coefficients, - * which is why we always loop through N/2 by N/opt strides - * The pragma unroll instruction tells the compiler to unroll the - * full loop, which should increase performance - */ - - size_t tid = threadIdx.x; - size_t twid_id; - size_t i1, i2; - double2 u, w; - - // divide input by compressed polynomial size - tid = threadIdx.x; - for (size_t i = 0; i < params::opt; ++i) { - A[tid] /= params::degree; - tid += params::degree / params::opt; - } - __syncthreads(); - - // none of the twiddles have equal real and imag part, so - // complete complex multiplication has to be done - // here we have more than one twiddle - // mapping in backward fft is reversed - // butterfly operation is started from last level - - if constexpr (params::degree >= 8192) { - // level 13 - tid = threadIdx.x; -#pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 8192); - i1 = 2 * (params::degree / 8192) * twid_id + - (tid & (params::degree / 8192 - 1)); - i2 = i1 + params::degree / 8192; - - w = negtwiddles[twid_id + 4096]; - u = A[i1] - A[i2]; - - A[i1] += A[i2]; - A[i2] = u * conjugate(w); - - tid += params::degree / params::opt; +template +__device__ inline void NSMFFT_inverse_step(double2 *A, Index const tid) { + constexpr Index SIZE = DEGREE / M; + constexpr Index T = SIZE / RADIX; + + Index i; + double2 w1; + + if constexpr (M == 1) { + i = 0; + w1 = (double2){ + 0.707106781186547461715008466854, + 0.707106781186547461715008466854, + }; + } else { + i = tid / T; + w1 = negtwiddles[M + i]; + } + + Index j = i * SIZE + (tid % T); + + if constexpr (RADIX == 4) { + double2 w2 = negtwiddles[2 * (M + i) + 0]; + double2 w3 = negtwiddles[2 * (M + i) + 1]; + + Index j0 = j + 0 * T; + Index j1 = j + 1 * T; + Index j2 = j + 2 * T; + Index j3 = j + 3 * T; + + double2 a0 = A[j0]; + double2 a1 = A[j1]; + double2 a2 = A[j2]; + double2 a3 = A[j3]; + + if constexpr (M == 1) { + a0 *= 1.0 / double(DEGREE); + a1 *= 1.0 / double(DEGREE); + a2 *= 1.0 / double(DEGREE); + a3 *= 1.0 / double(DEGREE); } - __syncthreads(); - } - - if constexpr (params::degree >= 4096) { - // level 12 - tid = threadIdx.x; -#pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 4096); - i1 = 2 * (params::degree / 4096) * twid_id + - (tid & (params::degree / 4096 - 1)); - i2 = i1 + params::degree / 4096; - - w = negtwiddles[twid_id + 2048]; - u = A[i1] - A[i2]; - - A[i1] += A[i2]; - A[i2] = u * conjugate(w); - tid += params::degree / params::opt; + inverse_butterfly(&a0, &a1, w2); + inverse_butterfly(&a2, &a3, w3); + inverse_butterfly(&a0, &a2, w1); + inverse_butterfly(&a1, &a3, w1); + + A[j0] = a0; + A[j1] = a1; + A[j2] = a2; + A[j3] = a3; + } else if constexpr (RADIX == 2) { + Index j0 = j + 0 * T; + Index j1 = j + 1 * T; + + double2 a0 = A[j0]; + double2 a1 = A[j1]; + + if constexpr (M == 1) { + a0 *= 1.0 / double(DEGREE); + a1 *= 1.0 / double(DEGREE); } - __syncthreads(); - } - - if constexpr (params::degree >= 2048) { - // level 11 - tid = threadIdx.x; -#pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 2048); - i1 = 2 * (params::degree / 2048) * twid_id + - (tid & (params::degree / 2048 - 1)); - i2 = i1 + params::degree / 2048; - - w = negtwiddles[twid_id + 1024]; - u = A[i1] - A[i2]; - A[i1] += A[i2]; - A[i2] = u * conjugate(w); + inverse_butterfly(&a0, &a1, w1); - tid += params::degree / params::opt; - } - __syncthreads(); + A[j0] = a0; + A[j1] = a1; } +} - if constexpr (params::degree >= 1024) { - // level 10 - tid = threadIdx.x; -#pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 1024); - i1 = 2 * (params::degree / 1024) * twid_id + - (tid & (params::degree / 1024 - 1)); - i2 = i1 + params::degree / 1024; +template +__device__ inline void NSMFFT_direct_impl(double2 *A, Index const tid) { + static_assert(OPT >= RADIX, + "params::opt should be larger than or equal to the fft radix"); - w = negtwiddles[twid_id + 512]; - u = A[i1] - A[i2]; + constexpr Index SIZE = DEGREE / M; - A[i1] += A[i2]; - A[i2] = u * conjugate(w); + if constexpr (coalesce_mem_access()) { + constexpr Index WARPSIZE = warpsize(); - tid += params::degree / params::opt; - } - __syncthreads(); - } + if constexpr (SIZE >= 2 * WARPSIZE) { + constexpr Index radix = (SIZE == 2 * WARPSIZE) ? 2 + : ((SIZE / (2 * WARPSIZE)) % RADIX != 0) ? 2 + : RADIX; + + __syncthreads(); - if constexpr (params::degree >= 512) { - // level 9 - tid = threadIdx.x; #pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 512); - i1 = 2 * (params::degree / 512) * twid_id + - (tid & (params::degree / 512 - 1)); - i2 = i1 + params::degree / 512; + for (Index k = 0; k < OPT / radix; ++k) { + if constexpr (SIZE == 2 * WARPSIZE) { + static_assert(radix == 2, ""); - w = negtwiddles[twid_id + 256]; - u = A[i1] - A[i2]; + Index i = tid / WARPSIZE; + Index j = i * SIZE + (tid % WARPSIZE); + Index j0 = j + 0 * WARPSIZE; + Index j1 = j + 1 * WARPSIZE; - A[i1] += A[i2]; - A[i2] = u * conjugate(w); + double2 a0 = A[j0]; + double2 a1 = A[j1]; + NSMFFT_direct_2warpsize(&a0, &a1, tid); + A[j0] = a0; + A[j1] = a1; + } else { + static_assert(SIZE > 2 * WARPSIZE, ""); + NSMFFT_direct_step(A, tid + k * (DEGREE / OPT)); + } + } - tid += params::degree / params::opt; + NSMFFT_direct_impl(A, tid); } - __syncthreads(); - } + } else { + if constexpr (SIZE > 1) { + constexpr Index radix = (SIZE % RADIX != 0) ? 2 : RADIX; + sync_fft_step(); - if constexpr (params::degree >= 256) { - // level 8 - tid = threadIdx.x; #pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 256); - i1 = 2 * (params::degree / 256) * twid_id + - (tid & (params::degree / 256 - 1)); - i2 = i1 + params::degree / 256; + for (Index k = 0; k < OPT / radix; ++k) { + NSMFFT_direct_step(A, tid + k * (DEGREE / OPT)); + } - w = negtwiddles[twid_id + 128]; - u = A[i1] - A[i2]; - - A[i1] += A[i2]; - A[i2] = u * conjugate(w); - - tid += params::degree / params::opt; + NSMFFT_direct_impl(A, tid); } - __syncthreads(); } +} - // below level 8, we don't need to check size of params degree, because we - // support minimum actual polynomial size = 256, when compressed size is - // halfed and minimum supported compressed size is 128, so we always need - // last 7 levels of butterfly operation, since butterfly levels are hardcoded - // we don't need to check if polynomial size is big enough to require - // specific level of butterfly. - // level 7 - tid = threadIdx.x; -#pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 128); - i1 = 2 * (params::degree / 128) * twid_id + - (tid & (params::degree / 128 - 1)); - i2 = i1 + params::degree / 128; - - w = negtwiddles[twid_id + 64]; - u = A[i1] - A[i2]; - - A[i1] += A[i2]; - A[i2] = u * conjugate(w); - - tid += params::degree / params::opt; - } - __syncthreads(); +template +__device__ inline void NSMFFT_inverse_impl(double2 *A, Index const tid) { + static_assert(OPT >= RADIX, + "params::opt should be larger than or equal to the fft radix"); - // level 6 - tid = threadIdx.x; -#pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 64); - i1 = - 2 * (params::degree / 64) * twid_id + (tid & (params::degree / 64 - 1)); - i2 = i1 + params::degree / 64; + constexpr Index SIZE = DEGREE / M; - w = negtwiddles[twid_id + 32]; - u = A[i1] - A[i2]; + if constexpr (coalesce_mem_access()) { + constexpr Index WARPSIZE = warpsize(); - A[i1] += A[i2]; - A[i2] = u * conjugate(w); + if constexpr (SIZE >= 2 * WARPSIZE) { + constexpr Index radix = (SIZE == 2 * WARPSIZE) ? 2 + : ((SIZE / (2 * WARPSIZE)) % RADIX != 0) ? 2 + : RADIX; - tid += params::degree / params::opt; - } - __syncthreads(); + NSMFFT_inverse_impl(A, tid); - // level 5 - tid = threadIdx.x; #pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 32); - i1 = - 2 * (params::degree / 32) * twid_id + (tid & (params::degree / 32 - 1)); - i2 = i1 + params::degree / 32; - - w = negtwiddles[twid_id + 16]; - u = A[i1] - A[i2]; + for (Index k = 0; k < OPT / radix; ++k) { + if constexpr (SIZE == 2 * WARPSIZE) { + static_assert(radix == 2, ""); - A[i1] += A[i2]; - A[i2] = u * conjugate(w); + Index i = tid / WARPSIZE; + Index j = i * SIZE + (tid % WARPSIZE); + Index j0 = j + 0 * WARPSIZE; + Index j1 = j + 1 * WARPSIZE; - tid += params::degree / params::opt; - } - __syncthreads(); - - // level 4 - tid = threadIdx.x; -#pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 16); - i1 = - 2 * (params::degree / 16) * twid_id + (tid & (params::degree / 16 - 1)); - i2 = i1 + params::degree / 16; + double2 a0 = A[j0]; + double2 a1 = A[j1]; + NSMFFT_inverse_2warpsize(&a0, &a1, tid); + A[j0] = a0; + A[j1] = a1; + } else { + static_assert(SIZE > 2 * WARPSIZE, ""); + NSMFFT_inverse_step(A, tid + k * (DEGREE / OPT)); + } + } - w = negtwiddles[twid_id + 8]; - u = A[i1] - A[i2]; - - A[i1] += A[i2]; - A[i2] = u * conjugate(w); + __syncthreads(); + } + } else { + if constexpr (SIZE > 1) { + constexpr Index radix = (SIZE % RADIX != 0) ? 2 : RADIX; - tid += params::degree / params::opt; - } - __syncthreads(); + NSMFFT_inverse_impl(A, tid); - // level 3 - tid = threadIdx.x; #pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 8); - i1 = 2 * (params::degree / 8) * twid_id + (tid & (params::degree / 8 - 1)); - i2 = i1 + params::degree / 8; - - w = negtwiddles[twid_id + 4]; - u = A[i1] - A[i2]; + for (Index k = 0; k < OPT / radix; ++k) { + NSMFFT_inverse_step(A, tid + k * (DEGREE / OPT)); + } - A[i1] += A[i2]; - A[i2] = u * conjugate(w); - - tid += params::degree / params::opt; + sync_fft_step(); + } } - __syncthreads(); - - // level 2 - tid = threadIdx.x; -#pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 4); - i1 = 2 * (params::degree / 4) * twid_id + (tid & (params::degree / 4 - 1)); - i2 = i1 + params::degree / 4; - - w = negtwiddles[twid_id + 2]; - u = A[i1] - A[i2]; - - A[i1] += A[i2]; - A[i2] = u * conjugate(w); +} - tid += params::degree / params::opt; - } +/* + * Direct negacyclic FFT: + * - before the FFT the N real coefficients are stored into a + * N/2 sized complex with the even coefficients in the real part + * and the odd coefficients in the imaginary part. This is referred to + * as the half-size FFT + * - when calling BNSMFFT_direct for the forward negacyclic FFT of PBS, + * opt is divided by 2 because the butterfly pattern is always applied + * between pairs of coefficients + * - instead of twisting each coefficient A_j before the FFT by + * multiplying by the w^j roots of unity (aka twiddles, w=exp(-i pi /N)), + * the FFT is modified, and for each level k of the FFT the twiddle: + * w_j,k = exp(-i pi j/2^k) + * is replaced with: + * \zeta_j,k = exp(-i pi (2j-1)/2^k) + */ +template +__device__ void NSMFFT_direct(double2 *A) { + NSMFFT_direct_impl(A, threadIdx.x); __syncthreads(); +} - // level 1 - tid = threadIdx.x; -#pragma unroll - for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / (params::degree / 2); - i1 = 2 * (params::degree / 2) * twid_id + (tid & (params::degree / 2 - 1)); - i2 = i1 + params::degree / 2; - - w = negtwiddles[twid_id + 1]; - u = A[i1] - A[i2]; - - A[i1] += A[i2]; - A[i2] = u * conjugate(w); - - tid += params::degree / params::opt; - } +/* + * negacyclic inverse fft + */ +template +__device__ void NSMFFT_inverse(double2 *A) { __syncthreads(); + NSMFFT_inverse_impl(A, threadIdx.x); } /* diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/bootstrapping_key.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/bootstrapping_key.cuh index 966be0464d..909de8b3f6 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/bootstrapping_key.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/bootstrapping_key.cuh @@ -88,7 +88,9 @@ void cuda_convert_lwe_programmable_bootstrap_key(cudaStream_t stream, total_polynomials * polynomial_size / 2 * sizeof(double2); int gridSize = total_polynomials; - int blockSize = polynomial_size / choose_opt_amortized(polynomial_size); + int blockSize = + polynomial_size / + choose_opt_amortized(polynomial_size, tfhe_fft_default_radix()); double2 *h_bsk; cudaMallocHost((void **)&h_bsk, buffer_size); @@ -263,7 +265,9 @@ void cuda_fourier_polynomial_mul(cudaStream_t stream, uint32_t gpu_index, size_t shared_memory_size = sizeof(double2) * polynomial_size / 2; int gridSize = total_polynomials; - int blockSize = polynomial_size / choose_opt_amortized(polynomial_size); + int blockSize = + polynomial_size / + choose_opt_amortized(polynomial_size, tfhe_fft_default_radix()); double2 *buffer; switch (polynomial_size) { @@ -292,7 +296,7 @@ void cuda_fourier_polynomial_mul(cudaStream_t stream, uint32_t gpu_index, if (shared_memory_size <= cuda_get_max_shared_memory(0)) { buffer = (double2 *)cuda_malloc_async(0, stream, gpu_index); check_cuda_error(cudaFuncSetAttribute( - batch_polynomial_mul, ForwardFFT>, + batch_polynomial_mul, ForwardFFT>, FULLSM>, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); check_cuda_error(cudaFuncSetCacheConfig( diff --git a/backends/tfhe-cuda-backend/cuda/src/polynomial/functions.cuh b/backends/tfhe-cuda-backend/cuda/src/polynomial/functions.cuh index f94c8e5848..b641323045 100644 --- a/backends/tfhe-cuda-backend/cuda/src/polynomial/functions.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/polynomial/functions.cuh @@ -58,26 +58,14 @@ divide_by_monomial_negacyclic_inplace(T *accumulator, tid += block_size; } } else { - tid = threadIdx.x; for (int i = 0; i < elems_per_thread; i++) { - if (j < degree) { - // if (tid < degree - j) - // accumulator_slice[tid] = input_slice[tid + j]; - // else - // accumulator_slice[tid] = -input_slice[tid - degree + j]; - int x = tid + j - SEL(degree, 0, tid < degree - j); - accumulator_slice[tid] = - SEL(-1, 1, tid < degree - j) * input_slice[x]; - } else { - int32_t jj = j - degree; - // if (tid < degree - jj) - // accumulator_slice[tid] = -input_slice[tid + jj]; - // else - // accumulator_slice[tid] = input_slice[tid - degree + jj]; - int x = tid + jj - SEL(degree, 0, tid < degree - jj); - accumulator_slice[tid] = - SEL(1, -1, tid < degree - jj) * input_slice[x]; - } + unsigned x = (unsigned)(tid + j); + x %= 2 * (unsigned)(degree); + bool wrap_around = x >= degree; + x %= (unsigned)(degree); + + accumulator_slice[tid] = wrap_around ? -input_slice[x] : input_slice[x]; + tid += block_size; } } @@ -101,24 +89,14 @@ __device__ void multiply_by_monomial_negacyclic_and_sub_polynomial( T *result_acc_slice = (T *)result_acc + (ptrdiff_t)(z * degree); int tid = threadIdx.x; for (int i = 0; i < elems_per_thread; i++) { - if (j < degree) { - // if (tid < j) - // result_acc_slice[tid] = -acc_slice[tid - j + degree]-acc_slice[tid]; - // else - // result_acc_slice[tid] = acc_slice[tid - j] - acc_slice[tid]; - int x = tid - j + SEL(0, degree, tid < j); - result_acc_slice[tid] = - SEL(1, -1, tid < j) * acc_slice[x] - acc_slice[tid]; - } else { - int32_t jj = j - degree; - // if (tid < jj) - // result_acc_slice[tid] = acc_slice[tid - jj + degree]-acc_slice[tid]; - // else - // result_acc_slice[tid] = -acc_slice[tid - jj] - acc_slice[tid]; - int x = tid - jj + SEL(0, degree, tid < jj); - result_acc_slice[tid] = - SEL(-1, 1, tid < jj) * acc_slice[x] - acc_slice[tid]; - } + // overflowing is fine here, since the degree is a power of two + unsigned x = (unsigned)(tid - j); + x %= 2 * (unsigned)(degree); + bool wrap_around = x >= degree; + x %= (unsigned)(degree); + + result_acc_slice[tid] = + (wrap_around ? -acc_slice[x] : acc_slice[x]) - acc_slice[tid]; tid += block_size; } } @@ -155,28 +133,17 @@ __device__ void round_to_closest_multiple_inplace(T *rotated_acc, int base_log, template __device__ void add_to_torus(double2 *m_values, Torus *result, bool init_torus = false) { - Torus mx = (sizeof(Torus) == 4) ? UINT32_MAX : UINT64_MAX; int tid = threadIdx.x; #pragma unroll for (int i = 0; i < params::opt / 2; i++) { double v1 = m_values[tid].x; double v2 = m_values[tid].y; - double frac = v1 - floor(v1); - frac *= mx; - double carry = frac - floor(frac); - frac += (carry >= 0.5); - Torus V1 = 0; - typecast_double_to_torus(frac, V1); - - frac = v2 - floor(v2); - frac *= mx; - carry = frac - floor(v2); - frac += (carry >= 0.5); + typecast_double_round_to_torus(v1, V1); Torus V2 = 0; - typecast_double_to_torus(frac, V2); + typecast_double_round_to_torus(v2, V2); if (init_torus) { result[tid] = V1; diff --git a/backends/tfhe-cuda-backend/cuda/src/polynomial/parameters.cuh b/backends/tfhe-cuda-backend/cuda/src/polynomial/parameters.cuh index 0c8325dd84..8906f4fbff 100644 --- a/backends/tfhe-cuda-backend/cuda/src/polynomial/parameters.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/polynomial/parameters.cuh @@ -1,54 +1,48 @@ #ifndef CUDA_PARAMETERS_CUH #define CUDA_PARAMETERS_CUH -constexpr int log2(int n) { return (n <= 2) ? 1 : 1 + log2(n / 2); } +constexpr int tfhe_fft_default_radix() { return 2; } -constexpr int choose_opt_amortized(int degree) { - if (degree <= 1024) - return 4; - else if (degree == 2048) - return 8; - else if (degree == 4096) - return 16; - else if (degree == 8192) - return 32; - else - return 64; +constexpr int tfhe_log2(int n) { return (n <= 1) ? 0 : 1 + tfhe_log2(n / 2); } + +template constexpr T tfhe_min(T a, T b) { return a < b ? a : b; } +template constexpr T tfhe_max(T a, T b) { return a > b ? a : b; } + +constexpr int choose_opt_amortized(int degree, int radix) { + return tfhe_max(2 * radix, (degree <= 1024) ? 4 + : (degree == 2048) ? 8 + : (degree == 4096) ? 16 + : (degree == 8192) ? 32 + : 64); } -constexpr int choose_opt(int degree) { - if (degree <= 1024) - return 4; - else if (degree == 2048) - return 4; - else if (degree == 4096) - return 4; - else if (degree == 8192) - return 8; - else if (degree == 16384) - return 16; - else - return 64; +constexpr int choose_opt(int degree, int radix) { + return tfhe_max(2 * radix, (degree <= 1024) ? 4 + : (degree == 2048) ? 4 + : (degree == 4096) ? 4 + : (degree == 8192) ? 8 + : (degree == 16384) ? 16 + : 64); } -template class HalfDegree { +template class HalfDegree { public: constexpr static int degree = params::degree / 2; - constexpr static int opt = params::opt / 2; + constexpr static int opt = tfhe_max(radix, params::opt / 2); constexpr static int log2_degree = params::log2_degree - 1; }; -template class Degree { +template class Degree { public: constexpr static int degree = N; - constexpr static int opt = choose_opt(N); - constexpr static int log2_degree = log2(N); + constexpr static int opt = choose_opt(N, radix); + constexpr static int log2_degree = tfhe_log2(N); }; -template class AmortizedDegree { +template class AmortizedDegree { public: constexpr static int degree = N; - constexpr static int opt = choose_opt_amortized(N); - constexpr static int log2_degree = log2(N); + constexpr static int opt = choose_opt_amortized(N, radix); + constexpr static int log2_degree = tfhe_log2(N); }; enum sharedMemDegree { NOSM = 0, PARTIALSM = 1, FULLSM = 2 }; @@ -76,8 +70,9 @@ public: constexpr static int fft_reorder = reorder::reorder; }; -template -class FFTParams : public Degree { +template +class FFTParams : public Degree { public: constexpr static int fft_direction = direction::direction; constexpr static int fft_reorder = reorder::reorder; diff --git a/backends/tfhe-cuda-backend/cuda/src/polynomial/polynomial_math.cuh b/backends/tfhe-cuda-backend/cuda/src/polynomial/polynomial_math.cuh index 37a1135346..4eecc00c08 100644 --- a/backends/tfhe-cuda-backend/cuda/src/polynomial/polynomial_math.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/polynomial/polynomial_math.cuh @@ -55,9 +55,6 @@ __device__ void polynomial_product_accumulate_in_fourier_domain( } } -// This method expects to work with polynomial_size / compression_params::opt -// threads in the x-block If init_accumulator is set, assumes that result was -// not initialized and does that with the outcome of first * second template __device__ void polynomial_accumulate_monic_monomial_mul( T *result, const T *__restrict__ poly, uint64_t monomial_degree, @@ -70,10 +67,11 @@ __device__ void polynomial_accumulate_monic_monomial_mul( int pos = tid; for (int i = 0; i < coeff_per_thread; i++) { T element = poly[pos]; - int new_pos = (pos + monomial_degree) % polynomial_size; + int new_pos = (pos + monomial_degree) % (2 * polynomial_size); + bool negate = new_pos >= polynomial_size; + new_pos %= polynomial_size; - T x = SEL(element, -element, full_cycles_count % 2); // monomial coefficient - x = SEL(-x, x, new_pos >= remainder_degrees); + T x = negate ? -element : element; // monomial coefficient if (init_accumulator) result[new_pos] = x; diff --git a/backends/tfhe-cuda-backend/cuda/src/types/complex/operations.cuh b/backends/tfhe-cuda-backend/cuda/src/types/complex/operations.cuh index 0e4f8e11e0..f80b011364 100644 --- a/backends/tfhe-cuda-backend/cuda/src/types/complex/operations.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/types/complex/operations.cuh @@ -49,16 +49,17 @@ __device__ inline double2 operator-(const double2 a, const double2 b) { } __device__ inline double2 operator*(const double2 a, const double2 b) { - double xx = a.x * b.x; - double xy = a.x * b.y; - double yx = a.y * b.x; - double yy = a.y * b.y; + double2 res; + res.x = (a.y * -b.y) + (a.x * b.x); + res.y = (a.x * b.y) + (a.y * b.x); + return res; +} +// a * conj(b) +__device__ inline double2 mul_conj(const double2 a, const double2 b) { double2 res; - // asm volatile("fma.rn.f64 %0, %1, %2, %3;": "=d"(res.x) : "d"(a.x), - // "d"(b.x), "d"(yy)); - res.x = xx - yy; - res.y = xy + yx; + res.x = (a.y * b.y) + (a.x * b.x); + res.y = (a.x * -b.y) + (a.y * b.x); return res; }