Skip to content

Commit

Permalink
fix: sha padding asserts and other bug fixes (#268)
Browse files Browse the repository at this point in the history
Co-authored-by: Ratan Kaliani <[email protected]>
  • Loading branch information
tamirhemo and ratankaliani authored Oct 31, 2023
1 parent 669ad3a commit 5b85909
Show file tree
Hide file tree
Showing 14 changed files with 418 additions and 187 deletions.
18 changes: 9 additions & 9 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 7 additions & 6 deletions plonky2x/core/src/frontend/hash/sha/curta/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@ use super::SHA;
use crate::prelude::*;

impl<L: PlonkParameters<D>, const D: usize> CircuitBuilder<L, D> {
/// The constraints for an accelerated SHA computation using Curta.
pub(crate) fn curta_constrain_sha<S: SHA<L, D, CYCLE_LEN>, const CYCLE_LEN: usize>(
&mut self,
accelerator: SHAAccelerator<S::IntVariable>,
) where
Chip<S::AirParameters>: Plonky2Air<L::Field, D>,
{
// Write all the digest using the digest hint.
// Get all the digest values using the digest hint.
for (request, response) in accelerator
.sha_requests
.iter()
Expand All @@ -42,22 +43,22 @@ impl<L: PlonkParameters<D>, const D: usize> CircuitBuilder<L, D> {
}

// Prove correctness of the digest using the proof hint.

// Initialize the corresponding stark and hint.
let sha_data = S::get_sha_data(self, accelerator);
let parameters = sha_data.parameters();

let sha_stark = S::stark(parameters);
let proof_hint = SHAProofHint::<S, CYCLE_LEN>::new(parameters);

let mut input_stream = VariableStream::new();
input_stream.write_sha_input(&sha_data);

// Read the stark proof and public inputs from the hint's output stream.
let output_stream = self.hint(input_stream, proof_hint);

let sha_stark = S::stark(parameters);

let proof = output_stream.read_byte_stark_proof(self, &sha_stark.stark);
let num_public_inputs = sha_stark.stark.air_data.num_public_inputs;
let public_inputs = output_stream.read_vec(self, num_public_inputs);

// Verify the proof.
sha_stark.verify_proof(self, proof, &public_inputs, sha_data)
}
}
16 changes: 15 additions & 1 deletion plonky2x/core/src/frontend/hash/sha/curta/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,25 @@ use crate::prelude::{
BoolVariable, CircuitVariable, PlonkParameters, ValueStream, Variable, VariableStream,
};

/// Circuit variables for the input data of a SHA computation.
pub struct SHAInputData<T> {
/// The padded chunks of the input message.
pub padded_chunks: Vec<T>,
// A flag for each chunk indicating whether the hash state needs to be restarted after
// processing the chunk.
pub end_bits: Vec<BoolVariable>,
/// A flag for each chunk indicating whether the digest should be read after processing the
/// chunk.
pub digest_bits: Vec<BoolVariable>,
/// The index of the digests to be read, corresponding to their location in `padded chunks`.
pub digest_indices: Vec<Variable>,
/// The message digests.
pub digests: Vec<[T; 8]>,
}

/// The values of the input data of a SHA computation.
///
/// This struct represents the values of the variables of `SHAInputData`.
pub struct SHAInputDataValues<
L: PlonkParameters<D>,
S: SHA<L, D, CYCLE_LEN>,
Expand All @@ -26,13 +37,15 @@ pub struct SHAInputDataValues<
pub digests: Vec<[S::Integer; 8]>,
}

/// The parameters required for reading the input data of a SHA computation from a stream.
#[derive(Clone, Debug, Copy, Serialize, Deserialize)]
pub struct SHAInputParameters {
pub num_chunks: usize,
pub num_digests: usize,
}

impl<T> SHAInputData<T> {
/// Get parameters from the input data.
pub fn parameters(&self) -> SHAInputParameters {
SHAInputParameters {
num_chunks: self.end_bits.len(),
Expand All @@ -42,6 +55,7 @@ impl<T> SHAInputData<T> {
}

impl VariableStream {
/// Read sha input data from the stream.
pub fn write_sha_input<T: CircuitVariable>(&mut self, input: &SHAInputData<T>) {
self.write_slice(&input.padded_chunks);
self.write_slice(&input.end_bits);
Expand All @@ -52,7 +66,7 @@ impl VariableStream {
}

impl<L: PlonkParameters<D>, const D: usize> ValueStream<L, D> {
#[allow(clippy::type_complexity)]
/// Read sha input data from the stream.
pub fn read_sha_input_values<S: SHA<L, D, CYCLE_LEN>, const CYCLE_LEN: usize>(
&mut self,
parameters: SHAInputParameters,
Expand Down
1 change: 1 addition & 0 deletions plonky2x/core/src/frontend/hash/sha/curta/digest_hint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use super::SHA;
use crate::frontend::hint::simple::hint::Hint;
use crate::prelude::*;

/// Provides the SHA of a message usign the algorithm specified by `S`.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SHADigestHint<S, const CYCLE_LEN: usize> {
_marker: PhantomData<S>,
Expand Down
63 changes: 56 additions & 7 deletions plonky2x/core/src/frontend/hash/sha/curta/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,86 +23,131 @@ pub mod proof_hint;
pub mod request;
pub mod stark;

/// An interface for a circuit that computes SHA using Curta.
pub trait SHA<L: PlonkParameters<D>, const D: usize, const CYCLE_LEN: usize>:
SHAir<BytesBuilder<Self::AirParameters>, CYCLE_LEN>
{
/// A `CircuitVariable` that represents the integer registers used by the hash function.
///
/// For example, in `SHA256` this would be a `CircuitVariable` that represents a 32-bit integer.
type IntVariable: CircuitVariable<ValueType<L::Field> = Self::Integer> + Copy;
/// A `CircuitVariable` that represents the hash digest.
type DigestVariable: CircuitVariable;

/// The air parameters of the corresponding Curta stark.
type AirParameters: AirParameters<
Field = L::Field,
CubicParams = L::CubicParams,
Instruction = Self::AirInstruction,
>;

/// The air instructions of the corresponding Curta stark.
type AirInstruction: UintInstructions;

/// Pad an input message of fixed length.
fn pad_circuit(
builder: &mut CircuitBuilder<L, D>,
input: &[ByteVariable],
) -> Vec<Self::IntVariable>;

/// Pad an input message of variable length.
fn pad_circuit_variable_length(
builder: &mut CircuitBuilder<L, D>,
input: &[ByteVariable],
length: U32Variable,
last_chunk: U32Variable,
) -> Vec<Self::IntVariable>;

/// Convert a value of the `Self::IntRegister` type to a `Self::IntVariable`.
///
/// This is used to assert compatibility between the stark and the circuit representation of the
/// integer variables.
fn value_to_variable(
builder: &mut CircuitBuilder<L, D>,
value: <Self::IntRegister as Register>::Value<Variable>,
) -> Self::IntVariable;

/// Convert a value of the `Self::DigestRegister` to an array of `Self::IntVariable`s.
fn digest_to_array(
builder: &mut CircuitBuilder<L, D>,
digest: Self::DigestVariable,
) -> [Self::IntVariable; 8];

/// Get the input data for the stark from a `SHAAccelerator`.
fn get_sha_data(
builder: &mut CircuitBuilder<L, D>,
accelerator: SHAAccelerator<Self::IntVariable>,
) -> SHAInputData<Self::IntVariable> {
// Initialze the data struictures of `SHAInputData`.
let mut end_bit_values = Vec::new();
let mut digest_bits = Vec::new();
let mut current_chunk_index = 0;
let mut digest_indices = Vec::<Variable>::new();

// Get the padded chunks from input messages, and assign the correct values of end_bit,
// digest_bits, and digest_indices.
//
// `end_bit` - a bit that is true for the last chunk of a message (or the total_message
// passed in the case of a message of variable length).
// `digest_bit` - a bit that indicates the hash state is read into the digest list after
// processing the chunk. For a message of fioxed length, this is the same as `end_bit`.
// `digest_index` - the index of the digest to be read, corresponding to the location of the
// chunk in the padded chunks.
let padded_chunks = accelerator
.sha_requests
.iter()
.flat_map(|req| {
let (padded_chunks, chunk_index) = match req {
// For every reuqest, we read the corresponding messagem, pad it, and compute the
// corresponding chunk index.

// Get the padded chunks and the number of chunks in the message, depending on the
// type of the request.
let (padded_chunks, last_chunk_index) = match req {
SHARequest::Fixed(input) => {
// If the length is fixed, the chunk_index is just `number of chunks - 1``.
let padded_chunks = Self::pad_circuit(builder, input);
let num_chunks =
builder.constant((padded_chunks.len() / 16 - 1).try_into().unwrap());
(padded_chunks, num_chunks)
}
// If the length of the massage is a variable, we read the chunk index form the
// request.
SHARequest::Variable(input, length, last_chunk) => (
Self::pad_circuit_variable_length(builder, input, *length, *last_chunk),
Self::pad_circuit_variable_length(builder, input, *length),
*last_chunk,
),
};
// Get the total number of chunks processed.
let total_number_of_chunks = padded_chunks.len() / 16;
// Store the end_bit values. The end bit indicates the end of message chunks.
end_bit_values.extend_from_slice(&vec![false; total_number_of_chunks - 1]);
end_bit_values.push(true);
// The chunk index is given by the currenty index plus the chunk index we got from
// the request.
let current_chunk_index_variable = builder
.constant::<Variable>(L::Field::from_canonical_usize(current_chunk_index));
let digest_index = builder.add(current_chunk_index_variable, chunk_index.variable);
let digest_index =
builder.add(current_chunk_index_variable, last_chunk_index.variable);
digest_indices.push(digest_index);
// The digest bit is equal to zero for all chunks except the one that corresponds to
// the `chunk_index`. We find the bits by comparing each value between 0 and the
// total number of chunks to the `chunk_index`.
let mut flag = builder.constant::<BoolVariable>(true);
for j in 0..total_number_of_chunks {
let j_var = builder.constant::<U32Variable>(j as u32);
let lte = builder.lte(chunk_index, j_var);
let lte = builder.lte(last_chunk_index, j_var);
let lte_times_flag = builder.and(lte, flag);
digest_bits.push(lte_times_flag);
let not_lte = builder.not(lte);
flag = builder.and(flag, not_lte);
}
// Increment the current chunk index by the total number of chunks.
current_chunk_index += total_number_of_chunks;
end_bit_values.extend_from_slice(&vec![false; total_number_of_chunks - 1]);
end_bit_values.push(true);

padded_chunks
})
.collect::<Vec<_>>();

// Convert end_bits to variables.
let end_bits = builder.constant_vec(&end_bit_values);

SHAInputData {
Expand All @@ -114,6 +159,7 @@ pub trait SHA<L: PlonkParameters<D>, const D: usize, const CYCLE_LEN: usize>:
}
}

/// The Curta Stark corresponding to the input data.
fn stark(parameters: SHAInputParameters) -> SHAStark<L, Self, D, CYCLE_LEN> {
let mut builder = BytesBuilder::<Self::AirParameters>::new();

Expand All @@ -122,13 +168,16 @@ pub trait SHA<L: PlonkParameters<D>, const D: usize, const CYCLE_LEN: usize>:
.map(|_| builder.alloc_array_public::<Self::IntRegister>(16))
.collect::<Vec<_>>();

// Allocate registers for the public inputs to the Stark.
let end_bits = builder.alloc_array_public::<BitRegister>(num_chunks);
let digest_bits = builder.alloc_array_public::<BitRegister>(num_chunks);
let digest_indices = builder.alloc_array_public::<ElementRegister>(parameters.num_digests);

// Hash the padded chunks.
let digests =
builder.sha::<Self, CYCLE_LEN>(&padded_chunks, &end_bits, &digest_bits, digest_indices);

// Build the stark.
let num_rows_degree = log2_ceil(CYCLE_LEN * num_chunks);
let num_rows = 1 << num_rows_degree;
let stark = builder.build::<L::CurtaConfig, D>(num_rows);
Expand Down
1 change: 1 addition & 0 deletions plonky2x/core/src/frontend/hash/sha/curta/proof_hint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use super::SHA;
use crate::frontend::hint::simple::hint::Hint;
use crate::prelude::*;

/// A hint for Curta proof of a SHA stark.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SHAProofHint<S, const CYCLE_LEN: usize> {
parameters: SHAInputParameters,
Expand Down
5 changes: 3 additions & 2 deletions plonky2x/core/src/frontend/hash/sha/curta/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@ pub enum SHARequestType {
Variable,
}

/// A SHA request.
/// A request for a SHA computation.
#[derive(Debug, Clone)]
pub enum SHARequest {
/// A message of fixed length.
Fixed(Vec<ByteVariable>),
/// A message of variable length, with the actual legnth given by the Variable.
/// A message of variable length, represented by a tuple `(total_message, lengh, last_chunk)`.
Variable(Vec<ByteVariable>, U32Variable, U32Variable),
}

impl SHARequest {
/// Returns the type of the request.
pub const fn req_type(&self) -> SHARequestType {
match self {
SHARequest::Fixed(_) => SHARequestType::Fixed,
Expand Down
6 changes: 4 additions & 2 deletions plonky2x/core/src/frontend/hash/sha/curta/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,23 @@ where
}
}

/// Generate a proof for the SHA stark given the input data.
#[allow(clippy::type_complexity)]
pub fn prove(
&self,
input: SHAInputDataValues<L, S, D, CYCLE_LEN>,
) -> (ByteStarkProof<L::Field, L::CurtaConfig, D>, Vec<L::Field>) {
// Initialize a writer for the trace.
let writer = TraceWriter::new(&self.stark.air_data, self.num_rows);

// Write the public inputs to the trace.
self.write_input(&writer, input);

// Execute the air instructions.
writer.write_global_instructions(&self.stark.air_data);
for i in 0..self.num_rows {
writer.write_row_instructions(&self.stark.air_data, i);
}

// Extract the trace and public input slice from the writer.
let InnerWriterData { trace, public, .. } = writer.into_inner().unwrap();
let proof = self
.stark
Expand Down
Loading

0 comments on commit 5b85909

Please sign in to comment.