Skip to content

Commit

Permalink
Define unary operator tests
Browse files Browse the repository at this point in the history
  • Loading branch information
david-zk committed Aug 12, 2024
1 parent db74ffd commit 8fc0e41
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 33 deletions.
2 changes: 1 addition & 1 deletion fhevm-engine/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ

// check before we insert computation that it has
// to succeed according to the type system
let output_type = check_fhe_operand_types(comp.operation, &handle_types, comp.is_scalar)?;
let output_type = check_fhe_operand_types(comp.operation, &handle_types, comp.is_scalar, &comp.input_handles)?;
// fill in types with output handles that are computed as we go
assert!(ct_types.insert(comp.output_handle.clone(), output_type).is_none());
}
Expand Down
2 changes: 1 addition & 1 deletion fhevm-engine/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async fn test_smoke() -> Result<(), Box<dyn std::error::Error>> {

let mut client = FhevmCoprocessorClient::connect(app.app_url().to_string()).await?;

let api_key_header = format!("Bearer {}", default_api_key());
let api_key_header = format!("bearer {}", default_api_key());
let ct_type = 4; // i32

// encrypt two ciphertexts
Expand Down
205 changes: 183 additions & 22 deletions fhevm-engine/src/tests/operators.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use bigdecimal::num_bigint::BigInt;
use strum::IntoEnumIterator;
use tonic::metadata::MetadataValue;
use std::str::FromStr;
use std::{ops::Not, str::FromStr};
use crate::{tests::utils::{setup_test_app, default_api_key}, tfhe_ops::{does_fhe_operation_support_both_encrypted_operands, does_fhe_operation_support_scalar}, types::{FheOperationType, SupportedFheOperations}};
use crate::server::coprocessor::fhevm_coprocessor_client::FhevmCoprocessorClient;
use crate::server::coprocessor::{AsyncComputation, AsyncComputeRequest, DebugDecryptRequest, DebugEncryptRequest, DebugEncryptRequestSingle};
Expand All @@ -20,6 +20,9 @@ struct BinaryOperatorTestCase {
struct UnaryOperatorTestCase {
bits: i32,
inp: BigInt,
operand: i32,
operand_types: i32,
expected_output: BigInt,
}

fn supported_bits() -> &'static [i32] {
Expand Down Expand Up @@ -60,6 +63,8 @@ async fn test_fhe_binary_operands() -> Result<(), Box<dyn std::error::Error>> {
let api_key_header = format!("bearer {}", default_api_key());

let mut output_handles = Vec::with_capacity(ops.len());
let mut enc_request_payload = Vec::with_capacity(ops.len() * 2);
let mut async_computations = Vec::with_capacity(ops.len());
for op in &ops {
let lhs_handle = next_handle();
let rhs_handle = if op.is_scalar {
Expand All @@ -73,13 +78,13 @@ async fn test_fhe_binary_operands() -> Result<(), Box<dyn std::error::Error>> {

println!("Encrypting inputs for binary test bits:{} op:{} is_scalar:{} lhs:{} rhs:{}",
op.bits, op.operand, op.is_scalar, op.lhs.to_string(), op.rhs.to_string());
let mut enc_request_payload = vec![
enc_request_payload.push(
DebugEncryptRequestSingle {
handle: lhs_handle.clone(),
le_value: lhs_bytes,
output_type: op.operand_types,
},
];
}
);
if !op.is_scalar {
let (_, rhs_bytes) = op.rhs.to_bytes_le();
enc_request_payload.push(DebugEncryptRequestSingle {
Expand All @@ -88,32 +93,35 @@ async fn test_fhe_binary_operands() -> Result<(), Box<dyn std::error::Error>> {
output_type: op.operand_types,
});
}
let mut encrypt_request = tonic::Request::new(DebugEncryptRequest {
values: enc_request_payload,
});
encrypt_request.metadata_mut().append("authorization", MetadataValue::from_str(&api_key_header).unwrap());
let _resp = client.debug_encrypt_ciphertext(encrypt_request).await?;

println!("rhs handle:{}", rhs_handle);
println!("Scheduling computation for binary test bits:{} op:{} is_scalar:{} lhs:{} rhs:{} output:{}",
op.bits, op.operand, op.is_scalar, op.lhs.to_string(), op.rhs.to_string(), op.expected_output.to_string());
let mut compute_request = tonic::Request::new(AsyncComputeRequest {
computations: vec![
AsyncComputation {
operation: op.operand,
is_scalar: op.is_scalar,
output_handle: output_handle,
input_handles: vec![
lhs_handle.clone(),
rhs_handle.clone(),
]
},
async_computations.push(AsyncComputation {
operation: op.operand,
is_scalar: op.is_scalar,
output_handle: output_handle,
input_handles: vec![
lhs_handle.clone(),
rhs_handle.clone(),
]
});
compute_request.metadata_mut().append("authorization", MetadataValue::from_str(&api_key_header).unwrap());
let _resp = client.async_compute(compute_request).await?;
}

println!("Encrypting inputs...");
let mut encrypt_request = tonic::Request::new(DebugEncryptRequest {
values: enc_request_payload,
});
encrypt_request.metadata_mut().append("authorization", MetadataValue::from_str(&api_key_header).unwrap());
let _resp = client.debug_encrypt_ciphertext(encrypt_request).await?;

println!("Scheduling computations...");
let mut compute_request = tonic::Request::new(AsyncComputeRequest {
computations: async_computations
});
compute_request.metadata_mut().append("authorization", MetadataValue::from_str(&api_key_header).unwrap());
let _resp = client.async_compute(compute_request).await?;

println!("Computations scheduled, waiting upon completion...");

loop {
Expand Down Expand Up @@ -149,6 +157,107 @@ async fn test_fhe_binary_operands() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}

#[tokio::test]
async fn test_fhe_unary_operands() -> Result<(), Box<dyn std::error::Error>> {
let ops = generate_unary_test_cases();
let app = setup_test_app().await?;
let mut client = FhevmCoprocessorClient::connect(app.app_url().to_string()).await?;
// needed for polling status
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(2)
.connect(app.db_url())
.await?;

let mut handle_counter = 0;
let mut next_handle = || {
let out = handle_counter;
handle_counter += 1;
format!("{:#08x}", out)
};

let api_key_header = format!("bearer {}", default_api_key());

let mut output_handles = Vec::with_capacity(ops.len());
let mut enc_request_payload = Vec::with_capacity(ops.len() * 2);
let mut async_computations = Vec::with_capacity(ops.len());
for op in &ops {
let input_handle = next_handle();
let output_handle = next_handle();
output_handles.push(output_handle.clone());

let (_, inp_bytes) = op.inp.to_bytes_le();

println!("Encrypting inputs for unary test bits:{} op:{} input:{}",
op.bits, op.operand, op.inp.to_string());
enc_request_payload.push(
DebugEncryptRequestSingle {
handle: input_handle.clone(),
le_value: inp_bytes,
output_type: op.operand_types,
}
);

println!("Scheduling computation for binary test bits:{} op:{} input:{} output:{}",
op.bits, op.operand, op.inp.to_string(), op.expected_output.to_string());
async_computations.push(AsyncComputation {
operation: op.operand,
is_scalar: false,
output_handle: output_handle,
input_handles: vec![
input_handle.clone(),
]
});
}

println!("Encrypting inputs...");
let mut encrypt_request = tonic::Request::new(DebugEncryptRequest {
values: enc_request_payload,
});
encrypt_request.metadata_mut().append("authorization", MetadataValue::from_str(&api_key_header).unwrap());
let _resp = client.debug_encrypt_ciphertext(encrypt_request).await?;

println!("Scheduling computations...");
let mut compute_request = tonic::Request::new(AsyncComputeRequest {
computations: async_computations
});
compute_request.metadata_mut().append("authorization", MetadataValue::from_str(&api_key_header).unwrap());
let _resp = client.async_compute(compute_request).await?;

println!("Computations scheduled, waiting upon completion...");

loop {
tokio::time::sleep(tokio::time::Duration::from_secs(3)).await;
let count =
sqlx::query!("SELECT count(*) FROM computations WHERE NOT is_completed AND NOT is_error")
.fetch_one(&pool)
.await?;
let current_count = count.count.unwrap();
if current_count == 0 {
println!("All computations completed");
break;
} else {
println!("{current_count} computations remaining, waiting...");
}
}

let mut decrypt_request = tonic::Request::new(DebugDecryptRequest {
handles: output_handles.clone(),
});
decrypt_request.metadata_mut().append("authorization", MetadataValue::from_str(&api_key_header).unwrap());
let resp = client.debug_decrypt_ciphertext(decrypt_request).await?;

assert_eq!(resp.get_ref().values.len(), output_handles.len(), "Outputs length doesn't match");
for (idx, op) in ops.iter().enumerate() {
let decr_response = &resp.get_ref().values[idx];
println!("Checking computation for binary test bits:{} op:{} input:{} output:{}",
op.bits, op.operand, op.inp.to_string(), op.expected_output.to_string());
assert_eq!(decr_response.output_type, op.operand_types, "operand types not equal");
assert_eq!(decr_response.value, op.expected_output.to_string(), "operand output values not equal");
}

Ok(())
}

fn generate_binary_test_cases() -> Vec<BinaryOperatorTestCase> {
let mut cases = Vec::new();
let mut push_case = |bits: i32, is_scalar: bool, shift_by: i32, op: SupportedFheOperations| {
Expand Down Expand Up @@ -192,6 +301,58 @@ fn generate_binary_test_cases() -> Vec<BinaryOperatorTestCase> {
cases
}

fn generate_unary_test_cases() -> Vec<UnaryOperatorTestCase> {
let mut cases = Vec::new();

for bits in supported_bits() {
let bits = *bits;
let shift_by = bits - 8;
for op in SupportedFheOperations::iter() {
if op.op_type() == FheOperationType::Unary {
let mut inp = BigInt::from(7);
inp <<= shift_by;
let expected_output = compute_expected_unary_output(&inp, op, bits);
let operand = op as i32;
cases.push(UnaryOperatorTestCase {
bits,
operand,
operand_types: supported_bits_to_bit_type_in_db(bits),
inp,
expected_output,
});
}
}
}

cases
}

fn compute_expected_unary_output(inp: &BigInt, op: SupportedFheOperations, bits: i32) -> BigInt {
match op {
SupportedFheOperations::FheNot => {
// TODO: find how this is done appropriately in big int crate
match bits {
8 => {
let inp: u8 = inp.try_into().unwrap();
BigInt::from(inp.not())
}
16 => {
let inp: u16 = inp.try_into().unwrap();
BigInt::from(inp.not())
}
32 => {
let inp: u32 = inp.try_into().unwrap();
BigInt::from(inp.not())
}
other => {
panic!("unknown bits: {other}")
}
}
},
other => panic!("unsupported binary operation: {:?}", other),
}
}

fn compute_expected_binary_output(lhs: &BigInt, rhs: &BigInt, op: SupportedFheOperations) -> BigInt {
match op {
SupportedFheOperations::FheAdd => lhs + rhs,
Expand Down
35 changes: 32 additions & 3 deletions fhevm-engine/src/tfhe_ops.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use tfhe::{prelude::FheTryTrivialEncrypt, FheBool, FheUint16, FheUint32, FheUint8};

use crate::types::{CoprocessorError, FheOperationType, SupportedFheCiphertexts, SupportedFheOperations};
use crate::{types::{CoprocessorError, FheOperationType, SupportedFheCiphertexts, SupportedFheOperations}, utils::check_if_handle_is_zero};

pub fn current_ciphertext_version() -> i16 {
1
Expand Down Expand Up @@ -154,7 +154,24 @@ pub fn perform_fhe_operation(fhe_operation: i16, input_operands: &[SupportedFheC
}
}
},
SupportedFheOperations::FheNot => todo!(),
SupportedFheOperations::FheNot => {
assert_eq!(input_operands.len(), 1);

match &input_operands[0] {
SupportedFheCiphertexts::FheUint8(a) => {
Ok(SupportedFheCiphertexts::FheUint8(!a))
}
SupportedFheCiphertexts::FheUint16(a) => {
Ok(SupportedFheCiphertexts::FheUint16(!a))
}
SupportedFheCiphertexts::FheUint32(a) => {
Ok(SupportedFheCiphertexts::FheUint32(!a))
}
_ => {
panic!("Unsupported fhe types");
}
}
},
SupportedFheOperations::FheIfThenElse => todo!(),
}
}
Expand Down Expand Up @@ -213,7 +230,7 @@ pub fn deserialize_fhe_ciphertext(input_type: i16, input_bytes: &[u8]) -> Result
}

// return output ciphertext type
pub fn check_fhe_operand_types(fhe_operation: i32, input_types: &[i16], is_scalar: bool) -> Result<i16, CoprocessorError> {
pub fn check_fhe_operand_types(fhe_operation: i32, input_types: &[i16], is_scalar: bool, input_handles: &[String]) -> Result<i16, CoprocessorError> {
let fhe_op: SupportedFheOperations = fhe_operation.try_into()?;

if is_scalar && !does_fhe_operation_support_scalar(&fhe_op) {
Expand Down Expand Up @@ -245,6 +262,18 @@ pub fn check_fhe_operand_types(fhe_operation: i32, input_types: &[i16], is_scala
});
}

// special case for div operation
if is_scalar && fhe_op == SupportedFheOperations::FheDiv {
if check_if_handle_is_zero(input_handles[1].as_str()) {
return Err(CoprocessorError::FheOperationScalarDivisionByZero {
lhs_handle: input_handles[0].clone(),
rhs_value: input_handles[1].clone(),
fhe_operation,
fhe_operation_name: format!("{:?}", SupportedFheOperations::FheDiv),
});
}
}

return Ok(input_types[0]);
}
FheOperationType::Unary => {
Expand Down
16 changes: 10 additions & 6 deletions fhevm-engine/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ pub enum CoprocessorError {
fhe_operation_name: String,
operand_types: Vec<i16>,
},
// TODO: implement scalar division by zero error
// FheOperationScalarDivisionByZero {
// lhs_handle: String,
// fhe_operation: i32,
// fhe_operation_name: String,
// },
FheOperationScalarDivisionByZero {
lhs_handle: String,
rhs_value: String,
fhe_operation: i32,
fhe_operation_name: String,
},
}

impl std::fmt::Display for CoprocessorError {
Expand Down Expand Up @@ -101,6 +101,9 @@ impl std::fmt::Display for CoprocessorError {
CoprocessorError::TooManyCiphertextsInBatch { maximum_allowed, got } => {
write!(f, "maximum ciphertexts exceeded in batch, maximum: {maximum_allowed}, got: {got}")
},
CoprocessorError::FheOperationScalarDivisionByZero { lhs_handle, rhs_value, fhe_operation, fhe_operation_name } => {
write!(f, "zero on the right side of scalar division, lhs handle: {lhs_handle}, rhs value: {rhs_value}, fhe operation: {fhe_operation} fhe operation name:{fhe_operation_name}")
},
}
}
}
Expand Down Expand Up @@ -202,6 +205,7 @@ impl TryFrom<i16> for SupportedFheOperations {
1 => Ok(SupportedFheOperations::FheSub),
2 => Ok(SupportedFheOperations::FheMul),
3 => Ok(SupportedFheOperations::FheDiv),
4 => Ok(SupportedFheOperations::FheNot),
_ => Err(CoprocessorError::UnknownFheOperation(value as i32))
};

Expand Down
8 changes: 8 additions & 0 deletions fhevm-engine/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@ use std::collections::{BTreeSet, HashMap, HashSet};
use lazy_static::lazy_static;
use crate::{server::coprocessor::AsyncComputation, types::CoprocessorError};

pub fn check_if_handle_is_zero(inp: &str) -> bool {
lazy_static! {
static ref TARGET_HANDLE_REGEX: regex::Regex = regex::Regex::new("^0x[0]+$").unwrap();
}

TARGET_HANDLE_REGEX.is_match(inp)
}

// handle must be serializable to bytes for scalar operations
pub fn check_valid_ciphertext_handle(inp: &str) -> Result<(), CoprocessorError> {
lazy_static! {
Expand Down
Loading

0 comments on commit 8fc0e41

Please sign in to comment.