Skip to content

Commit

Permalink
fix: sha precompute selector
Browse files Browse the repository at this point in the history
  • Loading branch information
Bisht13 committed Dec 11, 2024
1 parent cf873f9 commit ae967b7
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 63 deletions.
104 changes: 97 additions & 7 deletions src/circuit.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use anyhow::Result;
use anyhow::{anyhow, Result};
use num_bigint::BigInt;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
Expand Down Expand Up @@ -156,6 +156,88 @@ impl CircuitInputParams {
}
}

/// Finds a selector string in cleaned content and maps it back to its original position.
///
/// # Arguments
/// * `clean_content` - The cleaned content as a slice of bytes (no QP soft line breaks).
/// * `selector` - The string to find in the cleaned content.
/// * `position_map` - A slice mapping cleaned indices to original indices.
/// For each i, `position_map[i]` is the index in `original_body` where that cleaned byte originated.
/// If `position_map[i]` is `usize::MAX`, that cleaned position has no corresponding original position.
///
/// # Returns
/// A tuple containing `(selector, original_index)`.
///
/// # Errors
/// Returns an error if the selector is not found in the cleaned content or if the position mapping fails.
fn find_selector_in_clean_content(
clean_content: &[u8],
selector: &str,
position_map: &[usize],
) -> Result<(String, usize)> {
let clean_string = String::from_utf8_lossy(clean_content);
if let Some(selector_index) = clean_string.find(selector) {
// Map this cleaned index back to original
if selector_index < position_map.len() {
let original_index = position_map[selector_index];
if original_index == usize::MAX {
return Err(anyhow!("Failed to map selector position to original body"));
}
Ok((selector.to_string(), original_index))
} else {
Err(anyhow!("Selector index out of range in position map"))
}
} else {
Err(anyhow!(
"SHA precompute selector \"{}\" not found in cleaned body",
selector
))
}
}

/// Gets the adjusted selector string that accounts for potential soft line breaks in QP encoding.
/// If the selector exists in the original body, returns it as-is. Otherwise, finds it in cleaned
/// content and maps it back to the original format, including any soft line breaks.
///
/// # Arguments
/// * `original_body` - The original body as a slice of bytes, possibly containing QP soft line breaks.
/// * `selector` - The string to find in the content.
/// * `clean_content` - The cleaned content with soft line breaks removed.
/// * `position_map` - The index mapping from cleaned content to original content.
///
/// # Returns
/// The adjusted selector string that matches the original body format.
///
/// # Errors
/// Returns an error if the selector cannot be found in either the original or cleaned content.
fn get_adjusted_selector(
original_body: &[u8],
selector: &str,
clean_content: &[u8],
position_map: &[usize],
) -> Result<String> {
let original_str = String::from_utf8_lossy(original_body);

// First, try finding the selector in the original body as-is
if original_str.contains(selector) {
return Ok(selector.to_string());
}

// If not found, we must find it in the cleaned content and map back to original
let (_, original_index) =
find_selector_in_clean_content(clean_content, selector, position_map)?;

// Retrieve the substring from the original body that corresponds to the found selector plus 3 chars
// Note: This +3 accounts for the possible "=\r\n" that may have been present.
// Ensure we don't go out of bounds:
let end_index = std::cmp::min(original_body.len(), original_index + selector.len() + 3);
let adjusted_slice = &original_body[original_index..end_index];

// Convert back to a string. If invalid UTF-8, use lossy conversion.
let adjusted_str = String::from_utf8_lossy(adjusted_slice);
Ok(adjusted_str.to_string())
}

/// Generates the inputs for the circuit from the given parameters.
///
/// This function takes `CircuitInputParams` which includes the email body and header,
Expand Down Expand Up @@ -195,23 +277,31 @@ fn generate_circuit_inputs(params: CircuitInputParams) -> Result<CircuitInput> {
if !params.ignore_body_hash_check {
// Calculate the length needed for SHA-256 padding of the body
let body_sha_length = ((params.body.len() + 63 + 65) / 64) * 64;
println!("Body SHA length: {}", body_sha_length);
println!("Max body length: {}", params.max_body_length);
println!("Body length: {}", params.body.len());
// Pad the body to the maximum length or the calculated SHA-256 padding length
let (body_padded, body_padded_len) = sha256_pad(
params.body,
params.body.clone(),
cmp::max(params.max_body_length, body_sha_length),
);

println!("Body padded length: {}", body_padded_len);
let mut adjusted_selector = params.sha_precompute_selector;

if adjusted_selector.is_some() {
let (cleaned_body, position_map) =
remove_quoted_printable_soft_breaks(body_padded.clone());
adjusted_selector = Some(get_adjusted_selector(
&params.body,
&adjusted_selector.as_ref().unwrap(),
&cleaned_body,
&position_map,
)?);
}

// Ensure that the error type returned by `generate_partial_sha` is sized
// by converting it into an `anyhow::Error` if it's not already.
let result = generate_partial_sha(
body_padded,
body_padded_len,
params.sha_precompute_selector,
adjusted_selector,
params.max_body_length,
);

Expand Down
72 changes: 30 additions & 42 deletions src/cryptos.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#[cfg(target_arch = "wasm32")]
use crate::EmailHeaders;
use crate::{field_to_hex, find_index_in_body, hex_to_field, remove_quoted_printable_soft_breaks};
use anyhow::Result;
use anyhow::{anyhow, Result};
use ethers::types::Bytes;
use halo2curves::ff::Field;
use poseidon_rs::{poseidon_bytes, poseidon_fields, Fr, PoseidonError};
Expand Down Expand Up @@ -457,27 +457,6 @@ pub fn partial_sha(msg: &[u8], msg_len: usize) -> Vec<u8> {
result.to_vec()
}

/// Finds the original indices in `body` that correspond to `pattern` in the `cleaned_body`.
/// Returns `Some((original_start, original_end))` if found, or `None` if the pattern isn't present.
fn find_original_indices_for_pattern(
body: &[u8],
cleaned_body: &[u8],
index_map: &[usize],
pattern: &[u8],
) -> Option<(usize, usize)> {
// Search the pattern in cleaned_body
if let Some(cleaned_start_index) = cleaned_body
.windows(pattern.len())
.position(|window| window == pattern)
{
let original_start = index_map[cleaned_start_index];
let original_end = index_map[cleaned_start_index + pattern.len() - 1];
Some((original_start, original_end))
} else {
None
}
}

/// Generates a partial SHA-256 hash of a message up to the point of a selector string, if provided.
///
/// # Arguments
Expand All @@ -497,30 +476,41 @@ pub fn generate_partial_sha(
selector_regex: Option<String>,
max_remaining_body_length: usize,
) -> PartialShaResult {
let (cleaned_body, index_map) = remove_quoted_printable_soft_breaks(body.clone());

let selector_bytes = selector_regex.as_deref().map(|s| s.as_bytes());
let (selector_index, _) = find_original_indices_for_pattern(
&body,
&cleaned_body,
&index_map,
selector_bytes.expect("Selector bytes not found"),
)
.ok_or_else(|| {
Box::new(std::io::Error::new(
std::io::ErrorKind::Other,
"Selector not found in the body",
))
})?;
let mut selector_index = 0;

// Check if a selector is provided
if let Some(selector) = selector_regex {
// Create a regex pattern from the selector
let pattern = regex::Regex::new(&selector).unwrap();
let body_str = {
// Undo SHA padding
let mut trimmed_body = body.clone();
while !(trimmed_body.last() == Some(&10)
&& trimmed_body.get(trimmed_body.len() - 2) == Some(&13))
{
trimmed_body.pop();
}

String::from_utf8(trimmed_body).unwrap()
};

// Find the index of the selector in the body
if let Some(matched) = pattern.find(&body_str) {
selector_index = matched.start();
} else {
return Err(Box::new(std::io::Error::new(
std::io::ErrorKind::Other,
format!("Selector {} not found in the body", selector),
)));
}
};

// Calculate the cutoff index for SHA-256 block size (64 bytes)
let sha_cutoff_index = (selector_index / 64) * 64;
let precompute_text = &body[..sha_cutoff_index];
let mut body_remaining = body[sha_cutoff_index..].to_vec();

let body_remaining_length = body.len() - precompute_text.len();

println!("body_remaining_length: {}", body_remaining_length);
let body_remaining_length = body_length - precompute_text.len();

// Check if the remaining body length exceeds the maximum allowed length
if body_remaining_length > max_remaining_body_length {
Expand Down Expand Up @@ -548,8 +538,6 @@ pub fn generate_partial_sha(

// Compute the SHA-256 hash of the pre-selector part of the message
let precomputed_sha = partial_sha(precompute_text, sha_cutoff_index);

println!("body: {:?}", body_remaining);
Ok((precomputed_sha, body_remaining, body_remaining_length))
}

Expand Down
49 changes: 35 additions & 14 deletions src/parse_email.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,44 +268,65 @@ impl ParsedEmail {
}
}

/// Removes quoted-printable soft line breaks from an email body.
/// Removes Quoted-Printable (QP) soft line breaks (`=\r\n`) from the given byte vector while
/// maintaining a mapping from cleaned indices back to the original positions.
///
/// Quoted-printable encoding uses `=` followed by `\r\n` to indicate a soft line break.
/// This function removes such sequences from the input `Vec<u8>`.
/// Quoted-printable encoding may split long lines with `=\r\n` sequences. This function removes
/// these soft line breaks, producing a "cleaned" output array. It also creates an index map so
/// that for each position in the cleaned output, you can find the corresponding original index.
///
/// Any positions in the cleaned output that were added as padding (to match the original length)
/// will have their index map entry set to `usize::MAX`, indicating no corresponding original index.
///
/// # Arguments
///
/// * `body` - A `Vec<u8>` representing the email body to be cleaned.
/// * `body` - A `Vec<u8>` containing the QP-encoded content.
///
/// # Returns
///
/// A `Vec<u8>` with all quoted-printable soft line breaks removed.
/// A tuple of:
/// - `Vec<u8>`: The cleaned content, with all QP soft line breaks removed and padded with zeros
/// to match the original length.
/// - `Vec<usize>`: A mapping from cleaned indices to original indices. For cleaned indices that
/// correspond to actual content, `index_map[i]` gives the original position of
/// that byte in `body`. For padded bytes, the value is `usize::MAX`.
///
/// # Example
///
/// ```
/// let body = b"Hello=\r\nWorld".to_vec();
/// // body: [72,101,108,108,111,61,13,10,87,111,114,108,100]
/// let (clean_content, index_map) = remove_quoted_printable_soft_breaks(body);
///
/// // clean_content might look like [72,101,108,108,111,87,111,114,108,100,0,0,0]
/// // index_map might map:
/// // 0->0, 1->1, 2->2, 3->3, 4->4, 5->8, 6->9, 7->10, 8->11, 9->12, and the rest are usize::MAX.
/// ```
pub fn remove_quoted_printable_soft_breaks(body: Vec<u8>) -> (Vec<u8>, Vec<usize>) {
let original_len = body.len();
let mut result = Vec::with_capacity(original_len);
let mut cleaned = Vec::with_capacity(original_len);
let mut index_map = Vec::with_capacity(original_len);

let mut iter = body.iter().enumerate();
while let Some((i, &byte)) = iter.next() {
// Check if this is the start of a soft line break sequence `=\r\n`
if byte == b'=' && body.get(i + 1..i + 3) == Some(&[b'\r', b'\n']) {
// Skip the next two bytes (soft line break)
// Skip the next two bytes for the soft line break
iter.nth(1);
} else {
result.push(byte);
cleaned.push(byte);
index_map.push(i);
}
}

// Pad `result` to the original length with zeros
result.resize(original_len, 0);
// Pad the cleaned result with zeros to match the original length
cleaned.resize(original_len, 0);

// Pad `index_map` to the same length.
// Since these extra bytes don't map to anything in the original body,
// use a placeholder like usize::MAX.
// Pad index_map with usize::MAX for these padded positions
let padding_needed = original_len - index_map.len();
index_map.extend(std::iter::repeat(usize::MAX).take(padding_needed));

(result, index_map)
(cleaned, index_map)
}

/// Finds the index of the first occurrence of a pattern in the given body.
Expand Down

0 comments on commit ae967b7

Please sign in to comment.