Skip to content

Commit

Permalink
perf: Dory commitments should be done with Blitzar's packed_msm funct…
Browse files Browse the repository at this point in the history
…ion (PROOF-895) (#109)

# Rationale for this change
In order to improve the performance of the Dory protocol, the Dory
commitment computation should use Blitzar's `packed_msm` function.
Benchmarks indicate that the use of Blitzar's `packed_msm` over
`fixed_msm` improves performance of proof generation.

Benchmarks, run on August 15, 2024, for the scalar packing done in this
PR, show a speed up over three different VMs.
- Multi-A100 
  - 4.48x speed up of `ProofBuilder::commit_intermediate_mles`
  - 1.82x speed up of entire benchmark
- Multi-T4
  - 3.66x speed up of `ProofBuilder::commit_intermediate_mles`
  - 1.81x speed up of entire benchmark
- Single T4
  - 1.91x speed up of `ProofBuilder::commit_intermediate_mles`
  - 1.30x speed up of entire benchmark
  
# What changes are included in this PR?
- The `dory_commitment_helper_gpu` module replaces Blitzar's `fixed_msm`
function with the more efficient `packed_msm` function.
- The `pack_scalars` module is added to get the parameters required by
Blitzar's `packed_msm` function.
- The Dory `setup` module exposes Blitzar's `packed_msm` function.

# Are these changes tested?
Yes
  • Loading branch information
jacobtrombetta authored Aug 28, 2024
1 parent 9060658 commit bfe2c56
Show file tree
Hide file tree
Showing 7 changed files with 1,019 additions and 181 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ arrow-csv = { version = "51.0" }
bit-iter = { version = "1.1.1" }
bigdecimal = { version = "0.4.5", features = ["serde"] }
blake3 = { version = "1.3.3" }
blitzar = { version = "3.0.2" }
blitzar = { version = "3.1.0" }
bumpalo = { version = "3.11.0" }
bytemuck = {version = "1.16.3", features = ["derive"]}
byte-slice-cast = { version = "1.2.1" }
Expand Down
72 changes: 72 additions & 0 deletions crates/proof-of-sql/src/base/database/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,23 @@ impl ColumnType {
},
}
}

/// Returns the byte size of the column type.
pub fn byte_size(&self) -> usize {
match self {
Self::Boolean => std::mem::size_of::<bool>(),
Self::SmallInt => std::mem::size_of::<i16>(),
Self::Int => std::mem::size_of::<i32>(),
Self::BigInt | Self::TimestampTZ(_, _) => std::mem::size_of::<i64>(),
Self::Int128 => std::mem::size_of::<i128>(),
Self::Scalar | Self::Decimal75(_, _) | Self::VarChar => std::mem::size_of::<[u64; 4]>(),
}
}

/// Returns the bit size of the column type.
pub fn bit_size(&self) -> u32 {
self.byte_size() as u32 * 8
}
}

/// Convert ColumnType values to some arrow DataType
Expand Down Expand Up @@ -920,4 +937,59 @@ mod tests {
let new_owned_col = (&col).into();
assert_eq!(owned_col, new_owned_col);
}

#[test]
fn we_can_get_the_data_size_of_a_column() {
let column = Column::<DoryScalar>::Boolean(&[true, false, true]);
assert_eq!(column.column_type().byte_size(), 1);
assert_eq!(column.column_type().bit_size(), 8);

let column = Column::<Curve25519Scalar>::SmallInt(&[1, 2, 3, 4]);
assert_eq!(column.column_type().byte_size(), 2);
assert_eq!(column.column_type().bit_size(), 16);

let column = Column::<Curve25519Scalar>::Int(&[1, 2, 3]);
assert_eq!(column.column_type().byte_size(), 4);
assert_eq!(column.column_type().bit_size(), 32);

let column = Column::<Curve25519Scalar>::BigInt(&[1]);
assert_eq!(column.column_type().byte_size(), 8);
assert_eq!(column.column_type().bit_size(), 64);

let column = Column::<DoryScalar>::Int128(&[1, 2]);
assert_eq!(column.column_type().byte_size(), 16);
assert_eq!(column.column_type().bit_size(), 128);

let scals = [
Curve25519Scalar::from(1),
Curve25519Scalar::from(2),
Curve25519Scalar::from(3),
];

let column = Column::VarChar((&["a", "b", "c", "d", "e"], &scals));
assert_eq!(column.column_type().byte_size(), 32);
assert_eq!(column.column_type().bit_size(), 256);

let column = Column::Scalar(&scals);
assert_eq!(column.column_type().byte_size(), 32);
assert_eq!(column.column_type().bit_size(), 256);

let precision = 10;
let scale = 2;
let decimal_data = [
Curve25519Scalar::from(1),
Curve25519Scalar::from(2),
Curve25519Scalar::from(3),
];

let precision = Precision::new(precision).unwrap();
let column = Column::Decimal75(precision, scale, &decimal_data);
assert_eq!(column.column_type().byte_size(), 32);
assert_eq!(column.column_type().bit_size(), 256);

let column: Column<'_, DoryScalar> =
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::Utc, &[1, 2, 3]);
assert_eq!(column.column_type().byte_size(), 8);
assert_eq!(column.column_type().bit_size(), 64);
}
}
Original file line number Diff line number Diff line change
@@ -1,171 +1,85 @@
use super::{pairings, transpose, DoryCommitment, DoryProverPublicSetup, DoryScalar, G1Affine};
use crate::{
base::commitment::CommittableColumn, proof_primitive::dory::offset_to_bytes::OffsetToBytes,
};
use ark_bls12_381::Fr;
use ark_ec::CurveGroup;
use ark_std::ops::Mul;
use blitzar::{compute::ElementP2, sequence::Sequence};
use super::{pack_scalars, pairings, DoryCommitment, DoryProverPublicSetup, G1Affine};
use crate::base::commitment::CommittableColumn;
use blitzar::compute::ElementP2;
use rayon::prelude::*;

#[tracing::instrument(name = "get_offset_commits (gpu)", level = "debug", skip_all)]
fn get_offset_commits(
column_len: usize,
use tracing::{span, Level};

#[tracing::instrument(
name = "compute_dory_commitments_packed_impl (gpu)",
level = "debug",
skip_all
)]
fn compute_dory_commitments_packed_impl(
committable_columns: &[CommittableColumn],
offset: usize,
num_columns: usize,
num_of_commits: usize,
scalar: Fr,
setup: &DoryProverPublicSetup,
) -> Vec<G1Affine> {
let first_row_offset = offset % num_columns;
let first_row_len = column_len.min(num_columns - first_row_offset);
let num_zero_commits = offset / num_columns;
let data_size = 1;
) -> Vec<DoryCommitment> {
if committable_columns.is_empty() {
return vec![];
}

let ones = vec![1_u8; column_len];
let (first_row, remaining_elements) = ones.split_at(first_row_len);
let num_columns = 1 << setup.sigma();

let mut ones_blitzar_commits =
vec![ElementP2::<ark_bls12_381::g1::Config>::default(); num_of_commits];
// If the offset is larger than the number of columns, we compute an
// offset for the gamma_2 table to avoid finding sub-commits of zero.
let gamma_2_offset = offset / num_columns;
let offset = offset % num_columns;

// Get the number of sub-commits for each full commit
let num_sub_commits_per_full_commit =
pack_scalars::sub_commits_per_full_commit(committable_columns, offset, num_columns);

// Get the bit table and packed scalars for the packed msm
let (bit_table, packed_scalars) = pack_scalars::bit_table_and_scalars_for_packed_msm(
committable_columns,
offset,
num_columns,
num_sub_commits_per_full_commit,
);

if num_zero_commits < num_of_commits {
// Get the commit of the first non-zero row
let first_row_offset = offset - (num_zero_commits * num_columns);
let first_row_transpose = transpose::transpose_for_fixed_msm(
first_row,
first_row_offset,
1,
num_columns,
data_size,
);
let mut sub_commits_from_blitzar =
vec![ElementP2::<ark_bls12_381::g1::Config>::default(); bit_table.len()];

setup.prover_setup().blitzar_msm(
&mut ones_blitzar_commits[num_zero_commits..num_zero_commits + 1],
data_size as u32,
first_row_transpose.as_slice(),
// Compute packed msm
if !bit_table.is_empty() {
setup.prover_setup().blitzar_packed_msm(
&mut sub_commits_from_blitzar,
&bit_table,
packed_scalars.as_slice(),
);

// If there are more rows, get the commits of the middle row and duplicate them
let mut chunks = remaining_elements.chunks(num_columns);
if chunks.len() > 1 {
if let Some(middle_row) = chunks.next() {
let middle_row_transpose =
transpose::transpose_for_fixed_msm(middle_row, 0, 1, num_columns, data_size);
let mut middle_row_blitzar_commit =
vec![ElementP2::<ark_bls12_381::g1::Config>::default(); 1];

setup.prover_setup().blitzar_msm(
&mut middle_row_blitzar_commit,
data_size as u32,
middle_row_transpose.as_slice(),
);

ones_blitzar_commits[num_zero_commits + 1..num_of_commits - 1]
.par_iter_mut()
.for_each(|commit| *commit = middle_row_blitzar_commit[0].clone());
}
}

// Get the commit of the last row to handle an zero padding at the end of the column
if let Some(last_row) = remaining_elements.chunks(num_columns).last() {
let last_row_transpose =
transpose::transpose_for_fixed_msm(last_row, 0, 1, num_columns, data_size);

setup.prover_setup().blitzar_msm(
&mut ones_blitzar_commits[num_of_commits - 1..num_of_commits],
data_size as u32,
last_row_transpose.as_slice(),
);
}
}

ones_blitzar_commits
// Convert the sub-commits to G1Affine
let all_sub_commits: Vec<G1Affine> = sub_commits_from_blitzar
.par_iter()
.map(Into::into)
.map(|commit: G1Affine| commit.mul(scalar).into_affine())
.collect()
}
.collect();

#[tracing::instrument(name = "compute_dory_commitment_impl (gpu)", level = "debug", skip_all)]
fn compute_dory_commitment_impl<'a, T>(
column: &'a [T],
offset: usize,
setup: &DoryProverPublicSetup,
) -> DoryCommitment
where
&'a T: Into<DoryScalar>,
&'a [T]: Into<Sequence<'a>>,
T: OffsetToBytes,
{
let num_columns = 1 << setup.sigma();
let data_size = std::mem::size_of::<T>();

// Format column to match column major data layout required by blitzar's msm
let num_of_commits = ((column.len() + offset) + num_columns - 1) / num_columns;
let column_transpose =
transpose::transpose_for_fixed_msm(column, offset, num_of_commits, num_columns, data_size);
let gamma_2_slice = &setup.prover_setup().Gamma_2.last().unwrap()[0..num_of_commits];

// Compute the commitment for the entire data set
let mut blitzar_commits =
vec![ElementP2::<ark_bls12_381::g1::Config>::default(); num_of_commits];
setup.prover_setup().blitzar_msm(
&mut blitzar_commits,
data_size as u32,
column_transpose.as_slice(),
// Modify the signed sub-commits by adding the offset
let modified_sub_commits = pack_scalars::modify_commits(
&all_sub_commits,
committable_columns,
num_sub_commits_per_full_commit,
);

let commits: Vec<G1Affine> = blitzar_commits.par_iter().map(Into::into).collect();

// Signed data requires offset commitments
if T::IS_SIGNED {
let offset_commits = get_offset_commits(
column.len(),
offset,
num_columns,
num_of_commits,
T::min_as_fr(),
setup,
);
let gamma_2_slice = &setup.prover_setup().Gamma_2.last().unwrap()
[gamma_2_offset..gamma_2_offset + num_sub_commits_per_full_commit];

DoryCommitment(
pairings::multi_pairing(commits, gamma_2_slice)
+ pairings::multi_pairing(offset_commits, gamma_2_slice),
)
} else {
DoryCommitment(pairings::multi_pairing(commits, gamma_2_slice))
}
}
// Compute the Dory commitments using multi pairing of sub-commits
let span = span!(Level::INFO, "multi_pairing").entered();
let dc = modified_sub_commits
.par_chunks_exact(num_sub_commits_per_full_commit)
.map(|sub_commits| DoryCommitment(pairings::multi_pairing(sub_commits, gamma_2_slice)))
.collect();
span.exit();

fn compute_dory_commitment(
committable_column: &CommittableColumn,
offset: usize,
setup: &DoryProverPublicSetup,
) -> DoryCommitment {
match committable_column {
CommittableColumn::SmallInt(column) => compute_dory_commitment_impl(column, offset, setup),
CommittableColumn::Int(column) => compute_dory_commitment_impl(column, offset, setup),
CommittableColumn::BigInt(column) => compute_dory_commitment_impl(column, offset, setup),
CommittableColumn::Int128(column) => compute_dory_commitment_impl(column, offset, setup),
CommittableColumn::Decimal75(_, _, column) => {
compute_dory_commitment_impl(column, offset, setup)
}
CommittableColumn::Scalar(column) => compute_dory_commitment_impl(column, offset, setup),
CommittableColumn::VarChar(column) => compute_dory_commitment_impl(column, offset, setup),
CommittableColumn::Boolean(column) => compute_dory_commitment_impl(column, offset, setup),
CommittableColumn::TimestampTZ(_, _, column) => {
compute_dory_commitment_impl(column, offset, setup)
}
}
dc
}

pub(super) fn compute_dory_commitments(
committable_columns: &[CommittableColumn],
offset: usize,
setup: &DoryProverPublicSetup,
) -> Vec<DoryCommitment> {
committable_columns
.iter()
.map(|column| compute_dory_commitment(column, offset, setup))
.collect()
compute_dory_commitments_packed_impl(committable_columns, offset, setup)
}
1 change: 1 addition & 0 deletions crates/proof-of-sql/src/proof_primitive/dory/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ type DeferredG1 = deferred_msm::DeferredMSM<G1Affine, F>;
type DeferredG2 = deferred_msm::DeferredMSM<G2Affine, F>;

mod offset_to_bytes;
mod pack_scalars;
mod pairings;
mod transpose;

Expand Down
Loading

0 comments on commit bfe2c56

Please sign in to comment.