-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(mpc): tmp dir for client, pgp secret generation req/res
- Loading branch information
1 parent
e9a9b37
commit c31af66
Showing
1 changed file
with
92 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,9 +46,10 @@ use types::Status; | |
|
||
const CONTRIBUTE_ENDPOINT: &str = "/contribute"; | ||
const SK_ENDPOINT: &str = "/secret_key"; | ||
const CLEAR_ENDPOINT: &str = "/clear"; | ||
|
||
const CONTRIB_SK_PATH: &str = "contrib_key.sk.asc"; | ||
const ZKGM_DIR: &str = "zkgm"; | ||
const CONTRIB_SK_PATH: &str = "zkgm/contrib_key.sk.asc"; | ||
const SUCCESSFUL_PATH: &str = ".zkgm_successful"; | ||
|
||
#[derive(PartialEq, Eq, Debug, Clone, Deserialize)] | ||
#[serde(rename_all = "camelCase")] | ||
|
@@ -76,12 +77,18 @@ enum Error { | |
Phase2ContributionFailed(#[from] mpc_shared::Phase2ContributionError), | ||
#[error(transparent)] | ||
Phase2VerificationFailed(#[from] mpc_shared::Phase2VerificationError), | ||
#[error("pgp key couldn't be found")] | ||
PGPKeyNotFound, | ||
} | ||
|
||
type BoxBody = http_body_util::combinators::BoxBody<Bytes, hyper::Error>; | ||
|
||
type DynError = Box<dyn std::error::Error + Send + Sync>; | ||
|
||
fn temp_file(payload_id: &str) -> String { | ||
format!("{ZKGM_DIR}/{payload_id}") | ||
} | ||
|
||
fn generate_pgp_key(email: String) -> SignedSecretKey { | ||
let mut key_params = SecretKeyParamsBuilder::default(); | ||
key_params | ||
|
@@ -101,6 +108,20 @@ fn generate_pgp_key(email: String) -> SignedSecretKey { | |
signed_secret_key | ||
} | ||
|
||
async fn is_already_successful() -> bool { | ||
tokio::fs::metadata(SUCCESSFUL_PATH).await.is_ok() | ||
} | ||
|
||
async fn wait_successful(tx_status: Sender<Status>) { | ||
loop { | ||
if is_already_successful().await { | ||
tx_status.send(Status::Successful).expect("impossible"); | ||
tokio::time::sleep(tokio::time::Duration::from_millis(2000)).await; | ||
break; | ||
} | ||
} | ||
} | ||
|
||
async fn contribute( | ||
tx_status: Sender<Status>, | ||
Contribute { | ||
|
@@ -110,25 +131,20 @@ async fn contribute( | |
api_key, | ||
contributor_id, | ||
payload_id, | ||
user_email, | ||
.. | ||
}: Contribute, | ||
) -> Result<(), DynError> { | ||
if is_already_successful().await { | ||
return Ok(()); | ||
} | ||
let mut secret_key = if let Ok(_) = tokio::fs::metadata(CONTRIB_SK_PATH).await { | ||
SignedSecretKey::from_armor_single::<&[u8]>( | ||
tokio::fs::read(CONTRIB_SK_PATH).await?.as_ref(), | ||
) | ||
.expect("impossible") | ||
.0 | ||
} else { | ||
let secret_key = generate_pgp_key(user_email.unwrap_or("[email protected]".into())); | ||
tokio::fs::write( | ||
CONTRIB_SK_PATH, | ||
secret_key | ||
.to_armored_bytes(ArmorOptions::default()) | ||
.expect("impossible"), | ||
) | ||
.await?; | ||
secret_key | ||
return Err(Error::PGPKeyNotFound.into()); | ||
}; | ||
let client = SupabaseMPCApi::new(supabase_project.clone(), api_key, jwt); | ||
let current_contributor = client | ||
|
@@ -159,11 +175,11 @@ async fn contribute( | |
tx_status | ||
.send(Status::DownloadEnded(current_payload.id.clone())) | ||
.expect("impossible"); | ||
let phase2_contribution = if let Ok(true) = tokio::fs::metadata(&payload_id) | ||
let phase2_contribution = if let Ok(true) = tokio::fs::metadata(temp_file(&payload_id)) | ||
.await | ||
.map(|meta| meta.size() as usize == CONTRIBUTION_SIZE) | ||
{ | ||
tokio::fs::read(&payload_id).await? | ||
tokio::fs::read(temp_file(&payload_id)).await? | ||
} else { | ||
tx_status | ||
.send(Status::ContributionStarted) | ||
|
@@ -179,7 +195,7 @@ async fn contribute( | |
tx_status | ||
.send(Status::ContributionEnded) | ||
.expect("impossible"); | ||
tokio::fs::write(&payload_id, &phase2_contribution).await?; | ||
tokio::fs::write(temp_file(&payload_id), &phase2_contribution).await?; | ||
phase2_contribution | ||
}; | ||
|
||
|
@@ -213,7 +229,7 @@ async fn contribute( | |
) | ||
.await?; | ||
let pool = PoolBuilder::new() | ||
.path("db.sqlite3") | ||
.path(temp_file("state.sqlite3")) | ||
.flags( | ||
OpenFlags::SQLITE_OPEN_READ_WRITE | ||
| OpenFlags::SQLITE_OPEN_CREATE | ||
|
@@ -385,7 +401,7 @@ async fn handle( | |
.body(body) | ||
.unwrap()) | ||
}; | ||
let raw_response = |status, body| { | ||
let file_response = |status, body| { | ||
Ok(hyper::Response::builder() | ||
.header(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, "*") | ||
.header(hyper::header::CONTENT_TYPE, "application/octet-stream") | ||
|
@@ -399,16 +415,31 @@ async fn handle( | |
}; | ||
let response_empty = |status| response(status, BoxBody::default()); | ||
match (req.method(), req.uri().path()) { | ||
(&Method::POST, CLEAR_ENDPOINT) => { | ||
let _ = tokio::fs::remove_file(CONTRIB_SK_PATH).await; | ||
response_empty(hyper::StatusCode::OK) | ||
(&Method::POST, SK_ENDPOINT) => { | ||
let whole_body = req.collect().await?.aggregate(); | ||
let email = serde_json::from_reader(whole_body.reader())?; | ||
let guard = latest_status.write().await; | ||
let result = { | ||
if let Err(_) = tokio::fs::metadata(CONTRIB_SK_PATH).await { | ||
let secret_key = generate_pgp_key(email); | ||
let secret_key_serialized = secret_key | ||
.to_armored_bytes(ArmorOptions::default()) | ||
.expect("impossible"); | ||
tokio::fs::write(CONTRIB_SK_PATH, &secret_key_serialized).await?; | ||
response_empty(hyper::StatusCode::CREATED) | ||
} else { | ||
response_empty(hyper::StatusCode::OK) | ||
} | ||
}; | ||
drop(guard); | ||
result | ||
} | ||
(&Method::GET, SK_ENDPOINT) => { | ||
if let Ok(_) = tokio::fs::metadata(CONTRIB_SK_PATH).await { | ||
let content = tokio::fs::read(CONTRIB_SK_PATH).await?; | ||
raw_response(hyper::StatusCode::OK, full(content)) | ||
} else { | ||
if let Err(_) = tokio::fs::metadata(CONTRIB_SK_PATH).await { | ||
response_empty(hyper::StatusCode::NOT_FOUND) | ||
} else { | ||
let content = tokio::fs::read(CONTRIB_SK_PATH).await?; | ||
file_response(hyper::StatusCode::OK, full(content)) | ||
} | ||
} | ||
(&Method::POST, CONTRIBUTE_ENDPOINT) | ||
|
@@ -430,13 +461,14 @@ async fn handle( | |
.await; | ||
match result { | ||
Ok(_) => { | ||
lock.compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst) | ||
let _ = tokio::fs::write(SUCCESSFUL_PATH, &[1u8]).await; | ||
let _ = tokio::fs::remove_dir(ZKGM_DIR).await; | ||
} | ||
Err(e) => { | ||
tx_status | ||
.send(Status::Failed(format!("{:?}", e))) | ||
.expect("impossible"); | ||
tx_status.send(Status::Successful).expect("impossible") | ||
} | ||
Err(e) => tx_status | ||
.send(Status::Failed(format!("{:?}", e))) | ||
.expect("impossible"), | ||
} | ||
}); | ||
response_empty(hyper::StatusCode::ACCEPTED) | ||
|
@@ -462,26 +494,28 @@ async fn handle( | |
), | ||
}, | ||
// CORS preflight request. | ||
(&Method::OPTIONS, CONTRIBUTE_ENDPOINT | SK_ENDPOINT | CLEAR_ENDPOINT) => { | ||
Ok(hyper::Response::builder() | ||
.header(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, "*") | ||
.header( | ||
hyper::header::ACCESS_CONTROL_ALLOW_HEADERS, | ||
hyper::header::CONTENT_DISPOSITION, | ||
) | ||
.header( | ||
hyper::header::ACCESS_CONTROL_ALLOW_METHODS, | ||
format!( | ||
"{}, {}, {}", | ||
Method::OPTIONS.as_str(), | ||
Method::GET.as_str(), | ||
Method::POST.as_str() | ||
), | ||
) | ||
.status(hyper::StatusCode::OK) | ||
.body(BoxBody::default()) | ||
.unwrap()) | ||
} | ||
(&Method::OPTIONS, CONTRIBUTE_ENDPOINT | SK_ENDPOINT) => Ok(hyper::Response::builder() | ||
.header(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, "*") | ||
.header( | ||
hyper::header::ACCESS_CONTROL_ALLOW_HEADERS, | ||
format!( | ||
"{}, {}", | ||
hyper::header::CONTENT_TYPE, | ||
hyper::header::CONTENT_DISPOSITION | ||
), | ||
) | ||
.header( | ||
hyper::header::ACCESS_CONTROL_ALLOW_METHODS, | ||
format!( | ||
"{}, {}, {}", | ||
Method::OPTIONS.as_str(), | ||
Method::GET.as_str(), | ||
Method::POST.as_str() | ||
), | ||
) | ||
.status(hyper::StatusCode::OK) | ||
.body(BoxBody::default()) | ||
.unwrap()), | ||
_ => response_empty(hyper::StatusCode::NOT_FOUND), | ||
} | ||
} | ||
|
@@ -525,13 +559,17 @@ async fn input_and_status_handling( | |
|
||
#[tokio::main] | ||
async fn main() -> Result<(), DynError> { | ||
if let Err(_) = tokio::fs::metadata(ZKGM_DIR).await { | ||
tokio::fs::create_dir(ZKGM_DIR).await?; | ||
} | ||
let status = Arc::new(RwLock::new(Status::Idle)); | ||
let lock = Arc::new(AtomicBool::new(false)); | ||
let (tx_status, rx_status) = broadcast::channel(64); | ||
let graceful = GracefulShutdown::new(); | ||
let status_clone = status.clone(); | ||
let token = CancellationToken::new(); | ||
let token_clone = token.clone(); | ||
let tx_status_clone = tx_status.clone(); | ||
let handle = tokio::spawn(async move { | ||
let addr = SocketAddr::from(([0, 0, 0, 0], 0x1337)); | ||
let listener = TcpListener::bind(addr).await.unwrap(); | ||
|
@@ -540,7 +578,7 @@ async fn main() -> Result<(), DynError> { | |
Ok((stream, _)) = listener.accept() => { | ||
let io = TokioIo::new(stream); | ||
let status_clone = status_clone.clone(); | ||
let tx_status_clone = tx_status.clone(); | ||
let tx_status_clone = tx_status_clone.clone(); | ||
let lock_clone = lock.clone(); | ||
let conn = hyper::server::conn::http1::Builder::new().serve_connection( | ||
io, | ||
|
@@ -577,7 +615,10 @@ async fn main() -> Result<(), DynError> { | |
}, | ||
)?; | ||
input_and_status_handling(status, rx_status, tx_ui).await; | ||
ui::run_ui(&mut terminal, rx_ui).await?; | ||
tokio::select! { | ||
_ = ui::run_ui(&mut terminal, rx_ui) => {} | ||
_ = wait_successful(tx_status) => {} | ||
} | ||
terminal.clear()?; | ||
crossterm::terminal::disable_raw_mode()?; | ||
let _ = execute!(io::stdout(), Show); | ||
|