Skip to content

Commit

Permalink
Merge pull request #623 from Chia-Network/cache-interior-mutability
Browse files Browse the repository at this point in the history
Add interior mutability to BLSCache
  • Loading branch information
matt-o-how authored Jul 26, 2024
2 parents 8768baf + a91aa35 commit 3438e6e
Show file tree
Hide file tree
Showing 11 changed files with 170 additions and 161 deletions.
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,4 @@ zstd = "0.13.2"
blocking-threadpool = "1.0.1"
libfuzzer-sys = "0.4"
wasm-bindgen = "0.2.92"
parking_lot = "0.12.3"
1 change: 1 addition & 0 deletions crates/chia-bls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ thiserror = { workspace = true }
pyo3 = { workspace = true, features = ["multiple-pymethods"], optional = true }
arbitrary = { workspace = true, optional = true }
lru = { workspace = true }
parking_lot = { workspace = true }

[dev-dependencies]
rand = { workspace = true }
Expand Down
54 changes: 10 additions & 44 deletions crates/chia-bls/benches/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,77 +23,43 @@ fn cache_benchmark(c: &mut Criterion) {
pks.push(pk);
}

let mut bls_cache = BlsCache::default();
let bls_cache = BlsCache::default();

c.bench_function("bls_cache.aggregate_verify, 0% cache hits", |b| {
let mut cache = bls_cache.clone();
b.iter(|| {
assert!(
cache
.aggregate_verify(pks.iter().zip([&msg].iter().cycle()), &agg_sig)
.0
);
assert!(bls_cache.aggregate_verify(pks.iter().zip([&msg].iter().cycle()), &agg_sig));
});
});

// populate 10% of keys
bls_cache.aggregate_verify(pks[0..100].into_iter().zip([&msg].iter().cycle()), &agg_sig);
bls_cache.aggregate_verify(pks[0..100].iter().zip([&msg].iter().cycle()), &agg_sig);
c.bench_function("bls_cache.aggregate_verify, 10% cache hits", |b| {
let mut cache = bls_cache.clone();
b.iter(|| {
assert!(
cache
.aggregate_verify(pks.iter().zip([&msg].iter().cycle()), &agg_sig)
.0
);
assert!(bls_cache.aggregate_verify(pks.iter().zip([&msg].iter().cycle()), &agg_sig));
});
});

// populate another 10% of keys
bls_cache.aggregate_verify(
pks[100..200].into_iter().zip([&msg].iter().cycle()),
&agg_sig,
);
bls_cache.aggregate_verify(pks[100..200].iter().zip([&msg].iter().cycle()), &agg_sig);
c.bench_function("bls_cache.aggregate_verify, 20% cache hits", |b| {
let mut cache = bls_cache.clone();
b.iter(|| {
assert!(
cache
.aggregate_verify(pks.iter().zip([&msg].iter().cycle()), &agg_sig)
.0
);
assert!(bls_cache.aggregate_verify(pks.iter().zip([&msg].iter().cycle()), &agg_sig));
});
});

// populate another 30% of keys
bls_cache.aggregate_verify(
pks[200..500].into_iter().zip([&msg].iter().cycle()),
&agg_sig,
);
bls_cache.aggregate_verify(pks[200..500].iter().zip([&msg].iter().cycle()), &agg_sig);
c.bench_function("bls_cache.aggregate_verify, 50% cache hits", |b| {
let mut cache = bls_cache.clone();
b.iter(|| {
assert!(
cache
.aggregate_verify(pks.iter().zip([&msg].iter().cycle()), &agg_sig)
.0
);
assert!(bls_cache.aggregate_verify(pks.iter().zip([&msg].iter().cycle()), &agg_sig));
});
});

// populate all other keys
bls_cache.aggregate_verify(
pks[500..1000].into_iter().zip([&msg].iter().cycle()),
&agg_sig,
);
bls_cache.aggregate_verify(pks[500..1000].iter().zip([&msg].iter().cycle()), &agg_sig);
c.bench_function("bls_cache.aggregate_verify, 100% cache hits", |b| {
let mut cache = bls_cache.clone();
b.iter(|| {
assert!(
cache
.aggregate_verify(pks.iter().zip([&msg].iter().cycle()), &agg_sig)
.0
);
assert!(bls_cache.aggregate_verify(pks.iter().zip([&msg].iter().cycle()), &agg_sig));
});
});

Expand Down
91 changes: 29 additions & 62 deletions crates/chia-bls/src/bls_cache.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::{aggregate_verify_gt, hash_to_g2};
use crate::{GTElement, PublicKey, Signature};
use lru::LruCache;
use parking_lot::Mutex;
use sha2::{Digest, Sha256};
use std::borrow::Borrow;
use std::num::NonZeroUsize;
Expand All @@ -15,10 +16,10 @@ use std::num::NonZeroUsize;
/// aggregate_verify() primitive is faster. When long-syncing, that's
/// preferable.
#[cfg_attr(feature = "py-bindings", pyo3::pyclass(name = "BLSCache"))]
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct BlsCache {
// sha256(pubkey + message) -> GTElement
cache: LruCache<[u8; 32], GTElement>,
cache: Mutex<LruCache<[u8; 32], GTElement>>,
}

impl Default for BlsCache {
Expand All @@ -30,30 +31,23 @@ impl Default for BlsCache {
impl BlsCache {
pub fn new(cache_size: NonZeroUsize) -> Self {
Self {
cache: LruCache::new(cache_size),
cache: Mutex::new(LruCache::new(cache_size)),
}
}

pub fn len(&self) -> usize {
self.cache.len()
self.cache.lock().len()
}

pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}

pub fn update(&mut self, new_items: impl IntoIterator<Item = ([u8; 32], GTElement)>) {
for (key, value) in new_items {
self.cache.put(key, value);
}
self.cache.lock().is_empty()
}

pub fn aggregate_verify<Pk: Borrow<PublicKey>, Msg: AsRef<[u8]>>(
&mut self,
&self,
pks_msgs: impl IntoIterator<Item = (Pk, Msg)>,
sig: &Signature,
) -> (bool, Vec<([u8; 32], GTElement)>) {
let mut added: Vec<([u8; 32], GTElement)> = Vec::new();
) -> bool {
let iter = pks_msgs.into_iter().map(|(pk, msg)| -> GTElement {
// Hash pubkey + message
let mut hasher = Sha256::new();
Expand All @@ -62,7 +56,7 @@ impl BlsCache {
let hash: [u8; 32] = hasher.finalize().into();

// If the pairing is in the cache, we don't need to recalculate it.
if let Some(pairing) = self.cache.get(&hash).cloned() {
if let Some(pairing) = self.cache.lock().get(&hash).cloned() {
return pairing;
}

Expand All @@ -76,12 +70,11 @@ impl BlsCache {
let hash: [u8; 32] = hasher.finalize().into();

let pairing = aug_hash.pair(pk.borrow());
self.cache.put(hash, pairing.clone());
added.push((hash, pairing.clone()));
self.cache.lock().put(hash, pairing.clone());
pairing
});

(aggregate_verify_gt(sig, iter), added)
aggregate_verify_gt(sig, iter)
}
}

Expand Down Expand Up @@ -117,7 +110,7 @@ impl BlsCache {
pks: &Bound<'_, PyList>,
msgs: &Bound<'_, PyList>,
sig: &Signature,
) -> PyResult<(bool, Vec<([u8; 32], GTElement)>)> {
) -> PyResult<bool> {
let pks = pks
.iter()?
.map(|item| item?.extract())
Expand All @@ -141,17 +134,19 @@ impl BlsCache {
use pyo3::prelude::*;
use pyo3::types::PyBytes;
let ret = PyList::empty_bound(py);
for (key, value) in &self.cache {
let cache = self.cache.lock();
for (key, value) in cache.iter() {
ret.append((PyBytes::new_bound(py, key), value.clone().into_py(py)))?;
}
Ok(ret.into())
}

#[pyo3(name = "update")]
pub fn py_update(&mut self, other: &Bound<'_, PyList>) -> PyResult<()> {
let mut cache = self.cache.lock();
for item in other.borrow().iter()? {
let (key, value): (Vec<u8>, GTElement) = item?.extract()?;
self.cache.put(
cache.put(
key.try_into()
.map_err(|_| PyValueError::new_err("invalid key"))?,
value,
Expand All @@ -170,7 +165,7 @@ pub mod tests {

#[test]
fn test_aggregate_verify() {
let mut bls_cache = BlsCache::default();
let bls_cache = BlsCache::default();

let sk = SecretKey::from_seed(&[0; 32]);
let pk = sk.public_key();
Expand All @@ -184,25 +179,17 @@ pub mod tests {
assert!(bls_cache.is_empty());

// Verify the signature and add to the cache.
assert!(
bls_cache
.aggregate_verify(pk_list.into_iter().zip(msg_list), &sig)
.0
);
assert!(bls_cache.aggregate_verify(pk_list.into_iter().zip(msg_list), &sig));
assert_eq!(bls_cache.len(), 1);

// Now that it's cached, it shouldn't cache it again.
assert!(
bls_cache
.aggregate_verify(pk_list.into_iter().zip(msg_list), &sig)
.0
);
assert!(bls_cache.aggregate_verify(pk_list.into_iter().zip(msg_list), &sig));
assert_eq!(bls_cache.len(), 1);
}

#[test]
fn test_cache() {
let mut bls_cache = BlsCache::default();
let bls_cache = BlsCache::default();

let sk1 = SecretKey::from_seed(&[0; 32]);
let pk1 = sk1.public_key();
Expand All @@ -216,11 +203,7 @@ pub mod tests {
assert!(bls_cache.is_empty());

// Add the first signature to cache.
assert!(
bls_cache
.aggregate_verify(pk_list.iter().zip(msg_list.iter()), &agg_sig)
.0
);
assert!(bls_cache.aggregate_verify(pk_list.iter().zip(msg_list.iter()), &agg_sig));
assert_eq!(bls_cache.len(), 1);

// Try with the first key message pair in the cache but not the second.
Expand All @@ -232,11 +215,7 @@ pub mod tests {
pk_list.push(pk2);
msg_list.push(msg2);

assert!(
bls_cache
.aggregate_verify(pk_list.iter().zip(msg_list.iter()), &agg_sig)
.0
);
assert!(bls_cache.aggregate_verify(pk_list.iter().zip(msg_list.iter()), &agg_sig));
assert_eq!(bls_cache.len(), 2);

// Try reusing a public key.
Expand All @@ -247,18 +226,14 @@ pub mod tests {
msg_list.push(msg3);

// Verify this signature and add to the cache as well (since it's still a different aggregate).
assert!(
bls_cache
.aggregate_verify(pk_list.iter().zip(msg_list), &agg_sig)
.0
);
assert!(bls_cache.aggregate_verify(pk_list.iter().zip(msg_list), &agg_sig));
assert_eq!(bls_cache.len(), 3);
}

#[test]
fn test_cache_limit() {
// The cache is limited to only 3 items.
let mut bls_cache = BlsCache::new(NonZeroUsize::new(3).unwrap());
let bls_cache = BlsCache::new(NonZeroUsize::new(3).unwrap());

// Before we cache anything, it should be empty.
assert!(bls_cache.is_empty());
Expand All @@ -274,15 +249,11 @@ pub mod tests {
let msg_list = [msg];

// Add to cache by validating them one at a time.
assert!(
bls_cache
.aggregate_verify(pk_list.into_iter().zip(msg_list), &sig)
.0
);
assert!(bls_cache.aggregate_verify(pk_list.into_iter().zip(msg_list), &sig));
}

// The cache should be full now.
assert_eq!(bls_cache.cache.len(), 3);
assert_eq!(bls_cache.len(), 3);

// Recreate first key.
let sk = SecretKey::from_seed(&[1; 32]);
Expand All @@ -296,20 +267,16 @@ pub mod tests {
let hash: [u8; 32] = hasher.finalize().into();

// The first key should have been removed, since it's the oldest that's been accessed.
assert!(!bls_cache.cache.contains(&hash));
assert!(!bls_cache.cache.lock().contains(&hash));
}

#[test]
fn test_empty_sig() {
let mut bls_cache = BlsCache::default();
let bls_cache = BlsCache::default();

let pks: [&PublicKey; 0] = [];
let msgs: [&[u8]; 0] = [];

assert!(
bls_cache
.aggregate_verify(pks.into_iter().zip(msgs), &Signature::default())
.0
);
assert!(bls_cache.aggregate_verify(pks.into_iter().zip(msgs), &Signature::default()));
}
}
6 changes: 3 additions & 3 deletions crates/chia-consensus/src/gen/condition_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,11 @@ mod tests {
let mut a: Allocator = make_allocator(LIMIT_HEAP);
for v in 0..10000 {
let ptr = a.new_small_number(v).expect("valid u64");
assert_eq!(a.atom(ptr).as_ref(), u64_to_bytes(v as u64).as_slice())
assert_eq!(a.atom(ptr).as_ref(), u64_to_bytes(v as u64).as_slice());
}
for v in 18446744073709551615_u64 - 1000..18446744073709551615 {
for v in 18_446_744_073_709_551_615_u64 - 1000..18_446_744_073_709_551_615 {
let ptr = a.new_number(v.into()).expect("valid u64");
assert_eq!(a.atom(ptr).as_ref(), u64_to_bytes(v).as_slice())
assert_eq!(a.atom(ptr).as_ref(), u64_to_bytes(v).as_slice());
}
}
}
Loading

0 comments on commit 3438e6e

Please sign in to comment.