Skip to content

Commit

Permalink
feat: make ciphertext inputs global per request
Browse files Browse the repository at this point in the history
Instead of passing in ciphertexts as part of the SyncInput, pass them
globally for the whole request. Decompress them into a map that lives
for the duration of the request. Computations then refer to ciphertexts
in the map via handles, applying to inputs, compressed ciphertexts and
results.

Above also allows for use of results as inputs in subsequent
computations.
  • Loading branch information
dartdart26 committed Aug 30, 2024
1 parent 112ce72 commit eab925b
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 61 deletions.
38 changes: 30 additions & 8 deletions fhevm-engine/executor/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use executor::{
use fhevm_engine_common::{
keys::{FhevmKeys, SerializedFhevmKeys},
tfhe_ops::{current_ciphertext_version, perform_fhe_operation, try_expand_ciphertext_list},
types::{FhevmError, Handle, SupportedFheCiphertexts, HANDLE_LEN, SCALAR_LEN},
types::{get_ct_type, FhevmError, Handle, SupportedFheCiphertexts, HANDLE_LEN, SCALAR_LEN},
};
use sha3::{Digest, Keccak256};
use tfhe::{integer::U256, set_server_key};
Expand Down Expand Up @@ -77,15 +77,23 @@ impl FhevmExecutor for FhevmExecutorService {
SERVER_KEY_IS_SET.set(true);
}

// Exapnd inputs that are global to the whole request.
let req = req.get_ref();
let mut state = ComputationState::default();

// Exapnd inputs that are global to the whole request.
if Self::expand_inputs(&req.input_lists, &keys, &mut state).is_err() {
return SyncComputeResponse {
resp: Some(Resp::Error(SyncComputeError::BadInputList.into())),
};
}

// Decompress ciphertext that are global to the whole request.
if Self::decompress_ciphertexts(&req.ciphertexts, &mut state).is_err() {
return SyncComputeResponse {
resp: Some(Resp::Error(SyncComputeError::BadInputCiphertext.into())),
};
}

// Execute all computations.
let mut result_cts = Vec::new();
for computation in &req.computations {
Expand Down Expand Up @@ -170,6 +178,24 @@ impl FhevmExecutorService {
Ok(())
}

fn decompress_ciphertexts(
cts: &Vec<Ciphertext>,
state: &mut ComputationState,
) -> Result<(), Box<dyn Error>> {
for ct in cts.iter() {
let ct_type = get_ct_type(&ct.handle)?;
let supported_ct = SupportedFheCiphertexts::decompress(ct_type, &ct.ciphertext)?;
state.ciphertexts.insert(
ct.handle.clone(),
InMemoryCiphertext {
expanded: supported_ct,
compressed: ct.ciphertext.clone(),
},
);
}
Ok(())
}

fn get_ciphertext(
comp: &SyncComputation,
result_handle: &Handle,
Expand All @@ -178,7 +204,7 @@ impl FhevmExecutorService {
match (comp.inputs.first(), comp.inputs.len()) {
(
Some(SyncInput {
input: Some(Input::InputHandle(handle)),
input: Some(Input::Handle(handle)),
}),
1,
) => {
Expand Down Expand Up @@ -210,11 +236,7 @@ impl FhevmExecutorService {
.iter()
.map(|sync_input| match &sync_input.input {
Some(input) => match input {
Input::Ciphertext(c) if c.handle.len() == HANDLE_LEN => {
let ct_type = c.handle[30] as i16;
Ok(SupportedFheCiphertexts::decompress(ct_type, &c.ciphertext)?)
}
Input::InputHandle(h) => {
Input::Handle(h) => {
let ct = state.ciphertexts.get(h).ok_or(FhevmError::BadInputs)?;
Ok(ct.expanded.clone())
}
Expand Down
219 changes: 178 additions & 41 deletions fhevm-engine/executor/tests/sync_compute.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use anyhow::{anyhow, Result};
use executor::server::common::FheOperation;
use executor::server::executor::sync_compute_response::Resp;
use executor::server::executor::Ciphertext;
Expand All @@ -13,15 +12,17 @@ use utils::get_test;
mod utils;

#[tokio::test]
async fn get_input_ciphertext() -> Result<()> {
async fn get_input_ciphertext() {
let test = get_test().await;
let mut client = FhevmExecutorClient::connect(test.server_addr.clone()).await?;
let mut client = FhevmExecutorClient::connect(test.server_addr.clone())
.await
.unwrap();
let mut builder = CompactCiphertextListBuilder::new(&test.keys.compact_public_key);
let list = bincode::serialize(&builder.push(10_u8).build())?;
let list = bincode::serialize(&builder.push(10_u8).build()).unwrap();
// TODO: tests for all types and avoiding passing in 2 as an identifier for FheUint8.
let input_handle = test.input_handle(&list, 0, 2);
let sync_input = SyncInput {
input: Some(Input::InputHandle(input_handle.clone())),
input: Some(Input::Handle(input_handle.clone())),
};
let computation = SyncComputation {
operation: FheOperation::FheGetCiphertext.into(),
Expand All @@ -31,55 +32,44 @@ async fn get_input_ciphertext() -> Result<()> {
let req = SyncComputeRequest {
computations: vec![computation],
input_lists: vec![list],
ciphertexts: vec![],
};
let response = client.sync_compute(req).await?;
let response = client.sync_compute(req).await.unwrap();
let sync_compute_response = response.get_ref();
let resp = <Option<Resp> as Clone>::clone(&sync_compute_response.resp)
.ok_or_else(|| anyhow!("resp is None"))?;
let resp = sync_compute_response.resp.clone().unwrap();
match resp {
Resp::ResultCiphertexts(cts) => match (cts.ciphertexts.first(), cts.ciphertexts.len()) {
(Some(ct), 1) => {
if ct.handle != input_handle || ct.ciphertext.is_empty() {
return Err(anyhow!("response handle or ciphertext are unexpected"));
assert!(false, "response handle or ciphertext are unexpected");
}
Ok(())
}
_ => Err(anyhow!("unexpected amount of result ciphertexts returned")),
_ => assert!(false, "no response"),
},
Resp::Error(e) => Err(anyhow!(format!("error response: {}", e))),
Resp::Error(e) => assert!(false, "error: {}", e),
}
}

#[tokio::test]
async fn fhe_compute_two_ciphertexts() -> Result<()> {
async fn compute_on_two_ciphertexts() {
let test = get_test().await;
let mut client = FhevmExecutorClient::connect(test.server_addr.clone()).await?;
let mut client = FhevmExecutorClient::connect(test.server_addr.clone())
.await
.unwrap();
let mut builder = CompactCiphertextListBuilder::new(&test.keys.compact_public_key);
let list = builder.push(10_u16).push(11_u16).build();
let expander = list.expand_with_key(&test.keys.server_key)?;
let ct1 = SupportedFheCiphertexts::FheUint16(
expander
.get(0)
.ok_or(anyhow!("missing ciphertext at index 0"))??,
);
let expander = list.expand_with_key(&test.keys.server_key).unwrap();
let ct1 = SupportedFheCiphertexts::FheUint16(expander.get(0).unwrap().unwrap());
let ct1 = test.compress(ct1);
let ct2 = SupportedFheCiphertexts::FheUint16(
expander
.get(1)
.ok_or(anyhow!("missing ciphertext at index 1"))??,
);
let ct2 = SupportedFheCiphertexts::FheUint16(expander.get(1).unwrap().unwrap());
let ct2 = test.compress(ct2);
let handle1 = test.ciphertext_handle(&ct1, 3);
let sync_input1 = SyncInput {
input: Some(Input::Ciphertext(Ciphertext {
handle: test.ciphertext_handle(&ct1, 3).to_vec(),
ciphertext: ct1,
})),
input: Some(Input::Handle(handle1.clone())),
};
let handle2 = test.ciphertext_handle(&ct2, 3);
let sync_input2 = SyncInput {
input: Some(Input::Ciphertext(Ciphertext {
handle: test.ciphertext_handle(&ct2, 3).to_vec(),
ciphertext: ct2,
})),
input: Some(Input::Handle(handle2.clone())),
};
let computation = SyncComputation {
operation: FheOperation::FheAdd.into(),
Expand All @@ -89,21 +79,168 @@ async fn fhe_compute_two_ciphertexts() -> Result<()> {
let req = SyncComputeRequest {
computations: vec![computation],
input_lists: vec![],
ciphertexts: vec![
Ciphertext {
handle: handle1,
ciphertext: ct1,
},
Ciphertext {
handle: handle2,
ciphertext: ct2,
},
],
};
let response = client.sync_compute(req).await?;
let response = client.sync_compute(req).await.unwrap();
let sync_compute_response = response.get_ref();
let resp = <Option<Resp> as Clone>::clone(&sync_compute_response.resp)
.ok_or_else(|| anyhow!("resp is None"))?;
let resp = sync_compute_response.resp.clone().unwrap();
match resp {
Resp::ResultCiphertexts(cts) => match (cts.ciphertexts.first(), cts.ciphertexts.len()) {
(Some(ct), 1) => {
if ct.handle != vec![0xaa; HANDLE_LEN] || ct.ciphertext.is_empty() {
return Err(anyhow!("response handle or ciphertext are unexpected"));
if ct.handle != vec![0xaa; HANDLE_LEN] {
assert!(false, "response handle is unexpected");
}
let ct = SupportedFheCiphertexts::decompress(3, &ct.ciphertext).unwrap();
match ct
.decrypt(&test.as_ref().keys.client_key.clone().unwrap())
.as_str()
{
"21" => (),
s => assert!(false, "unexpected result: {}", s),
}
}
_ => assert!(false, "unexpected amount of result ciphertexts returned"),
},
Resp::Error(e) => assert!(false, "error response: {}", e),
}
}

#[tokio::test]
async fn compute_on_input_and_ciphertext() {
let test = get_test().await;
let mut client = FhevmExecutorClient::connect(test.server_addr.clone())
.await
.unwrap();
let mut builder_input = CompactCiphertextListBuilder::new(&test.keys.compact_public_key);
let input_list = bincode::serialize(&builder_input.push(10_u16).build()).unwrap();
let mut builder_cts = CompactCiphertextListBuilder::new(&test.keys.compact_public_key);
let list = builder_cts.push(11_u16).build();
let expander = list.expand_with_key(&test.keys.server_key).unwrap();
let ct1 = SupportedFheCiphertexts::FheUint16(expander.get(0).unwrap().unwrap());
let ct1 = test.compress(ct1);
let handle1 = test.ciphertext_handle(&ct1, 3);
let sync_input1 = SyncInput {
input: Some(Input::Handle(handle1.clone())),
};
let handle2 = test.input_handle(&input_list, 0, 3);
let sync_input2 = SyncInput {
input: Some(Input::Handle(handle2.clone())),
};
let computation = SyncComputation {
operation: FheOperation::FheAdd.into(),
result_handles: vec![vec![0xaa; HANDLE_LEN]],
inputs: vec![sync_input1, sync_input2],
};
let req = SyncComputeRequest {
computations: vec![computation],
input_lists: vec![input_list],
ciphertexts: vec![Ciphertext {
handle: handle1,
ciphertext: ct1,
}],
};
let response = client.sync_compute(req).await.unwrap();
let sync_compute_response = response.get_ref();
let resp = sync_compute_response.resp.clone().unwrap();
match resp {
Resp::ResultCiphertexts(cts) => match (cts.ciphertexts.first(), cts.ciphertexts.len()) {
(Some(ct), 1) => {
if ct.handle != vec![0xaa; HANDLE_LEN] {
assert!(false, "response handle is unexpected");
}
let ct = SupportedFheCiphertexts::decompress(3, &ct.ciphertext).unwrap();
match ct
.decrypt(&test.as_ref().keys.client_key.clone().unwrap())
.as_str()
{
"21" => (),
s => assert!(false, "unexpected result: {}", s),
}
}
_ => assert!(false, "unexpected amount of result ciphertexts returned"),
},
Resp::Error(e) => assert!(false, "error response: {}", e),
}
}

#[tokio::test]
async fn compute_on_result_ciphertext() {
let test = get_test().await;
let mut client = FhevmExecutorClient::connect(test.server_addr.clone())
.await
.unwrap();
let mut builder = CompactCiphertextListBuilder::new(&test.keys.compact_public_key);
let list = builder.push(10_u16).push(11_u16).build();
let expander = list.expand_with_key(&test.keys.server_key).unwrap();
let ct1 = SupportedFheCiphertexts::FheUint16(expander.get(0).unwrap().unwrap());
let ct1 = test.compress(ct1);
let ct2 = SupportedFheCiphertexts::FheUint16(expander.get(1).unwrap().unwrap());
let ct2 = test.compress(ct2);
let handle1 = test.ciphertext_handle(&ct1, 3);
let sync_input1 = SyncInput {
input: Some(Input::Handle(handle1.clone())),
};
let handle2 = test.ciphertext_handle(&ct2, 3);
let sync_input2 = SyncInput {
input: Some(Input::Handle(handle2.clone())),
};
let computation1 = SyncComputation {
operation: FheOperation::FheAdd.into(),
result_handles: vec![vec![0xaa; HANDLE_LEN]],
inputs: vec![sync_input1, sync_input2.clone()],
};
let sync_input3 = SyncInput {
input: Some(Input::Handle(vec![0xaa; HANDLE_LEN])),
};
// 10 + 11 = 21. Then, add the 21 result to 11 and expect 32.
let computation2 = SyncComputation {
operation: FheOperation::FheAdd.into(),
result_handles: vec![vec![0xbb; HANDLE_LEN]],
inputs: vec![sync_input3, sync_input2],
};
let req = SyncComputeRequest {
computations: vec![computation1, computation2],
input_lists: vec![],
ciphertexts: vec![
Ciphertext {
handle: handle1,
ciphertext: ct1,
},
Ciphertext {
handle: handle2,
ciphertext: ct2,
},
],
};
let response = client.sync_compute(req).await.unwrap();
let sync_compute_response = response.get_ref();
let resp = sync_compute_response.resp.clone().unwrap();
match resp {
Resp::ResultCiphertexts(cts) => match (cts.ciphertexts.get(1), cts.ciphertexts.len()) {
(Some(ct), 2) => {
if ct.handle != vec![0xbb; HANDLE_LEN] {
assert!(false, "response handle is unexpected");
}
let ct = SupportedFheCiphertexts::decompress(3, &ct.ciphertext).unwrap();
match ct
.decrypt(&test.as_ref().keys.client_key.clone().unwrap())
.as_str()
{
"32" => (),
s => assert!(false, "unexpected result: {}", s),
}
Ok(())
}
_ => Err(anyhow!("unexpected amount of result ciphertexts returned")),
_ => assert!(false, "unexpected amount of result ciphertexts returned"),
},
Resp::Error(e) => Err(anyhow!(format!("error response: {}", e))),
Resp::Error(e) => assert!(false, "error response: {}", e),
}
}
10 changes: 10 additions & 0 deletions fhevm-engine/fhevm-engine-common/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ pub enum FhevmError {
},
BadInputs,
MissingTfheRsData,
InvalidHandle,
}

impl std::error::Error for FhevmError {}
Expand Down Expand Up @@ -192,6 +193,9 @@ impl std::fmt::Display for FhevmError {
Self::MissingTfheRsData => {
write!(f, "Missing TFHE-rs data")
}
Self::InvalidHandle => {
write!(f, "Invalid ciphertext handle")
}
}
}
}
Expand Down Expand Up @@ -460,3 +464,9 @@ pub type Handle = Vec<u8>;
pub const HANDLE_LEN: usize = 32;
pub const SCALAR_LEN: usize = 32;

pub fn get_ct_type(handle: &[u8]) -> Result<i16, FhevmError> {
match handle.len() {
HANDLE_LEN => Ok(handle[30] as i16),
_ => Err(FhevmError::InvalidHandle),
}
}
Loading

0 comments on commit eab925b

Please sign in to comment.