Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize BDN Signature/Key Aggregation #546

Merged
merged 9 commits into from
Sep 24, 2024
77 changes: 38 additions & 39 deletions sign/bdn/bdn.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ package bdn
import (
"crypto/cipher"
"errors"
"fmt"
"math/big"

"go.dedis.ch/kyber/v4"
Expand All @@ -31,23 +32,16 @@ var modulus128 = new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), 128), big.NewI
// We also use the entire roster so that the coefficient will vary for the same
// public key used in different roster
func hashPointToR(pubs []kyber.Point) ([]kyber.Scalar, error) {
peers := make([][]byte, len(pubs))
for i, pub := range pubs {
peer, err := pub.MarshalBinary()
if err != nil {
return nil, err
}

peers[i] = peer
}

h, err := blake2s.NewXOF(blake2s.OutputLengthUnknown, nil)
if err != nil {
return nil, err
}

for _, peer := range peers {
_, err := h.Write(peer)
for _, pub := range pubs {
peer, err := pub.MarshalBinary()
if err != nil {
return nil, err
}
_, err = h.Write(peer)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -128,62 +122,67 @@ func (scheme *Scheme) Verify(x kyber.Point, msg, sig []byte) error {

// AggregateSignatures aggregates the signatures using a coefficient for each
// one of them where c = H(pk) and H: keyGroup -> R with R = {1, ..., 2^128}
func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *sign.Mask) (kyber.Point, error) {
if len(sigs) != mask.CountEnabled() {
return nil, errors.New("length of signatures and public keys must match")
}

coefs, err := hashPointToR(mask.Publics())
func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask Mask) (kyber.Point, error) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this is technically a breaking change if someone is abstracting over Scheme via an interface. I'm happy to explore alternatives (e.g., add a new function, etc.) if that's an issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are in the process of preparing a V4 release anyway 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case.... I can drop the interface and switch to a BDN specific mask if that works better for you (not sure how much breaking you want).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that would be totally fine by me. Actually Mask seems to only be used in bdn anyway, I'd be fine with it also being in internal or so, and have BDN expose the methods required to extract the list of participants from an aggregate signature.

Summoning @pierluca and @K1li4nL in case they have other opinions

bdnMask, err := newCachedMask(mask, false)
if err != nil {
return nil, err
}

agg := scheme.sigGroup.Point()
for i, buf := range sigs {
peerIndex := mask.IndexOfNthEnabled(i)
if peerIndex < 0 {
// this should never happen as we check the lenths at the beginning
// an error here is probably a bug in the mask
return nil, errors.New("couldn't find the index")
for i := range bdnMask.publics {
if enabled, err := mask.GetBit(i); err != nil {
// this should never happen because of the loop boundary
// an error here is probably a bug in the mask implementation
return nil, fmt.Errorf("couldn't find the index %d: %w", i, err)
} else if !enabled {
continue
}

if len(sigs) == 0 {
return nil, errors.New("length of signatures and public keys must match")
}

buf := sigs[0]
sigs = sigs[1:]

sig := scheme.sigGroup.Point()
err = sig.UnmarshalBinary(buf)
if err != nil {
return nil, err
}

sigC := sig.Clone().Mul(coefs[peerIndex], sig)
sigC := sig.Clone().Mul(bdnMask.coefs[i], sig)
// c+1 because R is in the range [1, 2^128] and not [0, 2^128-1]
sigC = sigC.Add(sigC, sig)
agg = agg.Add(agg, sigC)
}

if len(sigs) > 0 {
return nil, errors.New("length of signatures and public keys must match")
}

return agg, nil
}

// AggregatePublicKeys aggregates a set of public keys (similarly to
// AggregateSignatures for signatures) using the hash function
// H: keyGroup -> R with R = {1, ..., 2^128}.
func (scheme *Scheme) AggregatePublicKeys(mask *sign.Mask) (kyber.Point, error) {
coefs, err := hashPointToR(mask.Publics())
func (scheme *Scheme) AggregatePublicKeys(mask Mask) (kyber.Point, error) {
bdnMask, err := newCachedMask(mask, false)
if err != nil {
return nil, err
}

agg := scheme.keyGroup.Point()
for i := 0; i < mask.CountEnabled(); i++ {
peerIndex := mask.IndexOfNthEnabled(i)
if peerIndex < 0 {
for i := range bdnMask.publics {
if enabled, err := mask.GetBit(i); err != nil {
// this should never happen because of the loop boundary
// an error here is probably a bug in the mask implementation
return nil, errors.New("couldn't find the index")
return nil, fmt.Errorf("couldn't find the index %d: %w", i, err)
} else if !enabled {
continue
}

pub := mask.Publics()[peerIndex]
pubC := pub.Clone().Mul(coefs[peerIndex], pub)
pubC = pubC.Add(pubC, pub)
agg = agg.Add(agg, pubC)
agg = agg.Add(agg, bdnMask.getOrComputePubC(i))
}

return agg, nil
Expand Down Expand Up @@ -217,14 +216,14 @@ func Verify(suite pairing.Suite, x kyber.Point, msg, sig []byte) error {
// AggregateSignatures aggregates the signatures using a coefficient for each
// one of them where c = H(pk) and H: G2 -> R with R = {1, ..., 2^128}
// Deprecated: use the new scheme methods instead.
func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask *sign.Mask) (kyber.Point, error) {
func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask Mask) (kyber.Point, error) {
return NewSchemeOnG1(suite).AggregateSignatures(sigs, mask)
}

// AggregatePublicKeys aggregates a set of public keys (similarly to
// AggregateSignatures for signatures) using the hash function
// H: G2 -> R with R = {1, ..., 2^128}.
// Deprecated: use the new scheme methods instead.
func AggregatePublicKeys(suite pairing.Suite, mask *sign.Mask) (kyber.Point, error) {
func AggregatePublicKeys(suite pairing.Suite, mask Mask) (kyber.Point, error) {
return NewSchemeOnG1(suite).AggregatePublicKeys(mask)
}
101 changes: 101 additions & 0 deletions sign/bdn/bdn_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package bdn

import (
"encoding"
"encoding/hex"
"fmt"
"testing"

"github.com/stretchr/testify/require"
"go.dedis.ch/kyber/v4"
"go.dedis.ch/kyber/v4/pairing/bls12381/kilic"
"go.dedis.ch/kyber/v4/pairing/bn256"
"go.dedis.ch/kyber/v4/sign"
"go.dedis.ch/kyber/v4/sign/bls"
Expand Down Expand Up @@ -158,3 +161,101 @@ func Benchmark_BDN_AggregateSigs(b *testing.B) {
AggregateSignatures(suite, [][]byte{sig1, sig2}, mask)
}
}

func Benchmark_BDN_BLS12381_AggregateVerify(b *testing.B) {
suite := kilic.NewBLS12381Suite()
schemeOnG2 := NewSchemeOnG2(suite)

rng := random.New()
pubKeys := make([]kyber.Point, 3000)
privKeys := make([]kyber.Scalar, 3000)
for i := range pubKeys {
privKeys[i], pubKeys[i] = schemeOnG2.NewKeyPair(rng)
}

baseMask, err := sign.NewMask(pubKeys, nil)
require.NoError(b, err)
mask, err := NewCachedMask(baseMask)
require.NoError(b, err)
for i := range pubKeys {
require.NoError(b, mask.SetBit(i, true))
}

msg := []byte("Hello many times Boneh-Lynn-Shacham")
sigs := make([][]byte, len(privKeys))
for i, k := range privKeys {
s, err := schemeOnG2.Sign(k, msg)
require.NoError(b, err)
sigs[i] = s
}

sig, err := schemeOnG2.AggregateSignatures(sigs, mask)
require.NoError(b, err)
sigb, err := sig.MarshalBinary()
require.NoError(b, err)

b.ResetTimer()
for i := 0; i < b.N; i++ {
pk, err := schemeOnG2.AggregatePublicKeys(mask)
require.NoError(b, err)
require.NoError(b, schemeOnG2.Verify(pk, msg, sigb))
}
}

func unmarshalHex[T encoding.BinaryUnmarshaler](t *testing.T, into T, s string) T {
t.Helper()
b, err := hex.DecodeString(s)
require.NoError(t, err)
require.NoError(t, into.UnmarshalBinary(b))
return into
}

// This tests exists to make sure we don't accidentally make breaking changes to signature
// aggregation by using checking against known aggregated signatures and keys.
func TestBDNFixtures(t *testing.T) {
suite := bn256.NewSuite()
schemeOnG1 := NewSchemeOnG1(suite)

public1 := unmarshalHex(t, suite.G2().Point(), "1a30714035c7a161e286e54c191b8c68345bd8239c74925a26290e8e1ae97ed6657958a17dca12c943fadceb11b824402389ff427179e0f10194da3c1b771c6083797d2b5915ea78123cbdb99ea6389d6d6b67dcb512a2b552c373094ee5693524e3ebb4a176f7efa7285c25c80081d8cb598745978f1a63b886c09a316b1493")
private1 := unmarshalHex(t, suite.G2().Scalar(), "49cfe5e9f4532670137184d43c0299f8b635bcacf6b0af7cab262494602d9f38")
public2 := unmarshalHex(t, suite.G2().Point(), "603bc61466ec8762ec6de2ba9a80b9d302d08f580d1685ac45a8e404a6ed549719dc0faf94d896a9983ff23423772720e3de5d800bc200de6f7d7e146162d3183b8880c5c0d8b71ca4b3b40f30c12d8cc0679c81a47c239c6aa7e9cc2edab4a927fe865cd413c1c17e3df8f74108e784cd77dd3e161bdaf30019a55826a32a1f")
private2 := unmarshalHex(t, suite.G2().Scalar(), "493abea4bb35b74c78ad9245f9d37883aeb6ee91f7fb0d8a8e11abf7aa2be581")
public3 := unmarshalHex(t, suite.G2().Point(), "56118769a1f0b6286abacaa32109c1497ab0819c5d21f27317e184b6681c283007aa981cb4760de044946febdd6503ab77a4586bc29c04159e53a6fa5dcb9c0261ccd1cb2e28db5204ca829ac9f6be95f957a626544adc34ba3bc542533b6e2f5cbd0567e343641a61a42b63f26c3625f74b66f6f46d17b3bf1688fae4d455ec")
private3 := unmarshalHex(t, suite.G2().Scalar(), "7fb0ebc317e161502208c3c16a4af890dedc3c7b275e8a04e99c0528aa6a19aa")

sig1Exp, err := hex.DecodeString("0913b76987be19f943be23b636cab9a2484507717326bd8bbdcdbbb6b8d5eb9253cfb3597c3fa550ee4972a398813650825a871f8e0b242ae5ddbce1b7c0e2a8")
require.NoError(t, err)
sig2Exp, err := hex.DecodeString("21195d29b1863bca1559e24375211d1411d8a28a8f4c772870b07f4ccda2fd5e337c1315c210475c683e3aa8b87d3aed3f7255b3087daa30d1e1432dd61d7484")
require.NoError(t, err)
sig3Exp, err := hex.DecodeString("3c1ac80345c1733630dbdc8106925c867544b521c259f9fa9678d477e6e5d3d212b09bc0d95137c3dbc0af2241415156c56e757d5577a609293584d045593195")
require.NoError(t, err)

aggSigExp := unmarshalHex(t, suite.G1().Point(), "43c1d2ad5a7d71a08f3cd7495db6b3c81a4547af1b76438b2f215e85ec178fea048f93f6ffed65a69ea757b47761e7178103bb347fd79689652e55b6e0054af2")
aggKeyExp := unmarshalHex(t, suite.G2().Point(), "43b5161ede207b9a69fc93114b0c5022b76cc22e813ba739c7e622d826b132333cd637505399963b94e393ec7f5d4875f82391620b34be1fde1f232204fa4f723935d4dbfb725f059456bcf2557f846c03190969f7b800e904d25b0b5bcbdd421c9877d443f0313c3425dfc1e7e646b665d27b9e649faadef1129f95670d70e1")

msg := []byte("Hello many times Boneh-Lynn-Shacham")
sig1, err := schemeOnG1.Sign(private1, msg)
require.Nil(t, err)
require.Equal(t, sig1Exp, sig1)

sig2, err := schemeOnG1.Sign(private2, msg)
require.Nil(t, err)
require.Equal(t, sig2Exp, sig2)

sig3, err := schemeOnG1.Sign(private3, msg)
require.Nil(t, err)
require.Equal(t, sig3Exp, sig3)

mask, _ := sign.NewMask([]kyber.Point{public1, public2, public3}, nil)
mask.SetBit(0, true)
mask.SetBit(1, false)
mask.SetBit(2, true)

aggSig, err := schemeOnG1.AggregateSignatures([][]byte{sig1, sig3}, mask)
require.NoError(t, err)
require.True(t, aggSigExp.Equal(aggSig))

aggKey, err := schemeOnG1.AggregatePublicKeys(mask)
require.NoError(t, err)
require.True(t, aggKeyExp.Equal(aggKey))
}
112 changes: 112 additions & 0 deletions sign/bdn/mask.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package bdn

import (
"fmt"

"go.dedis.ch/kyber/v4"
"go.dedis.ch/kyber/v4/sign"
)

type Mask interface {

Check failure on line 10 in sign/bdn/mask.go

View workflow job for this annotation

GitHub Actions / lint

the interface has more than 10 methods: 12 (interfacebloat)
GetBit(i int) (bool, error)
SetBit(i int, enable bool) error

IndexOfNthEnabled(nth int) int
NthEnabledAtIndex(idx int) int

Publics() []kyber.Point
Participants() []kyber.Point

CountEnabled() int
CountTotal() int

Len() int
Mask() []byte
SetMask(mask []byte) error
Merge(mask []byte) error
Stebalien marked this conversation as resolved.
Show resolved Hide resolved
}

var _ Mask = (*sign.Mask)(nil)

// We need to rename this, otherwise we have a public field named Mask (when we embed it) which
// conflicts with the function named Mask. It also makes it private, which is nice.
type maskI = Mask

type CachedMask struct {
maskI
coefs []kyber.Scalar
pubKeyC []kyber.Point
// We could call Mask.Publics() instead of keeping these here, but that function copies the
// slice and this field lets us avoid that copy.
publics []kyber.Point
}

// Convert the passed mask (likely a *sign.Mask) into a BDN-specific mask with pre-computed terms.
//
// This cached mask will:
//
// 1. Pre-compute coefficients for signature aggregation. Once the CachedMask has been instantiated,
// distinct sets of signatures can be aggregated without any BLAKE2S hashing.
// 2. Pre-computes the terms for public key aggregation. Once the CachedMask has been instantiated,
// distinct sets of public keys can be aggregated by simply summing the cached terms, ~2 orders
// of magnitude faster than aggregating from scratch.
func NewCachedMask(mask Mask) (*CachedMask, error) {
return newCachedMask(mask, true)
}

func newCachedMask(mask Mask, precomputePubC bool) (*CachedMask, error) {
if m, ok := mask.(*CachedMask); ok {
return m, nil
}

publics := mask.Publics()
coefs, err := hashPointToR(publics)
if err != nil {
return nil, fmt.Errorf("failed to hash public keys: %w", err)
}

cm := &CachedMask{
maskI: mask,
coefs: coefs,
publics: publics,
}

if precomputePubC {
pubKeyC := make([]kyber.Point, len(publics))
for i := range publics {
pubKeyC[i] = cm.getOrComputePubC(i)
}
cm.pubKeyC = pubKeyC
}

return cm, err
}

// Clone copies the BDN mask while keeping the precomputed coefficients, etc.
func (cm *CachedMask) Clone() *CachedMask {
newMask, err := sign.NewMask(cm.publics, nil)
if err != nil {
// Not possible given that we didn't pass our own key.
panic(fmt.Sprintf("failed to create mask: %s", err))
}
if err := newMask.SetMask(cm.Mask()); err != nil {
// Not possible given that we're using the same sized mask.
panic(fmt.Sprintf("failed to create mask: %s", err))
}
return &CachedMask{
maskI: newMask,
coefs: cm.coefs,
pubKeyC: cm.pubKeyC,
publics: cm.publics,
}
}

func (cm *CachedMask) getOrComputePubC(i int) kyber.Point {
if cm.pubKeyC == nil {
// NOTE: don't cache here as we may be sharing this mask between threads.
pub := cm.publics[i]
pubC := pub.Clone().Mul(cm.coefs[i], pub)
return pubC.Add(pubC, pub)
}
return cm.pubKeyC[i]
}
11 changes: 11 additions & 0 deletions sign/mask.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@ func (m *Mask) SetMask(mask []byte) error {
return nil
}

// GetBit returns true if the given bit is set.
func (m *Mask) GetBit(i int) (bool, error) {
if i >= len(m.publics) || i < 0 {
return false, errors.New("index out of range")
}

byteIndex := i / 8
mask := byte(1) << uint(i&7)
return m.mask[byteIndex]&mask != 0, nil
}

// SetBit turns on or off the bit at the given index.
func (m *Mask) SetBit(i int, enable bool) error {
if i >= len(m.publics) || i < 0 {
Expand Down
Loading
Loading