Skip to content

Commit

Permalink
introduce an unchecked path to the Streamable trait, and make the BLS…
Browse files Browse the repository at this point in the history
… types support unchecked parsing
  • Loading branch information
arvidn committed Dec 4, 2023
1 parent 942b7aa commit d6a4fd1
Show file tree
Hide file tree
Showing 21 changed files with 431 additions and 201 deletions.
7 changes: 3 additions & 4 deletions chia-bls/benches/parse.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use chia_bls::secret_key::SecretKey;
use chia_bls::signature::sign;
use chia_bls::Signature;
use chia_bls::PublicKey;
use chia_bls::Signature;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
Expand All @@ -16,8 +16,8 @@ fn parse_benchmark(c: &mut Criterion) {
let msg = b"The quick brown fox jumps over the lazy dog";
let sig = sign(&sk, msg);

let sig_bytes = sig.to_bytes();
let pk_bytes = pk.to_bytes();
let sig_bytes = sig.to_bytes();
let pk_bytes = pk.to_bytes();

c.bench_function("parse Signature", |b| {
b.iter(|| {
Expand Down Expand Up @@ -46,4 +46,3 @@ fn parse_benchmark(c: &mut Criterion) {

criterion_group!(parse, parse_benchmark);
criterion_main!(parse);

8 changes: 1 addition & 7 deletions chia-bls/src/gtelement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,6 @@ impl GTElement {
#[pyo3(name = "SIZE")]
const PY_SIZE: usize = Self::SIZE;

#[staticmethod]
#[pyo3(name = "from_bytes_unchecked")]
fn py_from_bytes_unchecked(bytes: [u8; Self::SIZE]) -> Result<GTElement> {
Ok(Self::from_bytes(&bytes))
}

fn __str__(&self) -> pyo3::PyResult<String> {
Ok(hex::encode(self.to_bytes()))
}
Expand Down Expand Up @@ -153,7 +147,7 @@ impl Streamable for GTElement {
Ok(())
}

fn parse(input: &mut Cursor<&[u8]>) -> Result<Self> {
fn parse<const TRUSTED: bool>(input: &mut Cursor<&[u8]>) -> Result<Self> {
Ok(GTElement::from_bytes(
read_bytes(input, Self::SIZE)?.try_into().unwrap(),
))
Expand Down
17 changes: 7 additions & 10 deletions chia-bls/src/public_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,6 @@ impl PublicKey {
Self::default()
}

#[staticmethod]
#[pyo3(name = "from_bytes_unchecked")]
fn py_from_bytes_unchecked(bytes: [u8; Self::SIZE]) -> Result<Self> {
Self::from_bytes_unchecked(&bytes)
}

#[staticmethod]
#[pyo3(name = "generator")]
pub fn py_generator() -> Self {
Expand Down Expand Up @@ -217,10 +211,13 @@ impl Streamable for PublicKey {
Ok(())
}

fn parse(input: &mut Cursor<&[u8]>) -> chia_traits::Result<Self> {
Ok(Self::from_bytes(
read_bytes(input, 48)?.try_into().unwrap(),
)?)
fn parse<const TRUSTED: bool>(input: &mut Cursor<&[u8]>) -> chia_traits::Result<Self> {
let input = read_bytes(input, 48)?.try_into().unwrap();
if TRUSTED {
Ok(Self::from_bytes_unchecked(input)?)
} else {
Ok(Self::from_bytes(input)?)
}
}
}

Expand Down
4 changes: 3 additions & 1 deletion chia-bls/src/secret_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ impl Streamable for SecretKey {
Ok(())
}

fn parse(input: &mut Cursor<&[u8]>) -> chia_traits::chia_error::Result<Self> {
fn parse<const TRUSTED: bool>(
input: &mut Cursor<&[u8]>,
) -> chia_traits::chia_error::Result<Self> {
Ok(Self::from_bytes(
read_bytes(input, 32)?.try_into().unwrap(),
)?)
Expand Down
19 changes: 9 additions & 10 deletions chia-bls/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,15 @@ impl Streamable for Signature {
Ok(())
}

fn parse(input: &mut Cursor<&[u8]>) -> chia_traits::chia_error::Result<Self> {
Ok(Self::from_bytes(
read_bytes(input, 96)?.try_into().unwrap(),
)?)
fn parse<const TRUSTED: bool>(
input: &mut Cursor<&[u8]>,
) -> chia_traits::chia_error::Result<Self> {
let input = read_bytes(input, 96)?.try_into().unwrap();
if TRUSTED {
Ok(Self::from_bytes_unchecked(input)?)
} else {
Ok(Self::from_bytes(input)?)
}
}
}

Expand Down Expand Up @@ -285,12 +290,6 @@ impl Signature {
Self::default()
}

#[staticmethod]
#[pyo3(name = "from_bytes_unchecked")]
pub fn py_from_bytes_unchecked(bytes: [u8; Self::SIZE]) -> Result<Signature> {
Self::from_bytes_unchecked(&bytes)
}

#[pyo3(name = "pair")]
pub fn py_pair(&self, other: &PublicKey) -> GTElement {
self.pair(other)
Expand Down
7 changes: 5 additions & 2 deletions chia-protocol/fuzz/fuzz_targets/streamable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ pub fn test_streamable<T: Streamable + std::fmt::Debug + PartialEq>(obj: &T) {
};
assert_eq!(obj, &obj2);

let obj3 = T::from_bytes_unchecked(&bytes).unwrap();
assert_eq!(obj, &obj3);

let mut ctx = Sha256::new();
ctx.update(&bytes);
let expect_hash: [u8; 32] = ctx.finalize().into();
Expand All @@ -29,12 +32,12 @@ pub fn test_streamable<T: Streamable + std::fmt::Debug + PartialEq>(obj: &T) {
// make sure input too large is an error
let mut corrupt_bytes = bytes.clone();
corrupt_bytes.push(0);
assert!(T::from_bytes(&corrupt_bytes) == Err(chia_traits::Error::InputTooLarge));
assert!(T::from_bytes_unchecked(&corrupt_bytes) == Err(chia_traits::Error::InputTooLarge));

if !bytes.is_empty() {
// make sure input too short is an error
corrupt_bytes.truncate(bytes.len() - 1);
assert!(T::from_bytes(&corrupt_bytes) == Err(chia_traits::Error::EndOfBuffer));
assert!(T::from_bytes_unchecked(&corrupt_bytes) == Err(chia_traits::Error::EndOfBuffer));
}
}
#[cfg(fuzzing)]
Expand Down
10 changes: 5 additions & 5 deletions chia-protocol/src/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ impl Streamable for Bytes {
}
}

fn parse(input: &mut Cursor<&[u8]>) -> chia_error::Result<Self> {
let len = u32::parse(input)?;
fn parse<const TRUSTED: bool>(input: &mut Cursor<&[u8]>) -> chia_error::Result<Self> {
let len = u32::parse::<TRUSTED>(input)?;
Ok(Bytes(read_bytes(input, len as usize)?.to_vec()))
}
}
Expand Down Expand Up @@ -191,7 +191,7 @@ impl<const N: usize> Streamable for BytesImpl<N> {
Ok(())
}

fn parse(input: &mut Cursor<&[u8]>) -> chia_error::Result<Self> {
fn parse<const TRUSTED: bool>(input: &mut Cursor<&[u8]>) -> chia_error::Result<Self> {
Ok(BytesImpl(read_bytes(input, N)?.try_into().unwrap()))
}
}
Expand Down Expand Up @@ -562,15 +562,15 @@ mod tests {

fn from_bytes<T: Streamable + std::fmt::Debug + std::cmp::PartialEq>(buf: &[u8], expected: T) {
let mut input = Cursor::<&[u8]>::new(buf);
assert_eq!(T::parse(&mut input).unwrap(), expected);
assert_eq!(T::parse::<false>(&mut input).unwrap(), expected);
}

fn from_bytes_fail<T: Streamable + std::fmt::Debug + std::cmp::PartialEq>(
buf: &[u8],
expected: chia_error::Error,
) {
let mut input = Cursor::<&[u8]>::new(buf);
assert_eq!(T::parse(&mut input).unwrap_err(), expected);
assert_eq!(T::parse::<false>(&mut input).unwrap_err(), expected);
}

fn stream<T: Streamable>(v: &T) -> Vec<u8> {
Expand Down
2 changes: 1 addition & 1 deletion chia-protocol/src/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ impl Streamable for Program {
Ok(())
}

fn parse(input: &mut Cursor<&[u8]>) -> Result<Self> {
fn parse<const TRUSTED: bool>(input: &mut Cursor<&[u8]>) -> Result<Self> {
let pos = input.position();
let buf: &[u8] = &input.get_ref()[pos as usize..];
let len = serialized_length_from_bytes(buf).map_err(|_e| Error::EndOfBuffer)?;
Expand Down
8 changes: 4 additions & 4 deletions chia-tools/src/bin/analyze-chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ fn main() {
let block_buffer =
zstd::stream::decode_all(&mut std::io::Cursor::<Vec<u8>>::new(block_buffer))
.expect("failed to decompress block");
let block = FullBlock::parse(&mut std::io::Cursor::<&[u8]>::new(&block_buffer))
.expect("failed to parse FullBlock");
let block =
FullBlock::from_bytes_unchecked(&block_buffer).expect("failed to parse FullBlock");

let ti = match block.transactions_info {
Some(ti) => ti,
Expand Down Expand Up @@ -141,8 +141,8 @@ fn main() {
zstd::stream::decode_all(&mut std::io::Cursor::<Vec<u8>>::new(ref_block))
.expect("failed to decompress block");

let ref_block = FullBlock::parse(&mut std::io::Cursor::<&[u8]>::new(&ref_block))
.expect("failed to parse ref-block");
let ref_block =
FullBlock::from_bytes_unchecked(&ref_block).expect("failed to parse ref-block");
let ref_gen = match ref_block.transactions_generator {
None => {
panic!("block ref has no generator");
Expand Down
3 changes: 1 addition & 2 deletions chia-tools/src/bin/fast-forward-spend.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use clap::Parser;
use std::fs;
use std::io::Cursor;

use chia::fast_forward::fast_forward_singleton;
use chia_protocol::bytes::Bytes32;
Expand Down Expand Up @@ -31,7 +30,7 @@ fn main() {
let args = Args::parse();

let spend_bytes = fs::read(args.spend).expect("read file");
let spend = CoinSpend::parse(&mut Cursor::new(&spend_bytes)).expect("parse CoinSpend");
let spend = CoinSpend::from_bytes(&spend_bytes).expect("parse CoinSpend");

let new_parents_parent: Bytes32 = hex::decode(args.new_parents_parent)
.expect("invalid hex")
Expand Down
3 changes: 1 addition & 2 deletions chia-tools/src/bin/run-spend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use clvm_utils::tree_hash;
use clvm_utils::CurriedProgram;
use clvmr::{allocator::NodePtr, Allocator};
use hex_literal::hex;
use std::io::Cursor;

/// Run a puzzle given a solution and print the resulting conditions
#[derive(Parser, Debug)]
Expand Down Expand Up @@ -238,7 +237,7 @@ fn main() {

let mut a = Allocator::new();
let spend = read(args.spend).expect("spend file not found");
let spend = CoinSpend::parse(&mut Cursor::new(spend.as_slice())).expect("parse CoinSpend");
let spend = CoinSpend::from_bytes(&spend).expect("parse CoinSpend");

let puzzle = spend
.puzzle_reveal
Expand Down
8 changes: 4 additions & 4 deletions chia-tools/src/bin/test-block-generators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ fn main() {
let block_buffer =
zstd::stream::decode_all(&mut std::io::Cursor::<Vec<u8>>::new(block_buffer))
.expect("failed to decompress block");
let block = FullBlock::parse(&mut std::io::Cursor::<&[u8]>::new(&block_buffer))
.expect("failed to parse FullBlock");
let block =
FullBlock::from_bytes_unchecked(&block_buffer).expect("failed to parse FullBlock");

let ti = match block.transactions_info {
Some(ti) => ti,
Expand Down Expand Up @@ -218,8 +218,8 @@ fn main() {
zstd::stream::decode_all(&mut std::io::Cursor::<Vec<u8>>::new(ref_block))
.expect("failed to decompress block");

let ref_block = FullBlock::parse(&mut std::io::Cursor::<&[u8]>::new(&ref_block))
.expect("failed to parse ref-block");
let ref_block =
FullBlock::from_bytes_unchecked(&ref_block).expect("failed to parse ref-block");
let ref_gen = ref_block
.transactions_generator
.expect("block ref has no generator");
Expand Down
8 changes: 4 additions & 4 deletions chia-tools/src/visit_spends.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ pub fn iterate_tx_blocks(
let block_buffer =
zstd::stream::decode_all(&mut std::io::Cursor::<Vec<u8>>::new(block_buffer))
.expect("failed to decompress block");
let block = FullBlock::parse(&mut std::io::Cursor::<&[u8]>::new(&block_buffer))
.expect("failed to parse FullBlock");
let block =
FullBlock::from_bytes_unchecked(&block_buffer).expect("failed to parse FullBlock");

if block.transactions_info.is_none() {
continue;
Expand Down Expand Up @@ -83,8 +83,8 @@ pub fn iterate_tx_blocks(
zstd::stream::decode_all(&mut std::io::Cursor::<Vec<u8>>::new(ref_block))
.expect("failed to decompress block");

let ref_block = FullBlock::parse(&mut std::io::Cursor::<&[u8]>::new(&ref_block))
.expect("failed to parse ref-block");
let ref_block =
FullBlock::from_bytes_unchecked(&ref_block).expect("failed to parse ref-block");
let ref_gen = ref_block
.transactions_generator
.expect("block ref has no generator");
Expand Down
Loading

0 comments on commit d6a4fd1

Please sign in to comment.