diff --git a/activation/handler.go b/activation/handler.go index da99dd999d..3960e484a7 100644 --- a/activation/handler.go +++ b/activation/handler.go @@ -146,7 +146,7 @@ func NewHandler( fetcher: fetcher, beacon: beacon, tortoise: tortoise, - malPublisher: &MalfeasancePublisher{}, + malPublisher: &MalfeasancePublisher{}, // TODO(mafa): pass real publisher when available }, } diff --git a/activation/handler_test.go b/activation/handler_test.go index c080012133..7c30c86fd2 100644 --- a/activation/handler_test.go +++ b/activation/handler_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "slices" "sort" "testing" "testing/quick" @@ -128,7 +129,7 @@ type handlerMocks struct { mValidator *MocknipostValidator mbeacon *MockAtxReceiver mtortoise *mocks.MockTortoise - mMalPublish *MockmalfeasancePublisher + mMalPublish *MockatxMalfeasancePublisher } type testHandler struct { @@ -159,6 +160,7 @@ func (h *handlerMocks) expectAtxV1(atx *wire.ActivationTxV1, nodeId types.NodeID } h.mockFetch.EXPECT().RegisterPeerHashes(gomock.Any(), gomock.Any()) h.mockFetch.EXPECT().GetPoetProof(gomock.Any(), types.BytesToHash(atx.NIPost.PostMetadata.Challenge)) + deps := []types.ATXID{atx.PrevATXID, atx.PositioningATXID} if atx.PrevATXID == types.EmptyATXID { h.mValidator.EXPECT().InitialNIPostChallengeV1(gomock.Any(), gomock.Any(), h.goldenATXID) h.mValidator.EXPECT(). @@ -170,9 +172,17 @@ func (h *handlerMocks) expectAtxV1(atx *wire.ActivationTxV1, nodeId types.NodeID time.Sleep(settings.postVerificationDuration) return nil }) + deps = append(deps, *atx.CommitmentATXID) } else { h.mValidator.EXPECT().NIPostChallengeV1(gomock.Any(), gomock.Any(), nodeId) } + deps = slices.Compact(deps) + deps = slices.DeleteFunc(deps, func(dep types.ATXID) bool { + return dep == types.EmptyATXID || dep == h.goldenATXID + }) + if len(deps) > 0 { + h.mockFetch.EXPECT().GetAtxs(gomock.Any(), deps, gomock.Any()) + } h.mValidator.EXPECT().PositioningAtx(atx.PositioningATXID, gomock.Any(), h.goldenATXID, atx.PublishEpoch) h.mValidator.EXPECT(). NIPost(gomock.Any(), nodeId, h.goldenATXID, gomock.Any(), gomock.Any(), atx.NumUnits, gomock.Any()). @@ -194,7 +204,7 @@ func newTestHandlerMocks(tb testing.TB, golden types.ATXID) handlerMocks { mValidator: NewMocknipostValidator(ctrl), mbeacon: NewMockAtxReceiver(ctrl), mtortoise: mocks.NewMockTortoise(ctrl), - mMalPublish: NewMockmalfeasancePublisher(ctrl), + mMalPublish: NewMockatxMalfeasancePublisher(ctrl), } } @@ -205,6 +215,8 @@ func newTestHandler(tb testing.TB, goldenATXID types.ATXID, opts ...HandlerOptio edVerifier := signing.NewEdVerifier() mocks := newTestHandlerMocks(tb, goldenATXID) + // TODO(mafa): make mandatory parameter when real publisher is available + opts = append(opts, func(h *Handler) { h.v2.malPublisher = mocks.mMalPublish }) atxHdlr := NewHandler( "localID", cdb, @@ -341,7 +353,6 @@ func TestHandler_ProcessAtxStoresNewVRFNonce(t *testing.T) { atx2.VRFNonce = (*uint64)(&nonce2) atx2.Sign(sig) atxHdlr.expectAtxV1(atx2, sig.NodeID()) - atxHdlr.mockFetch.EXPECT().GetAtxs(gomock.Any(), gomock.Any(), gomock.Any()) require.NoError(t, atxHdlr.HandleGossipAtx(context.Background(), "", codec.MustEncode(atx2))) got, err = atxs.VRFNonce(atxHdlr.cdb, sig.NodeID(), atx2.PublishEpoch+1) @@ -391,7 +402,6 @@ func TestHandler_HandleGossipAtx(t *testing.T) { // second is now valid (deps are in) atxHdlr.expectAtxV1(second, sig.NodeID()) - atxHdlr.mockFetch.EXPECT().GetAtxs(gomock.Any(), []types.ATXID{second.PrevATXID}, gomock.Any()) require.NoError(t, atxHdlr.HandleGossipAtx(context.Background(), "", codec.MustEncode(second))) } @@ -695,7 +705,6 @@ func TestHandler_AtxWeight(t *testing.T) { buf = codec.MustEncode(atx2) atxHdlr.expectAtxV1(atx2, sig.NodeID(), func(o *atxHandleOpts) { o.poetLeaves = leaves }) - atxHdlr.mockFetch.EXPECT().GetAtxs(gomock.Any(), []types.ATXID{atx1.ID()}, gomock.Any()) require.NoError(t, atxHdlr.HandleSyncedAtx(context.Background(), atx2.ID().Hash32(), peer, buf)) stored2, err := atxHdlr.cdb.GetAtx(atx2.ID()) diff --git a/activation/handler_v2.go b/activation/handler_v2.go index dd3103bcd0..b7386a4dd3 100644 --- a/activation/handler_v2.go +++ b/activation/handler_v2.go @@ -69,7 +69,7 @@ type HandlerV2 struct { tortoise system.Tortoise logger *zap.Logger fetcher system.Fetcher - malPublisher malfeasancePublisher + malPublisher atxMalfeasancePublisher } func (h *HandlerV2) processATX( @@ -744,14 +744,6 @@ func (h *HandlerV2) checkMalicious(ctx context.Context, tx sql.Transaction, atx return true, nil } - malicious, err = h.checkDoublePost(ctx, tx, atx) - if err != nil { - return malicious, fmt.Errorf("checking double post: %w", err) - } - if malicious { - return true, nil - } - malicious, err = h.checkDoubleMerge(ctx, tx, atx) if err != nil { return malicious, fmt.Errorf("checking double merge: %w", err) @@ -815,31 +807,6 @@ func (h *HandlerV2) checkDoubleMarry(ctx context.Context, tx sql.Transaction, at return false, nil } -func (h *HandlerV2) checkDoublePost(ctx context.Context, tx sql.Transaction, atx *activationTx) (bool, error) { - for id := range atx.ids { - atxIDs, err := atxs.FindDoublePublish(tx, id, atx.PublishEpoch) - switch { - case errors.Is(err, sql.ErrNotFound): - continue - case err != nil: - return false, fmt.Errorf("searching for double publish: %w", err) - } - otherAtxId := slices.IndexFunc(atxIDs, func(other types.ATXID) bool { return other != atx.ID() }) - otherAtx := atxIDs[otherAtxId] - h.logger.Debug( - "found ID that has already contributed its PoST in this epoch", - zap.Stringer("node_id", id), - zap.Stringer("atx_id", atx.ID()), - zap.Stringer("other_atx_id", otherAtx), - zap.Uint32("epoch", atx.PublishEpoch.Uint32()), - ) - // TODO(mafa): finish proof - var proof wire.Proof - return true, h.malPublisher.Publish(ctx, id, proof) - } - return false, nil -} - func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx sql.Transaction, atx *activationTx) (bool, error) { if atx.MarriageATX == nil { return false, nil @@ -899,22 +866,63 @@ func (h *HandlerV2) checkPrevAtx(ctx context.Context, tx sql.Transaction, atx *a log.ZShortStringer("expected", expectedPrevID), ) - atx1, atx2, err := atxs.PrevATXCollision(tx, data.previous, id) + collisions, err := atxs.PrevATXCollisions(tx, data.previous, id) switch { case errors.Is(err, sql.ErrNotFound): continue case err != nil: - return false, fmt.Errorf("checking for previous ATX collision: %w", err) + return true, fmt.Errorf("checking for previous ATX collision: %w", err) } + var wireAtxV1 *wire.ActivationTxV1 + for _, collision := range collisions { + if collision == atx.ID() { + continue + } + var blob sql.Blob + v, err := atxs.LoadBlob(ctx, tx, collision.Bytes(), &blob) + if err != nil { + return true, fmt.Errorf("get atx blob %s: %w", id.ShortString(), err) + } + switch v { + case types.AtxV1: + if wireAtxV1 == nil { + // we have at least one v2 ATX (the one we are validating right now) so we only need one + // v1 ATX to create the proof if no other v2 ATXs are found + wireAtxV1 = &wire.ActivationTxV1{} + codec.MustDecode(blob.Bytes, wireAtxV1) + } + case types.AtxV2: + wireAtx := &wire.ActivationTxV2{} + codec.MustDecode(blob.Bytes, wireAtx) + // prefer creating a proof with 2 ATXs of version 2 + h.logger.Debug("creating a malfeasance proof for invalid previous ATX", + log.ZShortStringer("smesherID", id), + log.ZShortStringer("atx1", wireAtx.ID()), + log.ZShortStringer("atx2", atx.ActivationTxV2.ID()), + ) + proof, err := wire.NewInvalidPrevAtxProofV2(tx, atx.ActivationTxV2, wireAtx, id) + if err != nil { + return true, fmt.Errorf("creating invalid previous ATX proof: %w", err) + } + return true, h.malPublisher.Publish(ctx, id, proof) + default: + h.logger.Fatal("Failed to create invalid previous ATX proof: unknown ATX version", + zap.Stringer("atx_id", collision), + ) + } + } + + // no ATXv2 found, create a proof with an ATXv1 h.logger.Debug("creating a malfeasance proof for invalid previous ATX", log.ZShortStringer("smesherID", id), - log.ZShortStringer("atx1", atx1), - log.ZShortStringer("atx2", atx2), + log.ZShortStringer("atx1", wireAtxV1.ID()), + log.ZShortStringer("atx2", atx.ActivationTxV2.ID()), ) - - // TODO(mafa): finish proof - var proof wire.Proof + proof, err := wire.NewInvalidPrevAtxProofV1(tx, atx.ActivationTxV2, wireAtxV1, id) + if err != nil { + return true, fmt.Errorf("creating invalid previous ATX proof: %w", err) + } return true, h.malPublisher.Publish(ctx, id, proof) } return false, nil diff --git a/activation/handler_v2_test.go b/activation/handler_v2_test.go index 0089a90e0c..4c8f127250 100644 --- a/activation/handler_v2_test.go +++ b/activation/handler_v2_test.go @@ -135,7 +135,7 @@ func (h *handlerMocks) expectStoreAtxV2(atx *wire.ActivationTxV2) { } func (h *handlerMocks) expectInitialAtxV2(atx *wire.ActivationTxV2) { - h.mclock.EXPECT().CurrentLayer().Return(postGenesisEpoch.FirstLayer()) + h.mclock.EXPECT().CurrentLayer().Return(atx.PublishEpoch.FirstLayer()) h.mValidator.EXPECT().VRFNonceV2( atx.SmesherID, atx.Initial.CommitmentATX, @@ -175,7 +175,7 @@ func (h *handlerMocks) expectMergedAtxV2( equivocationSet []types.NodeID, poetLeaves []uint64, ) { - h.mclock.EXPECT().CurrentLayer().Return(postGenesisEpoch.FirstLayer()) + h.mclock.EXPECT().CurrentLayer().Return(atx.PublishEpoch.FirstLayer()) h.expectFetchDeps(atx) h.mValidator.EXPECT().VRFNonceV2( atx.SmesherID, @@ -2099,64 +2099,6 @@ func Test_MarryingMalicious(t *testing.T) { t.Run("other is malicious", tc(otherSig.NodeID())) } -func TestContextualValidation_DoublePost(t *testing.T) { - t.Parallel() - golden := types.RandomATXID() - sig, err := signing.NewEdSigner() - require.NoError(t, err) - - atxHandler := newV2TestHandler(t, golden) - - // marry - otherSig, err := signing.NewEdSigner() - require.NoError(t, err) - othersAtx := atxHandler.createAndProcessInitial(otherSig) - - mATX := newInitialATXv2(t, golden) - mATX.Marriages = []wire.MarriageCertificate{ - { - Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), - }, - { - ReferenceAtx: othersAtx.ID(), - Signature: otherSig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), - }, - } - mATX.Sign(sig) - - atxHandler.expectInitialAtxV2(mATX) - err = atxHandler.processATX(context.Background(), "", mATX, time.Now()) - require.NoError(t, err) - - // publish merged - merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) - post := wire.SubPostV2{ - MarriageIndex: 1, - NumUnits: othersAtx.TotalNumUnits(), - PrevATXIndex: 1, - } - merged.NIPosts[0].Posts = append(merged.NIPosts[0].Posts, post) - - mATXID := mATX.ID() - merged.MarriageATX = &mATXID - - merged.PreviousATXs = []types.ATXID{mATX.ID(), othersAtx.ID()} - merged.Sign(sig) - - atxHandler.expectMergedAtxV2(merged, []types.NodeID{sig.NodeID(), otherSig.NodeID()}, []uint64{poetLeaves}) - err = atxHandler.processATX(context.Background(), "", merged, time.Now()) - require.NoError(t, err) - - // The otherSig tries to publish alone in the same epoch. - // This is malfeasance as it tries include his PoST twice. - doubled := newSoloATXv2(t, merged.PublishEpoch, othersAtx.ID(), othersAtx.ID()) - doubled.Sign(otherSig) - atxHandler.expectAtxV2(doubled) - atxHandler.mMalPublish.EXPECT().Publish(gomock.Any(), otherSig.NodeID(), gomock.Any()) - err = atxHandler.processATX(context.Background(), "", doubled, time.Now()) - require.NoError(t, err) -} - func Test_CalculatingUnits(t *testing.T) { t.Parallel() t.Run("units on 1 nipost must not overflow", func(t *testing.T) { @@ -2184,47 +2126,224 @@ func Test_CalculatingUnits(t *testing.T) { } func TestContextual_PreviousATX(t *testing.T) { - golden := types.RandomATXID() - atxHdlr := newV2TestHandler(t, golden) - var ( - signers []*signing.EdSigner - eqSet []types.NodeID - ) - for range 3 { + t.Run("invalid previous ATX, both v2", func(t *testing.T) { + golden := types.RandomATXID() + atxHdlr := newV2TestHandler(t, golden) + var ( + signers []*signing.EdSigner + eqSet []types.NodeID + ) + for range 3 { + sig, err := signing.NewEdSigner() + require.NoError(t, err) + signers = append(signers, sig) + eqSet = append(eqSet, sig.NodeID()) + } + + mATX, otherAtxs := marryIDs(t, atxHdlr, signers, golden) + + // signer 1 creates a solo ATX + soloAtx := newSoloATXv2(t, mATX.PublishEpoch+1, otherAtxs[0].ID(), mATX.ID()) + soloAtx.Sign(signers[1]) + atxHdlr.expectAtxV2(soloAtx) + err := atxHdlr.processATX(context.Background(), "", soloAtx, time.Now()) + require.NoError(t, err) + + // create a MergedATX for all IDs + merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) + post := wire.SubPostV2{ + MarriageIndex: 1, + PrevATXIndex: 1, + NumUnits: soloAtx.TotalNumUnits(), + } + merged.NIPosts[0].Posts = append(merged.NIPosts[0].Posts, post) + // Pass a wrong previous ATX for signer 1. It's already been used for soloATX + // (which should be used for the previous ATX for signer 1). + merged.PreviousATXs = append(merged.PreviousATXs, otherAtxs[0].ID()) + matxID := mATX.ID() + merged.MarriageATX = &matxID + merged.Sign(signers[0]) + + atxHdlr.expectMergedAtxV2(merged, eqSet, []uint64{100}) + + verifier := wire.NewMockMalfeasanceValidator(atxHdlr.ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return atxHdlr.edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + atxHdlr.mMalPublish.EXPECT().Publish( + gomock.Any(), + signers[1].NodeID(), + gomock.AssignableToTypeOf(&wire.ProofInvalidPrevAtxV2{}), + ).DoAndReturn(func(ctx context.Context, _ types.NodeID, proof wire.Proof) error { + malProof := proof.(*wire.ProofInvalidPrevAtxV2) + nId, err := malProof.Valid(ctx, verifier) + require.NoError(t, err) + require.Equal(t, signers[1].NodeID(), nId) + return nil + }) + + err = atxHdlr.processATX(context.Background(), "", merged, time.Now()) + require.NoError(t, err) + }) + + t.Run("invalid previous ATX, v1 and v2", func(t *testing.T) { + golden := types.RandomATXID() + atxHdlr := newTestHandler(t, golden) + + sig1, err := signing.NewEdSigner() + require.NoError(t, err) + + // signer 1 creates a solo ATX + prevATX := newInitialATXv1(t, golden) + prevATX.Sign(sig1) + atxHdlr.expectAtxV1(prevATX, prevATX.SmesherID) + _, err = atxHdlr.v1.processATX(context.Background(), "", prevATX, time.Now()) + require.NoError(t, err) + atxv1 := newChainedActivationTxV1(t, prevATX, prevATX.ID()) + atxv1.Sign(sig1) + atxHdlr.expectAtxV1(atxv1, atxv1.SmesherID) + _, err = atxHdlr.v1.processATX(context.Background(), "", atxv1, time.Now()) + require.NoError(t, err) + + soloAtx := newSoloATXv2(t, atxv1.PublishEpoch+1, atxv1.ID(), atxv1.ID()) + soloAtx.Sign(sig1) + atxHdlr.expectAtxV2(soloAtx) + err = atxHdlr.v2.processATX(context.Background(), "", soloAtx, time.Now()) + require.NoError(t, err) + + sig2, err := signing.NewEdSigner() + require.NoError(t, err) + mATX := newInitialATXv2(t, golden) + mATX.Marriages = []wire.MarriageCertificate{ + { + ReferenceAtx: types.EmptyATXID, + Signature: sig2.Sign(signing.MARRIAGE, sig2.NodeID().Bytes()), + }, + { + ReferenceAtx: soloAtx.ID(), + Signature: sig1.Sign(signing.MARRIAGE, sig2.NodeID().Bytes()), + }, + } + mATX.PublishEpoch = soloAtx.PublishEpoch + mATX.Sign(sig2) + atxHdlr.expectInitialAtxV2(mATX) + err = atxHdlr.v2.processATX(context.Background(), "", mATX, time.Now()) + require.NoError(t, err) + + // create a MergedATX for all IDs + merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) + post := wire.SubPostV2{ + MarriageIndex: 1, + PrevATXIndex: 1, + NumUnits: soloAtx.TotalNumUnits(), + } + merged.NIPosts[0].Posts = append(merged.NIPosts[0].Posts, post) + merged.PreviousATXs = append(merged.PreviousATXs, prevATX.ID()) + merged.MarriageATX = new(types.ATXID) + *merged.MarriageATX = mATX.ID() + merged.Sign(sig2) + + atxHdlr.expectMergedAtxV2(merged, []types.NodeID{sig1.NodeID(), sig2.NodeID()}, []uint64{100}) + + verifier := wire.NewMockMalfeasanceValidator(atxHdlr.ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return atxHdlr.edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + atxHdlr.mMalPublish.EXPECT().Publish( + gomock.Any(), + sig1.NodeID(), + gomock.AssignableToTypeOf(&wire.ProofInvalidPrevAtxV1{}), + ).DoAndReturn(func(ctx context.Context, _ types.NodeID, proof wire.Proof) error { + malProof := proof.(*wire.ProofInvalidPrevAtxV1) + nId, err := malProof.Valid(ctx, verifier) + require.NoError(t, err) + require.Equal(t, sig1.NodeID(), nId) + return nil + }) + + err = atxHdlr.v2.processATX(context.Background(), "", merged, time.Now()) + require.NoError(t, err) + }) + + t.Run("double publish", func(t *testing.T) { + t.Parallel() + golden := types.RandomATXID() sig, err := signing.NewEdSigner() require.NoError(t, err) - signers = append(signers, sig) - eqSet = append(eqSet, sig.NodeID()) - } - mATX, otherAtxs := marryIDs(t, atxHdlr, signers, golden) + atxHdlr := newV2TestHandler(t, golden) - // signer 1 creates a solo ATX - soloAtx := newSoloATXv2(t, mATX.PublishEpoch+1, otherAtxs[0].ID(), mATX.ID()) - soloAtx.Sign(signers[1]) - atxHdlr.expectAtxV2(soloAtx) - err := atxHdlr.processATX(context.Background(), "", soloAtx, time.Now()) - require.NoError(t, err) + // marry + otherSig, err := signing.NewEdSigner() + require.NoError(t, err) + othersAtx := atxHdlr.createAndProcessInitial(otherSig) - // create a MergedATX for all IDs - merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) - post := wire.SubPostV2{ - MarriageIndex: 1, - PrevATXIndex: 1, - NumUnits: soloAtx.TotalNumUnits(), - } - merged.NIPosts[0].Posts = append(merged.NIPosts[0].Posts, post) - // Pass a wrong previous ATX for signer 1. It's already been used for soloATX - // (which should be used for the previous ATX for signer 1). - merged.PreviousATXs = append(merged.PreviousATXs, otherAtxs[0].ID()) - matxID := mATX.ID() - merged.MarriageATX = &matxID - merged.Sign(signers[0]) - - atxHdlr.expectMergedAtxV2(merged, eqSet, []uint64{100}) - atxHdlr.mMalPublish.EXPECT().Publish(gomock.Any(), signers[1].NodeID(), gomock.Any()) - err = atxHdlr.processATX(context.Background(), "", merged, time.Now()) - require.NoError(t, err) + mATX := newInitialATXv2(t, golden) + mATX.Marriages = []wire.MarriageCertificate{ + { + Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + { + ReferenceAtx: othersAtx.ID(), + Signature: otherSig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + } + mATX.Sign(sig) + + atxHdlr.expectInitialAtxV2(mATX) + err = atxHdlr.processATX(context.Background(), "", mATX, time.Now()) + require.NoError(t, err) + + // publish merged + merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) + post := wire.SubPostV2{ + MarriageIndex: 1, + NumUnits: othersAtx.TotalNumUnits(), + PrevATXIndex: 1, + } + merged.NIPosts[0].Posts = append(merged.NIPosts[0].Posts, post) + + mATXID := mATX.ID() + merged.MarriageATX = &mATXID + + merged.PreviousATXs = []types.ATXID{mATX.ID(), othersAtx.ID()} + merged.Sign(sig) + + atxHdlr.expectMergedAtxV2(merged, []types.NodeID{sig.NodeID(), otherSig.NodeID()}, []uint64{poetLeaves}) + err = atxHdlr.processATX(context.Background(), "", merged, time.Now()) + require.NoError(t, err) + + // The otherSig tries to publish alone in the same epoch. + // This is malfeasance as it tries include his PoST twice. + doubled := newSoloATXv2(t, merged.PublishEpoch, othersAtx.ID(), othersAtx.ID()) + doubled.Sign(otherSig) + atxHdlr.expectAtxV2(doubled) + + verifier := wire.NewMockMalfeasanceValidator(atxHdlr.ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return atxHdlr.edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + atxHdlr.mMalPublish.EXPECT().Publish( + gomock.Any(), + otherSig.NodeID(), + gomock.AssignableToTypeOf(&wire.ProofInvalidPrevAtxV2{}), + ).DoAndReturn(func(ctx context.Context, _ types.NodeID, proof wire.Proof) error { + malProof := proof.(*wire.ProofInvalidPrevAtxV2) + nId, err := malProof.Valid(ctx, verifier) + require.NoError(t, err) + require.Equal(t, otherSig.NodeID(), nId) + return nil + }) + + err = atxHdlr.processATX(context.Background(), "", doubled, time.Now()) + require.NoError(t, err) + }) } func Test_CalculatingWeight(t *testing.T) { diff --git a/activation/interface.go b/activation/interface.go index c9c3359091..38c8cf1332 100644 --- a/activation/interface.go +++ b/activation/interface.go @@ -92,7 +92,7 @@ type syncer interface { RegisterForATXSynced() <-chan struct{} } -// malfeasancePublisher is an interface for publishing malfeasance proofs. +// atxMalfeasancePublisher is an interface for publishing malfeasance proofs. // This interface is used to publish proofs in V2. // // The provider of that interface ensures that only valid proofs are published (invalid ones return an error). @@ -100,7 +100,7 @@ type syncer interface { // // Additionally the publisher will only gossip proofs when the node is in sync, otherwise it will only store them // and mark the associated identity as malfeasant. -type malfeasancePublisher interface { +type atxMalfeasancePublisher interface { Publish(ctx context.Context, id types.NodeID, proof wire.Proof) error } diff --git a/activation/mocks.go b/activation/mocks.go index 985f6a05f3..0ab71a7524 100644 --- a/activation/mocks.go +++ b/activation/mocks.go @@ -1092,32 +1092,32 @@ func (c *MocksyncerRegisterForATXSyncedCall) DoAndReturn(f func() <-chan struct{ return c } -// MockmalfeasancePublisher is a mock of malfeasancePublisher interface. -type MockmalfeasancePublisher struct { +// MockatxMalfeasancePublisher is a mock of atxMalfeasancePublisher interface. +type MockatxMalfeasancePublisher struct { ctrl *gomock.Controller - recorder *MockmalfeasancePublisherMockRecorder + recorder *MockatxMalfeasancePublisherMockRecorder isgomock struct{} } -// MockmalfeasancePublisherMockRecorder is the mock recorder for MockmalfeasancePublisher. -type MockmalfeasancePublisherMockRecorder struct { - mock *MockmalfeasancePublisher +// MockatxMalfeasancePublisherMockRecorder is the mock recorder for MockatxMalfeasancePublisher. +type MockatxMalfeasancePublisherMockRecorder struct { + mock *MockatxMalfeasancePublisher } -// NewMockmalfeasancePublisher creates a new mock instance. -func NewMockmalfeasancePublisher(ctrl *gomock.Controller) *MockmalfeasancePublisher { - mock := &MockmalfeasancePublisher{ctrl: ctrl} - mock.recorder = &MockmalfeasancePublisherMockRecorder{mock} +// NewMockatxMalfeasancePublisher creates a new mock instance. +func NewMockatxMalfeasancePublisher(ctrl *gomock.Controller) *MockatxMalfeasancePublisher { + mock := &MockatxMalfeasancePublisher{ctrl: ctrl} + mock.recorder = &MockatxMalfeasancePublisherMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockmalfeasancePublisher) EXPECT() *MockmalfeasancePublisherMockRecorder { +func (m *MockatxMalfeasancePublisher) EXPECT() *MockatxMalfeasancePublisherMockRecorder { return m.recorder } // Publish mocks base method. -func (m *MockmalfeasancePublisher) Publish(ctx context.Context, id types.NodeID, proof wire.Proof) error { +func (m *MockatxMalfeasancePublisher) Publish(ctx context.Context, id types.NodeID, proof wire.Proof) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Publish", ctx, id, proof) ret0, _ := ret[0].(error) @@ -1125,31 +1125,31 @@ func (m *MockmalfeasancePublisher) Publish(ctx context.Context, id types.NodeID, } // Publish indicates an expected call of Publish. -func (mr *MockmalfeasancePublisherMockRecorder) Publish(ctx, id, proof any) *MockmalfeasancePublisherPublishCall { +func (mr *MockatxMalfeasancePublisherMockRecorder) Publish(ctx, id, proof any) *MockatxMalfeasancePublisherPublishCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Publish", reflect.TypeOf((*MockmalfeasancePublisher)(nil).Publish), ctx, id, proof) - return &MockmalfeasancePublisherPublishCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Publish", reflect.TypeOf((*MockatxMalfeasancePublisher)(nil).Publish), ctx, id, proof) + return &MockatxMalfeasancePublisherPublishCall{Call: call} } -// MockmalfeasancePublisherPublishCall wrap *gomock.Call -type MockmalfeasancePublisherPublishCall struct { +// MockatxMalfeasancePublisherPublishCall wrap *gomock.Call +type MockatxMalfeasancePublisherPublishCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockmalfeasancePublisherPublishCall) Return(arg0 error) *MockmalfeasancePublisherPublishCall { +func (c *MockatxMalfeasancePublisherPublishCall) Return(arg0 error) *MockatxMalfeasancePublisherPublishCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MockmalfeasancePublisherPublishCall) Do(f func(context.Context, types.NodeID, wire.Proof) error) *MockmalfeasancePublisherPublishCall { +func (c *MockatxMalfeasancePublisherPublishCall) Do(f func(context.Context, types.NodeID, wire.Proof) error) *MockatxMalfeasancePublisherPublishCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockmalfeasancePublisherPublishCall) DoAndReturn(f func(context.Context, types.NodeID, wire.Proof) error) *MockmalfeasancePublisherPublishCall { +func (c *MockatxMalfeasancePublisherPublishCall) DoAndReturn(f func(context.Context, types.NodeID, wire.Proof) error) *MockatxMalfeasancePublisherPublishCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/activation/wire/malfeasance_double_marry.go b/activation/wire/malfeasance_double_marry.go index ac7a760c1f..fc2a98e545 100644 --- a/activation/wire/malfeasance_double_marry.go +++ b/activation/wire/malfeasance_double_marry.go @@ -15,12 +15,10 @@ import ( // ProofDoubleMarry is a proof that two distinct ATXs contain a marriage certificate signed by the same identity. // // We are proving the following: -// 1. The ATXs have different IDs. -// 2. Both ATXs have a valid signature. -// 3. Both ATXs contain a marriage certificate created by the same identity. -// 4. Both marriage certificates have valid signatures. -// -// HINT: this works if the identity that publishes the marriage ATX marries themselves. +// 1. The ATXs have different IDs. +// 2. Both ATXs have a valid signature. +// 3. Both ATXs contain a marriage certificate created by the same identity. +// 4. Both marriage certificates have valid signatures. type ProofDoubleMarry struct { // NodeID is the node ID that married twice. NodeID types.NodeID diff --git a/activation/wire/malfeasance_double_marry_test.go b/activation/wire/malfeasance_double_marry_test.go index f9f686503a..f52f8c8559 100644 --- a/activation/wire/malfeasance_double_marry_test.go +++ b/activation/wire/malfeasance_double_marry_test.go @@ -3,7 +3,6 @@ package wire import ( "context" "fmt" - "slices" "testing" "github.com/stretchr/testify/require" @@ -16,6 +15,8 @@ import ( ) func Test_DoubleMarryProof(t *testing.T) { + t.Parallel() + sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -25,7 +26,9 @@ func Test_DoubleMarryProof(t *testing.T) { edVerifier := signing.NewEdVerifier() t.Run("valid", func(t *testing.T) { + t.Parallel() db := statesql.InMemoryTest(t) + otherAtx := &types.ActivationTx{} otherAtx.SetID(types.RandomATXID()) otherAtx.SmesherID = otherSig.NodeID() @@ -59,37 +62,55 @@ func Test_DoubleMarryProof(t *testing.T) { require.Equal(t, otherSig.NodeID(), id) }) - t.Run("does not contain same certificate owner", func(t *testing.T) { + t.Run("identity is not included in both ATXs", func(t *testing.T) { + t.Parallel() db := statesql.InMemoryTest(t) + otherAtx := &types.ActivationTx{} + otherAtx.SetID(types.RandomATXID()) + otherAtx.SmesherID = otherSig.NodeID() + require.NoError(t, atxs.Add(db, otherAtx, types.AtxBlob{})) + atx1 := newActivationTxV2( withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), + withMarriageCertificate(otherSig, otherAtx.ID(), sig.NodeID()), ) atx1.Sign(sig) atx2 := newActivationTxV2( withMarriageCertificate(otherSig, types.EmptyATXID, otherSig.NodeID()), + withMarriageCertificate(sig, atx1.ID(), otherSig.NodeID()), ) atx2.Sign(otherSig) + marriages := make([]MarriageCertificate, len(atx1.Marriages)) + copy(marriages, atx1.Marriages) + atx1.Marriages = marriages[:1] proof, err := NewDoubleMarryProof(db, atx1, atx2, otherSig.NodeID()) - require.ErrorContains(t, err, fmt.Sprintf( + require.EqualError(t, err, fmt.Sprintf( "proof for atx1: does not contain a marriage certificate signed by %s", otherSig.NodeID().ShortString(), )) require.Nil(t, proof) + atx1.Marriages = marriages + marriages = make([]MarriageCertificate, len(atx2.Marriages)) + copy(marriages, atx2.Marriages) + atx2.Marriages = marriages[:1] proof, err = NewDoubleMarryProof(db, atx1, atx2, sig.NodeID()) - require.ErrorContains(t, err, fmt.Sprintf( + require.EqualError(t, err, fmt.Sprintf( "proof for atx2: does not contain a marriage certificate signed by %s", sig.NodeID().ShortString(), )) require.Nil(t, proof) + atx2.Marriages = marriages }) t.Run("same ATX ID", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + atx1 := newActivationTxV2() atx1.Sign(sig) - db := statesql.InMemoryTest(t) proof, err := NewDoubleMarryProof(db, atx1, atx1, sig.NodeID()) require.ErrorContains(t, err, "ATXs have the same ID") require.Nil(t, proof) @@ -108,128 +129,10 @@ func Test_DoubleMarryProof(t *testing.T) { require.Equal(t, types.EmptyNodeID, id) }) - t.Run("invalid marriage proof", func(t *testing.T) { + t.Run("invalid proof", func(t *testing.T) { + t.Parallel() db := statesql.InMemoryTest(t) - otherAtx := &types.ActivationTx{} - otherAtx.SetID(types.RandomATXID()) - otherAtx.SmesherID = otherSig.NodeID() - require.NoError(t, atxs.Add(db, otherAtx, types.AtxBlob{})) - - atx1 := newActivationTxV2( - withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), - withMarriageCertificate(otherSig, otherAtx.ID(), sig.NodeID()), - ) - atx1.Sign(sig) - - atx2 := newActivationTxV2( - withMarriageCertificate(otherSig, types.EmptyATXID, otherSig.NodeID()), - withMarriageCertificate(sig, atx1.ID(), otherSig.NodeID()), - ) - atx2.Sign(otherSig) - - // manually construct an invalid proof - proof1, err := createMarryProof(db, atx1, otherSig.NodeID()) - require.NoError(t, err) - proof2, err := createMarryProof(db, atx2, otherSig.NodeID()) - require.NoError(t, err) - - proof := &ProofDoubleMarry{ - NodeID: otherSig.NodeID(), - - ATX1: atx1.ID(), - SmesherID1: atx1.SmesherID, - Signature1: atx1.Signature, - Proof1: proof1, - - ATX2: atx2.ID(), - SmesherID2: atx2.SmesherID, - Signature2: atx2.Signature, - Proof2: proof2, - } - - ctrl := gomock.NewController(t) - verifier := NewMockMalfeasanceValidator(ctrl) - verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { - return edVerifier.Verify(d, nodeID, m, sig) - }).AnyTimes() - proof.Proof1.MarriageCertificatesProof = slices.Clone(proof1.MarriageCertificatesProof) - proof.Proof1.MarriageCertificatesProof[0] = types.RandomHash() - id, err := proof.Valid(context.Background(), verifier) - require.ErrorContains(t, err, "proof 1 is invalid: invalid marriage proof") - require.Equal(t, types.EmptyNodeID, id) - - proof.Proof1.MarriageCertificatesProof[0] = proof1.MarriageCertificatesProof[0] - proof.Proof2.MarriageCertificatesProof = slices.Clone(proof2.MarriageCertificatesProof) - proof.Proof2.MarriageCertificatesProof[0] = types.RandomHash() - id, err = proof.Valid(context.Background(), verifier) - require.ErrorContains(t, err, "proof 2 is invalid: invalid marriage proof") - require.Equal(t, types.EmptyNodeID, id) - }) - - t.Run("invalid certificate proof", func(t *testing.T) { - db := statesql.InMemoryTest(t) - otherAtx := &types.ActivationTx{} - otherAtx.SetID(types.RandomATXID()) - otherAtx.SmesherID = otherSig.NodeID() - require.NoError(t, atxs.Add(db, otherAtx, types.AtxBlob{})) - - atx1 := newActivationTxV2( - withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), - withMarriageCertificate(otherSig, otherAtx.ID(), sig.NodeID()), - ) - atx1.Sign(sig) - - atx2 := newActivationTxV2( - withMarriageCertificate(otherSig, types.EmptyATXID, otherSig.NodeID()), - withMarriageCertificate(sig, atx1.ID(), otherSig.NodeID()), - ) - atx2.Sign(otherSig) - - // manually construct an invalid proof - proof1, err := createMarryProof(db, atx1, otherSig.NodeID()) - require.NoError(t, err) - proof2, err := createMarryProof(db, atx2, otherSig.NodeID()) - require.NoError(t, err) - - proof := &ProofDoubleMarry{ - NodeID: otherSig.NodeID(), - - ATX1: atx1.ID(), - SmesherID1: atx1.SmesherID, - Signature1: atx1.Signature, - Proof1: proof1, - - ATX2: atx2.ID(), - SmesherID2: atx2.SmesherID, - Signature2: atx2.Signature, - Proof2: proof2, - } - - ctrl := gomock.NewController(t) - verifier := NewMockMalfeasanceValidator(ctrl) - verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { - return edVerifier.Verify(d, nodeID, m, sig) - }).AnyTimes() - - proof.Proof1.CertificateProof = slices.Clone(proof1.CertificateProof) - proof.Proof1.CertificateProof[0] = types.RandomHash() - id, err := proof.Valid(context.Background(), verifier) - require.ErrorContains(t, err, "proof 1 is invalid: invalid certificate proof") - require.Equal(t, types.EmptyNodeID, id) - - proof.Proof1.CertificateProof[0] = proof1.CertificateProof[0] - proof.Proof2.CertificateProof = slices.Clone(proof2.CertificateProof) - proof.Proof2.CertificateProof[0] = types.RandomHash() - id, err = proof.Valid(context.Background(), verifier) - require.ErrorContains(t, err, "proof 2 is invalid: invalid certificate proof") - require.Equal(t, types.EmptyNodeID, id) - }) - - t.Run("invalid atx signature", func(t *testing.T) { - db := statesql.InMemoryTest(t) otherAtx := &types.ActivationTx{} otherAtx.SetID(types.RandomATXID()) otherAtx.SmesherID = otherSig.NodeID() @@ -257,76 +160,46 @@ func Test_DoubleMarryProof(t *testing.T) { return edVerifier.Verify(d, nodeID, m, sig) }).AnyTimes() + // invalid signature for ATX1 proof.Signature1 = types.RandomEdSignature() id, err := proof.Valid(context.Background(), verifier) require.ErrorContains(t, err, "invalid signature for ATX1") require.Equal(t, types.EmptyNodeID, id) - proof.Signature1 = atx1.Signature + + // invalid signature for ATX2 proof.Signature2 = types.RandomEdSignature() id, err = proof.Valid(context.Background(), verifier) require.ErrorContains(t, err, "invalid signature for ATX2") require.Equal(t, types.EmptyNodeID, id) - }) - - t.Run("invalid certificate signature", func(t *testing.T) { - db := statesql.InMemoryTest(t) - otherAtx := &types.ActivationTx{} - otherAtx.SetID(types.RandomATXID()) - otherAtx.SmesherID = otherSig.NodeID() - require.NoError(t, atxs.Add(db, otherAtx, types.AtxBlob{})) - - atx1 := newActivationTxV2( - withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), - withMarriageCertificate(otherSig, otherAtx.ID(), sig.NodeID()), - ) - atx1.Sign(sig) + proof.Signature2 = atx2.Signature - atx2 := newActivationTxV2( - withMarriageCertificate(otherSig, types.EmptyATXID, sig.NodeID()), - withMarriageCertificate(sig, atx1.ID(), sig.NodeID()), - ) - atx2.Sign(otherSig) - - proof, err := NewDoubleMarryProof(db, atx1, atx2, otherSig.NodeID()) - require.NoError(t, err) - - ctrl := gomock.NewController(t) - verifier := NewMockMalfeasanceValidator(ctrl) - verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { - return edVerifier.Verify(d, nodeID, m, sig) - }).AnyTimes() - - proof.Proof1.Certificate.Signature = types.RandomEdSignature() - id, err := proof.Valid(context.Background(), verifier) - require.ErrorContains(t, err, "proof 1 is invalid: invalid certificate signature") + // invalid smesher ID for ATX1 + proof.SmesherID1 = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid signature for ATX1") require.Equal(t, types.EmptyNodeID, id) + proof.SmesherID1 = atx1.SmesherID - proof.Proof1.Certificate.Signature = atx1.Marriages[1].Signature - proof.Proof2.Certificate.Signature = types.RandomEdSignature() + // invalid smesher ID for ATX2 + proof.SmesherID2 = types.RandomNodeID() id, err = proof.Valid(context.Background(), verifier) - require.ErrorContains(t, err, "proof 2 is invalid: invalid certificate signature") + require.ErrorContains(t, err, "invalid signature for ATX2") require.Equal(t, types.EmptyNodeID, id) - }) - - t.Run("unknown reference ATX", func(t *testing.T) { - db := statesql.InMemoryTest(t) + proof.SmesherID2 = atx2.SmesherID - atx1 := newActivationTxV2( - withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), - withMarriageCertificate(otherSig, types.RandomATXID(), sig.NodeID()), // unknown reference ATX - ) - atx1.Sign(sig) - - atx2 := newActivationTxV2( - withMarriageCertificate(otherSig, types.EmptyATXID, sig.NodeID()), - withMarriageCertificate(sig, atx1.ID(), sig.NodeID()), - ) - atx2.Sign(otherSig) + // invalid ATX ID for ATX1 + proof.ATX1 = types.RandomATXID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid signature for ATX1") + require.Equal(t, types.EmptyNodeID, id) + proof.ATX1 = atx1.ID() - proof, err := NewDoubleMarryProof(db, atx1, atx2, otherSig.NodeID()) - require.Error(t, err) - require.Nil(t, proof) + // invalid ATX ID for ATX2 + proof.ATX2 = types.RandomATXID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid signature for ATX2") + require.Equal(t, types.EmptyNodeID, id) + proof.ATX2 = atx2.ID() }) } diff --git a/activation/wire/malfeasance_double_merge.go b/activation/wire/malfeasance_double_merge.go index 7bbaf51e2d..3b3f73194a 100644 --- a/activation/wire/malfeasance_double_merge.go +++ b/activation/wire/malfeasance_double_merge.go @@ -126,7 +126,6 @@ func NewDoubleMergeProof(db sql.Executor, atx1, atx2 *ActivationTxV2) (*ProofDou return &proof, nil } -// Valid implements Proof.Valid. func (p *ProofDoubleMerge) Valid(_ context.Context, edVerifier MalfeasanceValidator) (types.NodeID, error) { // 1. The ATXs have different IDs. if p.ATXID1 == p.ATXID2 { diff --git a/activation/wire/malfeasance_double_merge_test.go b/activation/wire/malfeasance_double_merge_test.go index 707e27ee58..83ba3a2b4a 100644 --- a/activation/wire/malfeasance_double_merge_test.go +++ b/activation/wire/malfeasance_double_merge_test.go @@ -14,11 +14,13 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/statesql" ) -func Test_NewDoubleMergeProof(t *testing.T) { - signer1, err := signing.NewEdSigner() +func Test_DoubleMergeProof(t *testing.T) { + t.Parallel() + + sig, err := signing.NewEdSigner() require.NoError(t, err) - signer2, err := signing.NewEdSigner() + otherSig, err := signing.NewEdSigner() require.NoError(t, err) marrySig, err := signing.NewEdSigner() @@ -30,27 +32,27 @@ func Test_NewDoubleMergeProof(t *testing.T) { wInitialAtx1 := newActivationTxV2( withInitial(types.RandomATXID(), PostV1{}), ) - wInitialAtx1.Sign(signer1) + wInitialAtx1.Sign(sig) initialAtx1 := &types.ActivationTx{ CommitmentATX: &wInitialAtx1.Initial.CommitmentATX, } initialAtx1.SetID(wInitialAtx1.ID()) - initialAtx1.SmesherID = signer1.NodeID() + initialAtx1.SmesherID = sig.NodeID() require.NoError(t, atxs.Add(db, initialAtx1, wInitialAtx1.Blob())) wInitialAtx2 := newActivationTxV2( withInitial(types.RandomATXID(), PostV1{}), ) - wInitialAtx2.Sign(signer2) + wInitialAtx2.Sign(otherSig) initialAtx2 := &types.ActivationTx{} initialAtx2.SetID(wInitialAtx2.ID()) - initialAtx2.SmesherID = signer2.NodeID() + initialAtx2.SmesherID = otherSig.NodeID() require.NoError(t, atxs.Add(db, initialAtx2, wInitialAtx2.Blob())) wMarriageAtx := newActivationTxV2( withMarriageCertificate(marrySig, types.EmptyATXID, marrySig.NodeID()), - withMarriageCertificate(signer1, wInitialAtx1.ID(), marrySig.NodeID()), - withMarriageCertificate(signer2, wInitialAtx2.ID(), marrySig.NodeID()), + withMarriageCertificate(sig, wInitialAtx1.ID(), marrySig.NodeID()), + withMarriageCertificate(otherSig, wInitialAtx2.ID(), marrySig.NodeID()), ) wMarriageAtx.Sign(marrySig) @@ -61,78 +63,142 @@ func Test_NewDoubleMergeProof(t *testing.T) { return wMarriageAtx } - t.Run("ATXs must be different", func(t *testing.T) { + t.Run("valid", func(t *testing.T) { t.Parallel() db := statesql.InMemoryTest(t) - atx := &ActivationTxV2{} - atx.Sign(signer1) - proof, err := NewDoubleMergeProof(db, atx, atx) - require.ErrorContains(t, err, "ATXs have the same ID") + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + marriageAtx := setupMarriage(db) + + atx1 := newActivationTxV2( + withMarriageATX(marriageAtx.ID()), + withPublishEpoch(marriageAtx.PublishEpoch+1), + ) + atx1.Sign(sig) + + atx2 := newActivationTxV2( + withMarriageATX(marriageAtx.ID()), + withPublishEpoch(marriageAtx.PublishEpoch+1), + ) + atx2.Sign(otherSig) + + proof, err := NewDoubleMergeProof(db, atx1, atx2) + require.NoError(t, err) + id, err := proof.Valid(context.Background(), verifier) + require.NoError(t, err) + require.Equal(t, sig.NodeID(), id) + }) + + t.Run("same ATX ID", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + marriageAtx := setupMarriage(db) + + atx1 := newActivationTxV2( + withMarriageATX(marriageAtx.ID()), + withPublishEpoch(marriageAtx.PublishEpoch+1), + ) + atx1.Sign(sig) + + proof, err := NewDoubleMergeProof(db, atx1, atx1) + require.EqualError(t, err, "ATXs have the same ID") require.Nil(t, proof) + + proof = &ProofDoubleMerge{ + ATXID1: atx1.ID(), + ATXID2: atx1.ID(), + } + id, err := proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "ATXs have the same ID") + require.Equal(t, types.EmptyNodeID, id) }) t.Run("ATXs must have different signers", func(t *testing.T) { - // Note: catching this scenario is the responsibility of "invalid previous ATX proof" t.Parallel() db := statesql.InMemoryTest(t) - atx1 := &ActivationTxV2{} - atx1.Sign(signer1) + atx1 := newActivationTxV2() + atx1.Sign(sig) - atx2 := &ActivationTxV2{VRFNonce: 1} - atx2.Sign(signer1) + atx2 := newActivationTxV2() + atx2.Sign(sig) proof, err := NewDoubleMergeProof(db, atx1, atx2) require.ErrorContains(t, err, "ATXs have the same smesher") require.Nil(t, proof) }) - t.Run("ATXs must have marriage ATX", func(t *testing.T) { + t.Run("ATXs must be published in the same epoch", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + atx := newActivationTxV2( + withPublishEpoch(1), + ) + atx.Sign(sig) + + atx2 := newActivationTxV2( + withPublishEpoch(2), + ) + atx2.Sign(otherSig) + proof, err := NewDoubleMergeProof(db, atx, atx2) + require.ErrorContains(t, err, "ATXs have different publish epoch") + require.Nil(t, proof) + }) + + t.Run("ATXs must have valid marriage ATX", func(t *testing.T) { t.Parallel() db := statesql.InMemoryTest(t) - atx := &ActivationTxV2{} - atx.Sign(signer1) - atx2 := &ActivationTxV2{VRFNonce: 1} - atx2.Sign(signer2) + atx := newActivationTxV2( + withPublishEpoch(1), + ) + atx.Sign(sig) + + atx2 := newActivationTxV2( + withPublishEpoch(1), + ) + atx2.Sign(otherSig) // ATX 1 has no marriage - _, err := NewDoubleMergeProof(db, atx, atx2) + proof, err := NewDoubleMergeProof(db, atx, atx2) require.ErrorContains(t, err, "ATX 1 have no marriage ATX") + require.Nil(t, proof) // ATX 2 has no marriage atx.MarriageATX = new(types.ATXID) *atx.MarriageATX = types.RandomATXID() - _, err = NewDoubleMergeProof(db, atx, atx2) + proof, err = NewDoubleMergeProof(db, atx, atx2) require.ErrorContains(t, err, "ATX 2 have no marriage ATX") + require.Nil(t, proof) // ATX 1 and 2 must have the same marriage ATX atx2.MarriageATX = new(types.ATXID) *atx2.MarriageATX = types.RandomATXID() - _, err = NewDoubleMergeProof(db, atx, atx2) + proof, err = NewDoubleMergeProof(db, atx, atx2) require.ErrorContains(t, err, "ATXs have different marriage ATXs") - }) - - t.Run("ATXs must be published in the same epoch", func(t *testing.T) { - t.Parallel() - db := statesql.InMemoryTest(t) - marriageID := types.RandomATXID() - atx := &ActivationTxV2{ - MarriageATX: &marriageID, - } - atx.Sign(signer1) + require.Nil(t, proof) - atx2 := &ActivationTxV2{ - MarriageATX: &marriageID, - PublishEpoch: 1, - } - atx2.Sign(signer2) - proof, err := NewDoubleMergeProof(db, atx, atx2) - require.ErrorContains(t, err, "ATXs have different publish epoch") + // Marriage ATX must be valid + atx2.MarriageATX = atx.MarriageATX + proof, err = NewDoubleMergeProof(db, atx, atx2) + require.ErrorIs(t, err, sql.ErrNotFound) require.Nil(t, proof) }) - t.Run("valid proof", func(t *testing.T) { + t.Run("invalid proof", func(t *testing.T) { t.Parallel() db := statesql.InMemoryTest(t) @@ -149,59 +215,106 @@ func Test_NewDoubleMergeProof(t *testing.T) { withMarriageATX(marriageAtx.ID()), withPublishEpoch(marriageAtx.PublishEpoch+1), ) - atx1.Sign(signer1) + atx1.Sign(sig) atx2 := newActivationTxV2( withMarriageATX(marriageAtx.ID()), withPublishEpoch(marriageAtx.PublishEpoch+1), ) - atx2.Sign(signer2) + atx2.Sign(otherSig) proof, err := NewDoubleMergeProof(db, atx1, atx2) require.NoError(t, err) + + // invalid marriage ATX ID + marriageAtxID := proof.MarriageATX + proof.MarriageATX = types.RandomATXID() id, err := proof.Valid(context.Background(), verifier) - require.NoError(t, err) - require.Equal(t, signer1.NodeID(), id) + require.EqualError(t, err, "ATX 1 invalid marriage ATX proof") + require.Equal(t, types.EmptyNodeID, id) + proof.MarriageATX = marriageAtxID - require.Equal(t, signer1.NodeID(), proof.SmesherID1) - require.Equal(t, signer2.NodeID(), proof.SmesherID2) - }) + // invalid marriage ATX smesher ID + smesherID := proof.MarriageATXSmesherID + proof.MarriageATXSmesherID = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "ATX 1 invalid marriage ATX proof") + require.Equal(t, types.EmptyNodeID, id) + proof.MarriageATXSmesherID = smesherID - t.Run("invalid marriage proof", func(t *testing.T) { - t.Parallel() - db := statesql.InMemoryTest(t) + // invalid ATX1 ID + id1 := proof.ATXID1 + proof.ATXID1 = types.RandomATXID() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "ATX 1 invalid signature") + require.Equal(t, types.EmptyNodeID, id) + proof.ATXID1 = id1 - ctrl := gomock.NewController(t) - verifier := NewMockMalfeasanceValidator(ctrl) - verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { - return edVerifier.Verify(d, nodeID, m, sig) - }).AnyTimes() + // invalid ATX2 ID + id2 := proof.ATXID2 + proof.ATXID2 = types.RandomATXID() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "ATX 2 invalid signature") + require.Equal(t, types.EmptyNodeID, id) + proof.ATXID2 = id2 - marriageAtx := setupMarriage(db) + // invalid ATX1 smesher ID + smesherID1 := proof.SmesherID1 + proof.SmesherID1 = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "ATX 1 invalid signature") + require.Equal(t, types.EmptyNodeID, id) + proof.SmesherID1 = smesherID1 - atx1 := newActivationTxV2( - withMarriageATX(marriageAtx.ID()), - withPublishEpoch(marriageAtx.PublishEpoch+1), - ) - atx1.Sign(signer1) + // invalid ATX2 smesher ID + smesherID2 := proof.SmesherID2 + proof.SmesherID2 = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "ATX 2 invalid signature") + require.Equal(t, types.EmptyNodeID, id) + proof.SmesherID2 = smesherID2 - atx2 := newActivationTxV2( - withMarriageATX(marriageAtx.ID()), - withPublishEpoch(marriageAtx.PublishEpoch+1), - ) - atx2.Sign(signer2) + // invalid ATX1 signature + signature1 := proof.Signature1 + proof.Signature1 = types.RandomEdSignature() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "ATX 1 invalid signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Signature1 = signature1 - proof, err := NewDoubleMergeProof(db, atx1, atx2) - require.NoError(t, err) + // invalid ATX2 signature + signature2 := proof.Signature2 + proof.Signature2 = types.RandomEdSignature() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "ATX 2 invalid signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Signature2 = signature2 + + // invalid publish epoch proof 1 + hash := proof.PublishEpochProof1[0] + proof.PublishEpochProof1[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "ATX 1 invalid publish epoch proof") + require.Equal(t, types.EmptyNodeID, id) + proof.PublishEpochProof1[0] = hash + + // invalid publish epoch proof 2 + hash = proof.PublishEpochProof2[0] + proof.PublishEpochProof2[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "ATX 2 invalid publish epoch proof") + require.Equal(t, types.EmptyNodeID, id) + proof.PublishEpochProof2[0] = hash - hash := proof.MarriageATXProof1[0] + // invalid marriage ATX proof 1 + hash = proof.MarriageATXProof1[0] proof.MarriageATXProof1[0] = types.RandomHash() - id, err := proof.Valid(context.Background(), verifier) + id, err = proof.Valid(context.Background(), verifier) require.EqualError(t, err, "ATX 1 invalid marriage ATX proof") require.Equal(t, types.EmptyNodeID, id) proof.MarriageATXProof1[0] = hash + // invalid marriage ATX proof 2 hash = proof.MarriageATXProof2[0] proof.MarriageATXProof2[0] = types.RandomHash() id, err = proof.Valid(context.Background(), verifier) @@ -210,117 +323,3 @@ func Test_NewDoubleMergeProof(t *testing.T) { proof.MarriageATXProof2[0] = hash }) } - -func Test_Validate_DoubleMergeProof(t *testing.T) { - edVerifier := signing.NewEdVerifier() - - ctrl := gomock.NewController(t) - verifier := NewMockMalfeasanceValidator(ctrl) - verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { - return edVerifier.Verify(d, nodeID, m, sig) - }).AnyTimes() - - t.Run("ATXs must have different IDs", func(t *testing.T) { - t.Parallel() - id := types.RandomATXID() - proof := &ProofDoubleMerge{ - ATXID1: id, - ATXID2: id, - } - _, err := proof.Valid(context.Background(), verifier) - require.ErrorContains(t, err, "ATXs have the same ID") - }) - t.Run("ATX 1 must have valid signature", func(t *testing.T) { - t.Parallel() - proof := &ProofDoubleMerge{ - ATXID1: types.RandomATXID(), - ATXID2: types.RandomATXID(), - } - _, err := proof.Valid(context.Background(), verifier) - require.ErrorContains(t, err, "ATX 1 invalid signature") - }) - t.Run("ATX 2 must have valid signature", func(t *testing.T) { - t.Parallel() - - atx1 := &ActivationTxV2{} - signer, err := signing.NewEdSigner() - require.NoError(t, err) - atx1.Sign(signer) - - proof := &ProofDoubleMerge{ - ATXID1: atx1.ID(), - SmesherID1: atx1.SmesherID, - Signature1: atx1.Signature, - - ATXID2: types.RandomATXID(), - } - _, err = proof.Valid(context.Background(), verifier) - require.ErrorContains(t, err, "ATX 2 invalid signature") - }) - - t.Run("epoch proof for ATX1 must be valid", func(t *testing.T) { - t.Parallel() - - atx1 := &ActivationTxV2{} - signer, err := signing.NewEdSigner() - require.NoError(t, err) - atx1.Sign(signer) - - atx2 := &ActivationTxV2{ - NIPosts: []NIPostV2{ - { - Challenge: types.RandomHash(), - }, - }, - } - signer2, err := signing.NewEdSigner() - require.NoError(t, err) - atx2.Sign(signer2) - - proof := &ProofDoubleMerge{ - ATXID1: atx1.ID(), - SmesherID1: atx1.SmesherID, - Signature1: atx1.Signature, - - ATXID2: atx2.ID(), - SmesherID2: atx2.SmesherID, - Signature2: atx2.Signature, - } - _, err = proof.Valid(context.Background(), verifier) - require.ErrorContains(t, err, "ATX 1 invalid publish epoch proof") - }) - - t.Run("epoch proof for ATX2 must be valid", func(t *testing.T) { - t.Parallel() - - atx1 := &ActivationTxV2{} - signer, err := signing.NewEdSigner() - require.NoError(t, err) - atx1.Sign(signer) - - atx2 := &ActivationTxV2{ - NIPosts: []NIPostV2{ - { - Challenge: types.RandomHash(), - }, - }, - } - signer2, err := signing.NewEdSigner() - require.NoError(t, err) - atx2.Sign(signer2) - - proof := &ProofDoubleMerge{ - ATXID1: atx1.ID(), - SmesherID1: atx1.SmesherID, - Signature1: atx1.Signature, - PublishEpochProof1: atx1.PublishEpochProof(), - - ATXID2: atx2.ID(), - SmesherID2: atx2.SmesherID, - Signature2: atx2.Signature, - } - _, err = proof.Valid(context.Background(), verifier) - require.ErrorContains(t, err, "ATX 2 invalid publish epoch proof") - }) -} diff --git a/activation/wire/malfeasance_invalid_post.go b/activation/wire/malfeasance_invalid_post.go index 5cffbc68e2..ddb792c9f0 100644 --- a/activation/wire/malfeasance_invalid_post.go +++ b/activation/wire/malfeasance_invalid_post.go @@ -16,10 +16,10 @@ import ( // ProofInvalidPost is a proof that a merged ATX with an invalid Post was published by a smesher. // // We are proofing the following: -// 1. The ATX has a valid signature. -// 2. If NodeID is different from SmesherID, we prove that NodeID and SmesherID are married. -// 3. The commitment ATX of NodeID used for the invalid PoST based on their initial ATX. -// 4. The provided Post is invalid for the given NodeID. +// 1. The ATX has a valid signature. +// 2. If NodeID is different from SmesherID, we prove that NodeID and SmesherID are married. +// 3. The commitment ATX of NodeID used for the invalid PoST based on their initial ATX. +// 4. The provided Post is invalid for the given NodeID. type ProofInvalidPost struct { // ATXID is the ID of the ATX containing the invalid PoST. ATXID types.ATXID @@ -156,8 +156,8 @@ type InvalidPostProof struct { SubPostRootProof SubPostRootProof `scale:"max=32"` SubPostRootIndex uint16 - // MarriageIndexProof is the proof that the MarriageIndex (CertificateIndex from MarryProof) is contained in the - // SubPostRoot. + // MarriageIndexProof is the proof that the MarriageIndex (CertificateIndex from NodeIDMarryProof) is contained in + // the SubPostRoot. MarriageIndexProof MarriageIndexProof `scale:"max=32"` // Post is the invalid PoST and its proof that it is contained in the SubPostRoot. diff --git a/activation/wire/malfeasance_invalid_post_test.go b/activation/wire/malfeasance_invalid_post_test.go index 89a245c92f..c0bea896a8 100644 --- a/activation/wire/malfeasance_invalid_post_test.go +++ b/activation/wire/malfeasance_invalid_post_test.go @@ -20,6 +20,8 @@ import ( ) func Test_InvalidPostProof(t *testing.T) { + t.Parallel() + // sig is the identity that creates the invalid PoST sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -122,6 +124,7 @@ func Test_InvalidPostProof(t *testing.T) { } t.Run("valid", func(t *testing.T) { + t.Parallel() db := statesql.InMemoryTest(t) nipostChallenge := types.RandomHash() @@ -172,6 +175,7 @@ func Test_InvalidPostProof(t *testing.T) { }) t.Run("valid merged atx", func(t *testing.T) { + t.Parallel() db := statesql.InMemoryTest(t) nipostChallenge := types.RandomHash() @@ -222,6 +226,7 @@ func Test_InvalidPostProof(t *testing.T) { }) t.Run("post is valid", func(t *testing.T) { + t.Parallel() db := statesql.InMemoryTest(t) nipostChallenge := types.RandomHash() @@ -312,6 +317,7 @@ func Test_InvalidPostProof(t *testing.T) { }) t.Run("differing node ID without marriage ATX", func(t *testing.T) { + t.Parallel() db := statesql.InMemoryTest(t) nipostChallenge := types.RandomHash() @@ -330,26 +336,10 @@ func Test_InvalidPostProof(t *testing.T) { proof, err := NewInvalidPostProof(db, atx, commitmentAtx, nodeID, 0, invalidPostIndex, validPostIndex) require.EqualError(t, err, "ATX is not a merged ATX, but NodeID is different from SmesherID") require.Nil(t, proof) - - proof, err = NewInvalidPostProof(db, atx, commitmentAtx, sig.NodeID(), 0, invalidPostIndex, validPostIndex) - require.NoError(t, err) - require.NotNil(t, proof) - - ctrl := gomock.NewController(t) - verifier := NewMockMalfeasanceValidator(ctrl) - verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { - return edVerifier.Verify(d, nodeID, m, sig) - }).AnyTimes() - - proof.NodeID = types.RandomNodeID() // invalid node ID - - id, err := proof.Valid(context.Background(), verifier) - require.EqualError(t, err, "missing marriage proof") - require.Equal(t, types.EmptyNodeID, id) }) - t.Run("node ID not in marriage ATX", func(t *testing.T) { + t.Run("nipost index is invalid", func(t *testing.T) { + t.Parallel() db := statesql.InMemoryTest(t) nipostChallenge := types.RandomHash() @@ -359,20 +349,18 @@ func Test_InvalidPostProof(t *testing.T) { Indices: types.RandomBytes(11), Pow: rand.Uint64(), } - atx := newMergedATXv2(db, nipostChallenge, post, numUnits) + atx := newSoloATXv2(db, nipostChallenge, post, numUnits) commitmentAtx := types.RandomATXID() const invalidPostIndex = 7 const validPostIndex = 15 - nodeID := types.RandomNodeID() - proof, err := NewInvalidPostProof(db, atx, commitmentAtx, nodeID, 0, invalidPostIndex, validPostIndex) - require.ErrorContains(t, err, - fmt.Sprintf("does not contain a marriage certificate signed by %s", nodeID.ShortString()), - ) + proof, err := NewInvalidPostProof(db, atx, commitmentAtx, sig.NodeID(), 1, invalidPostIndex, validPostIndex) + require.EqualError(t, err, "invalid NIPoST index") require.Nil(t, proof) }) - t.Run("invalid marriage proof", func(t *testing.T) { + t.Run("node ID not in marriage ATX", func(t *testing.T) { + t.Parallel() db := statesql.InMemoryTest(t) nipostChallenge := types.RandomHash() @@ -383,30 +371,20 @@ func Test_InvalidPostProof(t *testing.T) { Pow: rand.Uint64(), } atx := newMergedATXv2(db, nipostChallenge, post, numUnits) + commitmentAtx := types.RandomATXID() - ctrl := gomock.NewController(t) - verifier := NewMockMalfeasanceValidator(ctrl) - verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { - return edVerifier.Verify(d, nodeID, m, sig) - }).AnyTimes() - - // manually construct an invalid proof - proof, err := createMarriageProof(db, atx, sig.NodeID()) - require.NoError(t, err) - - marriageATX := proof.MarriageATX - proof.MarriageATX = types.RandomATXID() // invalid ATX - err = proof.Valid(verifier, atx.ID(), sig.NodeID(), pubSig.NodeID()) - require.ErrorContains(t, err, "invalid marriage ATX proof") - - proof.MarriageATX = marriageATX - proof.MarriageATXProof[0] = types.RandomHash() // invalid proof - err = proof.Valid(verifier, atx.ID(), sig.NodeID(), pubSig.NodeID()) - require.ErrorContains(t, err, "invalid marriage ATX proof") + const invalidPostIndex = 7 + const validPostIndex = 15 + nodeID := types.RandomNodeID() + proof, err := NewInvalidPostProof(db, atx, commitmentAtx, nodeID, 0, invalidPostIndex, validPostIndex) + require.ErrorContains(t, err, + fmt.Sprintf("does not contain a marriage certificate signed by %s", nodeID.ShortString()), + ) + require.Nil(t, proof) }) t.Run("node ID did not include post in merged ATX", func(t *testing.T) { + t.Parallel() db := statesql.InMemoryTest(t) nipostChallenge := types.RandomHash() @@ -429,28 +407,8 @@ func Test_InvalidPostProof(t *testing.T) { require.Nil(t, proof) }) - t.Run("invalid nipost index", func(t *testing.T) { - db := statesql.InMemoryTest(t) - - nipostChallenge := types.RandomHash() - const numUnits = uint32(11) - post := PostV1{ - Nonce: rand.Uint32(), - Indices: types.RandomBytes(11), - Pow: rand.Uint64(), - } - atx := newSoloATXv2(db, nipostChallenge, post, numUnits) - commitmentAtx := types.RandomATXID() - - const invalidPostIndex = 7 - const validPostIndex = 15 - // 1 is an invalid nipostIndex for this ATX - proof, err := NewInvalidPostProof(db, atx, commitmentAtx, sig.NodeID(), 1, invalidPostIndex, validPostIndex) - require.EqualError(t, err, "invalid NIPoST index") - require.Nil(t, proof) - }) - - t.Run("invalid ATX signature", func(t *testing.T) { + t.Run("invalid solo proof", func(t *testing.T) { + t.Parallel() db := statesql.InMemoryTest(t) nipostChallenge := types.RandomHash() @@ -475,122 +433,166 @@ func Test_InvalidPostProof(t *testing.T) { return edVerifier.Verify(d, nodeID, m, sig) }).AnyTimes() - proof.Signature = types.RandomEdSignature() // invalid signature - + // invalid ATXID + proof.ATXID = types.RandomATXID() id, err := proof.Valid(context.Background(), verifier) require.EqualError(t, err, "invalid signature") require.Equal(t, types.EmptyNodeID, id) - }) + proof.ATXID = atx.ID() - t.Run("solo invalid post proof is not valid", func(t *testing.T) { - db := statesql.InMemoryTest(t) + // invalid smesher ID + proof.SmesherID = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "invalid signature") + require.Equal(t, types.EmptyNodeID, id) + proof.SmesherID = atx.SmesherID - nipostChallenge := types.RandomHash() - const numUnits = uint32(11) - post := PostV1{ + // invalid signature + proof.Signature = types.RandomEdSignature() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "invalid signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Signature = atx.Signature + + // invalid node ID + proof.NodeID = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "missing marriage proof") + require.Equal(t, types.EmptyNodeID, id) + proof.NodeID = sig.NodeID() + + // invalid niposts root + nipostsRoot := proof.InvalidPostProof.NIPostsRoot + proof.InvalidPostProof.NIPostsRoot = NIPostsRoot(types.RandomHash()) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid NIPosts root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.NIPostsRoot = nipostsRoot + + // invalid niposts root proof + hash := proof.InvalidPostProof.NIPostsRootProof[0] + proof.InvalidPostProof.NIPostsRootProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid NIPosts root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.NIPostsRootProof[0] = hash + + // invalid nipost root + nipostRoot := proof.InvalidPostProof.NIPostRoot + proof.InvalidPostProof.NIPostRoot = NIPostRoot(types.RandomHash()) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid NIPoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.NIPostRoot = nipostRoot + + // invalid nipost root proof + hash = proof.InvalidPostProof.NIPostRootProof[0] + proof.InvalidPostProof.NIPostRootProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid NIPoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.NIPostRootProof[0] = hash + + // invalid nipost index + proof.InvalidPostProof.NIPostIndex = 1 + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid NIPoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.NIPostIndex = 0 + + // invalid challenge + challenge := proof.InvalidPostProof.Challenge + proof.InvalidPostProof.Challenge = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid challenge proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.Challenge = challenge + + // invalid challenge proof + hash = proof.InvalidPostProof.ChallengeProof[0] + proof.InvalidPostProof.ChallengeProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid challenge proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.ChallengeProof[0] = hash + + // invalid subposts root + subPostsRoot := proof.InvalidPostProof.SubPostsRoot + proof.InvalidPostProof.SubPostsRoot = SubPostsRoot(types.RandomHash()) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid sub PoSTs root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.SubPostsRoot = subPostsRoot + + // invalid subposts root proof + hash = proof.InvalidPostProof.SubPostsRootProof[0] + proof.InvalidPostProof.SubPostsRootProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid sub PoSTs root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.SubPostsRootProof[0] = hash + + // invalid subpost root + subPostRoot := proof.InvalidPostProof.SubPostRoot + proof.InvalidPostProof.SubPostRoot = SubPostRoot(types.RandomHash()) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid sub PoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.SubPostRoot = subPostRoot + + // invalid subpost root proof + hash = proof.InvalidPostProof.SubPostRootProof[0] + proof.InvalidPostProof.SubPostRootProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid sub PoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.SubPostRootProof[0] = hash + + // invalid subpost root index + proof.InvalidPostProof.SubPostRootIndex++ + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid sub PoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.SubPostRootIndex-- + + // invalid post + post = proof.InvalidPostProof.Post + proof.InvalidPostProof.Post = PostV1{ Nonce: rand.Uint32(), Indices: types.RandomBytes(11), Pow: rand.Uint64(), } - atx := newSoloATXv2(db, nipostChallenge, post, numUnits) - commitmentAtx := types.RandomATXID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid post proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.Post = post - ctrl := gomock.NewController(t) - verifier := NewMockMalfeasanceValidator(ctrl) - verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { - return edVerifier.Verify(d, nodeID, m, sig) - }).AnyTimes() + // invalid post proof + hash = proof.InvalidPostProof.PostProof[0] + proof.InvalidPostProof.PostProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid post proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.PostProof[0] = hash - // manually construct an invalid proof - const invalidPostIndex = 7 - const validPostIndex = 15 - proof, err := createInvalidPostProof(atx, commitmentAtx, 0, 0, invalidPostIndex, validPostIndex) - require.NoError(t, err) - require.NotNil(t, proof) - - nipostsRoot := proof.NIPostsRoot - proof.NIPostsRoot = NIPostsRoot(types.RandomHash()) // invalid root - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), nil) - require.EqualError(t, err, "invalid NIPosts root proof") - proof.NIPostsRoot = nipostsRoot - - proofHash := proof.NIPostsRootProof[0] - proof.NIPostsRootProof[0] = types.RandomHash() // invalid proof - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), nil) - require.EqualError(t, err, "invalid NIPosts root proof") - proof.NIPostsRootProof[0] = proofHash - - proof.NIPostIndex = 1 // invalid index - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), nil) - require.EqualError(t, err, "invalid NIPoST root proof") - proof.NIPostIndex = 0 - - nipostRoot := proof.NIPostRoot - proof.NIPostRoot = NIPostRoot(types.RandomHash()) // invalid root - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), nil) - require.EqualError(t, err, "invalid NIPoST root proof") - proof.NIPostRoot = nipostRoot - - proofHash = proof.NIPostRootProof[0] - proof.NIPostRootProof[0] = types.RandomHash() // invalid proof - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), nil) - require.EqualError(t, err, "invalid NIPoST root proof") - proof.NIPostRootProof[0] = proofHash - - challenge := proof.Challenge - proof.Challenge = types.RandomHash() // invalid challenge - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), nil) - require.EqualError(t, err, "invalid challenge proof") - proof.Challenge = challenge - - proofHash = proof.ChallengeProof[0] - proof.ChallengeProof[0] = types.RandomHash() // invalid proof - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), nil) - require.EqualError(t, err, "invalid challenge proof") - proof.ChallengeProof[0] = proofHash - - subPostsRoot := proof.SubPostsRoot - proof.SubPostsRoot = SubPostsRoot(types.RandomHash()) // invalid root - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), nil) - require.EqualError(t, err, "invalid sub PoSTs root proof") - proof.SubPostsRoot = subPostsRoot - - proofHash = proof.SubPostsRootProof[0] - proof.SubPostsRootProof[0] = types.RandomHash() // invalid proof - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), nil) - require.EqualError(t, err, "invalid sub PoSTs root proof") - proof.SubPostsRootProof[0] = proofHash - - proof.SubPostRootIndex = 1 // invalid index - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), nil) - require.EqualError(t, err, "invalid sub PoST root proof") - proof.SubPostRootIndex = 0 - - subPost := proof.SubPostRoot - proof.SubPostRoot = SubPostRoot(types.RandomHash()) // invalid root - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), nil) - require.EqualError(t, err, "invalid sub PoST root proof") - proof.SubPostRoot = subPost - - proofHash = proof.SubPostRootProof[0] - proof.SubPostRootProof[0] = types.RandomHash() // invalid proof - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), nil) - require.EqualError(t, err, "invalid sub PoST root proof") - proof.SubPostRootProof[0] = proofHash - - proof.Post = PostV1{} // invalid post - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), nil) - require.EqualError(t, err, "invalid PoST proof") - proof.Post = post - - proof.NumUnits++ // invalid number of units - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), nil) - require.EqualError(t, err, "invalid num units proof") - proof.NumUnits-- + // invalid numunits + proof.InvalidPostProof.NumUnits++ + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid post proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.NumUnits-- + + // invalid numunits proof + hash = proof.InvalidPostProof.NumUnitsProof[0] + proof.InvalidPostProof.NumUnitsProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid post proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.NumUnitsProof[0] = hash }) - t.Run("merged invalid post proof is not valid", func(t *testing.T) { + t.Run("invalid merged proof", func(t *testing.T) { + t.Parallel() db := statesql.InMemoryTest(t) nipostChallenge := types.RandomHash() @@ -601,6 +603,12 @@ func Test_InvalidPostProof(t *testing.T) { Pow: rand.Uint64(), } atx := newMergedATXv2(db, nipostChallenge, post, numUnits) + commitmentAtx := types.RandomATXID() + + const invalidPostIndex = 7 + const validPostIndex = 15 + proof, err := NewInvalidPostProof(db, atx, commitmentAtx, sig.NodeID(), 0, invalidPostIndex, validPostIndex) + require.NoError(t, err) ctrl := gomock.NewController(t) verifier := NewMockMalfeasanceValidator(ctrl) @@ -609,18 +617,48 @@ func Test_InvalidPostProof(t *testing.T) { return edVerifier.Verify(d, nodeID, m, sig) }).AnyTimes() - // manually construct an invalid proof - marriageIndex := uint32(1) - commitmentAtx := types.RandomATXID() - const invalidPostIndex = 7 - const validPostIndex = 15 - proof, err := createInvalidPostProof(atx, commitmentAtx, 0, 1, invalidPostIndex, validPostIndex) - require.NoError(t, err) - require.NotNil(t, proof) + // invalid ATXID + proof.ATXID = types.RandomATXID() + id, err := proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "invalid signature") + require.Equal(t, types.EmptyNodeID, id) + proof.ATXID = atx.ID() - invalidMarriageIndex := marriageIndex + 1 + // invalid smesher ID + proof.SmesherID = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "invalid signature") + require.Equal(t, types.EmptyNodeID, id) + proof.SmesherID = atx.SmesherID - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), &invalidMarriageIndex) - require.EqualError(t, err, "invalid marriage index proof") + // invalid signature + proof.Signature = types.RandomEdSignature() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "invalid signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Signature = atx.Signature + + // invalid node ID + proof.NodeID = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid marriage proof for NodeID") + require.Equal(t, types.EmptyNodeID, id) + proof.NodeID = sig.NodeID() + + // invalid marriage index proof + hash := proof.InvalidPostProof.MarriageIndexProof[0] + proof.InvalidPostProof.MarriageIndexProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid marriage index proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.MarriageIndexProof[0] = hash + + // invalid numunits proof + hash = proof.InvalidPostProof.NumUnitsProof[0] + proof.InvalidPostProof.NumUnitsProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid post proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.NumUnitsProof[0] = hash }) } diff --git a/activation/wire/malfeasance_invalid_prev_atx.go b/activation/wire/malfeasance_invalid_prev_atx.go new file mode 100644 index 0000000000..abb5acb5f6 --- /dev/null +++ b/activation/wire/malfeasance_invalid_prev_atx.go @@ -0,0 +1,343 @@ +package wire + +import ( + "context" + "errors" + "fmt" + "slices" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/signing" + "github.com/spacemeshos/go-spacemesh/sql" +) + +//go:generate scalegen + +// ProofInvalidPrevAtxV2 is a proof that two distinct ATXs reference the same previous ATX for one of the included +// identities. +// +// We are proving the following: +// 1. The ATXs have different IDs. +// 2. Both ATXs have a valid signature. +// 3. Both ATXs reference the same previous ATX for the same identity. +// 4. If the signer of one of the two ATXs is not the identity that referenced the same previous ATX, then the identity +// that did is married to the signer via a valid marriage certificate in the referenced marriage ATX. +type ProofInvalidPrevAtxV2 struct { + // NodeID is the node ID that referenced the same previous ATX twice. + NodeID types.NodeID + + // PrevATX is the ATX that was referenced twice. + PrevATX types.ATXID + + Proofs [2]InvalidPrevAtxProof +} + +var _ Proof = &ProofInvalidPrevAtxV2{} + +func NewInvalidPrevAtxProofV2( + db sql.Executor, + atx1, atx2 *ActivationTxV2, + nodeID types.NodeID, +) (*ProofInvalidPrevAtxV2, error) { + if atx1.ID() == atx2.ID() { + return nil, errors.New("ATXs have the same ID") + } + + if atx1.SmesherID != nodeID && atx1.MarriageATX == nil { + return nil, errors.New("ATX1 is not a merged ATX, but NodeID is different from SmesherID") + } + + if atx2.SmesherID != nodeID && atx2.MarriageATX == nil { + return nil, errors.New("ATX2 is not a merged ATX, but NodeID is different from SmesherID") + } + + var marriageProof1 *MarriageProof + nipostIndex1 := 0 + postIndex1 := 0 + if atx1.SmesherID != nodeID { + proof, err := createMarriageProof(db, atx1, nodeID) + if err != nil { + return nil, fmt.Errorf("marriage proof: %w", err) + } + marriageProof1 = &proof + for i, nipost := range atx1.NIPosts { + postIndex1 = slices.IndexFunc(nipost.Posts, func(post SubPostV2) bool { + return post.MarriageIndex == proof.NodeIDMarryProof.CertificateIndex + }) + if postIndex1 != -1 { + nipostIndex1 = i + break + } + } + if postIndex1 == -1 { + return nil, fmt.Errorf("no PoST from %s in ATX", nodeID.ShortString()) + } + } + + var marriageProof2 *MarriageProof + nipostIndex2 := 0 + postIndex2 := 0 + if atx2.SmesherID != nodeID { + proof, err := createMarriageProof(db, atx2, nodeID) + if err != nil { + return nil, fmt.Errorf("marriage proof: %w", err) + } + marriageProof2 = &proof + for i, nipost := range atx2.NIPosts { + postIndex2 = slices.IndexFunc(nipost.Posts, func(post SubPostV2) bool { + return post.MarriageIndex == proof.NodeIDMarryProof.CertificateIndex + }) + if postIndex2 != -1 { + nipostIndex2 = i + break + } + } + if postIndex2 == -1 { + return nil, fmt.Errorf("no PoST from %s in ATX", nodeID.ShortString()) + } + } + + prevATX1 := atx1.PreviousATXs[atx1.NIPosts[nipostIndex1].Posts[postIndex1].PrevATXIndex] + prevATX2 := atx2.PreviousATXs[atx2.NIPosts[nipostIndex2].Posts[postIndex2].PrevATXIndex] + if prevATX1 != prevATX2 { + return nil, errors.New("ATXs reference different previous ATXs") + } + + proof1, err := createInvalidPrevAtxProof(atx1, prevATX1, nipostIndex1, postIndex1, marriageProof1) + if err != nil { + return nil, fmt.Errorf("proof for atx1: %w", err) + } + + proof2, err := createInvalidPrevAtxProof(atx2, prevATX2, nipostIndex2, postIndex2, marriageProof2) + if err != nil { + return nil, fmt.Errorf("proof for atx2: %w", err) + } + + proof := &ProofInvalidPrevAtxV2{ + NodeID: nodeID, + PrevATX: prevATX1, + Proofs: [2]InvalidPrevAtxProof{proof1, proof2}, + } + return proof, nil +} + +func createInvalidPrevAtxProof( + atx *ActivationTxV2, + prevATX types.ATXID, + nipostIndex, + postIndex int, + marriageProof *MarriageProof, +) (InvalidPrevAtxProof, error) { + proof := InvalidPrevAtxProof{ + ATXID: atx.ID(), + + NIPostsRoot: atx.NIPosts.Root(atx.PreviousATXs), + NIPostsRootProof: atx.NIPostsRootProof(), + + NIPostRoot: atx.NIPosts[nipostIndex].Root(atx.PreviousATXs), + NIPostRootProof: atx.NIPosts.Proof(int(nipostIndex), atx.PreviousATXs), + NIPostIndex: uint16(nipostIndex), + + SubPostsRoot: atx.NIPosts[nipostIndex].Posts.Root(atx.PreviousATXs), + SubPostsRootProof: atx.NIPosts[nipostIndex].PostsRootProof(atx.PreviousATXs), + + SubPostRoot: atx.NIPosts[nipostIndex].Posts[postIndex].Root(atx.PreviousATXs), + SubPostRootProof: atx.NIPosts[nipostIndex].Posts.Proof(postIndex, atx.PreviousATXs), + SubPostRootIndex: uint16(postIndex), + + MarriageIndexProof: atx.NIPosts[nipostIndex].Posts[postIndex].MarriageIndexProof(atx.PreviousATXs), + MarriageProof: marriageProof, + + PrevATXProof: atx.NIPosts[nipostIndex].Posts[postIndex].PrevATXProof(prevATX), + + SmesherID: atx.SmesherID, + Signature: atx.Signature, + } + + return proof, nil +} + +func (p ProofInvalidPrevAtxV2) Valid(_ context.Context, malValidator MalfeasanceValidator) (types.NodeID, error) { + if p.Proofs[0].ATXID == p.Proofs[1].ATXID { + return types.EmptyNodeID, errors.New("proofs have the same ATX ID") + } + if err := p.Proofs[0].Valid(p.PrevATX, p.NodeID, malValidator); err != nil { + return types.EmptyNodeID, fmt.Errorf("proof 1 is invalid: %w", err) + } + if err := p.Proofs[1].Valid(p.PrevATX, p.NodeID, malValidator); err != nil { + return types.EmptyNodeID, fmt.Errorf("proof 2 is invalid: %w", err) + } + return p.NodeID, nil +} + +// ProofInvalidPrevAtxV1 is a proof that two ATXs published by an identity reference the same previous ATX for an +// identity. +// +// We are proving the following: +// 1. Both ATXs have a valid signature. +// 2. Both ATXs reference the same previous ATX for the same identity. +// 3. If the signer of the ATXv2 is not the identity that referenced the same previous ATX, then the included marriage +// proof is valid. +// 4. The ATXv1 has been signed by the identity that referenced the same previous ATX. +type ProofInvalidPrevAtxV1 struct { + // NodeID is the node ID that referenced the same previous ATX twice. + NodeID types.NodeID + + // PrevATX is the ATX that was referenced twice. + PrevATX types.ATXID + + Proof InvalidPrevAtxProof + ATXv1 ActivationTxV1 +} + +var _ Proof = &ProofInvalidPrevAtxV1{} + +func NewInvalidPrevAtxProofV1( + db sql.Executor, + atx1 *ActivationTxV2, + atx2 *ActivationTxV1, + nodeID types.NodeID, +) (*ProofInvalidPrevAtxV1, error) { + if atx1.SmesherID != nodeID && atx1.MarriageATX == nil { + return nil, errors.New("ATX1 is not a merged ATX, but NodeID is different from SmesherID") + } + + if atx2.SmesherID != nodeID { + return nil, errors.New("ATX2 is not signed by NodeID") + } + + var marriageProof *MarriageProof + nipostIndex := 0 + postIndex := 0 + if atx1.SmesherID != nodeID { + proof, err := createMarriageProof(db, atx1, nodeID) + if err != nil { + return nil, fmt.Errorf("marriage proof: %w", err) + } + marriageProof = &proof + for i, nipost := range atx1.NIPosts { + postIndex = slices.IndexFunc(nipost.Posts, func(post SubPostV2) bool { + return post.MarriageIndex == proof.NodeIDMarryProof.CertificateIndex + }) + if postIndex != -1 { + nipostIndex = i + break + } + } + if postIndex == -1 { + return nil, fmt.Errorf("no PoST from %s in ATX", nodeID.ShortString()) + } + } + prevATX1 := atx1.PreviousATXs[atx1.NIPosts[nipostIndex].Posts[postIndex].PrevATXIndex] + prevATX2 := atx2.PrevATXID + if prevATX1 != prevATX2 { + return nil, errors.New("ATXs reference different previous ATXs") + } + + proof, err := createInvalidPrevAtxProof(atx1, prevATX1, nipostIndex, postIndex, marriageProof) + if err != nil { + return nil, fmt.Errorf("proof for atx1: %w", err) + } + + return &ProofInvalidPrevAtxV1{ + NodeID: nodeID, + PrevATX: prevATX1, + Proof: proof, + ATXv1: *atx2, + }, nil +} + +func (p ProofInvalidPrevAtxV1) Valid(_ context.Context, malValidator MalfeasanceValidator) (types.NodeID, error) { + if err := p.Proof.Valid(p.PrevATX, p.NodeID, malValidator); err != nil { + return types.EmptyNodeID, fmt.Errorf("proof is invalid: %w", err) + } + if !malValidator.Signature(signing.ATX, p.ATXv1.SmesherID, p.ATXv1.SignedBytes(), p.ATXv1.Signature) { + return types.EmptyNodeID, errors.New("invalid ATX signature") + } + if p.NodeID != p.ATXv1.SmesherID { + return types.EmptyNodeID, errors.New("ATXv1 has not been signed by the same identity") + } + if p.ATXv1.PrevATXID != p.PrevATX { + return types.EmptyNodeID, errors.New("ATXv1 references a different previous ATX") + } + return p.NodeID, nil +} + +type InvalidPrevAtxProof struct { + // ATXID is the ID of the ATX being proven. + ATXID types.ATXID + // SmesherID is the ID of the smesher that published the ATX. + SmesherID types.NodeID + // Signature is the signature of the ATXID by the smesher. + Signature types.EdSignature + + // NIPostsRoot and its proof that it is contained in the ATX. + NIPostsRoot NIPostsRoot + NIPostsRootProof NIPostsRootProof `scale:"max=32"` + + // NIPostRoot and its proof that it is contained at the given index in the NIPostsRoot. + NIPostRoot NIPostRoot + NIPostRootProof NIPostRootProof `scale:"max=32"` + NIPostIndex uint16 + + // SubPostsRoot and its proof that it is contained in the NIPostRoot. + SubPostsRoot SubPostsRoot + SubPostsRootProof SubPostsRootProof `scale:"max=32"` + + // SubPostRoot and its proof that is contained at the given index in the SubPostsRoot. + SubPostRoot SubPostRoot + SubPostRootProof SubPostRootProof `scale:"max=32"` + SubPostRootIndex uint16 + + // MarriageProof is the proof that NodeID and SmesherID are married. It is nil if NodeID == SmesherID. + MarriageProof *MarriageProof + + // MarriageIndexProof is the proof that the MarriageIndex (CertificateIndex from NodeIDMarryProof) is contained in + // the SubPostRoot. + MarriageIndexProof MarriageIndexProof `scale:"max=32"` + + // PrevATXProof is the proof that the previous ATX is contained in the SubPostRoot. + PrevATXProof PrevATXProof `scale:"max=32"` +} + +func (p InvalidPrevAtxProof) Valid(prevATX types.ATXID, nodeID types.NodeID, malValidator MalfeasanceValidator) error { + if !malValidator.Signature(signing.ATX, p.SmesherID, p.ATXID.Bytes(), p.Signature) { + return errors.New("invalid ATX signature") + } + + if nodeID != p.SmesherID && p.MarriageProof == nil { + return errors.New("missing marriage proof") + } + + if !p.NIPostsRootProof.Valid(p.ATXID, p.NIPostsRoot) { + return errors.New("invalid NIPosts root proof") + } + if !p.NIPostRootProof.Valid(p.NIPostsRoot, int(p.NIPostIndex), p.NIPostRoot) { + return errors.New("invalid NIPoST root proof") + } + if !p.SubPostsRootProof.Valid(p.NIPostRoot, p.SubPostsRoot) { + return errors.New("invalid sub PoSTs root proof") + } + if !p.SubPostRootProof.Valid(p.SubPostsRoot, int(p.SubPostRootIndex), p.SubPostRoot) { + return errors.New("invalid sub PoST root proof") + } + + var marriageIndex *uint32 + if p.MarriageProof != nil { + if err := p.MarriageProof.Valid(malValidator, p.ATXID, nodeID, p.SmesherID); err != nil { + return fmt.Errorf("invalid marriage proof: %w", err) + } + marriageIndex = &p.MarriageProof.NodeIDMarryProof.CertificateIndex + } + if marriageIndex != nil { + if !p.MarriageIndexProof.Valid(p.SubPostRoot, *marriageIndex) { + return errors.New("invalid marriage index proof") + } + } + + if !p.PrevATXProof.Valid(p.SubPostRoot, prevATX) { + return errors.New("invalid previous ATX proof") + } + + return nil +} diff --git a/activation/wire/malfeasance_invalid_prev_atx_scale.go b/activation/wire/malfeasance_invalid_prev_atx_scale.go new file mode 100644 index 0000000000..15b06acc28 --- /dev/null +++ b/activation/wire/malfeasance_invalid_prev_atx_scale.go @@ -0,0 +1,364 @@ +// Code generated by github.com/spacemeshos/go-scale/scalegen. DO NOT EDIT. + +// nolint +package wire + +import ( + "github.com/spacemeshos/go-scale" + "github.com/spacemeshos/go-spacemesh/common/types" +) + +func (t *ProofInvalidPrevAtxV2) EncodeScale(enc *scale.Encoder) (total int, err error) { + { + n, err := scale.EncodeByteArray(enc, t.NodeID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.PrevATX[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructArray(enc, t.Proofs[:]) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *ProofInvalidPrevAtxV2) DecodeScale(dec *scale.Decoder) (total int, err error) { + { + n, err := scale.DecodeByteArray(dec, t.NodeID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.PrevATX[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeStructArray(dec, t.Proofs[:]) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *ProofInvalidPrevAtxV1) EncodeScale(enc *scale.Encoder) (total int, err error) { + { + n, err := scale.EncodeByteArray(enc, t.NodeID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.PrevATX[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := t.Proof.EncodeScale(enc) + if err != nil { + return total, err + } + total += n + } + { + n, err := t.ATXv1.EncodeScale(enc) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *ProofInvalidPrevAtxV1) DecodeScale(dec *scale.Decoder) (total int, err error) { + { + n, err := scale.DecodeByteArray(dec, t.NodeID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.PrevATX[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := t.Proof.DecodeScale(dec) + if err != nil { + return total, err + } + total += n + } + { + n, err := t.ATXv1.DecodeScale(dec) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *InvalidPrevAtxProof) EncodeScale(enc *scale.Encoder) (total int, err error) { + { + n, err := scale.EncodeByteArray(enc, t.ATXID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.SmesherID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.Signature[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.NIPostsRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.NIPostsRootProof, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.NIPostRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.NIPostRootProof, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeCompact16(enc, uint16(t.NIPostIndex)) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.SubPostsRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.SubPostsRootProof, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.SubPostRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.SubPostRootProof, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeCompact16(enc, uint16(t.SubPostRootIndex)) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeOption(enc, t.MarriageProof) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.MarriageIndexProof, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.PrevATXProof, 32) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *InvalidPrevAtxProof) DecodeScale(dec *scale.Decoder) (total int, err error) { + { + n, err := scale.DecodeByteArray(dec, t.ATXID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.SmesherID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.Signature[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.NIPostsRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.NIPostsRootProof = field + } + { + n, err := scale.DecodeByteArray(dec, t.NIPostRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.NIPostRootProof = field + } + { + field, n, err := scale.DecodeCompact16(dec) + if err != nil { + return total, err + } + total += n + t.NIPostIndex = uint16(field) + } + { + n, err := scale.DecodeByteArray(dec, t.SubPostsRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.SubPostsRootProof = field + } + { + n, err := scale.DecodeByteArray(dec, t.SubPostRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.SubPostRootProof = field + } + { + field, n, err := scale.DecodeCompact16(dec) + if err != nil { + return total, err + } + total += n + t.SubPostRootIndex = uint16(field) + } + { + field, n, err := scale.DecodeOption[MarriageProof](dec) + if err != nil { + return total, err + } + total += n + t.MarriageProof = field + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.MarriageIndexProof = field + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.PrevATXProof = field + } + return total, nil +} diff --git a/activation/wire/malfeasance_invalid_prev_atx_test.go b/activation/wire/malfeasance_invalid_prev_atx_test.go new file mode 100644 index 0000000000..1c829e74eb --- /dev/null +++ b/activation/wire/malfeasance_invalid_prev_atx_test.go @@ -0,0 +1,974 @@ +package wire + +import ( + "context" + "fmt" + "slices" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/signing" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" +) + +func Test_InvalidPrevAtxProofV2(t *testing.T) { + t.Parallel() + + // sig is the identity that creates the ATXs referencing the same prevATX + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + // pubSig is the identity that publishes a merged ATX with the same prevATX + pubSig, err := signing.NewEdSigner() + require.NoError(t, err) + + // marrySig is the identity that publishes the marriage ATX + marrySig, err := signing.NewEdSigner() + require.NoError(t, err) + + edVerifier := signing.NewEdVerifier() + + newMergedATXv2 := func( + db sql.Executor, + prevATX types.ATXID, + ) *ActivationTxV2 { + wInitialAtx := newActivationTxV2( + withInitial(types.RandomATXID(), PostV1{}), + ) + wInitialAtx.Sign(sig) + initialAtx := &types.ActivationTx{ + CommitmentATX: &wInitialAtx.Initial.CommitmentATX, + } + initialAtx.SetID(wInitialAtx.ID()) + initialAtx.SmesherID = sig.NodeID() + require.NoError(t, atxs.Add(db, initialAtx, wInitialAtx.Blob())) + + wPubInitialAtx := newActivationTxV2( + withInitial(types.RandomATXID(), PostV1{}), + ) + wPubInitialAtx.Sign(pubSig) + pubInitialAtx := &types.ActivationTx{} + pubInitialAtx.SetID(wPubInitialAtx.ID()) + pubInitialAtx.SmesherID = pubSig.NodeID() + require.NoError(t, atxs.Add(db, pubInitialAtx, wPubInitialAtx.Blob())) + + marryInitialAtx := types.RandomATXID() + + wMarriageAtx := newActivationTxV2( + withMarriageCertificate(marrySig, types.EmptyATXID, marrySig.NodeID()), + withMarriageCertificate(sig, wInitialAtx.ID(), marrySig.NodeID()), + withMarriageCertificate(pubSig, wPubInitialAtx.ID(), marrySig.NodeID()), + ) + wMarriageAtx.Sign(marrySig) + + marriageAtx := &types.ActivationTx{} + marriageAtx.SetID(wMarriageAtx.ID()) + marriageAtx.SmesherID = marrySig.NodeID() + require.NoError(t, atxs.Add(db, marriageAtx, wMarriageAtx.Blob())) + + atx := newActivationTxV2( + withPreviousATXs(marryInitialAtx, wPubInitialAtx.ID(), prevATX), + withMarriageATX(wMarriageAtx.ID()), + withNIPost( + withNIPostMembershipProof(MerkleProofV2{}), + withNIPostSubPost(SubPostV2{ + MarriageIndex: 0, + PrevATXIndex: 0, + }), + withNIPostSubPost(SubPostV2{ + MarriageIndex: 1, + PrevATXIndex: 2, + }), + withNIPostSubPost(SubPostV2{ + MarriageIndex: 2, + PrevATXIndex: 1, + }), + ), + ) + atx.Sign(pubSig) + return atx + } + + t.Run("valid", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATXID := types.RandomATXID() + atx1 := newActivationTxV2( + withPreviousATXs(prevATXID), + withPublishEpoch(5), + ) + atx1.Sign(sig) + atx2 := newActivationTxV2( + withPreviousATXs(prevATXID), + withPublishEpoch(7), + ) + atx2.Sign(sig) + + proof, err := NewInvalidPrevAtxProofV2(db, atx1, atx2, sig.NodeID()) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + // verify the proof + id, err := proof.Valid(context.Background(), verifier) + require.NoError(t, err) + require.Equal(t, sig.NodeID(), id) + }) + + t.Run("valid merged & solo atx", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATXID := types.RandomATXID() + prevAtx := &types.ActivationTx{} + prevAtx.SetID(prevATXID) + prevAtx.SmesherID = sig.NodeID() + require.NoError(t, atxs.Add(db, prevAtx, types.AtxBlob{})) + atx1 := newActivationTxV2( + withPreviousATXs(prevATXID), + withPublishEpoch(5), + ) + atx1.Sign(sig) + atx2 := newMergedATXv2(db, prevATXID) + + proof, err := NewInvalidPrevAtxProofV2(db, atx1, atx2, sig.NodeID()) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + // verify the proof + id, err := proof.Valid(context.Background(), verifier) + require.NoError(t, err) + require.Equal(t, sig.NodeID(), id) + }) + + // valid merged & merged is covered by either double marry or double merge proofs + + t.Run("same ATX ID", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + atx1 := newActivationTxV2( + withPreviousATXs(types.RandomATXID()), + ) + atx1.Sign(sig) + + proof, err := NewInvalidPrevAtxProofV2(db, atx1, atx1, sig.NodeID()) + require.ErrorContains(t, err, "ATXs have the same ID") + require.Nil(t, proof) + }) + + t.Run("smesher ID mismatch", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATX := types.RandomATXID() + atx1 := newActivationTxV2( + withPreviousATXs(prevATX), + ) + atx1.Sign(sig) + atx2 := newActivationTxV2( + withPreviousATXs(prevATX), + ) + atx2.Sign(pubSig) + + proof, err := NewInvalidPrevAtxProofV2(db, atx1, atx2, sig.NodeID()) + require.EqualError(t, err, "ATX2 is not a merged ATX, but NodeID is different from SmesherID") + require.Nil(t, proof) + + proof, err = NewInvalidPrevAtxProofV2(db, atx1, atx2, pubSig.NodeID()) + require.EqualError(t, err, "ATX1 is not a merged ATX, but NodeID is different from SmesherID") + require.Nil(t, proof) + }) + + t.Run("id not married to smesher", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + invalidSig, err := signing.NewEdSigner() + require.NoError(t, err) + + prevATXID := types.RandomATXID() + prevAtx := &types.ActivationTx{} + prevAtx.SetID(prevATXID) + prevAtx.SmesherID = sig.NodeID() + require.NoError(t, atxs.Add(db, prevAtx, types.AtxBlob{})) + atx1 := newActivationTxV2( + withPreviousATXs(prevATXID), + withPublishEpoch(5), + ) + atx1.Sign(invalidSig) + atx2 := newMergedATXv2(db, prevATXID) + + proof, err := NewInvalidPrevAtxProofV2(db, atx1, atx2, invalidSig.NodeID()) + require.ErrorContains(t, err, + fmt.Sprintf("does not contain a marriage certificate signed by %s", invalidSig.NodeID().ShortString()), + ) + require.Nil(t, proof) + + proof, err = NewInvalidPrevAtxProofV2(db, atx2, atx1, invalidSig.NodeID()) + require.ErrorContains(t, err, + fmt.Sprintf("does not contain a marriage certificate signed by %s", invalidSig.NodeID().ShortString()), + ) + require.Nil(t, proof) + }) + + t.Run("merged ATX does not contain post from identity", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATXID := types.RandomATXID() + prevAtx := &types.ActivationTx{} + prevAtx.SetID(prevATXID) + prevAtx.SmesherID = sig.NodeID() + require.NoError(t, atxs.Add(db, prevAtx, types.AtxBlob{})) + atx1 := newActivationTxV2( + withPreviousATXs(prevATXID), + withPublishEpoch(5), + ) + atx1.Sign(sig) + atx2 := newMergedATXv2(db, prevATXID) + + // remove the post from sig in the merged ATX + atx2.NIPosts[0].Posts = slices.DeleteFunc(atx2.NIPosts[0].Posts, func(subPost SubPostV2) bool { + return subPost.MarriageIndex == 1 + }) + + proof, err := NewInvalidPrevAtxProofV2(db, atx1, atx2, sig.NodeID()) + require.ErrorContains(t, err, + fmt.Sprintf("no PoST from %s in ATX", sig.NodeID().ShortString()), + ) + require.Nil(t, proof) + + proof, err = NewInvalidPrevAtxProofV2(db, atx2, atx1, sig.NodeID()) + require.ErrorContains(t, err, + fmt.Sprintf("no PoST from %s in ATX", sig.NodeID().ShortString()), + ) + require.Nil(t, proof) + }) + + t.Run("prev ATX differs between ATXs", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + atx1 := newActivationTxV2( + withPreviousATXs(types.RandomATXID()), + withPublishEpoch(5), + ) + atx1.Sign(sig) + atx2 := newActivationTxV2( + withPreviousATXs(types.RandomATXID()), + withPublishEpoch(7), + ) + atx2.Sign(sig) + + proof, err := NewInvalidPrevAtxProofV2(db, atx1, atx2, sig.NodeID()) + require.ErrorContains(t, err, "ATXs reference different previous ATXs") + require.Nil(t, proof) + }) + + t.Run("invalid solo proof", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATXID := types.RandomATXID() + atx1 := newActivationTxV2( + withPreviousATXs(prevATXID), + withPublishEpoch(5), + ) + atx1.Sign(sig) + atx2 := newActivationTxV2( + withPreviousATXs(prevATXID), + withPublishEpoch(7), + ) + atx2.Sign(sig) + + proof, err := NewInvalidPrevAtxProofV2(db, atx1, atx2, sig.NodeID()) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + // same ATX ID + proof.Proofs[0].ATXID = atx2.ID() + id, err := proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proofs have the same ATX ID") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].ATXID = atx1.ID() + + // invalid prev ATX + proof.PrevATX = types.RandomATXID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid previous ATX proof") + require.Equal(t, types.EmptyNodeID, id) + proof.PrevATX = prevATXID + + // invalid node ID + proof.NodeID = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "missing marriage proof") + require.Equal(t, types.EmptyNodeID, id) + proof.NodeID = sig.NodeID() + + // invalid ATX ID + proof.Proofs[0].ATXID = types.RandomATXID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].ATXID = atx1.ID() + + proof.Proofs[1].ATXID = types.RandomATXID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].ATXID = atx2.ID() + + // invalid SmesherID + proof.Proofs[0].SmesherID = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].SmesherID = sig.NodeID() + + proof.Proofs[1].SmesherID = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].SmesherID = sig.NodeID() + + // invalid signature + proof.Proofs[0].Signature = types.RandomEdSignature() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].Signature = atx1.Signature + + proof.Proofs[1].Signature = types.RandomEdSignature() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].Signature = atx2.Signature + + // invalid NIPosts root + nipostsRoot := proof.Proofs[0].NIPostsRoot + proof.Proofs[0].NIPostsRoot = NIPostsRoot(types.RandomHash()) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid NIPosts root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].NIPostsRoot = nipostsRoot + + nipostsRoot = proof.Proofs[1].NIPostsRoot + proof.Proofs[1].NIPostsRoot = NIPostsRoot(types.RandomHash()) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid NIPosts root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].NIPostsRoot = nipostsRoot + + // invalid NIPosts root proof + hash := proof.Proofs[0].NIPostsRootProof[0] + proof.Proofs[0].NIPostsRootProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid NIPosts root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].NIPostsRootProof[0] = hash + + hash = proof.Proofs[1].NIPostsRootProof[0] + proof.Proofs[1].NIPostsRootProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid NIPosts root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].NIPostsRootProof[0] = hash + + // invalid NIPost root + nipostRoot := proof.Proofs[0].NIPostRoot + proof.Proofs[0].NIPostRoot = NIPostRoot(types.RandomHash()) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid NIPoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].NIPostRoot = nipostRoot + + nipostRoot = proof.Proofs[1].NIPostRoot + proof.Proofs[1].NIPostRoot = NIPostRoot(types.RandomHash()) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid NIPoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].NIPostRoot = nipostRoot + + // invalid NIPost root proof + hash = proof.Proofs[0].NIPostRootProof[0] + proof.Proofs[0].NIPostRootProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid NIPoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].NIPostRootProof[0] = hash + + hash = proof.Proofs[1].NIPostRootProof[0] + proof.Proofs[1].NIPostRootProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid NIPoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].NIPostRootProof[0] = hash + + // invalid NIPost index + proof.Proofs[0].NIPostIndex++ + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid NIPoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].NIPostIndex-- + + proof.Proofs[1].NIPostIndex++ + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid NIPoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].NIPostIndex-- + + // invalid sub posts root + subPostsRoot := proof.Proofs[0].SubPostsRoot + proof.Proofs[0].SubPostsRoot = SubPostsRoot(types.RandomHash()) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid sub PoSTs root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].SubPostsRoot = subPostsRoot + + subPostsRoot = proof.Proofs[1].SubPostsRoot + proof.Proofs[1].SubPostsRoot = SubPostsRoot(types.RandomHash()) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid sub PoSTs root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].SubPostsRoot = subPostsRoot + + // invalid sub posts root proof + hash = proof.Proofs[0].SubPostsRootProof[0] + proof.Proofs[0].SubPostsRootProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid sub PoSTs root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].SubPostsRootProof[0] = hash + + hash = proof.Proofs[1].SubPostsRootProof[0] + proof.Proofs[1].SubPostsRootProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid sub PoSTs root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].SubPostsRootProof[0] = hash + + // invalid sub post root + subPostRoot := proof.Proofs[0].SubPostRoot + proof.Proofs[0].SubPostRoot = SubPostRoot(types.RandomHash()) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid sub PoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].SubPostRoot = subPostRoot + + subPostRoot = proof.Proofs[1].SubPostRoot + proof.Proofs[1].SubPostRoot = SubPostRoot(types.RandomHash()) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid sub PoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].SubPostRoot = subPostRoot + + // invalid sub post root proof + hash = proof.Proofs[0].SubPostRootProof[0] + proof.Proofs[0].SubPostRootProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid sub PoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].SubPostRootProof[0] = hash + + hash = proof.Proofs[1].SubPostRootProof[0] + proof.Proofs[1].SubPostRootProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid sub PoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].SubPostRootProof[0] = hash + + // invalid sub post index + proof.Proofs[0].SubPostRootIndex++ + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid sub PoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].SubPostRootIndex-- + + proof.Proofs[1].SubPostRootIndex++ + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid sub PoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].SubPostRootIndex-- + + // invalid prev atx proof + hash = proof.Proofs[0].PrevATXProof[0] + proof.Proofs[0].PrevATXProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid previous ATX proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].PrevATXProof[0] = hash + + hash = proof.Proofs[1].PrevATXProof[0] + proof.Proofs[1].PrevATXProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid previous ATX proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].PrevATXProof[0] = hash + }) + + t.Run("invalid merged proof", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATXID := types.RandomATXID() + prevAtx := &types.ActivationTx{} + prevAtx.SetID(prevATXID) + prevAtx.SmesherID = sig.NodeID() + require.NoError(t, atxs.Add(db, prevAtx, types.AtxBlob{})) + atx1 := newActivationTxV2( + withPreviousATXs(prevATXID), + withPublishEpoch(5), + ) + atx1.Sign(sig) + atx2 := newMergedATXv2(db, prevATXID) + + proof, err := NewInvalidPrevAtxProofV2(db, atx1, atx2, sig.NodeID()) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + // invalid node ID + proof.NodeID = types.RandomNodeID() + id, err := proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "missing marriage proof") + require.Equal(t, types.EmptyNodeID, id) + proof.NodeID = sig.NodeID() + + // invalid ATX ID + proof.Proofs[0].ATXID = types.RandomATXID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].ATXID = atx1.ID() + + proof.Proofs[1].ATXID = types.RandomATXID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].ATXID = atx2.ID() + + // invalid SmesherID + proof.Proofs[0].SmesherID = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].SmesherID = sig.NodeID() + + proof.Proofs[1].SmesherID = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].SmesherID = pubSig.NodeID() + + // invalid signature + proof.Proofs[0].Signature = types.RandomEdSignature() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].Signature = atx1.Signature + + proof.Proofs[1].Signature = types.RandomEdSignature() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].Signature = atx2.Signature + + // missing marriage proof + marriageProof := proof.Proofs[1].MarriageProof + proof.Proofs[1].MarriageProof = nil + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "missing marriage proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].MarriageProof = marriageProof + + // invalid marriage index proof + hash := proof.Proofs[1].MarriageIndexProof[0] + proof.Proofs[1].MarriageIndexProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid marriage index proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].MarriageIndexProof[0] = hash + + // invalid prev atx proof + hash = proof.Proofs[0].PrevATXProof[0] + proof.Proofs[0].PrevATXProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid previous ATX proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].PrevATXProof[0] = hash + + hash = proof.Proofs[1].PrevATXProof[0] + proof.Proofs[1].PrevATXProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid previous ATX proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].PrevATXProof[0] = hash + }) +} + +func Test_InvalidPrevAtxProofV1(t *testing.T) { + t.Parallel() + + // sig is the identity that creates the ATXs referencing the same prevATX + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + // pubSig is the identity that publishes a merged ATX with the same prevATX + pubSig, err := signing.NewEdSigner() + require.NoError(t, err) + + // marrySig is the identity that publishes the marriage ATX + marrySig, err := signing.NewEdSigner() + require.NoError(t, err) + + edVerifier := signing.NewEdVerifier() + + newMergedATXv2 := func( + db sql.Executor, + prevATX types.ATXID, + ) *ActivationTxV2 { + wInitialAtx := newActivationTxV2( + withInitial(types.RandomATXID(), PostV1{}), + ) + wInitialAtx.Sign(sig) + initialAtx := &types.ActivationTx{ + CommitmentATX: &wInitialAtx.Initial.CommitmentATX, + } + initialAtx.SetID(wInitialAtx.ID()) + initialAtx.SmesherID = sig.NodeID() + require.NoError(t, atxs.Add(db, initialAtx, wInitialAtx.Blob())) + + wPubInitialAtx := newActivationTxV2( + withInitial(types.RandomATXID(), PostV1{}), + ) + wPubInitialAtx.Sign(pubSig) + pubInitialAtx := &types.ActivationTx{} + pubInitialAtx.SetID(wPubInitialAtx.ID()) + pubInitialAtx.SmesherID = pubSig.NodeID() + require.NoError(t, atxs.Add(db, pubInitialAtx, wPubInitialAtx.Blob())) + + marryInitialAtx := types.RandomATXID() + + wMarriageAtx := newActivationTxV2( + withMarriageCertificate(marrySig, types.EmptyATXID, marrySig.NodeID()), + withMarriageCertificate(sig, wInitialAtx.ID(), marrySig.NodeID()), + withMarriageCertificate(pubSig, wPubInitialAtx.ID(), marrySig.NodeID()), + ) + wMarriageAtx.Sign(marrySig) + + marriageAtx := &types.ActivationTx{} + marriageAtx.SetID(wMarriageAtx.ID()) + marriageAtx.SmesherID = marrySig.NodeID() + require.NoError(t, atxs.Add(db, marriageAtx, wMarriageAtx.Blob())) + + atx := newActivationTxV2( + withPreviousATXs(marryInitialAtx, wPubInitialAtx.ID(), prevATX), + withMarriageATX(wMarriageAtx.ID()), + withNIPost( + withNIPostMembershipProof(MerkleProofV2{}), + withNIPostSubPost(SubPostV2{ + MarriageIndex: 0, + PrevATXIndex: 0, + }), + withNIPostSubPost(SubPostV2{ + MarriageIndex: 1, + PrevATXIndex: 2, + }), + withNIPostSubPost(SubPostV2{ + MarriageIndex: 2, + PrevATXIndex: 1, + }), + ), + ) + atx.Sign(pubSig) + return atx + } + + t.Run("valid", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATX := types.RandomATXID() + atxv1 := &ActivationTxV1{ + InnerActivationTxV1: InnerActivationTxV1{ + NIPostChallengeV1: NIPostChallengeV1{ + PublishEpoch: 5, + PrevATXID: prevATX, + PositioningATXID: types.RandomATXID(), + }, + }, + } + atxv1.Sign(sig) + + atxv2 := newActivationTxV2( + withPreviousATXs(prevATX), + withPublishEpoch(7), + ) + atxv2.Sign(sig) + + proof, err := NewInvalidPrevAtxProofV1(db, atxv2, atxv1, sig.NodeID()) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + // verify the proof + id, err := proof.Valid(context.Background(), verifier) + require.NoError(t, err) + require.Equal(t, sig.NodeID(), id) + }) + + t.Run("valid merged", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATX := types.RandomATXID() + atxv1 := &ActivationTxV1{ + InnerActivationTxV1: InnerActivationTxV1{ + NIPostChallengeV1: NIPostChallengeV1{ + PublishEpoch: 5, + PrevATXID: prevATX, + PositioningATXID: types.RandomATXID(), + }, + }, + } + atxv1.Sign(sig) + + atxv2 := newMergedATXv2(db, prevATX) + + proof, err := NewInvalidPrevAtxProofV1(db, atxv2, atxv1, sig.NodeID()) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + // verify the proof + id, err := proof.Valid(context.Background(), verifier) + require.NoError(t, err) + require.Equal(t, sig.NodeID(), id) + }) + + t.Run("smesher ID mismatch", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATX := types.RandomATXID() + atxv1 := &ActivationTxV1{ + InnerActivationTxV1: InnerActivationTxV1{ + NIPostChallengeV1: NIPostChallengeV1{ + PublishEpoch: 5, + PrevATXID: prevATX, + PositioningATXID: types.RandomATXID(), + }, + }, + } + atxv1.Sign(sig) + + atxv2 := newActivationTxV2( + withPreviousATXs(prevATX), + withPublishEpoch(7), + ) + atxv2.Sign(pubSig) + + proof, err := NewInvalidPrevAtxProofV1(db, atxv2, atxv1, pubSig.NodeID()) + require.EqualError(t, err, "ATX2 is not signed by NodeID") + require.Nil(t, proof) + + proof, err = NewInvalidPrevAtxProofV1(db, atxv2, atxv1, sig.NodeID()) + require.EqualError(t, err, "ATX1 is not a merged ATX, but NodeID is different from SmesherID") + require.Nil(t, proof) + }) + + t.Run("id not married to smesher", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + invalidSig, err := signing.NewEdSigner() + require.NoError(t, err) + + prevATX := types.RandomATXID() + atxv1 := &ActivationTxV1{ + InnerActivationTxV1: InnerActivationTxV1{ + NIPostChallengeV1: NIPostChallengeV1{ + PublishEpoch: 5, + PrevATXID: prevATX, + PositioningATXID: types.RandomATXID(), + }, + }, + } + atxv1.Sign(invalidSig) + + atxv2 := newMergedATXv2(db, prevATX) + + proof, err := NewInvalidPrevAtxProofV1(db, atxv2, atxv1, invalidSig.NodeID()) + require.ErrorContains(t, err, + fmt.Sprintf("does not contain a marriage certificate signed by %s", invalidSig.NodeID().ShortString()), + ) + require.Nil(t, proof) + }) + + t.Run("merged ATX does not contain post from identity", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATX := types.RandomATXID() + atxv1 := &ActivationTxV1{ + InnerActivationTxV1: InnerActivationTxV1{ + NIPostChallengeV1: NIPostChallengeV1{ + PublishEpoch: 5, + PrevATXID: prevATX, + PositioningATXID: types.RandomATXID(), + }, + }, + } + atxv1.Sign(sig) + + atxv2 := newMergedATXv2(db, prevATX) + + // remove the post from sig in the merged ATX + atxv2.NIPosts[0].Posts = slices.DeleteFunc(atxv2.NIPosts[0].Posts, func(subPost SubPostV2) bool { + return subPost.MarriageIndex == 1 + }) + + proof, err := NewInvalidPrevAtxProofV1(db, atxv2, atxv1, sig.NodeID()) + require.ErrorContains(t, err, + fmt.Sprintf("no PoST from %s in ATX", sig.NodeID().ShortString()), + ) + require.Nil(t, proof) + }) + + t.Run("prev ATX differs between ATXs", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATX := types.RandomATXID() + atxv1 := &ActivationTxV1{ + InnerActivationTxV1: InnerActivationTxV1{ + NIPostChallengeV1: NIPostChallengeV1{ + PublishEpoch: 5, + PrevATXID: prevATX, + PositioningATXID: types.RandomATXID(), + }, + }, + } + atxv1.Sign(sig) + + atxv2 := newActivationTxV2( + withPreviousATXs(types.RandomATXID()), + withPublishEpoch(7), + ) + atxv2.Sign(sig) + + proof, err := NewInvalidPrevAtxProofV1(db, atxv2, atxv1, sig.NodeID()) + require.ErrorContains(t, err, "ATXs reference different previous ATXs") + require.Nil(t, proof) + }) + + t.Run("invalid proof", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATX := types.RandomATXID() + atxv1 := &ActivationTxV1{ + InnerActivationTxV1: InnerActivationTxV1{ + NIPostChallengeV1: NIPostChallengeV1{ + PublishEpoch: 5, + PrevATXID: prevATX, + PositioningATXID: types.RandomATXID(), + }, + }, + } + atxv1.Sign(sig) + + atxv2 := newActivationTxV2( + withPreviousATXs(prevATX), + withPublishEpoch(7), + ) + atxv2.Sign(sig) + + proof, err := NewInvalidPrevAtxProofV1(db, atxv2, atxv1, sig.NodeID()) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + // invalid PrevATX + proof.PrevATX = types.RandomATXID() + id, err := proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid previous ATX proof") + require.Equal(t, types.EmptyNodeID, id) + proof.PrevATX = prevATX + + // invalid SmesherID for atxv1 + proof.ATXv1.SmesherID = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.ATXv1.SmesherID = sig.NodeID() + + // invalid signature for atxv1 + proof.ATXv1.Signature = types.RandomEdSignature() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.ATXv1.Signature = atxv1.Signature + + // signer of atxv1 does not match + proof.ATXv1.Sign(pubSig) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "ATXv1 has not been signed by the same identity") + require.Equal(t, types.EmptyNodeID, id) + proof.ATXv1.Sign(sig) + + // prevATX of atxv1 does not match + proof.ATXv1.PrevATXID = types.RandomATXID() + proof.ATXv1.Sign(sig) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "ATXv1 references a different previous ATX") + require.Equal(t, types.EmptyNodeID, id) + proof.ATXv1.PrevATXID = prevATX + proof.ATXv1.Sign(sig) + }) +} diff --git a/activation/wire/malfeasance_shared_test.go b/activation/wire/malfeasance_shared_test.go new file mode 100644 index 0000000000..46fbccea11 --- /dev/null +++ b/activation/wire/malfeasance_shared_test.go @@ -0,0 +1,362 @@ +package wire + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/signing" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" +) + +func Test_MarryProof(t *testing.T) { + t.Parallel() + + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + otherSig, err := signing.NewEdSigner() + require.NoError(t, err) + + edVerifier := signing.NewEdVerifier() + + t.Run("valid", func(t *testing.T) { + t.Parallel() + + db := statesql.InMemoryTest(t) + otherAtx := &types.ActivationTx{} + otherAtx.SetID(types.RandomATXID()) + otherAtx.SmesherID = otherSig.NodeID() + require.NoError(t, atxs.Add(db, otherAtx, types.AtxBlob{})) + + atx1 := newActivationTxV2( + withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), + withMarriageCertificate(otherSig, otherAtx.ID(), sig.NodeID()), + ) + atx1.Sign(sig) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + // valid for otherSig + proof, err := createMarryProof(db, atx1, otherSig.NodeID()) + require.NoError(t, err) + require.NotEmpty(t, proof) + + err = proof.Valid(verifier, atx1.ID(), sig.NodeID(), otherSig.NodeID()) + require.NoError(t, err) + + // valid for sig + proof, err = createMarryProof(db, atx1, sig.NodeID()) + require.NoError(t, err) + require.NotEmpty(t, proof) + + err = proof.Valid(verifier, atx1.ID(), sig.NodeID(), sig.NodeID()) + require.NoError(t, err) + }) + + t.Run("identity not included in certificates", func(t *testing.T) { + t.Parallel() + + db := statesql.InMemoryTest(t) + otherAtx := &types.ActivationTx{} + otherAtx.SetID(types.RandomATXID()) + otherAtx.SmesherID = otherSig.NodeID() + require.NoError(t, atxs.Add(db, otherAtx, types.AtxBlob{})) + + atx1 := newActivationTxV2( + withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), + withMarriageCertificate(otherSig, otherAtx.ID(), sig.NodeID()), + ) + atx1.Sign(sig) + + nodeID := types.RandomNodeID() + proof, err := createMarryProof(db, atx1, nodeID) + require.EqualError(t, err, + fmt.Sprintf("does not contain a marriage certificate signed by %s", nodeID.ShortString()), + ) + require.Empty(t, proof) + }) + + t.Run("invalid proof", func(t *testing.T) { + t.Parallel() + + db := statesql.InMemoryTest(t) + otherAtx := &types.ActivationTx{} + otherAtx.SetID(types.RandomATXID()) + otherAtx.SmesherID = otherSig.NodeID() + require.NoError(t, atxs.Add(db, otherAtx, types.AtxBlob{})) + + atx1 := newActivationTxV2( + withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), + withMarriageCertificate(otherSig, otherAtx.ID(), sig.NodeID()), + ) + atx1.Sign(sig) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + proof, err := createMarryProof(db, atx1, otherSig.NodeID()) + require.NoError(t, err) + require.NotEmpty(t, proof) + + // not valid for random NodeID + err = proof.Valid(verifier, atx1.ID(), sig.NodeID(), types.RandomNodeID()) + require.EqualError(t, err, "invalid certificate signature") + + // not valid for another ATX + err = proof.Valid(verifier, types.RandomATXID(), sig.NodeID(), otherSig.NodeID()) + require.EqualError(t, err, "invalid marriage proof") + + // not valid if certificate signature is invalid + certSig := proof.Certificate.Signature + proof.Certificate.Signature = types.RandomEdSignature() + err = proof.Valid(verifier, atx1.ID(), sig.NodeID(), otherSig.NodeID()) + require.EqualError(t, err, "invalid certificate signature") + proof.Certificate.Signature = certSig + + // not valid if marriage root is invalid + marriageRoot := proof.MarriageCertificatesRoot + proof.MarriageCertificatesRoot = MarriageCertificatesRoot(types.RandomHash()) + err = proof.Valid(verifier, atx1.ID(), sig.NodeID(), otherSig.NodeID()) + require.EqualError(t, err, "invalid marriage proof") + proof.MarriageCertificatesRoot = marriageRoot + + // not valid if marriage root proof is invalid + hash := proof.MarriageCertificatesProof[0] + proof.MarriageCertificatesProof[0] = types.RandomHash() + err = proof.Valid(verifier, atx1.ID(), sig.NodeID(), otherSig.NodeID()) + require.EqualError(t, err, "invalid marriage proof") + proof.MarriageCertificatesProof[0] = hash + + // not valid if certificate proof is invalid + index := proof.CertificateIndex + proof.CertificateIndex = 100 + err = proof.Valid(verifier, atx1.ID(), sig.NodeID(), otherSig.NodeID()) + require.EqualError(t, err, "invalid certificate proof") + proof.CertificateIndex = index + + certProof := proof.CertificateProof + proof.CertificateProof = MarriageCertificateProof{types.RandomHash()} + err = proof.Valid(verifier, atx1.ID(), sig.NodeID(), otherSig.NodeID()) + require.EqualError(t, err, "invalid certificate proof") + proof.CertificateProof = certProof + + hash = proof.CertificateProof[0] + proof.CertificateProof[0] = types.RandomHash() + err = proof.Valid(verifier, atx1.ID(), sig.NodeID(), otherSig.NodeID()) + require.EqualError(t, err, "invalid certificate proof") + proof.CertificateProof[0] = hash + }) +} + +func Test_MarriageProof(t *testing.T) { + t.Parallel() + + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + otherSig, err := signing.NewEdSigner() + require.NoError(t, err) + + edVerifier := signing.NewEdVerifier() + + t.Run("valid", func(t *testing.T) { + t.Parallel() + + db := statesql.InMemoryTest(t) + otherAtx := &types.ActivationTx{} + otherAtx.SetID(types.RandomATXID()) + otherAtx.SmesherID = otherSig.NodeID() + require.NoError(t, atxs.Add(db, otherAtx, types.AtxBlob{})) + + wMarriageAtx := newActivationTxV2( + withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), + withMarriageCertificate(otherSig, otherAtx.ID(), sig.NodeID()), + ) + wMarriageAtx.Sign(sig) + marriageAtx := &types.ActivationTx{} + marriageAtx.SetID(wMarriageAtx.ID()) + marriageAtx.SmesherID = sig.NodeID() + require.NoError(t, atxs.Add(db, marriageAtx, wMarriageAtx.Blob())) + + atx := newActivationTxV2( + withMarriageATX(wMarriageAtx.ID()), + ) + atx.Sign(sig) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + proof, err := createMarriageProof(db, atx, otherSig.NodeID()) + require.NoError(t, err) + require.NotEmpty(t, proof) + + err = proof.Valid(verifier, atx.ID(), otherSig.NodeID(), sig.NodeID()) + require.NoError(t, err) + }) + + t.Run("node ID is the same as smesher ID", func(t *testing.T) { + t.Parallel() + + db := statesql.InMemoryTest(t) + otherAtx := &types.ActivationTx{} + otherAtx.SetID(types.RandomATXID()) + otherAtx.SmesherID = otherSig.NodeID() + require.NoError(t, atxs.Add(db, otherAtx, types.AtxBlob{})) + + wMarriageAtx := newActivationTxV2( + withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), + withMarriageCertificate(otherSig, otherAtx.ID(), sig.NodeID()), + ) + wMarriageAtx.Sign(sig) + marriageAtx := &types.ActivationTx{} + marriageAtx.SetID(wMarriageAtx.ID()) + marriageAtx.SmesherID = sig.NodeID() + require.NoError(t, atxs.Add(db, marriageAtx, wMarriageAtx.Blob())) + + atx := newActivationTxV2( + withMarriageATX(wMarriageAtx.ID()), + ) + atx.Sign(sig) + + proof, err := createMarriageProof(db, atx, sig.NodeID()) + require.EqualError(t, err, "node ID is the same as smesher ID") + require.Empty(t, proof) + }) + + t.Run("marriage ATX is not available", func(t *testing.T) { + t.Parallel() + + db := statesql.InMemoryTest(t) + + atx := newActivationTxV2( + withMarriageATX(types.RandomATXID()), + ) + atx.Sign(sig) + + proof, err := createMarriageProof(db, atx, otherSig.NodeID()) + require.ErrorIs(t, err, sql.ErrNotFound) + require.Empty(t, proof) + }) + + t.Run("node ID isn't married in marriage ATX", func(t *testing.T) { + t.Parallel() + + db := statesql.InMemoryTest(t) + otherAtx := &types.ActivationTx{} + otherAtx.SetID(types.RandomATXID()) + otherAtx.SmesherID = otherSig.NodeID() + require.NoError(t, atxs.Add(db, otherAtx, types.AtxBlob{})) + + wMarriageAtx := newActivationTxV2( + withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), + withMarriageCertificate(otherSig, otherAtx.ID(), sig.NodeID()), + ) + wMarriageAtx.Sign(sig) + marriageAtx := &types.ActivationTx{} + marriageAtx.SetID(wMarriageAtx.ID()) + marriageAtx.SmesherID = sig.NodeID() + require.NoError(t, atxs.Add(db, marriageAtx, wMarriageAtx.Blob())) + + atx := newActivationTxV2( + withMarriageATX(wMarriageAtx.ID()), + ) + atx.Sign(sig) + + invalidSig, err := signing.NewEdSigner() + require.NoError(t, err) + + proof, err := createMarriageProof(db, atx, invalidSig.NodeID()) + require.ErrorContains(t, err, + fmt.Sprintf("does not contain a marriage certificate signed by %s", invalidSig.NodeID().ShortString()), + ) + require.Empty(t, proof) + + atx.Sign(invalidSig) + proof, err = createMarriageProof(db, atx, otherSig.NodeID()) + require.ErrorContains(t, err, + fmt.Sprintf("does not contain a marriage certificate signed by %s", invalidSig.NodeID().ShortString()), + ) + require.Empty(t, proof) + }) + + t.Run("invalid proof", func(t *testing.T) { + t.Parallel() + + db := statesql.InMemoryTest(t) + otherAtx := &types.ActivationTx{} + otherAtx.SetID(types.RandomATXID()) + otherAtx.SmesherID = otherSig.NodeID() + require.NoError(t, atxs.Add(db, otherAtx, types.AtxBlob{})) + + wMarriageAtx := newActivationTxV2( + withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), + withMarriageCertificate(otherSig, otherAtx.ID(), sig.NodeID()), + ) + wMarriageAtx.Sign(sig) + marriageAtx := &types.ActivationTx{} + marriageAtx.SetID(wMarriageAtx.ID()) + marriageAtx.SmesherID = sig.NodeID() + require.NoError(t, atxs.Add(db, marriageAtx, wMarriageAtx.Blob())) + + atx := newActivationTxV2( + withMarriageATX(wMarriageAtx.ID()), + ) + atx.Sign(sig) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + proof, err := createMarriageProof(db, atx, otherSig.NodeID()) + require.NoError(t, err) + require.NotEmpty(t, proof) + + // not valid for random ATX + err = proof.Valid(verifier, types.RandomATXID(), otherSig.NodeID(), sig.NodeID()) + require.EqualError(t, err, "invalid marriage ATX proof") + + // not valid for another smesher + err = proof.Valid(verifier, atx.ID(), otherSig.NodeID(), types.RandomNodeID()) + require.ErrorContains(t, err, "invalid certificate signature") + + // not valid for another nodeID + err = proof.Valid(verifier, atx.ID(), types.RandomNodeID(), sig.NodeID()) + require.ErrorContains(t, err, "invalid certificate signature") + + // not valid for incorrect marriage ATX + marriageATX := proof.MarriageATX + proof.MarriageATX = types.RandomATXID() + err = proof.Valid(verifier, atx.ID(), otherSig.NodeID(), sig.NodeID()) + require.EqualError(t, err, "invalid marriage ATX proof") + proof.MarriageATX = marriageATX + + // not valid for incorrect marriage ATX smesher ID + marriageATXSmesherID := proof.MarriageATXSmesherID + proof.MarriageATXSmesherID = types.RandomNodeID() + err = proof.Valid(verifier, atx.ID(), otherSig.NodeID(), sig.NodeID()) + require.ErrorContains(t, err, "invalid certificate signature") + proof.MarriageATXSmesherID = marriageATXSmesherID + }) +} diff --git a/activation/wire/wire_v2.go b/activation/wire/wire_v2.go index da4dc3fd35..696c4b332d 100644 --- a/activation/wire/wire_v2.go +++ b/activation/wire/wire_v2.go @@ -129,7 +129,7 @@ func (atx *ActivationTxV2) merkleTree(tree *merkle.Tree) { tree.AddLeaf(types.EmptyHash32.Bytes()) } - tree.AddLeaf(atx.PreviousATXs.Root().Bytes()) + tree.AddLeaf(types.Hash32(atx.PreviousATXs.Root()).Bytes()) tree.AddLeaf(types.Hash32(atx.NIPosts.Root(atx.PreviousATXs)).Bytes()) var vrfNonce types.Hash32 @@ -188,10 +188,16 @@ func (p InitialPostRootProof) Valid(atxID types.ATXID, initialPostRoot InitialPo return validateProof(types.Hash32(atxID), types.Hash32(initialPostRoot), p, uint64(InitialPostRootIndex)) } -func (atx *ActivationTxV2) PreviousATXsRootProof() []types.Hash32 { +func (atx *ActivationTxV2) PreviousATXsRootProof() PrevATXsRootProof { return atx.merkleProof(PreviousATXsRootIndex) } +type PrevATXsRootProof []types.Hash32 + +func (p PrevATXsRootProof) Valid(atxID types.ATXID, prevATXsRoot PrevATXsRoot) bool { + return validateProof(types.Hash32(atxID), types.Hash32(prevATXsRoot), p, uint64(PreviousATXsRootIndex)) +} + func (atx *ActivationTxV2) NIPostsRootProof() NIPostsRootProof { return atx.merkleProof(NIPostsRootIndex) } @@ -280,8 +286,23 @@ func (prevATXs PrevATXs) merkleTree(tree *merkle.Tree) { } } -func (prevATXs PrevATXs) Root() types.Hash32 { - return createRoot(prevATXs.merkleTree) +type PrevATXsRoot types.Hash32 + +func (prevATXs PrevATXs) Root() PrevATXsRoot { + return PrevATXsRoot(createRoot(prevATXs.merkleTree)) +} + +func (prevATXs PrevATXs) Proof(index int) PrevATXsProof { + if index < 0 || index >= len(prevATXs) { + panic("index out of range") + } + return createProof(uint64(index), prevATXs.merkleTree) +} + +type PrevATXsProof []types.Hash32 + +func (p PrevATXsProof) Valid(prevATXsRoot PrevATXsRoot, index int, prevATX types.ATXID) bool { + return validateProof(types.Hash32(prevATXsRoot), types.Hash32(prevATX), p, uint64(index)) } type NIPosts []NIPostV2 @@ -465,21 +486,12 @@ func (post *SubPostV2) MarshalLogObject(encoder zapcore.ObjectEncoder) error { return nil } -func (sp *SubPostV2) merkleTree(tree *merkle.Tree, prevATXs []types.ATXID) { +func (sp *SubPostV2) merkleTree(tree *merkle.Tree, prevATX types.ATXID) { var marriageIndex types.Hash32 binary.LittleEndian.PutUint32(marriageIndex[:], sp.MarriageIndex) tree.AddLeaf(marriageIndex.Bytes()) - switch { - case len(prevATXs) == 0: // special case for initial ATX: prevATXs is empty - tree.AddLeaf(types.EmptyATXID.Bytes()) - case int(sp.PrevATXIndex) < len(prevATXs): - tree.AddLeaf(prevATXs[sp.PrevATXIndex].Bytes()) - default: - // prevATXIndex is out of range, don't fail ATXID generation - // will be detected by syntactical validation - tree.AddLeaf(types.EmptyATXID.Bytes()) - } + tree.AddLeaf(prevATX.Bytes()) var leafIndex types.Hash32 binary.LittleEndian.PutUint64(leafIndex[:], sp.MembershipLeafIndex) @@ -494,7 +506,17 @@ func (sp *SubPostV2) merkleTree(tree *merkle.Tree, prevATXs []types.ATXID) { func (sp *SubPostV2) merkleProof(leafIndex SubPostTreeIndex, prevATXs []types.ATXID) []types.Hash32 { return createProof(uint64(leafIndex), func(tree *merkle.Tree) { - sp.merkleTree(tree, prevATXs) + var prevATX types.ATXID + switch { + case len(prevATXs) == 0: // special case for initial ATX: prevATXs is empty + prevATX = types.EmptyATXID + case int(sp.PrevATXIndex) < len(prevATXs): + prevATX = prevATXs[sp.PrevATXIndex] + default: + // not the full set of prevATXs is provided, proof cannot be generated + panic("prevATXIndex out of range or prevATXs incomplete") + } + sp.merkleTree(tree, prevATX) }) } @@ -502,7 +524,18 @@ type SubPostRoot types.Hash32 func (sp *SubPostV2) Root(prevATXs []types.ATXID) SubPostRoot { return SubPostRoot(createRoot(func(tree *merkle.Tree) { - sp.merkleTree(tree, prevATXs) + var prevATX types.ATXID + switch { + case len(prevATXs) == 0: // special case for initial ATX: prevATXs is empty + prevATX = types.EmptyATXID + case int(sp.PrevATXIndex) < len(prevATXs): + prevATX = prevATXs[sp.PrevATXIndex] + default: + // prevATXIndex is out of range, don't fail ATXID generation + // will be detected by syntactical validation + prevATX = types.EmptyATXID + } + sp.merkleTree(tree, prevATX) })) } @@ -522,6 +555,18 @@ func (sp *SubPostV2) PrevATXIndexProof(prevATXs []types.ATXID) []types.Hash32 { return sp.merkleProof(PrevATXIndex, prevATXs) } +func (sp *SubPostV2) PrevATXProof(prevATX types.ATXID) PrevATXProof { + return createProof(uint64(SubPostTreeIndex(PrevATXIndex)), func(tree *merkle.Tree) { + sp.merkleTree(tree, prevATX) + }) +} + +type PrevATXProof []types.Hash32 + +func (p PrevATXProof) Valid(subPostRoot SubPostRoot, prevATX types.ATXID) bool { + return validateProof(types.Hash32(subPostRoot), types.Hash32(prevATX), p, uint64(PrevATXIndex)) +} + func (sp *SubPostV2) MembershipLeafIndexProof(prevATXs []types.ATXID) []types.Hash32 { return sp.merkleProof(MembershipLeafIndex, prevATXs) } diff --git a/sql/atxs/atxs.go b/sql/atxs/atxs.go index 7b4c549614..4b022e1803 100644 --- a/sql/atxs/atxs.go +++ b/sql/atxs/atxs.go @@ -245,8 +245,7 @@ func GetLastIDByNodeID(db sql.Executor, nodeID types.NodeID) (id types.ATXID, er } // PrevIDByNodeID returns the previous ATX ID for a given node ID and public epoch. -// It returns the newest ATX ID containing PoST of the given node ID -// that was published before the given public epoch. +// It returns the newest ATX ID containing PoST of the given node ID that was published in or before the given epoch. func PrevIDByNodeID(db sql.Executor, nodeID types.NodeID, pubEpoch types.EpochID) (id types.ATXID, err error) { enc := func(stmt *sql.Statement) { stmt.BindBytes(1, nodeID.Bytes()) @@ -259,7 +258,7 @@ func PrevIDByNodeID(db sql.Executor, nodeID types.NodeID, pubEpoch types.EpochID if rows, err := db.Exec(` SELECT atxid FROM posts - WHERE pubkey = ?1 AND publish_epoch < ?2 + WHERE pubkey = ?1 AND publish_epoch <= ?2 ORDER BY publish_epoch DESC LIMIT 1;`, enc, dec); err != nil { return types.EmptyATXID, fmt.Errorf("exec nodeID %v, epoch %d: %w", nodeID, pubEpoch, err) @@ -863,7 +862,10 @@ func IterateAtxIdsWithMalfeasance( return err } -func PrevATXCollision(db sql.Executor, prev types.ATXID, id types.NodeID) (types.ATXID, types.ATXID, error) { +// PrevATXCollisions returns all ATXs with the same prevATX as the given ATX ID from the same node ID. +// It is used to detect double-publishing and double poet registrations. +// The ATXs returned are ordered by received time so that the first one is the one that was seen first by the node. +func PrevATXCollisions(db sql.Executor, prev types.ATXID, id types.NodeID) ([]types.ATXID, error) { var atxs []types.ATXID enc := func(stmt *sql.Statement) { stmt.BindBytes(1, prev[:]) @@ -873,16 +875,22 @@ func PrevATXCollision(db sql.Executor, prev types.ATXID, id types.NodeID) (types var id types.ATXID stmt.ColumnBytes(0, id[:]) atxs = append(atxs, id) - return len(atxs) < 2 + return true } - _, err := db.Exec("SELECT atxid FROM posts WHERE prev_atxid = ?1 AND pubkey = ?2;", enc, dec) + query := `SELECT atxid FROM posts + WHERE prev_atxid = ?1 AND pubkey = ?2 + ORDER BY ( + SELECT received FROM atxs + WHERE id = atxid + );` + _, err := db.Exec(query, enc, dec) if err != nil { - return types.EmptyATXID, types.EmptyATXID, fmt.Errorf("error getting ATXs with same prevATX: %w", err) + return nil, fmt.Errorf("error getting ATXs with same prevATX: %w", err) } - if len(atxs) != 2 { - return types.EmptyATXID, types.EmptyATXID, sql.ErrNotFound + if len(atxs) < 2 { + return nil, sql.ErrNotFound } - return atxs[0], atxs[1], nil + return atxs, nil } func Units(db sql.Executor, atxID types.ATXID, nodeID types.NodeID) (uint32, error) { diff --git a/sql/atxs/atxs_test.go b/sql/atxs/atxs_test.go index 124b914d6f..c82e4378ba 100644 --- a/sql/atxs/atxs_test.go +++ b/sql/atxs/atxs_test.go @@ -1070,15 +1070,15 @@ func Test_PrevATXCollision(t *testing.T) { require.NoError(t, atxs.SetPost(db, atx2.ID(), prevATXID, 0, atx2.SmesherID, 10, atx2.PublishEpoch)) } - collision1, collision2, err := atxs.PrevATXCollision(db, prevATXID, sig.NodeID()) + collisions, err := atxs.PrevATXCollisions(db, prevATXID, sig.NodeID()) require.NoError(t, err) - require.ElementsMatch(t, []types.ATXID{atx1.ID(), atx2.ID()}, []types.ATXID{collision1, collision2}) + require.ElementsMatch(t, []types.ATXID{atx1.ID(), atx2.ID()}, collisions) - _, _, err = atxs.PrevATXCollision(db, types.RandomATXID(), sig.NodeID()) + _, err = atxs.PrevATXCollisions(db, types.RandomATXID(), sig.NodeID()) require.ErrorIs(t, err, sql.ErrNotFound) for _, id := range append(otherIds, types.RandomNodeID()) { - _, _, err := atxs.PrevATXCollision(db, prevATXID, id) + _, err := atxs.PrevATXCollisions(db, prevATXID, id) require.ErrorIs(t, err, sql.ErrNotFound) } } @@ -1392,13 +1392,17 @@ func TestPrevIDByNodeID(t *testing.T) { require.NoError(t, atxs.Add(db, atx2, types.AtxBlob{})) require.NoError(t, atxs.SetPost(db, atx2.ID(), types.EmptyATXID, 0, sig.NodeID(), 4, atx2.PublishEpoch)) - _, err = atxs.PrevIDByNodeID(db, sig.NodeID(), 1) + _, err = atxs.PrevIDByNodeID(db, sig.NodeID(), 0) require.ErrorIs(t, err, sql.ErrNotFound) - prevID, err := atxs.PrevIDByNodeID(db, sig.NodeID(), 2) + prevID, err := atxs.PrevIDByNodeID(db, sig.NodeID(), 1) require.NoError(t, err) require.Equal(t, atx1.ID(), prevID) + prevID, err = atxs.PrevIDByNodeID(db, sig.NodeID(), 2) + require.NoError(t, err) + require.Equal(t, atx2.ID(), prevID) + prevID, err = atxs.PrevIDByNodeID(db, sig.NodeID(), 3) require.NoError(t, err) require.Equal(t, atx2.ID(), prevID)