Skip to content

Commit

Permalink
feat: add FHE computation to the executor
Browse files Browse the repository at this point in the history
Error handling is still rough, can be simplified further.
  • Loading branch information
dartdart26 committed Aug 29, 2024
1 parent 7640c2f commit aea93ce
Show file tree
Hide file tree
Showing 9 changed files with 243 additions and 60 deletions.
3 changes: 3 additions & 0 deletions fhevm-engine/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions fhevm-engine/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ prost = "0.13"
tonic = { version = "0.12", features = ["server"] }
bincode = "1.3.3"
sha3 = "0.10.8"
anyhow = "1.0.86"

1 change: 1 addition & 0 deletions fhevm-engine/executor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ tonic.workspace = true
tfhe.workspace = true
bincode.workspace = true
sha3.workspace = true
anyhow.workspace = true
fhevm-engine-common = { path = "../fhevm-engine-common" }

[build-dependencies]
Expand Down
135 changes: 95 additions & 40 deletions fhevm-engine/executor/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ use executor::{
};
use fhevm_engine_common::{
keys::{FhevmKeys, SerializedFhevmKeys},
tfhe_ops::{current_ciphertext_version, try_expand_ciphertext_list},
types::{FhevmError, Handle, SupportedFheCiphertexts},
tfhe_ops::{current_ciphertext_version, perform_fhe_operation, try_expand_ciphertext_list},
types::{FhevmError, Handle, SupportedFheCiphertexts, HANDLE_LEN, SCALAR_LEN},
};
use sha3::{Digest, Keccak256};
use tfhe::set_server_key;
use tfhe::{integer::U256, set_server_key};
use tokio::task::spawn_blocking;
use tonic::{transport::Server, Code, Request, Response, Status};

Expand Down Expand Up @@ -56,20 +56,6 @@ struct ComputationState {
ciphertexts: HashMap<Handle, InMemoryCiphertext>,
}

fn error_response(error: SyncComputeError) -> SyncComputeResponse {
SyncComputeResponse {
resp: Some(Resp::Error(error.into())),
}
}

fn success_response(cts: Vec<Ciphertext>) -> SyncComputeResponse {
SyncComputeResponse {
resp: Some(Resp::ResultCiphertexts(ResultCiphertexts {
ciphertexts: cts,
})),
}
}

struct FhevmExecutorService {
keys: Arc<FhevmKeys>,
}
Expand All @@ -95,24 +81,30 @@ impl FhevmExecutor for FhevmExecutorService {
let req = req.get_ref();
let mut state = ComputationState::default();
if Self::expand_inputs(&req.input_lists, &keys, &mut state).is_err() {
return error_response(SyncComputeError::BadInputList);
return SyncComputeResponse {
resp: Some(Resp::Error(SyncComputeError::BadInputList.into())),
};
}

// Execute all computations.
let mut result_cts = Vec::new();
for computation in &req.computations {
let outcome = Self::process_computation(computation, &mut state);
// Either all succeed or we return on the first failure.
match outcome.resp.unwrap() {
Resp::Error(error) => {
return error_response(
SyncComputeError::try_from(error).expect("correct error value"),
);
match outcome {
Ok(cts) => result_cts.extend(cts),
Err(e) => {
return SyncComputeResponse {
resp: Some(Resp::Error(e.into())),
};
}
Resp::ResultCiphertexts(cts) => result_cts.extend(cts.ciphertexts),
}
}
success_response(result_cts)
SyncComputeResponse {
resp: Some(Resp::ResultCiphertexts(ResultCiphertexts {
ciphertexts: result_cts,
})),
}
})
.await;
match resp {
Expand All @@ -135,12 +127,21 @@ impl FhevmExecutorService {
fn process_computation(
comp: &SyncComputation,
state: &mut ComputationState,
) -> SyncComputeResponse {
) -> Result<Vec<Ciphertext>, SyncComputeError> {
// For now, assume only one result handle.
let result_handle = comp
.result_handles
.first()
.filter(|h| h.len() == HANDLE_LEN)
.ok_or_else(|| SyncComputeError::BadResultHandles)?
.clone();
let op = FheOperation::try_from(comp.operation);
match op {
Ok(FheOperation::FheGetInputCiphertext) => Self::get_input_ciphertext(comp, &state),
Ok(_) => error_response(SyncComputeError::UnsupportedOperation),
_ => error_response(SyncComputeError::InvalidOperation),
Ok(FheOperation::FheGetInputCiphertext) => {
Self::get_input_ciphertext(comp, &result_handle, &state)
}
Ok(_) => Self::compute(comp, result_handle, state),
_ => Err(SyncComputeError::InvalidOperation),
}
}

Expand All @@ -151,9 +152,9 @@ impl FhevmExecutorService {
) -> Result<(), FhevmError> {
for list in lists {
let cts = try_expand_ciphertext_list(&list, &keys.server_key)?;
let list_hash: Handle = Keccak256::digest(list).into();
let list_hash: Handle = Keccak256::digest(list).to_vec();
for (i, ct) in cts.iter().enumerate() {
let mut handle = list_hash;
let mut handle = list_hash.clone();
handle[29] = i as u8;
handle[30] = ct.type_num() as u8;
handle[31] = current_ciphertext_version() as u8;
Expand All @@ -171,29 +172,83 @@ impl FhevmExecutorService {

fn get_input_ciphertext(
comp: &SyncComputation,
result_handle: &Handle,
state: &ComputationState,
) -> SyncComputeResponse {
) -> Result<Vec<Ciphertext>, SyncComputeError> {
match (comp.inputs.first(), comp.inputs.len()) {
(
Some(SyncInput {
input: Some(Input::InputHandle(handle)),
}),
1,
) => {
if let Ok(handle) = (handle as &[u8]).try_into() as Result<Handle, _> {
if let Some(in_mem_ciphertext) = state.ciphertexts.get(&handle) {
success_response(vec![Ciphertext {
handle: handle.to_vec(),
if let Some(in_mem_ciphertext) = state.ciphertexts.get(handle) {
if *handle != *result_handle {
Err(SyncComputeError::BadInputs)
} else {
Ok(vec![Ciphertext {
handle: result_handle.to_vec(),
ciphertext: in_mem_ciphertext.compressed.clone(),
}])
} else {
error_response(SyncComputeError::UnknownHandle)
}
} else {
error_response(SyncComputeError::BadInputs)
Err(SyncComputeError::UnknownHandle)
}
}
_ => error_response(SyncComputeError::BadInputs),
_ => Err(SyncComputeError::BadInputs),
}
}

fn compute(
comp: &SyncComputation,
result_handle: Handle,
state: &mut ComputationState,
) -> Result<Vec<Ciphertext>, SyncComputeError> {
// Collect computation inputs.
let inputs: Result<Vec<SupportedFheCiphertexts>, Box<dyn Error>> = comp
.inputs
.iter()
.map(|sync_input| match &sync_input.input {
Some(input) => match input {
Input::Ciphertext(c) if c.handle.len() == HANDLE_LEN => {
let ct_type = c.handle[30] as i16;
Ok(SupportedFheCiphertexts::decompress(ct_type, &c.ciphertext)?)
}
Input::InputHandle(h) => {
let ct = state.ciphertexts.get(h).ok_or(FhevmError::BadInputs)?;
Ok(ct.expanded.clone())
}
Input::Scalar(s) if s.len() == SCALAR_LEN => {
let mut scalar = U256::default();
scalar.copy_from_be_byte_slice(&s);
Ok(SupportedFheCiphertexts::Scalar(scalar))
}
_ => Err(FhevmError::BadInputs.into()),
},
None => Err(FhevmError::BadInputs.into()),
})
.collect();

// Do the computation on the inputs.
match inputs {
Ok(inputs) => match perform_fhe_operation(comp.operation as i16, &inputs) {
Ok(result) => {
let compressed = result.clone().compress();
state.ciphertexts.insert(
result_handle.clone(),
InMemoryCiphertext {
expanded: result,
compressed: compressed.clone(),
},
);
Ok(vec![Ciphertext {
handle: result_handle,
ciphertext: compressed,
}])
}
Err(_) => Err(SyncComputeError::ComputationFailed),
},
Err(_) => Err(SyncComputeError::BadInputs),
}
}
}
91 changes: 76 additions & 15 deletions fhevm-engine/executor/tests/sync_compute.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
use anyhow::{anyhow, Result};
use executor::server::common::FheOperation;
use executor::server::executor::sync_compute_response::Resp;
use executor::server::executor::Ciphertext;
use executor::server::executor::{
fhevm_executor_client::FhevmExecutorClient, SyncComputation, SyncComputeRequest,
};
use executor::server::executor::{sync_input::Input, SyncInput};
use fhevm_engine_common::types::{SupportedFheCiphertexts, HANDLE_LEN};
use tfhe::CompactCiphertextListBuilder;
use utils::get_test;

mod utils;

#[tokio::test]
async fn get_input_ciphertexts() -> Result<(), Box<dyn std::error::Error>> {
async fn get_input_ciphertext() -> Result<()> {
let test = get_test().await;
let mut client = FhevmExecutorClient::connect(test.server_addr.clone()).await?;
let mut builder = CompactCiphertextListBuilder::new(&test.keys.compact_public_key);
let list = bincode::serialize(&builder.push(10_u8).build()).unwrap();
let list = bincode::serialize(&builder.push(10_u8).build())?;
// TODO: tests for all types and avoiding passing in 2 as an identifier for FheUint8.
let input_handle = test.input_handle(&list, 0, 2);
let sync_input = SyncInput {
input: Some(Input::InputHandle(input_handle.to_vec())),
input: Some(Input::InputHandle(input_handle.clone())),
};
let computation = SyncComputation {
operation: FheOperation::FheGetInputCiphertext.into(),
result_handles: vec![vec![0xaa]],
result_handles: vec![input_handle.clone()],
inputs: vec![sync_input],
};
let req = SyncComputeRequest {
Expand All @@ -31,18 +34,76 @@ async fn get_input_ciphertexts() -> Result<(), Box<dyn std::error::Error>> {
};
let response = client.sync_compute(req).await?;
let sync_compute_response = response.get_ref();
match &sync_compute_response.resp {
Some(Resp::ResultCiphertexts(cts)) => {
match (cts.ciphertexts.first(), cts.ciphertexts.len()) {
(Some(ct), 1) => {
if ct.handle != input_handle || ct.ciphertext.is_empty() {
assert!(false);
}
let resp = <Option<Resp> as Clone>::clone(&sync_compute_response.resp)
.ok_or_else(|| anyhow!("resp is None"))?;
match resp {
Resp::ResultCiphertexts(cts) => match (cts.ciphertexts.first(), cts.ciphertexts.len()) {
(Some(ct), 1) => {
if ct.handle != input_handle || ct.ciphertext.is_empty() {
return Err(anyhow!("response handle or ciphertext are unexpected"));
}
_ => assert!(false),
Ok(())
}
}
_ => assert!(false),
_ => Err(anyhow!("unexpected amount of result ciphertexts returned")),
},
Resp::Error(e) => Err(anyhow!(format!("error response: {}", e))),
}
}

#[tokio::test]
async fn fhe_compute_two_ciphertexts() -> Result<()> {
let test = get_test().await;
let mut client = FhevmExecutorClient::connect(test.server_addr.clone()).await?;
let mut builder = CompactCiphertextListBuilder::new(&test.keys.compact_public_key);
let list = builder.push(10_u16).push(11_u16).build();
let expander = list.expand_with_key(&test.keys.server_key)?;
let ct1 = SupportedFheCiphertexts::FheUint16(
expander
.get(0)
.ok_or(anyhow!("missing ciphertext at index 0"))??,
);
let ct1 = test.compress(ct1);
let ct2 = SupportedFheCiphertexts::FheUint16(
expander
.get(1)
.ok_or(anyhow!("missing ciphertext at index 1"))??,
);
let ct2 = test.compress(ct2);
let sync_input1 = SyncInput {
input: Some(Input::Ciphertext(Ciphertext {
handle: test.ciphertext_handle(&ct1, 3).to_vec(),
ciphertext: ct1,
})),
};
let sync_input2 = SyncInput {
input: Some(Input::Ciphertext(Ciphertext {
handle: test.ciphertext_handle(&ct2, 3).to_vec(),
ciphertext: ct2,
})),
};
let computation = SyncComputation {
operation: FheOperation::FheAdd.into(),
result_handles: vec![vec![0xaa; HANDLE_LEN]],
inputs: vec![sync_input1, sync_input2],
};
let req = SyncComputeRequest {
computations: vec![computation],
input_lists: vec![],
};
let response = client.sync_compute(req).await?;
let sync_compute_response = response.get_ref();
let resp = <Option<Resp> as Clone>::clone(&sync_compute_response.resp)
.ok_or_else(|| anyhow!("resp is None"))?;
match resp {
Resp::ResultCiphertexts(cts) => match (cts.ciphertexts.first(), cts.ciphertexts.len()) {
(Some(ct), 1) => {
if ct.handle != vec![0xaa; HANDLE_LEN] || ct.ciphertext.is_empty() {
return Err(anyhow!("response handle or ciphertext are unexpected"));
}
Ok(())
}
_ => Err(anyhow!("unexpected amount of result ciphertexts returned")),
},
Resp::Error(e) => Err(anyhow!(format!("error response: {}", e))),
}
Ok(())
}
Loading

0 comments on commit aea93ce

Please sign in to comment.