diff --git a/fhevm-engine/Cargo.lock b/fhevm-engine/Cargo.lock index bdaea17c..98e8aa7b 100644 --- a/fhevm-engine/Cargo.lock +++ b/fhevm-engine/Cargo.lock @@ -135,6 +135,69 @@ version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" +[[package]] +name = "ark-ff-asm" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ed4aa4fe255d0bc6d79373f7e31d2ea147bcf486cba1be5ba7ea85abdb92348" +dependencies = [ + "quote", + "syn 1.0.109", +] + +[[package]] +name = "ark-ff-macros" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7abe79b0e4288889c4574159ab790824d0033b9fdcb2a112a3182fac2e514565" +dependencies = [ + "num-bigint", + "num-traits", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "ark-serialize" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb7b85a02b83d2f22f89bd5cac66c9c89474240cb6207cb1efc16d098e822a5" +dependencies = [ + "ark-serialize-derive", + "ark-std", + "digest", + "num-bigint", +] + +[[package]] +name = "ark-serialize-derive" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae3281bc6d0fd7e549af32b52511e1302185bd688fd3359fa36423346ff682ea" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "ark-std" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94893f1e0c6eeab764ade8dc4c0db24caf4fe7cbbaafc0eba0a9030f447b5185" +dependencies = [ + "num-traits", + "rand", + "rayon", +] + +[[package]] +name = "arrayvec" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" + [[package]] name = "async-stream" version = "0.3.5" @@ -528,6 +591,7 @@ dependencies = [ "prost", "regex", "serde_json", + "sha3", "sqlx", "strum", "testcontainers", @@ -681,6 +745,17 @@ dependencies = [ "serde", ] +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "digest" version = "0.10.7" @@ -1292,6 +1367,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.13.0" @@ -2680,6 +2764,74 @@ dependencies = [ "serde", "sha3", "tfhe-versionable", + "tfhe-zk-pok", +] + +[[package]] +name = "tfhe-ark-bls12-381" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6769557a36c7d9b2313052badc55562b0e260a0dcca745150b15c3bae65a4957" +dependencies = [ + "ark-serialize", + "ark-std", + "tfhe-ark-ec", + "tfhe-ark-ff", +] + +[[package]] +name = "tfhe-ark-ec" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ff6eb48b47e2cd6a2db68909cc62888516a7c1c4faaa894cb8dff7d48029b3f" +dependencies = [ + "ark-serialize", + "ark-std", + "derivative", + "hashbrown 0.14.5", + "itertools 0.12.1", + "num-bigint", + "num-integer", + "num-traits", + "rayon", + "tfhe-ark-ff", + "tfhe-ark-poly", + "zeroize", +] + +[[package]] +name = "tfhe-ark-ff" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08ab3109583fa162a9c83082ad4006a877786da8d76ad3cd9180bcb3d7ac9e9" +dependencies = [ + "ark-ff-asm", + "ark-ff-macros", + "ark-serialize", + "ark-std", + "arrayvec", + "derivative", + "digest", + "itertools 0.12.1", + "num-bigint", + "num-traits", + "paste", + "rayon", + "zeroize", +] + +[[package]] +name = "tfhe-ark-poly" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4849a880457e8562e759fae62462a8be506f8ce1b8a0d5f90061583c4f6d25de" +dependencies = [ + "ark-serialize", + "ark-std", + "derivative", + "hashbrown 0.14.5", + "rayon", + "tfhe-ark-ff", ] [[package]] @@ -2705,6 +2857,24 @@ dependencies = [ "syn 2.0.75", ] +[[package]] +name = "tfhe-zk-pok" +version = "0.3.0-alpha.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ad66713d0dcdeb042c2e07f939825459bca83365445da202c6dd86eb81814b3" +dependencies = [ + "ark-serialize", + "rand", + "rayon", + "serde", + "sha3", + "tfhe-ark-bls12-381", + "tfhe-ark-ec", + "tfhe-ark-ff", + "tfhe-ark-poly", + "zeroize", +] + [[package]] name = "thiserror" version = "1.0.63" @@ -3396,3 +3566,17 @@ name = "zeroize" version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.75", +] diff --git a/fhevm-engine/Cargo.toml b/fhevm-engine/Cargo.toml index 51cdbc70..05c3a680 100644 --- a/fhevm-engine/Cargo.toml +++ b/fhevm-engine/Cargo.toml @@ -3,5 +3,5 @@ resolver = "2" members = ["coprocessor", "executor", "fhevm-engine-common"] [workspace.dependencies] -tfhe = { version = "0.8.0-alpha.2", features = ["boolean", "shortint", "integer", "aarch64-unix"] } +tfhe = { version = "0.8.0-alpha.2", features = ["boolean", "shortint", "integer", "aarch64-unix", "zk-pok"] } clap = { version = "4.5", features = ["derive"] } diff --git a/fhevm-engine/coprocessor/Cargo.toml b/fhevm-engine/coprocessor/Cargo.toml index c569f94c..81ad6a63 100644 --- a/fhevm-engine/coprocessor/Cargo.toml +++ b/fhevm-engine/coprocessor/Cargo.toml @@ -25,6 +25,7 @@ bigdecimal = "0.4" fhevm-engine-common = { path = "../fhevm-engine-common" } strum = { version = "0.26", features = ["derive"] } bincode = "1.3.3" +sha3 = "0.10.8" [dev-dependencies] testcontainers = "0.21" diff --git a/fhevm-engine/coprocessor/migrations/20240722111257_coprocessor.sql b/fhevm-engine/coprocessor/migrations/20240722111257_coprocessor.sql index 9549495e..3d86ffee 100644 --- a/fhevm-engine/coprocessor/migrations/20240722111257_coprocessor.sql +++ b/fhevm-engine/coprocessor/migrations/20240722111257_coprocessor.sql @@ -2,6 +2,7 @@ CREATE TABLE IF NOT EXISTS computations ( tenant_id INT NOT NULL, output_handle BYTEA NOT NULL, + output_type SMALLINT NOT NULL, -- can be handle or scalar, depends on is_scalar field -- only second dependency can ever be scalar dependencies BYTEA[] NOT NULL, @@ -21,10 +22,23 @@ CREATE TABLE IF NOT EXISTS ciphertexts ( ciphertext BYTEA NOT NULL, ciphertext_version SMALLINT NOT NULL, ciphertext_type SMALLINT NOT NULL, + -- if ciphertext came from blob we have its reference + input_blob_hash BYTEA, + input_blob_index INT NOT NULL DEFAULT 0, created_at TIMESTAMP DEFAULT NOW(), PRIMARY KEY (tenant_id, handle, ciphertext_version) ); +-- store for audits and historical reference +CREATE TABLE IF NOT EXISTS input_blobs ( + tenant_id INT NOT NULL, + blob_hash BYTEA NOT NULL, + blob_data BYTEA NOT NULL, + blob_ciphertext_count INT NOT NULL, + created_at TIMESTAMP DEFAULT NOW(), + PRIMARY KEY (tenant_id, blob_hash) +); + CREATE TABLE IF NOT EXISTS tenants ( tenant_id SERIAL PRIMARY KEY, tenant_api_key UUID NOT NULL DEFAULT gen_random_uuid(), diff --git a/fhevm-engine/coprocessor/src/cli.rs b/fhevm-engine/coprocessor/src/cli.rs index b8f77fc3..8d9da65f 100644 --- a/fhevm-engine/coprocessor/src/cli.rs +++ b/fhevm-engine/coprocessor/src/cli.rs @@ -27,6 +27,14 @@ pub struct Args { #[arg(long, default_value_t = 32)] pub tenant_key_cache_size: i32, + /// Maximum compact inputs to upload + #[arg(long, default_value_t = 8)] + pub maximimum_compact_inputs_upload: usize, + + /// Maximum compact inputs to upload + #[arg(long, default_value_t = 255)] + pub maximum_handles_per_input: u8, + /// Coprocessor FHE processing threads #[arg(long, default_value_t = 8)] pub coprocessor_fhe_threads: usize, diff --git a/fhevm-engine/coprocessor/src/db_queries.rs b/fhevm-engine/coprocessor/src/db_queries.rs index d07002f6..7dcea8a8 100644 --- a/fhevm-engine/coprocessor/src/db_queries.rs +++ b/fhevm-engine/coprocessor/src/db_queries.rs @@ -1,7 +1,7 @@ use std::collections::{BTreeSet, HashMap}; use std::str::FromStr; -use crate::types::CoprocessorError; +use crate::types::{CoprocessorError, TfheTenantKeys}; use sqlx::{query, Postgres}; /// Returns tenant id upon valid authorization request @@ -55,12 +55,19 @@ pub async fn check_if_ciphertexts_exist_in_db( ) -> Result, i16>, CoprocessorError> { let handles_to_check_in_db_vec = cts.iter().cloned().collect::>(); let ciphertexts = query!( - " - SELECT handle, ciphertext_type + r#" + -- existing computations + SELECT handle AS "handle!", ciphertext_type AS "ciphertext_type!" FROM ciphertexts - WHERE handle = ANY($1::BYTEA[]) - AND tenant_id = $2 - ", + WHERE tenant_id = $2 + AND handle = ANY($1::BYTEA[]) + UNION + -- pending computations + SELECT output_handle AS "handle!", output_type AS "ciphertext_type!" + FROM computations + WHERE tenant_id = $2 + AND output_handle = ANY($1::BYTEA[]) + "#, &handles_to_check_in_db_vec, tenant_id, ) @@ -86,3 +93,69 @@ pub async fn check_if_ciphertexts_exist_in_db( Ok(result) } + +pub async fn fetch_tenant_server_key<'a, T>(tenant_id: i32, pool: T, tenant_key_cache: &std::sync::Arc>>) +-> Result> +where T: sqlx::PgExecutor<'a> + Copy +{ + // try getting from cache until it succeeds with populating cache + loop { + { + let mut w = tenant_key_cache.write().await; + if let Some(key) = w.get(&tenant_id) { + return Ok(key.sks.clone()); + } + } + + populate_cache_with_tenant_keys(vec![tenant_id], pool, &tenant_key_cache).await?; + } +} + +pub async fn query_tenant_keys<'a, T>(tenants_to_query: Vec, conn: T) +-> Result, Box> +where T: sqlx::PgExecutor<'a> +{ + let mut res = Vec::with_capacity(tenants_to_query.len()); + let keys = query!( + " + SELECT tenant_id, pks_key, sks_key + FROM tenants + WHERE tenant_id = ANY($1::INT[]) + ", + &tenants_to_query + ) + .fetch_all(conn) + .await?; + + for key in keys { + let sks: tfhe::ServerKey = bincode::deserialize(&key.sks_key) + .expect("We can't deserialize our own validated sks key"); + let pks: tfhe::CompactPublicKey = bincode::deserialize(&key.pks_key) + .expect("We can't deserialize our own validated pks key"); + res.push(TfheTenantKeys { tenant_id: key.tenant_id, sks, pks }); + } + + Ok(res) +} + +pub async fn populate_cache_with_tenant_keys<'a, T>(tenants_to_query: Vec, conn: T, tenant_key_cache: &std::sync::Arc>>) +-> Result<(), Box> +where T: sqlx::PgExecutor<'a> +{ + if !tenants_to_query.is_empty() { + let keys = query_tenant_keys(tenants_to_query, conn).await?; + + assert!( + keys.len() > 0, + "We should have keys here, otherwise our database is corrupt" + ); + + let mut key_cache = tenant_key_cache.write().await; + + for key in keys { + key_cache.put(key.tenant_id, key); + } + } + + Ok(()) +} \ No newline at end of file diff --git a/fhevm-engine/coprocessor/src/server.rs b/fhevm-engine/coprocessor/src/server.rs index 09816f35..4172cced 100644 --- a/fhevm-engine/coprocessor/src/server.rs +++ b/fhevm-engine/coprocessor/src/server.rs @@ -1,11 +1,15 @@ -use crate::db_queries::{check_if_api_key_is_valid, check_if_ciphertexts_exist_in_db}; +use std::collections::BTreeMap; +use std::num::NonZeroUsize; + +use crate::db_queries::{check_if_api_key_is_valid, check_if_ciphertexts_exist_in_db, fetch_tenant_server_key}; use crate::server::coprocessor::GenericResponse; -use fhevm_engine_common::tfhe_ops::{check_fhe_operand_types, current_ciphertext_version, debug_trivial_encrypt_be_bytes, deserialize_fhe_ciphertext}; -use fhevm_engine_common::types::FhevmError; -use crate::types::CoprocessorError; +use fhevm_engine_common::tfhe_ops::{check_fhe_operand_types, current_ciphertext_version, debug_trivial_encrypt_be_bytes, deserialize_fhe_ciphertext, try_expand_ciphertext_list}; +use fhevm_engine_common::types::{FhevmError, SupportedFheCiphertexts}; +use sha3::{Digest, Keccak256}; +use crate::types::{CoprocessorError, TfheTenantKeys}; use crate::utils::sort_computations_by_dependencies; use coprocessor::async_computation_input::Input; -use coprocessor::{DebugDecryptResponse, DebugDecryptResponseSingle}; +use coprocessor::{DebugDecryptResponse, DebugDecryptResponseSingle, InputCiphertextResponse, InputCiphertextResponseHandle, InputUploadBatch, InputUploadResponse}; use sqlx::{query, Acquire}; use tonic::transport::Server; @@ -16,6 +20,7 @@ pub mod coprocessor { pub struct CoprocessorService { pool: sqlx::Pool, args: crate::cli::Args, + tenant_key_cache: std::sync::Arc>>, } pub async fn run_server( @@ -33,7 +38,12 @@ pub async fn run_server( .connect(&db_url) .await?; - let service = CoprocessorService { pool, args }; + let tenant_key_cache: std::sync::Arc>> = + std::sync::Arc::new(tokio::sync::RwLock::new(lru::LruCache::new( + NonZeroUsize::new(args.tenant_key_cache_size as usize).unwrap(), + ))); + + let service = CoprocessorService { pool, args, tenant_key_cache }; Server::builder() .add_service( @@ -58,10 +68,10 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ let mut public_key = sqlx::query!( " - SELECT sks_key - FROM tenants - WHERE tenant_id = $1 - ", + SELECT sks_key + FROM tenants + WHERE tenant_id = $1 + ", tenant_id ) .fetch_all(&self.pool) @@ -122,10 +132,10 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ let mut priv_key = sqlx::query!( " - SELECT cks_key - FROM tenants - WHERE tenant_id = $1 - ", + SELECT cks_key + FROM tenants + WHERE tenant_id = $1 + ", tenant_id ) .fetch_all(&self.pool) @@ -140,12 +150,12 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ let cts = sqlx::query!( " - SELECT ciphertext, ciphertext_type, handle - FROM ciphertexts - WHERE tenant_id = $1 - AND handle = ANY($2::BYTEA[]) - AND ciphertext_version = $3 - ", + SELECT ciphertext, ciphertext_type, handle + FROM ciphertexts + WHERE tenant_id = $1 + AND handle = ANY($2::BYTEA[]) + AND ciphertext_version = $3 + ", tenant_id, &req.handles, current_ciphertext_version() @@ -182,6 +192,145 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ return Ok(tonic::Response::new(DebugDecryptResponse { values })); } + async fn upload_inputs( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + let req = request.get_ref(); + if req.input_ciphertexts.len() > self.args.maximimum_compact_inputs_upload { + return Err(tonic::Status::from_error(Box::new( + CoprocessorError::MoreThanMaximumCompactInputCiphertextsUploaded { + input_count: req.input_ciphertexts.len(), + maximum_allowed: self.args.maximimum_compact_inputs_upload, + }, + ))); + } + + let mut response = InputUploadResponse { + upload_responses: Vec::with_capacity(req.input_ciphertexts.len()) + }; + if req.input_ciphertexts.is_empty() { + return Ok(tonic::Response::new(response)); + } + + let tenant_id = check_if_api_key_is_valid(&request, &self.pool).await?; + + let server_key = { + fetch_tenant_server_key(tenant_id, &self.pool, &self.tenant_key_cache) + .await + .map_err(|e| { + tonic::Status::from_error(e) + })? + }; + let mut tfhe_work_set = tokio::task::JoinSet::new(); + + // server key is biiig, clone the pointer + let server_key = std::sync::Arc::new(server_key); + for (idx, ci) in req.input_ciphertexts.iter().enumerate() { + let cloned_input = ci.clone(); + let server_key = server_key.clone(); + tfhe_work_set.spawn_blocking( + move || -> Result<_, (Box<(dyn std::error::Error + Send + Sync)>, usize)> { + let expanded = + try_expand_ciphertext_list(&cloned_input.input_payload, &server_key) + .map_err(|e| { + let err: Box<(dyn std::error::Error + Send + Sync)> = Box::new(e); + (err, idx) + })?; + + Ok((expanded, idx)) + }, + ); + } + + let mut results: BTreeMap> = BTreeMap::new(); + while let Some(output) = tfhe_work_set.join_next().await { + let (cts, idx) = output.map_err(|e| { + let err: Box<(dyn std::error::Error + Sync + Send)> = Box::new(e); + tonic::Status::from_error(err) + })?.map_err(|e| { + tonic::Status::from_error(e.0) + })?; + + if cts.len() > self.args.maximum_handles_per_input as usize { + return Err(tonic::Status::from_error( + Box::new(CoprocessorError::CompactInputCiphertextHasMoreCiphertextThanLimitAllows { + input_blob_index: idx, + input_ciphertexts_in_blob: cts.len(), + input_maximum_ciphertexts_allowed: self.args.maximum_handles_per_input as usize, + }) + )); + } + + assert!(results.insert(idx, cts).is_none(), "fresh map, we passed vector ordered by indexes before"); + } + + assert_eq!(results.len(), req.input_ciphertexts.len(), "We should have all the ciphertexts now"); + + let mut trx = self.pool.begin().await.map_err(Into::::into)?; + for (idx, input_blob) in req.input_ciphertexts.iter().enumerate() { + let mut state = Keccak256::new(); + state.update(&input_blob.input_payload); + let blob_hash = state.finalize().to_vec(); + assert_eq!(blob_hash.len(), 32, "should be 32 bytes"); + + let corresponding_unpacked = results.get(&idx).expect("we should have all results computed now"); + + // save blob for audits and historical reference + let _ = sqlx::query!(" + INSERT INTO input_blobs(tenant_id, blob_hash, blob_data, blob_ciphertext_count) + VALUES($1, $2, $3, $4) + ON CONFLICT (tenant_id, blob_hash) DO NOTHING + ", tenant_id, &blob_hash, &input_blob.input_payload, corresponding_unpacked.len() as i32) + .execute(trx.as_mut()).await.map_err(Into::::into)?; + + let mut ct_resp = InputCiphertextResponse { + input_handles: Vec::with_capacity(corresponding_unpacked.len()), + }; + + for (ct_idx, the_ct) in corresponding_unpacked.iter().enumerate() { + let (serialized_type, serialized_ct) = the_ct.serialize(); + let ciphertext_version = current_ciphertext_version(); + let mut handle_hash = Keccak256::new(); + handle_hash.update(&blob_hash); + handle_hash.update(&[idx as u8]); + let mut handle = handle_hash.finalize().to_vec(); + assert_eq!(handle.len(), 32); + // idx cast to u8 must succeed because we don't allow + // more handles than u8 size + handle[29] = idx as u8; + handle[30] = serialized_type as u8; + handle[31] = ciphertext_version as u8; + + let _ = sqlx::query!(" + INSERT INTO ciphertexts( + tenant_id, + handle, + ciphertext, + ciphertext_version, + ciphertext_type, + input_blob_hash, + input_blob_index + ) + VALUES($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (tenant_id, handle, ciphertext_version) DO NOTHING + ", tenant_id, &handle, &serialized_ct, ciphertext_version, serialized_type, &blob_hash, ct_idx as i32) + .execute(trx.as_mut()).await.map_err(Into::::into)?; + + ct_resp.input_handles.push(InputCiphertextResponseHandle { + handle: handle.to_vec(), + ciphertext_type: serialized_type as i32, + }); + } + + response.upload_responses.push(ct_resp); + } + + trx.commit().await.map_err(Into::::into)?; + + Ok(tonic::Response::new(response)) + } + async fn upload_ciphertexts( &self, request: tonic::Request, @@ -248,8 +397,10 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ check_if_ciphertexts_exist_in_db(handles_to_check_in_db, tenant_id, &self.pool).await?; let mut computations_inputs: Vec>> = Vec::with_capacity(sorted_computations.len()); + let mut computations_outputs: Vec> = Vec::with_capacity(sorted_computations.len()); let mut are_comps_scalar: Vec = Vec::with_capacity(sorted_computations.len()); for comp in &sorted_computations { + computations_outputs.push(comp.output_handle.clone()); let mut handle_types = Vec::with_capacity(comp.inputs.len()); let mut is_computation_scalar = false; let mut this_comp_inputs: Vec> = Vec::with_capacity(comp.inputs.len()); @@ -298,19 +449,36 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ .begin() .await .map_err(Into::::into)?; + let mut new_work_available = false; for (idx, comp) in sorted_computations.iter().enumerate() { + let output_type = ct_types + .get(&comp.output_handle) + .expect("we should have collected all output result types by now with check_fhe_operand_types"); let fhe_operation: i16 = comp .operation .try_into() .map_err(|_| CoprocessorError::FhevmError(FhevmError::UnknownFheOperation(comp.operation)))?; let res = query!( " - INSERT INTO computations(tenant_id, output_handle, dependencies, fhe_operation, is_completed, is_scalar) - VALUES($1, $2, $3, $4, false, $5) + INSERT INTO computations( + tenant_id, + output_handle, + dependencies, + fhe_operation, + is_completed, + is_scalar, + output_type + ) + VALUES($1, $2, $3, $4, false, $5, $6) ON CONFLICT (tenant_id, output_handle) DO NOTHING ", - tenant_id, comp.output_handle, &computations_inputs[idx], fhe_operation, are_comps_scalar[idx] + tenant_id, + comp.output_handle, + &computations_inputs[idx], + fhe_operation, + are_comps_scalar[idx], + output_type ).execute(trx.as_mut()).await.map_err(Into::::into)?; if res.rows_affected() > 0 { new_work_available = true; diff --git a/fhevm-engine/coprocessor/src/tests/inputs.rs b/fhevm-engine/coprocessor/src/tests/inputs.rs new file mode 100644 index 00000000..a5ff2f5b --- /dev/null +++ b/fhevm-engine/coprocessor/src/tests/inputs.rs @@ -0,0 +1,82 @@ +use std::str::FromStr; + +use tonic::metadata::MetadataValue; + +use crate::{db_queries::query_tenant_keys, server::coprocessor::{fhevm_coprocessor_client::FhevmCoprocessorClient, DebugDecryptRequest, InputToUpload, InputUploadBatch}, tests::utils::{default_api_key, default_tenant_id, setup_test_app}}; + + +#[tokio::test] +async fn test_fhe_inputs() -> Result<(), Box> { + let app = setup_test_app().await?; + let mut client = FhevmCoprocessorClient::connect(app.app_url().to_string()).await?; + let api_key_header = format!("bearer {}", default_api_key()); + let pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(2) + .connect(app.db_url()) + .await?; + + let keys = query_tenant_keys(vec![default_tenant_id()], &pool).await.map_err(|e| { + let e: Box = e; + e + })?; + let keys = &keys[0]; + + let mut builder = tfhe::CompactCiphertextListBuilder::new(&keys.pks); + let the_list = builder + .push(false) + .push(1u8) + .push(2u16) + .push(3u32) + .push(4u64) + .build(); + + let serialized = bincode::serialize(&the_list).unwrap(); + + println!("Encrypting inputs..."); + let mut input_request = tonic::Request::new(InputUploadBatch { + input_ciphertexts: vec![ + InputToUpload { + input_payload: serialized, + signature: Vec::new(), + } + ] + }); + input_request.metadata_mut().append( + "authorization", + MetadataValue::from_str(&api_key_header).unwrap(), + ); + let resp = client.upload_inputs(input_request).await?; + let resp = resp.get_ref(); + assert_eq!(resp.upload_responses.len(), 1); + + let first_resp = &resp.upload_responses[0]; + + assert_eq!(first_resp.input_handles.len(), 5); + + let mut decr_handles: Vec> = Vec::new(); + for handle in &first_resp.input_handles { + decr_handles.push(handle.handle.clone()); + } + + let mut decrypt_request = tonic::Request::new(DebugDecryptRequest { + handles: decr_handles, + }); + decrypt_request.metadata_mut().append( + "authorization", + MetadataValue::from_str(&api_key_header).unwrap(), + ); + let resp = client.debug_decrypt_ciphertext(decrypt_request).await?; + let resp = resp.get_ref(); + assert_eq!(resp.values.len(), 5); + + assert_eq!(resp.values[0].output_type, 1); + assert_eq!(resp.values[0].value, "false"); + assert_eq!(resp.values[1].output_type, 2); + assert_eq!(resp.values[1].value, "1"); + assert_eq!(resp.values[2].output_type, 3); + assert_eq!(resp.values[2].value, "2"); + assert_eq!(resp.values[3].output_type, 4); + assert_eq!(resp.values[3].value, "3"); + + Ok(()) +} \ No newline at end of file diff --git a/fhevm-engine/coprocessor/src/tests/mod.rs b/fhevm-engine/coprocessor/src/tests/mod.rs index 69860412..3ef809c6 100644 --- a/fhevm-engine/coprocessor/src/tests/mod.rs +++ b/fhevm-engine/coprocessor/src/tests/mod.rs @@ -11,6 +11,7 @@ use utils::{default_api_key, wait_until_all_ciphertexts_computed}; mod operators; mod utils; +mod inputs; #[tokio::test] async fn test_smoke() -> Result<(), Box> { diff --git a/fhevm-engine/coprocessor/src/tests/utils.rs b/fhevm-engine/coprocessor/src/tests/utils.rs index 65d6abee..c6ff6ec3 100644 --- a/fhevm-engine/coprocessor/src/tests/utils.rs +++ b/fhevm-engine/coprocessor/src/tests/utils.rs @@ -32,6 +32,10 @@ pub fn default_api_key() -> &'static str { "a1503fb6-d79b-4e9e-826d-44cf262f3e05" } +pub fn default_tenant_id() -> i32 { + 1 +} + pub async fn setup_test_app() -> Result> { static PORT_COUNTER: AtomicU16 = AtomicU16::new(10000); @@ -83,10 +87,12 @@ pub async fn setup_test_app() -> Result work_items_batch_size: 40, tenant_key_cache_size: 4, coprocessor_fhe_threads: 4, + maximum_handles_per_input: 255, tokio_threads: 2, pg_pool_max_connections: 2, server_addr: format!("127.0.0.1:{app_port}"), database_url: Some(db_url.clone()), + maximimum_compact_inputs_upload: 10, }; std::thread::spawn(move || { @@ -113,7 +119,7 @@ pub async fn wait_until_all_ciphertexts_computed(test_instance: &TestInstance) - loop { tokio::time::sleep(tokio::time::Duration::from_secs(3)).await; let count = sqlx::query!( - "SELECT count(*) FROM computations WHERE NOT is_completed AND NOT is_error" + "SELECT count(*) FROM computations WHERE NOT is_completed" ) .fetch_one(&pool) .await?; diff --git a/fhevm-engine/coprocessor/src/tfhe_worker.rs b/fhevm-engine/coprocessor/src/tfhe_worker.rs index 2dbadbb0..e13dba54 100644 --- a/fhevm-engine/coprocessor/src/tfhe_worker.rs +++ b/fhevm-engine/coprocessor/src/tfhe_worker.rs @@ -1,7 +1,7 @@ use fhevm_engine_common::tfhe_ops::{ current_ciphertext_version, deserialize_fhe_ciphertext, perform_fhe_operation, }; -use crate::types::TfheTenantKeys; +use crate::{db_queries::populate_cache_with_tenant_keys, types::TfheTenantKeys}; use fhevm_engine_common::types::SupportedFheCiphertexts; use sqlx::{postgres::PgListener, query, Acquire}; use std::{ @@ -28,10 +28,9 @@ pub async fn run_tfhe_worker( async fn tfhe_worker_cycle( args: &crate::cli::Args, ) -> Result<(), Box> { - let key_cache_size = 32; let tenant_key_cache: std::sync::Arc>> = std::sync::Arc::new(tokio::sync::RwLock::new(lru::LruCache::new( - NonZeroUsize::new(key_cache_size).unwrap(), + NonZeroUsize::new(args.tenant_key_cache_size as usize).unwrap(), ))); let db_url = crate::utils::db_url(args); @@ -116,42 +115,16 @@ async fn tfhe_worker_cycle( let tenants_to_query = tenants_to_query.into_iter().collect::>(); let keys_to_query = keys_to_query.into_iter().collect::>(); - if !keys_to_query.is_empty() { - let keys = query!( - " - SELECT tenant_id, pks_key, sks_key - FROM tenants - WHERE tenant_id = ANY($1::INT[]) - ", - &keys_to_query - ) - .fetch_all(trx.as_mut()) - .await?; - - assert!( - keys.len() > 0, - "We should have keys here, otherwise our database is corrupt" - ); - - let mut key_cache = tenant_key_cache.write().await; - - for key in keys { - let sks: tfhe::ServerKey = bincode::deserialize(&key.sks_key) - .expect("We can't deserialize our own validated sks key"); - let pks: tfhe::CompactPublicKey = bincode::deserialize(&key.pks_key) - .expect("We can't deserialize our own validated pks key"); - key_cache.put(key.tenant_id, TfheTenantKeys { sks, pks }); - } - } + populate_cache_with_tenant_keys(keys_to_query, trx.as_mut(), &tenant_key_cache).await?; // TODO: select all the ciphertexts where they're contained in the tuples let ciphertexts_rows = query!( " - SELECT tenant_id, handle, ciphertext, ciphertext_type - FROM ciphertexts - WHERE tenant_id = ANY($1::INT[]) - AND handle = ANY($2::BYTEA[]) - ", + SELECT tenant_id, handle, ciphertext, ciphertext_type + FROM ciphertexts + WHERE tenant_id = ANY($1::INT[]) + AND handle = ANY($2::BYTEA[]) + ", &tenants_to_query, &cts_to_query ) @@ -257,11 +230,11 @@ async fn tfhe_worker_cycle( .await?; let _ = query!( " - UPDATE computations - SET is_completed = true - WHERE tenant_id = $1 - AND output_handle = $2 - ", + UPDATE computations + SET is_completed = true + WHERE tenant_id = $1 + AND output_handle = $2 + ", w.tenant_id, w.output_handle ) @@ -271,11 +244,11 @@ async fn tfhe_worker_cycle( Err((err, tenant_id, output_handle)) => { let _ = query!( " - UPDATE computations - SET is_error = true, error_message = $1 - WHERE tenant_id = $2 - AND output_handle = $3 - ", + UPDATE computations + SET is_error = true, error_message = $1 + WHERE tenant_id = $2 + AND output_handle = $3 + ", err.to_string(), tenant_id, output_handle diff --git a/fhevm-engine/coprocessor/src/types.rs b/fhevm-engine/coprocessor/src/types.rs index 1f814b6d..fe0fc40e 100644 --- a/fhevm-engine/coprocessor/src/types.rs +++ b/fhevm-engine/coprocessor/src/types.rs @@ -9,7 +9,20 @@ pub enum CoprocessorError { CiphertextHandleLongerThan64Bytes, CiphertextHandleMustBeAtLeast1Byte(String), UnexistingInputCiphertextsFound(Vec), + AlreadyExistingResultHandlesFound(Vec), OutputHandleIsAlsoInputHandle(String), + DuplicateResultHandleInInputsUploaded { + hex_handle: String, + }, + MoreThanMaximumCompactInputCiphertextsUploaded { + input_count: usize, + maximum_allowed: usize, + }, + CompactInputCiphertextHasMoreCiphertextThanLimitAllows { + input_blob_index: usize, + input_ciphertexts_in_blob: usize, + input_maximum_ciphertexts_allowed: usize, + }, ComputationInputIsUndefined { computation_output_handle: String, computation_inputs_index: usize, @@ -36,6 +49,15 @@ impl std::fmt::Display for CoprocessorError { Self::DuplicateOutputHandleInBatch(op) => { write!(f, "Duplicate output handle in ciphertext batch: {}", op) } + Self::DuplicateResultHandleInInputsUploaded { hex_handle } => { + write!(f, "Duplicate result handle in inputs detected: {hex_handle}") + } + Self::MoreThanMaximumCompactInputCiphertextsUploaded { input_count, maximum_allowed } => { + write!(f, "More than maximum input blobs uploaded, maximum allowed: {maximum_allowed}, uploaded: {input_count}") + } + Self::CompactInputCiphertextHasMoreCiphertextThanLimitAllows { input_blob_index, input_ciphertexts_in_blob, input_maximum_ciphertexts_allowed } => { + write!(f, "Input blob contains mismatching amount of ciphertexts, input blob index: {input_blob_index}, ciphertexts in blob: {input_ciphertexts_in_blob}, maximum ciphertexts in blob allowed: {input_maximum_ciphertexts_allowed}") + } Self::CiphertextHandleLongerThan64Bytes => { write!(f, "Found ciphertext handle longer than 64 bytes") } @@ -45,6 +67,9 @@ impl std::fmt::Display for CoprocessorError { Self::UnexistingInputCiphertextsFound(handles) => { write!(f, "Ciphertexts not found: {:?}", handles) } + Self::AlreadyExistingResultHandlesFound(e) => { + write!(f, "Handles not found in the database: {:?}", e) + } Self::OutputHandleIsAlsoInputHandle(handle) => { write!( f, @@ -95,6 +120,7 @@ impl From for tonic::Status { } pub struct TfheTenantKeys { + pub tenant_id: i32, pub sks: tfhe::ServerKey, // maybe we'll need this #[allow(dead_code)] diff --git a/fhevm-engine/fhevm-engine-common/src/tfhe_ops.rs b/fhevm-engine/fhevm-engine-common/src/tfhe_ops.rs index ced6ca93..f30411dc 100644 --- a/fhevm-engine/fhevm-engine-common/src/tfhe_ops.rs +++ b/fhevm-engine/fhevm-engine-common/src/tfhe_ops.rs @@ -33,9 +33,7 @@ pub fn deserialize_fhe_ciphertext( Ok(SupportedFheCiphertexts::FheUint64(v)) } _ => { - return Err(FhevmError::UnknownCiphertextType( - input_type, - )); + return Err(FhevmError::UnknownFheType(input_type as i32)); } } } @@ -89,6 +87,74 @@ pub fn current_ciphertext_version() -> i16 { 1 } +pub fn try_expand_ciphertext_list( + input_ciphertext: &[u8], + server_key: &tfhe::ServerKey, +) -> Result, FhevmError> { + let mut res = Vec::new(); + + let the_list: tfhe::CompactCiphertextList = + bincode::deserialize(input_ciphertext) + .map_err(|e| { + let err: Box<(dyn std::error::Error + Send + Sync)> = e; + FhevmError::DeserializationError(err) + })?; + + let expanded = the_list.expand_with_key(server_key) + .map_err(|e| { + FhevmError::CiphertextExpansionError(e) + })?; + + for idx in 0..expanded.len() { + let Some(data_kind) = expanded.get_kind_of(idx) else { + panic!("we're itering over what ciphertext told us how many ciphertexts are there, it must exist") + }; + + match data_kind { + tfhe::FheTypes::Bool => { + let ct: tfhe::FheBool = expanded.get(idx) + .expect("Index must exist") + .expect("Must succeed, we just checked this is the type"); + + res.push(SupportedFheCiphertexts::FheBool(ct)); + }, + tfhe::FheTypes::Uint8 => { + let ct: tfhe::FheUint8 = expanded.get(idx) + .expect("Index must exist") + .expect("Must succeed, we just checked this is the type"); + + res.push(SupportedFheCiphertexts::FheUint8(ct)); + }, + tfhe::FheTypes::Uint16 => { + let ct: tfhe::FheUint16 = expanded.get(idx) + .expect("Index must exist") + .expect("Must succeed, we just checked this is the type"); + + res.push(SupportedFheCiphertexts::FheUint16(ct)); + }, + tfhe::FheTypes::Uint32 => { + let ct: tfhe::FheUint32 = expanded.get(idx) + .expect("Index must exist") + .expect("Must succeed, we just checked this is the type"); + + res.push(SupportedFheCiphertexts::FheUint32(ct)); + }, + tfhe::FheTypes::Uint64 => { + let ct: tfhe::FheUint64 = expanded.get(idx) + .expect("Index must exist") + .expect("Must succeed, we just checked this is the type"); + + res.push(SupportedFheCiphertexts::FheUint64(ct)); + }, + other => { + return Err(FhevmError::CiphertextExpansionUnsupportedCiphertextKind(other)); + } + } + } + + Ok(res) +} + // return output ciphertext type pub fn check_fhe_operand_types( fhe_operation: i32, @@ -254,9 +320,9 @@ pub fn check_fhe_operand_types( }); } - let output_type = op[0] as i16; + let output_type = op[0] as i32; validate_fhe_type(output_type)?; - Ok(output_type) + Ok(output_type as i16) } (other_left, other_right) => { let bool_to_op = |inp| { @@ -286,10 +352,11 @@ pub fn check_fhe_operand_types( } } -pub fn validate_fhe_type(input_type: i16) -> Result<(), FhevmError> { - match input_type { +pub fn validate_fhe_type(input_type: i32) -> Result<(), FhevmError> { + let i16_type: i16 = input_type.try_into().or(Err(FhevmError::UnknownFheType(input_type)))?; + match i16_type { 1 | 2 | 3 | 4 | 5 => Ok(()), - _ => Err(FhevmError::UnknownCiphertextType(input_type)), + _ => Err(FhevmError::UnknownFheType(input_type)), } } diff --git a/fhevm-engine/fhevm-engine-common/src/types.rs b/fhevm-engine/fhevm-engine-common/src/types.rs index 1fcd0943..60fdf71e 100644 --- a/fhevm-engine/fhevm-engine-common/src/types.rs +++ b/fhevm-engine/fhevm-engine-common/src/types.rs @@ -5,8 +5,9 @@ use tfhe::prelude::FheDecrypt; pub enum FhevmError { UnknownFheOperation(i32), UnknownFheType(i32), - UnknownCiphertextType(i16), DeserializationError(Box), + CiphertextExpansionError(tfhe::Error), + CiphertextExpansionUnsupportedCiphertextKind(tfhe::FheTypes), FheOperationOnlyOneOperandCanBeScalar { fhe_operation: i32, fhe_operation_name: String, @@ -78,12 +79,15 @@ impl std::fmt::Display for FhevmError { Self::UnknownFheType(op) => { write!(f, "Unknown fhe type: {}", op) } - Self::UnknownCiphertextType(the_type) => { - write!(f, "Unknown input ciphertext type: {}", the_type) - } Self::DeserializationError(e) => { write!(f, "error deserializing ciphertext: {:?}", e) }, + Self::CiphertextExpansionError(e) => { + write!(f, "error expanding compact ciphertext list: {:?}", e) + }, + Self::CiphertextExpansionUnsupportedCiphertextKind(e) => { + write!(f, "unsupported tfhe type found while expanding ciphertexts: {:?}", e) + }, Self::FheOperationDoesntSupportScalar { fhe_operation, fhe_operation_name, diff --git a/proto/coprocessor.proto b/proto/coprocessor.proto index 64a6b464..90124b4b 100644 --- a/proto/coprocessor.proto +++ b/proto/coprocessor.proto @@ -10,6 +10,7 @@ service FhevmCoprocessor { rpc AsyncCompute (AsyncComputeRequest) returns (GenericResponse) {} rpc UploadCiphertexts (CiphertextUploadBatch) returns (GenericResponse) {} rpc WaitComputations (AsyncComputeRequest) returns (FhevmResponses) {} + rpc UploadInputs (InputUploadBatch) returns (InputUploadResponse) {} // for debugging, should be removed in prod rpc DebugDecryptCiphertext (DebugDecryptRequest) returns (DebugDecryptResponse) {} // for debugging, should be removed in prod @@ -85,6 +86,24 @@ message CiphertextToUpload { bytes ciphertext_bytes = 3; } +message InputToUpload { + bytes input_payload = 1; + bytes signature = 2; +} + +message InputCiphertextResponseHandle { + bytes handle = 1; + int32 ciphertext_type = 2; +} + +message InputCiphertextResponse { + repeated InputCiphertextResponseHandle input_handles = 1; +} + +message InputUploadResponse { + repeated InputCiphertextResponse upload_responses = 1; +} + // The request message containing the user's name. message AsyncComputeRequest { repeated AsyncComputation computations = 1; @@ -94,6 +113,10 @@ message CiphertextUploadBatch { repeated CiphertextToUpload input_ciphertexts = 1; } +message InputUploadBatch { + repeated InputToUpload input_ciphertexts = 1; +} + message WaitBatch { repeated string ciphertext_handles = 1; }