diff --git a/mpc/client/src/main.rs b/mpc/client/src/main.rs index 49dc0c8984..f43db80f73 100644 --- a/mpc/client/src/main.rs +++ b/mpc/client/src/main.rs @@ -46,9 +46,10 @@ use types::Status; const CONTRIBUTE_ENDPOINT: &str = "/contribute"; const SK_ENDPOINT: &str = "/secret_key"; -const CLEAR_ENDPOINT: &str = "/clear"; -const CONTRIB_SK_PATH: &str = "contrib_key.sk.asc"; +const ZKGM_DIR: &str = "zkgm"; +const CONTRIB_SK_PATH: &str = "zkgm/contrib_key.sk.asc"; +const SUCCESSFUL_PATH: &str = ".zkgm_successful"; #[derive(PartialEq, Eq, Debug, Clone, Deserialize)] #[serde(rename_all = "camelCase")] @@ -76,12 +77,18 @@ enum Error { Phase2ContributionFailed(#[from] mpc_shared::Phase2ContributionError), #[error(transparent)] Phase2VerificationFailed(#[from] mpc_shared::Phase2VerificationError), + #[error("pgp key couldn't be found")] + PGPKeyNotFound, } type BoxBody = http_body_util::combinators::BoxBody; type DynError = Box; +fn temp_file(payload_id: &str) -> String { + format!("{ZKGM_DIR}/{payload_id}") +} + fn generate_pgp_key(email: String) -> SignedSecretKey { let mut key_params = SecretKeyParamsBuilder::default(); key_params @@ -101,6 +108,20 @@ fn generate_pgp_key(email: String) -> SignedSecretKey { signed_secret_key } +async fn is_already_successful() -> bool { + tokio::fs::metadata(SUCCESSFUL_PATH).await.is_ok() +} + +async fn wait_successful(tx_status: Sender) { + loop { + if is_already_successful().await { + tx_status.send(Status::Successful).expect("impossible"); + tokio::time::sleep(tokio::time::Duration::from_millis(2000)).await; + break; + } + } +} + async fn contribute( tx_status: Sender, Contribute { @@ -110,9 +131,12 @@ async fn contribute( api_key, contributor_id, payload_id, - user_email, + .. }: Contribute, ) -> Result<(), DynError> { + if is_already_successful().await { + return Ok(()); + } let mut secret_key = if let Ok(_) = tokio::fs::metadata(CONTRIB_SK_PATH).await { SignedSecretKey::from_armor_single::<&[u8]>( tokio::fs::read(CONTRIB_SK_PATH).await?.as_ref(), @@ -120,15 +144,7 @@ async fn contribute( .expect("impossible") .0 } else { - let secret_key = generate_pgp_key(user_email.unwrap_or("placeholder@test.com".into())); - tokio::fs::write( - CONTRIB_SK_PATH, - secret_key - .to_armored_bytes(ArmorOptions::default()) - .expect("impossible"), - ) - .await?; - secret_key + return Err(Error::PGPKeyNotFound.into()); }; let client = SupabaseMPCApi::new(supabase_project.clone(), api_key, jwt); let current_contributor = client @@ -159,11 +175,11 @@ async fn contribute( tx_status .send(Status::DownloadEnded(current_payload.id.clone())) .expect("impossible"); - let phase2_contribution = if let Ok(true) = tokio::fs::metadata(&payload_id) + let phase2_contribution = if let Ok(true) = tokio::fs::metadata(temp_file(&payload_id)) .await .map(|meta| meta.size() as usize == CONTRIBUTION_SIZE) { - tokio::fs::read(&payload_id).await? + tokio::fs::read(temp_file(&payload_id)).await? } else { tx_status .send(Status::ContributionStarted) @@ -179,7 +195,7 @@ async fn contribute( tx_status .send(Status::ContributionEnded) .expect("impossible"); - tokio::fs::write(&payload_id, &phase2_contribution).await?; + tokio::fs::write(temp_file(&payload_id), &phase2_contribution).await?; phase2_contribution }; @@ -213,7 +229,7 @@ async fn contribute( ) .await?; let pool = PoolBuilder::new() - .path("db.sqlite3") + .path(temp_file("state.sqlite3")) .flags( OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE @@ -385,7 +401,7 @@ async fn handle( .body(body) .unwrap()) }; - let raw_response = |status, body| { + let file_response = |status, body| { Ok(hyper::Response::builder() .header(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, "*") .header(hyper::header::CONTENT_TYPE, "application/octet-stream") @@ -399,16 +415,31 @@ async fn handle( }; let response_empty = |status| response(status, BoxBody::default()); match (req.method(), req.uri().path()) { - (&Method::POST, CLEAR_ENDPOINT) => { - let _ = tokio::fs::remove_file(CONTRIB_SK_PATH).await; - response_empty(hyper::StatusCode::OK) + (&Method::POST, SK_ENDPOINT) => { + let whole_body = req.collect().await?.aggregate(); + let email = serde_json::from_reader(whole_body.reader())?; + let guard = latest_status.write().await; + let result = { + if let Err(_) = tokio::fs::metadata(CONTRIB_SK_PATH).await { + let secret_key = generate_pgp_key(email); + let secret_key_serialized = secret_key + .to_armored_bytes(ArmorOptions::default()) + .expect("impossible"); + tokio::fs::write(CONTRIB_SK_PATH, &secret_key_serialized).await?; + response_empty(hyper::StatusCode::CREATED) + } else { + response_empty(hyper::StatusCode::OK) + } + }; + drop(guard); + result } (&Method::GET, SK_ENDPOINT) => { - if let Ok(_) = tokio::fs::metadata(CONTRIB_SK_PATH).await { - let content = tokio::fs::read(CONTRIB_SK_PATH).await?; - raw_response(hyper::StatusCode::OK, full(content)) - } else { + if let Err(_) = tokio::fs::metadata(CONTRIB_SK_PATH).await { response_empty(hyper::StatusCode::NOT_FOUND) + } else { + let content = tokio::fs::read(CONTRIB_SK_PATH).await?; + file_response(hyper::StatusCode::OK, full(content)) } } (&Method::POST, CONTRIBUTE_ENDPOINT) @@ -430,13 +461,14 @@ async fn handle( .await; match result { Ok(_) => { - lock.compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst) + let _ = tokio::fs::write(SUCCESSFUL_PATH, &[1u8]).await; + let _ = tokio::fs::remove_dir(ZKGM_DIR).await; + } + Err(e) => { + tx_status + .send(Status::Failed(format!("{:?}", e))) .expect("impossible"); - tx_status.send(Status::Successful).expect("impossible") } - Err(e) => tx_status - .send(Status::Failed(format!("{:?}", e))) - .expect("impossible"), } }); response_empty(hyper::StatusCode::ACCEPTED) @@ -462,26 +494,28 @@ async fn handle( ), }, // CORS preflight request. - (&Method::OPTIONS, CONTRIBUTE_ENDPOINT | SK_ENDPOINT | CLEAR_ENDPOINT) => { - Ok(hyper::Response::builder() - .header(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, "*") - .header( - hyper::header::ACCESS_CONTROL_ALLOW_HEADERS, - hyper::header::CONTENT_DISPOSITION, - ) - .header( - hyper::header::ACCESS_CONTROL_ALLOW_METHODS, - format!( - "{}, {}, {}", - Method::OPTIONS.as_str(), - Method::GET.as_str(), - Method::POST.as_str() - ), - ) - .status(hyper::StatusCode::OK) - .body(BoxBody::default()) - .unwrap()) - } + (&Method::OPTIONS, CONTRIBUTE_ENDPOINT | SK_ENDPOINT) => Ok(hyper::Response::builder() + .header(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, "*") + .header( + hyper::header::ACCESS_CONTROL_ALLOW_HEADERS, + format!( + "{}, {}", + hyper::header::CONTENT_TYPE, + hyper::header::CONTENT_DISPOSITION + ), + ) + .header( + hyper::header::ACCESS_CONTROL_ALLOW_METHODS, + format!( + "{}, {}, {}", + Method::OPTIONS.as_str(), + Method::GET.as_str(), + Method::POST.as_str() + ), + ) + .status(hyper::StatusCode::OK) + .body(BoxBody::default()) + .unwrap()), _ => response_empty(hyper::StatusCode::NOT_FOUND), } } @@ -525,6 +559,9 @@ async fn input_and_status_handling( #[tokio::main] async fn main() -> Result<(), DynError> { + if let Err(_) = tokio::fs::metadata(ZKGM_DIR).await { + tokio::fs::create_dir(ZKGM_DIR).await?; + } let status = Arc::new(RwLock::new(Status::Idle)); let lock = Arc::new(AtomicBool::new(false)); let (tx_status, rx_status) = broadcast::channel(64); @@ -532,6 +569,7 @@ async fn main() -> Result<(), DynError> { let status_clone = status.clone(); let token = CancellationToken::new(); let token_clone = token.clone(); + let tx_status_clone = tx_status.clone(); let handle = tokio::spawn(async move { let addr = SocketAddr::from(([0, 0, 0, 0], 0x1337)); let listener = TcpListener::bind(addr).await.unwrap(); @@ -540,7 +578,7 @@ async fn main() -> Result<(), DynError> { Ok((stream, _)) = listener.accept() => { let io = TokioIo::new(stream); let status_clone = status_clone.clone(); - let tx_status_clone = tx_status.clone(); + let tx_status_clone = tx_status_clone.clone(); let lock_clone = lock.clone(); let conn = hyper::server::conn::http1::Builder::new().serve_connection( io, @@ -577,7 +615,10 @@ async fn main() -> Result<(), DynError> { }, )?; input_and_status_handling(status, rx_status, tx_ui).await; - ui::run_ui(&mut terminal, rx_ui).await?; + tokio::select! { + _ = ui::run_ui(&mut terminal, rx_ui) => {} + _ = wait_successful(tx_status) => {} + } terminal.clear()?; crossterm::terminal::disable_raw_mode()?; let _ = execute!(io::stdout(), Show);