From 5cf3582dca74f1592c0962e6b4654d62e63553f0 Mon Sep 17 00:00:00 2001 From: Andrew Kirillov <20803092+akirillo@users.noreply.github.com> Date: Tue, 12 Sep 2023 14:50:08 -0700 Subject: [PATCH 1/2] using column-major order for weight matrices in verifier --- src/testing/test_utils.cairo | 108 ++-------- src/testing/tests/utils_tests.cairo | 52 +---- src/verifier.cairo | 309 +++++++++++++++++++--------- src/verifier/types.cairo | 154 +++----------- src/verifier/utils.cairo | 19 +- tests/src/verifier/utils.rs | 35 ---- 6 files changed, 279 insertions(+), 398 deletions(-) diff --git a/src/testing/test_utils.cairo b/src/testing/test_utils.cairo index c7e943a6..f4577fc5 100644 --- a/src/testing/test_utils.cairo +++ b/src/testing/test_utils.cairo @@ -61,29 +61,18 @@ fn get_test_matrix() -> SparseWeightMatrix { // [4, 5, 6, 0], // ] - // Matrix (sparse): + // Matrix (sparse, column-major): // [ - // [(0, 1)], - // [(0, 2), (1, 3)], - // [(0, 4), (1, 5), (2, 6)], + // [(0, 1), (1, 2), (2, 4)], + // [(1, 3), (2, 5)], + // [(2, 6)], // ] - let mut matrix = ArrayTrait::new(); - - let mut row_0 = ArrayTrait::new(); - row_0.append((0, 1.into())); - matrix.append(row_0); - - let mut row_1 = ArrayTrait::new(); - row_1.append((0, 2.into())); - row_1.append((1, 3.into())); - matrix.append(row_1); - - let mut row_2 = ArrayTrait::new(); - row_2.append((0, 4.into())); - row_2.append((1, 5.into())); - row_2.append((2, 6.into())); - matrix.append(row_2); + let matrix = array![ + array![(0, 1.into()), (1, 2.into()), (2, 4.into())], + array![(1, 3.into()), (2, 5.into())], + array![(2, 6.into())], + ]; matrix } @@ -91,73 +80,18 @@ fn get_test_matrix() -> SparseWeightMatrix { fn get_dummy_circuit_weights() -> ( SparseWeightMatrix, SparseWeightMatrix, SparseWeightMatrix, SparseWeightMatrix, SparseWeightVec, ) { - let mut W_L = ArrayTrait::new(); - let mut W_L_0 = ArrayTrait::new(); - W_L_0.append((0_usize, -(1.into()))); - W_L.append(W_L_0); - W_L.append(ArrayTrait::new()); - let mut W_L_2 = ArrayTrait::new(); - W_L_2.append((1_usize, -(1.into()))); - W_L.append(W_L_2); - W_L.append(ArrayTrait::new()); - let mut W_L_4 = ArrayTrait::new(); - W_L_4.append((2_usize, -(1.into()))); - W_L.append(W_L_4); - W_L.append(ArrayTrait::new()); - W_L.append(ArrayTrait::new()); - W_L.append(ArrayTrait::new()); - - let mut W_R = ArrayTrait::new(); - W_R.append(ArrayTrait::new()); - let mut W_R_1 = ArrayTrait::new(); - W_R_1.append((0_usize, -(1.into()))); - W_R.append(W_R_1); - W_R.append(ArrayTrait::new()); - let mut W_R_3 = ArrayTrait::new(); - W_R_3.append((1_usize, -(1.into()))); - W_R.append(W_R_3); - W_R.append(ArrayTrait::new()); - let mut W_R_5 = ArrayTrait::new(); - W_R_5.append((2_usize, -(1.into()))); - W_R.append(W_R_5); - W_R.append(ArrayTrait::new()); - W_R.append(ArrayTrait::new()); - - let mut W_O = ArrayTrait::new(); - W_O.append(ArrayTrait::new()); - W_O.append(ArrayTrait::new()); - W_O.append(ArrayTrait::new()); - W_O.append(ArrayTrait::new()); - let mut W_O_4 = ArrayTrait::new(); - W_O_4.append((0_usize, 1.into())); - W_O.append(W_O_4); - let mut W_O_5 = ArrayTrait::new(); - W_O_5.append((1_usize, 1.into())); - W_O.append(W_O_5); - W_O.append(ArrayTrait::new()); - let mut W_O_7 = ArrayTrait::new(); - W_O_7.append((2_usize, 1.into())); - W_O.append(W_O_7); - - let mut W_V = ArrayTrait::new(); - let mut W_V_0 = ArrayTrait::new(); - W_V_0.append((0_usize, -(1.into()))); - W_V.append(W_V_0); - let mut W_V_1 = ArrayTrait::new(); - W_V_1.append((1_usize, -(1.into()))); - W_V.append(W_V_1); - let mut W_V_2 = ArrayTrait::new(); - W_V_2.append((2_usize, -(1.into()))); - W_V.append(W_V_2); - let mut W_V_3 = ArrayTrait::new(); - W_V_3.append((3_usize, -(1.into()))); - W_V.append(W_V_3); - W_V.append(ArrayTrait::new()); - W_V.append(ArrayTrait::new()); - let mut W_V_6 = ArrayTrait::new(); - W_V_6.append((0_usize, -(1.into()))); - W_V.append(W_V_6); - W_V.append(ArrayTrait::new()); + let W_L = array![array![(0, -1.into())], array![(2, -1.into())], array![(4, -1.into())]]; + + let W_R = array![array![(1, -1.into())], array![(3, -1.into())], array![(5, -1.into())]]; + + let W_O = array![array![(4, 1.into())], array![(5, 1.into())], array![(7, 1.into())]]; + + let W_V = array![ + array![(0, -1.into()), (6, -1.into())], + array![(1, -1.into())], + array![(2, -1.into())], + array![(3, -1.into())], + ]; let mut c = ArrayTrait::new(); c.append((6_usize, 69.into())); diff --git a/src/testing/tests/utils_tests.cairo b/src/testing/tests/utils_tests.cairo index 6a27a9b1..11c162f5 100644 --- a/src/testing/tests/utils_tests.cairo +++ b/src/testing/tests/utils_tests.cairo @@ -4,7 +4,7 @@ use array::ArrayTrait; use renegade_contracts::{ utils::{math::{get_consecutive_powers, elt_wise_mul, binary_exp}, eq::ArrayTPartialEq}, - verifier::{scalar::Scalar, types::{SparseWeightMatrixTrait, SparseWeightVecTrait}}, + verifier::{scalar::Scalar, types::SparseWeightVecTrait}, }; use super::super::{ @@ -58,30 +58,6 @@ fn test_binary_exp_basic() { // | VERIFIER UTILS TESTS | // ------------------------ -#[test] -#[available_gas(100000000)] -fn test_flatten_sparse_weight_matrix_basic() { - let matrix = get_test_matrix(); - - let z = 2.into(); - let width = 4; - - // For z := [2, 2^2, 2^4, ...] we have expected := zW - // ("flattening" W matrix via left-multiplication by z) - let mut expected = ArrayTrait::new(); - // 2*1 + 4*2 + 8*4 = 42 - expected.append(42.into()); - // 4*3 + 8*5 = 52 - expected.append(52.into()); - // 8*6 = 48 - expected.append(48.into()); - expected.append(0.into()); - - let flattened = matrix.flatten(z, width); - - assert(flattened == expected, 'wrong flattened matrix'); -} - #[test] #[available_gas(100000000)] fn test_flatten_column_basic() { @@ -98,32 +74,6 @@ fn test_flatten_column_basic() { assert(flattened == 114.into(), 'wrong flattened column'); } -#[test] -#[available_gas(100000000)] -fn test_get_sparse_weight_column_basic() { - let matrix = get_test_matrix(); - - let col_0 = matrix.get_sparse_weight_column(0); - let col_1 = matrix.get_sparse_weight_column(1); - let col_2 = matrix.get_sparse_weight_column(2); - let col_3 = matrix.get_sparse_weight_column(3); - - let mut expected_col_0 = ArrayTrait::new(); - expected_col_0.append((0, 1.into())); - expected_col_0.append((1, 2.into())); - expected_col_0.append((2, 4.into())); - let mut expected_col_1 = ArrayTrait::new(); - expected_col_1.append((1, 3.into())); - expected_col_1.append((2, 5.into())); - let mut expected_col_2 = ArrayTrait::new(); - expected_col_2.append((2, 6.into())); - let expected_col_3 = ArrayTrait::new(); - - assert(col_0 == expected_col_0, 'wrong column 0'); - assert(col_1 == expected_col_1, 'wrong column 1'); - assert(col_2 == expected_col_2, 'wrong column 1'); - assert(col_3 == expected_col_3, 'wrong column 1'); -} // ----------------------- // | STORAGE UTILS TESTS | // ----------------------- diff --git a/src/verifier.cairo b/src/verifier.cairo index 4a8fda28..a86d38ee 100644 --- a/src/verifier.cairo +++ b/src/verifier.cairo @@ -56,6 +56,7 @@ mod MultiVerifier { EcPoint, ec_point_zero, ec_mul, ec_point_unwrap, ec_point_non_zero, ec_point_new, stark_curve }; + use hash::LegacyHash; use alexandria_data_structures::array_ext::ArrayTraitExt; use alexandria_math::fast_power::fast_power; @@ -71,8 +72,7 @@ mod MultiVerifier { types::{ VerificationJob, VerificationJobTrait, RemainingGenerators, RemainingGeneratorsTrait, VecPoly3, VecPoly3Term, VecPoly3Trait, SparseWeightVec, SparseWeightVecTrait, - SparseWeightMatrix, SparseWeightMatrixTrait, VecSubterm, Proof, CircuitParams, - VecIndices, VecIndicesTrait + SparseWeightMatrix, VecSubterm, Proof, CircuitParams, VecIndices, VecIndicesTrait }, utils::{squeeze_challenge_scalars, calc_delta, get_s_elem}, scalar::{Scalar, ScalarTrait} }; @@ -112,14 +112,22 @@ mod MultiVerifier { q: LegacyMap, /// Mapping from circuit ID -> the witness size for the circuit m: LegacyMap, - /// Mapping from circuit ID -> sparse-reduced matrix of left input weights for the circuit - W_L: LegacyMap>, - /// Mapping from circuit ID -> sparse-reduced matrix of right input weights for the circuit - W_R: LegacyMap>, - /// Mapping from circuit ID -> sparse-reduced matrix of output weights for the circuit - W_O: LegacyMap>, - /// Mapping from circuit ID -> sparse-reduced matrix of witness weights for the circuit - W_V: LegacyMap>, + /// Mapping from index -> column of left input weights. + /// All circuits' W_L columns are stored in this single mapping. For a given circuit, + /// the starting index of its columns is given by the hash of the circuit ID. + W_L: LegacyMap>, + /// Mapping from index -> column of right input weights. + /// All circuits' W_R columns are stored in this single mapping. For a given circuit, + /// the starting index of its columns is given by the hash of the circuit ID. + W_R: LegacyMap>, + /// Mapping from index -> column of output weights. + /// All circuits' W_O columns are stored in this single mapping. For a given circuit, + /// the starting index of its columns is given by the hash of the circuit ID. + W_O: LegacyMap>, + /// Mapping from index -> column of witness weights. + /// All circuits' W_V columns are stored in this single mapping. For a given circuit, + /// the starting index of its columns is given by the hash of the circuit ID. + W_V: LegacyMap>, /// Mapping from circuit ID -> sparse-reduced vector of constants for the circuit c: LegacyMap>, /// Mapping from circuit ID -> boolean indicating if the circuit's size params have been set @@ -213,85 +221,7 @@ mod MultiVerifier { !_is_circuit_fully_parameterized(@self, circuit_id), 'circuit already parameterized' ); - match circuit_params { - CircuitParams::SizeParams(circuit_size_params) => { - // Assert that n_plus = 2^k - assert( - fast_power( - 2, circuit_size_params.k.into(), MAX_USIZE.into() + 1 - ) == circuit_size_params - .n_plus - .into(), - 'n_plus != 2^k' - ); - - self.n.write(circuit_id, circuit_size_params.n); - self.n_plus.write(circuit_id, circuit_size_params.n_plus); - self.k.write(circuit_id, circuit_size_params.k); - self.q.write(circuit_id, circuit_size_params.q); - self.m.write(circuit_id, circuit_size_params.m); - - self.size_params_set.write(circuit_id, true); - }, - CircuitParams::W_L(w_l) => { - // Assert size params have been set - assert(self.size_params_set.read(circuit_id), 'size params not set'); - // Assert weight matrix has `q` rows - assert(w_l.len() == self.q.read(circuit_id), 'W_L has wrong number of rows'); - // Assert weight matrix has correct max number of columns - w_l.assert_width(self.n.read(circuit_id)); - - self.W_L.write(circuit_id, StoreSerdeWrapper { inner: w_l }); - - self.W_L_set.write(circuit_id, true); - }, - CircuitParams::W_R(w_r) => { - // Assert size params have been set - assert(self.size_params_set.read(circuit_id), 'size params not set'); - // Assert weight matrix has `q` rows - assert(w_r.len() == self.q.read(circuit_id), 'W_R has wrong number of rows'); - // Assert weight matrix has correct max number of columns - w_r.assert_width(self.n.read(circuit_id)); - - self.W_R.write(circuit_id, StoreSerdeWrapper { inner: w_r }); - - self.W_R_set.write(circuit_id, true); - }, - CircuitParams::W_O(w_o) => { - // Assert size params have been set - assert(self.size_params_set.read(circuit_id), 'size params not set'); - // Assert weight matrix has `q` rows - assert(w_o.len() == self.q.read(circuit_id), 'W_O has wrong number of rows'); - // Assert weight matrix has correct max number of columns - w_o.assert_width(self.n.read(circuit_id)); - - self.W_O.write(circuit_id, StoreSerdeWrapper { inner: w_o }); - - self.W_O_set.write(circuit_id, true); - }, - CircuitParams::W_V(w_v) => { - // Assert size params have been set - assert(self.size_params_set.read(circuit_id), 'size params not set'); - // Assert weight matrix has `q` rows - assert(w_v.len() == self.q.read(circuit_id), 'W_V has wrong number of rows'); - // Assert weight matrix has correct max number of columns - w_v.assert_width(self.m.read(circuit_id)); - - self.W_V.write(circuit_id, StoreSerdeWrapper { inner: w_v }); - - self.W_V_set.write(circuit_id, true); - }, - CircuitParams::C(c) => { - // Assert size params have been set - assert(self.size_params_set.read(circuit_id), 'size params not set'); - // Assert that `c` vector is not too wide - assert(c.len() <= self.q.read(circuit_id), 'c too wide'); - - self.c.write(circuit_id, StoreSerdeWrapper { inner: c }); - - self.c_set.write(circuit_id, true); - }, - }; + circuit_params.write_circuit_params(ref self, circuit_id); if _is_circuit_fully_parameterized(@self, circuit_id) { self.emit(Event::CircuitParameterized(CircuitParameterized { circuit_id })); @@ -330,10 +260,8 @@ mod MultiVerifier { let n_plus = self.n_plus.read(circuit_id); let k = self.k.read(circuit_id); let q = self.q.read(circuit_id); - let W_L = self.W_L.read(circuit_id).inner; - let W_R = self.W_R.read(circuit_id).inner; - let W_O = self.W_O.read(circuit_id).inner; - let W_V = self.W_V.read(circuit_id).inner; + let W_L = ContractAwareCircuitParamsTrait::read_full_W_L(@self, circuit_id); + let W_R = ContractAwareCircuitParamsTrait::read_full_W_R(@self, circuit_id); let c = self.c.read(circuit_id).inner; if self.feature_flags.read().enable_profiling @@ -951,16 +879,20 @@ mod MultiVerifier { ) -> Scalar { match self { VecSubterm::W_L_flat(()) => { - contract.W_L.read(circuit_id).inner.get_flattened_elem(vec_index, z) + let offset = ContractAwareCircuitParamsTrait::get_matrix_offset(circuit_id); + contract.W_L.read(offset + vec_index.into()).inner.flatten(z) }, VecSubterm::W_R_flat(()) => { - contract.W_R.read(circuit_id).inner.get_flattened_elem(vec_index, z) + let offset = ContractAwareCircuitParamsTrait::get_matrix_offset(circuit_id); + contract.W_R.read(offset + vec_index.into()).inner.flatten(z) }, VecSubterm::W_O_flat(()) => { - contract.W_O.read(circuit_id).inner.get_flattened_elem(vec_index, z) + let offset = ContractAwareCircuitParamsTrait::get_matrix_offset(circuit_id); + contract.W_O.read(offset + vec_index.into()).inner.flatten(z) }, VecSubterm::W_V_flat(()) => { - contract.W_V.read(circuit_id).inner.get_flattened_elem(vec_index, z) + let offset = ContractAwareCircuitParamsTrait::get_matrix_offset(circuit_id); + contract.W_V.read(offset + vec_index.into()).inner.flatten(z) }, VecSubterm::S(()) => { get_s_elem(u, vec_index) @@ -980,4 +912,187 @@ mod MultiVerifier { } } } + + #[generate_trait] + impl CircuitParamsImpl of ContractAwareCircuitParamsTrait { + fn get_matrix_offset(circuit_id: felt252) -> felt252 { + LegacyHash::hash(circuit_id, circuit_id) + } + + fn write_circuit_params( + self: CircuitParams, ref contract: ContractState, circuit_id: felt252 + ) { + match self { + CircuitParams::SizeParams(circuit_size_params) => { + // Assert that n_plus = 2^k + assert( + fast_power( + 2, circuit_size_params.k.into(), MAX_USIZE.into() + 1 + ) == circuit_size_params + .n_plus + .into(), + 'n_plus != 2^k' + ); + + contract.n.write(circuit_id, circuit_size_params.n); + contract.n_plus.write(circuit_id, circuit_size_params.n_plus); + contract.k.write(circuit_id, circuit_size_params.k); + contract.q.write(circuit_id, circuit_size_params.q); + contract.m.write(circuit_id, circuit_size_params.m); + + contract.size_params_set.write(circuit_id, true); + }, + CircuitParams::W_L(w_l) => { + // Assert size params have been set + assert(contract.size_params_set.read(circuit_id), 'size params not set'); + // Assert weight matrix has `n` columns + assert( + w_l.len() == contract.n.read(circuit_id), 'W_L has wrong number of columns' + ); + // // Assert weight matrix has correct max number of columns + // w_l.assert_width(contract.n.read(circuit_id)); + + let offset = ContractAwareCircuitParamsTrait::get_matrix_offset(circuit_id); + let mut i = 0; + loop { + if i == w_l.len() { + break; + } + + contract + .W_L + .write(offset + i.into(), StoreSerdeWrapper { inner: w_l[i].clone() }); + + i += 1; + }; + + contract.W_L_set.write(circuit_id, true); + }, + CircuitParams::W_R(w_r) => { + // Assert size params have been set + assert(contract.size_params_set.read(circuit_id), 'size params not set'); + // Assert weight matrix has `n` columns + assert( + w_r.len() == contract.n.read(circuit_id), 'W_R has wrong number of columns' + ); + // // Assert weight matrix has correct max number of columns + // w_r.assert_width(contract.n.read(circuit_id)); + + let offset = ContractAwareCircuitParamsTrait::get_matrix_offset(circuit_id); + let mut i = 0; + loop { + if i == w_r.len() { + break; + } + + contract + .W_R + .write(offset + i.into(), StoreSerdeWrapper { inner: w_r[i].clone() }); + + i += 1; + }; + + contract.W_R_set.write(circuit_id, true); + }, + CircuitParams::W_O(w_o) => { + // Assert size params have been set + assert(contract.size_params_set.read(circuit_id), 'size params not set'); + // Assert weight matrix has `n` columns + assert( + w_o.len() == contract.n.read(circuit_id), 'W_O has wrong number of columns' + ); + // // Assert weight matrix has correct max number of columns + // w_o.assert_width(contract.n.read(circuit_id)); + + let offset = ContractAwareCircuitParamsTrait::get_matrix_offset(circuit_id); + let mut i = 0; + loop { + if i == w_o.len() { + break; + } + + contract + .W_O + .write(offset + i.into(), StoreSerdeWrapper { inner: w_o[i].clone() }); + + i += 1; + }; + + contract.W_O_set.write(circuit_id, true); + }, + CircuitParams::W_V(w_v) => { + // Assert size params have been set + assert(contract.size_params_set.read(circuit_id), 'size params not set'); + // Assert weight matrix has `m` columns + assert( + w_v.len() == contract.m.read(circuit_id), 'W_V has wrong number of columns' + ); + // // Assert weight matrix has correct max number of columns + // w_v.assert_width(contract.n.read(circuit_id)); + + let offset = ContractAwareCircuitParamsTrait::get_matrix_offset(circuit_id); + let mut i = 0; + loop { + if i == w_v.len() { + break; + } + + contract + .W_V + .write(offset + i.into(), StoreSerdeWrapper { inner: w_v[i].clone() }); + + i += 1; + }; + + contract.W_V_set.write(circuit_id, true); + }, + CircuitParams::C(c) => { + // Assert size params have been set + assert(contract.size_params_set.read(circuit_id), 'size params not set'); + // Assert that `c` vector is not too long + assert(c.len() <= contract.q.read(circuit_id), 'c too long'); + + contract.c.write(circuit_id, StoreSerdeWrapper { inner: c }); + + contract.c_set.write(circuit_id, true); + }, + }; + } + + fn read_full_W_L(contract: @ContractState, circuit_id: felt252) -> SparseWeightMatrix { + let offset = ContractAwareCircuitParamsTrait::get_matrix_offset(circuit_id); + let num_columns = contract.n.read(circuit_id); + let mut w_l = ArrayTrait::new(); + let mut i = 0; + loop { + if i == num_columns { + break; + } + + w_l.append(contract.W_L.read(offset + i.into()).inner); + + i += 1; + }; + + w_l + } + + fn read_full_W_R(contract: @ContractState, circuit_id: felt252) -> SparseWeightMatrix { + let offset = ContractAwareCircuitParamsTrait::get_matrix_offset(circuit_id); + let num_columns = contract.n.read(circuit_id); + let mut w_r = ArrayTrait::new(); + let mut i = 0; + loop { + if i == num_columns { + break; + } + + w_r.append(contract.W_R.read(offset + i.into()).inner); + + i += 1; + }; + + w_r + } + } } diff --git a/src/verifier/types.cairo b/src/verifier/types.cairo index 38cfa6b3..c3c7c3a7 100644 --- a/src/verifier/types.cairo +++ b/src/verifier/types.cairo @@ -322,6 +322,7 @@ struct Proof { // (index, weight) entries in a sparse-reduced vector are expected // to be sorted by increasing index type SparseWeightVec = Array<(usize, Scalar)>; +// This matrix is assumed to be in *column-major* form type SparseWeightMatrix = Array; type SparseWeightVecSpan = Span<(usize, Scalar)>; @@ -353,131 +354,34 @@ impl SparseWeightVecImpl of SparseWeightVecTrait { } } -#[generate_trait] -impl SparseWeightMatrixImpl of SparseWeightMatrixTrait { - /// "Flattens" the matrix into a `width`-length vector by computing - /// [z, z^2, ..., z^height] * W_{L, R, O, V} (vector-matrix multiplication) - fn flatten(self: @SparseWeightMatrix, z: Scalar, width: usize) -> Array { - let matrix: SparseWeightMatrixSpan = self.deep_span(); - - // Can't set an item at a given index in an array, can only append, - // so we use a dict here - let mut flattened_dict: Felt252Dict> = Default::default(); - - // Loop over rows first, then entries - // Since matrices are sparse and in row-major form, this ensure that we only loop - // once per non-zero entry - let mut row_index: usize = 0; - loop { - if row_index == matrix.len() { - break; - }; - - let mut row = *matrix.at(row_index); - let mut entry_index = 0; - let z_i = binary_exp(z, (row_index + 1).into()); - loop { - if entry_index == row.len() { - break; - }; - - let (col_index, weight) = *row.at(entry_index); - let col_index_felt = col_index.into(); - let mut scalar = get_scalar_or_zero(ref flattened_dict, col_index_felt); - - // z vector starts at z^1, i.e. is [z, z^2, ..., z^q] - scalar += z_i * weight; - insert_scalar(ref flattened_dict, col_index_felt, scalar); - - entry_index += 1; - }; - - row_index += 1; - }; - - let mut flattened_vec = ArrayTrait::new(); - let mut col_index = 0; - loop { - if col_index == width { - break; - }; - - flattened_vec.append(get_scalar_or_zero(ref flattened_dict, col_index.into())); - col_index += 1; - }; - - flattened_vec - } - - /// Extracts a column from the matrix in the form of a sparse-reduced vector - fn get_sparse_weight_column(self: @SparseWeightMatrix, col_index: usize) -> SparseWeightVec { - let matrix: SparseWeightMatrixSpan = self.deep_span(); - let mut column = ArrayTrait::new(); - let mut row_index = 0; - loop { - if row_index == matrix.len() { - break; - }; - - let mut row = *matrix.at(row_index); - let mut entry_index = 0; - loop { - // Break early if we've passed the desired column's index. - // This relies on the assumption that sparse weight vector entries - // are sorted by increasing index. - if entry_index > col_index || entry_index == row.len() { - break; - }; - - let (current_index, current_weight) = *row.at(entry_index); - if current_index == col_index { - column.append((row_index, current_weight)); - break; - }; - - entry_index += 1; - }; - - row_index += 1; - }; - - column - } - - /// Gets the element at `index` in the flattened matrix - fn get_flattened_elem(self: @SparseWeightMatrix, index: usize, z: Scalar) -> Scalar { - // Pop column `index` from `matrix` as a `SparseWeightVec` - let column = self.get_sparse_weight_column(index); - - // Flatten the column using `z` - column.flatten(z) - } - - /// Asserts that the matrix has a maximum width of `width`. - /// This asserts both that each row has at most `width` entries, and that - /// the last entry in each row has an index less than `width`. - /// This relies on the assumption that sparse weight vector entries are sorted - /// by increasing index. - fn assert_width(self: @SparseWeightMatrix, width: usize) { - let matrix: SparseWeightMatrixSpan = self.deep_span(); - let mut row_index = 0; - loop { - if row_index == matrix.len() { - break; - }; - - let row = *matrix.at(row_index); - let row_len = row.len(); - assert(row_len <= width, 'row has too many entries'); - if row_len > 0 { - let (last_index, _) = *row.at(row.len() - 1); - assert(last_index <= width, 'last index in row too big'); - } - - row_index += 1; - }; - } -} +// #[generate_trait] +// impl SparseWeightMatrixImpl of SparseWeightMatrixTrait { +// /// Asserts that the matrix has a maximum width of `width`. +// /// This asserts both that each row has at most `width` entries, and that +// /// the last entry in each row has an index less than `width`. +// /// This relies on the assumption that sparse weight vector entries are sorted +// /// by increasing index. +// // TODO: Swap for `assert_height`? +// fn assert_width(self: @SparseWeightMatrix, width: usize) { +// let matrix: SparseWeightMatrixSpan = self.deep_span(); +// let mut row_index = 0; +// loop { +// if row_index == matrix.len() { +// break; +// }; + +// let row = *matrix.at(row_index); +// let row_len = row.len(); +// assert(row_len <= width, 'row has too many entries'); +// if row_len > 0 { +// let (last_index, _) = *row.at(row.len() - 1); +// assert(last_index <= width, 'last index in row too big'); +// } + +// row_index += 1; +// }; +// } +// } /// The public sizing parameters of the circuit #[derive(Drop, Clone, Serde, PartialEq)] diff --git a/src/verifier/utils.cairo b/src/verifier/utils.cairo index 33ac8d13..678b4779 100644 --- a/src/verifier/utils.cairo +++ b/src/verifier/utils.cairo @@ -10,7 +10,7 @@ use renegade_contracts::{ }; use super::{ - types::{SparseWeightMatrix, SparseWeightMatrixTrait, Proof}, scalar::{Scalar, ScalarTrait}, + types::{SparseWeightMatrix, SparseWeightVecTrait, Proof}, scalar::{Scalar, ScalarTrait}, }; @@ -129,6 +129,7 @@ fn squeeze_challenge_scalars( // TODO: Can make this more efficient by pre-computing all powers of z & selectively using in dot products // (will need all powers of z across both of W_L, W_R) // TODO: Technically, only need powers of y for which the corresponding column of W_R & W_L is non-zero +// (would require writing a dot product impl over sparse vectors) fn calc_delta( n: usize, y_inv_powers_to_n: Span, @@ -137,8 +138,20 @@ fn calc_delta( W_R: @SparseWeightMatrix ) -> Scalar { // Flatten W_L, W_R using z - let w_L_flat = W_L.flatten(z, n); - let w_R_flat = W_R.flatten(z, n); + let mut w_L_flat = ArrayTrait::new(); + let mut w_R_flat = ArrayTrait::new(); + let mut col_index = 0; + loop { + if col_index == n { + break; + }; + + let w_L_col = W_L[col_index]; + w_L_flat.append(w_L_col.flatten(z)); + + let w_R_col = W_R[col_index]; + w_R_flat.append(w_R_col.flatten(z)); + }; // \delta = dot(elt_wise_mul(y_inv_powers_to_n, w_R_flat.span()).span(), w_L_flat.span()) diff --git a/tests/src/verifier/utils.rs b/tests/src/verifier/utils.rs index e0c5bcd2..d1336120 100644 --- a/tests/src/verifier/utils.rs +++ b/tests/src/verifier/utils.rs @@ -255,38 +255,3 @@ pub fn prep_dummy_circuit_verifier(verifier: &mut Verifier, witness_commitments: debug!("Applying dummy circuit constraints on verifier..."); apply_dummy_circuit_constraints(a_var, b_var, x_var, y_var, verifier); } - -// fn get_dummy_circuit_weights() -> CircuitWeights { -// let mut transcript = HashChainTranscript::new(TRANSCRIPT_SEED.as_bytes()); -// let pc_gens = PedersenGens::default(); -// let mut prover = Prover::new(&pc_gens, &mut transcript); - -// let mut rng = thread_rng(); - -// let (_, a_var) = prover.commit(Scalar::random(&mut rng), Scalar::random(&mut rng)); -// let (_, b_var) = prover.commit(Scalar::random(&mut rng), Scalar::random(&mut rng)); -// let (_, x_var) = prover.commit(Scalar::random(&mut rng), Scalar::random(&mut rng)); -// let (_, y_var) = prover.commit(Scalar::random(&mut rng), Scalar::random(&mut rng)); - -// apply_dummy_circuit_constraints(a_var, b_var, x_var, y_var, &mut prover); - -// prover.get_weights() -// } - -// fn get_dummy_circuit_params() -> [CircuitParams; NUM_CIRCUITS] { -// let circuit_weights = get_dummy_circuit_weights(); -// [ -// CircuitParams::SizeParams(CircuitSizeParams { -// n: DUMMY_CIRCUIT_N, -// n_plus: DUMMY_CIRCUIT_N_PLUS, -// k: DUMMY_CIRCUIT_K, -// q: DUMMY_CIRCUIT_Q, -// m: DUMMY_CIRCUIT_M, -// }), -// CircuitParams::Wl(circuit_weights.w_l), -// CircuitParams::Wr(circuit_weights.w_r), -// CircuitParams::Wo(circuit_weights.w_o), -// CircuitParams::Wv(circuit_weights.w_v), -// CircuitParams::C(circuit_weights.c), -// ] -// } From 44d8d5e919578420583e27104cec6d1f2bf9692b Mon Sep 17 00:00:00 2001 From: Andrew Kirillov <20803092+akirillo@users.noreply.github.com> Date: Tue, 12 Sep 2023 18:15:36 -0700 Subject: [PATCH 2/2] using correct types from mpc-bulletproof branch --- Cargo.lock | 32 +++++++-------- src/verifier/scalar.cairo | 67 +++++++++++++++---------------- src/verifier/utils.cairo | 2 + tests/Cargo.toml | 12 +++--- tests/src/profiling/utils.rs | 6 +-- tests/src/utils.rs | 18 ++++----- tests/src/verifier_utils/utils.rs | 6 +-- 7 files changed, 72 insertions(+), 71 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a61d3760..fc29ac96 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1298,7 +1298,7 @@ dependencies = [ [[package]] name = "circuit-macros" version = "0.1.0" -source = "git+https://github.com/renegade-fi/renegade.git?branch=starknet-client-consolidation#895bd7db94296fe88af2e5248eb2fbeff3383449" +source = "git+https://github.com/renegade-fi/renegade.git?branch=andrew/column-major-weights#0f7c63fe9675ba8dad83b082960113030d8216a7" dependencies = [ "itertools 0.10.5", "mpc-bulletproof", @@ -1310,7 +1310,7 @@ dependencies = [ [[package]] name = "circuit-types" version = "0.1.0" -source = "git+https://github.com/renegade-fi/renegade.git?branch=starknet-client-consolidation#895bd7db94296fe88af2e5248eb2fbeff3383449" +source = "git+https://github.com/renegade-fi/renegade.git?branch=andrew/column-major-weights#0f7c63fe9675ba8dad83b082960113030d8216a7" dependencies = [ "ark-ec", "ark-ff", @@ -1337,7 +1337,7 @@ dependencies = [ [[package]] name = "circuits" version = "0.1.0" -source = "git+https://github.com/renegade-fi/renegade.git?branch=starknet-client-consolidation#895bd7db94296fe88af2e5248eb2fbeff3383449" +source = "git+https://github.com/renegade-fi/renegade.git?branch=andrew/column-major-weights#0f7c63fe9675ba8dad83b082960113030d8216a7" dependencies = [ "ark-crypto-primitives", "ark-ff", @@ -1478,7 +1478,7 @@ dependencies = [ [[package]] name = "common" version = "0.1.0" -source = "git+https://github.com/renegade-fi/renegade.git?branch=starknet-client-consolidation#895bd7db94296fe88af2e5248eb2fbeff3383449" +source = "git+https://github.com/renegade-fi/renegade.git?branch=andrew/column-major-weights#0f7c63fe9675ba8dad83b082960113030d8216a7" dependencies = [ "base64 0.13.1", "bimap", @@ -1507,7 +1507,7 @@ dependencies = [ [[package]] name = "config" version = "0.1.0" -source = "git+https://github.com/renegade-fi/renegade.git?branch=starknet-client-consolidation#895bd7db94296fe88af2e5248eb2fbeff3383449" +source = "git+https://github.com/renegade-fi/renegade.git?branch=andrew/column-major-weights#0f7c63fe9675ba8dad83b082960113030d8216a7" dependencies = [ "base64 0.13.1", "clap 3.2.25", @@ -1548,7 +1548,7 @@ checksum = "28c122c3980598d243d63d9a704629a2d748d101f278052ff068be5a4423ab6f" [[package]] name = "constants" version = "0.1.0" -source = "git+https://github.com/renegade-fi/renegade.git?branch=starknet-client-consolidation#895bd7db94296fe88af2e5248eb2fbeff3383449" +source = "git+https://github.com/renegade-fi/renegade.git?branch=andrew/column-major-weights#0f7c63fe9675ba8dad83b082960113030d8216a7" [[package]] name = "convert_case" @@ -2413,7 +2413,7 @@ dependencies = [ [[package]] name = "external-api" version = "0.1.0" -source = "git+https://github.com/renegade-fi/renegade.git?branch=starknet-client-consolidation#895bd7db94296fe88af2e5248eb2fbeff3383449" +source = "git+https://github.com/renegade-fi/renegade.git?branch=andrew/column-major-weights#0f7c63fe9675ba8dad83b082960113030d8216a7" dependencies = [ "circuit-types", "common", @@ -3444,7 +3444,7 @@ dependencies = [ [[package]] name = "gossip-api" version = "0.1.0" -source = "git+https://github.com/renegade-fi/renegade.git?branch=starknet-client-consolidation#895bd7db94296fe88af2e5248eb2fbeff3383449" +source = "git+https://github.com/renegade-fi/renegade.git?branch=andrew/column-major-weights#0f7c63fe9675ba8dad83b082960113030d8216a7" dependencies = [ "circuit-types", "common", @@ -3946,7 +3946,7 @@ checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" [[package]] name = "job-types" version = "0.1.0" -source = "git+https://github.com/renegade-fi/renegade.git?branch=starknet-client-consolidation#895bd7db94296fe88af2e5248eb2fbeff3383449" +source = "git+https://github.com/renegade-fi/renegade.git?branch=andrew/column-major-weights#0f7c63fe9675ba8dad83b082960113030d8216a7" dependencies = [ "circuit-types", "circuits", @@ -4540,7 +4540,7 @@ checksum = "4519a88847ba2d5ead3dc53f1060ec6a571de93f325d9c5c4968147382b1cbc3" [[package]] name = "mpc-bulletproof" version = "0.1.0" -source = "git+https://github.com/renegade-fi/mpc-bulletproof.git#0c579bdd3734026f2a096536b52ee6f5863fe3f3" +source = "git+https://github.com/renegade-fi/mpc-bulletproof.git?branch=andrew/column-major-weights#1743b6e95c86417858d0bf68ba6a9d8a9acfbec6" dependencies = [ "ark-ff", "ark-serialize", @@ -5611,7 +5611,7 @@ checksum = "c707298afce11da2efef2f600116fa93ffa7a032b5d7b628aa17711ec81383ca" [[package]] name = "renegade-crypto" version = "0.1.0" -source = "git+https://github.com/renegade-fi/renegade.git?branch=starknet-client-consolidation#895bd7db94296fe88af2e5248eb2fbeff3383449" +source = "git+https://github.com/renegade-fi/renegade.git?branch=andrew/column-major-weights#0f7c63fe9675ba8dad83b082960113030d8216a7" dependencies = [ "ark-crypto-primitives", "ark-ff", @@ -6449,7 +6449,7 @@ dependencies = [ [[package]] name = "starknet-client" version = "0.1.0" -source = "git+https://github.com/renegade-fi/renegade.git?branch=starknet-client-consolidation#895bd7db94296fe88af2e5248eb2fbeff3383449" +source = "git+https://github.com/renegade-fi/renegade.git?branch=andrew/column-major-weights#0f7c63fe9675ba8dad83b082960113030d8216a7" dependencies = [ "ark-ff", "circuit-types", @@ -6670,7 +6670,7 @@ dependencies = [ [[package]] name = "state" version = "0.1.0" -source = "git+https://github.com/renegade-fi/renegade.git?branch=starknet-client-consolidation#895bd7db94296fe88af2e5248eb2fbeff3383449" +source = "git+https://github.com/renegade-fi/renegade.git?branch=andrew/column-major-weights#0f7c63fe9675ba8dad83b082960113030d8216a7" dependencies = [ "base64 0.13.1", "circuit-types", @@ -6788,7 +6788,7 @@ dependencies = [ [[package]] name = "system-bus" version = "0.1.0" -source = "git+https://github.com/renegade-fi/renegade.git?branch=starknet-client-consolidation#895bd7db94296fe88af2e5248eb2fbeff3383449" +source = "git+https://github.com/renegade-fi/renegade.git?branch=andrew/column-major-weights#0f7c63fe9675ba8dad83b082960113030d8216a7" dependencies = [ "bus", "common", @@ -6845,7 +6845,7 @@ checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" [[package]] name = "test-helpers" version = "0.1.0" -source = "git+https://github.com/renegade-fi/renegade.git?branch=starknet-client-consolidation#895bd7db94296fe88af2e5248eb2fbeff3383449" +source = "git+https://github.com/renegade-fi/renegade.git?branch=andrew/column-major-weights#0f7c63fe9675ba8dad83b082960113030d8216a7" dependencies = [ "async-trait", "constants", @@ -7481,7 +7481,7 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "util" version = "0.1.0" -source = "git+https://github.com/renegade-fi/renegade.git?branch=starknet-client-consolidation#895bd7db94296fe88af2e5248eb2fbeff3383449" +source = "git+https://github.com/renegade-fi/renegade.git?branch=andrew/column-major-weights#0f7c63fe9675ba8dad83b082960113030d8216a7" dependencies = [ "chrono", "env_logger 0.9.3", diff --git a/src/verifier/scalar.cairo b/src/verifier/scalar.cairo index 9b561f54..cb45ddd3 100644 --- a/src/verifier/scalar.cairo +++ b/src/verifier/scalar.cairo @@ -23,21 +23,20 @@ struct Scalar { impl ScalarImpl of ScalarTrait { fn inverse(self: @Scalar) -> Scalar { // Safe to unwrap b/c scalar field is smaller than base field - // let inner = mult_inverse((*self.inner).into(), stark_curve::ORDER.into()) - // .try_into() - // .unwrap(); - let inner = felt252_div(1, (*self.inner).try_into().unwrap()); + let inner = mult_inverse((*self.inner).into(), stark_curve::ORDER.into()) + .try_into() + .unwrap(); + // let inner = felt252_div(1, (*self.inner).try_into().unwrap()); Scalar { inner } } fn pow(self: @Scalar, exponent: u256) -> Scalar { // Safe to unwrap b/c scalar field is smaller than base field - // let inner = pow_mod((*self.inner).into(), exponent, stark_curve::ORDER.into()) - // .try_into() - // .unwrap(); - // Scalar { inner } - - binary_exp(*self, exponent) + let inner = pow_mod((*self.inner).into(), exponent, stark_curve::ORDER.into()) + .try_into() + .unwrap(); + Scalar { inner } + // binary_exp(*self, exponent) } } @@ -52,10 +51,10 @@ impl ScalarImpl of ScalarTrait { impl ScalarAdd of Add { fn add(lhs: Scalar, rhs: Scalar) -> Scalar { // Safe to unwrap b/c scalar field is smaller than base field - // let inner = add_mod(lhs.inner.into(), rhs.inner.into(), stark_curve::ORDER.into()) - // .try_into() - // .unwrap(); - let inner = (lhs.inner + rhs.inner); + let inner = add_mod(lhs.inner.into(), rhs.inner.into(), stark_curve::ORDER.into()) + .try_into() + .unwrap(); + // let inner = (lhs.inner + rhs.inner); Scalar { inner } } } @@ -73,10 +72,10 @@ impl ScalarAddEq of AddEq { impl ScalarSub of Sub { fn sub(lhs: Scalar, rhs: Scalar) -> Scalar { // Safe to unwrap b/c scalar field is smaller than base field - // let inner = sub_mod(lhs.inner.into(), rhs.inner.into(), stark_curve::ORDER.into()) - // .try_into() - // .unwrap(); - let inner = (lhs.inner - rhs.inner); + let inner = sub_mod(lhs.inner.into(), rhs.inner.into(), stark_curve::ORDER.into()) + .try_into() + .unwrap(); + // let inner = (lhs.inner - rhs.inner); Scalar { inner } } } @@ -94,10 +93,10 @@ impl ScalarSubEq of SubEq { impl ScalarMul of Mul { fn mul(lhs: Scalar, rhs: Scalar) -> Scalar { // Safe to unwrap b/c scalar field is smaller than base field - // let inner = mult_mod(lhs.inner.into(), rhs.inner.into(), stark_curve::ORDER.into()) - // .try_into() - // .unwrap(); - let inner = (lhs.inner * rhs.inner); + let inner = mult_mod(lhs.inner.into(), rhs.inner.into(), stark_curve::ORDER.into()) + .try_into() + .unwrap(); + // let inner = (lhs.inner * rhs.inner); Scalar { inner } } } @@ -117,10 +116,10 @@ impl ScalarDiv of Div { // Under the hood, this is implemented as // lhs * rhs.inverse() // Safe to unwrap b/c scalar field is smaller than base field - // let inner = div_mod(lhs.inner.into(), rhs.inner.into(), stark_curve::ORDER.into()) - // .try_into() - // .unwrap(); - let inner = felt252_div(lhs.inner, rhs.inner.try_into().unwrap()); + let inner = div_mod(lhs.inner.into(), rhs.inner.into(), stark_curve::ORDER.into()) + .try_into() + .unwrap(); + // let inner = felt252_div(lhs.inner, rhs.inner.try_into().unwrap()); Scalar { inner } } } @@ -138,8 +137,8 @@ impl ScalarDivEq of DivEq { impl ScalarNeg of Neg { fn neg(a: Scalar) -> Scalar { // Safe to unwrap b/c scalar field is smaller than base field - // let inner = add_inverse_mod(a.inner.into(), stark_curve::ORDER.into()).try_into().unwrap(); - let inner = -a.inner; + let inner = add_inverse_mod(a.inner.into(), stark_curve::ORDER.into()).try_into().unwrap(); + // let inner = -a.inner; Scalar { inner } } } @@ -154,8 +153,8 @@ impl ScalarNeg of Neg { impl U256IntoScalar of Into { fn into(self: u256) -> Scalar { - // let inner_u256 = self % stark_curve::ORDER.into(); - let inner_u256 = self % BASE_FIELD_ORDER; + let inner_u256 = self % stark_curve::ORDER.into(); + // let inner_u256 = self % BASE_FIELD_ORDER; // Safe to unwrap b/c scalar field is smaller than base field Scalar { inner: inner_u256.try_into().unwrap() } } @@ -163,10 +162,10 @@ impl U256IntoScalar of Into { impl FeltIntoScalar> of Into { fn into(self: T) -> Scalar { - // let inner_felt: felt252 = self.into(); - // let inner_u256: u256 = inner_felt.into(); - // inner_u256.into() - Scalar { inner: self.into() } + let inner_felt: felt252 = self.into(); + let inner_u256: u256 = inner_felt.into(); + inner_u256.into() + // Scalar { inner: self.into() } } } diff --git a/src/verifier/utils.cairo b/src/verifier/utils.cairo index 678b4779..0542b60b 100644 --- a/src/verifier/utils.cairo +++ b/src/verifier/utils.cairo @@ -151,6 +151,8 @@ fn calc_delta( let w_R_col = W_R[col_index]; w_R_flat.append(w_R_col.flatten(z)); + + col_index += 1; }; // \delta = diff --git a/tests/Cargo.toml b/tests/Cargo.toml index 087bb185..bae6acd4 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -28,10 +28,10 @@ starknet = { workspace = true } katana-core = { git = "https://github.com/renegade-fi/dojo.git", branch = "renegade-testing" } dojo-test-utils = { git = "https://github.com/renegade-fi/dojo.git", branch = "renegade-testing" } mpc-stark = { workspace = true } -mpc-bulletproof = { git = "https://github.com/renegade-fi/mpc-bulletproof.git", features = ["integration_test"] } +mpc-bulletproof = { git = "https://github.com/renegade-fi/mpc-bulletproof.git", branch="andrew/column-major-weights", features = ["integration_test"] } merlin = { git = "https://github.com/renegade-fi/merlin.git" } -renegade-crypto = { git = "https://github.com/renegade-fi/renegade.git", branch = "starknet-client-consolidation" } -circuits = { git = "https://github.com/renegade-fi/renegade.git", branch = "starknet-client-consolidation", features = ["test_helpers"] } -circuit-types = { git = "https://github.com/renegade-fi/renegade.git", branch = "starknet-client-consolidation" } -test-helpers = { git = "https://github.com/renegade-fi/renegade.git", branch = "starknet-client-consolidation" } -starknet-client = { git = "https://github.com/renegade-fi/renegade.git", branch = "starknet-client-consolidation" } +renegade-crypto = { git = "https://github.com/renegade-fi/renegade.git", branch = "andrew/column-major-weights" } +circuits = { git = "https://github.com/renegade-fi/renegade.git", branch = "andrew/column-major-weights", features = ["test_helpers"] } +circuit-types = { git = "https://github.com/renegade-fi/renegade.git", branch = "andrew/column-major-weights" } +test-helpers = { git = "https://github.com/renegade-fi/renegade.git", branch = "andrew/column-major-weights" } +starknet-client = { git = "https://github.com/renegade-fi/renegade.git", branch = "andrew/column-major-weights" } diff --git a/tests/src/profiling/utils.rs b/tests/src/profiling/utils.rs index 6ba773b7..0d798c4f 100644 --- a/tests/src/profiling/utils.rs +++ b/tests/src/profiling/utils.rs @@ -29,7 +29,7 @@ use dojo_test_utils::sequencer::TestSequencer; use eyre::{eyre, Result}; use merlin::HashChainTranscript; use mpc_bulletproof::{ - r1cs::{Prover, R1CSProof, SparseReducedMatrix, Verifier}, + r1cs::{Prover, R1CSProof, SparseWeightMatrix, Verifier}, BulletproofGens, PedersenGens, }; use mpc_stark::algebra::{scalar::Scalar, stark_curve::StarkPoint}; @@ -201,8 +201,8 @@ pub async fn invoke_calc_delta( n: FieldElement, y_inv_powers_to_n: &Vec, z: Scalar, - w_l: SparseReducedMatrix, - w_r: SparseReducedMatrix, + w_l: SparseWeightMatrix, + w_r: SparseWeightMatrix, ) -> Result<()> { let calldata = iter::once(n) .chain(y_inv_powers_to_n.to_calldata()) diff --git a/tests/src/utils.rs b/tests/src/utils.rs index a21aa6d8..1a0d49dd 100644 --- a/tests/src/utils.rs +++ b/tests/src/utils.rs @@ -34,7 +34,7 @@ use merlin::HashChainTranscript; use mpc_bulletproof::{ r1cs::{ CircuitWeights, ConstraintSystem, LinearCombination, Prover, R1CSProof, - RandomizableConstraintSystem, SparseReducedMatrix, SparseWeightRow, Variable, + RandomizableConstraintSystem, SparseWeightMatrix, SparseWeightVec, Variable, }, r1cs_mpc::R1CSError, BulletproofGens, PedersenGens, @@ -658,15 +658,15 @@ pub enum CircuitParams { /// Sizing parameters for the circuit SizeParams(CircuitSizeParams), /// Sparse-reduced matrix of left input weights for the circuit - Wl(SparseReducedMatrix), + Wl(SparseWeightMatrix), /// Sparse-reduced matrix of right input weights for the circuit - Wr(SparseReducedMatrix), + Wr(SparseWeightMatrix), /// Sparse-reduced matrix of output weights for the circuit - Wo(SparseReducedMatrix), + Wo(SparseWeightMatrix), /// Sparse-reduced matrix of witness weights for the circuit - Wv(SparseReducedMatrix), + Wv(SparseWeightMatrix), /// Sparse-reduced vector of constants for the circuit - C(SparseWeightRow), + C(SparseWeightVec), } pub struct NewWalletArgs { @@ -766,7 +766,7 @@ impl CalldataSerializable for [T; N] { } } -// `(usize, Scalar)` represents an entry in a `SparseWeightRow` +// `(usize, Scalar)` represents an entry in a `SparseWeightVec` impl CalldataSerializable for (usize, Scalar) { fn to_calldata(&self) -> Vec { self.0 @@ -777,13 +777,13 @@ impl CalldataSerializable for (usize, Scalar) { } } -impl CalldataSerializable for SparseWeightRow { +impl CalldataSerializable for SparseWeightVec { fn to_calldata(&self) -> Vec { self.0.to_calldata() } } -impl CalldataSerializable for SparseReducedMatrix { +impl CalldataSerializable for SparseWeightMatrix { fn to_calldata(&self) -> Vec { self.0.to_calldata() } diff --git a/tests/src/verifier_utils/utils.rs b/tests/src/verifier_utils/utils.rs index f33a9128..66dacc5b 100644 --- a/tests/src/verifier_utils/utils.rs +++ b/tests/src/verifier_utils/utils.rs @@ -3,7 +3,7 @@ use eyre::{eyre, Result}; use merlin::HashChainTranscript; use mpc_bulletproof::{ inner_product, - r1cs::{R1CSProof, SparseReducedMatrix, Verifier}, + r1cs::{R1CSProof, SparseWeightMatrix, Verifier}, InnerProductProof, TranscriptProtocol, }; use mpc_stark::algebra::{scalar::Scalar, stark_curve::StarkPoint}; @@ -106,8 +106,8 @@ pub async fn calc_delta( account: &ScriptAccount, y_inv_powers_to_n: &Vec, z: Scalar, - w_l: SparseReducedMatrix, - w_r: SparseReducedMatrix, + w_l: SparseWeightMatrix, + w_r: SparseWeightMatrix, ) -> Result { let calldata = iter::once(FieldElement::from(DUMMY_CIRCUIT_N)) .chain(y_inv_powers_to_n.to_calldata())