From 6e4577fe73863675a2fd3f968e21b1a85e62c203 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Toni=20Ram=C3=ADrez?= <58293609+ToniRamirezM@users.noreply.github.com> Date: Tue, 19 Nov 2024 10:34:44 +0100 Subject: [PATCH] fix: aggregating proofs (#191) (#193) * fix: aggregating proofs (#191) * ensure oldAccInputHash is ready * feat: updata sync lib * feat: acc input hash sanity check * feat: check acc input hash -1 * feat: refactor * feat: refactor * fix: batch1 acc input hash * fix: timestamp in input prover * fix: timestamp in input prover * fix: timestamp * feat: remove test * fix: test * fix: test * fix: comments * fix: comments * fix: test --- aggregator/aggregator.go | 88 ++++-- aggregator/aggregator_test.go | 340 ++++++++++-------------- aggregator/interfaces.go | 2 +- aggregator/mocks/mock_prover.go | 21 +- aggregator/prover/mocks/mock_channel.go | 176 ++++++++++++ aggregator/prover/prover.go | 31 ++- aggregator/prover/prover_test.go | 58 +++- go.mod | 2 +- go.sum | 4 +- test/Makefile | 3 +- 10 files changed, 480 insertions(+), 245 deletions(-) create mode 100644 aggregator/prover/mocks/mock_channel.go diff --git a/aggregator/aggregator.go b/aggregator/aggregator.go index 72c316be..0659180f 100644 --- a/aggregator/aggregator.go +++ b/aggregator/aggregator.go @@ -970,7 +970,7 @@ func (a *Aggregator) tryAggregateProofs(ctx context.Context, prover ProverInterf tmpLogger.Infof("Proof ID for aggregated proof: %v", *proof.ProofID) tmpLogger = tmpLogger.WithFields("proofId", *proof.ProofID) - recursiveProof, _, err := prover.WaitRecursiveProof(ctx, *proof.ProofID) + recursiveProof, _, _, err := prover.WaitRecursiveProof(ctx, *proof.ProofID) if err != nil { err = fmt.Errorf("failed to get aggregated proof from prover, %w", err) tmpLogger.Error(FirstToUpper(err.Error())) @@ -1121,7 +1121,7 @@ func (a *Aggregator) getAndLockBatchToProve( // Not found, so it it not possible to verify the batch yet if sequence == nil || errors.Is(err, entities.ErrNotFound) { tmpLogger.Infof("Sequencing event for batch %d has not been synced yet, "+ - "so it is not possible to verify it yet. Waiting...", batchNumberToVerify) + "so it is not possible to verify it yet. Waiting ...", batchNumberToVerify) return nil, nil, nil, state.ErrNotFound } @@ -1138,7 +1138,7 @@ func (a *Aggregator) getAndLockBatchToProve( return nil, nil, nil, err } else if errors.Is(err, entities.ErrNotFound) { a.logger.Infof("Virtual batch %d has not been synced yet, "+ - "so it is not possible to verify it yet. Waiting...", batchNumberToVerify) + "so it is not possible to verify it yet. Waiting ...", batchNumberToVerify) return nil, nil, nil, state.ErrNotFound } @@ -1163,21 +1163,43 @@ func (a *Aggregator) getAndLockBatchToProve( virtualBatch.L1InfoRoot = &l1InfoRoot } + // Ensure the old acc input hash is in memory + oldAccInputHash := a.getAccInputHash(batchNumberToVerify - 1) + if oldAccInputHash == (common.Hash{}) && batchNumberToVerify > 1 { + tmpLogger.Warnf("AccInputHash for previous batch (%d) is not in memory. Waiting ...", batchNumberToVerify-1) + return nil, nil, nil, state.ErrNotFound + } + + forcedBlockHashL1 := rpcBatch.ForcedBlockHashL1() + l1InfoRoot = *virtualBatch.L1InfoRoot + + if batchNumberToVerify == 1 { + l1Block, err := a.l1Syncr.GetL1BlockByNumber(ctx, virtualBatch.BlockNumber) + if err != nil { + a.logger.Errorf("Error getting l1 block: %v", err) + return nil, nil, nil, err + } + + forcedBlockHashL1 = l1Block.ParentHash + l1InfoRoot = rpcBatch.GlobalExitRoot() + } + // Calculate acc input hash as the RPC is not returning the correct one at the moment accInputHash := cdkcommon.CalculateAccInputHash( a.logger, - a.getAccInputHash(batchNumberToVerify-1), + oldAccInputHash, virtualBatch.BatchL2Data, - *virtualBatch.L1InfoRoot, + l1InfoRoot, uint64(sequence.Timestamp.Unix()), rpcBatch.LastCoinbase(), - rpcBatch.ForcedBlockHashL1(), + forcedBlockHashL1, ) // Store the acc input hash a.setAccInputHash(batchNumberToVerify, accInputHash) // Log params to calculate acc input hash a.logger.Debugf("Calculated acc input hash for batch %d: %v", batchNumberToVerify, accInputHash) + a.logger.Debugf("OldAccInputHash: %v", oldAccInputHash) a.logger.Debugf("L1InfoRoot: %v", virtualBatch.L1InfoRoot) // a.logger.Debugf("LastL2BLockTimestamp: %v", rpcBatch.LastL2BLockTimestamp()) a.logger.Debugf("TimestampLimit: %v", uint64(sequence.Timestamp.Unix())) @@ -1196,7 +1218,7 @@ func (a *Aggregator) getAndLockBatchToProve( AccInputHash: accInputHash, L1InfoTreeIndex: rpcBatch.L1InfoTreeIndex(), L1InfoRoot: *virtualBatch.L1InfoRoot, - Timestamp: time.Unix(int64(rpcBatch.LastL2BLockTimestamp()), 0), + Timestamp: sequence.Timestamp, GlobalExitRoot: rpcBatch.GlobalExitRoot(), ChainID: a.cfg.ChainID, ForkID: a.cfg.ForkId, @@ -1325,7 +1347,7 @@ func (a *Aggregator) tryGenerateBatchProof(ctx context.Context, prover ProverInt tmpLogger = tmpLogger.WithFields("proofId", *proof.ProofID) - resGetProof, stateRoot, err := prover.WaitRecursiveProof(ctx, *proof.ProofID) + resGetProof, stateRoot, accInputHash, err := prover.WaitRecursiveProof(ctx, *proof.ProofID) if err != nil { err = fmt.Errorf("failed to get proof from prover, %w", err) tmpLogger.Error(FirstToUpper(err.Error())) @@ -1335,17 +1357,9 @@ func (a *Aggregator) tryGenerateBatchProof(ctx context.Context, prover ProverInt tmpLogger.Info("Batch proof generated") // Sanity Check: state root from the proof must match the one from the batch - if a.cfg.BatchProofSanityCheckEnabled && (stateRoot != common.Hash{}) && (stateRoot != batchToProve.StateRoot) { - for { - tmpLogger.Errorf("State root from the proof does not match the expected for batch %d: Proof = [%s] Expected = [%s]", - batchToProve.BatchNumber, stateRoot.String(), batchToProve.StateRoot.String(), - ) - time.Sleep(a.cfg.RetryTime.Duration) - } - } else { - tmpLogger.Infof("State root sanity check for batch %d passed", batchToProve.BatchNumber) + if a.cfg.BatchProofSanityCheckEnabled { + a.performSanityChecks(tmpLogger, stateRoot, accInputHash, batchToProve) } - proof.Proof = resGetProof // NOTE(pg): the defer func is useless from now on, use a different variable @@ -1374,6 +1388,35 @@ func (a *Aggregator) tryGenerateBatchProof(ctx context.Context, prover ProverInt return true, nil } +func (a *Aggregator) performSanityChecks(tmpLogger *log.Logger, stateRoot, accInputHash common.Hash, + batchToProve *state.Batch) { + // Sanity Check: state root from the proof must match the one from the batch + if (stateRoot != common.Hash{}) && (stateRoot != batchToProve.StateRoot) { + for { + tmpLogger.Errorf("HALTING: "+ + "State root from the proof does not match the expected for batch %d: Proof = [%s] Expected = [%s]", + batchToProve.BatchNumber, stateRoot.String(), batchToProve.StateRoot.String(), + ) + time.Sleep(a.cfg.RetryTime.Duration) + } + } else { + tmpLogger.Infof("State root sanity check for batch %d passed", batchToProve.BatchNumber) + } + + // Sanity Check: acc input hash from the proof must match the one from the batch + if (accInputHash != common.Hash{}) && (accInputHash != batchToProve.AccInputHash) { + for { + tmpLogger.Errorf("HALTING: Acc input hash from the proof does not match the expected for "+ + "batch %d: Proof = [%s] Expected = [%s]", + batchToProve.BatchNumber, accInputHash.String(), batchToProve.AccInputHash.String(), + ) + time.Sleep(a.cfg.RetryTime.Duration) + } + } else { + tmpLogger.Infof("Acc input hash sanity check for batch %d passed", batchToProve.BatchNumber) + } +} + // canVerifyProof returns true if we have reached the timeout to verify a proof // and no other prover is verifying a proof (verifyingProof = false). func (a *Aggregator) canVerifyProof() bool { @@ -1505,10 +1548,17 @@ func (a *Aggregator) buildInputProver( } } + // Ensure the old acc input hash is in memory + oldAccInputHash := a.getAccInputHash(batchToVerify.BatchNumber - 1) + if oldAccInputHash == (common.Hash{}) && batchToVerify.BatchNumber > 1 { + a.logger.Warnf("AccInputHash for previous batch (%d) is not in memory. Waiting ...", batchToVerify.BatchNumber-1) + return nil, fmt.Errorf("acc input hash for previous batch (%d) is not in memory", batchToVerify.BatchNumber-1) + } + inputProver := &prover.StatelessInputProver{ PublicInputs: &prover.StatelessPublicInputs{ Witness: witness, - OldAccInputHash: a.getAccInputHash(batchToVerify.BatchNumber - 1).Bytes(), + OldAccInputHash: oldAccInputHash.Bytes(), OldBatchNum: batchToVerify.BatchNumber - 1, ChainId: batchToVerify.ChainID, ForkId: batchToVerify.ForkID, diff --git a/aggregator/aggregator_test.go b/aggregator/aggregator_test.go index 506ce16c..8d5b5392 100644 --- a/aggregator/aggregator_test.go +++ b/aggregator/aggregator_test.go @@ -1114,7 +1114,7 @@ func Test_tryAggregateProofs(t *testing.T) { Return(nil). Once() m.proverMock.On("AggregatedProof", proof1.Proof, proof2.Proof).Return(&proofID, nil).Once() - m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return("", common.Hash{}, errTest).Once() + m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return("", common.Hash{}, common.Hash{}, errTest).Once() m.stateMock.On("BeginStateTransaction", mock.MatchedBy(matchAggregatorCtxFn)).Return(dbTx, nil).Once().NotBefore(lockProofsTxBegin) m.stateMock. On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), &proof1, dbTx). @@ -1172,7 +1172,7 @@ func Test_tryAggregateProofs(t *testing.T) { Return(nil). Once() m.proverMock.On("AggregatedProof", proof1.Proof, proof2.Proof).Return(&proofID, nil).Once() - m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return("", common.Hash{}, errTest).Once() + m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return("", common.Hash{}, common.Hash{}, errTest).Once() m.stateMock.On("BeginStateTransaction", mock.MatchedBy(matchAggregatorCtxFn)).Return(dbTx, nil).Once().NotBefore(lockProofsTxBegin) m.stateMock. On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), &proof1, dbTx). @@ -1220,7 +1220,7 @@ func Test_tryAggregateProofs(t *testing.T) { Return(nil). Once() m.proverMock.On("AggregatedProof", proof1.Proof, proof2.Proof).Return(&proofID, nil).Once() - m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return(recursiveProof, common.Hash{}, nil).Once() + m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return(recursiveProof, common.Hash{}, common.Hash{}, nil).Once() m.stateMock.On("DeleteGeneratedProofs", mock.MatchedBy(matchProverCtxFn), proof1.BatchNumber, proof2.BatchNumberFinal, dbTx).Return(errTest).Once() dbTx.On("Rollback", mock.MatchedBy(matchProverCtxFn)).Return(nil).Once() m.stateMock.On("BeginStateTransaction", mock.MatchedBy(matchAggregatorCtxFn)).Return(dbTx, nil).Once().NotBefore(lockProofsTxBegin) @@ -1280,7 +1280,7 @@ func Test_tryAggregateProofs(t *testing.T) { Return(nil). Once() m.proverMock.On("AggregatedProof", proof1.Proof, proof2.Proof).Return(&proofID, nil).Once() - m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return(recursiveProof, common.Hash{}, nil).Once() + m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return(recursiveProof, common.Hash{}, common.Hash{}, nil).Once() m.stateMock.On("DeleteGeneratedProofs", mock.MatchedBy(matchProverCtxFn), proof1.BatchNumber, proof2.BatchNumberFinal, dbTx).Return(nil).Once() m.stateMock.On("AddGeneratedProof", mock.MatchedBy(matchProverCtxFn), mock.Anything, dbTx).Return(errTest).Once() dbTx.On("Rollback", mock.MatchedBy(matchProverCtxFn)).Return(nil).Once() @@ -1315,6 +1315,7 @@ func Test_tryAggregateProofs(t *testing.T) { { name: "time to send final, state error", setup: func(m mox, a *Aggregator) { + a.accInputHashes = make(map[uint64]common.Hash) a.cfg.VerifyProofInterval = types.Duration{Duration: time.Nanosecond} m.proverMock.On("Name").Return(proverName).Times(3) m.proverMock.On("ID").Return(proverID).Times(3) @@ -1343,7 +1344,7 @@ func Test_tryAggregateProofs(t *testing.T) { Once() m.proverMock.On("AggregatedProof", proof1.Proof, proof2.Proof).Return(&proofID, nil).Once() - m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return(recursiveProof, common.Hash{}, nil).Once() + m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return(recursiveProof, common.Hash{}, common.Hash{}, nil).Once() m.stateMock.On("DeleteGeneratedProofs", mock.MatchedBy(matchProverCtxFn), proof1.BatchNumber, proof2.BatchNumberFinal, dbTx).Return(nil).Once() expectedInputProver := map[string]interface{}{ "recursive_proof_1": proof1.Proof, @@ -1443,6 +1444,7 @@ func Test_tryGenerateBatchProof(t *testing.T) { IntervalAfterWhichBatchConsolidateAnyway: types.Duration{Duration: time.Second * 1}, ChainID: uint64(1), ForkId: uint64(12), + BatchProofSanityCheckEnabled: true, } lastVerifiedBatchNum := uint64(22) @@ -1456,8 +1458,8 @@ func Test_tryGenerateBatchProof(t *testing.T) { proverName := "proverName" proverID := "proverID" - recursiveProof := "recursiveProof" errTest := errors.New("test error") + errAIH := fmt.Errorf("failed to build input prover, acc input hash for previous batch (22) is not in memory") proverCtx := context.WithValue(context.Background(), "owner", ownerProver) //nolint:staticcheck matchProverCtxFn := func(ctx context.Context) bool { return ctx.Value("owner") == ownerProver } matchAggregatorCtxFn := func(ctx context.Context) bool { return ctx.Value("owner") == ownerAggregator } @@ -1491,6 +1493,49 @@ func Test_tryGenerateBatchProof(t *testing.T) { setup func(mox, *Aggregator) asserts func(bool, *Aggregator, error) }{ + { + name: "getAndLockBatchToProve returns AIH error", + setup: func(m mox, a *Aggregator) { + sequence := synchronizer.SequencedBatches{ + FromBatchNumber: uint64(1), + ToBatchNumber: uint64(2), + } + l1InfoRoot := common.HexToHash("0x057e9950fbd39b002e323f37c2330d0c096e66919e24cc96fb4b2dfa8f4af782") + + virtualBatch := synchronizer.VirtualBatch{ + BatchNumber: 1, + BatchL2Data: []byte{ + 0xb, 0x0, 0x0, 0x0, 0x7b, 0x0, 0x0, 0x1, 0xc8, 0xb, 0x0, 0x0, 0x3, 0x15, 0x0, 0x1, 0x8a, 0xf8, + }, + L1InfoRoot: &l1InfoRoot, + } + rpcBatch := rpctypes.NewRPCBatch(lastVerifiedBatchNum+1, common.Hash{}, []string{}, []byte("batchL2Data"), common.Hash{}, common.BytesToHash([]byte("mock LocalExitRoot")), common.BytesToHash([]byte("mock StateRoot")), common.Address{}, false) + + m.proverMock.On("Name").Return(proverName) + m.proverMock.On("ID").Return(proverID) + m.proverMock.On("Addr").Return("addr") + m.etherman.On("GetLatestVerifiedBatchNum").Return(uint64(0), nil) + m.stateMock.On("CheckProofExistsForBatch", mock.Anything, uint64(1), nil).Return(false, nil) + m.synchronizerMock.On("GetSequenceByBatchNumber", mock.Anything, mock.Anything).Return(&sequence, nil) + m.synchronizerMock.On("GetVirtualBatchByBatchNumber", mock.Anything, mock.Anything).Return(&virtualBatch, nil) + m.synchronizerMock.On("GetL1BlockByNumber", mock.Anything, mock.Anything).Return(&synchronizer.L1Block{ParentHash: common.Hash{}}, nil) + m.rpcMock.On("GetBatch", mock.Anything).Return(rpcBatch, nil) + m.rpcMock.On("GetWitness", mock.Anything, false).Return([]byte("witness"), nil) + m.stateMock.On("AddGeneratedProof", mock.Anything, mock.Anything, nil).Return(nil) + m.stateMock.On("AddSequence", mock.Anything, mock.Anything, nil).Return(nil) + m.stateMock.On("DeleteGeneratedProofs", mock.Anything, uint64(1), uint64(1), nil).Return(nil) + m.synchronizerMock.On("GetLeafsByL1InfoRoot", mock.Anything, l1InfoRoot).Return(l1InfoTreeLeaf, nil) + m.synchronizerMock.On("GetL1InfoTreeLeaves", mock.Anything, mock.Anything).Return(map[uint32]synchronizer.L1InfoTreeLeaf{ + 1: { + BlockNumber: uint64(1), + }, + }, nil) + }, + asserts: func(result bool, a *Aggregator, err error) { + assert.False(result) + assert.ErrorContains(err, errAIH.Error()) + }, + }, { name: "getAndLockBatchToProve returns generic error", setup: func(m mox, a *Aggregator) { @@ -1534,20 +1579,20 @@ func Test_tryGenerateBatchProof(t *testing.T) { L1InfoRoot: &l1InfoRoot, } - m.synchronizerMock.On("GetVirtualBatchByBatchNumber", mock.Anything, lastVerifiedBatchNum).Return(&virtualBatch, nil).Once() + m.synchronizerMock.On("GetVirtualBatchByBatchNumber", mock.Anything, mock.Anything).Return(&virtualBatch, nil).Once() m.etherman.On("GetLatestVerifiedBatchNum").Return(lastVerifiedBatchNum, nil).Once() - m.stateMock.On("CheckProofExistsForBatch", mock.MatchedBy(matchProverCtxFn), lastVerifiedBatchNum+1, nil).Return(true, nil).Once() + m.stateMock.On("CheckProofExistsForBatch", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Return(true, nil) m.stateMock.On("CleanupGeneratedProofs", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() sequence := synchronizer.SequencedBatches{ FromBatchNumber: uint64(10), ToBatchNumber: uint64(20), } - m.synchronizerMock.On("GetSequenceByBatchNumber", mock.MatchedBy(matchProverCtxFn), lastVerifiedBatchNum).Return(&sequence, nil).Once() + m.synchronizerMock.On("GetSequenceByBatchNumber", mock.Anything, mock.Anything).Return(&sequence, nil) rpcBatch := rpctypes.NewRPCBatch(lastVerifiedBatchNum+1, common.Hash{}, []string{}, batchL2Data, common.Hash{}, common.BytesToHash([]byte("mock LocalExitRoot")), common.BytesToHash([]byte("mock StateRoot")), common.Address{}, false) rpcBatch.SetLastL2BLockTimestamp(uint64(time.Now().Unix())) - m.rpcMock.On("GetWitness", lastVerifiedBatchNum, false).Return([]byte("witness"), nil) - m.rpcMock.On("GetBatch", lastVerifiedBatchNum).Return(rpcBatch, nil) + m.rpcMock.On("GetWitness", mock.Anything, false).Return([]byte("witness"), nil) + m.rpcMock.On("GetBatch", mock.Anything).Return(rpcBatch, nil) m.stateMock.On("AddSequence", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Return(nil).Once() m.stateMock.On("AddGeneratedProof", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Run( func(args mock.Arguments) { @@ -1587,15 +1632,6 @@ func Test_tryGenerateBatchProof(t *testing.T) { batchL2Data, err := hex.DecodeString(codedL2Block1) require.NoError(err) l1InfoRoot := common.HexToHash("0x057e9950fbd39b002e323f37c2330d0c096e66919e24cc96fb4b2dfa8f4af782") - batch := state.Batch{ - BatchNumber: lastVerifiedBatchNum + 1, - BatchL2Data: batchL2Data, - L1InfoRoot: l1InfoRoot, - Timestamp: time.Now(), - Coinbase: common.Address{}, - ChainID: uint64(1), - ForkID: uint64(12), - } virtualBatch := synchronizer.VirtualBatch{ BatchNumber: lastVerifiedBatchNum + 1, @@ -1630,20 +1666,17 @@ func Test_tryGenerateBatchProof(t *testing.T) { }, ).Return(nil).Once() - m.synchronizerMock.On("GetLeafsByL1InfoRoot", mock.Anything, l1InfoRoot).Return(l1InfoTreeLeaf, nil).Twice() + m.synchronizerMock.On("GetLeafsByL1InfoRoot", mock.Anything, l1InfoRoot).Return(l1InfoTreeLeaf, nil) m.synchronizerMock.On("GetL1InfoTreeLeaves", mock.Anything, mock.Anything).Return(map[uint32]synchronizer.L1InfoTreeLeaf{ 1: { BlockNumber: uint64(35), }, - }, nil).Twice() + }, nil) m.rpcMock.On("GetBatch", lastVerifiedBatchNum+1).Return(rpcBatch, nil) - expectedInputProver, err := a.buildInputProver(context.Background(), &batch, []byte("witness")) - require.NoError(err) - - m.proverMock.On("BatchProof", expectedInputProver).Return(&proofID, nil).Once() - m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return("", common.Hash{}, errTest).Once() - m.stateMock.On("DeleteGeneratedProofs", mock.MatchedBy(matchAggregatorCtxFn), batchToProve.BatchNumber, batchToProve.BatchNumber, nil).Return(nil).Once() + m.proverMock.On("BatchProof", mock.Anything).Return(&proofID, nil).Once() + m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return("", common.Hash{}, common.Hash{}, errTest).Once() + m.stateMock.On("DeleteGeneratedProofs", mock.MatchedBy(matchAggregatorCtxFn), batchToProve.BatchNumber, batchToProve.BatchNumber, nil).Return(nil) }, asserts: func(result bool, a *Aggregator, err error) { assert.False(result) @@ -1651,36 +1684,35 @@ func Test_tryGenerateBatchProof(t *testing.T) { }, }, { - name: "DeleteBatchProofs error after WaitRecursiveProof prover error", + name: "WaitRecursiveProof no error", setup: func(m mox, a *Aggregator) { - m.proverMock.On("Name").Return(proverName).Twice() - m.proverMock.On("ID").Return(proverID).Twice() - m.proverMock.On("Addr").Return("addr").Twice() + m.proverMock.On("Name").Return(proverName) + m.proverMock.On("ID").Return(proverID) + m.proverMock.On("Addr").Return("addr") batchL2Data, err := hex.DecodeString(codedL2Block1) require.NoError(err) l1InfoRoot := common.HexToHash("0x057e9950fbd39b002e323f37c2330d0c096e66919e24cc96fb4b2dfa8f4af782") - batch := state.Batch{ + + virtualBatch := synchronizer.VirtualBatch{ BatchNumber: lastVerifiedBatchNum + 1, BatchL2Data: batchL2Data, - L1InfoRoot: l1InfoRoot, - Timestamp: time.Now(), - Coinbase: common.Address{}, - ChainID: uint64(1), - ForkID: uint64(12), + L1InfoRoot: &l1InfoRoot, } - m.etherman.On("GetLatestVerifiedBatchNum").Return(lastVerifiedBatchNum, nil).Once() - m.stateMock.On("CheckProofExistsForBatch", mock.MatchedBy(matchProverCtxFn), mock.AnythingOfType("uint64"), nil).Return(false, nil).Once() + m.synchronizerMock.On("GetVirtualBatchByBatchNumber", mock.Anything, lastVerifiedBatchNum+1).Return(&virtualBatch, nil) + + m.etherman.On("GetLatestVerifiedBatchNum").Return(lastVerifiedBatchNum, nil) + m.stateMock.On("CheckProofExistsForBatch", mock.MatchedBy(matchProverCtxFn), mock.AnythingOfType("uint64"), nil).Return(false, nil) sequence := synchronizer.SequencedBatches{ FromBatchNumber: uint64(10), ToBatchNumber: uint64(20), } - m.synchronizerMock.On("GetSequenceByBatchNumber", mock.MatchedBy(matchProverCtxFn), lastVerifiedBatchNum+1).Return(&sequence, nil).Once() + m.synchronizerMock.On("GetSequenceByBatchNumber", mock.MatchedBy(matchProverCtxFn), lastVerifiedBatchNum+1).Return(&sequence, nil) rpcBatch := rpctypes.NewRPCBatch(lastVerifiedBatchNum+1, common.Hash{}, []string{}, batchL2Data, common.Hash{}, common.BytesToHash([]byte("mock LocalExitRoot")), common.BytesToHash([]byte("mock StateRoot")), common.Address{}, false) rpcBatch.SetLastL2BLockTimestamp(uint64(time.Now().Unix())) - m.rpcMock.On("GetBatch", lastVerifiedBatchNum+1).Return(rpcBatch, nil) - m.stateMock.On("AddSequence", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Return(nil).Once() + m.rpcMock.On("GetWitness", lastVerifiedBatchNum+1, false).Return([]byte("witness"), nil) + m.stateMock.On("AddSequence", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Return(nil) m.stateMock.On("AddGeneratedProof", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Run( func(args mock.Arguments) { proof, ok := args[1].(*state.Proof) @@ -1693,57 +1725,35 @@ func Test_tryGenerateBatchProof(t *testing.T) { assert.Equal(&proverID, proof.ProverID) assert.InDelta(time.Now().Unix(), proof.GeneratingSince.Unix(), float64(time.Second)) }, - ).Return(nil).Once() + ).Return(nil) - m.synchronizerMock.On("GetLeafsByL1InfoRoot", mock.Anything, l1InfoRoot).Return(l1InfoTreeLeaf, nil).Twice() + m.synchronizerMock.On("GetLeafsByL1InfoRoot", mock.Anything, l1InfoRoot).Return(l1InfoTreeLeaf, nil) m.synchronizerMock.On("GetL1InfoTreeLeaves", mock.Anything, mock.Anything).Return(map[uint32]synchronizer.L1InfoTreeLeaf{ 1: { BlockNumber: uint64(35), }, - }, nil).Twice() - - expectedInputProver, err := a.buildInputProver(context.Background(), &batch, []byte("witness")) - require.NoError(err) - - m.rpcMock.On("GetWitness", lastVerifiedBatchNum+1, false).Return([]byte("witness"), nil) - - virtualBatch := synchronizer.VirtualBatch{ - BatchNumber: lastVerifiedBatchNum + 1, - BatchL2Data: batchL2Data, - L1InfoRoot: &l1InfoRoot, - } - - m.synchronizerMock.On("GetVirtualBatchByBatchNumber", mock.Anything, lastVerifiedBatchNum+1).Return(&virtualBatch, nil).Once() + }, nil) - m.proverMock.On("BatchProof", expectedInputProver).Return(&proofID, nil).Once() - m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return("", common.Hash{}, errTest).Once() - m.stateMock.On("DeleteGeneratedProofs", mock.MatchedBy(matchAggregatorCtxFn), batchToProve.BatchNumber, batchToProve.BatchNumber, nil).Return(errTest).Once() + m.rpcMock.On("GetBatch", lastVerifiedBatchNum+1).Return(rpcBatch, nil) + m.proverMock.On("BatchProof", mock.Anything).Return(&proofID, nil) + m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return("", common.Hash{}, common.Hash{}, nil) + m.stateMock.On("UpdateGeneratedProof", mock.Anything, mock.Anything, nil).Return(nil) }, asserts: func(result bool, a *Aggregator, err error) { - assert.False(result) - assert.ErrorIs(err, errTest) + assert.True(result) + assert.NoError(err) }, }, { - name: "not time to send final ok", + name: "DeleteBatchProofs error after WaitRecursiveProof prover error", setup: func(m mox, a *Aggregator) { - a.cfg.BatchProofSanityCheckEnabled = false - m.proverMock.On("Name").Return(proverName).Times(3) - m.proverMock.On("ID").Return(proverID).Times(3) - m.proverMock.On("Addr").Return("addr").Times(3) + m.proverMock.On("Name").Return(proverName).Twice() + m.proverMock.On("ID").Return(proverID).Twice() + m.proverMock.On("Addr").Return("addr").Twice() batchL2Data, err := hex.DecodeString(codedL2Block1) require.NoError(err) l1InfoRoot := common.HexToHash("0x057e9950fbd39b002e323f37c2330d0c096e66919e24cc96fb4b2dfa8f4af782") - batch := state.Batch{ - BatchNumber: lastVerifiedBatchNum + 1, - BatchL2Data: batchL2Data, - L1InfoRoot: l1InfoRoot, - Timestamp: time.Now(), - Coinbase: common.Address{}, - ChainID: uint64(1), - ForkID: uint64(12), - } m.etherman.On("GetLatestVerifiedBatchNum").Return(lastVerifiedBatchNum, nil).Once() m.stateMock.On("CheckProofExistsForBatch", mock.MatchedBy(matchProverCtxFn), mock.AnythingOfType("uint64"), nil).Return(false, nil).Once() @@ -1752,7 +1762,9 @@ func Test_tryGenerateBatchProof(t *testing.T) { ToBatchNumber: uint64(20), } m.synchronizerMock.On("GetSequenceByBatchNumber", mock.MatchedBy(matchProverCtxFn), lastVerifiedBatchNum+1).Return(&sequence, nil).Once() - + rpcBatch := rpctypes.NewRPCBatch(lastVerifiedBatchNum+1, common.Hash{}, []string{}, batchL2Data, common.Hash{}, common.BytesToHash([]byte("mock LocalExitRoot")), common.BytesToHash([]byte("mock StateRoot")), common.Address{}, false) + rpcBatch.SetLastL2BLockTimestamp(uint64(time.Now().Unix())) + m.rpcMock.On("GetBatch", lastVerifiedBatchNum+1).Return(rpcBatch, nil) m.stateMock.On("AddSequence", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Return(nil).Once() m.stateMock.On("AddGeneratedProof", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Run( func(args mock.Arguments) { @@ -1768,81 +1780,12 @@ func Test_tryGenerateBatchProof(t *testing.T) { }, ).Return(nil).Once() - m.synchronizerMock.On("GetLeafsByL1InfoRoot", mock.Anything, l1InfoRoot).Return(l1InfoTreeLeaf, nil).Twice() + m.synchronizerMock.On("GetLeafsByL1InfoRoot", mock.Anything, l1InfoRoot).Return(l1InfoTreeLeaf, nil) m.synchronizerMock.On("GetL1InfoTreeLeaves", mock.Anything, mock.Anything).Return(map[uint32]synchronizer.L1InfoTreeLeaf{ 1: { BlockNumber: uint64(35), }, - }, nil).Twice() - - rpcBatch := rpctypes.NewRPCBatch(lastVerifiedBatchNum+1, common.Hash{}, []string{}, batchL2Data, common.Hash{}, common.BytesToHash([]byte("mock LocalExitRoot")), common.BytesToHash([]byte("mock StateRoot")), common.Address{}, false) - rpcBatch.SetLastL2BLockTimestamp(uint64(time.Now().Unix())) - m.rpcMock.On("GetBatch", lastVerifiedBatchNum+1).Return(rpcBatch, nil) - m.rpcMock.On("GetWitness", lastVerifiedBatchNum+1, false).Return([]byte("witness"), nil) - - virtualBatch := synchronizer.VirtualBatch{ - BatchNumber: lastVerifiedBatchNum + 1, - BatchL2Data: batchL2Data, - L1InfoRoot: &l1InfoRoot, - } - - m.synchronizerMock.On("GetVirtualBatchByBatchNumber", mock.Anything, lastVerifiedBatchNum+1).Return(&virtualBatch, nil).Once() - - expectedInputProver, err := a.buildInputProver(context.Background(), &batch, []byte("witness")) - require.NoError(err) - - m.proverMock.On("BatchProof", expectedInputProver).Return(&proofID, nil).Once() - m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return(recursiveProof, common.Hash{}, nil).Once() - m.stateMock.On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), mock.Anything, nil).Run( - func(args mock.Arguments) { - proof, ok := args[1].(*state.Proof) - if !ok { - t.Fatalf("expected args[1] to be of type *state.Proof, got %T", args[1]) - } - assert.Equal(batchToProve.BatchNumber, proof.BatchNumber) - assert.Equal(batchToProve.BatchNumber, proof.BatchNumberFinal) - assert.Equal(&proverName, proof.Prover) - assert.Equal(&proverID, proof.ProverID) - assert.Equal("", proof.InputProver) - assert.Equal(recursiveProof, proof.Proof) - assert.Nil(proof.GeneratingSince) - }, - ).Return(nil).Once() - }, - asserts: func(result bool, a *Aggregator, err error) { - assert.True(result) - assert.NoError(err) - }, - }, - { - name: "time to send final, state error ok", - setup: func(m mox, a *Aggregator) { - a.cfg.VerifyProofInterval = types.NewDuration(0) - a.cfg.BatchProofSanityCheckEnabled = false - m.proverMock.On("Name").Return(proverName).Times(3) - m.proverMock.On("ID").Return(proverID).Times(3) - m.proverMock.On("Addr").Return("addr").Times(3) - - batchL2Data, err := hex.DecodeString(codedL2Block1) - require.NoError(err) - l1InfoRoot := common.HexToHash("0x057e9950fbd39b002e323f37c2330d0c096e66919e24cc96fb4b2dfa8f4af782") - batch := state.Batch{ - BatchNumber: lastVerifiedBatchNum + 1, - BatchL2Data: batchL2Data, - L1InfoRoot: l1InfoRoot, - Timestamp: time.Now(), - Coinbase: common.Address{}, - ChainID: uint64(1), - ForkID: uint64(12), - } - - m.etherman.On("GetLatestVerifiedBatchNum").Return(lastVerifiedBatchNum, nil).Once() - m.stateMock.On("CheckProofExistsForBatch", mock.MatchedBy(matchProverCtxFn), mock.AnythingOfType("uint64"), nil).Return(false, nil).Once() - sequence := synchronizer.SequencedBatches{ - FromBatchNumber: uint64(10), - ToBatchNumber: uint64(20), - } - m.synchronizerMock.On("GetSequenceByBatchNumber", mock.MatchedBy(matchProverCtxFn), lastVerifiedBatchNum+1).Return(&sequence, nil).Once() + }, nil) m.rpcMock.On("GetWitness", lastVerifiedBatchNum+1, false).Return([]byte("witness"), nil) @@ -1854,63 +1797,18 @@ func Test_tryGenerateBatchProof(t *testing.T) { m.synchronizerMock.On("GetVirtualBatchByBatchNumber", mock.Anything, lastVerifiedBatchNum+1).Return(&virtualBatch, nil).Once() - m.rpcMock.On("GetWitness", lastVerifiedBatchNum+1, false).Return([]byte("witness"), nil) - rpcBatch := rpctypes.NewRPCBatch(lastVerifiedBatchNum+1, common.Hash{}, []string{}, batchL2Data, common.Hash{}, common.BytesToHash([]byte("mock LocalExitRoot")), common.BytesToHash([]byte("mock StateRoot")), common.Address{}, false) - rpcBatch.SetLastL2BLockTimestamp(uint64(time.Now().Unix())) - m.rpcMock.On("GetBatch", lastVerifiedBatchNum+1).Return(rpcBatch, nil) - - m.stateMock.On("AddSequence", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Return(nil).Once() - m.stateMock.On("AddGeneratedProof", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Run( - func(args mock.Arguments) { - proof, ok := args[1].(*state.Proof) - if !ok { - t.Fatalf("expected args[1] to be of type *state.Proof, got %T", args[1]) - } - assert.Equal(batchToProve.BatchNumber, proof.BatchNumber) - assert.Equal(batchToProve.BatchNumber, proof.BatchNumberFinal) - assert.Equal(&proverName, proof.Prover) - assert.Equal(&proverID, proof.ProverID) - assert.InDelta(time.Now().Unix(), proof.GeneratingSince.Unix(), float64(time.Second)) - }, - ).Return(nil).Once() - - m.synchronizerMock.On("GetLeafsByL1InfoRoot", mock.Anything, l1InfoRoot).Return(l1InfoTreeLeaf, nil).Twice() - m.synchronizerMock.On("GetL1InfoTreeLeaves", mock.Anything, mock.Anything).Return(map[uint32]synchronizer.L1InfoTreeLeaf{ - 1: { - BlockNumber: uint64(35), - }, - }, nil).Twice() - - expectedInputProver, err := a.buildInputProver(context.Background(), &batch, []byte("witness")) - require.NoError(err) - - m.proverMock.On("BatchProof", expectedInputProver).Return(&proofID, nil).Once() - m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return(recursiveProof, common.Hash{}, nil).Once() - m.etherman.On("GetLatestVerifiedBatchNum").Return(uint64(42), errTest).Once() - m.stateMock.On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), mock.Anything, nil).Run( - func(args mock.Arguments) { - proof, ok := args[1].(*state.Proof) - if !ok { - t.Fatalf("expected args[1] to be of type *state.Proof, got %T", args[1]) - } - assert.Equal(batchToProve.BatchNumber, proof.BatchNumber) - assert.Equal(batchToProve.BatchNumber, proof.BatchNumberFinal) - assert.Equal(&proverName, proof.Prover) - assert.Equal(&proverID, proof.ProverID) - assert.Equal("", proof.InputProver) - assert.Equal(recursiveProof, proof.Proof) - assert.Nil(proof.GeneratingSince) - }, - ).Return(nil).Once() + m.proverMock.On("BatchProof", mock.Anything).Return(&proofID, nil).Once() + m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return("", common.Hash{}, common.Hash{}, errTest).Once() + m.stateMock.On("DeleteGeneratedProofs", mock.Anything, batchToProve.BatchNumber, batchToProve.BatchNumber, nil).Return(errTest).Once() }, asserts: func(result bool, a *Aggregator, err error) { - assert.True(result) - assert.NoError(err) + assert.False(result) + assert.ErrorIs(err, errTest) }, }, } - for _, tc := range testCases { + for x, tc := range testCases { t.Run(tc.name, func(t *testing.T) { stateMock := mocks.NewStateInterfaceMock(t) ethTxManager := mocks.NewEthTxManagerClientMock(t) @@ -1935,6 +1833,9 @@ func Test_tryGenerateBatchProof(t *testing.T) { accInputHashes: make(map[uint64]common.Hash), accInputHashesMutex: &sync.Mutex{}, } + if x > 0 { + a.accInputHashes = populateAccInputHashes() + } aggregatorCtx := context.WithValue(context.Background(), "owner", ownerAggregator) //nolint:staticcheck a.ctx, a.exit = context.WithCancel(aggregatorCtx) @@ -1960,6 +1861,14 @@ func Test_tryGenerateBatchProof(t *testing.T) { } } +func populateAccInputHashes() map[uint64]common.Hash { + accInputHashes := make(map[uint64]common.Hash) + for i := 10; i < 200; i++ { + accInputHashes[uint64(i)] = common.BytesToHash([]byte(fmt.Sprintf("hash%d", i))) + } + return accInputHashes +} + func Test_accInputHashFunctions(t *testing.T) { aggregator := Aggregator{ accInputHashes: make(map[uint64]common.Hash), @@ -1980,3 +1889,32 @@ func Test_accInputHashFunctions(t *testing.T) { aggregator.removeAccInputHashes(1, 2) assert.Equal(t, 0, len(aggregator.accInputHashes)) } + +func Test_sanityChecks(t *testing.T) { + batchToProve := state.Batch{ + BatchNumber: 1, + StateRoot: common.HexToHash("0x01"), + AccInputHash: common.HexToHash("0x02"), + } + + aggregator := Aggregator{ + accInputHashes: make(map[uint64]common.Hash), + accInputHashesMutex: &sync.Mutex{}, + } + + aggregator.performSanityChecks(log.GetDefaultLogger(), batchToProve.StateRoot, batchToProve.AccInputHash, &batchToProve) + + // Halt by SR sanity check + go func() { + aggregator.performSanityChecks(log.GetDefaultLogger(), common.HexToHash("0x03"), batchToProve.AccInputHash, &batchToProve) + time.Sleep(5 * time.Second) + return + }() + + // Halt by AIH sanity check + go func() { + aggregator.performSanityChecks(log.GetDefaultLogger(), batchToProve.StateRoot, common.HexToHash("0x04"), &batchToProve) + time.Sleep(5 * time.Second) + return + }() +} diff --git a/aggregator/interfaces.go b/aggregator/interfaces.go index 81f63d94..f1673c46 100644 --- a/aggregator/interfaces.go +++ b/aggregator/interfaces.go @@ -30,7 +30,7 @@ type ProverInterface interface { BatchProof(input *prover.StatelessInputProver) (*string, error) AggregatedProof(inputProof1, inputProof2 string) (*string, error) FinalProof(inputProof string, aggregatorAddr string) (*string, error) - WaitRecursiveProof(ctx context.Context, proofID string) (string, common.Hash, error) + WaitRecursiveProof(ctx context.Context, proofID string) (string, common.Hash, common.Hash, error) WaitFinalProof(ctx context.Context, proofID string) (*prover.FinalProof, error) } diff --git a/aggregator/mocks/mock_prover.go b/aggregator/mocks/mock_prover.go index 72bd66dc..b6ce1011 100644 --- a/aggregator/mocks/mock_prover.go +++ b/aggregator/mocks/mock_prover.go @@ -220,7 +220,7 @@ func (_m *ProverInterfaceMock) WaitFinalProof(ctx context.Context, proofID strin } // WaitRecursiveProof provides a mock function with given fields: ctx, proofID -func (_m *ProverInterfaceMock) WaitRecursiveProof(ctx context.Context, proofID string) (string, common.Hash, error) { +func (_m *ProverInterfaceMock) WaitRecursiveProof(ctx context.Context, proofID string) (string, common.Hash, common.Hash, error) { ret := _m.Called(ctx, proofID) if len(ret) == 0 { @@ -229,8 +229,9 @@ func (_m *ProverInterfaceMock) WaitRecursiveProof(ctx context.Context, proofID s var r0 string var r1 common.Hash - var r2 error - if rf, ok := ret.Get(0).(func(context.Context, string) (string, common.Hash, error)); ok { + var r2 common.Hash + var r3 error + if rf, ok := ret.Get(0).(func(context.Context, string) (string, common.Hash, common.Hash, error)); ok { return rf(ctx, proofID) } if rf, ok := ret.Get(0).(func(context.Context, string) string); ok { @@ -247,13 +248,21 @@ func (_m *ProverInterfaceMock) WaitRecursiveProof(ctx context.Context, proofID s } } - if rf, ok := ret.Get(2).(func(context.Context, string) error); ok { + if rf, ok := ret.Get(2).(func(context.Context, string) common.Hash); ok { r2 = rf(ctx, proofID) } else { - r2 = ret.Error(2) + if ret.Get(2) != nil { + r2 = ret.Get(2).(common.Hash) + } + } + + if rf, ok := ret.Get(3).(func(context.Context, string) error); ok { + r3 = rf(ctx, proofID) + } else { + r3 = ret.Error(3) } - return r0, r1, r2 + return r0, r1, r2, r3 } // NewProverInterfaceMock creates a new instance of ProverInterfaceMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. diff --git a/aggregator/prover/mocks/mock_channel.go b/aggregator/prover/mocks/mock_channel.go new file mode 100644 index 00000000..d125896d --- /dev/null +++ b/aggregator/prover/mocks/mock_channel.go @@ -0,0 +1,176 @@ +// Code generated by mockery v2.39.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + metadata "google.golang.org/grpc/metadata" + + prover "github.com/0xPolygon/cdk/aggregator/prover" +) + +// ChannelMock is an autogenerated mock type for the AggregatorService_ChannelServer type +type ChannelMock struct { + mock.Mock +} + +// Context provides a mock function with given fields: +func (_m *ChannelMock) Context() context.Context { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Context") + } + + var r0 context.Context + if rf, ok := ret.Get(0).(func() context.Context); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(context.Context) + } + } + + return r0 +} + +// Recv provides a mock function with given fields: +func (_m *ChannelMock) Recv() (*prover.ProverMessage, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Recv") + } + + var r0 *prover.ProverMessage + var r1 error + if rf, ok := ret.Get(0).(func() (*prover.ProverMessage, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() *prover.ProverMessage); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*prover.ProverMessage) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RecvMsg provides a mock function with given fields: m +func (_m *ChannelMock) RecvMsg(m interface{}) error { + ret := _m.Called(m) + + if len(ret) == 0 { + panic("no return value specified for RecvMsg") + } + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(m) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Send provides a mock function with given fields: _a0 +func (_m *ChannelMock) Send(_a0 *prover.AggregatorMessage) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for Send") + } + + var r0 error + if rf, ok := ret.Get(0).(func(*prover.AggregatorMessage) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SendHeader provides a mock function with given fields: _a0 +func (_m *ChannelMock) SendHeader(_a0 metadata.MD) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for SendHeader") + } + + var r0 error + if rf, ok := ret.Get(0).(func(metadata.MD) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SendMsg provides a mock function with given fields: m +func (_m *ChannelMock) SendMsg(m interface{}) error { + ret := _m.Called(m) + + if len(ret) == 0 { + panic("no return value specified for SendMsg") + } + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(m) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetHeader provides a mock function with given fields: _a0 +func (_m *ChannelMock) SetHeader(_a0 metadata.MD) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for SetHeader") + } + + var r0 error + if rf, ok := ret.Get(0).(func(metadata.MD) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetTrailer provides a mock function with given fields: _a0 +func (_m *ChannelMock) SetTrailer(_a0 metadata.MD) { + _m.Called(_a0) +} + +// NewChannelMock creates a new instance of ChannelMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewChannelMock(t interface { + mock.TestingT + Cleanup(func()) +}) *ChannelMock { + mock := &ChannelMock{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/aggregator/prover/prover.go b/aggregator/prover/prover.go index 8cb13b1d..ad9c9895 100644 --- a/aggregator/prover/prover.go +++ b/aggregator/prover/prover.go @@ -18,8 +18,10 @@ import ( ) const ( - stateRootStartIndex = 19 - stateRootFinalIndex = stateRootStartIndex + 8 + StateRootStartIndex = 19 + StateRootFinalIndex = StateRootStartIndex + 8 + AccInputHashStartIndex = 27 + AccInputHashFinalIndex = AccInputHashStartIndex + 8 ) var ( @@ -282,30 +284,36 @@ func (p *Prover) CancelProofRequest(proofID string) error { // WaitRecursiveProof waits for a recursive proof to be generated by the prover // and returns it. -func (p *Prover) WaitRecursiveProof(ctx context.Context, proofID string) (string, common.Hash, error) { +func (p *Prover) WaitRecursiveProof(ctx context.Context, proofID string) (string, common.Hash, common.Hash, error) { res, err := p.waitProof(ctx, proofID) if err != nil { - return "", common.Hash{}, err + return "", common.Hash{}, common.Hash{}, err } resProof, ok := res.Proof.(*GetProofResponse_RecursiveProof) if !ok { - return "", common.Hash{}, fmt.Errorf( + return "", common.Hash{}, common.Hash{}, fmt.Errorf( "%w, wanted %T, got %T", ErrBadProverResponse, &GetProofResponse_RecursiveProof{}, res.Proof, ) } - sr, err := GetStateRootFromProof(p.logger, resProof.RecursiveProof) + sr, err := GetSanityCheckHashFromProof(p.logger, resProof.RecursiveProof, StateRootStartIndex, StateRootFinalIndex) if err != nil && sr != (common.Hash{}) { p.logger.Errorf("Error getting state root from proof: %v", err) } + accInputHash, err := GetSanityCheckHashFromProof(p.logger, resProof.RecursiveProof, + AccInputHashStartIndex, AccInputHashFinalIndex) + if err != nil && accInputHash != (common.Hash{}) { + p.logger.Errorf("Error getting acc input hash from proof: %v", err) + } + if sr == (common.Hash{}) { p.logger.Info("Recursive proof does not contain state root. Possibly mock prover is in use.") } - return resProof.RecursiveProof, sr, nil + return resProof.RecursiveProof, sr, accInputHash, nil } // WaitFinalProof waits for the final proof to be generated by the prover and @@ -395,11 +403,8 @@ func (p *Prover) call(req *AggregatorMessage) (*ProverMessage, error) { return res, nil } -// GetStateRootFromProof returns the state root from the proof. -func GetStateRootFromProof(logger *log.Logger, proof string) (common.Hash, error) { - // Log received proof - logger.Debugf("Received proof to get SR from: %s", proof) - +// GetSanityCheckHashFromProof returns info from the proof +func GetSanityCheckHashFromProof(logger *log.Logger, proof string, startIndex, endIndex int) (common.Hash, error) { type Publics struct { Publics []string `mapstructure:"publics"` } @@ -420,7 +425,7 @@ func GetStateRootFromProof(logger *log.Logger, proof string) (common.Hash, error v [8]uint64 j = 0 ) - for i := stateRootStartIndex; i < stateRootFinalIndex; i++ { + for i := startIndex; i < endIndex; i++ { u64, err := strconv.ParseInt(publics.Publics[i], 10, 64) if err != nil { logger.Fatal(err) diff --git a/aggregator/prover/prover_test.go b/aggregator/prover/prover_test.go index 737d5592..438718a3 100644 --- a/aggregator/prover/prover_test.go +++ b/aggregator/prover/prover_test.go @@ -1,12 +1,19 @@ package prover_test import ( + "context" "fmt" + "net" "os" "testing" + "time" "github.com/0xPolygon/cdk/aggregator/prover" + "github.com/0xPolygon/cdk/aggregator/prover/mocks" + "github.com/0xPolygon/cdk/config/types" "github.com/0xPolygon/cdk/log" + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -18,6 +25,50 @@ type TestStateRoot struct { Publics []string `mapstructure:"publics"` } +func TestProver(t *testing.T) { + mockChannel := mocks.ChannelMock{} + var addr net.Addr + + mockChannel.On("Send", mock.Anything).Return(nil) + mockChannel.On("Recv").Return(&prover.ProverMessage{ + Id: "test", + Response: &prover.ProverMessage_GetStatusResponse{ + GetStatusResponse: &prover.GetStatusResponse{ + Status: prover.GetStatusResponse_STATUS_IDLE, + ProverName: "testName", + ProverId: "testId", + }, + }, + }, nil).Times(1) + + p, err := prover.New(log.GetDefaultLogger(), &mockChannel, addr, types.Duration{Duration: time.Second * 5}) + require.NoError(t, err) + name := p.Name() + require.Equal(t, "testName", name, "name does not match") + address := p.Addr() + require.Equal(t, "", address, "address does not match") + id := p.ID() + require.Equal(t, "testId", id, "id does not match") + + mockChannel.On("Recv").Return(&prover.ProverMessage{ + Id: "test", + Response: &prover.ProverMessage_GetProofResponse{ + GetProofResponse: &prover.GetProofResponse{ + Proof: &prover.GetProofResponse_RecursiveProof{ + RecursiveProof: "this is a proof", + }, + Result: prover.GetProofResponse_RESULT_COMPLETED_OK, + }, + }, + }, nil) + + proof, sr, accinputHash, err := p.WaitRecursiveProof(context.Background(), "proofID") + require.NoError(t, err) + + require.NotNil(t, proof, "proof is nil") + require.NotNil(t, sr, "state root is nil") + require.Equal(t, common.Hash{}, accinputHash, "state root is not empty") +} func TestCalculateStateRoots(t *testing.T) { var expectedStateRoots = map[string]string{ "1871.json": "0x0ed594d8bc0bb38f3190ff25fb1e5b4fe1baf0e2e0c1d7bf3307f07a55d3a60f", @@ -40,13 +91,18 @@ func TestCalculateStateRoots(t *testing.T) { require.NoError(t, err) // Get the state root from the batch proof - fileStateRoot, err := prover.GetStateRootFromProof(log.GetDefaultLogger(), string(data)) + fileStateRoot, err := prover.GetSanityCheckHashFromProof(log.GetDefaultLogger(), string(data), prover.StateRootStartIndex, prover.StateRootFinalIndex) require.NoError(t, err) // Get the expected state root expectedStateRoot, ok := expectedStateRoots[file.Name()] require.True(t, ok, "Expected state root not found") + // Check Acc Input Hash + accInputHash, err := prover.GetSanityCheckHashFromProof(log.GetDefaultLogger(), string(data), prover.AccInputHashStartIndex, prover.AccInputHashFinalIndex) + require.NotEqual(t, common.Hash{}, accInputHash, "Acc Input Hash is empty") + require.NoError(t, err) + // Compare the state roots require.Equal(t, expectedStateRoot, fileStateRoot.String(), "State roots do not match") } diff --git a/go.mod b/go.mod index c42bb806..70ec5e69 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/0xPolygon/cdk-data-availability v0.0.10 github.com/0xPolygon/cdk-rpc v0.0.0-20241004114257-6c3cb6eebfb6 github.com/0xPolygon/zkevm-ethtx-manager v0.2.1 - github.com/0xPolygonHermez/zkevm-synchronizer-l1 v1.0.5 + github.com/0xPolygonHermez/zkevm-synchronizer-l1 v1.0.6 github.com/ethereum/go-ethereum v1.14.8 github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3 github.com/hermeznetwork/tracerr v0.3.2 diff --git a/go.sum b/go.sum index d6cbdb2b..ccf812c4 100644 --- a/go.sum +++ b/go.sum @@ -6,8 +6,8 @@ github.com/0xPolygon/cdk-rpc v0.0.0-20241004114257-6c3cb6eebfb6 h1:FXL/rcO7/GtZ3 github.com/0xPolygon/cdk-rpc v0.0.0-20241004114257-6c3cb6eebfb6/go.mod h1:2scWqMMufrQXu7TikDgQ3BsyaKoX8qP26D6E262vSOg= github.com/0xPolygon/zkevm-ethtx-manager v0.2.1 h1:2Yb+KdJFMpVrS9LIkd658XiWuN+MCTs7SgeWaopXScg= github.com/0xPolygon/zkevm-ethtx-manager v0.2.1/go.mod h1:lqQmzSo2OXEZItD0R4Cd+lqKFxphXEWgqHefVcGDZZc= -github.com/0xPolygonHermez/zkevm-synchronizer-l1 v1.0.5 h1:YmnhuCl349MoNASN0fMeGKU1o9HqJhiZkfMsA/1cTRA= -github.com/0xPolygonHermez/zkevm-synchronizer-l1 v1.0.5/go.mod h1:X4Su/M/+hSISqdl9yomKlRsbTyuZHsRohporyHsP8gg= +github.com/0xPolygonHermez/zkevm-synchronizer-l1 v1.0.6 h1:+XsCHXvQezRdMnkI37Wa/nV4sOZshJavxNzRpH/R6dw= +github.com/0xPolygonHermez/zkevm-synchronizer-l1 v1.0.6/go.mod h1:X4Su/M/+hSISqdl9yomKlRsbTyuZHsRohporyHsP8gg= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/DataDog/zstd v1.5.6 h1:LbEglqepa/ipmmQJUDnSsfvA8e8IStVcGaFWDuxvGOY= github.com/DataDog/zstd v1.5.6/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw= diff --git a/test/Makefile b/test/Makefile index 12f406fd..2e3445db 100644 --- a/test/Makefile +++ b/test/Makefile @@ -59,7 +59,8 @@ generate-mocks-aggregator: ## Generates mocks for aggregator, using mockery tool export "GOROOT=$$(go env GOROOT)" && $$(go env GOPATH)/bin/mockery --name=EthTxManagerClient --dir=../aggregator --output=../aggregator/mocks --outpkg=mocks --structname=EthTxManagerClientMock --filename=mock_eth_tx_manager.go export "GOROOT=$$(go env GOROOT)" && $$(go env GOPATH)/bin/mockery --name=Tx --srcpkg=github.com/jackc/pgx/v4 --output=../aggregator/mocks --outpkg=mocks --structname=DbTxMock --filename=mock_dbtx.go export "GOROOT=$$(go env GOROOT)" && $$(go env GOPATH)/bin/mockery --name=RPCInterface --dir=../aggregator --output=../aggregator/mocks --outpkg=mocks --structname=RPCInterfaceMock --filename=mock_rpc.go - + export "GOROOT=$$(go env GOROOT)" && $$(go env GOPATH)/bin/mockery --name=AggregatorService_ChannelServer --dir=../aggregator/prover --output=../aggregator/prover/mocks --outpkg=mocks --structname=ChannelMock --filename=mock_channel.go + .PHONY: generate-mocks-aggsender generate-mocks-aggsender: ## Generates mocks for aggsender, using mockery tool