Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: inputs in SyncComputeRequest #13

Merged
merged 1 commit into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions fhevm-engine/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions fhevm-engine/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ clap = { version = "4.5", features = ["derive"] }
tokio = { version = "1.0", features = ["rt-multi-thread", "macros"] }
prost = "0.13"
tonic = { version = "0.12", features = ["server"] }
bincode = "1.3.3"
sha3 = "0.10.8"
4 changes: 2 additions & 2 deletions fhevm-engine/coprocessor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ hex = "0.4"
bigdecimal = "0.4"
fhevm-engine-common = { path = "../fhevm-engine-common" }
strum = { version = "0.26", features = ["derive"] }
bincode = "1.3.3"
sha3 = "0.10.8"
bincode.workspace = true
sha3.workspace = true

[dev-dependencies]
testcontainers = "0.21"
Expand Down
3 changes: 3 additions & 0 deletions fhevm-engine/executor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ clap.workspace = true
tokio.workspace = true
prost.workspace = true
tonic.workspace = true
tfhe.workspace = true
bincode.workspace = true
sha3.workspace = true
fhevm-engine-common = { path = "../fhevm-engine-common" }

[build-dependencies]
Expand Down
167 changes: 159 additions & 8 deletions fhevm-engine/executor/src/server.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
use std::error::Error;
use std::{cell::Cell, collections::HashMap, error::Error, sync::Arc};

use common::FheOperation;
use executor::{
fhevm_executor_server::{FhevmExecutor, FhevmExecutorServer},
SyncComputeRequest, SyncComputeResponse,
sync_compute_response::Resp,
sync_input::Input,
Ciphertext, ResultCiphertexts, SyncComputation, SyncComputeError, SyncComputeRequest,
SyncComputeResponse, SyncInput,
};
use tonic::{transport::Server, Request, Response};
use fhevm_engine_common::{
keys::{FhevmKeys, SerializedFhevmKeys},
tfhe_ops::{current_ciphertext_version, try_expand_ciphertext_list},
types::{FhevmError, Handle, SupportedFheCiphertexts},
};
use sha3::{Digest, Keccak256};
use tfhe::set_server_key;
use tokio::task::spawn_blocking;
use tonic::{transport::Server, Code, Request, Response, Status};

mod common {
pub mod common {
tonic::include_proto!("fhevm.common");
}

Expand All @@ -21,7 +33,7 @@ pub fn start(args: &crate::cli::Args) -> Result<(), Box<dyn Error>> {
.enable_all()
.build()?;

let executor = FhevmExecutorService::default();
let executor = FhevmExecutorService::new();
let addr = args.server_addr.parse().expect("server address");

runtime.block_on(async {
Expand All @@ -34,15 +46,154 @@ pub fn start(args: &crate::cli::Args) -> Result<(), Box<dyn Error>> {
Ok(())
}

struct InMemoryCiphertext {
expanded: SupportedFheCiphertexts,
compressed: Vec<u8>,
}

#[derive(Default)]
pub struct FhevmExecutorService {}
struct ComputationState {
ciphertexts: HashMap<Handle, InMemoryCiphertext>,
}

fn error_response(error: SyncComputeError) -> SyncComputeResponse {
SyncComputeResponse {
resp: Some(Resp::Error(error.into())),
}
}

fn success_response(cts: Vec<Ciphertext>) -> SyncComputeResponse {
SyncComputeResponse {
resp: Some(Resp::ResultCiphertexts(ResultCiphertexts {
ciphertexts: cts,
})),
}
}

struct FhevmExecutorService {
keys: Arc<FhevmKeys>,
}

#[tonic::async_trait]
impl FhevmExecutor for FhevmExecutorService {
async fn sync_compute(
&self,
req: Request<SyncComputeRequest>,
) -> Result<Response<SyncComputeResponse>, tonic::Status> {
Ok(Response::new(SyncComputeResponse::default()))
) -> Result<Response<SyncComputeResponse>, Status> {
let keys = self.keys.clone();
let resp = spawn_blocking(move || {
// Make sure we only clone the server key if needed.
thread_local! {
static SERVER_KEY_IS_SET: Cell<bool> = Cell::new(false);
}
if !SERVER_KEY_IS_SET.get() {
set_server_key(keys.server_key.clone());
SERVER_KEY_IS_SET.set(true);
}

// Exapnd inputs that are global to the whole request.
let req = req.get_ref();
let mut state = ComputationState::default();
if Self::expand_inputs(&req.input_lists, &keys, &mut state).is_err() {
return error_response(SyncComputeError::BadInputList);
}

// Execute all computations.
let mut result_cts = Vec::new();
for computation in &req.computations {
let outcome = Self::process_computation(computation, &mut state);
// Either all succeed or we return on the first failure.
match outcome.resp.unwrap() {
Resp::Error(error) => {
return error_response(
SyncComputeError::try_from(error).expect("correct error value"),
);
}
Resp::ResultCiphertexts(cts) => result_cts.extend(cts.ciphertexts),
}
}
success_response(result_cts)
})
.await;
match resp {
Ok(resp) => Ok(Response::new(resp)),
Err(_) => Err(Status::new(
Code::Unknown,
"failed to execute computation via spawn_blocking",
)),
}
}
}

impl FhevmExecutorService {
fn new() -> Self {
FhevmExecutorService {
keys: Arc::new(SerializedFhevmKeys::load_from_disk().into()),
}
}

fn process_computation(
comp: &SyncComputation,
state: &mut ComputationState,
) -> SyncComputeResponse {
let op = FheOperation::try_from(comp.operation);
match op {
Ok(FheOperation::FheGetInputCiphertext) => Self::get_input_ciphertext(comp, &state),
Ok(_) => error_response(SyncComputeError::UnsupportedOperation),
_ => error_response(SyncComputeError::InvalidOperation),
}
}

fn expand_inputs(
lists: &Vec<Vec<u8>>,
keys: &FhevmKeys,
state: &mut ComputationState,
) -> Result<(), FhevmError> {
for list in lists {
let cts = try_expand_ciphertext_list(&list, &keys.server_key)?;
let list_hash: Handle = Keccak256::digest(list).into();
for (i, ct) in cts.iter().enumerate() {
let mut handle = list_hash;
handle[29] = i as u8;
handle[30] = ct.type_num() as u8;
handle[31] = current_ciphertext_version() as u8;
state.ciphertexts.insert(
handle,
InMemoryCiphertext {
expanded: ct.clone(),
compressed: ct.clone().compress(),
},
);
}
}
Ok(())
}

fn get_input_ciphertext(
comp: &SyncComputation,
state: &ComputationState,
) -> SyncComputeResponse {
match (comp.inputs.first(), comp.inputs.len()) {
(
Some(SyncInput {
input: Some(Input::InputHandle(handle)),
}),
1,
) => {
if let Ok(handle) = (handle as &[u8]).try_into() as Result<Handle, _> {
if let Some(in_mem_ciphertext) = state.ciphertexts.get(&handle) {
success_response(vec![Ciphertext {
handle: handle.to_vec(),
ciphertext: in_mem_ciphertext.compressed.clone(),
}])
} else {
error_response(SyncComputeError::UnknownHandle)
}
} else {
error_response(SyncComputeError::BadInputs)
}
}
_ => error_response(SyncComputeError::BadInputs),
}
}
}
48 changes: 42 additions & 6 deletions fhevm-engine/executor/tests/sync_compute.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,48 @@
use executor::server::executor::{fhevm_executor_client::FhevmExecutorClient, SyncComputeRequest};
use utils::TestInstance;
use executor::server::common::FheOperation;
use executor::server::executor::sync_compute_response::Resp;
use executor::server::executor::{
fhevm_executor_client::FhevmExecutorClient, SyncComputation, SyncComputeRequest,
};
use executor::server::executor::{sync_input::Input, SyncInput};
use tfhe::CompactCiphertextListBuilder;
use utils::get_test;

mod utils;

#[tokio::test]
async fn compute_on_ciphertexts() -> Result<(), Box<dyn std::error::Error>> {
let test_instance = TestInstance::new();
let mut client = FhevmExecutorClient::connect(test_instance.server_addr).await?;
let resp = client.sync_compute(SyncComputeRequest::default()).await?;
async fn get_input_ciphertexts() -> Result<(), Box<dyn std::error::Error>> {
let test = get_test().await;
let mut client = FhevmExecutorClient::connect(test.server_addr.clone()).await?;
let mut builder = CompactCiphertextListBuilder::new(&test.keys.compact_public_key);
let list = bincode::serialize(&builder.push(10_u8).build()).unwrap();
// TODO: tests for all types and avoiding passing in 2 as an identifier for FheUint8.
let input_handle = test.input_handle(&list, 0, 2);
let sync_input = SyncInput {
input: Some(Input::InputHandle(input_handle.to_vec())),
};
let computation = SyncComputation {
operation: FheOperation::FheGetInputCiphertext.into(),
result_handles: vec![vec![0xaa]],
inputs: vec![sync_input],
};
let req = SyncComputeRequest {
computations: vec![computation],
input_lists: vec![list],
};
let response = client.sync_compute(req).await?;
let sync_compute_response = response.get_ref();
match &sync_compute_response.resp {
Some(Resp::ResultCiphertexts(cts)) => {
match (cts.ciphertexts.first(), cts.ciphertexts.len()) {
(Some(ct), 1) => {
if ct.handle != input_handle || ct.ciphertext.is_empty() {
assert!(false);
}
}
_ => assert!(false),
}
}
_ => assert!(false),
}
Ok(())
}
30 changes: 27 additions & 3 deletions fhevm-engine/executor/tests/utils.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
use std::{sync::Arc, time::Duration};

use clap::Parser;
use executor::{cli::Args, server};
use fhevm_engine_common::keys::{FhevmKeys, SerializedFhevmKeys};
use fhevm_engine_common::{
keys::{FhevmKeys, SerializedFhevmKeys},
tfhe_ops::current_ciphertext_version,
types::Handle,
};
use sha3::{Digest, Keccak256};
use tokio::{sync::OnceCell, time::sleep};

pub struct TestInstance {
pub keys: FhevmKeys,
pub server_addr: String,
}

impl TestInstance {
pub fn new() -> Self {
pub async fn new() -> Self {
// Get defaults by parsing a cmd line without any arguments.
let args = Args::parse_from(&["test"]);

Expand All @@ -20,8 +28,24 @@ impl TestInstance {
std::thread::spawn(move || server::start(&args).expect("start server"));

// TODO: a hacky way to wait for the server to start
std::thread::sleep(std::time::Duration::from_millis(150));
sleep(Duration::from_secs(6)).await;

instance
}

pub fn input_handle(&self, list: &[u8], index: u8, ct_type: u8) -> Handle {
let mut handle: Handle = Keccak256::digest(list).into();
handle[29] = index;
handle[30] = ct_type;
handle[31] = current_ciphertext_version() as u8;
handle
}
}

static TEST: OnceCell<Arc<TestInstance>> = OnceCell::const_new();

pub async fn get_test() -> Arc<TestInstance> {
TEST.get_or_init(|| async { Arc::new(TestInstance::new().await) })
.await
.clone()
}
Loading