From 8fc0e4188958e44426e8aeec285775ba0e059630 Mon Sep 17 00:00:00 2001 From: David Kazlauskas Date: Mon, 12 Aug 2024 07:54:47 +0300 Subject: [PATCH] Define unary operator tests --- fhevm-engine/src/server.rs | 2 +- fhevm-engine/src/tests/mod.rs | 2 +- fhevm-engine/src/tests/operators.rs | 205 +++++++++++++++++++++++++--- fhevm-engine/src/tfhe_ops.rs | 35 ++++- fhevm-engine/src/types.rs | 16 ++- fhevm-engine/src/utils.rs | 8 ++ proto/coprocessor.proto | 2 + 7 files changed, 237 insertions(+), 33 deletions(-) diff --git a/fhevm-engine/src/server.rs b/fhevm-engine/src/server.rs index 42fce0b1..01b74b22 100644 --- a/fhevm-engine/src/server.rs +++ b/fhevm-engine/src/server.rs @@ -230,7 +230,7 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ // check before we insert computation that it has // to succeed according to the type system - let output_type = check_fhe_operand_types(comp.operation, &handle_types, comp.is_scalar)?; + let output_type = check_fhe_operand_types(comp.operation, &handle_types, comp.is_scalar, &comp.input_handles)?; // fill in types with output handles that are computed as we go assert!(ct_types.insert(comp.output_handle.clone(), output_type).is_none()); } diff --git a/fhevm-engine/src/tests/mod.rs b/fhevm-engine/src/tests/mod.rs index f8044d96..cf31d4e6 100644 --- a/fhevm-engine/src/tests/mod.rs +++ b/fhevm-engine/src/tests/mod.rs @@ -14,7 +14,7 @@ async fn test_smoke() -> Result<(), Box> { let mut client = FhevmCoprocessorClient::connect(app.app_url().to_string()).await?; - let api_key_header = format!("Bearer {}", default_api_key()); + let api_key_header = format!("bearer {}", default_api_key()); let ct_type = 4; // i32 // encrypt two ciphertexts diff --git a/fhevm-engine/src/tests/operators.rs b/fhevm-engine/src/tests/operators.rs index 73aa9d1d..7df8a6e1 100644 --- a/fhevm-engine/src/tests/operators.rs +++ b/fhevm-engine/src/tests/operators.rs @@ -1,7 +1,7 @@ use bigdecimal::num_bigint::BigInt; use strum::IntoEnumIterator; use tonic::metadata::MetadataValue; -use std::str::FromStr; +use std::{ops::Not, str::FromStr}; use crate::{tests::utils::{setup_test_app, default_api_key}, tfhe_ops::{does_fhe_operation_support_both_encrypted_operands, does_fhe_operation_support_scalar}, types::{FheOperationType, SupportedFheOperations}}; use crate::server::coprocessor::fhevm_coprocessor_client::FhevmCoprocessorClient; use crate::server::coprocessor::{AsyncComputation, AsyncComputeRequest, DebugDecryptRequest, DebugEncryptRequest, DebugEncryptRequestSingle}; @@ -20,6 +20,9 @@ struct BinaryOperatorTestCase { struct UnaryOperatorTestCase { bits: i32, inp: BigInt, + operand: i32, + operand_types: i32, + expected_output: BigInt, } fn supported_bits() -> &'static [i32] { @@ -60,6 +63,8 @@ async fn test_fhe_binary_operands() -> Result<(), Box> { let api_key_header = format!("bearer {}", default_api_key()); let mut output_handles = Vec::with_capacity(ops.len()); + let mut enc_request_payload = Vec::with_capacity(ops.len() * 2); + let mut async_computations = Vec::with_capacity(ops.len()); for op in &ops { let lhs_handle = next_handle(); let rhs_handle = if op.is_scalar { @@ -73,13 +78,13 @@ async fn test_fhe_binary_operands() -> Result<(), Box> { println!("Encrypting inputs for binary test bits:{} op:{} is_scalar:{} lhs:{} rhs:{}", op.bits, op.operand, op.is_scalar, op.lhs.to_string(), op.rhs.to_string()); - let mut enc_request_payload = vec![ + enc_request_payload.push( DebugEncryptRequestSingle { handle: lhs_handle.clone(), le_value: lhs_bytes, output_type: op.operand_types, - }, - ]; + } + ); if !op.is_scalar { let (_, rhs_bytes) = op.rhs.to_bytes_le(); enc_request_payload.push(DebugEncryptRequestSingle { @@ -88,32 +93,35 @@ async fn test_fhe_binary_operands() -> Result<(), Box> { output_type: op.operand_types, }); } - let mut encrypt_request = tonic::Request::new(DebugEncryptRequest { - values: enc_request_payload, - }); - encrypt_request.metadata_mut().append("authorization", MetadataValue::from_str(&api_key_header).unwrap()); - let _resp = client.debug_encrypt_ciphertext(encrypt_request).await?; println!("rhs handle:{}", rhs_handle); println!("Scheduling computation for binary test bits:{} op:{} is_scalar:{} lhs:{} rhs:{} output:{}", op.bits, op.operand, op.is_scalar, op.lhs.to_string(), op.rhs.to_string(), op.expected_output.to_string()); - let mut compute_request = tonic::Request::new(AsyncComputeRequest { - computations: vec![ - AsyncComputation { - operation: op.operand, - is_scalar: op.is_scalar, - output_handle: output_handle, - input_handles: vec![ - lhs_handle.clone(), - rhs_handle.clone(), - ] - }, + async_computations.push(AsyncComputation { + operation: op.operand, + is_scalar: op.is_scalar, + output_handle: output_handle, + input_handles: vec![ + lhs_handle.clone(), + rhs_handle.clone(), ] }); - compute_request.metadata_mut().append("authorization", MetadataValue::from_str(&api_key_header).unwrap()); - let _resp = client.async_compute(compute_request).await?; } + println!("Encrypting inputs..."); + let mut encrypt_request = tonic::Request::new(DebugEncryptRequest { + values: enc_request_payload, + }); + encrypt_request.metadata_mut().append("authorization", MetadataValue::from_str(&api_key_header).unwrap()); + let _resp = client.debug_encrypt_ciphertext(encrypt_request).await?; + + println!("Scheduling computations..."); + let mut compute_request = tonic::Request::new(AsyncComputeRequest { + computations: async_computations + }); + compute_request.metadata_mut().append("authorization", MetadataValue::from_str(&api_key_header).unwrap()); + let _resp = client.async_compute(compute_request).await?; + println!("Computations scheduled, waiting upon completion..."); loop { @@ -149,6 +157,107 @@ async fn test_fhe_binary_operands() -> Result<(), Box> { Ok(()) } +#[tokio::test] +async fn test_fhe_unary_operands() -> Result<(), Box> { + let ops = generate_unary_test_cases(); + let app = setup_test_app().await?; + let mut client = FhevmCoprocessorClient::connect(app.app_url().to_string()).await?; + // needed for polling status + let pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(2) + .connect(app.db_url()) + .await?; + + let mut handle_counter = 0; + let mut next_handle = || { + let out = handle_counter; + handle_counter += 1; + format!("{:#08x}", out) + }; + + let api_key_header = format!("bearer {}", default_api_key()); + + let mut output_handles = Vec::with_capacity(ops.len()); + let mut enc_request_payload = Vec::with_capacity(ops.len() * 2); + let mut async_computations = Vec::with_capacity(ops.len()); + for op in &ops { + let input_handle = next_handle(); + let output_handle = next_handle(); + output_handles.push(output_handle.clone()); + + let (_, inp_bytes) = op.inp.to_bytes_le(); + + println!("Encrypting inputs for unary test bits:{} op:{} input:{}", + op.bits, op.operand, op.inp.to_string()); + enc_request_payload.push( + DebugEncryptRequestSingle { + handle: input_handle.clone(), + le_value: inp_bytes, + output_type: op.operand_types, + } + ); + + println!("Scheduling computation for binary test bits:{} op:{} input:{} output:{}", + op.bits, op.operand, op.inp.to_string(), op.expected_output.to_string()); + async_computations.push(AsyncComputation { + operation: op.operand, + is_scalar: false, + output_handle: output_handle, + input_handles: vec![ + input_handle.clone(), + ] + }); + } + + println!("Encrypting inputs..."); + let mut encrypt_request = tonic::Request::new(DebugEncryptRequest { + values: enc_request_payload, + }); + encrypt_request.metadata_mut().append("authorization", MetadataValue::from_str(&api_key_header).unwrap()); + let _resp = client.debug_encrypt_ciphertext(encrypt_request).await?; + + println!("Scheduling computations..."); + let mut compute_request = tonic::Request::new(AsyncComputeRequest { + computations: async_computations + }); + compute_request.metadata_mut().append("authorization", MetadataValue::from_str(&api_key_header).unwrap()); + let _resp = client.async_compute(compute_request).await?; + + println!("Computations scheduled, waiting upon completion..."); + + loop { + tokio::time::sleep(tokio::time::Duration::from_secs(3)).await; + let count = + sqlx::query!("SELECT count(*) FROM computations WHERE NOT is_completed AND NOT is_error") + .fetch_one(&pool) + .await?; + let current_count = count.count.unwrap(); + if current_count == 0 { + println!("All computations completed"); + break; + } else { + println!("{current_count} computations remaining, waiting..."); + } + } + + let mut decrypt_request = tonic::Request::new(DebugDecryptRequest { + handles: output_handles.clone(), + }); + decrypt_request.metadata_mut().append("authorization", MetadataValue::from_str(&api_key_header).unwrap()); + let resp = client.debug_decrypt_ciphertext(decrypt_request).await?; + + assert_eq!(resp.get_ref().values.len(), output_handles.len(), "Outputs length doesn't match"); + for (idx, op) in ops.iter().enumerate() { + let decr_response = &resp.get_ref().values[idx]; + println!("Checking computation for binary test bits:{} op:{} input:{} output:{}", + op.bits, op.operand, op.inp.to_string(), op.expected_output.to_string()); + assert_eq!(decr_response.output_type, op.operand_types, "operand types not equal"); + assert_eq!(decr_response.value, op.expected_output.to_string(), "operand output values not equal"); + } + + Ok(()) +} + fn generate_binary_test_cases() -> Vec { let mut cases = Vec::new(); let mut push_case = |bits: i32, is_scalar: bool, shift_by: i32, op: SupportedFheOperations| { @@ -192,6 +301,58 @@ fn generate_binary_test_cases() -> Vec { cases } +fn generate_unary_test_cases() -> Vec { + let mut cases = Vec::new(); + + for bits in supported_bits() { + let bits = *bits; + let shift_by = bits - 8; + for op in SupportedFheOperations::iter() { + if op.op_type() == FheOperationType::Unary { + let mut inp = BigInt::from(7); + inp <<= shift_by; + let expected_output = compute_expected_unary_output(&inp, op, bits); + let operand = op as i32; + cases.push(UnaryOperatorTestCase { + bits, + operand, + operand_types: supported_bits_to_bit_type_in_db(bits), + inp, + expected_output, + }); + } + } + } + + cases +} + +fn compute_expected_unary_output(inp: &BigInt, op: SupportedFheOperations, bits: i32) -> BigInt { + match op { + SupportedFheOperations::FheNot => { + // TODO: find how this is done appropriately in big int crate + match bits { + 8 => { + let inp: u8 = inp.try_into().unwrap(); + BigInt::from(inp.not()) + } + 16 => { + let inp: u16 = inp.try_into().unwrap(); + BigInt::from(inp.not()) + } + 32 => { + let inp: u32 = inp.try_into().unwrap(); + BigInt::from(inp.not()) + } + other => { + panic!("unknown bits: {other}") + } + } + }, + other => panic!("unsupported binary operation: {:?}", other), + } +} + fn compute_expected_binary_output(lhs: &BigInt, rhs: &BigInt, op: SupportedFheOperations) -> BigInt { match op { SupportedFheOperations::FheAdd => lhs + rhs, diff --git a/fhevm-engine/src/tfhe_ops.rs b/fhevm-engine/src/tfhe_ops.rs index 3310f013..6ecafb22 100644 --- a/fhevm-engine/src/tfhe_ops.rs +++ b/fhevm-engine/src/tfhe_ops.rs @@ -1,6 +1,6 @@ use tfhe::{prelude::FheTryTrivialEncrypt, FheBool, FheUint16, FheUint32, FheUint8}; -use crate::types::{CoprocessorError, FheOperationType, SupportedFheCiphertexts, SupportedFheOperations}; +use crate::{types::{CoprocessorError, FheOperationType, SupportedFheCiphertexts, SupportedFheOperations}, utils::check_if_handle_is_zero}; pub fn current_ciphertext_version() -> i16 { 1 @@ -154,7 +154,24 @@ pub fn perform_fhe_operation(fhe_operation: i16, input_operands: &[SupportedFheC } } }, - SupportedFheOperations::FheNot => todo!(), + SupportedFheOperations::FheNot => { + assert_eq!(input_operands.len(), 1); + + match &input_operands[0] { + SupportedFheCiphertexts::FheUint8(a) => { + Ok(SupportedFheCiphertexts::FheUint8(!a)) + } + SupportedFheCiphertexts::FheUint16(a) => { + Ok(SupportedFheCiphertexts::FheUint16(!a)) + } + SupportedFheCiphertexts::FheUint32(a) => { + Ok(SupportedFheCiphertexts::FheUint32(!a)) + } + _ => { + panic!("Unsupported fhe types"); + } + } + }, SupportedFheOperations::FheIfThenElse => todo!(), } } @@ -213,7 +230,7 @@ pub fn deserialize_fhe_ciphertext(input_type: i16, input_bytes: &[u8]) -> Result } // return output ciphertext type -pub fn check_fhe_operand_types(fhe_operation: i32, input_types: &[i16], is_scalar: bool) -> Result { +pub fn check_fhe_operand_types(fhe_operation: i32, input_types: &[i16], is_scalar: bool, input_handles: &[String]) -> Result { let fhe_op: SupportedFheOperations = fhe_operation.try_into()?; if is_scalar && !does_fhe_operation_support_scalar(&fhe_op) { @@ -245,6 +262,18 @@ pub fn check_fhe_operand_types(fhe_operation: i32, input_types: &[i16], is_scala }); } + // special case for div operation + if is_scalar && fhe_op == SupportedFheOperations::FheDiv { + if check_if_handle_is_zero(input_handles[1].as_str()) { + return Err(CoprocessorError::FheOperationScalarDivisionByZero { + lhs_handle: input_handles[0].clone(), + rhs_value: input_handles[1].clone(), + fhe_operation, + fhe_operation_name: format!("{:?}", SupportedFheOperations::FheDiv), + }); + } + } + return Ok(input_types[0]); } FheOperationType::Unary => { diff --git a/fhevm-engine/src/types.rs b/fhevm-engine/src/types.rs index a8916c0e..84405d98 100644 --- a/fhevm-engine/src/types.rs +++ b/fhevm-engine/src/types.rs @@ -39,12 +39,12 @@ pub enum CoprocessorError { fhe_operation_name: String, operand_types: Vec, }, - // TODO: implement scalar division by zero error - // FheOperationScalarDivisionByZero { - // lhs_handle: String, - // fhe_operation: i32, - // fhe_operation_name: String, - // }, + FheOperationScalarDivisionByZero { + lhs_handle: String, + rhs_value: String, + fhe_operation: i32, + fhe_operation_name: String, + }, } impl std::fmt::Display for CoprocessorError { @@ -101,6 +101,9 @@ impl std::fmt::Display for CoprocessorError { CoprocessorError::TooManyCiphertextsInBatch { maximum_allowed, got } => { write!(f, "maximum ciphertexts exceeded in batch, maximum: {maximum_allowed}, got: {got}") }, + CoprocessorError::FheOperationScalarDivisionByZero { lhs_handle, rhs_value, fhe_operation, fhe_operation_name } => { + write!(f, "zero on the right side of scalar division, lhs handle: {lhs_handle}, rhs value: {rhs_value}, fhe operation: {fhe_operation} fhe operation name:{fhe_operation_name}") + }, } } } @@ -202,6 +205,7 @@ impl TryFrom for SupportedFheOperations { 1 => Ok(SupportedFheOperations::FheSub), 2 => Ok(SupportedFheOperations::FheMul), 3 => Ok(SupportedFheOperations::FheDiv), + 4 => Ok(SupportedFheOperations::FheNot), _ => Err(CoprocessorError::UnknownFheOperation(value as i32)) }; diff --git a/fhevm-engine/src/utils.rs b/fhevm-engine/src/utils.rs index ef545eea..a8698c5c 100644 --- a/fhevm-engine/src/utils.rs +++ b/fhevm-engine/src/utils.rs @@ -3,6 +3,14 @@ use std::collections::{BTreeSet, HashMap, HashSet}; use lazy_static::lazy_static; use crate::{server::coprocessor::AsyncComputation, types::CoprocessorError}; +pub fn check_if_handle_is_zero(inp: &str) -> bool { + lazy_static! { + static ref TARGET_HANDLE_REGEX: regex::Regex = regex::Regex::new("^0x[0]+$").unwrap(); + } + + TARGET_HANDLE_REGEX.is_match(inp) +} + // handle must be serializable to bytes for scalar operations pub fn check_valid_ciphertext_handle(inp: &str) -> Result<(), CoprocessorError> { lazy_static! { diff --git a/proto/coprocessor.proto b/proto/coprocessor.proto index 362cef79..6d2ca574 100644 --- a/proto/coprocessor.proto +++ b/proto/coprocessor.proto @@ -42,6 +42,8 @@ message DebugDecryptResponseSingle { enum FheOperation { FHE_ADD = 0; FHE_SUB = 1; + FHE_MUL = 2; + FHE_DIV = 3; } message AsyncComputation {