diff --git a/sign/bdn/bdn.go b/sign/bdn/bdn.go index 4b1ab1b9c..7c9b9187f 100644 --- a/sign/bdn/bdn.go +++ b/sign/bdn/bdn.go @@ -12,6 +12,7 @@ package bdn import ( "crypto/cipher" "errors" + "fmt" "math/big" "github.com/drand/kyber" @@ -31,23 +32,16 @@ var modulus128 = new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), 128), big.NewI // We also use the entire roster so that the coefficient will vary for the same // public key used in different roster func hashPointToR(pubs []kyber.Point) ([]kyber.Scalar, error) { - peers := make([][]byte, len(pubs)) - for i, pub := range pubs { - peer, err := pub.MarshalBinary() - if err != nil { - return nil, err - } - - peers[i] = peer - } - h, err := blake2s.NewXOF(blake2s.OutputLengthUnknown, nil) if err != nil { return nil, err } - - for _, peer := range peers { - _, err := h.Write(peer) + for _, pub := range pubs { + peer, err := pub.MarshalBinary() + if err != nil { + return nil, err + } + _, err = h.Write(peer) if err != nil { return nil, err } @@ -128,32 +122,35 @@ func (scheme *Scheme) Verify(x kyber.Point, msg, sig []byte) error { // AggregateSignatures aggregates the signatures using a coefficient for each // one of them where c = H(pk) and H: keyGroup -> R with R = {1, ..., 2^128} -func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *sign.Mask) (kyber.Point, error) { - if len(sigs) != mask.CountEnabled() { - return nil, errors.New("length of signatures and public keys must match") - } - - coefs, err := hashPointToR(mask.Publics()) +func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask Mask) (kyber.Point, error) { + bdnMask, err := newCachedMask(mask, false) if err != nil { return nil, err } - agg := scheme.sigGroup.Point() - for i, buf := range sigs { - peerIndex := mask.IndexOfNthEnabled(i) - if peerIndex < 0 { - // this should never happen as we check the lenths at the beginning - // an error here is probably a bug in the mask - return nil, errors.New("couldn't find the index") + for i := range bdnMask.publics { + if enabled, err := mask.GetBit(i); err != nil { + // this should never happen because of the loop boundary + // an error here is probably a bug in the mask implementation + return nil, fmt.Errorf("couldn't find the index %d: %w", i, err) + } else if !enabled { + continue } + if len(sigs) == 0 { + return nil, errors.New("length of signatures and public keys must match") + } + + buf := sigs[0] + sigs = sigs[1:] + sig := scheme.sigGroup.Point() err = sig.UnmarshalBinary(buf) if err != nil { return nil, err } - sigC := sig.Clone().Mul(coefs[peerIndex], sig) + sigC := sig.Clone().Mul(bdnMask.coefs[i], sig) // c+1 because R is in the range [1, 2^128] and not [0, 2^128-1] sigC = sigC.Add(sigC, sig) agg = agg.Add(agg, sigC) @@ -165,25 +162,23 @@ func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *sign.Mask) (kyber // AggregatePublicKeys aggregates a set of public keys (similarly to // AggregateSignatures for signatures) using the hash function // H: keyGroup -> R with R = {1, ..., 2^128}. -func (scheme *Scheme) AggregatePublicKeys(mask *sign.Mask) (kyber.Point, error) { - coefs, err := hashPointToR(mask.Publics()) +func (scheme *Scheme) AggregatePublicKeys(mask Mask) (kyber.Point, error) { + bdnMask, err := newCachedMask(mask, false) if err != nil { return nil, err } agg := scheme.keyGroup.Point() - for i := 0; i < mask.CountEnabled(); i++ { - peerIndex := mask.IndexOfNthEnabled(i) - if peerIndex < 0 { + for i := range bdnMask.publics { + if enabled, err := mask.GetBit(i); err != nil { // this should never happen because of the loop boundary // an error here is probably a bug in the mask implementation - return nil, errors.New("couldn't find the index") + return nil, fmt.Errorf("couldn't find the index %d: %w", i, err) + } else if !enabled { + continue } - pub := mask.Publics()[peerIndex] - pubC := pub.Clone().Mul(coefs[peerIndex], pub) - pubC = pubC.Add(pubC, pub) - agg = agg.Add(agg, pubC) + agg = agg.Add(agg, bdnMask.getOrComputePubC(i)) } return agg, nil @@ -217,7 +212,7 @@ func Verify(suite pairing.Suite, x kyber.Point, msg, sig []byte) error { // AggregateSignatures aggregates the signatures using a coefficient for each // one of them where c = H(pk) and H: G2 -> R with R = {1, ..., 2^128} // Deprecated: use the new scheme methods instead. -func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask *sign.Mask) (kyber.Point, error) { +func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask Mask) (kyber.Point, error) { return NewSchemeOnG1(suite).AggregateSignatures(sigs, mask) } @@ -225,6 +220,6 @@ func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask *sign.Mask) (k // AggregateSignatures for signatures) using the hash function // H: G2 -> R with R = {1, ..., 2^128}. // Deprecated: use the new scheme methods instead. -func AggregatePublicKeys(suite pairing.Suite, mask *sign.Mask) (kyber.Point, error) { +func AggregatePublicKeys(suite pairing.Suite, mask Mask) (kyber.Point, error) { return NewSchemeOnG1(suite).AggregatePublicKeys(mask) } diff --git a/sign/bdn/bdn_test.go b/sign/bdn/bdn_test.go index db80d4706..a791d5979 100644 --- a/sign/bdn/bdn_test.go +++ b/sign/bdn/bdn_test.go @@ -1,6 +1,8 @@ package bdn import ( + "encoding" + "encoding/hex" "fmt" "testing" @@ -183,6 +185,104 @@ func Benchmark_BDN_AggregateSigs(b *testing.B) { } } +func Benchmark_BDN_BLS12381_AggregateVerify(b *testing.B) { + suite := bls12381.NewBLS12381Suite() + schemeOnG2 := NewSchemeOnG2(suite) + + rng := random.New() + pubKeys := make([]kyber.Point, 3000) + privKeys := make([]kyber.Scalar, 3000) + for i := range pubKeys { + privKeys[i], pubKeys[i] = schemeOnG2.NewKeyPair(rng) + } + + baseMask, err := sign.NewMask(suite, pubKeys, nil) + require.NoError(b, err) + mask, err := NewCachedMask(baseMask) + require.NoError(b, err) + for i := range pubKeys { + require.NoError(b, mask.SetBit(i, true)) + } + + msg := []byte("Hello many times Boneh-Lynn-Shacham") + sigs := make([][]byte, len(privKeys)) + for i, k := range privKeys { + s, err := schemeOnG2.Sign(k, msg) + require.NoError(b, err) + sigs[i] = s + } + + sig, err := schemeOnG2.AggregateSignatures(sigs, mask) + require.NoError(b, err) + sigb, err := sig.MarshalBinary() + require.NoError(b, err) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + pk, err := schemeOnG2.AggregatePublicKeys(mask) + require.NoError(b, err) + require.NoError(b, schemeOnG2.Verify(pk, msg, sigb)) + } +} + +func unmarshalHex[T encoding.BinaryUnmarshaler](t *testing.T, into T, s string) T { + t.Helper() + b, err := hex.DecodeString(s) + require.NoError(t, err) + require.NoError(t, into.UnmarshalBinary(b)) + return into +} + +// This tests exists to make sure we don't accidentally make breaking changes to signature +// aggregation by using checking against known aggregated signatures and keys. +func TestBDNFixtures(t *testing.T) { + suite := bn256.NewSuite() + schemeOnG1 := NewSchemeOnG1(suite) + + public1 := unmarshalHex(t, suite.G2().Point(), "1a30714035c7a161e286e54c191b8c68345bd8239c74925a26290e8e1ae97ed6657958a17dca12c943fadceb11b824402389ff427179e0f10194da3c1b771c6083797d2b5915ea78123cbdb99ea6389d6d6b67dcb512a2b552c373094ee5693524e3ebb4a176f7efa7285c25c80081d8cb598745978f1a63b886c09a316b1493") + private1 := unmarshalHex(t, suite.G2().Scalar(), "49cfe5e9f4532670137184d43c0299f8b635bcacf6b0af7cab262494602d9f38") + public2 := unmarshalHex(t, suite.G2().Point(), "603bc61466ec8762ec6de2ba9a80b9d302d08f580d1685ac45a8e404a6ed549719dc0faf94d896a9983ff23423772720e3de5d800bc200de6f7d7e146162d3183b8880c5c0d8b71ca4b3b40f30c12d8cc0679c81a47c239c6aa7e9cc2edab4a927fe865cd413c1c17e3df8f74108e784cd77dd3e161bdaf30019a55826a32a1f") + private2 := unmarshalHex(t, suite.G2().Scalar(), "493abea4bb35b74c78ad9245f9d37883aeb6ee91f7fb0d8a8e11abf7aa2be581") + public3 := unmarshalHex(t, suite.G2().Point(), "56118769a1f0b6286abacaa32109c1497ab0819c5d21f27317e184b6681c283007aa981cb4760de044946febdd6503ab77a4586bc29c04159e53a6fa5dcb9c0261ccd1cb2e28db5204ca829ac9f6be95f957a626544adc34ba3bc542533b6e2f5cbd0567e343641a61a42b63f26c3625f74b66f6f46d17b3bf1688fae4d455ec") + private3 := unmarshalHex(t, suite.G2().Scalar(), "7fb0ebc317e161502208c3c16a4af890dedc3c7b275e8a04e99c0528aa6a19aa") + + sig1Exp, err := hex.DecodeString("0913b76987be19f943be23b636cab9a2484507717326bd8bbdcdbbb6b8d5eb9253cfb3597c3fa550ee4972a398813650825a871f8e0b242ae5ddbce1b7c0e2a8") + require.NoError(t, err) + sig2Exp, err := hex.DecodeString("21195d29b1863bca1559e24375211d1411d8a28a8f4c772870b07f4ccda2fd5e337c1315c210475c683e3aa8b87d3aed3f7255b3087daa30d1e1432dd61d7484") + require.NoError(t, err) + sig3Exp, err := hex.DecodeString("3c1ac80345c1733630dbdc8106925c867544b521c259f9fa9678d477e6e5d3d212b09bc0d95137c3dbc0af2241415156c56e757d5577a609293584d045593195") + require.NoError(t, err) + + aggSigExp := unmarshalHex(t, suite.G1().Point(), "520875e6667e0acf489e458c6c2233d09af81afa3b2045e0ec2435cfc582ba2c44af281d688efcf991d20975ce32c9933a09f8c4b38c18ef4b4510d8fa0f09d7") + aggKeyExp := unmarshalHex(t, suite.G2().Point(), "394d47291878a81fefb17708c57cf8078b24c46bf4554b3012732acd15395dbf09f13a65e068de766f5449d1de130f09bf09dc35a67f7f822f2a187230e155891d40db3c51afa5b3e05a039c50d04ff9c788718a2887e34644a55a14a2a2679226a3315c281e03367a4d797db819625e0c662d35e45e0e9e7604c104179ae8a7") + + msg := []byte("Hello many times Boneh-Lynn-Shacham") + sig1, err := schemeOnG1.Sign(private1, msg) + require.Nil(t, err) + require.Equal(t, sig1Exp, sig1) + + sig2, err := schemeOnG1.Sign(private2, msg) + require.Nil(t, err) + require.Equal(t, sig2Exp, sig2) + + sig3, err := schemeOnG1.Sign(private3, msg) + require.Nil(t, err) + require.Equal(t, sig3Exp, sig3) + + mask, _ := sign.NewMask(suite, []kyber.Point{public1, public2, public3}, nil) + mask.SetBit(0, true) + mask.SetBit(1, false) + mask.SetBit(2, true) + + aggSig, err := schemeOnG1.AggregateSignatures([][]byte{sig1, sig2, sig3}, mask) + require.NoError(t, err) + require.True(t, aggSigExp.Equal(aggSig)) + + aggKey, err := schemeOnG1.AggregatePublicKeys(mask) + require.NoError(t, err) + require.True(t, aggKeyExp.Equal(aggKey)) +} + func TestBDNDeprecatedAPIs(t *testing.T) { msg := []byte("Hello Boneh-Lynn-Shacham") suite := bn256.NewSuite() diff --git a/sign/bdn/mask.go b/sign/bdn/mask.go new file mode 100644 index 000000000..26c26f8e9 --- /dev/null +++ b/sign/bdn/mask.go @@ -0,0 +1,112 @@ +package bdn + +import ( + "fmt" + + "github.com/drand/kyber" + "github.com/drand/kyber/sign" +) + +type Mask interface { + GetBit(i int) (bool, error) + SetBit(i int, enable bool) error + + IndexOfNthEnabled(nth int) int + NthEnabledAtIndex(idx int) int + + Publics() []kyber.Point + Participants() []kyber.Point + + CountEnabled() int + CountTotal() int + + Len() int + Mask() []byte + SetMask(mask []byte) error + Merge(mask []byte) error +} + +var _ Mask = (*sign.Mask)(nil) + +// We need to rename this, otherwise we have a public field named Mask (when we embed it) which +// conflicts with the function named Mask. It also makes it private, which is nice. +type maskI = Mask + +type CachedMask struct { + maskI + coefs []kyber.Scalar + pubKeyC []kyber.Point + // We could call Mask.Publics() instead of keeping these here, but that function copies the + // slice and this field lets us avoid that copy. + publics []kyber.Point +} + +// Convert the passed mask (likely a *sign.Mask) into a BDN-specific mask with pre-computed terms. +// +// This cached mask will: +// +// 1. Pre-compute coefficients for signature aggregation. Once the CachedMask has been instantiated, +// distinct sets of signatures can be aggregated without any BLAKE2S hashing. +// 2. Pre-computes the terms for public key aggregation. Once the CachedMask has been instantiated, +// distinct sets of public keys can be aggregated by simply summing the cached terms, ~2 orders +// of magnitude faster than aggregating from scratch. +func NewCachedMask(mask Mask) (*CachedMask, error) { + return newCachedMask(mask, true) +} + +func newCachedMask(mask Mask, precomputePubC bool) (*CachedMask, error) { + if m, ok := mask.(*CachedMask); ok { + return m, nil + } + + publics := mask.Publics() + coefs, err := hashPointToR(publics) + if err != nil { + return nil, fmt.Errorf("failed to hash public keys: %w", err) + } + + cm := &CachedMask{ + maskI: mask, + coefs: coefs, + publics: publics, + } + + if precomputePubC { + pubKeyC := make([]kyber.Point, len(publics)) + for i := range publics { + pubKeyC[i] = cm.getOrComputePubC(i) + } + cm.pubKeyC = pubKeyC + } + + return cm, err +} + +// Clone copies the BDN mask while keeping the precomputed coefficients, etc. +func (cm *CachedMask) Clone() *CachedMask { + newMask, err := sign.NewMask(nil, cm.publics, nil) + if err != nil { + // Not possible given that we didn't pass our own key. + panic(fmt.Sprintf("failed to create mask: %s", err)) + } + if err := newMask.SetMask(cm.Mask()); err != nil { + // Not possible given that we're using the same sized mask. + panic(fmt.Sprintf("failed to create mask: %s", err)) + } + return &CachedMask{ + maskI: newMask, + coefs: cm.coefs, + pubKeyC: cm.pubKeyC, + publics: cm.publics, + } +} + +func (cm *CachedMask) getOrComputePubC(i int) kyber.Point { + if cm.pubKeyC == nil { + // NOTE: don't cache here as we may be sharing this mask between threads. + pub := cm.publics[i] + pubC := pub.Clone().Mul(cm.coefs[i], pub) + return pubC.Add(pubC, pub) + } + return cm.pubKeyC[i] +} diff --git a/sign/mask.go b/sign/mask.go index 98e96f0f6..30ac9ec94 100644 --- a/sign/mask.go +++ b/sign/mask.go @@ -59,6 +59,17 @@ func (m *Mask) SetMask(mask []byte) error { return nil } +// GetBit returns true if the given bit is set. +func (m *Mask) GetBit(i int) (bool, error) { + if i >= len(m.publics) || i < 0 { + return false, errors.New("index out of range") + } + + byteIndex := i / 8 + bitIndex := byte(1) << uint(i&7) + return m.mask[byteIndex]&bitIndex != 0, nil +} + // SetBit turns on or off the bit at the given index. func (m *Mask) SetBit(i int, enable bool) error { if i >= len(m.publics) || i < 0 { diff --git a/sign/mask_test.go b/sign/mask_test.go index 06a00bb42..5d36cacbf 100644 --- a/sign/mask_test.go +++ b/sign/mask_test.go @@ -45,20 +45,56 @@ func TestMask_CreateMask(t *testing.T) { require.Error(t, err) } -func TestMask_SetBit(t *testing.T) { +func TestMask_SetGetBit(t *testing.T) { mask, err := NewMask(suite, publics, publics[2]) require.NoError(t, err) + // Make sure the mask is initially as we'd expect. + + bit, err := mask.GetBit(1) + require.NoError(t, err) + require.False(t, bit) + + bit, err = mask.GetBit(2) + require.NoError(t, err) + require.True(t, bit) + + // Set bit 1 + err = mask.SetBit(1, true) require.NoError(t, err) require.Equal(t, uint8(0x6), mask.Mask()[0]) require.Equal(t, 2, len(mask.Participants())) + bit, err = mask.GetBit(1) + require.NoError(t, err) + require.True(t, bit) + + // Unset bit 2 + err = mask.SetBit(2, false) require.NoError(t, err) require.Equal(t, uint8(0x2), mask.Mask()[0]) require.Equal(t, 1, len(mask.Participants())) + bit, err = mask.GetBit(2) + require.NoError(t, err) + require.False(t, bit) + + // Unset bit 10 (using byte 2 now) + + err = mask.SetBit(10, false) + require.NoError(t, err) + require.Equal(t, uint8(0x2), mask.Mask()[0]) + require.Equal(t, uint8(0x4), mask.Mask()[1]) + require.Equal(t, 2, len(mask.Participants())) + + bit, err = mask.GetBit(10) + require.NoError(t, err) + require.True(t, bit) + + // And make sure the range limit works. + err = mask.SetBit(-1, true) require.Error(t, err) err = mask.SetBit(len(publics), true)