diff --git a/crates/valv/src/cmd/main.rs b/crates/valv/src/cmd/main.rs index 072069f..ddd3736 100644 --- a/crates/valv/src/cmd/main.rs +++ b/crates/valv/src/cmd/main.rs @@ -1,9 +1,13 @@ #![allow(clippy::unwrap_used, clippy::expect_used)] -use std::sync::Arc; +use std::{ + io::{self, Read}, + process::exit, + sync::Arc, +}; use clap::Parser; -use secrecy::{ExposeSecret, Secret}; + use tonic::transport::Server; use valv::{ api, valv::proto::v1::master_key_management_service_server::MasterKeyManagementServiceServer, @@ -13,24 +17,39 @@ use valv::{ #[derive(Parser)] struct Cli { listen_addr: String, - - #[arg(short, long, value_name = "KEY")] - master_key: Secret, } #[tokio::main] async fn main() { - let _guard = unsafe { foundationdb::boot() }; + println!("Please input the 32-byte root key (no newline):"); - let args = Cli::parse(); + let mut buffer = String::new(); + io::stdin() + .read_to_string(&mut buffer) + .expect("Failed to read from stdin"); + + let trimmed_key = buffer.trim(); - let mut valv = Valv::new().await.expect("Failed to initialize Valv"); + if trimmed_key.len() != 32 { + eprintln!( + "Error: Master key must be exactly 32 bytes. Got {} bytes.", + trimmed_key.len() + ); + exit(1); + } - let master_key = args.master_key.clone().expose_secret().clone().into_bytes()[..32] + let root_key: [u8; 32] = trimmed_key + .as_bytes() .try_into() - .expect("Master key must be 32 bytes"); + .expect("Failed to convert trimmed key to 32-byte array"); - valv.set_master_key(master_key); + let _guard = unsafe { foundationdb::boot() }; + + let args = Cli::parse(); + + let mut valv = Valv::new(root_key) + .await + .expect("Failed to initialize Valv"); let api = api::server::API { valv: Arc::new(valv), diff --git a/crates/valv/src/errors.rs b/crates/valv/src/errors.rs index e7e85d7..fa849cc 100644 --- a/crates/valv/src/errors.rs +++ b/crates/valv/src/errors.rs @@ -1,6 +1,10 @@ +use std::sync::{MutexGuard, PoisonError}; + use foundationdb::FdbBindingError; use thiserror::Error; +use crate::ValvState; + #[derive(Error, Debug)] pub enum ValvError { #[error("IO error")] @@ -29,6 +33,9 @@ pub enum ValvError { #[error("Internal error: {0}")] Internal(String), + + #[error("Valv is locked")] + Locked, } pub type Result = std::result::Result; diff --git a/crates/valv/src/integration_tests.rs b/crates/valv/src/integration_tests.rs index b5a7480..98751d5 100644 --- a/crates/valv/src/integration_tests.rs +++ b/crates/valv/src/integration_tests.rs @@ -42,14 +42,14 @@ mod tests { async fn setup_server( ) -> Result>, ValvError> { let addr = SERVER_ADDR.parse().expect("Invalid address"); - let mut valv = Valv::new().await?; let master_key_bytes: [u8; 32] = "77aaee825aa561995d7bda258f9b76b0" .as_bytes() .try_into() .expect("Invalid master key"); - valv.set_master_key(master_key_bytes); + let valv = Valv::new(master_key_bytes).await?; + let api = API { valv: Arc::new(valv), }; @@ -129,8 +129,9 @@ mod tests { plaintext: vec![0; 32].into(), keyring_name: "test_tenant".to_string(), }); + println!("test step encrypted start"); let encrypt_response = client.encrypt(encrypt_request).await.unwrap(); - + println!("test step encrypted pass"); // Assert the encrypt response let original_ciphertext = encrypt_response.get_ref().ciphertext.clone(); assert_eq!( @@ -145,7 +146,9 @@ mod tests { ciphertext: original_ciphertext.clone(), keyring_name: "test_tenant".to_string(), }); + println!("test step decrypt start"); let decrypt_response = client.decrypt(decrypt_request).await.unwrap(); + println!("test step decrypt pass"); // Assert the decrypt response assert_eq!(decrypt_response.get_ref().plaintext.len(), 32); diff --git a/crates/valv/src/lib.rs b/crates/valv/src/lib.rs index be31bac..abfe849 100644 --- a/crates/valv/src/lib.rs +++ b/crates/valv/src/lib.rs @@ -1,4 +1,7 @@ +use core::panic; + use errors::{Result, ValvError}; +use foundationdb::RetryableTransaction; use gen::valv::internal; use prost::bytes::Buf; use secrecy::{ExposeSecret, Secret}; @@ -40,44 +43,81 @@ pub struct KeyMaterial<'a> { #[async_trait::async_trait] pub trait ValvAPI: Send + Sync { - // TODO: Separate get_key into get_key_metadata and get_key_with_primary_version + //async fn rotate_master_key(&self) -> Result; + async fn create_master_key(&self) -> Result; async fn create_key(&self, tenant: &str, name: &str) -> Result; + // TODO: Separate get_key into get_key_metadata and get_key_with_primary_version async fn get_key(&self, tenant: &str, name: &str) -> Result>; async fn list_keys(&self, tenant: &str) -> Result>>; async fn update_key(&self, tenant: &str, key: internal::Key) -> Result; + //async fn rotate_key(&self, tenant: &str, name: &str) -> Result<()>; + async fn get_key_version( &self, tenant: &str, key_name: &str, version_id: u32, ) -> Result>; + async fn encrypt(&self, tenant: &str, key_name: &str, plaintext: Vec) -> Result>; async fn decrypt(&self, tenant: &str, key_name: &str, ciphertext: Vec) -> Result>; } +#[derive(PartialEq, Debug)] +pub enum ValvState { + MissingMasterKey, + Unlocked, +} + pub struct Valv { pub db: FoundationDB, - pub master_key: Secret<[u8; 32]>, + root_key: Secret<[u8; 32]>, + state: ValvState, } +#[allow(clippy::expect_used)] impl Valv { - pub async fn new() -> Result { - Ok(Valv { - db: FoundationDB::new("local").await?, - master_key: [0; 32].into(), - }) - } + pub async fn new(key: [u8; 32]) -> Result { + let db = FoundationDB::new("local").await?; + + let mut valv = Valv { + db, + root_key: key.into(), + state: ValvState::MissingMasterKey, + }; - pub fn set_master_key(&mut self, key: [u8; 32]) { - self.master_key = Secret::new(key); + match valv.get_key("valv", "master_key").await { + Ok(_) => { + valv.state = ValvState::Unlocked; + } + Err(ValvError::KeyNotFound(_)) => { + // Here we allow expect as we want to crash if we cannot create the master key. + // Without a master key, the system is not functional. + valv.create_master_key() + .await + .expect("Could not create master key"); + + valv.state = ValvState::Unlocked; + } + Err(e) => { + panic!( + "Could not attempt to get wrapped master key due to: {}", + e.to_string() + ); + } + } + + Ok(valv) } } #[async_trait::async_trait] impl ValvAPI for Valv { async fn get_key(&self, tenant: &str, key_name: &str) -> Result> { + assert_ne!(self.state, ValvState::MissingMasterKey); + let trx_result = self .db .database @@ -101,6 +141,8 @@ impl ValvAPI for Valv { } async fn list_keys(&self, tenant: &str) -> Result>> { + assert_ne!(self.state, ValvState::MissingMasterKey); + let trx_result = self .db .database @@ -123,25 +165,137 @@ impl ValvAPI for Valv { } } + async fn create_master_key(&self) -> Result { + let mut iv = [0; 12]; + let mut key = [0; 32]; + let mut tag = [0; 16]; + boring::rand::rand_bytes(&mut iv)?; + boring::rand::rand_bytes(&mut key)?; + + println!("encrypting with root key"); + + let encrypted_key = boring::symm::encrypt_aead( + boring::symm::Cipher::aes_256_gcm(), + self.root_key.expose_secret(), + Some(&iv), + &[], + &key, + &mut tag, + )?; + + let encrypted_result: [u8; 4 + 12 + 32 + 16] = { + let mut result = [0u8; 4 + 12 + 32 + 16]; + result[..4].copy_from_slice(&1_u32.to_be_bytes()); + result[4..16].copy_from_slice(&iv); + result[16..16 + encrypted_key.len()].copy_from_slice(&encrypted_key); + result[16 + encrypted_key.len()..].copy_from_slice(&tag); + result + }; + + let key = internal::Key { + key_id: "master_key".to_string(), + primary_version_id: 1, + purpose: "ENCRYPT_DECRYPT".to_string(), + creation_time: Some(prost_types::Timestamp { + seconds: chrono::Utc::now().timestamp(), + nanos: chrono::Utc::now().timestamp_subsec_nanos() as i32, + }), + rotation_schedule: Some(prost_types::Duration { + seconds: chrono::TimeDelta::days(30).num_seconds(), + nanos: 0, + }), + }; + + let key_version = internal::KeyVersion { + key_id: "master_key".to_string(), + key_material: encrypted_result.to_vec().into(), + state: internal::KeyVersionState::Enabled as i32, + version: 1, + creation_time: Some(prost_types::Timestamp { + seconds: chrono::Utc::now().timestamp(), + nanos: chrono::Utc::now().timestamp_subsec_nanos() as i32, + }), + ..Default::default() + }; + + let trx_result = self + .db + .database + .run(|trx, _| async { + let trx = trx; + self.db.update_key_metadata(&trx, "valv", &key).await?; + + self.db + .append_key_version(&trx, "valv", &key, &key_version) + .await?; + + Ok(()) + }) + .await; + + match trx_result { + Ok(_) => Ok(key), + Err(e) => { + println!("Error creating master key: {e}"); + Err(ValvError::Internal(e.to_string())) + } + } + } + async fn create_key(&self, tenant: &str, name: &str) -> Result { + assert_ne!(self.state, ValvState::MissingMasterKey); + + let key = self.get_key("valv", "master_key").await?; + + assert!(key.is_some()); + + let master_key = match key { + Some(key) => key, + None => panic!(), + }; + + let trx_result = self + .db + .database + .run(|trx, _| async { + let trx = trx; + + let material = self + .get_unwrapped_master_key_material(&trx, master_key.primary_version_id) + .await?; + Ok(material) + }) + .await; + + let unwrapped_master_key_material = match trx_result { + Ok(key) => key, + Err(e) => { + println!("Error creating key {name}: {e}"); + return Err(ValvError::Internal(e.to_string())); + } + }; + let mut iv = [0; 12]; let mut key = [0; 32]; let mut tag = [0; 16]; + boring::rand::rand_bytes(&mut iv)?; boring::rand::rand_bytes(&mut key)?; let encrypted_key = boring::symm::encrypt_aead( boring::symm::Cipher::aes_256_gcm(), - self.master_key.expose_secret(), + &unwrapped_master_key_material, Some(&iv), &[], &key, &mut tag, )?; - let mut encrypted_result = Vec::with_capacity(iv.len() + encrypted_key.len() + tag.len()); + let mut encrypted_result = + Vec::with_capacity(4 + iv.len() + encrypted_key.len() + tag.len()); // Add IV, key material and tag to result + encrypted_result.extend_from_slice(&1_u32.to_be_bytes()); encrypted_result.extend_from_slice(&iv); encrypted_result.extend_from_slice(&encrypted_key); encrypted_result.extend_from_slice(&tag); @@ -219,6 +373,10 @@ impl ValvAPI for Valv { } } + /*async fn rotate_key(&self, tenant: &str, name: &str) -> Result<()> { + Ok(()) + }*/ + async fn get_key_version( &self, tenant: &str, @@ -256,6 +414,7 @@ impl ValvAPI for Valv { .run(|trx, _| { async { let trx = trx; + let key = self.db.get_key_metadata(&trx, tenant, key_name).await?; let key = match key { @@ -265,37 +424,10 @@ impl ValvAPI for Valv { } }; - let key_version = self - .db - .get_key_version(&trx, tenant, &key.key_id, key.primary_version_id) + let unwrapped_key_version_material = self + .get_unwrapped_key_material(&trx, tenant, key_name, key.primary_version_id) .await?; - let key_version = match key_version { - Some(key_version) => key_version, - None => { - return Err(ValvError::KeyNotFound(key_name.to_string()).into()); - } - }; - - let (iv, remainder) = key_version.key_material.split_at(12); - let (cipher, tag) = remainder.split_at(remainder.len() - 16); - - let decrypted_key_material = boring::symm::decrypt_aead( - boring::symm::Cipher::aes_256_gcm(), - self.master_key.expose_secret(), - Some(iv), - &[], - cipher, - tag, - ); - - let decrypted_key_material = match decrypted_key_material { - Ok(decrypted_key_material) => decrypted_key_material, - Err(e) => { - return Err(ValvError::Internal(e.to_string()).into()); - } - }; - let mut iv = [0; 12]; boring::rand::rand_bytes(&mut iv).map_err(ValvError::BoringSSL)?; @@ -303,7 +435,7 @@ impl ValvAPI for Valv { let encrypted_key = boring::symm::encrypt_aead( boring::symm::Cipher::aes_256_gcm(), - &decrypted_key_material, + &unwrapped_key_version_material, Some(&iv), &[], &plaintext, @@ -319,7 +451,7 @@ impl ValvAPI for Valv { ); // Add version, IV and encrypted key to result - encrypted_result.extend_from_slice(&(key_version.version).to_be_bytes()); + encrypted_result.extend_from_slice(&(key.primary_version_id).to_be_bytes()); encrypted_result.extend_from_slice(&iv); encrypted_result.extend_from_slice(&encrypted_key); @@ -374,12 +506,19 @@ impl ValvAPI for Valv { } }; - let (kv_iv, kv_remainder) = key_version.key_material.split_at(12); + let (master_key_version, kv_remainder) = key_version.key_material.split_at(4); + let (kv_iv, kv_remainder) = kv_remainder.split_at(12); let (kv_cipher, kv_tag) = kv_remainder.split_at(kv_remainder.len() - 16); - let decrypted_key_material = boring::symm::decrypt_aead( + let master_key_version = std::io::Cursor::new(master_key_version).get_u32(); + + let unwrapped_primary_master_key = self + .get_unwrapped_master_key_material(&trx, master_key_version) + .await?; + + let unwrapped_key_version_material = boring::symm::decrypt_aead( boring::symm::Cipher::aes_256_gcm(), - self.master_key.expose_secret(), + &unwrapped_primary_master_key, Some(kv_iv), &[], kv_cipher, @@ -389,7 +528,7 @@ impl ValvAPI for Valv { let plaintext = boring::symm::decrypt_aead( boring::symm::Cipher::aes_256_gcm(), - &decrypted_key_material, + &unwrapped_key_version_material, Some(iv), &[], cipher, @@ -412,3 +551,85 @@ impl ValvAPI for Valv { } } } + +impl Valv { + async fn get_unwrapped_master_key_material( + &self, + trx: &RetryableTransaction, + version: u32, + ) -> Result> { + assert_ne!(self.state, ValvState::MissingMasterKey); + + let key_version = self + .db + .get_key_version(trx, "valv", "master_key", version) + .await?; + + let key_version = match key_version { + Some(key_version) => key_version, + None => { + return Err(ValvError::KeyNotFound("master_key".to_string())); + } + }; + + let (_, remainder) = key_version.key_material.split_at(4); + let (iv, remainder) = remainder.split_at(12); + let (cipher, tag) = remainder.split_at(remainder.len() - 16); + + let unwrapped_key_version_material = boring::symm::decrypt_aead( + boring::symm::Cipher::aes_256_gcm(), + self.root_key.expose_secret(), + Some(iv), + &[], + cipher, + tag, + ) + .map_err(ValvError::BoringSSL)?; + + println!("decrypted master with root"); + + Ok(unwrapped_key_version_material) + } + async fn get_unwrapped_key_material( + &self, + trx: &RetryableTransaction, + tenant: &str, + key_name: &str, + version: u32, + ) -> Result> { + let key_version = self + .db + .get_key_version(trx, tenant, key_name, version) + .await?; + + let key_version = match key_version { + Some(key_version) => key_version, + None => { + return Err(ValvError::KeyNotFound(key_name.to_string())); + } + }; + + let (master_key_version, remainder) = key_version.key_material.split_at(4); + let (iv, remainder) = remainder.split_at(12); + let (cipher, tag) = remainder.split_at(remainder.len() - 16); + + let master_key_version = std::io::Cursor::new(master_key_version).get_u32(); + + let master_key = self + .get_unwrapped_master_key_material(trx, master_key_version) + .await?; + + let unwrapped_key_version_material = boring::symm::decrypt_aead( + boring::symm::Cipher::aes_256_gcm(), + &master_key, + Some(iv), + &[], + cipher, + tag, + ) + .map_err(ValvError::BoringSSL)?; + println!("decrypted normal key"); + + Ok(unwrapped_key_version_material) + } +}