diff --git a/Cargo.lock b/Cargo.lock index 131081ab0..fd9838f24 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2196,7 +2196,6 @@ dependencies = [ "schnorr-signatures", "secq256k1", "std-shims", - "subtle", "thiserror", "zeroize", ] @@ -2293,10 +2292,12 @@ name = "ec-divisors" version = "0.1.0" dependencies = [ "dalek-ff-group", + "ff", "group", "hex", "pasta_curves", "rand_core", + "subtle", "zeroize", ] diff --git a/crypto/dalek-ff-group/src/field.rs b/crypto/dalek-ff-group/src/field.rs index b1af27114..bc3078c84 100644 --- a/crypto/dalek-ff-group/src/field.rs +++ b/crypto/dalek-ff-group/src/field.rs @@ -35,7 +35,7 @@ impl_modulus!( type ResidueType = Residue; /// A constant-time implementation of the Ed25519 field. -#[derive(Clone, Copy, PartialEq, Eq, Default, Debug)] +#[derive(Clone, Copy, PartialEq, Eq, Default, Debug, Zeroize)] pub struct FieldElement(ResidueType); // Square root of -1. diff --git a/crypto/dkg/Cargo.toml b/crypto/dkg/Cargo.toml index cde0d1539..39ebb6dcb 100644 --- a/crypto/dkg/Cargo.toml +++ b/crypto/dkg/Cargo.toml @@ -37,7 +37,6 @@ schnorr = { package = "schnorr-signatures", path = "../schnorr", version = "^0.5 dleq = { path = "../dleq", version = "^0.4.1", default-features = false } # eVRF DKG dependencies -subtle = { version = "2", default-features = false, features = ["std"], optional = true } generic-array = { version = "1", default-features = false, features = ["alloc"], optional = true } blake2 = { version = "0.10", default-features = false, features = ["std"], optional = true } rand_chacha = { version = "0.3", default-features = false, features = ["std"], optional = true } @@ -82,7 +81,6 @@ borsh = ["dep:borsh"] evrf = [ "std", - "dep:subtle", "dep:generic-array", "dep:blake2", diff --git a/crypto/dkg/src/evrf/proof.rs b/crypto/dkg/src/evrf/proof.rs index ce9c57d14..8eb3ab002 100644 --- a/crypto/dkg/src/evrf/proof.rs +++ b/crypto/dkg/src/evrf/proof.rs @@ -1,6 +1,5 @@ use core::{marker::PhantomData, ops::Deref, fmt}; -use subtle::*; use zeroize::{Zeroize, Zeroizing}; use rand_core::{RngCore, CryptoRng, SeedableRng}; @@ -10,10 +9,7 @@ use generic_array::{typenum::Unsigned, ArrayLength, GenericArray}; use blake2::{Digest, Blake2s256}; use ciphersuite::{ - group::{ - ff::{Field, PrimeField, PrimeFieldBits}, - Group, GroupEncoding, - }, + group::{ff::Field, Group, GroupEncoding}, Ciphersuite, }; @@ -24,7 +20,7 @@ use generalized_bulletproofs::{ }; use generalized_bulletproofs_circuit_abstraction::*; -use ec_divisors::{DivisorCurve, new_divisor}; +use ec_divisors::{DivisorCurve, ScalarDecomposition}; use generalized_bulletproofs_ec_gadgets::*; /// A pair of curves to perform the eVRF with. @@ -309,147 +305,6 @@ impl Evrf { debug_assert!(challenged_generators.next().is_none()); } - /// Convert a scalar to a sequence of coefficients for the polynomial 2**i, where the sum of the - /// coefficients is F::NUM_BITS. - /// - /// Despite the name, the returned coefficients are not guaranteed to be bits (0 or 1). - /// - /// This scalar will presumably be used in a discrete log proof. That requires calculating a - /// divisor which is variable time to the amount of points interpolated. Since the amount of - /// points interpolated is equal to the sum of the coefficients in the polynomial, we need all - /// scalars to have a constant sum of their coefficients (instead of one variable to its bits). - /// - /// We achieve this by finding the highest non-0 coefficient, decrementing it, and increasing the - /// immediately less significant coefficient by 2. This increases the sum of the coefficients by - /// 1 (-1+2=1). - fn scalar_to_bits(scalar: &::F) -> Vec { - let num_bits = u64::from(<::EmbeddedCurve as Ciphersuite>::F::NUM_BITS); - - // Obtain the bits of the private key - let num_bits_usize = usize::try_from(num_bits).unwrap(); - let mut decomposition = vec![0; num_bits_usize]; - for (i, bit) in scalar.to_le_bits().into_iter().take(num_bits_usize).enumerate() { - let bit = u64::from(u8::from(bit)); - decomposition[i] = bit; - } - - // The following algorithm only works if the value of the scalar exceeds num_bits - // If it isn't, we increase it by the modulus such that it does exceed num_bits - { - let mut less_than_num_bits = Choice::from(0); - for i in 0 .. num_bits { - less_than_num_bits |= scalar.ct_eq(&::F::from(i)); - } - let mut decomposition_of_modulus = vec![0; num_bits_usize]; - // Decompose negative one - for (i, bit) in (-::F::ONE) - .to_le_bits() - .into_iter() - .take(num_bits_usize) - .enumerate() - { - let bit = u64::from(u8::from(bit)); - decomposition_of_modulus[i] = bit; - } - // Increment it by one - decomposition_of_modulus[0] += 1; - - // Add the decomposition onto the decomposition of the modulus - for i in 0 .. num_bits_usize { - let new_decomposition = <_>::conditional_select( - &decomposition[i], - &(decomposition[i] + decomposition_of_modulus[i]), - less_than_num_bits, - ); - decomposition[i] = new_decomposition; - } - } - - // Calculcate the sum of the coefficients - let mut sum_of_coefficients: u64 = 0; - for decomposition in &decomposition { - sum_of_coefficients += *decomposition; - } - - /* - Now, because we added a log2(k)-bit number to a k-bit number, we may have our sum of - coefficients be *too high*. We attempt to reduce the sum of the coefficients accordingly. - - This algorithm is guaranteed to complete as expected. Take the sequence `222`. `222` becomes - `032` becomes `013`. Even if the next coefficient in the sequence is `2`, the third - coefficient will be reduced once and the next coefficient (`2`, increased to `3`) will only - be eligible for reduction once. This demonstrates, even for a worst case of log2(k) `2`s - followed by `1`s (as possible if the modulus is a Mersenne prime), the log2(k) `2`s can be - reduced as necessary so long as there is a single coefficient after (requiring the entire - sequence be at least of length log2(k) + 1). For a 2-bit number, log2(k) + 1 == 2, so this - holds for any odd prime field. - - To fully type out the demonstration for the Mersenne prime 3, with scalar to encode 1 (the - highest value less than the number of bits): - - 10 - Little-endian bits of 1 - 21 - Little-endian bits of 1, plus the modulus - 02 - After one reduction, where the sum of the coefficients does in fact equal 2 (the target) - */ - { - let mut log2_num_bits = 0; - while (1 << log2_num_bits) < num_bits { - log2_num_bits += 1; - } - - for _ in 0 .. log2_num_bits { - // If the sum of coefficients is the amount of bits, we're done - let mut done = sum_of_coefficients.ct_eq(&num_bits); - - for i in 0 .. (num_bits_usize - 1) { - let should_act = (!done) & decomposition[i].ct_gt(&1); - // Subtract 2 from this coefficient - let amount_to_sub = <_>::conditional_select(&0, &2, should_act); - decomposition[i] -= amount_to_sub; - // Add 1 to the next coefficient - let amount_to_add = <_>::conditional_select(&0, &1, should_act); - decomposition[i + 1] += amount_to_add; - - // Also update the sum of coefficients - sum_of_coefficients -= <_>::conditional_select(&0, &1, should_act); - - // If we updated the coefficients this loop iter, we're done for this loop iter - done |= should_act; - } - } - } - - for _ in 0 .. num_bits { - // If the sum of coefficients is the amount of bits, we're done - let mut done = sum_of_coefficients.ct_eq(&num_bits); - - // Find the highest coefficient currently non-zero - for i in (1 .. decomposition.len()).rev() { - // If this is non-zero, we should decrement this coefficient if we haven't already - // decremented a coefficient this round - let is_non_zero = !(0.ct_eq(&decomposition[i])); - let should_act = (!done) & is_non_zero; - - // Update this coefficient and the prior coefficient - let amount_to_sub = <_>::conditional_select(&0, &1, should_act); - decomposition[i] -= amount_to_sub; - - let amount_to_add = <_>::conditional_select(&0, &2, should_act); - // i must be at least 1, so i - 1 will be at least 0 (meaning it's safe to index with) - decomposition[i - 1] += amount_to_add; - - // Also update the sum of coefficients - sum_of_coefficients += <_>::conditional_select(&0, &1, should_act); - - // If we updated the coefficients this loop iter, we're done for this loop iter - done |= should_act; - } - } - debug_assert!(bool::from(decomposition.iter().sum::().ct_eq(&num_bits))); - - decomposition - } - /// Prove a point on an elliptic curve had its discrete logarithm generated via an eVRF. pub(crate) fn prove( rng: &mut (impl RngCore + CryptoRng), @@ -471,11 +326,9 @@ impl Evrf { // A function to calculate a divisor and push it onto the tape // This defines a vec, divisor_points, outside of the fn to reuse its allocation - let mut divisor_points = - Vec::with_capacity((::F::NUM_BITS as usize) + 1); let mut divisor = |vector_commitment_tape: &mut Vec<_>, - dlog: &[u64], + dlog: &ScalarDecomposition<<::EmbeddedCurve as Ciphersuite>::F>, push_generator: bool, generator: <::EmbeddedCurve as Ciphersuite>::G, dh: <::EmbeddedCurve as Ciphersuite>::G| { @@ -484,24 +337,7 @@ impl Evrf { generator_tables.push(GeneratorTable::new(&curve_spec, x, y)); } - { - let mut generator = generator; - for coefficient in dlog { - let mut coefficient = *coefficient; - while coefficient != 0 { - coefficient -= 1; - divisor_points.push(generator); - } - generator = generator.double(); - } - debug_assert_eq!( - dlog.iter().sum::(), - u64::from(::F::NUM_BITS) - ); - } - divisor_points.push(-dh); - let mut divisor = new_divisor(&divisor_points).unwrap().normalize_x_coefficient(); - divisor_points.zeroize(); + let mut divisor = dlog.scalar_mul_divisor(generator).normalize_x_coefficient(); vector_commitment_tape.push(divisor.zero_coefficient); @@ -540,11 +376,12 @@ impl Evrf { let evrf_public_key; let mut actual_coefficients = Vec::with_capacity(coefficients); { - let mut dlog = Self::scalar_to_bits(evrf_private_key); + let dlog = + ScalarDecomposition::<::F>::new(**evrf_private_key); let points = Self::transcript_to_points(transcript, coefficients); // Start by pushing the discrete logarithm onto the tape - for coefficient in &dlog { + for coefficient in dlog.decomposition() { vector_commitment_tape.push(<_>::from(*coefficient)); } @@ -573,8 +410,6 @@ impl Evrf { actual_coefficients.push(res); } debug_assert_eq!(actual_coefficients.len(), coefficients); - - dlog.zeroize(); } // Now do the ECDHs for the encryption @@ -595,14 +430,15 @@ impl Evrf { break; } } - let mut dlog = Self::scalar_to_bits(&ecdh_private_key); + let dlog = + ScalarDecomposition::<::F>::new(ecdh_private_key); let ecdh_commitment = ::generator() * ecdh_private_key; ecdh_commitments.push(ecdh_commitment); ecdh_commitments_xy.last_mut().unwrap()[j] = <::G as DivisorCurve>::to_xy(ecdh_commitment).unwrap(); // Start by pushing the discrete logarithm onto the tape - for coefficient in &dlog { + for coefficient in dlog.decomposition() { vector_commitment_tape.push(<_>::from(*coefficient)); } @@ -625,7 +461,6 @@ impl Evrf { *res += dh_x; ecdh_private_key.zeroize(); - dlog.zeroize(); } encryption_masks.push(res); } diff --git a/crypto/evrf/divisors/Cargo.toml b/crypto/evrf/divisors/Cargo.toml index d4e3a2d0d..04e820b63 100644 --- a/crypto/evrf/divisors/Cargo.toml +++ b/crypto/evrf/divisors/Cargo.toml @@ -14,9 +14,11 @@ rustdoc-args = ["--cfg", "docsrs"] [dependencies] rand_core = { version = "0.6", default-features = false } -zeroize = { version = "^1.5", default-features = false, features = ["zeroize_derive"] } +zeroize = { version = "^1.5", default-features = false, features = ["std", "zeroize_derive"] } -group = "0.13" +subtle = { version = "2", default-features = false, features = ["std"] } +ff = { version = "0.13", default-features = false, features = ["std", "bits"] } +group = { version = "0.13", default-features = false } hex = { version = "0.4", optional = true } dalek-ff-group = { path = "../../dalek-ff-group", features = ["std"], optional = true } diff --git a/crypto/evrf/divisors/src/lib.rs b/crypto/evrf/divisors/src/lib.rs index d71aa8a4d..dbeb149f1 100644 --- a/crypto/evrf/divisors/src/lib.rs +++ b/crypto/evrf/divisors/src/lib.rs @@ -3,21 +3,24 @@ #![deny(missing_docs)] #![allow(non_snake_case)] +use subtle::{Choice, ConstantTimeEq, ConstantTimeGreater, ConditionallySelectable}; +use zeroize::{Zeroize, ZeroizeOnDrop}; + use group::{ - ff::{Field, PrimeField}, + ff::{Field, PrimeField, PrimeFieldBits}, Group, }; mod poly; -pub use poly::*; +pub use poly::Poly; #[cfg(test)] mod tests; /// A curve usable with this library. -pub trait DivisorCurve: Group { +pub trait DivisorCurve: Group + ConstantTimeEq + ConditionallySelectable { /// An element of the field this curve is defined over. - type FieldElement: PrimeField; + type FieldElement: Zeroize + PrimeField + ConditionallySelectable; /// The A in the curve equation y^2 = x^3 + A x + B. fn a() -> Self::FieldElement; @@ -72,46 +75,89 @@ pub(crate) fn slope_intercept(a: C, b: C) -> (C::FieldElement, } // The line interpolating two points. -fn line(a: C, mut b: C) -> Poly { - // If they're both the point at infinity, we simply set the line to one - if bool::from(a.is_identity() & b.is_identity()) { - return Poly { - y_coefficients: vec![], - yx_coefficients: vec![], - x_coefficients: vec![], - zero_coefficient: C::FieldElement::ONE, - }; +fn line(a: C, b: C) -> Poly { + #[derive(Clone, Copy)] + struct LinesRes { + y_coefficient: F, + x_coefficient: F, + zero_coefficient: F, + } + impl ConditionallySelectable for LinesRes { + fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { + Self { + y_coefficient: <_>::conditional_select(&a.y_coefficient, &b.y_coefficient, choice), + x_coefficient: <_>::conditional_select(&a.x_coefficient, &b.x_coefficient, choice), + zero_coefficient: <_>::conditional_select(&a.zero_coefficient, &b.zero_coefficient, choice), + } + } } + let a_is_identity = a.is_identity(); + let b_is_identity = b.is_identity(); + + // If they're both the point at infinity, we simply set the line to one + let both_are_identity = a_is_identity & b_is_identity; + let if_both_are_identity = LinesRes { + y_coefficient: C::FieldElement::ZERO, + x_coefficient: C::FieldElement::ZERO, + zero_coefficient: C::FieldElement::ONE, + }; + // If either point is the point at infinity, or these are additive inverses, the line is // `1 * x - x`. The first `x` is a term in the polynomial, the `x` is the `x` coordinate of these // points (of which there is one, as the second point is either at infinity or has a matching `x` // coordinate). - if bool::from(a.is_identity() | b.is_identity()) || (a == -b) { - let (x, _) = C::to_xy(if !bool::from(a.is_identity()) { a } else { b }).unwrap(); - return Poly { - y_coefficients: vec![], - yx_coefficients: vec![], - x_coefficients: vec![C::FieldElement::ONE], + let one_is_identity = a_is_identity | b_is_identity; + let additive_inverses = a.ct_eq(&-b); + let one_is_identity_or_additive_inverses = one_is_identity | additive_inverses; + let if_one_is_identity_or_additive_inverses = { + // If both are identity, set `a` to the generator so we can safely evaluate the following + // (which we won't select at the end of this function) + let a = <_>::conditional_select(&a, &C::generator(), both_are_identity); + // If `a` is identity, this selects `b`. If `a` isn't identity, this selects `a` + let non_identity = <_>::conditional_select(&a, &b, a.is_identity()); + let (x, _) = C::to_xy(non_identity).unwrap(); + LinesRes { + y_coefficient: C::FieldElement::ZERO, + x_coefficient: C::FieldElement::ONE, zero_coefficient: -x, - }; - } + } + }; + + // The following calculation assumes neither point is the point at infinity + // If either are, we use a prior result + // To ensure we can calculcate a result here, set any points at infinity to the generator + let a = <_>::conditional_select(&a, &C::generator(), a_is_identity); + let b = <_>::conditional_select(&b, &C::generator(), b_is_identity); + // It also assumes a, b aren't additive inverses which is also covered by a prior result + let b = <_>::conditional_select(&b, &a.double(), additive_inverses); // If the points are equal, we use the line interpolating the sum of these points with the point // at infinity - if a == b { - b = -a.double(); - } + let b = <_>::conditional_select(&b, &-a.double(), a.ct_eq(&b)); let (slope, intercept) = slope_intercept::(a, b); // Section 4 of the proofs explicitly state the line `L = y - lambda * x - mu` // y - (slope * x) - intercept + let mut res = LinesRes { + y_coefficient: C::FieldElement::ONE, + x_coefficient: -slope, + zero_coefficient: -intercept, + }; + + res = <_>::conditional_select( + &res, + &if_one_is_identity_or_additive_inverses, + one_is_identity_or_additive_inverses, + ); + res = <_>::conditional_select(&res, &if_both_are_identity, both_are_identity); + Poly { - y_coefficients: vec![C::FieldElement::ONE], + y_coefficients: vec![res.y_coefficient], yx_coefficients: vec![], - x_coefficients: vec![-slope], - zero_coefficient: -intercept, + x_coefficients: vec![res.x_coefficient], + zero_coefficient: res.zero_coefficient, } } @@ -121,36 +167,65 @@ fn line(a: C, mut b: C) -> Poly { /// - No points were passed in /// - The points don't sum to the point at infinity /// - A passed in point was the point at infinity +/// +/// If the arguments were valid, this function executes in an amount of time constant to the amount +/// of points. #[allow(clippy::new_ret_no_self)] pub fn new_divisor(points: &[C]) -> Option> { - // A single point is either the point at infinity, or this doesn't sum to the point at infinity - // Both cause us to return None - if points.len() < 2 { - None?; + // No points were passed in, this is the point at infinity, or the single point isn't infinity + // and accordingly doesn't sum to infinity. All three cause us to return None + // Checks a bit other than the first bit is set, meaning this is >= 2 + let mut invalid_args = (points.len() & (!1)).ct_eq(&0); + // The points don't sum to the point at infinity + invalid_args |= !points.iter().sum::().is_identity(); + // A point was the point at identity + for point in points { + invalid_args |= point.is_identity(); } - if points.iter().sum::() != C::identity() { + if bool::from(invalid_args) { None?; } + let points_len = points.len(); + // Create the initial set of divisors let mut divs = vec![]; let mut iter = points.iter().copied(); while let Some(a) = iter.next() { - if a == C::identity() { - None?; - } - let b = iter.next(); - if b == Some(C::identity()) { - None?; - } // Draw the line between those points - divs.push((a + b.unwrap_or(C::identity()), line::(a, b.unwrap_or(-a)))); + // These unwraps are branching on the length of the iterator, not violating the constant-time + // priorites desired + divs.push((2, a + b.unwrap_or(C::identity()), line::(a, b.unwrap_or(-a)))); } let modulus = C::divisor_modulus(); + // Our Poly algorithm is leaky and will create an excessive amount of y x**j and x**j + // coefficients which are zero, yet as our implementation is constant time, still come with + // an immense performance cost. This code truncates the coefficients we know are zero. + let trim = |divisor: &mut Poly<_>, points_len: usize| { + // We should only be trimming divisors reduced by the modulus + debug_assert!(divisor.yx_coefficients.len() <= 1); + if divisor.yx_coefficients.len() == 1 { + let truncate_to = ((points_len + 1) / 2).saturating_sub(2); + #[cfg(debug_assertions)] + for p in truncate_to .. divisor.yx_coefficients[0].len() { + debug_assert_eq!(divisor.yx_coefficients[0][p], ::ZERO); + } + divisor.yx_coefficients[0].truncate(truncate_to); + } + { + let truncate_to = points_len / 2; + #[cfg(debug_assertions)] + for p in truncate_to .. divisor.x_coefficients.len() { + debug_assert_eq!(divisor.x_coefficients[p], ::ZERO); + } + divisor.x_coefficients.truncate(truncate_to); + } + }; + // Pair them off until only one remains while divs.len() > 1 { let mut next_divs = vec![]; @@ -159,23 +234,208 @@ pub fn new_divisor(points: &[C]) -> Option(a, b), &modulus); - let denominator = line::(a, -a).mul_mod(line::(b, -b), &modulus); - let (q, r) = numerator.div_rem(&denominator); - assert_eq!(r, Poly::zero()); + let numerator = a_div.mul_mod(&b_div, &modulus).mul_mod(&line::(a, b), &modulus); + let denominator = line::(a, -a).mul_mod(&line::(b, -b), &modulus); + let (mut q, r) = numerator.div_rem(&denominator); + debug_assert_eq!(r, Poly::zero()); - next_divs.push((a + b, q)); + trim(&mut q, 1 + points); + + next_divs.push((points, a + b, q)); } divs = next_divs; } // Return the unified divisor - Some(divs.remove(0).1) + let mut divisor = divs.remove(0).2; + trim(&mut divisor, points_len); + Some(divisor) +} + +/// The decomposition of a scalar. +/// +/// The decomposition ($d$) of a scalar ($s$) has the following two properties: +/// +/// - $\sum^{\mathsf{NUM_BITS} - 1}_{i=0} d_i * 2^i = s$ +/// - $\sum^{\mathsf{NUM_BITS} - 1}_{i=0} d_i = \mathsf{NUM_BITS}$ +#[derive(Clone, Zeroize, ZeroizeOnDrop)] +pub struct ScalarDecomposition { + scalar: F, + decomposition: Vec, +} + +impl ScalarDecomposition { + /// Decompose a scalar. + pub fn new(scalar: F) -> Self { + /* + We need the sum of the coefficients to equal F::NUM_BITS. The scalar's bits will be less than + F::NUM_BITS. Accordingly, we need to increment the sum of the coefficients without + incrementing the scalar represented. We do this by finding the highest non-0 coefficient, + decrementing it, and increasing the immediately less significant coefficient by 2. This + increases the sum of the coefficients by 1 (-1+2=1). + */ + + let num_bits = u64::from(F::NUM_BITS); + + // Obtain the bits of the scalar + let num_bits_usize = usize::try_from(num_bits).unwrap(); + let mut decomposition = vec![0; num_bits_usize]; + for (i, bit) in scalar.to_le_bits().into_iter().take(num_bits_usize).enumerate() { + let bit = u64::from(u8::from(bit)); + decomposition[i] = bit; + } + + // The following algorithm only works if the value of the scalar exceeds num_bits + // If it isn't, we increase it by the modulus such that it does exceed num_bits + { + let mut less_than_num_bits = Choice::from(0); + for i in 0 .. num_bits { + less_than_num_bits |= scalar.ct_eq(&F::from(i)); + } + let mut decomposition_of_modulus = vec![0; num_bits_usize]; + // Decompose negative one + for (i, bit) in (-F::ONE).to_le_bits().into_iter().take(num_bits_usize).enumerate() { + let bit = u64::from(u8::from(bit)); + decomposition_of_modulus[i] = bit; + } + // Increment it by one + decomposition_of_modulus[0] += 1; + + // Add the decomposition onto the decomposition of the modulus + for i in 0 .. num_bits_usize { + let new_decomposition = <_>::conditional_select( + &decomposition[i], + &(decomposition[i] + decomposition_of_modulus[i]), + less_than_num_bits, + ); + decomposition[i] = new_decomposition; + } + } + + // Calculcate the sum of the coefficients + let mut sum_of_coefficients: u64 = 0; + for decomposition in &decomposition { + sum_of_coefficients += *decomposition; + } + + /* + Now, because we added a log2(k)-bit number to a k-bit number, we may have our sum of + coefficients be *too high*. We attempt to reduce the sum of the coefficients accordingly. + + This algorithm is guaranteed to complete as expected. Take the sequence `222`. `222` becomes + `032` becomes `013`. Even if the next coefficient in the sequence is `2`, the third + coefficient will be reduced once and the next coefficient (`2`, increased to `3`) will only + be eligible for reduction once. This demonstrates, even for a worst case of log2(k) `2`s + followed by `1`s (as possible if the modulus is a Mersenne prime), the log2(k) `2`s can be + reduced as necessary so long as there is a single coefficient after (requiring the entire + sequence be at least of length log2(k) + 1). For a 2-bit number, log2(k) + 1 == 2, so this + holds for any odd prime field. + + To fully type out the demonstration for the Mersenne prime 3, with scalar to encode 1 (the + highest value less than the number of bits): + + 10 - Little-endian bits of 1 + 21 - Little-endian bits of 1, plus the modulus + 02 - After one reduction, where the sum of the coefficients does in fact equal 2 (the target) + */ + { + let mut log2_num_bits = 0; + while (1 << log2_num_bits) < num_bits { + log2_num_bits += 1; + } + + for _ in 0 .. log2_num_bits { + // If the sum of coefficients is the amount of bits, we're done + let mut done = sum_of_coefficients.ct_eq(&num_bits); + + for i in 0 .. (num_bits_usize - 1) { + let should_act = (!done) & decomposition[i].ct_gt(&1); + // Subtract 2 from this coefficient + let amount_to_sub = <_>::conditional_select(&0, &2, should_act); + decomposition[i] -= amount_to_sub; + // Add 1 to the next coefficient + let amount_to_add = <_>::conditional_select(&0, &1, should_act); + decomposition[i + 1] += amount_to_add; + + // Also update the sum of coefficients + sum_of_coefficients -= <_>::conditional_select(&0, &1, should_act); + + // If we updated the coefficients this loop iter, we're done for this loop iter + done |= should_act; + } + } + } + + for _ in 0 .. num_bits { + // If the sum of coefficients is the amount of bits, we're done + let mut done = sum_of_coefficients.ct_eq(&num_bits); + + // Find the highest coefficient currently non-zero + for i in (1 .. decomposition.len()).rev() { + // If this is non-zero, we should decrement this coefficient if we haven't already + // decremented a coefficient this round + let is_non_zero = !(0.ct_eq(&decomposition[i])); + let should_act = (!done) & is_non_zero; + + // Update this coefficient and the prior coefficient + let amount_to_sub = <_>::conditional_select(&0, &1, should_act); + decomposition[i] -= amount_to_sub; + + let amount_to_add = <_>::conditional_select(&0, &2, should_act); + // i must be at least 1, so i - 1 will be at least 0 (meaning it's safe to index with) + decomposition[i - 1] += amount_to_add; + + // Also update the sum of coefficients + sum_of_coefficients += <_>::conditional_select(&0, &1, should_act); + + // If we updated the coefficients this loop iter, we're done for this loop iter + done |= should_act; + } + } + debug_assert!(bool::from(decomposition.iter().sum::().ct_eq(&num_bits))); + + ScalarDecomposition { scalar, decomposition } + } + + /// The decomposition of the scalar. + pub fn decomposition(&self) -> &[u64] { + &self.decomposition + } + + /// A divisor to prove a scalar multiplication. + /// + /// The divisor will interpolate $d_i$ instances of $2^i \cdot G$ with $-(s \cdot G)$. + /// + /// This function executes in constant time with regards to the scalar. + /// + /// This function MAY panic if this scalar is zero. + pub fn scalar_mul_divisor>( + &self, + mut generator: C, + ) -> Poly { + // The following for loop is constant time to the sum of `dlog`'s elements + let mut divisor_points = + Vec::with_capacity(usize::try_from(::NUM_BITS).unwrap()); + divisor_points.push(-generator * self.scalar); + for coefficient in &self.decomposition { + let mut coefficient = *coefficient; + while coefficient != 0 { + coefficient -= 1; + divisor_points.push(generator); + } + generator = generator.double(); + } + + let res = new_divisor(&divisor_points).unwrap(); + divisor_points.zeroize(); + res + } } #[cfg(any(test, feature = "pasta"))] diff --git a/crypto/evrf/divisors/src/poly.rs b/crypto/evrf/divisors/src/poly.rs index b818433bc..0e41bc49d 100644 --- a/crypto/evrf/divisors/src/poly.rs +++ b/crypto/evrf/divisors/src/poly.rs @@ -1,25 +1,112 @@ use core::ops::{Add, Neg, Sub, Mul, Rem}; -use zeroize::Zeroize; +use subtle::{Choice, ConstantTimeEq, ConstantTimeGreater, ConditionallySelectable}; +use zeroize::{Zeroize, ZeroizeOnDrop}; use group::ff::PrimeField; -/// A structure representing a Polynomial with x**i, y**i, and y**i * x**j terms. -#[derive(Clone, PartialEq, Eq, Debug, Zeroize)] -pub struct Poly> { - /// c[i] * y ** (i + 1) +#[derive(Clone, Copy, PartialEq, Debug)] +struct CoefficientIndex { + y_pow: u64, + x_pow: u64, +} +impl ConditionallySelectable for CoefficientIndex { + fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { + Self { + y_pow: <_>::conditional_select(&a.y_pow, &b.y_pow, choice), + x_pow: <_>::conditional_select(&a.x_pow, &b.x_pow, choice), + } + } +} +impl ConstantTimeEq for CoefficientIndex { + fn ct_eq(&self, other: &Self) -> Choice { + self.y_pow.ct_eq(&other.y_pow) & self.x_pow.ct_eq(&other.x_pow) + } +} +impl ConstantTimeGreater for CoefficientIndex { + fn ct_gt(&self, other: &Self) -> Choice { + self.y_pow.ct_gt(&other.y_pow) | + (self.y_pow.ct_eq(&other.y_pow) & self.x_pow.ct_gt(&other.x_pow)) + } +} + +/// A structure representing a Polynomial with x^i, y^i, and y^i * x^j terms. +#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)] +pub struct Poly + Zeroize + PrimeField> { + /// c\[i] * y^(i + 1) pub y_coefficients: Vec, - /// c[i][j] * y ** (i + 1) x ** (j + 1) + /// c\[i]\[j] * y^(i + 1) x^(j + 1) pub yx_coefficients: Vec>, - /// c[i] * x ** (i + 1) + /// c\[i] * x^(i + 1) pub x_coefficients: Vec, - /// Coefficient for x ** 0, y ** 0, and x ** 0 y ** 0 (the coefficient for 1) + /// Coefficient for x^0, y^0, and x^0 y^0 (the coefficient for 1) pub zero_coefficient: F, } -impl> Poly { +impl + Zeroize + PrimeField> PartialEq for Poly { + // This is not constant time and is not meant to be + fn eq(&self, b: &Poly) -> bool { + { + let mutual_y_coefficients = self.y_coefficients.len().min(b.y_coefficients.len()); + if self.y_coefficients[.. mutual_y_coefficients] != b.y_coefficients[.. mutual_y_coefficients] + { + return false; + } + for coeff in &self.y_coefficients[mutual_y_coefficients ..] { + if *coeff != F::ZERO { + return false; + } + } + for coeff in &b.y_coefficients[mutual_y_coefficients ..] { + if *coeff != F::ZERO { + return false; + } + } + } + + { + for (i, yx_coeffs) in self.yx_coefficients.iter().enumerate() { + for (j, coeff) in yx_coeffs.iter().enumerate() { + if coeff != b.yx_coefficients.get(i).unwrap_or(&vec![]).get(j).unwrap_or(&F::ZERO) { + return false; + } + } + } + // Run from the other perspective in case other is longer than self + for (i, yx_coeffs) in b.yx_coefficients.iter().enumerate() { + for (j, coeff) in yx_coeffs.iter().enumerate() { + if coeff != self.yx_coefficients.get(i).unwrap_or(&vec![]).get(j).unwrap_or(&F::ZERO) { + return false; + } + } + } + } + + { + let mutual_x_coefficients = self.x_coefficients.len().min(b.x_coefficients.len()); + if self.x_coefficients[.. mutual_x_coefficients] != b.x_coefficients[.. mutual_x_coefficients] + { + return false; + } + for coeff in &self.x_coefficients[mutual_x_coefficients ..] { + if *coeff != F::ZERO { + return false; + } + } + for coeff in &b.x_coefficients[mutual_x_coefficients ..] { + if *coeff != F::ZERO { + return false; + } + } + } + + self.zero_coefficient == b.zero_coefficient + } +} + +impl + Zeroize + PrimeField> Poly { /// A polynomial for zero. - pub fn zero() -> Self { + pub(crate) fn zero() -> Self { Poly { y_coefficients: vec![], yx_coefficients: vec![], @@ -27,37 +114,9 @@ impl> Poly { zero_coefficient: F::ZERO, } } - - /// The amount of terms in the polynomial. - #[allow(clippy::len_without_is_empty)] - #[must_use] - pub fn len(&self) -> usize { - self.y_coefficients.len() + - self.yx_coefficients.iter().map(Vec::len).sum::() + - self.x_coefficients.len() + - usize::from(u8::from(self.zero_coefficient != F::ZERO)) - } - - // Remove high-order zero terms, allowing the length of the vectors to equal the amount of terms. - pub(crate) fn tidy(&mut self) { - let tidy = |vec: &mut Vec| { - while vec.last() == Some(&F::ZERO) { - vec.pop(); - } - }; - - tidy(&mut self.y_coefficients); - for vec in self.yx_coefficients.iter_mut() { - tidy(vec); - } - while self.yx_coefficients.last() == Some(&vec![]) { - self.yx_coefficients.pop(); - } - tidy(&mut self.x_coefficients); - } } -impl> Add<&Self> for Poly { +impl + Zeroize + PrimeField> Add<&Self> for Poly { type Output = Self; fn add(mut self, other: &Self) -> Self { @@ -91,12 +150,11 @@ impl> Add<&Self> for Poly { } self.zero_coefficient += other.zero_coefficient; - self.tidy(); self } } -impl> Neg for Poly { +impl + Zeroize + PrimeField> Neg for Poly { type Output = Self; fn neg(mut self) -> Self { @@ -117,7 +175,7 @@ impl> Neg for Poly { } } -impl> Sub for Poly { +impl + Zeroize + PrimeField> Sub for Poly { type Output = Self; fn sub(self, other: Self) -> Self { @@ -125,14 +183,10 @@ impl> Sub for Poly { } } -impl> Mul for Poly { +impl + Zeroize + PrimeField> Mul for Poly { type Output = Self; fn mul(mut self, scalar: F) -> Self { - if scalar == F::ZERO { - return Poly::zero(); - } - for y_coeff in self.y_coefficients.iter_mut() { *y_coeff *= scalar; } @@ -149,7 +203,7 @@ impl> Mul for Poly { } } -impl> Poly { +impl + Zeroize + PrimeField> Poly { #[must_use] fn shift_by_x(mut self, power_of_x: usize) -> Self { if power_of_x == 0 { @@ -203,17 +257,17 @@ impl> Poly { self.zero_coefficient = F::ZERO; // Move the x coefficients - self.yx_coefficients[power_of_y - 1] = self.x_coefficients; + std::mem::swap(&mut self.yx_coefficients[power_of_y - 1], &mut self.x_coefficients); self.x_coefficients = vec![]; self } } -impl> Mul for Poly { +impl + Zeroize + PrimeField> Mul<&Poly> for Poly { type Output = Self; - fn mul(self, other: Self) -> Self { + fn mul(self, other: &Self) -> Self { let mut res = self.clone() * other.zero_coefficient; for (i, y_coeff) in other.y_coefficients.iter().enumerate() { @@ -233,94 +287,320 @@ impl> Mul for Poly { res = res + &scaled.shift_by_x(i + 1); } - res.tidy(); res } } -impl> Poly { +impl + Zeroize + PrimeField> Poly { + // The leading y coefficient and associated x coefficient. + fn leading_coefficient(&self) -> (usize, usize) { + if self.y_coefficients.len() > self.yx_coefficients.len() { + (self.y_coefficients.len(), 0) + } else if !self.yx_coefficients.is_empty() { + (self.yx_coefficients.len(), self.yx_coefficients.last().unwrap().len()) + } else { + (0, self.x_coefficients.len()) + } + } + + /// Returns the highest non-zero coefficient greater than the specified coefficient. + /// + /// If no non-zero coefficient is greater than the specified coefficient, this will return + /// (0, 0). + fn greater_than_or_equal_coefficient( + &self, + greater_than_or_equal: &CoefficientIndex, + ) -> CoefficientIndex { + let mut leading_coefficient = CoefficientIndex { y_pow: 0, x_pow: 0 }; + for (y_pow_sub_one, coeff) in self.y_coefficients.iter().enumerate() { + let y_pow = u64::try_from(y_pow_sub_one + 1).unwrap(); + let coeff_is_non_zero = !coeff.is_zero(); + let potential = CoefficientIndex { y_pow, x_pow: 0 }; + leading_coefficient = <_>::conditional_select( + &leading_coefficient, + &potential, + coeff_is_non_zero & + potential.ct_gt(&leading_coefficient) & + (potential.ct_gt(greater_than_or_equal) | potential.ct_eq(greater_than_or_equal)), + ); + } + for (y_pow_sub_one, yx_coefficients) in self.yx_coefficients.iter().enumerate() { + let y_pow = u64::try_from(y_pow_sub_one + 1).unwrap(); + for (x_pow_sub_one, coeff) in yx_coefficients.iter().enumerate() { + let x_pow = u64::try_from(x_pow_sub_one + 1).unwrap(); + let coeff_is_non_zero = !coeff.is_zero(); + let potential = CoefficientIndex { y_pow, x_pow }; + leading_coefficient = <_>::conditional_select( + &leading_coefficient, + &potential, + coeff_is_non_zero & + potential.ct_gt(&leading_coefficient) & + (potential.ct_gt(greater_than_or_equal) | potential.ct_eq(greater_than_or_equal)), + ); + } + } + for (x_pow_sub_one, coeff) in self.x_coefficients.iter().enumerate() { + let x_pow = u64::try_from(x_pow_sub_one + 1).unwrap(); + let coeff_is_non_zero = !coeff.is_zero(); + let potential = CoefficientIndex { y_pow: 0, x_pow }; + leading_coefficient = <_>::conditional_select( + &leading_coefficient, + &potential, + coeff_is_non_zero & + potential.ct_gt(&leading_coefficient) & + (potential.ct_gt(greater_than_or_equal) | potential.ct_eq(greater_than_or_equal)), + ); + } + leading_coefficient + } + /// Perform multiplication mod `modulus`. #[must_use] - pub fn mul_mod(self, other: Self, modulus: &Self) -> Self { - ((self % modulus) * (other % modulus)) % modulus + pub(crate) fn mul_mod(self, other: &Self, modulus: &Self) -> Self { + (self * other) % modulus } /// Perform division, returning the result and remainder. /// - /// Panics upon division by zero, with undefined behavior if a non-tidy divisor is used. + /// This function is constant time to the structure of the numerator and denominator. The actual + /// value of the coefficients will not introduce timing differences. + /// + /// Panics upon division by a polynomial where all coefficients are zero. #[must_use] - pub fn div_rem(self, divisor: &Self) -> (Self, Self) { - // The leading y coefficient and associated x coefficient. - let leading_y = |poly: &Self| -> (_, _) { - if poly.y_coefficients.len() > poly.yx_coefficients.len() { - (poly.y_coefficients.len(), 0) - } else if !poly.yx_coefficients.is_empty() { - (poly.yx_coefficients.len(), poly.yx_coefficients.last().unwrap().len()) - } else { - (0, poly.x_coefficients.len()) + pub(crate) fn div_rem(self, denominator: &Self) -> (Self, Self) { + // These functions have undefined behavior if this isn't a valid index for this poly + fn ct_get + Zeroize + PrimeField>( + poly: &Poly, + index: CoefficientIndex, + ) -> F { + let mut res = poly.zero_coefficient; + for (y_pow_sub_one, coeff) in poly.y_coefficients.iter().enumerate() { + res = <_>::conditional_select(&res, coeff, index.ct_eq(&CoefficientIndex { y_pow: (y_pow_sub_one + 1).try_into().unwrap(), x_pow: 0 })); } - }; - - let (div_y, div_x) = leading_y(divisor); - // If this divisor is actually a scalar, don't perform long division - if (div_y == 0) && (div_x == 0) { - return (self * divisor.zero_coefficient.invert().unwrap(), Poly::zero()); + for (y_pow_sub_one, coeffs) in poly.yx_coefficients.iter().enumerate() { + for (x_pow_sub_one, coeff) in coeffs.iter().enumerate() { + res = <_>::conditional_select(&res, coeff, index.ct_eq(&CoefficientIndex { y_pow: (y_pow_sub_one + 1).try_into().unwrap(), x_pow: (x_pow_sub_one + 1).try_into().unwrap() })); + } + } + for (x_pow_sub_one, coeff) in poly.x_coefficients.iter().enumerate() { + res = <_>::conditional_select(&res, coeff, index.ct_eq(&CoefficientIndex { y_pow: 0, x_pow: (x_pow_sub_one + 1).try_into().unwrap() })); + } + res } - // Remove leading terms until the value is less than the divisor - let mut quotient: Poly = Poly::zero(); - let mut remainder = self.clone(); - loop { - // If there's nothing left to divide, return - if remainder == Poly::zero() { - break; + fn ct_set + Zeroize + PrimeField>( + poly: &mut Poly, + index: CoefficientIndex, + value: F, + ) { + for (y_pow_sub_one, coeff) in poly.y_coefficients.iter_mut().enumerate() { + *coeff = <_>::conditional_select(coeff, &value, index.ct_eq(&CoefficientIndex { y_pow: (y_pow_sub_one + 1).try_into().unwrap(), x_pow: 0 })); } - - let (rem_y, rem_x) = leading_y(&remainder); - if (rem_y < div_y) || (rem_x < div_x) { - break; + for (y_pow_sub_one, coeffs) in poly.yx_coefficients.iter_mut().enumerate() { + for (x_pow_sub_one, coeff) in coeffs.iter_mut().enumerate() { + *coeff = <_>::conditional_select(coeff, &value, index.ct_eq(&CoefficientIndex { y_pow: (y_pow_sub_one + 1).try_into().unwrap(), x_pow: (x_pow_sub_one + 1).try_into().unwrap() })); + } } + for (x_pow_sub_one, coeff) in poly.x_coefficients.iter_mut().enumerate() { + *coeff = <_>::conditional_select(coeff, &value, index.ct_eq(&CoefficientIndex { y_pow: 0, x_pow: (x_pow_sub_one + 1).try_into().unwrap() })); + } + poly.zero_coefficient = <_>::conditional_select(&poly.zero_coefficient, &value, index.ct_eq(&CoefficientIndex { y_pow: 0, x_pow: 0 })); + } - let get = |poly: &Poly, y_pow: usize, x_pow: usize| -> F { - if (y_pow == 0) && (x_pow == 0) { - poly.zero_coefficient - } else if x_pow == 0 { - poly.y_coefficients[y_pow - 1] - } else if y_pow == 0 { - poly.x_coefficients[x_pow - 1] - } else { - poly.yx_coefficients[y_pow - 1][x_pow - 1] + fn conditional_select_poly + Zeroize + PrimeField>( + mut a: Poly, + mut b: Poly, + choice: Choice, + ) -> Poly { + let pad_to = |a: &mut Poly, b: &Poly| { + while a.x_coefficients.len() < b.x_coefficients.len() { + a.x_coefficients.push(F::ZERO); + } + while a.yx_coefficients.len() < b.yx_coefficients.len() { + a.yx_coefficients.push(vec![]); + } + for (a, b) in a.yx_coefficients.iter_mut().zip(&b.yx_coefficients) { + while a.len() < b.len() { + a.push(F::ZERO); + } + } + while a.y_coefficients.len() < b.y_coefficients.len() { + a.y_coefficients.push(F::ZERO); } }; - let coeff_numerator = get(&remainder, rem_y, rem_x); - let coeff_denominator = get(divisor, div_y, div_x); + // Pad these to be the same size/layout as each other + pad_to(&mut a, &b); + pad_to(&mut b, &a); + + let mut res = Poly::zero(); + for (a, b) in a.y_coefficients.iter().zip(&b.y_coefficients) { + res.y_coefficients.push(<_>::conditional_select(a, b, choice)); + } + for (a, b) in a.yx_coefficients.iter().zip(&b.yx_coefficients) { + let mut yx_coefficients = Vec::with_capacity(a.len()); + for (a, b) in a.iter().zip(b) { + yx_coefficients.push(<_>::conditional_select(a, b, choice)) + } + res.yx_coefficients.push(yx_coefficients); + } + for (a, b) in a.x_coefficients.iter().zip(&b.x_coefficients) { + res.x_coefficients.push(<_>::conditional_select(a, b, choice)); + } + res.zero_coefficient = <_>::conditional_select(&a.zero_coefficient, &b.zero_coefficient, choice); + + res + } - // We want coeff_denominator scaled by x to equal coeff_numerator - // x * d = n - // n / d = x - let mut quotient_term = Poly::zero(); - // Because this is the coefficient for the leading term of a tidied polynomial, it must be - // non-zero - quotient_term.zero_coefficient = coeff_numerator * coeff_denominator.invert().unwrap(); + // The following long division algorithm only works if the denominator actually has a variable + // If the denominator isn't variable to anything, short-circuit to scalar 'division' + // This is safe as `leading_coefficient` is based on the structure, not the values, of the poly + let denominator_leading_coefficient = denominator.leading_coefficient(); + if denominator_leading_coefficient == (0, 0) { + return (self * denominator.zero_coefficient.invert().unwrap(), Poly::zero()); + } + + // The structure of the quotient, which is the the numerator with all coefficients set to 0 + let mut quotient_structure = Poly { + y_coefficients: vec![F::ZERO; self.y_coefficients.len()], + yx_coefficients: self.yx_coefficients.clone(), + x_coefficients: vec![F::ZERO; self.x_coefficients.len()], + zero_coefficient: F::ZERO, + }; + for coeff in quotient_structure + .yx_coefficients + .iter_mut() + .flat_map(|yx_coefficients| yx_coefficients.iter_mut()) + { + *coeff = F::ZERO; + } - // Add the necessary yx powers - let delta_y = rem_y - div_y; - let delta_x = rem_x - div_x; - let quotient_term = quotient_term.shift_by_y(delta_y).shift_by_x(delta_x); + // Calculate the amount of iterations we need to perform + let iterations = self.y_coefficients.len() + + self.yx_coefficients.iter().map(|yx_coefficients| yx_coefficients.len()).sum::() + + self.x_coefficients.len(); - let to_remove = quotient_term.clone() * divisor.clone(); - debug_assert_eq!(get(&to_remove, rem_y, rem_x), coeff_numerator); + // Find the highest non-zero coefficient in the denominator + // This is the coefficient which we actually perform division with + let denominator_dividing_coefficient = + denominator.greater_than_or_equal_coefficient(&CoefficientIndex { y_pow: 0, x_pow: 0 }); + let denominator_dividing_coefficient_inv = + ct_get(denominator, denominator_dividing_coefficient).invert().unwrap(); - remainder = remainder - to_remove; - quotient = quotient + "ient_term; + let mut quotient = quotient_structure.clone(); + let mut remainder = self.clone(); + for _ in 0 .. iterations { + // Find the numerator coefficient we're clearing + // This will be (0, 0) if we aren't clearing a coefficient + let numerator_coefficient = + remainder.greater_than_or_equal_coefficient(&denominator_dividing_coefficient); + + // We only apply the effects of this iteration if the numerator's coefficient is actually >= + let meaningful_iteration = numerator_coefficient.ct_gt(&denominator_dividing_coefficient) | + numerator_coefficient.ct_eq(&denominator_dividing_coefficient); + + // 1) Find the scalar `q` such that the leading coefficient of `q * denominator` is equal to + // the leading coefficient of self. + let numerator_coefficient_value = ct_get(&remainder, numerator_coefficient); + let q = numerator_coefficient_value * denominator_dividing_coefficient_inv; + + // 2) Calculate the full term of the quotient by scaling with the necessary powers of y/x + let proper_powers_of_yx = CoefficientIndex { + y_pow: numerator_coefficient.y_pow.wrapping_sub(denominator_dividing_coefficient.y_pow), + x_pow: numerator_coefficient.x_pow.wrapping_sub(denominator_dividing_coefficient.x_pow), + }; + let fallabck_powers_of_yx = CoefficientIndex { y_pow: 0, x_pow: 0 }; + let mut quotient_term = quotient_structure.clone(); + ct_set( + &mut quotient_term, + // If the numerator coefficient isn't >=, proper_powers_of_yx will have garbage in them + <_>::conditional_select(&fallabck_powers_of_yx, &proper_powers_of_yx, meaningful_iteration), + q, + ); + + let quotient_if_meaningful = quotient.clone() + "ient_term; + quotient = conditional_select_poly(quotient, quotient_if_meaningful, meaningful_iteration); + + // 3) Remove what we've divided out from self + let remainder_if_meaningful = remainder.clone() - (quotient_term * denominator); + remainder = + conditional_select_poly(remainder, remainder_if_meaningful, meaningful_iteration); + } + + quotient = conditional_select_poly( + quotient, + // If the dividing coefficient was for y**0 x**0, we return the poly scaled by its inverse + self.clone() * denominator_dividing_coefficient_inv, + denominator_dividing_coefficient.ct_eq(&CoefficientIndex { y_pow: 0, x_pow: 0 }), + ); + remainder = conditional_select_poly( + remainder, + // If the dividing coefficient was for y**0 x**0, we're able to perfectly divide and there's + // no remainder + Poly::zero(), + denominator_dividing_coefficient.ct_eq(&CoefficientIndex { y_pow: 0, x_pow: 0 }), + ); + + // Clear any junk terms out of the remainder which are less than the denominator + let denominator_leading_coefficient = CoefficientIndex { + y_pow: denominator_leading_coefficient.0.try_into().unwrap(), + x_pow: denominator_leading_coefficient.1.try_into().unwrap(), + }; + if denominator_leading_coefficient != (CoefficientIndex { y_pow: 0, x_pow: 0 }) { + while { + let index = + CoefficientIndex { y_pow: remainder.y_coefficients.len().try_into().unwrap(), x_pow: 0 }; + bool::from( + index.ct_gt(&denominator_leading_coefficient) | + index.ct_eq(&denominator_leading_coefficient), + ) + } { + let popped = remainder.y_coefficients.pop(); + debug_assert_eq!(popped, Some(F::ZERO)); + } + while { + let index = CoefficientIndex { + y_pow: remainder.yx_coefficients.len().try_into().unwrap(), + x_pow: remainder + .yx_coefficients + .last() + .map(|yx_coefficients| yx_coefficients.len()) + .unwrap_or(0) + .try_into() + .unwrap(), + }; + bool::from( + index.ct_gt(&denominator_leading_coefficient) | + index.ct_eq(&denominator_leading_coefficient), + ) + } { + let popped = remainder.yx_coefficients.last_mut().unwrap().pop(); + // This may have been `vec![]` + if let Some(popped) = popped { + debug_assert_eq!(popped, F::ZERO); + } + if remainder.yx_coefficients.last().unwrap().is_empty() { + let popped = remainder.yx_coefficients.pop(); + debug_assert_eq!(popped, Some(vec![])); + } + } + while { + let index = + CoefficientIndex { y_pow: 0, x_pow: remainder.x_coefficients.len().try_into().unwrap() }; + bool::from( + index.ct_gt(&denominator_leading_coefficient) | + index.ct_eq(&denominator_leading_coefficient), + ) + } { + let popped = remainder.x_coefficients.pop(); + debug_assert_eq!(popped, Some(F::ZERO)); + } } - debug_assert_eq!((quotient.clone() * divisor.clone()) + &remainder, self); (quotient, remainder) } } -impl> Rem<&Self> for Poly { +impl + Zeroize + PrimeField> Rem<&Self> for Poly { type Output = Self; fn rem(self, modulus: &Self) -> Self { @@ -328,10 +608,10 @@ impl> Rem<&Self> for Poly { } } -impl> Poly { +impl + Zeroize + PrimeField> Poly { /// Evaluate this polynomial with the specified x/y values. /// - /// Panics on polynomials with terms whose powers exceed 2**64. + /// Panics on polynomials with terms whose powers exceed 2^64. #[must_use] pub fn eval(&self, x: F, y: F) -> F { let mut res = self.zero_coefficient; @@ -358,14 +638,11 @@ impl> Poly { res } - /// Differentiate a polynomial, reduced by a modulus with a leading y term y**2 x**0, by x and y. + /// Differentiate a polynomial, reduced by a modulus with a leading y term y^2 x^0, by x and y. /// - /// This function panics if a y**2 term is present within the polynomial. + /// This function has undefined behavior if unreduced. #[must_use] pub fn differentiate(&self) -> (Poly, Poly) { - assert!(self.y_coefficients.len() <= 1); - assert!(self.yx_coefficients.len() <= 1); - // Differentation by x practically involves: // - Dropping everything without an x component // - Shifting everything down a power of x @@ -391,17 +668,18 @@ impl> Poly { if !self.yx_coefficients.is_empty() { let mut yx_coeffs = self.yx_coefficients[0].clone(); - diff_x.y_coefficients = vec![yx_coeffs.remove(0)]; - diff_x.yx_coefficients = vec![yx_coeffs]; - - let mut prior_x_power = F::from(2); - for yx_coeff in &mut diff_x.yx_coefficients[0] { - *yx_coeff *= prior_x_power; - prior_x_power += F::ONE; + if !yx_coeffs.is_empty() { + diff_x.y_coefficients = vec![yx_coeffs.remove(0)]; + diff_x.yx_coefficients = vec![yx_coeffs]; + + let mut prior_x_power = F::from(2); + for yx_coeff in &mut diff_x.yx_coefficients[0] { + *yx_coeff *= prior_x_power; + prior_x_power += F::ONE; + } } } - diff_x.tidy(); diff_x }; diff --git a/crypto/evrf/divisors/src/tests/mod.rs b/crypto/evrf/divisors/src/tests/mod.rs index bd8de441a..c7c955670 100644 --- a/crypto/evrf/divisors/src/tests/mod.rs +++ b/crypto/evrf/divisors/src/tests/mod.rs @@ -6,6 +6,8 @@ use pasta_curves::{Ep, Eq}; use crate::{DivisorCurve, Poly, new_divisor}; +mod poly; + // Equation 4 in the security proofs fn check_divisor(points: Vec) { // Create the divisor @@ -184,16 +186,16 @@ fn test_subset_sum_to_infinity() { #[test] fn test_divisor_pallas() { - test_divisor::(); test_same_point::(); test_subset_sum_to_infinity::(); + test_divisor::(); } #[test] fn test_divisor_vesta() { - test_divisor::(); test_same_point::(); test_subset_sum_to_infinity::(); + test_divisor::(); } #[test] @@ -229,7 +231,7 @@ fn test_divisor_ed25519() { } } - test_divisor::(); test_same_point::(); test_subset_sum_to_infinity::(); + test_divisor::(); } diff --git a/crypto/evrf/divisors/src/tests/poly.rs b/crypto/evrf/divisors/src/tests/poly.rs index c630a69e5..63f73a96a 100644 --- a/crypto/evrf/divisors/src/tests/poly.rs +++ b/crypto/evrf/divisors/src/tests/poly.rs @@ -1,3 +1,5 @@ +use rand_core::OsRng; + use group::ff::Field; use pasta_curves::Ep; @@ -16,7 +18,24 @@ fn test_poly() { let mut modulus = Poly::zero(); modulus.y_coefficients = vec![one]; - assert_eq!(poly % &modulus, Poly::zero()); + assert_eq!( + poly.clone().div_rem(&modulus).0, + Poly { + y_coefficients: vec![one], + yx_coefficients: vec![], + x_coefficients: vec![], + zero_coefficient: zero + } + ); + assert_eq!( + poly % &modulus, + Poly { + y_coefficients: vec![], + yx_coefficients: vec![], + x_coefficients: vec![], + zero_coefficient: zero + } + ); } { @@ -25,7 +44,7 @@ fn test_poly() { let mut squared = Poly::zero(); squared.y_coefficients = vec![zero, zero, zero, one]; - assert_eq!(poly.clone() * poly.clone(), squared); + assert_eq!(poly.clone() * &poly, squared); } { @@ -37,18 +56,18 @@ fn test_poly() { let mut res = Poly::zero(); res.zero_coefficient = F::from(6u64); - assert_eq!(a.clone() * b.clone(), res); + assert_eq!(a.clone() * &b, res); b.y_coefficients = vec![F::from(4u64)]; res.y_coefficients = vec![F::from(8u64)]; - assert_eq!(a.clone() * b.clone(), res); - assert_eq!(b.clone() * a.clone(), res); + assert_eq!(a.clone() * &b, res); + assert_eq!(b.clone() * &a, res); a.x_coefficients = vec![F::from(5u64)]; res.x_coefficients = vec![F::from(15u64)]; res.yx_coefficients = vec![vec![F::from(20u64)]]; - assert_eq!(a.clone() * b.clone(), res); - assert_eq!(b * a.clone(), res); + assert_eq!(a.clone() * &b, res); + assert_eq!(b * &a, res); // res is now 20xy + 8*y + 15*x + 6 // res ** 2 = @@ -60,7 +79,7 @@ fn test_poly() { vec![vec![F::from(480u64), F::from(600u64)], vec![F::from(320u64), F::from(400u64)]]; squared.x_coefficients = vec![F::from(180u64), F::from(225u64)]; squared.zero_coefficient = F::from(36u64); - assert_eq!(res.clone() * res, squared); + assert_eq!(res.clone() * &res, squared); } }