diff --git a/src/algebra/poly/authenticated_poly.rs b/src/algebra/poly/authenticated_poly.rs index 85cdf9f..ffe674c 100644 --- a/src/algebra/poly/authenticated_poly.rs +++ b/src/algebra/poly/authenticated_poly.rs @@ -288,6 +288,41 @@ impl Mul<&AuthenticatedDensePoly> for &AuthenticatedDensePoly< } } +// --- Scalar Multiplication --- // + +impl Mul<&Scalar> for &AuthenticatedDensePoly { + type Output = AuthenticatedDensePoly; + + fn mul(self, rhs: &Scalar) -> Self::Output { + let new_coeffs = self.coeffs.iter().map(|coeff| coeff * rhs).collect_vec(); + AuthenticatedDensePoly::from_coeffs(new_coeffs) + } +} +impl_borrow_variants!(AuthenticatedDensePoly, Mul, mul, *, Scalar, C: CurveGroup); +impl_commutative!(AuthenticatedDensePoly, Mul, mul, *, Scalar, C: CurveGroup); + +impl Mul<&ScalarResult> for &AuthenticatedDensePoly { + type Output = AuthenticatedDensePoly; + + fn mul(self, rhs: &ScalarResult) -> Self::Output { + let new_coeffs = self.coeffs.iter().map(|coeff| coeff * rhs).collect_vec(); + AuthenticatedDensePoly::from_coeffs(new_coeffs) + } +} +impl_borrow_variants!(AuthenticatedDensePoly, Mul, mul, *, ScalarResult, C: CurveGroup); +impl_commutative!(AuthenticatedDensePoly, Mul, mul, *, ScalarResult, C: CurveGroup); + +impl Mul<&AuthenticatedScalarResult> for &AuthenticatedDensePoly { + type Output = AuthenticatedDensePoly; + + fn mul(self, rhs: &AuthenticatedScalarResult) -> Self::Output { + let new_coeffs = self.coeffs.iter().map(|coeff| coeff * rhs).collect_vec(); + AuthenticatedDensePoly::from_coeffs(new_coeffs) + } +} +impl_borrow_variants!(AuthenticatedDensePoly, Mul, mul, *, AuthenticatedScalarResult, C: CurveGroup); +impl_commutative!(AuthenticatedDensePoly, Mul, mul, *, AuthenticatedScalarResult, C: CurveGroup); + // --- Division --- // /// Given a public divisor b(x) and shared dividend a(x) = a_1(x) + a_2(x) for party shares a_1, a_2 /// We can divide each share locally to obtain a secret sharing of \floor{a(x) / b(x)} @@ -541,18 +576,82 @@ mod test { assert_eq!(res.unwrap(), expected_res); } + /// Tests multiplying by a public constant scalar + #[tokio::test] + async fn test_scalar_mul_constant() { + let mut rng = thread_rng(); + let poly = random_poly(DEGREE_BOUND); + let scaling_factor = Scalar::random(&mut rng); + + let expected_res = &poly * scaling_factor.inner(); + + let (res, _) = execute_mock_mpc(|fabric| { + let poly = poly.clone(); + async move { + let shared_poly = share_poly(poly, PARTY0, &fabric); + (shared_poly * scaling_factor).open_authenticated().await + } + }) + .await; + + assert!(res.is_ok()); + assert_eq!(res.unwrap(), expected_res); + } + + /// Tests multiplying by a public result + #[tokio::test] + async fn test_scalar_mul_public() { + let mut rng = thread_rng(); + let poly = random_poly(DEGREE_BOUND); + let scaling_factor = Scalar::random(&mut rng); + + let expected_res = &poly * scaling_factor.inner(); + + let (res, _) = execute_mock_mpc(|fabric| { + let poly = poly.clone(); + async move { + let shared_poly = share_poly(poly, PARTY0, &fabric); + let scaling_factor = fabric.allocate_scalar(scaling_factor); + + (shared_poly * scaling_factor).open_authenticated().await + } + }) + .await; + + assert!(res.is_ok()); + assert_eq!(res.unwrap(), expected_res); + } + + /// Tests multiplying by a shared scalar + #[tokio::test] + async fn test_scalar_mul() { + let mut rng = thread_rng(); + let poly = random_poly(DEGREE_BOUND); + let scaling_factor = Scalar::random(&mut rng); + + let expected_res = &poly * scaling_factor.inner(); + + let (res, _) = execute_mock_mpc(|fabric| { + let poly = poly.clone(); + async move { + let shared_poly = share_poly(poly, PARTY0, &fabric); + let scaling_factor = fabric.share_scalar(scaling_factor, PARTY0); + + (shared_poly * scaling_factor).open_authenticated().await + } + }) + .await; + + assert!(res.is_ok()); + assert_eq!(res.unwrap(), expected_res); + } + /// Tests dividing a shared polynomial by a public polynomial #[tokio::test] async fn test_div_polynomial_public() { let poly1 = random_poly(DEGREE_BOUND); let poly2 = random_poly(DEGREE_BOUND); - let (poly1, poly2) = if poly1.degree() < poly2.degree() { - (poly2, poly1) - } else { - (poly1, poly2) - }; - let expected_res = &poly1 / &poly2; let (res, _) = execute_mock_mpc(|fabric| { diff --git a/src/algebra/poly/poly.rs b/src/algebra/poly/poly.rs index 08863fc..23da163 100644 --- a/src/algebra/poly/poly.rs +++ b/src/algebra/poly/poly.rs @@ -245,6 +245,7 @@ impl Mul<&DensePolynomialResult> for &DensePolynomialResult impl_borrow_variants!(DensePolynomialResult, Mul, mul, *, DensePolynomialResult, C: CurveGroup); // --- Scalar Multiplication --- // + impl Mul<&Scalar> for &DensePolynomialResult { type Output = DensePolynomialResult;