Skip to content

Commit

Permalink
chore: replace generator with hint in RLP decoding (#275)
Browse files Browse the repository at this point in the history
  • Loading branch information
hidenori-shinohara authored Nov 2, 2023
1 parent 09c0fc3 commit 598e609
Showing 1 changed file with 94 additions and 137 deletions.
231 changes: 94 additions & 137 deletions plonky2x/core/src/frontend/eth/rlp/builder.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
use std::marker::PhantomData;

use curta::math::field::Field;
use curta::math::prelude::PrimeField64;
use ethers::types::Bytes;
use log::info;
use num::bigint::ToBigInt;
use num::BigInt;
use plonky2::iop::generator::{GeneratedValues, SimpleGenerator};
use plonky2::iop::target::Target;
use plonky2::iop::witness::PartitionWitness;
use plonky2::plonk::circuit_data::CommonCircuitData;
use plonky2::util::serialization::{Buffer, IoResult};
use serde::{Deserialize, Serialize};

use crate::frontend::hint::simple::hint::Hint;
use crate::prelude::{
ArrayVariable, BoolVariable, ByteVariable, CircuitBuilder, CircuitVariable, PlonkParameters,
Variable,
ArrayVariable, BoolVariable, ByteVariable, CircuitBuilder, PlonkParameters, ValueStream,
Variable, VariableStream,
};

pub fn bool_to_u32(b: bool) -> u32 {
Expand Down Expand Up @@ -180,113 +175,39 @@ pub fn verify_decoded_list<const L: usize, const M: usize>(
assert!(claim_poly == encoding_poly);
}

#[derive(Debug, Clone)]
pub struct RLPDecodeListGenerator<
L: PlonkParameters<D>,
const D: usize,
const ENCODING_LEN: usize,
const LIST_LEN: usize,
const ELEMENT_LEN: usize,
> {
encoding: ArrayVariable<ByteVariable, ENCODING_LEN>,
length: Variable,
finish: BoolVariable,
pub decoded_list: ArrayVariable<ArrayVariable<ByteVariable, ELEMENT_LEN>, LIST_LEN>,
pub decoded_element_lens: ArrayVariable<Variable, LIST_LEN>,
pub len_decoded_list: Variable,
_phantom: PhantomData<L>,
}

impl<
L: PlonkParameters<D>,
const D: usize,
const ENCODING_LEN: usize,
const LIST_LEN: usize,
const ELEMENT_LEN: usize,
> RLPDecodeListGenerator<L, D, ENCODING_LEN, LIST_LEN, ELEMENT_LEN>
{
pub fn new(
builder: &mut CircuitBuilder<L, D>,
encoding: ArrayVariable<ByteVariable, ENCODING_LEN>,
length: Variable,
finish: BoolVariable,
) -> Self {
let decoded_list =
builder.init::<ArrayVariable<ArrayVariable<ByteVariable, ELEMENT_LEN>, LIST_LEN>>();
let decoded_element_lens = builder.init::<ArrayVariable<Variable, LIST_LEN>>();
let len_decoded_list = builder.init::<Variable>();
Self {
encoding,
length,
finish,
decoded_list,
decoded_element_lens,
len_decoded_list,
_phantom: PhantomData,
}
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
struct DecodeHint<const ENCODING_LEN: usize, const LIST_LEN: usize, const ELEMENT_LEN: usize> {}
impl<
L: PlonkParameters<D>,
const D: usize,
const ENCODING_LEN: usize,
const LIST_LEN: usize,
const ELEMENT_LEN: usize,
> SimpleGenerator<L::Field, D>
for RLPDecodeListGenerator<L, D, ENCODING_LEN, LIST_LEN, ELEMENT_LEN>
> Hint<L, D> for DecodeHint<ENCODING_LEN, LIST_LEN, ELEMENT_LEN>
{
fn id(&self) -> String {
"RLPDecodeListGenerator".to_string()
}
fn hint(&self, input_stream: &mut ValueStream<L, D>, output_stream: &mut ValueStream<L, D>) {
let encoded = input_stream.read_value::<ArrayVariable<ByteVariable, ENCODING_LEN>>();
let len = input_stream.read_value::<Variable>();
let finish = input_stream.read_value::<BoolVariable>();

fn dependencies(&self) -> Vec<Target> {
let mut targets: Vec<Target> = Vec::new();
targets.extend(self.encoding.targets());
targets.extend(self.length.targets());
targets.extend(self.finish.targets());
targets
}

fn run_once(
&self,
witness: &PartitionWitness<L::Field>,
out_buffer: &mut GeneratedValues<L::Field>,
) {
let finish = self.finish.get(witness);
let encoding = self.encoding.get(witness);
let length = self.length.get(witness).as_canonical_u64() as usize;
let (decoded_list, decoded_list_lens, len_decoded_list) =
decode_element_as_list::<ENCODING_LEN, LIST_LEN, ELEMENT_LEN>(
&encoding, length, finish,
&encoded,
len.as_canonical_u64() as usize,
finish,
);
self.decoded_list.set(out_buffer, decoded_list);
self.decoded_element_lens.set(
out_buffer,

output_stream
.write_value::<ArrayVariable<ArrayVariable<ByteVariable, ELEMENT_LEN>, LIST_LEN>>(
decoded_list,
);
output_stream.write_value::<ArrayVariable<Variable, LIST_LEN>>(
decoded_list_lens
.iter()
.map(|x| L::Field::from_canonical_usize(*x))
.collect(),
.collect::<Vec<_>>(),
);
self.len_decoded_list
.set(out_buffer, L::Field::from_canonical_usize(len_decoded_list));
}

#[allow(unused_variables)]
fn serialize(
&self,
dst: &mut Vec<u8>,
common_data: &CommonCircuitData<L::Field, D>,
) -> IoResult<()> {
todo!()
}

#[allow(unused_variables)]
fn deserialize(
src: &mut Buffer,
common_data: &CommonCircuitData<L::Field, D>,
) -> IoResult<Self> {
todo!()
output_stream.write_value::<Variable>(L::Field::from_canonical_usize(len_decoded_list));
}
}

Expand All @@ -305,14 +226,22 @@ impl<L: PlonkParameters<D>, const D: usize> CircuitBuilder<L, D> {
ArrayVariable<Variable, LIST_LEN>,
Variable,
) {
let generator = RLPDecodeListGenerator::new(self, encoded, len, finish);
self.add_simple_generator(generator.clone());
let mut input_stream = VariableStream::new();
input_stream.write(&encoded);
input_stream.write(&len);
input_stream.write(&finish);

let hint = DecodeHint::<ENCODING_LEN, LIST_LEN, ELEMENT_LEN> {};

let output_stream = self.hint(input_stream, hint);
let decoded_list = output_stream
.read::<ArrayVariable<ArrayVariable<ByteVariable, ELEMENT_LEN>, LIST_LEN>>(self);
let decoded_element_lens = output_stream.read::<ArrayVariable<Variable, LIST_LEN>>(self);
let len_decoded_list = output_stream.read::<Variable>(self);

// TODO: here add verification logic constraints using `builder` to check that the decoded list is correct
(
generator.decoded_list,
generator.decoded_element_lens,
generator.len_decoded_list,
)

(decoded_list, decoded_element_lens, len_decoded_list)
}
}

Expand Down Expand Up @@ -356,50 +285,78 @@ mod tests {
}

#[test]
fn test_rlp_decode_list_generator() {

fn test_rlp_decode_hint() {
let mut builder: CircuitBuilder<DefaultParameters, 2> = DefaultBuilder::new();
type F = GoldilocksField;

type F = GoldilocksField;
const ENCODING_LEN: usize = 600;
const LIST_LEN: usize = 17;
const ELEMENT_LEN: usize = 34;
let encoding = builder.read::<ArrayVariable<ByteVariable, ENCODING_LEN>>();
const ELEMENT_LEN: usize = 32;

let hint: DecodeHint<ENCODING_LEN, LIST_LEN, ELEMENT_LEN> =
DecodeHint::<ENCODING_LEN, LIST_LEN, ELEMENT_LEN> {};
let encoded = builder.read::<ArrayVariable<ByteVariable, ENCODING_LEN>>();
let len = builder.read::<Variable>();
let finish = builder.read::<BoolVariable>();

let (decoded_list, decoded_element_lens, len_decoded_list) = builder
.decode_element_as_list::<ENCODING_LEN, LIST_LEN, ELEMENT_LEN>(
encoding.clone(),
len,
finish,
let mut input_stream = VariableStream::new();
input_stream.write(&encoded);
input_stream.write(&len);
input_stream.write(&finish);
let output_stream = builder.hint(input_stream, hint);
let decoded_list = output_stream
.read::<ArrayVariable<ArrayVariable<ByteVariable, ELEMENT_LEN>, LIST_LEN>>(
&mut builder,
);
let decoded_element_lens =
output_stream.read::<ArrayVariable<Variable, LIST_LEN>>(&mut builder);
let len_decoded_list = output_stream.read::<Variable>(&mut builder);

builder.watch(&len_decoded_list, "len_decoded_list");
builder.watch(&decoded_element_lens, "decoded_element_lens");
builder.watch(&decoded_list, "decoded_list");

let circuit = builder.mock_build();
builder.write(decoded_list);
builder.write(decoded_element_lens);
builder.write(len_decoded_list);

let circuit = builder.build();
let mut input = circuit.input();

// This is a RLP-encoded list of length 17. Each of the first 16 elements is a 32-byte hash,
// and the last element is 0.
let rlp_encoding: Vec<u8> = bytes!("0xf90211a0215ead887d4da139eba306f76d765f5c4bfb03f6118ac1eb05eec3a92e1b0076a03eb28e7b61c689fae945b279f873cfdddf4e66db0be0efead563ea08bc4a269fa03025e2cce6f9c1ff09c8da516d938199c809a7f94dcd61211974aebdb85a4e56a0188d1100731419827900267bf4e6ea6d428fa5a67656e021485d1f6c89e69be6a0b281bb20061318a515afbdd02954740f069ebc75e700fde24dfbdf8c76d57119a0d8d77d917f5b7577e7e644bbc7a933632271a8daadd06a8e7e322f12dd828217a00f301190681b368db4308d1d1aa1794f85df08d4f4f646ecc4967c58fd9bd77ba0206598a4356dd50c70cfb1f0285bdb1402b7d65b61c851c095a7535bec230d5aa000959956c2148c82c207272af1ae129403d42e8173aedf44a190e85ee5fef8c3a0c88307e92c80a76e057e82755d9d67934ae040a6ec402bc156ad58dbcd2bcbc4a0e40a8e323d0b0b19d37ab6a3d110de577307c6f8efed15097dfb5551955fc770a02da2c6b12eedab6030b55d4f7df2fb52dab0ef4db292cb9b9789fa170256a11fa0d00e11cde7531fb79a315b4d81ea656b3b452fe3fe7e50af48a1ac7bf4aa6343a066625c0eb2f6609471f20857b97598ae4dfc197666ff72fe47b94e4124900683a0ace3aa5d35ba3ebbdc0abde8add5896876b25261717c0a415c92642c7889ec66a03a4931a67ae8ebc1eca9ffa711c16599b86d5286504182618d9c2da7b83f5ef780");
let mut encoding_fixed_size = [0u8; ENCODING_LEN];
encoding_fixed_size[..rlp_encoding.len()].copy_from_slice(&rlp_encoding);
let finish = false;
input.write::<ArrayVariable<ByteVariable, ENCODING_LEN>>(encoding_fixed_size.to_vec());
input.write::<Variable>(F::from_canonical_usize(rlp_encoding.len()));
input.write::<BoolVariable>(false);

let (witness, _output) = circuit.mock_prove(&input);

let len = len_decoded_list.get(&witness);
let decoded_element_lens = decoded_element_lens.get(&witness);
// let decoded_list = decoded_list.get(&witness);
assert!(len == F::from_canonical_usize(17));
for i in 0..17 {
if i == 16 {
assert!(decoded_element_lens[i] == F::from_canonical_usize(0));
} else {
assert!(decoded_element_lens[i] == F::from_canonical_usize(32));
}
input.write::<BoolVariable>(finish);

let (proof, mut output) = circuit.prove(&input);
circuit.verify(&proof, &input, &output);

let decoded_list_out =
output.read::<ArrayVariable<ArrayVariable<ByteVariable, ELEMENT_LEN>, LIST_LEN>>();
let decoded_element_lens_out = output.read::<ArrayVariable<Variable, LIST_LEN>>();
let len_decoded_list_out = output.read::<Variable>();

let (decoded_list_exp, decoded_list_lens_exp, len_decoded_list_exp) =
decode_element_as_list::<ENCODING_LEN, LIST_LEN, ELEMENT_LEN>(
&encoding_fixed_size,
rlp_encoding.len(),
finish,
);

assert_eq!(
len_decoded_list_out,
F::from_canonical_usize(len_decoded_list_exp)
);
assert_eq!(decoded_list_out.len(), LIST_LEN);
assert_eq!(len_decoded_list_out, F::from_canonical_usize(LIST_LEN));

for i in 0..LIST_LEN {
assert_eq!(decoded_list_out[i], decoded_list_exp[i]);
assert_eq!(
decoded_element_lens_out[i],
F::from_canonical_usize(decoded_list_lens_exp[i])
);
}
}
}

0 comments on commit 598e609

Please sign in to comment.