Skip to content

Commit

Permalink
feat(integer): improve shift/rotate by encrypted amount
Browse files Browse the repository at this point in the history
This commit does a few things:
* Changes the BitExtractor to use many_lut to reduce number of PBS
  done
* Add blocks rotation/shift operation
* Implement a new algorithm for bit shift/rotation by encrypted amounts
* Add support bit shift/rotation for 1_1 parameters (as result of adding
  block shift/rotation)

The gist of the new bit shift/rotation is to use the same idea as the scalar
version where we first shift blocks between adjacent blocks,
then use a rotation of blocks.

Doing this requires to do a division and modulo operation:
```rust
let (shift_within_blocks, block_rotations) =
  (amount / bits_per_block, amount % bits_per_block)
```
When `amount` is clear this operation is simple, when `amount` is
encrypted then is harder (`bits_per_block` is always clear).
However, when bits_per_block is a power of 2 (e.g 1, 2, 4) `/` and `%`
can be made by shifting and bit-masking, which are simple operations.

This means the new algorithm is only compatible with 1_1, 2_2, 4_4 but
not 3_3.
The new algorithm improves the latency as well as the throughput as
it requires less PBS in total
  • Loading branch information
tmontaigu committed Oct 14, 2024
1 parent e376049 commit ac71973
Show file tree
Hide file tree
Showing 9 changed files with 959 additions and 349 deletions.
169 changes: 140 additions & 29 deletions tfhe/src/integer/server_key/radix_parallel/bit_extractor.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,45 @@
use crate::integer::ServerKey;
use crate::shortint::server_key::LookupTableOwned;
use crate::shortint::server_key::ManyLookupTableOwned;
use crate::shortint::Ciphertext;
use itertools::iproduct;
use rayon::prelude::*;
use std::collections::VecDeque;

/// Extracts bits from a slice of shortint ciphertext
///
/// * Relies on many-lut to do less PBSes
/// * Position of the extracted bit is customizable
/// * Has a buffer system that allows to pre-generate extracted bits such that the next call to
/// `next()` is free, giving better control as to where computations happen
pub(crate) struct BitExtractor<'a> {
bit_extract_luts: Vec<LookupTableOwned>,
input_blocks: std::slice::Iter<'a, Ciphertext>,
bit_extract_luts: Option<ManyLookupTableOwned>,
bits_per_block: usize,
server_key: &'a ServerKey,
buffer: VecDeque<Ciphertext>,
}

impl<'a> Iterator for BitExtractor<'a> {
type Item = Ciphertext;

fn next(&mut self) -> Option<Self::Item> {
let maybe_bit_block = self.buffer.pop_front();
if maybe_bit_block.is_some() {
return maybe_bit_block;
}

self.prepare_next_batch();

self.buffer.pop_front()
}
}

impl<'a> BitExtractor<'a> {
pub(crate) fn new(server_key: &'a ServerKey, bits_per_block: usize) -> Self {
Self::with_final_offset(server_key, bits_per_block, 0)
pub(crate) fn new(
input: &'a [Ciphertext],
server_key: &'a ServerKey,
bits_per_block: usize,
) -> Self {
Self::with_final_offset(input, server_key, bits_per_block, 0)
}

/// Creates a bit extractor that will extract bits from an input ciphertext
Expand All @@ -23,46 +50,130 @@ impl<'a> BitExtractor<'a> {
/// It may be used to align the bit with a certain position to avoid
/// shifting it later (and increasing noise)
pub(crate) fn with_final_offset(
input: &'a [Ciphertext],
server_key: &'a ServerKey,
bits_per_block: usize,
final_offset: usize,
) -> Self {
let bit_extract_luts = (0..bits_per_block)
.into_par_iter()
.map(|i| {
server_key.key.generate_lookup_table(|x| {
let bit_value = (x >> i) & 1;
bit_value << final_offset
assert_eq!(
server_key.message_modulus().0,
server_key.carry_modulus().0,
"BitExtractor requires parameters with carry modulus == message modulus"
);
let bit_extract_luts = if bits_per_block == 1 && final_offset == 0 {
None
} else {
let bit_extract_fns = (0..bits_per_block)
.into_par_iter()
.map(|i| {
move |x: u64| {
let bit_value = (x >> i) & 1u64;
bit_value << final_offset
}
})
})
.collect::<Vec<_>>();
.collect::<Vec<_>>();

let tmp = bit_extract_fns
.iter()
.map(|func| func as &dyn Fn(u64) -> u64)
.collect::<Vec<_>>();

Some(server_key.key.generate_many_lookup_table(tmp.as_slice()))
};

Self {
input_blocks: input.iter(),
bit_extract_luts,
bits_per_block,
server_key,
buffer: VecDeque::with_capacity(2 * bits_per_block),
}
}

pub(crate) fn extract_all_bits(&self, blocks: &[Ciphertext]) -> Vec<Ciphertext> {
let num_blocks = blocks.len();
self.extract_n_bits(blocks, num_blocks * self.bits_per_block)
pub(crate) fn set_source_blocks(&mut self, blocks: &'a [Ciphertext]) {
self.input_blocks = blocks.iter();
self.buffer.clear();
}

pub(crate) fn extract_n_bits(&self, blocks: &[Ciphertext], n: usize) -> Vec<Ciphertext> {
let num_blocks = blocks.len();
let mut bits = Vec::with_capacity(n);
let jobs = iproduct!(0..num_blocks, 0..self.bits_per_block)
.take(n)
.collect::<Vec<_>>();
jobs.into_par_iter()
.map(|(block_index, bit_index)| {
let block = &blocks[block_index];
let lut = &self.bit_extract_luts[bit_index];
self.server_key.key.apply_lookup_table(block, lut)
})
.collect_into_vec(&mut bits);
pub(crate) fn prepare_next_batch(&mut self) {
if self.buffer.is_empty() {
let Some(next_block_to_extract_from) = self.input_blocks.next() else {
return;
};

match &self.bit_extract_luts {
None => self.buffer.push_back(next_block_to_extract_from.clone()),
Some(bit_extract_luts) => {
let new_bits = self
.server_key
.key
.apply_many_lookup_table(next_block_to_extract_from, bit_extract_luts);

self.buffer.extend(new_bits);
}
}
}
}

pub(crate) fn prepare_all_bits(&mut self) {
let num_blocks = self.input_blocks.len();
self.prepare_n_bits(num_blocks * self.bits_per_block)
}

/// Extract the remaining `n` bits in parallel from the current source blocks
/// and place them into the internal buffer
///
/// # Panics
///
/// Panics if the current slice of blocks has less than n bits available
pub(crate) fn prepare_n_bits(&mut self, n: usize) {
if self.buffer.len() >= n {
return;
}

let num_bits_to_extract = n - self.buffer.len();
let num_blocks_to_process = num_bits_to_extract.div_ceil(self.bits_per_block);
let blocks = self.input_blocks.as_slice();

if let Some(bit_extract_luts) = &self.bit_extract_luts {
let mut new_bits = blocks[..num_blocks_to_process]
.par_iter()
.flat_map(|block| {
self.server_key
.key
.apply_many_lookup_table(block, bit_extract_luts)
})
.collect::<Vec<_>>();
self.buffer.extend(new_bits.drain(..));
} else {
let iterator = blocks[..num_blocks_to_process].iter().cloned();
self.buffer.extend(iterator);
};

// We have to advance our internal iterator
self.input_blocks = blocks[num_blocks_to_process..].iter();
}

/// Extract all the remaining bits in parallel from the current source blocks
pub(crate) fn extract_all_bits(&mut self) -> Vec<Ciphertext> {
let num_blocks = self.input_blocks.len();
self.extract_n_bits(num_blocks * self.bits_per_block)
}

/// Extract the remaining `n` bits in parallel from the current source blocks
///
/// # Panics
///
/// Panics if the current slice of blocks has less than n bits available
pub(crate) fn extract_n_bits(&mut self, n: usize) -> Vec<Ciphertext> {
self.prepare_n_bits(n);

let mut bits = Vec::with_capacity(n);
bits.extend(self.buffer.drain(0..n));
bits
}

pub(crate) fn current_buffer_len(&self) -> usize {
self.buffer.len()
}
}
Loading

0 comments on commit ac71973

Please sign in to comment.