Skip to content

Commit

Permalink
Merge pull request #30 from zama-ai/davidk/euint4
Browse files Browse the repository at this point in the history
Adds 4 bit integer support to coprocessor
  • Loading branch information
david-zk authored Sep 16, 2024
2 parents 1e71941 + f700b47 commit 7956808
Show file tree
Hide file tree
Showing 6 changed files with 304 additions and 19 deletions.
1 change: 1 addition & 0 deletions fhevm-engine/coprocessor/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ mod errors;
mod inputs;
mod operators;
mod utils;
mod random;

#[tokio::test]
async fn test_smoke() -> Result<(), Box<dyn std::error::Error>> {
Expand Down
23 changes: 13 additions & 10 deletions fhevm-engine/coprocessor/src/tests/operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ struct UnaryOperatorTestCase {
}

fn supported_bits() -> &'static [i32] {
&[8, 16, 32, 64, 128, 160, 256, 512, 1024, 2048]
&[4, 8, 16, 32, 64, 128, 160, 256, 512, 1024, 2048]
}

pub fn supported_types() -> &'static [i32] {
&[
0, // bool
// 1, TODO: add 4 bit support
0, // bool
1, // 4 bit
2, // 8 bit
3, // 16 bit
4, // 32 bit
Expand All @@ -61,6 +61,7 @@ pub fn supported_types() -> &'static [i32] {

fn supported_bits_to_bit_type_in_db(inp: i32) -> i32 {
match inp {
4 => 1,
8 => 2,
16 => 3,
32 => 4,
Expand Down Expand Up @@ -620,12 +621,12 @@ fn generate_binary_test_cases() -> Vec<BinaryOperatorTestCase> {
SupportedFheOperations::FheRotr,
];
let mut push_case = |bits: i32, is_scalar: bool, shift_by: i32, op: SupportedFheOperations| {
let mut lhs = BigInt::from(12);
let mut rhs = BigInt::from(7);
let mut lhs = BigInt::from(6);
let mut rhs = BigInt::from(2);
lhs <<= shift_by;
// don't shift by much for bit shift opts not to make result 0
if bit_shift_ops.contains(&op) {
rhs = BigInt::from(2);
rhs = BigInt::from(1);
} else {
rhs <<= shift_by;
}
Expand All @@ -651,7 +652,8 @@ fn generate_binary_test_cases() -> Vec<BinaryOperatorTestCase> {

for bits in supported_bits() {
let bits = *bits;
let mut shift_by = bits - 8;
let mut shift_by =
if bits > 4 { bits - 8 } else { 0 };
for op in SupportedFheOperations::iter() {
if bits <= 256 || op.supports_ebytes_inputs() {
if op == SupportedFheOperations::FheMul {
Expand Down Expand Up @@ -679,12 +681,13 @@ fn generate_unary_test_cases() -> Vec<UnaryOperatorTestCase> {

for bits in supported_bits() {
let bits = *bits;
let shift_by = bits - 8;
let shift_by = bits - 3;
let max_bits_value = (BigInt::from(1) << bits) - 1;
for op in SupportedFheOperations::iter() {
if op.op_type() == FheOperationType::Unary {
let mut inp = BigInt::from(7);
let mut inp = BigInt::from(3);
inp <<= shift_by;
let expected_output = compute_expected_unary_output(&inp, op);
let expected_output = compute_expected_unary_output(&inp, op) & &max_bits_value;
let operand = op as i32;
cases.push(UnaryOperatorTestCase {
bits,
Expand Down
11 changes: 7 additions & 4 deletions fhevm-engine/coprocessor/src/tests/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
},
},
tests::utils::{
decrypt_ciphertexts, default_api_key, random_handle_start, setup_test_app,
decrypt_ciphertexts, default_api_key, random_handle, setup_test_app,
wait_until_all_ciphertexts_computed, DecryptionResult,
},
};
Expand All @@ -28,7 +28,7 @@ async fn test_fhe_random_basic() -> Result<(), Box<dyn std::error::Error>> {
.await?;
let mut client = FhevmCoprocessorClient::connect(app.app_url().to_string()).await?;

let mut handle_counter = random_handle_start();
let mut handle_counter = random_handle();
let mut next_handle = || {
let out: u64 = handle_counter;
handle_counter += 1;
Expand Down Expand Up @@ -117,6 +117,7 @@ async fn test_fhe_random_basic() -> Result<(), Box<dyn std::error::Error>> {
let resp = decrypt_ciphertexts(&pool, 1, decrypt_request).await?;
let expected: Vec<DecryptionResult> = vec![
DecryptionResult { value: "true".to_string(), output_type: 0 },
DecryptionResult { value: "15".to_string(), output_type: 1 },
DecryptionResult { value: "191".to_string(), output_type: 2 },
DecryptionResult { value: "31935".to_string(), output_type: 3 },
DecryptionResult { value: "50166975".to_string(), output_type: 4 },
Expand Down Expand Up @@ -160,7 +161,7 @@ async fn test_fhe_random_bounded() -> Result<(), Box<dyn std::error::Error>> {
.await?;
let mut client = FhevmCoprocessorClient::connect(app.app_url().to_string()).await?;

let mut handle_counter = random_handle_start();
let mut handle_counter = random_handle();
let mut next_handle = || {
let out: u64 = handle_counter;
handle_counter += 1;
Expand All @@ -174,7 +175,8 @@ async fn test_fhe_random_bounded() -> Result<(), Box<dyn std::error::Error>> {

let deterministic_seed = 123u8;
let bounds = [
"8",
"2",
"4",
"128",
"16384",
"1073741824",
Expand All @@ -185,6 +187,7 @@ async fn test_fhe_random_bounded() -> Result<(), Box<dyn std::error::Error>> {
];
let results = [
"true",
"3",
"127",
"15551",
"50166975",
Expand Down
2 changes: 1 addition & 1 deletion fhevm-engine/coprocessor/src/tests/utils.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::cli::Args;
use fhevm_engine_common::tfhe_ops::{current_ciphertext_version, deserialize_fhe_ciphertext};
use rand::{Rng, RngCore};
use rand::Rng;
use std::collections::BTreeMap;
use std::sync::atomic::{AtomicU16, Ordering};
use testcontainers::{core::WaitFor, runners::AsyncRunner, GenericImage, ImageExt};
Expand Down
Loading

0 comments on commit 7956808

Please sign in to comment.