From abb4a98f9f1c45e2c57b9d8fa7fbc50878f59382 Mon Sep 17 00:00:00 2001 From: Nicolas Gailly Date: Fri, 12 Jun 2020 17:01:07 +0100 Subject: [PATCH] fix set of messages (#15) * fix set of messages * fix from comments --- share/dkg/dkg.go | 3 +- share/dkg/proto_test.go | 46 +++++++++++++ share/dkg/protocol.go | 145 ++++++++++++++++++++++++++-------------- share/dkg/structs.go | 27 ++++++++ 4 files changed, 169 insertions(+), 52 deletions(-) diff --git a/share/dkg/dkg.go b/share/dkg/dkg.go index 096edf3e7..7f41e468c 100644 --- a/share/dkg/dkg.go +++ b/share/dkg/dkg.go @@ -100,7 +100,8 @@ type DkgConfig struct { FastSync bool // Nonce is required to avoid replay attacks from previous runs of a DKG / - // resharing. A Nonce must be of length 32 bytes. User can get a secure + // resharing. The required property of the Nonce is that it must be unique + // accross runs. A Nonce must be of length 32 bytes. User can get a secure // nonce by calling `GetNonce()`. Nonce []byte } diff --git a/share/dkg/proto_test.go b/share/dkg/proto_test.go index a83b9b62e..ac5ae5549 100644 --- a/share/dkg/proto_test.go +++ b/share/dkg/proto_test.go @@ -5,8 +5,10 @@ import ( "testing" "time" + "github.com/drand/kyber" "github.com/drand/kyber/group/edwards25519" "github.com/drand/kyber/sign/schnorr" + "github.com/drand/kyber/util/random" clock "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" ) @@ -452,3 +454,47 @@ func TestProtoThresholdFast(t *testing.T) { } } } + +func generateDeal(idx Index) *DealBundle { + suite := edwards25519.NewBlakeSHA256Ed25519() + deals := make([]Deal, 2) + deals[0].ShareIndex = 56 + deals[1].ShareIndex = 57 + deals[0].EncryptedShare = []byte("My first secure share") + deals[1].EncryptedShare = []byte("It keeps getting more secure") + return &DealBundle{ + DealerIndex: idx, + Deals: deals, + Public: []kyber.Point{suite.Point().Pick(random.New())}, + SessionID: []byte("Blob"), + } +} + +func TestSet(t *testing.T) { + s := newSet() + deal := generateDeal(1) + s.Push(deal) + require.NotNil(t, s.vals[1]) + require.Nil(t, s.bad) + // push a second time shouldn't change the set + s.Push(deal) + require.NotNil(t, s.vals[1]) + require.Nil(t, s.bad) + + deal2 := generateDeal(2) + s.Push(deal2) + require.Equal(t, 2, len(s.vals)) + require.Nil(t, s.bad) + + // push a different deal + deal1b := generateDeal(1) + s.Push(deal1b) + require.Equal(t, 1, len(s.vals)) + require.Contains(t, s.bad, Index(1)) + + // try again, it should fail directly + s.Push(deal1b) + require.Equal(t, 1, len(s.vals)) + require.Contains(t, s.bad, Index(1)) + +} diff --git a/share/dkg/protocol.go b/share/dkg/protocol.go index d94195527..ed2ee1cf1 100644 --- a/share/dkg/protocol.go +++ b/share/dkg/protocol.go @@ -4,7 +4,6 @@ import ( "bytes" "errors" "fmt" - "strconv" "strings" "time" @@ -131,9 +130,9 @@ func (p *Protocol) Start() { p.startFast() return } - var deals []*DealBundle - var resps []*ResponseBundle - var justifs []*JustificationBundle + var deals = newSet() + var resps = newSet() + var justifs = newSet() for { select { case newPhase := <-p.phaser.NextPhase(): @@ -143,37 +142,37 @@ func (p *Protocol) Start() { return } case ResponsePhase: - if !p.sendResponses(deals) { + if !p.sendResponses(deals.ToDeals()) { return } case JustifPhase: - if !p.sendJustifications(resps) { + if !p.sendJustifications(resps.ToResponses()) { return } case FinishPhase: - p.finish(justifs) + p.finish(justifs.ToJustifications()) return } case newDeal := <-p.board.IncomingDeal(): if err := p.VerifySignature(newDeal); err == nil { - deals = append(deals, newDeal.Bundle) + deals.Push(newDeal.Bundle) } case newResp := <-p.board.IncomingResponse(): if err := p.VerifySignature(newResp); err == nil { - resps = append(resps, newResp.Bundle) + resps.Push(newResp.Bundle) } case newJust := <-p.board.IncomingJustification(): if err := p.VerifySignature(newJust); err == nil { - justifs = append(justifs, newJust.Bundle) + justifs.Push(newJust.Bundle) } } } } func (p *Protocol) startFast() { - var deals = make(map[uint32]*DealBundle) - var resps = make(map[uint32]*ResponseBundle) - var justifs = make(map[uint32]*JustificationBundle) + var deals = newSet() + var resps = newSet() + var justifs = newSet() var newN = len(p.conf.DkgConfig.NewNodes) var oldN = len(p.conf.DkgConfig.OldNodes) var phase Phase @@ -182,11 +181,7 @@ func (p *Protocol) startFast() { return true } phase = ResponsePhase - bdeals := make([]*DealBundle, 0, len(deals)) - for _, d := range deals { - bdeals = append(bdeals, d) - } - if !p.sendResponses(bdeals) { + if !p.sendResponses(deals.ToDeals()) { return false } return true @@ -196,11 +191,7 @@ func (p *Protocol) startFast() { return true } phase = JustifPhase - bresps := make([]*ResponseBundle, 0, len(resps)) - for _, r := range resps { - bresps = append(bresps, r) - } - if !p.sendJustifications(bresps) { + if !p.sendJustifications(resps.ToResponses()) { return false } return true @@ -210,11 +201,7 @@ func (p *Protocol) startFast() { // although it should never happen twice but never too sure return } - bjusts := make([]*JustificationBundle, 0, len(justifs)) - for _, j := range justifs { - bjusts = append(bjusts, j) - } - p.finish(bjusts) + p.finish(justifs.ToJustifications()) } for { select { @@ -239,28 +226,11 @@ func (p *Protocol) startFast() { } case newDeal := <-p.board.IncomingDeal(): if err := p.VerifySignature(newDeal); err == nil { - // we make sure we don't see two deals from the same dealer that - // are inconsistent - For example we might receive multiple - // times the same deal from the network due to the use of - // gossiping; here we make sure they're all consistent. - if prevDeal, ok := deals[newDeal.Bundle.DealerIndex]; ok { - prevHash := prevDeal.Hash() - newHash := newDeal.Bundle.Hash() - if !bytes.Equal(prevHash, newHash) { - delete(deals, newDeal.Bundle.DealerIndex) - continue - } - } - deals[newDeal.Bundle.DealerIndex] = newDeal.Bundle - var idxs []string - for idx := range deals { - idxs = append(idxs, strconv.Itoa(int(idx))) - } - + deals.Push(newDeal.Bundle) } // XXX This assumes we receive our own deal bundle since we use a // broadcast channel - may need to revisit that assumption - if len(deals) == oldN { + if deals.Len() == oldN { if !sendResponseFn() { return } @@ -269,9 +239,9 @@ func (p *Protocol) startFast() { // TODO See how can we deal with inconsistent answers from different // share holders if err := p.VerifySignature(newResp); err == nil { - resps[newResp.Bundle.ShareIndex] = newResp.Bundle + resps.Push(newResp.Bundle) } - if len(resps) == newN { + if resps.Len() == newN { if !sendJustifFn() { return } @@ -280,9 +250,9 @@ func (p *Protocol) startFast() { // TODO see how can we deal with inconsistent answers from different // dealers if err := p.VerifySignature(newJust); err == nil { - justifs[newJust.Bundle.DealerIndex] = newJust.Bundle + justifs.Push(newJust.Bundle) } - if len(justifs) == oldN { + if justifs.Len() == oldN { finishFn() return } @@ -439,3 +409,76 @@ type OptionResult struct { Result *Result Error error } + +type packet interface { + Hash() []byte + Index() Index +} + +type set struct { + vals map[Index]packet + bad []Index +} + +func newSet() *set { + return &set{ + vals: make(map[Index]packet), + } +} + +func (s *set) Push(p packet) { + hash := p.Hash() + idx := p.Index() + if s.isBad(idx) { + // already misbehaved before + return + } + prev, present := s.vals[idx] + if present { + if !bytes.Equal(prev.Hash(), hash) { + // bad behavior - we evict + delete(s.vals, idx) + s.bad = append(s.bad, idx) + } + // same packet just rebroadcasted - all good + return + } + s.vals[idx] = p +} + +func (s *set) isBad(idx Index) bool { + for _, i := range s.bad { + if idx == i { + return true + } + } + return false +} + +func (s *set) ToDeals() []*DealBundle { + deals := make([]*DealBundle, 0, len(s.vals)) + for _, p := range s.vals { + deals = append(deals, p.(*DealBundle)) + } + return deals +} + +func (s *set) ToResponses() []*ResponseBundle { + resps := make([]*ResponseBundle, 0, len(s.vals)) + for _, p := range s.vals { + resps = append(resps, p.(*ResponseBundle)) + } + return resps +} + +func (s *set) ToJustifications() []*JustificationBundle { + justs := make([]*JustificationBundle, 0, len(s.vals)) + for _, p := range s.vals { + justs = append(justs, p.(*JustificationBundle)) + } + return justs +} + +func (s *set) Len() int { + return len(s.vals) +} diff --git a/share/dkg/structs.go b/share/dkg/structs.go index 6f5de4d3e..56c199e58 100644 --- a/share/dkg/structs.go +++ b/share/dkg/structs.go @@ -94,6 +94,10 @@ type Deal struct { EncryptedShare []byte } +var _ packet = (*DealBundle)(nil) + +// DealBundle is the struct sent out by dealers that contains all the deals and +// the public polynomial. type DealBundle struct { DealerIndex uint32 Deals []Deal @@ -119,9 +123,14 @@ func (d *DealBundle) Hash() []byte { binary.Write(h, binary.BigEndian, deal.ShareIndex) h.Write(deal.EncryptedShare) } + h.Write(d.SessionID) return h.Sum(nil) } +func (d *DealBundle) Index() Index { + return d.DealerIndex +} + // Response holds the Response from another participant as well as the index of // the target Dealer. type Response struct { @@ -130,6 +139,10 @@ type Response struct { Status bool } +var _ packet = (*ResponseBundle)(nil) + +// ResponseBundle is the struct sent out by share holder containing the status +// for the deals received in the first phase. type ResponseBundle struct { // Index of the share holder for which these reponses are for ShareIndex uint32 @@ -154,9 +167,14 @@ func (r *ResponseBundle) Hash() []byte { binary.Write(h, binary.BigEndian, byte(0)) } } + h.Write(r.SessionID) return h.Sum(nil) } +func (b *ResponseBundle) Index() Index { + return b.ShareIndex +} + func (b *ResponseBundle) String() string { var s = fmt.Sprintf("ShareHolder %d: ", b.ShareIndex) var arr []string @@ -167,6 +185,10 @@ func (b *ResponseBundle) String() string { return s } +var _ packet = (*JustificationBundle)(nil) + +// JustificationBundle is the struct that contains all justifications for each +// complaint in the precedent phase. type JustificationBundle struct { DealerIndex uint32 Justifications []Justification @@ -191,9 +213,14 @@ func (j *JustificationBundle) Hash() []byte { sbuff, _ := just.Share.MarshalBinary() h.Write(sbuff) } + h.Write(j.SessionID) return h.Sum(nil) } +func (j *JustificationBundle) Index() Index { + return j.DealerIndex +} + type AuthDealBundle struct { Bundle *DealBundle Signature []byte