diff --git a/payjoin/src/receive/v2/mod.rs b/payjoin/src/receive/v2/mod.rs index 8e32ed11..f53d1ec0 100644 --- a/payjoin/src/receive/v2/mod.rs +++ b/payjoin/src/receive/v2/mod.rs @@ -17,7 +17,7 @@ use super::{ }; use crate::psbt::PsbtExt; use crate::receive::optional_parameters::Params; -use crate::v2::{HpkePublicKey, HpkeSecretKey, OhttpEncapsulationError}; +use crate::v2::{HpkeKeyPair, HpkePublicKey, OhttpEncapsulationError}; use crate::{OhttpKeys, PjUriBuilder, Request}; pub(crate) mod error; @@ -33,7 +33,7 @@ struct SessionContext { ohttp_keys: OhttpKeys, expiry: SystemTime, ohttp_relay: url::Url, - s: (HpkeSecretKey, HpkePublicKey), + s: HpkeKeyPair, e: Option, } @@ -85,7 +85,7 @@ impl SessionInitializer { ohttp_relay, expiry: SystemTime::now() + expire_after.unwrap_or(TWENTY_FOUR_HOURS_DEFAULT_EXPIRY), - s: crate::v2::gen_keypair(), + s: HpkeKeyPair::gen_keypair(), e: None, }, } @@ -532,7 +532,7 @@ impl PayjoinProposal { Some(e) => { let payjoin_bytes = self.inner.payjoin_psbt.serialize(); log::debug!("THERE IS AN e: {:?}", e); - crate::v2::encrypt_message_b(payjoin_bytes, self.context.s.clone(), e) + crate::v2::encrypt_message_b(payjoin_bytes, &self.context.s, e) } None => Ok(self.extract_v1_req().as_bytes().to_vec()), }?; @@ -602,7 +602,7 @@ mod test { ), ohttp_relay: url::Url::parse("https://relay.com").unwrap(), expiry: SystemTime::now() + Duration::from_secs(60), - s: crate::v2::gen_keypair(), + s: HpkeKeyPair::gen_keypair(), e: None, }, }; diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index a1f2ee31..29e791a7 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -242,7 +242,7 @@ impl<'a> RequestBuilder<'a> { sequence, min_fee_rate: self.min_fee_rate, #[cfg(feature = "v2")] - e: crate::v2::gen_keypair().0, + e: crate::v2::HpkeKeyPair::gen_keypair().secret_key(), }) } } diff --git a/payjoin/src/v2.rs b/payjoin/src/v2.rs index c04f580b..09308682 100644 --- a/payjoin/src/v2.rs +++ b/payjoin/src/v2.rs @@ -9,6 +9,7 @@ use hpke::kdf::HkdfSha256; use hpke::kem::SecpK256HkdfSha256; use hpke::rand_core::OsRng; use hpke::{Deserializable, OpModeR, OpModeS, Serializable}; +use serde::{Deserialize, Serialize}; pub const PADDED_MESSAGE_BYTES: usize = 7168; pub const PADDED_PLAINTEXT_A_LENGTH: usize = @@ -23,9 +24,20 @@ pub type EncappedKey = ::EncappedKey; fn sk_to_pk(sk: &SecretKey) -> PublicKey { ::sk_to_pk(sk) } -pub(crate) fn gen_keypair() -> (HpkeSecretKey, HpkePublicKey) { - let (sk, pk) = ::gen_keypair(&mut OsRng); - (HpkeSecretKey(sk), HpkePublicKey(pk)) +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct HpkeKeyPair(pub HpkeSecretKey, pub HpkePublicKey); + +impl From for (HpkeSecretKey, HpkePublicKey) { + fn from(value: HpkeKeyPair) -> Self { (value.0, value.1) } +} + +impl HpkeKeyPair { + pub fn gen_keypair() -> Self { + let (sk, pk) = ::gen_keypair(&mut OsRng); + Self(HpkeSecretKey(sk), HpkePublicKey(pk)) + } + pub fn secret_key(self) -> HpkeSecretKey { self.0 } + pub fn public_key(self) -> HpkePublicKey { self.1 } } #[derive(Clone, PartialEq, Eq)] @@ -164,13 +176,15 @@ pub fn decrypt_message_a( #[cfg(feature = "receive")] pub fn encrypt_message_b( mut plaintext: Vec, - receiver_keypair: (HpkeSecretKey, HpkePublicKey), + receiver_keypair: &HpkeKeyPair, sender_pk: &HpkePublicKey, ) -> Result, HpkeError> { - let pk = sk_to_pk(&receiver_keypair.0 .0); let (encapsulated_key, mut encryption_context) = hpke::setup_sender::( - &OpModeS::Auth((receiver_keypair.0 .0, pk.clone())), + &OpModeS::Auth(( + receiver_keypair.clone().secret_key().0, + receiver_keypair.clone().public_key().0, + )), &sender_pk.0, INFO_B, &mut OsRng,