Skip to content

Commit

Permalink
finally working with lvl1param!
Browse files Browse the repository at this point in the history
  • Loading branch information
nindanaoto committed Jun 20, 2023
1 parent 125b24b commit a32cf4b
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 50 deletions.
8 changes: 4 additions & 4 deletions include/gatebootstrapping_gpu.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ __device__ constexpr typename P::T offsetgen()
}

template <class P>
__device__ inline void RotatedTestVector(TFHEpp::lvl1param::T* tlwe,
__device__ inline void RotatedTestVector(typename P::T* tlwe,
const int32_t bar,
const typename P::T μ)
{
Expand Down Expand Up @@ -202,13 +202,13 @@ __device__ inline void __BlindRotatePreAdd__(typename P::targetP::T* const out,
{
const uint32_t bar =
2 * P::targetP::n -
modSwitchFromTorus<P::targetP>(offset + casign * in0[P::domainP::k*P::domainP::n] +
modSwitchFromTorus<typename P::targetP>(offset + casign * in0[P::domainP::k*P::domainP::n] +
cbsign * in1[P::domainP::k*P::domainP::n]);
RotatedTestVector<P::targetP>(out, bar, P::targetP::μ);
RotatedTestVector<typename P::targetP>(out, bar, P::targetP::μ);
}

// accumulate
for (int i = 0; i < P::domainP::n; i++) { // lvl0param::n iterations
for (int i = 0; i < P::domainP::k*P::domainP::n; i++) { // lvl0param::n iterations
const uint32_t bar = modSwitchFromTorus<P::targetP>(0 + casign * in0[i] +
cbsign * in1[i]);
Accumulate<P>(out, sh_acc_ntt, bar,
Expand Down
4 changes: 2 additions & 2 deletions include/keyswitch_gpu.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ __device__ inline void KeySwitch(typename P::targetP::T* const lwe,
const uint32_t bdim = ThisBlockSize();
for (int i = tid; i <= P::targetP::k*P::targetP::n; i += bdim) {
typename P::targetP::T res = 0;
if (i == P::targetP::n) res = tlwe[P::domainP::n];
if (i == P::targetP::k*P::targetP::n) res = tlwe[P::domainP::k*P::domainP::n];
for (int j = 0; j < P::domainP::k*P::domainP::n; j++) {
typename P::domainP::T tmp;
if (j == 0)
Expand Down Expand Up @@ -66,7 +66,7 @@ __device__ inline void IdentityKeySwitchPreAdd(typename P::targetP::T* const lwe
if (i == P::targetP::k*P::targetP::n) res = casign*ina[P::domainP::k*P::domainP::n]+ cbsign*inb[P::domainP::k*P::domainP::n] + offset;
for (int j = 0; j < P::domainP::k*P::domainP::n; j++) {
typename P::domainP::T tmp;
tmp = casign*ina[j]+ cbsign*inb[j] + decomp_offset;
tmp = casign*ina[j]+ cbsign*inb[j] + 0 + decomp_offset;
for (int k = 0; k < P::t; k++) {
typename P::domainP::T val =
(tmp >>
Expand Down
44 changes: 27 additions & 17 deletions src/bootstrap_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <include/ntt_gpu/ntt.cuh>
#include <limits>
#include <vector>
#include <algorithm>

namespace cufhe {
template<class P = TFHEpp::lvl1param>
Expand Down Expand Up @@ -321,8 +322,8 @@ __device__ inline void __SampleExtractIndex__(typename P::T* const res, const ty
}else {
const uint k = i >> P::nbit;
const uint n = i & nmask;
if (n <= index) res[index] = in[k*P::n + index - n];
else res[index] = -in[k*P::n + P::n + index-n];
if (n <= index) res[i] = in[k*P::n + index - n];
else res[i] = -in[k*P::n + P::n + index-n];
}
}
}
Expand All @@ -335,9 +336,12 @@ __device__ inline void __HomGate__(typename brP::targetP::T* const out,
const CuNTTHandler<> ntt)
{
__shared__ typename iksP::targetP::T tlwe[iksP::targetP::k*iksP::targetP::n+1];
__shared__ typename brP::targetP::T trlwe[(brP::targetP::k+1)*brP::targetP::n];

IdentityKeySwitchPreAdd<iksP, casign, cbsign, offset>(tlwe, in0, in1, ksk);
__syncthreads();

__shared__ typename brP::targetP::T trlwe[(brP::targetP::k+1)*brP::targetP::n];

__BlindRotate__<brP>(trlwe, tlwe, μ, bk,ntt);
__SampleExtractIndex__<typename brP::targetP,0>(out,trlwe);
__threadfence();
Expand Down Expand Up @@ -601,8 +605,10 @@ template<class P>
__global__ __launch_bounds__(NUM_THREAD4HOMGATE) void __CopyBootstrap__(
typename P::T* const out, const typename P::T* const in)
{
const uint32_t tid = ThisThreadRankInBlock();
out[tid] = in[tid];
const uint tid = ThisThreadRankInBlock();
const uint bdim = ThisBlockSize();
for (int i = tid; i <= P::k*P::n; i += bdim)
out[i] = in[i];
__syncthreads();
__threadfence();
}
Expand All @@ -611,8 +617,10 @@ template<class P>
__global__ __launch_bounds__(NUM_THREAD4HOMGATE) void __NotBootstrap__(
typename P::T* const out, const typename P::T* const in)
{
const uint32_t tid = ThisThreadRankInBlock();
out[tid] = -in[tid];
const uint tid = ThisThreadRankInBlock();
const uint bdim = ThisBlockSize();
for (int i = tid; i <= P::k*P::n; i += bdim)
out[i] = -in[i];
__syncthreads();
__threadfence();
}
Expand All @@ -627,17 +635,18 @@ __global__ __launch_bounds__(NUM_THREAD4HOMGATE) void __MuxBootstrap__(
__shared__ typename iksP::targetP::T tlwelvl0[iksP::targetP::k*iksP::targetP::n+1];

IdentityKeySwitchPreAdd<iksP, 1, 1, -iksP::domainP::μ>(tlwelvl0, inc, in1, ksk);
__threadfence();
__syncthreads();
__shared__ typename brP::targetP::T tlwe1[(brP::targetP::k+1)*brP::targetP::n];
__BlindRotate__<brP>(tlwe1,tlwelvl0,μ,bk,ntt);
__SampleExtractIndex__<typename brP::targetP,0>(out, tlwe1);

IdentityKeySwitchPreAdd<iksP, 1, 1, -iksP::domainP::μ>(tlwelvl0, inc, in0, ksk);
__threadfence();
IdentityKeySwitchPreAdd<iksP, -1, 1, -iksP::domainP::μ>(tlwelvl0, inc, in0, ksk);
__syncthreads();
__shared__ typename brP::targetP::T tlwe0[(brP::targetP::k+1)*brP::targetP::n];
__BlindRotate__<brP>(tlwe0,tlwelvl0,μ,bk,ntt);
__SampleExtractIndex__<typename brP::targetP,0>(tlwe1, tlwe0);
__threadfence();

__syncthreads();

volatile const uint32_t tid = ThisThreadRankInBlock();
volatile const uint32_t bdim = ThisBlockSize();
Expand All @@ -661,18 +670,19 @@ __global__ __launch_bounds__(NUM_THREAD4HOMGATE) void __NMuxBootstrap__(
__shared__ typename iksP::targetP::T tlwelvl0[iksP::targetP::k*iksP::targetP::n+1];

IdentityKeySwitchPreAdd<iksP, 1, 1, -iksP::domainP::μ>(tlwelvl0, inc, in1, ksk);
__threadfence();
__syncthreads();
__shared__ typename brP::targetP::T tlwe1[(brP::targetP::k+1)*brP::targetP::n];
__BlindRotate__<brP>(tlwe1,tlwelvl0,μ,bk,ntt);
__SampleExtractIndex__<typename brP::targetP,0>(out, tlwe1);

IdentityKeySwitchPreAdd<iksP, 1, 1, -iksP::domainP::μ>(tlwelvl0, inc, in0, ksk);
__threadfence();
IdentityKeySwitchPreAdd<iksP, -1, 1, -iksP::domainP::μ>(tlwelvl0, inc, in0, ksk);
__syncthreads();
__shared__ typename brP::targetP::T tlwe0[(brP::targetP::k+1)*brP::targetP::n];
__BlindRotate__<brP>(tlwe0,tlwelvl0,μ,bk,ntt);
__SampleExtractIndex__<typename brP::targetP,0>(tlwe1, tlwe0);

__threadfence();
__syncthreads();


volatile const uint32_t tid = ThisThreadRankInBlock();
volatile const uint32_t bdim = ThisBlockSize();
Expand Down Expand Up @@ -1090,7 +1100,7 @@ template<class P>
void CopyBootstrap(typename P::T* const out, const typename P::T* const in,
const cudaStream_t st, const int gpuNum)
{
__CopyBootstrap__<P><<<1, P::n + 1, 0, st>>>(out, in);
__CopyBootstrap__<P><<<1, std::min(P::n + 1,NUM_THREAD4HOMGATE), 0, st>>>(out, in);
CuCheckError();
}
#define INST(P) \
Expand All @@ -1104,7 +1114,7 @@ template<class P>
void NotBootstrap(typename P::T* const out, const typename P::T* const in,
const cudaStream_t st, const int gpuNum)
{
__NotBootstrap__<P><<<1, P::n + 1, 0, st>>>(out, in);
__NotBootstrap__<P><<<1, std::min(P::n + 1,NUM_THREAD4HOMGATE), 0, st>>>(out, in);
CuCheckError();
}
#define INST(P) \
Expand Down
Loading

0 comments on commit a32cf4b

Please sign in to comment.