From f534375be982f3449a1b04cec98d188cc746badf Mon Sep 17 00:00:00 2001 From: DanGould Date: Tue, 25 Jun 2024 12:43:34 -0400 Subject: [PATCH] Handle fragments with uri::UrlExt trait This extension trait defines functions to parse and set the ohttp parameter in the fragment of a `pj=` URL. Close #298 --- payjoin/src/send/mod.rs | 20 ++---- payjoin/src/uri/error.rs | 7 --- payjoin/src/uri/mod.rs | 127 ++++++++++++++++++++++---------------- payjoin/src/uri/pj_url.rs | 43 ------------- 4 files changed, 78 insertions(+), 119 deletions(-) delete mode 100644 payjoin/src/uri/pj_url.rs diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index b8fdd113..0561603f 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -202,8 +202,6 @@ impl<'a> RequestBuilder<'a> { psbt.validate_input_utxos(true) .map_err(InternalCreateRequestError::InvalidOriginalInput)?; let endpoint = self.uri.extras.endpoint.clone(); - #[cfg(feature = "v2")] - let ohttp_keys = self.uri.extras.ohttp_keys; let disable_output_substitution = self.uri.extras.disable_output_substitution || self.disable_output_substitution; let payee = self.uri.address.script_pubkey(); @@ -234,8 +232,6 @@ impl<'a> RequestBuilder<'a> { Ok(RequestContext { psbt, endpoint, - #[cfg(feature = "v2")] - ohttp_keys, disable_output_substitution, fee_contribution, payee, @@ -252,8 +248,6 @@ impl<'a> RequestBuilder<'a> { pub struct RequestContext { psbt: Psbt, endpoint: Url, - #[cfg(feature = "v2")] - ohttp_keys: Option, disable_output_substitution: bool, fee_contribution: Option<(bitcoin::Amount, usize)>, min_fee_rate: FeeRate, @@ -303,6 +297,7 @@ impl RequestContext { &mut self, ohttp_relay: Url, ) -> Result<(Request, ContextV2), CreateRequestError> { + use crate::uri::PjUrlExt; let rs = Self::rs_pubkey_from_dir_endpoint(&self.endpoint)?; let url = self.endpoint.clone(); let body = serialize_v2_body( @@ -314,7 +309,7 @@ impl RequestContext { let body = crate::v2::encrypt_message_a(body, self.e, rs) .map_err(InternalCreateRequestError::Hpke)?; let (body, ohttp_res) = crate::v2::ohttp_encapsulate( - self.ohttp_keys.as_mut().ok_or(InternalCreateRequestError::MissingOhttpConfig)?, + self.endpoint.ohttp().as_mut().ok_or(InternalCreateRequestError::MissingOhttpConfig)?, "POST", url.as_str(), Some(&body), @@ -384,7 +379,6 @@ impl Serialize for RequestContext { let mut state = serializer.serialize_struct("RequestContext", 8)?; state.serialize_field("psbt", &self.psbt.to_string())?; state.serialize_field("endpoint", &self.endpoint.as_str())?; - state.serialize_field("ohttp_keys", &self.ohttp_keys)?; state.serialize_field("disable_output_substitution", &self.disable_output_substitution)?; state.serialize_field( "fee_contribution", @@ -433,7 +427,6 @@ impl<'de> Deserialize<'de> for RequestContext { { let mut psbt = None; let mut endpoint = None; - let mut ohttp_keys = None; let mut disable_output_substitution = None; let mut fee_contribution = None; let mut min_fee_rate = None; @@ -453,7 +446,6 @@ impl<'de> Deserialize<'de> for RequestContext { url::Url::from_str(&map.next_value::()?) .map_err(de::Error::custom)?, ), - "ohttp_keys" => ohttp_keys = Some(map.next_value()?), "disable_output_substitution" => disable_output_substitution = Some(map.next_value()?), "fee_contribution" => { @@ -479,7 +471,6 @@ impl<'de> Deserialize<'de> for RequestContext { Ok(RequestContext { psbt: psbt.ok_or_else(|| de::Error::missing_field("psbt"))?, endpoint: endpoint.ok_or_else(|| de::Error::missing_field("endpoint"))?, - ohttp_keys: ohttp_keys.ok_or_else(|| de::Error::missing_field("ohttp_keys"))?, disable_output_substitution: disable_output_substitution .ok_or_else(|| de::Error::missing_field("disable_output_substitution"))?, fee_contribution, @@ -975,7 +966,7 @@ fn serialize_v2_body( ) -> Result, CreateRequestError> { // Grug say localhost base be discarded anyway. no big brain needed. let placeholder_url = serialize_url( - "http:/localhost".to_string(), + Url::parse("http://localhost").unwrap(), disable_output_substitution, fee_contribution, min_feerate, @@ -987,12 +978,12 @@ fn serialize_v2_body( } fn serialize_url( - endpoint: String, + endpoint: Url, disable_output_substitution: bool, fee_contribution: Option<(bitcoin::Amount, usize)>, min_fee_rate: FeeRate, ) -> Result { - let mut url = Url::parse(&endpoint)?; + let mut url = endpoint; url.query_pairs_mut().append_pair("v", "1"); if disable_output_substitution { url.query_pairs_mut().append_pair("disableoutputsubstitution", "1"); @@ -1066,7 +1057,6 @@ mod test { let req_ctx = RequestContext { psbt: Psbt::from_str(ORIGINAL_PSBT).unwrap(), endpoint: Url::parse("http://localhost:1234").unwrap(), - ohttp_keys: None, disable_output_substitution: false, fee_contribution: None, min_fee_rate: FeeRate::ZERO, diff --git a/payjoin/src/uri/error.rs b/payjoin/src/uri/error.rs index 443394a6..f44d7bed 100644 --- a/payjoin/src/uri/error.rs +++ b/payjoin/src/uri/error.rs @@ -1,6 +1,3 @@ -#[cfg(feature = "v2")] -use crate::uri::OhttpKeysParseError; - #[derive(Debug)] pub struct PjParseError(InternalPjParseError); @@ -11,8 +8,6 @@ pub(crate) enum InternalPjParseError { MissingEndpoint, NotUtf8, BadEndpoint, - #[cfg(feature = "v2")] - ParseOhttpKeys(OhttpKeysParseError), UnsecureEndpoint, } @@ -30,8 +25,6 @@ impl std::fmt::Display for PjParseError { InternalPjParseError::MissingEndpoint => write!(f, "Missing payjoin endpoint"), InternalPjParseError::NotUtf8 => write!(f, "Endpoint is not valid UTF-8"), InternalPjParseError::BadEndpoint => write!(f, "Endpoint is not valid"), - #[cfg(feature = "v2")] - InternalPjParseError::ParseOhttpKeys(e) => write!(f, "OHTTP Keys are not valid: {}", e), InternalPjParseError::UnsecureEndpoint => { write!(f, "Endpoint scheme is not secure (https or onion)") } diff --git a/payjoin/src/uri/mod.rs b/payjoin/src/uri/mod.rs index 625d7ff5..fed988e9 100644 --- a/payjoin/src/uri/mod.rs +++ b/payjoin/src/uri/mod.rs @@ -6,9 +6,6 @@ pub use error::PjParseError; use url::Url; use crate::uri::error::InternalPjParseError; -#[cfg(feature = "v2")] -use crate::v2::OhttpKeysParseError; - pub mod error; #[cfg(feature = "v2")] @@ -33,8 +30,6 @@ impl MaybePayjoinExtras { pub struct PayjoinExtras { pub(crate) endpoint: Url, pub(crate) disable_output_substitution: bool, - #[cfg(feature = "v2")] - pub(crate) ohttp_keys: Option, } impl PayjoinExtras { @@ -99,30 +94,25 @@ pub struct PjUriBuilder { pj: Url, /// Whether or not payjoin output substitution is allowed pjos: bool, - #[cfg(feature = "v2")] - /// Config for ohttp. - /// - /// Required only for v2 payjoin. - ohttp: Option, } impl PjUriBuilder { /// Create a new `PjUriBuilder` with required parameters. + /// + /// ## Parameters + /// - `address`: Represents a bitcoin address. + /// - `origin`: Represents either the payjoin endpoint in v1 or the directory in v2. + /// - `ohttp_keys`: Optional OHTTP keys for v2 (only available if the "v2" feature is enabled). pub fn new( address: Address, - pj: Url, + origin: Url, #[cfg(feature = "v2")] ohttp_keys: Option, ) -> Self { - Self { - address, - amount: None, - message: None, - label: None, - pj, - pjos: false, - #[cfg(feature = "v2")] - ohttp: ohttp_keys, - } + #[allow(unused_mut)] + let mut pj = origin; + #[cfg(feature = "v2")] + pj.set_ohttp(ohttp_keys); + Self { address, amount: None, message: None, label: None, pj, pjos: false } } /// Set the amount you want to receive. pub fn amount(mut self, amount: Amount) -> Self { @@ -153,12 +143,7 @@ impl PjUriBuilder { /// Constructs a `bip21::Uri` with PayjoinParams from the /// parameters set in the builder. pub fn build<'a>(self) -> PjUri<'a> { - let extras = PayjoinExtras { - endpoint: self.pj, - disable_output_substitution: self.pjos, - #[cfg(feature = "v2")] - ohttp_keys: self.ohttp, - }; + let extras = PayjoinExtras { endpoint: self.pj, disable_output_substitution: self.pjos }; let mut pj_uri = bip21::Uri::with_extras(self.address, extras); pj_uri.amount = self.amount; pj_uri.label = self.label.map(Into::into); @@ -183,8 +168,6 @@ impl<'a> bip21::de::DeserializeParams<'a> for MaybePayjoinExtras { pub struct DeserializationState { pj: Option, pjos: Option, - #[cfg(feature = "v2")] - ohttp: Option, } impl<'a> bip21::SerializeParams for &'a MaybePayjoinExtras { @@ -206,18 +189,11 @@ impl<'a> bip21::SerializeParams for &'a PayjoinExtras { type Iterator = std::vec::IntoIter<(Self::Key, Self::Value)>; fn serialize_params(self) -> Self::Iterator { - #[allow(unused_mut)] - let mut params = vec![ + vec![ ("pj", self.endpoint.as_str().to_string()), ("pjos", if self.disable_output_substitution { "1" } else { "0" }.to_string()), - ]; - #[cfg(feature = "v2")] - if let Some(ohttp_keys) = &self.ohttp_keys { - params.push(("ohttp", ohttp_keys.to_string())); - } else { - log::warn!("Failed to encode ohttp config, ignoring"); - } - params.into_iter() + ] + .into_iter() } } @@ -235,19 +211,6 @@ impl<'a> bip21::de::DeserializationState<'a> for DeserializationState { ::Error, > { match key { - #[cfg(feature = "v2")] - "ohttp" if self.ohttp.is_none() => { - use std::str::FromStr; - - let base64_config = - Cow::try_from(value).map_err(|_| InternalPjParseError::NotUtf8)?; - let config = OhttpKeys::from_str(&base64_config) - .map_err(InternalPjParseError::ParseOhttpKeys)?; - self.ohttp = Some(config); - Ok(bip21::de::ParamKind::Known) - } - #[cfg(feature = "v2")] - "ohttp" => Err(InternalPjParseError::MultipleParams("ohttp").into()), "pj" if self.pj.is_none() => { let endpoint = Cow::try_from(value).map_err(|_| InternalPjParseError::NotUtf8)?; let url = Url::parse(&endpoint).map_err(|_| InternalPjParseError::BadEndpoint)?; @@ -283,8 +246,6 @@ impl<'a> bip21::de::DeserializationState<'a> for DeserializationState { Ok(MaybePayjoinExtras::Supported(PayjoinExtras { endpoint, disable_output_substitution: pjos.unwrap_or(false), - #[cfg(feature = "v2")] - ohttp_keys: self.ohttp, })) } else { Err(InternalPjParseError::UnsecureEndpoint.into()) @@ -294,6 +255,50 @@ impl<'a> bip21::de::DeserializationState<'a> for DeserializationState { } } +/// Parse and set fragment parameters from `&pj=` URLs +#[cfg(feature = "v2")] +pub trait PjUrlExt { + fn ohttp(&self) -> Option; + fn set_ohttp(&mut self, ohttp: Option); +} + +#[cfg(feature = "v2")] +impl PjUrlExt for Url { + fn ohttp(&self) -> Option { + self.fragment().and_then(|f| { + let parts: Vec<&str> = f.splitn(2, "ohttp=").collect(); + if parts.len() == 2 { + let base64_config = Cow::try_from(parts[1]).ok()?; + let config_bytes = + bitcoin::base64::decode_config(&*base64_config, bitcoin::base64::URL_SAFE) + .ok()?; + OhttpKeys::decode(&config_bytes).ok() + } else { + None + } + }) + } + + fn set_ohttp(&mut self, ohttp: Option) { + if let Some(ohttp) = ohttp { + let new_ohttp = format!("ohttp={}", ohttp.to_string()); + let mut fragment = self.fragment().unwrap_or("").to_string(); + if let Some(start) = fragment.find("ohttp=") { + let end = fragment[start..].find('&').map_or(fragment.len(), |i| start + i); + fragment.replace_range(start..end, &new_ohttp); + } else { + if !fragment.is_empty() { + fragment.push('&'); + } + fragment.push_str(&new_ohttp); + } + self.set_fragment(Some(&fragment)); + } else { + self.set_fragment(None); + } + } +} + #[cfg(test)] mod tests { use std::convert::TryFrom; @@ -400,4 +405,18 @@ mod tests { } } } + + #[test] + #[cfg(feature = "v2")] + fn test_url_ext_ohttp_fragment() { + use url::Url; + + use super::PjUrlExt; + + let url = Url::parse( + "https://example.com#ohttp=AQAg3WpRjS0aqAxQUoLvpas2VYjT2oIg6-3XSiB-QiYI1BAABAABAAM", + ) + .unwrap(); + assert!(url.ohttp().is_some()); + } } diff --git a/payjoin/src/uri/pj_url.rs b/payjoin/src/uri/pj_url.rs deleted file mode 100644 index 199c922a..00000000 --- a/payjoin/src/uri/pj_url.rs +++ /dev/null @@ -1,43 +0,0 @@ -use url::Url; - -pub struct PjUrl { - url: Url, - ohttp: Option, -} - -impl PjUrl { - pub fn new(url: Url) -> Self { - let (url, ohttp) = Self::extract_ohttp(url); - PjUrl { url, ohttp } - } - - fn extract_ohttp(mut url: Url) -> (Url, Option) { - let fragment = &mut url.fragment().and_then(|f| { - let parts: Vec<&str> = f.splitn(2, "ohttp=").collect(); - if parts.len() == 2 { - Some((parts[0].trim_end_matches('&'), parts[1].to_string())) - } else { - None - } - }); - - if let Some((remaining_fragment, ohttp)) = fragment { - url.set_fragment(Some(remaining_fragment)); - (url, Some(ohttp)) - } else { - (url, None) - } - } - - pub fn into_url(self) -> Url { - let mut url = self.url; - if let Some(ohttp) = self.ohttp { - let fragment = url - .fragment() - .map(|f| format!("{}&ohttp={}", f, ohttp)) - .unwrap_or_else(|| format!("ohttp={}", ohttp)); - url.set_fragment(Some(&fragment)); - } - url - } -}