From ae967b7a5c329b02b57a47307810bfafdc6b4cbc Mon Sep 17 00:00:00 2001 From: Aditya Bisht Date: Thu, 12 Dec 2024 03:40:47 +0530 Subject: [PATCH] fix: sha precompute selector --- src/circuit.rs | 104 ++++++++++++++++++++++++++++++++++++++++++--- src/cryptos.rs | 72 +++++++++++++------------------ src/parse_email.rs | 49 +++++++++++++++------ 3 files changed, 162 insertions(+), 63 deletions(-) diff --git a/src/circuit.rs b/src/circuit.rs index d84e586..156c3c3 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use num_bigint::BigInt; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; @@ -156,6 +156,88 @@ impl CircuitInputParams { } } +/// Finds a selector string in cleaned content and maps it back to its original position. +/// +/// # Arguments +/// * `clean_content` - The cleaned content as a slice of bytes (no QP soft line breaks). +/// * `selector` - The string to find in the cleaned content. +/// * `position_map` - A slice mapping cleaned indices to original indices. +/// For each i, `position_map[i]` is the index in `original_body` where that cleaned byte originated. +/// If `position_map[i]` is `usize::MAX`, that cleaned position has no corresponding original position. +/// +/// # Returns +/// A tuple containing `(selector, original_index)`. +/// +/// # Errors +/// Returns an error if the selector is not found in the cleaned content or if the position mapping fails. +fn find_selector_in_clean_content( + clean_content: &[u8], + selector: &str, + position_map: &[usize], +) -> Result<(String, usize)> { + let clean_string = String::from_utf8_lossy(clean_content); + if let Some(selector_index) = clean_string.find(selector) { + // Map this cleaned index back to original + if selector_index < position_map.len() { + let original_index = position_map[selector_index]; + if original_index == usize::MAX { + return Err(anyhow!("Failed to map selector position to original body")); + } + Ok((selector.to_string(), original_index)) + } else { + Err(anyhow!("Selector index out of range in position map")) + } + } else { + Err(anyhow!( + "SHA precompute selector \"{}\" not found in cleaned body", + selector + )) + } +} + +/// Gets the adjusted selector string that accounts for potential soft line breaks in QP encoding. +/// If the selector exists in the original body, returns it as-is. Otherwise, finds it in cleaned +/// content and maps it back to the original format, including any soft line breaks. +/// +/// # Arguments +/// * `original_body` - The original body as a slice of bytes, possibly containing QP soft line breaks. +/// * `selector` - The string to find in the content. +/// * `clean_content` - The cleaned content with soft line breaks removed. +/// * `position_map` - The index mapping from cleaned content to original content. +/// +/// # Returns +/// The adjusted selector string that matches the original body format. +/// +/// # Errors +/// Returns an error if the selector cannot be found in either the original or cleaned content. +fn get_adjusted_selector( + original_body: &[u8], + selector: &str, + clean_content: &[u8], + position_map: &[usize], +) -> Result { + let original_str = String::from_utf8_lossy(original_body); + + // First, try finding the selector in the original body as-is + if original_str.contains(selector) { + return Ok(selector.to_string()); + } + + // If not found, we must find it in the cleaned content and map back to original + let (_, original_index) = + find_selector_in_clean_content(clean_content, selector, position_map)?; + + // Retrieve the substring from the original body that corresponds to the found selector plus 3 chars + // Note: This +3 accounts for the possible "=\r\n" that may have been present. + // Ensure we don't go out of bounds: + let end_index = std::cmp::min(original_body.len(), original_index + selector.len() + 3); + let adjusted_slice = &original_body[original_index..end_index]; + + // Convert back to a string. If invalid UTF-8, use lossy conversion. + let adjusted_str = String::from_utf8_lossy(adjusted_slice); + Ok(adjusted_str.to_string()) +} + /// Generates the inputs for the circuit from the given parameters. /// /// This function takes `CircuitInputParams` which includes the email body and header, @@ -195,23 +277,31 @@ fn generate_circuit_inputs(params: CircuitInputParams) -> Result { if !params.ignore_body_hash_check { // Calculate the length needed for SHA-256 padding of the body let body_sha_length = ((params.body.len() + 63 + 65) / 64) * 64; - println!("Body SHA length: {}", body_sha_length); - println!("Max body length: {}", params.max_body_length); - println!("Body length: {}", params.body.len()); // Pad the body to the maximum length or the calculated SHA-256 padding length let (body_padded, body_padded_len) = sha256_pad( - params.body, + params.body.clone(), cmp::max(params.max_body_length, body_sha_length), ); - println!("Body padded length: {}", body_padded_len); + let mut adjusted_selector = params.sha_precompute_selector; + + if adjusted_selector.is_some() { + let (cleaned_body, position_map) = + remove_quoted_printable_soft_breaks(body_padded.clone()); + adjusted_selector = Some(get_adjusted_selector( + ¶ms.body, + &adjusted_selector.as_ref().unwrap(), + &cleaned_body, + &position_map, + )?); + } // Ensure that the error type returned by `generate_partial_sha` is sized // by converting it into an `anyhow::Error` if it's not already. let result = generate_partial_sha( body_padded, body_padded_len, - params.sha_precompute_selector, + adjusted_selector, params.max_body_length, ); diff --git a/src/cryptos.rs b/src/cryptos.rs index 4702f05..b1dc45b 100644 --- a/src/cryptos.rs +++ b/src/cryptos.rs @@ -3,7 +3,7 @@ #[cfg(target_arch = "wasm32")] use crate::EmailHeaders; use crate::{field_to_hex, find_index_in_body, hex_to_field, remove_quoted_printable_soft_breaks}; -use anyhow::Result; +use anyhow::{anyhow, Result}; use ethers::types::Bytes; use halo2curves::ff::Field; use poseidon_rs::{poseidon_bytes, poseidon_fields, Fr, PoseidonError}; @@ -457,27 +457,6 @@ pub fn partial_sha(msg: &[u8], msg_len: usize) -> Vec { result.to_vec() } -/// Finds the original indices in `body` that correspond to `pattern` in the `cleaned_body`. -/// Returns `Some((original_start, original_end))` if found, or `None` if the pattern isn't present. -fn find_original_indices_for_pattern( - body: &[u8], - cleaned_body: &[u8], - index_map: &[usize], - pattern: &[u8], -) -> Option<(usize, usize)> { - // Search the pattern in cleaned_body - if let Some(cleaned_start_index) = cleaned_body - .windows(pattern.len()) - .position(|window| window == pattern) - { - let original_start = index_map[cleaned_start_index]; - let original_end = index_map[cleaned_start_index + pattern.len() - 1]; - Some((original_start, original_end)) - } else { - None - } -} - /// Generates a partial SHA-256 hash of a message up to the point of a selector string, if provided. /// /// # Arguments @@ -497,30 +476,41 @@ pub fn generate_partial_sha( selector_regex: Option, max_remaining_body_length: usize, ) -> PartialShaResult { - let (cleaned_body, index_map) = remove_quoted_printable_soft_breaks(body.clone()); - - let selector_bytes = selector_regex.as_deref().map(|s| s.as_bytes()); - let (selector_index, _) = find_original_indices_for_pattern( - &body, - &cleaned_body, - &index_map, - selector_bytes.expect("Selector bytes not found"), - ) - .ok_or_else(|| { - Box::new(std::io::Error::new( - std::io::ErrorKind::Other, - "Selector not found in the body", - )) - })?; + let mut selector_index = 0; + + // Check if a selector is provided + if let Some(selector) = selector_regex { + // Create a regex pattern from the selector + let pattern = regex::Regex::new(&selector).unwrap(); + let body_str = { + // Undo SHA padding + let mut trimmed_body = body.clone(); + while !(trimmed_body.last() == Some(&10) + && trimmed_body.get(trimmed_body.len() - 2) == Some(&13)) + { + trimmed_body.pop(); + } + + String::from_utf8(trimmed_body).unwrap() + }; + + // Find the index of the selector in the body + if let Some(matched) = pattern.find(&body_str) { + selector_index = matched.start(); + } else { + return Err(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + format!("Selector {} not found in the body", selector), + ))); + } + }; // Calculate the cutoff index for SHA-256 block size (64 bytes) let sha_cutoff_index = (selector_index / 64) * 64; let precompute_text = &body[..sha_cutoff_index]; let mut body_remaining = body[sha_cutoff_index..].to_vec(); - let body_remaining_length = body.len() - precompute_text.len(); - - println!("body_remaining_length: {}", body_remaining_length); + let body_remaining_length = body_length - precompute_text.len(); // Check if the remaining body length exceeds the maximum allowed length if body_remaining_length > max_remaining_body_length { @@ -548,8 +538,6 @@ pub fn generate_partial_sha( // Compute the SHA-256 hash of the pre-selector part of the message let precomputed_sha = partial_sha(precompute_text, sha_cutoff_index); - - println!("body: {:?}", body_remaining); Ok((precomputed_sha, body_remaining, body_remaining_length)) } diff --git a/src/parse_email.rs b/src/parse_email.rs index 6eb0c96..c3ff085 100644 --- a/src/parse_email.rs +++ b/src/parse_email.rs @@ -268,44 +268,65 @@ impl ParsedEmail { } } -/// Removes quoted-printable soft line breaks from an email body. +/// Removes Quoted-Printable (QP) soft line breaks (`=\r\n`) from the given byte vector while +/// maintaining a mapping from cleaned indices back to the original positions. /// -/// Quoted-printable encoding uses `=` followed by `\r\n` to indicate a soft line break. -/// This function removes such sequences from the input `Vec`. +/// Quoted-printable encoding may split long lines with `=\r\n` sequences. This function removes +/// these soft line breaks, producing a "cleaned" output array. It also creates an index map so +/// that for each position in the cleaned output, you can find the corresponding original index. +/// +/// Any positions in the cleaned output that were added as padding (to match the original length) +/// will have their index map entry set to `usize::MAX`, indicating no corresponding original index. /// /// # Arguments /// -/// * `body` - A `Vec` representing the email body to be cleaned. +/// * `body` - A `Vec` containing the QP-encoded content. /// /// # Returns /// -/// A `Vec` with all quoted-printable soft line breaks removed. +/// A tuple of: +/// - `Vec`: The cleaned content, with all QP soft line breaks removed and padded with zeros +/// to match the original length. +/// - `Vec`: A mapping from cleaned indices to original indices. For cleaned indices that +/// correspond to actual content, `index_map[i]` gives the original position of +/// that byte in `body`. For padded bytes, the value is `usize::MAX`. +/// +/// # Example +/// +/// ``` +/// let body = b"Hello=\r\nWorld".to_vec(); +/// // body: [72,101,108,108,111,61,13,10,87,111,114,108,100] +/// let (clean_content, index_map) = remove_quoted_printable_soft_breaks(body); +/// +/// // clean_content might look like [72,101,108,108,111,87,111,114,108,100,0,0,0] +/// // index_map might map: +/// // 0->0, 1->1, 2->2, 3->3, 4->4, 5->8, 6->9, 7->10, 8->11, 9->12, and the rest are usize::MAX. +/// ``` pub fn remove_quoted_printable_soft_breaks(body: Vec) -> (Vec, Vec) { let original_len = body.len(); - let mut result = Vec::with_capacity(original_len); + let mut cleaned = Vec::with_capacity(original_len); let mut index_map = Vec::with_capacity(original_len); let mut iter = body.iter().enumerate(); while let Some((i, &byte)) = iter.next() { + // Check if this is the start of a soft line break sequence `=\r\n` if byte == b'=' && body.get(i + 1..i + 3) == Some(&[b'\r', b'\n']) { - // Skip the next two bytes (soft line break) + // Skip the next two bytes for the soft line break iter.nth(1); } else { - result.push(byte); + cleaned.push(byte); index_map.push(i); } } - // Pad `result` to the original length with zeros - result.resize(original_len, 0); + // Pad the cleaned result with zeros to match the original length + cleaned.resize(original_len, 0); - // Pad `index_map` to the same length. - // Since these extra bytes don't map to anything in the original body, - // use a placeholder like usize::MAX. + // Pad index_map with usize::MAX for these padded positions let padding_needed = original_len - index_map.len(); index_map.extend(std::iter::repeat(usize::MAX).take(padding_needed)); - (result, index_map) + (cleaned, index_map) } /// Finds the index of the first occurrence of a pattern in the given body.