From 10605188ee34484cfef86b77ad61523b1586bfae Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Fri, 9 Aug 2024 14:51:09 -0700 Subject: [PATCH 1/6] Remove n^2 algorithm from signature/key aggregation CountEnabled and IndexOfNthEnabled are both O(n) in the size of the mask, making this loop n^2. The BLS operations still tend to be the slow part, but the n^2 factor will start to show up with thousands of keys. --- sign/bdn/bdn.go | 45 ++++++++++++++++++++++++++------------------- sign/mask.go | 11 +++++++++++ 2 files changed, 37 insertions(+), 19 deletions(-) diff --git a/sign/bdn/bdn.go b/sign/bdn/bdn.go index 4b1ab1b9c..d6a86edc5 100644 --- a/sign/bdn/bdn.go +++ b/sign/bdn/bdn.go @@ -12,6 +12,7 @@ package bdn import ( "crypto/cipher" "errors" + "fmt" "math/big" "github.com/drand/kyber" @@ -129,31 +130,36 @@ func (scheme *Scheme) Verify(x kyber.Point, msg, sig []byte) error { // AggregateSignatures aggregates the signatures using a coefficient for each // one of them where c = H(pk) and H: keyGroup -> R with R = {1, ..., 2^128} func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *sign.Mask) (kyber.Point, error) { - if len(sigs) != mask.CountEnabled() { - return nil, errors.New("length of signatures and public keys must match") - } - - coefs, err := hashPointToR(mask.Publics()) + publics := mask.Publics() + coefs, err := hashPointToR(publics) if err != nil { return nil, err } agg := scheme.sigGroup.Point() - for i, buf := range sigs { - peerIndex := mask.IndexOfNthEnabled(i) - if peerIndex < 0 { - // this should never happen as we check the lenths at the beginning - // an error here is probably a bug in the mask - return nil, errors.New("couldn't find the index") + for i := range publics { + if enabled, err := mask.GetBit(i); err != nil { + // this should never happen because of the loop boundary + // an error here is probably a bug in the mask implementation + return nil, fmt.Errorf("couldn't find the index %d: %w", i, err) + } else if !enabled { + continue } + if len(sigs) == 0 { + return nil, errors.New("length of signatures and public keys must match") + } + + buf := sigs[0] + sigs = sigs[1:] + sig := scheme.sigGroup.Point() err = sig.UnmarshalBinary(buf) if err != nil { return nil, err } - sigC := sig.Clone().Mul(coefs[peerIndex], sig) + sigC := sig.Clone().Mul(coefs[i], sig) // c+1 because R is in the range [1, 2^128] and not [0, 2^128-1] sigC = sigC.Add(sigC, sig) agg = agg.Add(agg, sigC) @@ -166,22 +172,23 @@ func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *sign.Mask) (kyber // AggregateSignatures for signatures) using the hash function // H: keyGroup -> R with R = {1, ..., 2^128}. func (scheme *Scheme) AggregatePublicKeys(mask *sign.Mask) (kyber.Point, error) { - coefs, err := hashPointToR(mask.Publics()) + publics := mask.Publics() + coefs, err := hashPointToR(publics) if err != nil { return nil, err } agg := scheme.keyGroup.Point() - for i := 0; i < mask.CountEnabled(); i++ { - peerIndex := mask.IndexOfNthEnabled(i) - if peerIndex < 0 { + for i, pub := range publics { + if enabled, err := mask.GetBit(i); err != nil { // this should never happen because of the loop boundary // an error here is probably a bug in the mask implementation - return nil, errors.New("couldn't find the index") + return nil, fmt.Errorf("couldn't find the index %d: %w", i, err) + } else if !enabled { + continue } - pub := mask.Publics()[peerIndex] - pubC := pub.Clone().Mul(coefs[peerIndex], pub) + pubC := pub.Clone().Mul(coefs[i], pub) pubC = pubC.Add(pubC, pub) agg = agg.Add(agg, pubC) } diff --git a/sign/mask.go b/sign/mask.go index 98e96f0f6..f0b97a144 100644 --- a/sign/mask.go +++ b/sign/mask.go @@ -59,6 +59,17 @@ func (m *Mask) SetMask(mask []byte) error { return nil } +// GetBit returns true if the given bit is set. +func (m *Mask) GetBit(i int) (bool, error) { + if i >= len(m.publics) || i < 0 { + return false, errors.New("index out of range") + } + + byteIndex := i / 8 + mask := byte(1) << uint(i&7) + return m.mask[byteIndex]&mask != 0, nil +} + // SetBit turns on or off the bit at the given index. func (m *Mask) SetBit(i int, enable bool) error { if i >= len(m.publics) || i < 0 { From 4bf5c5dcb342bdb87db861ba25d95526b61d167c Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Tue, 13 Aug 2024 07:49:22 -0700 Subject: [PATCH 2/6] Remove an unnecessary loop from hashPointToR --- sign/bdn/bdn.go | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/sign/bdn/bdn.go b/sign/bdn/bdn.go index d6a86edc5..0577666e3 100644 --- a/sign/bdn/bdn.go +++ b/sign/bdn/bdn.go @@ -32,23 +32,16 @@ var modulus128 = new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), 128), big.NewI // We also use the entire roster so that the coefficient will vary for the same // public key used in different roster func hashPointToR(pubs []kyber.Point) ([]kyber.Scalar, error) { - peers := make([][]byte, len(pubs)) - for i, pub := range pubs { - peer, err := pub.MarshalBinary() - if err != nil { - return nil, err - } - - peers[i] = peer - } - h, err := blake2s.NewXOF(blake2s.OutputLengthUnknown, nil) if err != nil { return nil, err } - - for _, peer := range peers { - _, err := h.Write(peer) + for _, pub := range pubs { + peer, err := pub.MarshalBinary() + if err != nil { + return nil, err + } + _, err = h.Write(peer) if err != nil { return nil, err } From 6e99a9f504f3168889bf8f67247ed4f9d1e44312 Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Tue, 13 Aug 2024 15:44:41 +0000 Subject: [PATCH 3/6] rename mask -> bitIndex Co-authored-by: AnomalRoil --- sign/mask.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sign/mask.go b/sign/mask.go index f0b97a144..30ac9ec94 100644 --- a/sign/mask.go +++ b/sign/mask.go @@ -66,8 +66,8 @@ func (m *Mask) GetBit(i int) (bool, error) { } byteIndex := i / 8 - mask := byte(1) << uint(i&7) - return m.mask[byteIndex]&mask != 0, nil + bitIndex := byte(1) << uint(i&7) + return m.mask[byteIndex]&bitIndex != 0, nil } // SetBit turns on or off the bit at the given index. From ee6db8f263e1b0a884673fc6bb993094c9772b12 Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Mon, 26 Aug 2024 15:40:19 -0700 Subject: [PATCH 4/6] add test fixtures --- sign/bdn/bdn_test.go | 60 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/sign/bdn/bdn_test.go b/sign/bdn/bdn_test.go index db80d4706..46fef79f9 100644 --- a/sign/bdn/bdn_test.go +++ b/sign/bdn/bdn_test.go @@ -1,6 +1,8 @@ package bdn import ( + "encoding" + "encoding/hex" "fmt" "testing" @@ -183,6 +185,64 @@ func Benchmark_BDN_AggregateSigs(b *testing.B) { } } +func unmarshalHex[T encoding.BinaryUnmarshaler](t *testing.T, into T, s string) T { + t.Helper() + b, err := hex.DecodeString(s) + require.NoError(t, err) + require.NoError(t, into.UnmarshalBinary(b)) + return into +} + +// This tests exists to make sure we don't accidentally make breaking changes to signature +// aggregation by using checking against known aggregated signatures and keys. +func TestBDNFixtures(t *testing.T) { + suite := bn256.NewSuite() + schemeOnG1 := NewSchemeOnG1(suite) + + public1 := unmarshalHex(t, suite.G2().Point(), "1a30714035c7a161e286e54c191b8c68345bd8239c74925a26290e8e1ae97ed6657958a17dca12c943fadceb11b824402389ff427179e0f10194da3c1b771c6083797d2b5915ea78123cbdb99ea6389d6d6b67dcb512a2b552c373094ee5693524e3ebb4a176f7efa7285c25c80081d8cb598745978f1a63b886c09a316b1493") + private1 := unmarshalHex(t, suite.G2().Scalar(), "49cfe5e9f4532670137184d43c0299f8b635bcacf6b0af7cab262494602d9f38") + public2 := unmarshalHex(t, suite.G2().Point(), "603bc61466ec8762ec6de2ba9a80b9d302d08f580d1685ac45a8e404a6ed549719dc0faf94d896a9983ff23423772720e3de5d800bc200de6f7d7e146162d3183b8880c5c0d8b71ca4b3b40f30c12d8cc0679c81a47c239c6aa7e9cc2edab4a927fe865cd413c1c17e3df8f74108e784cd77dd3e161bdaf30019a55826a32a1f") + private2 := unmarshalHex(t, suite.G2().Scalar(), "493abea4bb35b74c78ad9245f9d37883aeb6ee91f7fb0d8a8e11abf7aa2be581") + public3 := unmarshalHex(t, suite.G2().Point(), "56118769a1f0b6286abacaa32109c1497ab0819c5d21f27317e184b6681c283007aa981cb4760de044946febdd6503ab77a4586bc29c04159e53a6fa5dcb9c0261ccd1cb2e28db5204ca829ac9f6be95f957a626544adc34ba3bc542533b6e2f5cbd0567e343641a61a42b63f26c3625f74b66f6f46d17b3bf1688fae4d455ec") + private3 := unmarshalHex(t, suite.G2().Scalar(), "7fb0ebc317e161502208c3c16a4af890dedc3c7b275e8a04e99c0528aa6a19aa") + + sig1Exp, err := hex.DecodeString("0913b76987be19f943be23b636cab9a2484507717326bd8bbdcdbbb6b8d5eb9253cfb3597c3fa550ee4972a398813650825a871f8e0b242ae5ddbce1b7c0e2a8") + require.NoError(t, err) + sig2Exp, err := hex.DecodeString("21195d29b1863bca1559e24375211d1411d8a28a8f4c772870b07f4ccda2fd5e337c1315c210475c683e3aa8b87d3aed3f7255b3087daa30d1e1432dd61d7484") + require.NoError(t, err) + sig3Exp, err := hex.DecodeString("3c1ac80345c1733630dbdc8106925c867544b521c259f9fa9678d477e6e5d3d212b09bc0d95137c3dbc0af2241415156c56e757d5577a609293584d045593195") + require.NoError(t, err) + + aggSigExp := unmarshalHex(t, suite.G1().Point(), "520875e6667e0acf489e458c6c2233d09af81afa3b2045e0ec2435cfc582ba2c44af281d688efcf991d20975ce32c9933a09f8c4b38c18ef4b4510d8fa0f09d7") + aggKeyExp := unmarshalHex(t, suite.G2().Point(), "394d47291878a81fefb17708c57cf8078b24c46bf4554b3012732acd15395dbf09f13a65e068de766f5449d1de130f09bf09dc35a67f7f822f2a187230e155891d40db3c51afa5b3e05a039c50d04ff9c788718a2887e34644a55a14a2a2679226a3315c281e03367a4d797db819625e0c662d35e45e0e9e7604c104179ae8a7") + + msg := []byte("Hello many times Boneh-Lynn-Shacham") + sig1, err := schemeOnG1.Sign(private1, msg) + require.Nil(t, err) + require.Equal(t, sig1Exp, sig1) + + sig2, err := schemeOnG1.Sign(private2, msg) + require.Nil(t, err) + require.Equal(t, sig2Exp, sig2) + + sig3, err := schemeOnG1.Sign(private3, msg) + require.Nil(t, err) + require.Equal(t, sig3Exp, sig3) + + mask, _ := sign.NewMask(suite, []kyber.Point{public1, public2, public3}, nil) + mask.SetBit(0, true) + mask.SetBit(1, false) + mask.SetBit(2, true) + + aggSig, err := schemeOnG1.AggregateSignatures([][]byte{sig1, sig2, sig3}, mask) + require.NoError(t, err) + require.True(t, aggSigExp.Equal(aggSig)) + + aggKey, err := schemeOnG1.AggregatePublicKeys(mask) + require.NoError(t, err) + require.True(t, aggKeyExp.Equal(aggKey)) +} + func TestBDNDeprecatedAPIs(t *testing.T) { msg := []byte("Hello Boneh-Lynn-Shacham") suite := bn256.NewSuite() From 307e6dc63582189d95ad4a0cd43dcdc59a326f6b Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Mon, 26 Aug 2024 15:50:59 -0700 Subject: [PATCH 5/6] flesh out mask test for getting/setting bits --- sign/mask_test.go | 38 +++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/sign/mask_test.go b/sign/mask_test.go index 06a00bb42..5d36cacbf 100644 --- a/sign/mask_test.go +++ b/sign/mask_test.go @@ -45,20 +45,56 @@ func TestMask_CreateMask(t *testing.T) { require.Error(t, err) } -func TestMask_SetBit(t *testing.T) { +func TestMask_SetGetBit(t *testing.T) { mask, err := NewMask(suite, publics, publics[2]) require.NoError(t, err) + // Make sure the mask is initially as we'd expect. + + bit, err := mask.GetBit(1) + require.NoError(t, err) + require.False(t, bit) + + bit, err = mask.GetBit(2) + require.NoError(t, err) + require.True(t, bit) + + // Set bit 1 + err = mask.SetBit(1, true) require.NoError(t, err) require.Equal(t, uint8(0x6), mask.Mask()[0]) require.Equal(t, 2, len(mask.Participants())) + bit, err = mask.GetBit(1) + require.NoError(t, err) + require.True(t, bit) + + // Unset bit 2 + err = mask.SetBit(2, false) require.NoError(t, err) require.Equal(t, uint8(0x2), mask.Mask()[0]) require.Equal(t, 1, len(mask.Participants())) + bit, err = mask.GetBit(2) + require.NoError(t, err) + require.False(t, bit) + + // Unset bit 10 (using byte 2 now) + + err = mask.SetBit(10, false) + require.NoError(t, err) + require.Equal(t, uint8(0x2), mask.Mask()[0]) + require.Equal(t, uint8(0x4), mask.Mask()[1]) + require.Equal(t, 2, len(mask.Participants())) + + bit, err = mask.GetBit(10) + require.NoError(t, err) + require.True(t, bit) + + // And make sure the range limit works. + err = mask.SetBit(-1, true) require.Error(t, err) err = mask.SetBit(len(publics), true) From 07398f96e33eea484d03e34a59c4db21fb83a070 Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Mon, 26 Aug 2024 17:40:46 -0700 Subject: [PATCH 6/6] Introduce a new CachedMask for BDN This new mask will pre-compute reusable values, speeding up repeated verification and aggregation of aggregate signatures (mostly the former). --- sign/bdn/bdn.go | 25 ++++------ sign/bdn/bdn_test.go | 40 ++++++++++++++++ sign/bdn/mask.go | 112 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 162 insertions(+), 15 deletions(-) create mode 100644 sign/bdn/mask.go diff --git a/sign/bdn/bdn.go b/sign/bdn/bdn.go index 0577666e3..7c9b9187f 100644 --- a/sign/bdn/bdn.go +++ b/sign/bdn/bdn.go @@ -122,15 +122,13 @@ func (scheme *Scheme) Verify(x kyber.Point, msg, sig []byte) error { // AggregateSignatures aggregates the signatures using a coefficient for each // one of them where c = H(pk) and H: keyGroup -> R with R = {1, ..., 2^128} -func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *sign.Mask) (kyber.Point, error) { - publics := mask.Publics() - coefs, err := hashPointToR(publics) +func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask Mask) (kyber.Point, error) { + bdnMask, err := newCachedMask(mask, false) if err != nil { return nil, err } - agg := scheme.sigGroup.Point() - for i := range publics { + for i := range bdnMask.publics { if enabled, err := mask.GetBit(i); err != nil { // this should never happen because of the loop boundary // an error here is probably a bug in the mask implementation @@ -152,7 +150,7 @@ func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *sign.Mask) (kyber return nil, err } - sigC := sig.Clone().Mul(coefs[i], sig) + sigC := sig.Clone().Mul(bdnMask.coefs[i], sig) // c+1 because R is in the range [1, 2^128] and not [0, 2^128-1] sigC = sigC.Add(sigC, sig) agg = agg.Add(agg, sigC) @@ -164,15 +162,14 @@ func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *sign.Mask) (kyber // AggregatePublicKeys aggregates a set of public keys (similarly to // AggregateSignatures for signatures) using the hash function // H: keyGroup -> R with R = {1, ..., 2^128}. -func (scheme *Scheme) AggregatePublicKeys(mask *sign.Mask) (kyber.Point, error) { - publics := mask.Publics() - coefs, err := hashPointToR(publics) +func (scheme *Scheme) AggregatePublicKeys(mask Mask) (kyber.Point, error) { + bdnMask, err := newCachedMask(mask, false) if err != nil { return nil, err } agg := scheme.keyGroup.Point() - for i, pub := range publics { + for i := range bdnMask.publics { if enabled, err := mask.GetBit(i); err != nil { // this should never happen because of the loop boundary // an error here is probably a bug in the mask implementation @@ -181,9 +178,7 @@ func (scheme *Scheme) AggregatePublicKeys(mask *sign.Mask) (kyber.Point, error) continue } - pubC := pub.Clone().Mul(coefs[i], pub) - pubC = pubC.Add(pubC, pub) - agg = agg.Add(agg, pubC) + agg = agg.Add(agg, bdnMask.getOrComputePubC(i)) } return agg, nil @@ -217,7 +212,7 @@ func Verify(suite pairing.Suite, x kyber.Point, msg, sig []byte) error { // AggregateSignatures aggregates the signatures using a coefficient for each // one of them where c = H(pk) and H: G2 -> R with R = {1, ..., 2^128} // Deprecated: use the new scheme methods instead. -func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask *sign.Mask) (kyber.Point, error) { +func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask Mask) (kyber.Point, error) { return NewSchemeOnG1(suite).AggregateSignatures(sigs, mask) } @@ -225,6 +220,6 @@ func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask *sign.Mask) (k // AggregateSignatures for signatures) using the hash function // H: G2 -> R with R = {1, ..., 2^128}. // Deprecated: use the new scheme methods instead. -func AggregatePublicKeys(suite pairing.Suite, mask *sign.Mask) (kyber.Point, error) { +func AggregatePublicKeys(suite pairing.Suite, mask Mask) (kyber.Point, error) { return NewSchemeOnG1(suite).AggregatePublicKeys(mask) } diff --git a/sign/bdn/bdn_test.go b/sign/bdn/bdn_test.go index 46fef79f9..a791d5979 100644 --- a/sign/bdn/bdn_test.go +++ b/sign/bdn/bdn_test.go @@ -185,6 +185,46 @@ func Benchmark_BDN_AggregateSigs(b *testing.B) { } } +func Benchmark_BDN_BLS12381_AggregateVerify(b *testing.B) { + suite := bls12381.NewBLS12381Suite() + schemeOnG2 := NewSchemeOnG2(suite) + + rng := random.New() + pubKeys := make([]kyber.Point, 3000) + privKeys := make([]kyber.Scalar, 3000) + for i := range pubKeys { + privKeys[i], pubKeys[i] = schemeOnG2.NewKeyPair(rng) + } + + baseMask, err := sign.NewMask(suite, pubKeys, nil) + require.NoError(b, err) + mask, err := NewCachedMask(baseMask) + require.NoError(b, err) + for i := range pubKeys { + require.NoError(b, mask.SetBit(i, true)) + } + + msg := []byte("Hello many times Boneh-Lynn-Shacham") + sigs := make([][]byte, len(privKeys)) + for i, k := range privKeys { + s, err := schemeOnG2.Sign(k, msg) + require.NoError(b, err) + sigs[i] = s + } + + sig, err := schemeOnG2.AggregateSignatures(sigs, mask) + require.NoError(b, err) + sigb, err := sig.MarshalBinary() + require.NoError(b, err) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + pk, err := schemeOnG2.AggregatePublicKeys(mask) + require.NoError(b, err) + require.NoError(b, schemeOnG2.Verify(pk, msg, sigb)) + } +} + func unmarshalHex[T encoding.BinaryUnmarshaler](t *testing.T, into T, s string) T { t.Helper() b, err := hex.DecodeString(s) diff --git a/sign/bdn/mask.go b/sign/bdn/mask.go new file mode 100644 index 000000000..26c26f8e9 --- /dev/null +++ b/sign/bdn/mask.go @@ -0,0 +1,112 @@ +package bdn + +import ( + "fmt" + + "github.com/drand/kyber" + "github.com/drand/kyber/sign" +) + +type Mask interface { + GetBit(i int) (bool, error) + SetBit(i int, enable bool) error + + IndexOfNthEnabled(nth int) int + NthEnabledAtIndex(idx int) int + + Publics() []kyber.Point + Participants() []kyber.Point + + CountEnabled() int + CountTotal() int + + Len() int + Mask() []byte + SetMask(mask []byte) error + Merge(mask []byte) error +} + +var _ Mask = (*sign.Mask)(nil) + +// We need to rename this, otherwise we have a public field named Mask (when we embed it) which +// conflicts with the function named Mask. It also makes it private, which is nice. +type maskI = Mask + +type CachedMask struct { + maskI + coefs []kyber.Scalar + pubKeyC []kyber.Point + // We could call Mask.Publics() instead of keeping these here, but that function copies the + // slice and this field lets us avoid that copy. + publics []kyber.Point +} + +// Convert the passed mask (likely a *sign.Mask) into a BDN-specific mask with pre-computed terms. +// +// This cached mask will: +// +// 1. Pre-compute coefficients for signature aggregation. Once the CachedMask has been instantiated, +// distinct sets of signatures can be aggregated without any BLAKE2S hashing. +// 2. Pre-computes the terms for public key aggregation. Once the CachedMask has been instantiated, +// distinct sets of public keys can be aggregated by simply summing the cached terms, ~2 orders +// of magnitude faster than aggregating from scratch. +func NewCachedMask(mask Mask) (*CachedMask, error) { + return newCachedMask(mask, true) +} + +func newCachedMask(mask Mask, precomputePubC bool) (*CachedMask, error) { + if m, ok := mask.(*CachedMask); ok { + return m, nil + } + + publics := mask.Publics() + coefs, err := hashPointToR(publics) + if err != nil { + return nil, fmt.Errorf("failed to hash public keys: %w", err) + } + + cm := &CachedMask{ + maskI: mask, + coefs: coefs, + publics: publics, + } + + if precomputePubC { + pubKeyC := make([]kyber.Point, len(publics)) + for i := range publics { + pubKeyC[i] = cm.getOrComputePubC(i) + } + cm.pubKeyC = pubKeyC + } + + return cm, err +} + +// Clone copies the BDN mask while keeping the precomputed coefficients, etc. +func (cm *CachedMask) Clone() *CachedMask { + newMask, err := sign.NewMask(nil, cm.publics, nil) + if err != nil { + // Not possible given that we didn't pass our own key. + panic(fmt.Sprintf("failed to create mask: %s", err)) + } + if err := newMask.SetMask(cm.Mask()); err != nil { + // Not possible given that we're using the same sized mask. + panic(fmt.Sprintf("failed to create mask: %s", err)) + } + return &CachedMask{ + maskI: newMask, + coefs: cm.coefs, + pubKeyC: cm.pubKeyC, + publics: cm.publics, + } +} + +func (cm *CachedMask) getOrComputePubC(i int) kyber.Point { + if cm.pubKeyC == nil { + // NOTE: don't cache here as we may be sharing this mask between threads. + pub := cm.publics[i] + pubC := pub.Clone().Mul(cm.coefs[i], pub) + return pubC.Add(pubC, pub) + } + return cm.pubKeyC[i] +}