From 46a67897a8c92bb7782b23cacb13714f5902594d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Toni=20Ram=C3=ADrez?= <58293609+ToniRamirezM@users.noreply.github.com> Date: Fri, 22 Nov 2024 13:56:41 +0100 Subject: [PATCH] feat: sqlite aggregator (#189) * feat: sqlite aggregator --- aggregator/aggregator.go | 138 +++---- aggregator/aggregator_test.go | 367 +++++++++--------- aggregator/config.go | 13 +- aggregator/db/config.go | 25 -- aggregator/db/db.go | 31 -- aggregator/db/dbstorage/dbstorage.go | 35 ++ aggregator/db/dbstorage/proof.go | 356 +++++++++++++++++ aggregator/db/dbstorage/proof_test.go | 150 +++++++ aggregator/db/dbstorage/sequence.go | 21 + aggregator/db/migrations.go | 37 +- aggregator/db/migrations/0001.sql | 34 +- aggregator/db/migrations/0002.sql | 8 - aggregator/db/migrations/0003.sql | 7 - aggregator/db/migrations/0004.sql | 23 -- aggregator/db/migrations_test.go | 9 + aggregator/interfaces.go | 29 +- aggregator/mocks/mock_dbtx.go | 350 ----------------- .../mocks/{mock_state.go => mock_storage.go} | 106 ++--- aggregator/mocks/mock_txer.go | 163 ++++++++ aggregator/profitabilitychecker.go | 92 ----- cmd/run.go | 41 +- config/default.go | 9 +- scripts/local_config | 3 - state/config.go | 13 - state/interfaces.go | 26 -- state/pgstatestorage/interfaces.go | 14 - state/pgstatestorage/pgstatestorage.go | 29 -- state/pgstatestorage/proof.go | 266 ------------- state/pgstatestorage/sequence.go | 21 - state/state.go | 40 -- test/Makefile | 4 +- .../kurtosis-cdk-node-config.toml.template | 8 - test/config/test.config.toml | 8 - 33 files changed, 1080 insertions(+), 1396 deletions(-) delete mode 100644 aggregator/db/config.go delete mode 100644 aggregator/db/db.go create mode 100644 aggregator/db/dbstorage/dbstorage.go create mode 100644 aggregator/db/dbstorage/proof.go create mode 100644 aggregator/db/dbstorage/proof_test.go create mode 100644 aggregator/db/dbstorage/sequence.go delete mode 100644 aggregator/db/migrations/0002.sql delete mode 100644 aggregator/db/migrations/0003.sql delete mode 100644 aggregator/db/migrations/0004.sql delete mode 100644 aggregator/mocks/mock_dbtx.go rename aggregator/mocks/{mock_state.go => mock_storage.go} (54%) create mode 100644 aggregator/mocks/mock_txer.go delete mode 100644 aggregator/profitabilitychecker.go delete mode 100644 state/config.go delete mode 100644 state/interfaces.go delete mode 100644 state/pgstatestorage/interfaces.go delete mode 100644 state/pgstatestorage/pgstatestorage.go delete mode 100644 state/pgstatestorage/proof.go delete mode 100644 state/pgstatestorage/sequence.go delete mode 100644 state/state.go diff --git a/aggregator/aggregator.go b/aggregator/aggregator.go index 0659180f..d9ecb1a2 100644 --- a/aggregator/aggregator.go +++ b/aggregator/aggregator.go @@ -17,6 +17,7 @@ import ( cdkTypes "github.com/0xPolygon/cdk-rpc/types" "github.com/0xPolygon/cdk/agglayer" + "github.com/0xPolygon/cdk/aggregator/db/dbstorage" ethmanTypes "github.com/0xPolygon/cdk/aggregator/ethmantypes" "github.com/0xPolygon/cdk/aggregator/prover" cdkcommon "github.com/0xPolygon/cdk/common" @@ -59,7 +60,7 @@ type Aggregator struct { cfg Config logger *log.Logger - state StateInterface + storage StorageInterface etherman Etherman ethTxManager EthTxManagerClient l1Syncr synchronizer.Synchronizer @@ -67,10 +68,9 @@ type Aggregator struct { accInputHashes map[uint64]common.Hash accInputHashesMutex *sync.Mutex - profitabilityChecker aggregatorTxProfitabilityChecker timeSendFinalProof time.Time timeCleanupLockedProofs types.Duration - stateDBMutex *sync.Mutex + storageMutex *sync.Mutex timeSendFinalProofMutex *sync.RWMutex finalProof chan finalProofMsg @@ -93,21 +93,7 @@ func New( ctx context.Context, cfg Config, logger *log.Logger, - stateInterface StateInterface, etherman Etherman) (*Aggregator, error) { - var profitabilityChecker aggregatorTxProfitabilityChecker - - switch cfg.TxProfitabilityCheckerType { - case ProfitabilityBase: - profitabilityChecker = NewTxProfitabilityCheckerBase( - stateInterface, cfg.IntervalAfterWhichBatchConsolidateAnyway.Duration, cfg.TxProfitabilityMinReward.Int, - ) - case ProfitabilityAcceptAll: - profitabilityChecker = NewTxProfitabilityCheckerAcceptAll( - stateInterface, cfg.IntervalAfterWhichBatchConsolidateAnyway.Duration, - ) - } - // Create ethtxmanager client cfg.EthTxManager.Log = ethtxlog.Config{ Environment: ethtxlog.LogEnvironment(cfg.Log.Environment), @@ -150,18 +136,22 @@ func New( } } + storage, err := dbstorage.NewDBStorage(cfg.DBPath) + if err != nil { + return nil, err + } + a := &Aggregator{ ctx: ctx, cfg: cfg, logger: logger, - state: stateInterface, + storage: storage, etherman: etherman, ethTxManager: ethTxManager, l1Syncr: l1Syncr, accInputHashes: make(map[uint64]common.Hash), accInputHashesMutex: &sync.Mutex{}, - profitabilityChecker: profitabilityChecker, - stateDBMutex: &sync.Mutex{}, + storageMutex: &sync.Mutex{}, timeSendFinalProofMutex: &sync.RWMutex{}, timeCleanupLockedProofs: cfg.CleanupLockedProofsInterval, finalProof: make(chan finalProofMsg), @@ -213,7 +203,7 @@ func (a *Aggregator) handleReorg(reorgData synchronizer.ReorgExecutionResult) { a.logger.Errorf("Error getting last virtual batch number: %v", err) } else { // Delete wip proofs - err = a.state.DeleteUngeneratedProofs(a.ctx, nil) + err = a.storage.DeleteUngeneratedProofs(a.ctx, nil) if err != nil { a.logger.Errorf("Error deleting ungenerated proofs: %v", err) } else { @@ -221,7 +211,7 @@ func (a *Aggregator) handleReorg(reorgData synchronizer.ReorgExecutionResult) { } // Delete any proof for the batches that have been rolled back - err = a.state.DeleteGeneratedProofs(a.ctx, lastVBatchNumber+1, maxDBBigIntValue, nil) + err = a.storage.DeleteGeneratedProofs(a.ctx, lastVBatchNumber+1, maxDBBigIntValue, nil) if err != nil { a.logger.Errorf("Error deleting generated proofs: %v", err) } else { @@ -275,7 +265,7 @@ func (a *Aggregator) handleRollbackBatches(rollbackData synchronizer.RollbackBat // Delete wip proofs if err == nil { - err = a.state.DeleteUngeneratedProofs(a.ctx, nil) + err = a.storage.DeleteUngeneratedProofs(a.ctx, nil) if err != nil { a.logger.Errorf("Error deleting ungenerated proofs: %v", err) } else { @@ -285,7 +275,7 @@ func (a *Aggregator) handleRollbackBatches(rollbackData synchronizer.RollbackBat // Delete any proof for the batches that have been rolled back if err == nil { - err = a.state.DeleteGeneratedProofs(a.ctx, rollbackData.LastBatchNumber+1, maxDBBigIntValue, nil) + err = a.storage.DeleteGeneratedProofs(a.ctx, rollbackData.LastBatchNumber+1, maxDBBigIntValue, nil) if err != nil { a.logger.Errorf("Error deleting generated proofs: %v", err) } else { @@ -336,7 +326,7 @@ func (a *Aggregator) Start() error { grpchealth.RegisterHealthServer(a.srv, healthService) // Delete ungenerated recursive proofs - err = a.state.DeleteUngeneratedProofs(a.ctx, nil) + err = a.storage.DeleteUngeneratedProofs(a.ctx, nil) if err != nil { return fmt.Errorf("failed to initialize proofs cache %w", err) } @@ -608,7 +598,7 @@ func (a *Aggregator) handleFailureToAddVerifyBatchToBeMonitored(ctx context.Cont "batches", fmt.Sprintf("%d-%d", proof.BatchNumber, proof.BatchNumberFinal), ) proof.GeneratingSince = nil - err := a.state.UpdateGeneratedProof(ctx, proof, nil) + err := a.storage.UpdateGeneratedProof(ctx, proof, nil) if err != nil { tmpLogger.Errorf("Failed updating proof state (false): %v", err) } @@ -703,7 +693,7 @@ func (a *Aggregator) tryBuildFinalProof(ctx context.Context, prover ProverInterf if err != nil { // Set the generating state to false for the proof ("unlock" it) proof.GeneratingSince = nil - err2 := a.state.UpdateGeneratedProof(a.ctx, proof, nil) + err2 := a.storage.UpdateGeneratedProof(a.ctx, proof, nil) if err2 != nil { tmpLogger.Errorf("Failed to unlock proof: %v", err2) } @@ -766,7 +756,7 @@ func (a *Aggregator) validateEligibleFinalProof( // We have a proof that contains batches below that the last batch verified, we need to delete this proof a.logger.Warnf("Proof %d-%d lower than next batch to verify %d. Deleting it", proof.BatchNumber, proof.BatchNumberFinal, batchNumberToVerify) - err := a.state.DeleteGeneratedProofs(ctx, proof.BatchNumber, proof.BatchNumberFinal, nil) + err := a.storage.DeleteGeneratedProofs(ctx, proof.BatchNumber, proof.BatchNumberFinal, nil) if err != nil { return false, fmt.Errorf("failed to delete discarded proof, err: %w", err) } @@ -779,7 +769,7 @@ func (a *Aggregator) validateEligibleFinalProof( } } - bComplete, err := a.state.CheckProofContainsCompleteSequences(ctx, proof, nil) + bComplete, err := a.storage.CheckProofContainsCompleteSequences(ctx, proof, nil) if err != nil { return false, fmt.Errorf("failed to check if proof contains complete sequences, %w", err) } @@ -795,11 +785,11 @@ func (a *Aggregator) validateEligibleFinalProof( func (a *Aggregator) getAndLockProofReadyToVerify( ctx context.Context, lastVerifiedBatchNum uint64, ) (*state.Proof, error) { - a.stateDBMutex.Lock() - defer a.stateDBMutex.Unlock() + a.storageMutex.Lock() + defer a.storageMutex.Unlock() // Get proof ready to be verified - proofToVerify, err := a.state.GetProofReadyToVerify(ctx, lastVerifiedBatchNum, nil) + proofToVerify, err := a.storage.GetProofReadyToVerify(ctx, lastVerifiedBatchNum, nil) if err != nil { return nil, err } @@ -807,7 +797,7 @@ func (a *Aggregator) getAndLockProofReadyToVerify( now := time.Now().Round(time.Microsecond) proofToVerify.GeneratingSince = &now - err = a.state.UpdateGeneratedProof(ctx, proofToVerify, nil) + err = a.storage.UpdateGeneratedProof(ctx, proofToVerify, nil) if err != nil { return nil, err } @@ -817,21 +807,21 @@ func (a *Aggregator) getAndLockProofReadyToVerify( func (a *Aggregator) unlockProofsToAggregate(ctx context.Context, proof1 *state.Proof, proof2 *state.Proof) error { // Release proofs from generating state in a single transaction - dbTx, err := a.state.BeginStateTransaction(ctx) + dbTx, err := a.storage.BeginTx(ctx, nil) if err != nil { a.logger.Warnf("Failed to begin transaction to release proof aggregation state, err: %v", err) return err } proof1.GeneratingSince = nil - err = a.state.UpdateGeneratedProof(ctx, proof1, dbTx) + err = a.storage.UpdateGeneratedProof(ctx, proof1, dbTx) if err == nil { proof2.GeneratingSince = nil - err = a.state.UpdateGeneratedProof(ctx, proof2, dbTx) + err = a.storage.UpdateGeneratedProof(ctx, proof2, dbTx) } if err != nil { - if err := dbTx.Rollback(ctx); err != nil { + if err := dbTx.Rollback(); err != nil { err := fmt.Errorf("failed to rollback proof aggregation state: %w", err) a.logger.Error(FirstToUpper(err.Error())) return err @@ -840,7 +830,7 @@ func (a *Aggregator) unlockProofsToAggregate(ctx context.Context, proof1 *state. return fmt.Errorf("failed to release proof aggregation state: %w", err) } - err = dbTx.Commit(ctx) + err = dbTx.Commit() if err != nil { return fmt.Errorf("failed to release proof aggregation state %w", err) } @@ -856,16 +846,16 @@ func (a *Aggregator) getAndLockProofsToAggregate( "proverAddr", prover.Addr(), ) - a.stateDBMutex.Lock() - defer a.stateDBMutex.Unlock() + a.storageMutex.Lock() + defer a.storageMutex.Unlock() - proof1, proof2, err := a.state.GetProofsToAggregate(ctx, nil) + proof1, proof2, err := a.storage.GetProofsToAggregate(ctx, nil) if err != nil { return nil, nil, err } // Set proofs in generating state in a single transaction - dbTx, err := a.state.BeginStateTransaction(ctx) + dbTx, err := a.storage.BeginTx(ctx, nil) if err != nil { tmpLogger.Errorf("Failed to begin transaction to set proof aggregation state, err: %v", err) return nil, nil, err @@ -873,14 +863,14 @@ func (a *Aggregator) getAndLockProofsToAggregate( now := time.Now().Round(time.Microsecond) proof1.GeneratingSince = &now - err = a.state.UpdateGeneratedProof(ctx, proof1, dbTx) + err = a.storage.UpdateGeneratedProof(ctx, proof1, dbTx) if err == nil { proof2.GeneratingSince = &now - err = a.state.UpdateGeneratedProof(ctx, proof2, dbTx) + err = a.storage.UpdateGeneratedProof(ctx, proof2, dbTx) } if err != nil { - if err := dbTx.Rollback(ctx); err != nil { + if err := dbTx.Rollback(); err != nil { err := fmt.Errorf("failed to rollback proof aggregation state %w", err) tmpLogger.Error(FirstToUpper(err.Error())) return nil, nil, err @@ -889,7 +879,7 @@ func (a *Aggregator) getAndLockProofsToAggregate( return nil, nil, fmt.Errorf("failed to set proof aggregation state %w", err) } - err = dbTx.Commit(ctx) + err = dbTx.Commit() if err != nil { return nil, nil, fmt.Errorf("failed to set proof aggregation state %w", err) } @@ -983,16 +973,16 @@ func (a *Aggregator) tryAggregateProofs(ctx context.Context, prover ProverInterf // update the state by removing the 2 aggregated proofs and storing the // newly generated recursive proof - dbTx, err := a.state.BeginStateTransaction(ctx) + dbTx, err := a.storage.BeginTx(ctx, nil) if err != nil { err = fmt.Errorf("failed to begin transaction to update proof aggregation state, %w", err) tmpLogger.Error(FirstToUpper(err.Error())) return false, err } - err = a.state.DeleteGeneratedProofs(ctx, proof1.BatchNumber, proof2.BatchNumberFinal, dbTx) + err = a.storage.DeleteGeneratedProofs(ctx, proof1.BatchNumber, proof2.BatchNumberFinal, dbTx) if err != nil { - if err := dbTx.Rollback(ctx); err != nil { + if err := dbTx.Rollback(); err != nil { err := fmt.Errorf("failed to rollback proof aggregation state, %w", err) tmpLogger.Error(FirstToUpper(err.Error())) return false, err @@ -1005,9 +995,9 @@ func (a *Aggregator) tryAggregateProofs(ctx context.Context, prover ProverInterf now := time.Now().Round(time.Microsecond) proof.GeneratingSince = &now - err = a.state.AddGeneratedProof(ctx, proof, dbTx) + err = a.storage.AddGeneratedProof(ctx, proof, dbTx) if err != nil { - if err := dbTx.Rollback(ctx); err != nil { + if err := dbTx.Rollback(); err != nil { err := fmt.Errorf("failed to rollback proof aggregation state, %w", err) tmpLogger.Error(FirstToUpper(err.Error())) return false, err @@ -1017,7 +1007,7 @@ func (a *Aggregator) tryAggregateProofs(ctx context.Context, prover ProverInterf return false, err } - err = dbTx.Commit(ctx) + err = dbTx.Commit() if err != nil { err = fmt.Errorf("failed to store the recursive proof, %w", err) tmpLogger.Error(FirstToUpper(err.Error())) @@ -1041,7 +1031,7 @@ func (a *Aggregator) tryAggregateProofs(ctx context.Context, prover ProverInterf proof.GeneratingSince = nil // final proof has not been generated, update the recursive proof - err := a.state.UpdateGeneratedProof(a.ctx, proof, nil) + err := a.storage.UpdateGeneratedProof(a.ctx, proof, nil) if err != nil { err = fmt.Errorf("failed to store batch proof result, %w", err) tmpLogger.Error(FirstToUpper(err.Error())) @@ -1073,8 +1063,8 @@ func (a *Aggregator) getAndLockBatchToProve( "proverAddr", prover.Addr(), ) - a.stateDBMutex.Lock() - defer a.stateDBMutex.Unlock() + a.storageMutex.Lock() + defer a.storageMutex.Unlock() // Get last virtual batch number from L1 lastVerifiedBatchNumber, err := a.etherman.GetLatestVerifiedBatchNum() @@ -1088,7 +1078,7 @@ func (a *Aggregator) getAndLockBatchToProve( // Look for the batch number to verify for proofExists { batchNumberToVerify++ - proofExists, err = a.state.CheckProofExistsForBatch(ctx, batchNumberToVerify, nil) + proofExists, err = a.storage.CheckProofExistsForBatch(ctx, batchNumberToVerify, nil) if err != nil { tmpLogger.Infof("Error checking proof exists for batch %d", batchNumberToVerify) @@ -1101,7 +1091,7 @@ func (a *Aggregator) getAndLockBatchToProve( tmpLogger.Warnf("AccInputHash for batch %d is not in memory, "+ "deleting proofs to regenerate acc input hash chain in memory", batchNumberToVerify) - err := a.state.CleanupGeneratedProofs(ctx, math.MaxInt, nil) + err := a.storage.CleanupGeneratedProofs(ctx, math.MaxInt, nil) if err != nil { tmpLogger.Infof("Error cleaning up generated proofs for batch %d", batchNumberToVerify) return nil, nil, nil, err @@ -1201,7 +1191,6 @@ func (a *Aggregator) getAndLockBatchToProve( a.logger.Debugf("Calculated acc input hash for batch %d: %v", batchNumberToVerify, accInputHash) a.logger.Debugf("OldAccInputHash: %v", oldAccInputHash) a.logger.Debugf("L1InfoRoot: %v", virtualBatch.L1InfoRoot) - // a.logger.Debugf("LastL2BLockTimestamp: %v", rpcBatch.LastL2BLockTimestamp()) a.logger.Debugf("TimestampLimit: %v", uint64(sequence.Timestamp.Unix())) a.logger.Debugf("LastCoinbase: %v", rpcBatch.LastCoinbase()) a.logger.Debugf("ForcedBlockHashL1: %v", rpcBatch.ForcedBlockHashL1()) @@ -1242,7 +1231,7 @@ func (a *Aggregator) getAndLockBatchToProve( a.logger.Debugf("Time to get witness for batch %d: %v", batchNumberToVerify, end.Sub(start)) // Store the sequence in aggregator DB - err = a.state.AddSequence(ctx, stateSequence, nil) + err = a.storage.AddSequence(ctx, stateSequence, nil) if err != nil { tmpLogger.Infof("Error storing sequence for batch %d", batchNumberToVerify) @@ -1250,25 +1239,9 @@ func (a *Aggregator) getAndLockBatchToProve( } // All the data required to generate a proof is ready - tmpLogger.Infof("Found virtual batch %d pending to generate proof", virtualBatch.BatchNumber) + tmpLogger.Infof("All information to generate proof for batch %d is ready", virtualBatch.BatchNumber) tmpLogger = tmpLogger.WithFields("batch", virtualBatch.BatchNumber) - tmpLogger.Info("Checking profitability to aggregate batch") - - // pass pol collateral as zero here, bcs in smart contract fee for aggregator is not defined yet - isProfitable, err := a.profitabilityChecker.IsProfitable(ctx, big.NewInt(0)) - if err != nil { - tmpLogger.Errorf("Failed to check aggregator profitability, err: %v", err) - - return nil, nil, nil, err - } - - if !isProfitable { - tmpLogger.Infof("Batch is not profitable, pol collateral %d", big.NewInt(0)) - - return nil, nil, nil, err - } - now := time.Now().Round(time.Microsecond) proof := &state.Proof{ BatchNumber: virtualBatch.BatchNumber, @@ -1279,9 +1252,9 @@ func (a *Aggregator) getAndLockBatchToProve( } // Avoid other prover to process the same batch - err = a.state.AddGeneratedProof(ctx, proof, nil) + err = a.storage.AddGeneratedProof(ctx, proof, nil) if err != nil { - tmpLogger.Errorf("Failed to add batch proof, err: %v", err) + tmpLogger.Errorf("Failed to add batch proof to DB for batch %d, err: %v", virtualBatch.BatchNumber, err) return nil, nil, nil, err } @@ -1317,7 +1290,7 @@ func (a *Aggregator) tryGenerateBatchProof(ctx context.Context, prover ProverInt defer func() { if err != nil { tmpLogger.Debug("Deleting proof in progress") - err2 := a.state.DeleteGeneratedProofs(a.ctx, proof.BatchNumber, proof.BatchNumberFinal, nil) + err2 := a.storage.DeleteGeneratedProofs(a.ctx, proof.BatchNumber, proof.BatchNumberFinal, nil) if err2 != nil { tmpLogger.Errorf("Failed to delete proof in progress, err: %v", err2) } @@ -1377,7 +1350,7 @@ func (a *Aggregator) tryGenerateBatchProof(ctx context.Context, prover ProverInt proof.GeneratingSince = nil // final proof has not been generated, update the batch proof - err := a.state.UpdateGeneratedProof(a.ctx, proof, nil) + err := a.storage.UpdateGeneratedProof(a.ctx, proof, nil) if err != nil { err = fmt.Errorf("failed to store batch proof result, %w", err) tmpLogger.Error(FirstToUpper(err.Error())) @@ -1583,7 +1556,6 @@ func printInputProver(logger *log.Logger, inputProver *prover.StatelessInputProv logger.Debugf("Witness length: %v", len(inputProver.PublicInputs.Witness)) logger.Debugf("BatchL2Data length: %v", len(inputProver.PublicInputs.BatchL2Data)) - // logger.Debugf("Full DataStream: %v", common.Bytes2Hex(inputProver.PublicInputs.DataStream)) logger.Debugf("OldAccInputHash: %v", common.BytesToHash(inputProver.PublicInputs.OldAccInputHash)) logger.Debugf("L1InfoRoot: %v", common.BytesToHash(inputProver.PublicInputs.L1InfoRoot)) logger.Debugf("TimestampLimit: %v", inputProver.PublicInputs.TimestampLimit) @@ -1652,7 +1624,7 @@ func (a *Aggregator) handleMonitoredTxResult(result ethtxtypes.MonitoredTxResult } } - err = a.state.DeleteGeneratedProofs(a.ctx, firstBatch, lastBatch, nil) + err = a.storage.DeleteGeneratedProofs(a.ctx, firstBatch, lastBatch, nil) if err != nil { mTxResultLogger.Errorf("failed to delete generated proofs from %d to %d: %v", firstBatch, lastBatch, err) } @@ -1670,7 +1642,7 @@ func (a *Aggregator) cleanupLockedProofs() { case <-a.ctx.Done(): return case <-time.After(a.timeCleanupLockedProofs.Duration): - n, err := a.state.CleanupLockedProofs(a.ctx, a.cfg.GeneratingProofCleanupThreshold, nil) + n, err := a.storage.CleanupLockedProofs(a.ctx, a.cfg.GeneratingProofCleanupThreshold, nil) if err != nil { a.logger.Errorf("Failed to cleanup locked proofs: %v", err) } diff --git a/aggregator/aggregator_test.go b/aggregator/aggregator_test.go index 8d5b5392..95b55367 100644 --- a/aggregator/aggregator_test.go +++ b/aggregator/aggregator_test.go @@ -6,6 +6,7 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" + "database/sql" "encoding/hex" "encoding/json" "errors" @@ -50,13 +51,14 @@ const ( ) type mox struct { - stateMock *mocks.StateInterfaceMock + storageMock *mocks.StorageInterfaceMock ethTxManager *mocks.EthTxManagerClientMock etherman *mocks.EthermanMock proverMock *mocks.ProverInterfaceMock aggLayerClientMock *agglayer.AgglayerClientMock synchronizerMock *mocks.SynchronizerInterfaceMock rpcMock *mocks.RPCInterfaceMock + txerMock *mocks.TxerMock } func WaitUntil(t *testing.T, wg *sync.WaitGroup, timeout time.Duration) { @@ -76,7 +78,7 @@ func WaitUntil(t *testing.T, wg *sync.WaitGroup, timeout time.Duration) { } func Test_Start(t *testing.T) { - mockState := new(mocks.StateInterfaceMock) + mockStorage := new(mocks.StorageInterfaceMock) mockL1Syncr := new(mocks.SynchronizerInterfaceMock) mockEtherman := new(mocks.EthermanMock) mockEthTxManager := new(mocks.EthTxManagerClientMock) @@ -84,21 +86,21 @@ func Test_Start(t *testing.T) { mockL1Syncr.On("Sync", mock.Anything).Return(nil) mockEtherman.On("GetLatestVerifiedBatchNum").Return(uint64(90), nil).Once() mockEtherman.On("GetBatchAccInputHash", mock.Anything, uint64(90)).Return(common.Hash{}, nil).Once() - mockState.On("DeleteUngeneratedProofs", mock.Anything, nil).Return(nil).Once() - mockState.On("CleanupLockedProofs", mock.Anything, "", nil).Return(int64(0), nil) + mockStorage.On("DeleteUngeneratedProofs", mock.Anything, nil).Return(nil).Once() + mockStorage.On("CleanupLockedProofs", mock.Anything, "", nil).Return(int64(0), nil) mockEthTxManager.On("Start").Return(nil) ctx := context.Background() a := &Aggregator{ - state: mockState, + storage: mockStorage, logger: log.GetDefaultLogger(), halted: atomic.Bool{}, l1Syncr: mockL1Syncr, etherman: mockEtherman, ethTxManager: mockEthTxManager, ctx: ctx, - stateDBMutex: &sync.Mutex{}, + storageMutex: &sync.Mutex{}, timeSendFinalProofMutex: &sync.RWMutex{}, timeCleanupLockedProofs: types.Duration{Duration: 5 * time.Second}, accInputHashes: make(map[uint64]common.Hash), @@ -117,26 +119,26 @@ func Test_handleReorg(t *testing.T) { t.Parallel() mockL1Syncr := new(mocks.SynchronizerInterfaceMock) - mockState := new(mocks.StateInterfaceMock) + mockStorage := new(mocks.StorageInterfaceMock) reorgData := synchronizer.ReorgExecutionResult{} a := &Aggregator{ l1Syncr: mockL1Syncr, - state: mockState, + storage: mockStorage, logger: log.GetDefaultLogger(), halted: atomic.Bool{}, ctx: context.Background(), } mockL1Syncr.On("GetLastestVirtualBatchNumber", mock.Anything).Return(uint64(100), nil).Once() - mockState.On("DeleteGeneratedProofs", mock.Anything, mock.Anything, mock.Anything, nil).Return(nil).Once() - mockState.On("DeleteUngeneratedProofs", mock.Anything, nil).Return(nil).Once() + mockStorage.On("DeleteGeneratedProofs", mock.Anything, mock.Anything, mock.Anything, nil).Return(nil).Once() + mockStorage.On("DeleteUngeneratedProofs", mock.Anything, nil).Return(nil).Once() go a.handleReorg(reorgData) time.Sleep(3 * time.Second) assert.True(t, a.halted.Load()) - mockState.AssertExpectations(t) + mockStorage.AssertExpectations(t) mockL1Syncr.AssertExpectations(t) } @@ -144,7 +146,7 @@ func Test_handleRollbackBatches(t *testing.T) { t.Parallel() mockEtherman := new(mocks.EthermanMock) - mockState := new(mocks.StateInterfaceMock) + mockStorage := new(mocks.StorageInterfaceMock) // Test data rollbackData := synchronizer.RollbackBatchesData{ @@ -153,13 +155,13 @@ func Test_handleRollbackBatches(t *testing.T) { mockEtherman.On("GetLatestVerifiedBatchNum").Return(uint64(90), nil).Once() mockEtherman.On("GetBatchAccInputHash", mock.Anything, uint64(90)).Return(common.Hash{}, nil).Once() - mockState.On("DeleteUngeneratedProofs", mock.Anything, mock.Anything).Return(nil).Once() - mockState.On("DeleteGeneratedProofs", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() + mockStorage.On("DeleteUngeneratedProofs", mock.Anything, mock.Anything).Return(nil).Once() + mockStorage.On("DeleteGeneratedProofs", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() a := Aggregator{ ctx: context.Background(), etherman: mockEtherman, - state: mockState, + storage: mockStorage, logger: log.GetDefaultLogger(), halted: atomic.Bool{}, accInputHashes: make(map[uint64]common.Hash), @@ -171,18 +173,18 @@ func Test_handleRollbackBatches(t *testing.T) { assert.False(t, a.halted.Load()) mockEtherman.AssertExpectations(t) - mockState.AssertExpectations(t) + mockStorage.AssertExpectations(t) } func Test_handleRollbackBatchesHalt(t *testing.T) { t.Parallel() mockEtherman := new(mocks.EthermanMock) - mockState := new(mocks.StateInterfaceMock) + mockStorage := new(mocks.StorageInterfaceMock) mockEtherman.On("GetLatestVerifiedBatchNum").Return(uint64(110), nil).Once() - mockState.On("DeleteUngeneratedProofs", mock.Anything, mock.Anything).Return(nil).Once() - mockState.On("DeleteGeneratedProofs", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() + mockStorage.On("DeleteUngeneratedProofs", mock.Anything, mock.Anything).Return(nil).Once() + mockStorage.On("DeleteGeneratedProofs", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() // Test data rollbackData := synchronizer.RollbackBatchesData{ @@ -192,7 +194,7 @@ func Test_handleRollbackBatchesHalt(t *testing.T) { a := Aggregator{ ctx: context.Background(), etherman: mockEtherman, - state: mockState, + storage: mockStorage, logger: log.GetDefaultLogger(), halted: atomic.Bool{}, accInputHashes: make(map[uint64]common.Hash), @@ -211,7 +213,7 @@ func Test_handleRollbackBatchesError(t *testing.T) { t.Parallel() mockEtherman := new(mocks.EthermanMock) - mockState := new(mocks.StateInterfaceMock) + mockStorage := new(mocks.StorageInterfaceMock) mockEtherman.On("GetLatestVerifiedBatchNum").Return(uint64(110), fmt.Errorf("error")).Once() @@ -223,7 +225,7 @@ func Test_handleRollbackBatchesError(t *testing.T) { a := Aggregator{ ctx: context.Background(), etherman: mockEtherman, - state: mockState, + storage: mockStorage, logger: log.GetDefaultLogger(), halted: atomic.Bool{}, accInputHashes: make(map[uint64]common.Hash), @@ -308,7 +310,7 @@ func Test_sendFinalProofSuccess(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - stateMock := mocks.NewStateInterfaceMock(t) + storageMock := mocks.NewStorageInterfaceMock(t) ethTxManager := mocks.NewEthTxManagerClientMock(t) etherman := mocks.NewEthermanMock(t) aggLayerClient := agglayer.NewAgglayerClientMock(t) @@ -319,14 +321,14 @@ func Test_sendFinalProofSuccess(t *testing.T) { require.NoError(err, "error generating key") a := Aggregator{ - state: stateMock, + storage: storageMock, etherman: etherman, ethTxManager: ethTxManager, aggLayerClient: aggLayerClient, finalProof: make(chan finalProofMsg), logger: log.GetDefaultLogger(), verifyingProof: false, - stateDBMutex: &sync.Mutex{}, + storageMutex: &sync.Mutex{}, timeSendFinalProofMutex: &sync.RWMutex{}, sequencerPrivateKey: privateKey, rpcClient: rpcMock, @@ -336,7 +338,7 @@ func Test_sendFinalProofSuccess(t *testing.T) { a.ctx, a.exit = context.WithCancel(context.Background()) m := mox{ - stateMock: stateMock, + storageMock: storageMock, ethTxManager: ethTxManager, etherman: etherman, aggLayerClientMock: aggLayerClient, @@ -418,7 +420,7 @@ func Test_sendFinalProofError(t *testing.T) { fmt.Println("Stopping sendFinalProof") a.exit() }).Return(nil, errTest).Once() - m.stateMock.On("UpdateGeneratedProof", mock.Anything, mock.Anything, nil).Return(nil).Once() + m.storageMock.On("UpdateGeneratedProof", mock.Anything, mock.Anything, nil).Return(nil).Once() }, asserts: func(a *Aggregator) { assert.False(a.verifyingProof) @@ -442,7 +444,7 @@ func Test_sendFinalProofError(t *testing.T) { fmt.Println("Stopping sendFinalProof") a.exit() }).Return(errTest) - m.stateMock.On("UpdateGeneratedProof", mock.Anything, mock.Anything, nil).Return(nil).Once() + m.storageMock.On("UpdateGeneratedProof", mock.Anything, mock.Anything, nil).Return(nil).Once() }, asserts: func(a *Aggregator) { assert.False(a.verifyingProof) @@ -464,7 +466,7 @@ func Test_sendFinalProofError(t *testing.T) { fmt.Println("Stopping sendFinalProof") a.exit() }).Return(nil, nil, errTest) - m.stateMock.On("UpdateGeneratedProof", mock.Anything, recursiveProof, nil).Return(nil).Once() + m.storageMock.On("UpdateGeneratedProof", mock.Anything, recursiveProof, nil).Return(nil).Once() }, asserts: func(a *Aggregator) { assert.False(a.verifyingProof) @@ -488,7 +490,7 @@ func Test_sendFinalProofError(t *testing.T) { fmt.Println("Stopping sendFinalProof") a.exit() }).Return(nil, errTest).Once() - m.stateMock.On("UpdateGeneratedProof", mock.Anything, recursiveProof, nil).Return(nil).Once() + m.storageMock.On("UpdateGeneratedProof", mock.Anything, recursiveProof, nil).Return(nil).Once() }, asserts: func(a *Aggregator) { assert.False(a.verifyingProof) @@ -499,7 +501,7 @@ func Test_sendFinalProofError(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - stateMock := mocks.NewStateInterfaceMock(t) + storageMock := mocks.NewStorageInterfaceMock(t) ethTxManager := mocks.NewEthTxManagerClientMock(t) etherman := mocks.NewEthermanMock(t) aggLayerClient := agglayer.NewAgglayerClientMock(t) @@ -510,14 +512,14 @@ func Test_sendFinalProofError(t *testing.T) { require.NoError(err, "error generating key") a := Aggregator{ - state: stateMock, + storage: storageMock, etherman: etherman, ethTxManager: ethTxManager, aggLayerClient: aggLayerClient, finalProof: make(chan finalProofMsg), logger: log.GetDefaultLogger(), verifyingProof: false, - stateDBMutex: &sync.Mutex{}, + storageMutex: &sync.Mutex{}, timeSendFinalProofMutex: &sync.RWMutex{}, sequencerPrivateKey: privateKey, rpcClient: rpcMock, @@ -527,7 +529,7 @@ func Test_sendFinalProofError(t *testing.T) { a.ctx, a.exit = context.WithCancel(context.Background()) m := mox{ - stateMock: stateMock, + storageMock: storageMock, ethTxManager: ethTxManager, etherman: etherman, aggLayerClientMock: aggLayerClient, @@ -626,16 +628,16 @@ func Test_buildFinalProof(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { proverMock := mocks.NewProverInterfaceMock(t) - stateMock := mocks.NewStateInterfaceMock(t) + storageMock := mocks.NewStorageInterfaceMock(t) rpcMock := mocks.NewRPCInterfaceMock(t) m := mox{ - proverMock: proverMock, - stateMock: stateMock, - rpcMock: rpcMock, + proverMock: proverMock, + storageMock: storageMock, + rpcMock: rpcMock, } a := Aggregator{ - state: stateMock, - logger: log.GetDefaultLogger(), + storage: storageMock, + logger: log.GetDefaultLogger(), cfg: Config{ SenderAddress: common.BytesToAddress([]byte("from")).Hex(), }, @@ -656,9 +658,8 @@ func Test_tryBuildFinalProof(t *testing.T) { errTest := errors.New("test error") from := common.BytesToAddress([]byte("from")) cfg := Config{ - VerifyProofInterval: types.Duration{Duration: time.Millisecond * 1}, - TxProfitabilityCheckerType: ProfitabilityAcceptAll, - SenderAddress: from.Hex(), + VerifyProofInterval: types.Duration{Duration: time.Millisecond * 1}, + SenderAddress: from.Hex(), } latestVerifiedBatchNum := uint64(22) batchNum := uint64(23) @@ -728,10 +729,10 @@ func Test_tryBuildFinalProof(t *testing.T) { m.proverMock.On("ID").Return(proverID).Twice() m.proverMock.On("Addr").Return("addr").Twice() m.etherman.On("GetLatestVerifiedBatchNum").Return(latestVerifiedBatchNum, nil).Once() - m.stateMock.On("GetProofReadyToVerify", mock.MatchedBy(matchProverCtxFn), latestVerifiedBatchNum, nil).Return(&proofToVerify, nil).Once() - proofGeneratingTrueCall := m.stateMock.On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proofToVerify, nil).Return(nil).Once() + m.storageMock.On("GetProofReadyToVerify", mock.MatchedBy(matchProverCtxFn), latestVerifiedBatchNum, nil).Return(&proofToVerify, nil).Once() + proofGeneratingTrueCall := m.storageMock.On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proofToVerify, nil).Return(nil).Once() m.proverMock.On("FinalProof", proofToVerify.Proof, from.String()).Return(nil, errTest).Once() - m.stateMock. + m.storageMock. On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), &proofToVerify, nil). Return(nil). Once(). @@ -749,11 +750,11 @@ func Test_tryBuildFinalProof(t *testing.T) { m.proverMock.On("ID").Return(proverID).Twice() m.proverMock.On("Addr").Return("addr").Twice() m.etherman.On("GetLatestVerifiedBatchNum").Return(latestVerifiedBatchNum, nil).Once() - m.stateMock.On("GetProofReadyToVerify", mock.MatchedBy(matchProverCtxFn), latestVerifiedBatchNum, nil).Return(&proofToVerify, nil).Once() - proofGeneratingTrueCall := m.stateMock.On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proofToVerify, nil).Return(nil).Once() + m.storageMock.On("GetProofReadyToVerify", mock.MatchedBy(matchProverCtxFn), latestVerifiedBatchNum, nil).Return(&proofToVerify, nil).Once() + proofGeneratingTrueCall := m.storageMock.On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proofToVerify, nil).Return(nil).Once() m.proverMock.On("FinalProof", proofToVerify.Proof, from.String()).Return(&finalProofID, nil).Once() m.proverMock.On("WaitFinalProof", mock.MatchedBy(matchProverCtxFn), finalProofID).Return(nil, errTest).Once() - m.stateMock. + m.storageMock. On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), &proofToVerify, nil). Return(nil). Once(). @@ -771,7 +772,7 @@ func Test_tryBuildFinalProof(t *testing.T) { m.proverMock.On("ID").Return(proverID).Once() m.proverMock.On("Addr").Return(proverID).Once() m.etherman.On("GetLatestVerifiedBatchNum").Return(latestVerifiedBatchNum, nil).Once() - m.stateMock.On("GetProofReadyToVerify", mock.MatchedBy(matchProverCtxFn), latestVerifiedBatchNum, nil).Return(nil, errTest).Once() + m.storageMock.On("GetProofReadyToVerify", mock.MatchedBy(matchProverCtxFn), latestVerifiedBatchNum, nil).Return(nil, errTest).Once() }, asserts: func(result bool, a *Aggregator, err error) { assert.False(result) @@ -785,7 +786,7 @@ func Test_tryBuildFinalProof(t *testing.T) { m.proverMock.On("ID").Return(proverID).Once() m.proverMock.On("Addr").Return(proverID).Once() m.etherman.On("GetLatestVerifiedBatchNum").Return(latestVerifiedBatchNum, nil).Once() - m.stateMock.On("GetProofReadyToVerify", mock.MatchedBy(matchProverCtxFn), latestVerifiedBatchNum, nil).Return(nil, state.ErrNotFound).Once() + m.storageMock.On("GetProofReadyToVerify", mock.MatchedBy(matchProverCtxFn), latestVerifiedBatchNum, nil).Return(nil, state.ErrNotFound).Once() }, asserts: func(result bool, a *Aggregator, err error) { assert.False(result) @@ -799,8 +800,8 @@ func Test_tryBuildFinalProof(t *testing.T) { m.proverMock.On("ID").Return(proverID).Twice() m.proverMock.On("Addr").Return(proverID).Twice() m.etherman.On("GetLatestVerifiedBatchNum").Return(latestVerifiedBatchNum, nil).Once() - m.stateMock.On("GetProofReadyToVerify", mock.MatchedBy(matchProverCtxFn), latestVerifiedBatchNum, nil).Return(&proofToVerify, nil).Once() - m.stateMock.On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proofToVerify, nil).Return(nil).Once() + m.storageMock.On("GetProofReadyToVerify", mock.MatchedBy(matchProverCtxFn), latestVerifiedBatchNum, nil).Return(&proofToVerify, nil).Once() + m.storageMock.On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proofToVerify, nil).Return(nil).Once() m.proverMock.On("FinalProof", proofToVerify.Proof, from.String()).Return(&finalProofID, nil).Once() m.proverMock.On("WaitFinalProof", mock.MatchedBy(matchProverCtxFn), finalProofID).Return(&finalProof, nil).Once() }, @@ -822,7 +823,7 @@ func Test_tryBuildFinalProof(t *testing.T) { m.proverMock.On("ID").Return(proverID).Once() m.proverMock.On("Addr").Return(proverID).Once() m.etherman.On("GetLatestVerifiedBatchNum").Return(latestVerifiedBatchNum, nil).Once() - m.stateMock.On("CheckProofContainsCompleteSequences", mock.MatchedBy(matchProverCtxFn), &proofToVerify, nil).Return(false, errTest).Once() + m.storageMock.On("CheckProofContainsCompleteSequences", mock.MatchedBy(matchProverCtxFn), &proofToVerify, nil).Return(false, errTest).Once() }, asserts: func(result bool, a *Aggregator, err error) { assert.False(result) @@ -851,7 +852,7 @@ func Test_tryBuildFinalProof(t *testing.T) { m.proverMock.On("ID").Return(proverID).Once() m.proverMock.On("Addr").Return(proverID).Once() m.etherman.On("GetLatestVerifiedBatchNum").Return(latestVerifiedBatchNum, nil).Once() - m.stateMock.On("CheckProofContainsCompleteSequences", mock.MatchedBy(matchProverCtxFn), &proofToVerify, nil).Return(false, nil).Once() + m.storageMock.On("CheckProofContainsCompleteSequences", mock.MatchedBy(matchProverCtxFn), &proofToVerify, nil).Return(false, nil).Once() }, asserts: func(result bool, a *Aggregator, err error) { assert.False(result) @@ -866,7 +867,7 @@ func Test_tryBuildFinalProof(t *testing.T) { m.proverMock.On("ID").Return(proverID).Twice() m.proverMock.On("Addr").Return(proverID).Twice() m.etherman.On("GetLatestVerifiedBatchNum").Return(latestVerifiedBatchNum, nil).Once() - m.stateMock.On("CheckProofContainsCompleteSequences", mock.MatchedBy(matchProverCtxFn), &proofToVerify, nil).Return(true, nil).Once() + m.storageMock.On("CheckProofContainsCompleteSequences", mock.MatchedBy(matchProverCtxFn), &proofToVerify, nil).Return(true, nil).Once() m.proverMock.On("FinalProof", proofToVerify.Proof, from.String()).Return(&finalProofID, nil).Once() m.proverMock.On("WaitFinalProof", mock.MatchedBy(matchProverCtxFn), finalProofID).Return(&finalProof, nil).Once() }, @@ -885,18 +886,18 @@ func Test_tryBuildFinalProof(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - stateMock := mocks.NewStateInterfaceMock(t) + storageMock := mocks.NewStorageInterfaceMock(t) ethTxManager := mocks.NewEthTxManagerClientMock(t) etherman := mocks.NewEthermanMock(t) proverMock := mocks.NewProverInterfaceMock(t) a := Aggregator{ cfg: cfg, - state: stateMock, + storage: storageMock, etherman: etherman, ethTxManager: ethTxManager, logger: log.GetDefaultLogger(), - stateDBMutex: &sync.Mutex{}, + storageMutex: &sync.Mutex{}, timeSendFinalProofMutex: &sync.RWMutex{}, timeCleanupLockedProofs: cfg.CleanupLockedProofsInterval, finalProof: make(chan finalProofMsg), @@ -907,7 +908,7 @@ func Test_tryBuildFinalProof(t *testing.T) { aggregatorCtx := context.WithValue(context.Background(), "owner", ownerAggregator) //nolint:staticcheck a.ctx, a.exit = context.WithCancel(aggregatorCtx) m := mox{ - stateMock: stateMock, + storageMock: storageMock, ethTxManager: ethTxManager, etherman: etherman, proverMock: proverMock, @@ -974,7 +975,7 @@ func Test_tryAggregateProofs(t *testing.T) { m.proverMock.On("Name").Return(proverName).Twice() m.proverMock.On("ID").Return(proverID).Twice() m.proverMock.On("Addr").Return("addr") - m.stateMock.On("GetProofsToAggregate", mock.MatchedBy(matchProverCtxFn), nil).Return(nil, nil, errTest).Once() + m.storageMock.On("GetProofsToAggregate", mock.MatchedBy(matchProverCtxFn), nil).Return(nil, nil, errTest).Once() }, asserts: func(result bool, a *Aggregator, err error) { assert.False(result) @@ -987,7 +988,7 @@ func Test_tryAggregateProofs(t *testing.T) { m.proverMock.On("Name").Return(proverName).Twice() m.proverMock.On("ID").Return(proverID).Twice() m.proverMock.On("Addr").Return("addr") - m.stateMock.On("GetProofsToAggregate", mock.MatchedBy(matchProverCtxFn), nil).Return(nil, nil, state.ErrNotFound).Once() + m.storageMock.On("GetProofsToAggregate", mock.MatchedBy(matchProverCtxFn), nil).Return(nil, nil, state.ErrNotFound).Once() }, asserts: func(result bool, a *Aggregator, err error) { assert.False(result) @@ -1000,12 +1001,12 @@ func Test_tryAggregateProofs(t *testing.T) { m.proverMock.On("Name").Return(proverName).Twice() m.proverMock.On("ID").Return(proverID).Twice() m.proverMock.On("Addr").Return("addr") - dbTx := &mocks.DbTxMock{} - dbTx.On("Rollback", mock.MatchedBy(matchProverCtxFn)).Return(nil).Once() - m.stateMock.On("BeginStateTransaction", mock.MatchedBy(matchProverCtxFn)).Return(dbTx, nil).Once() - m.stateMock.On("GetProofsToAggregate", mock.MatchedBy(matchProverCtxFn), nil).Return(&proof1, &proof2, nil).Once() - m.stateMock. - On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof1, dbTx). + m.txerMock.On("Rollback").Return(nil).Once() + m.storageMock.On("BeginTx", mock.Anything, (*sql.TxOptions)(nil)).Return(m.txerMock, nil).Once() + m.storageMock.On("GetProofsToAggregate", mock.MatchedBy(matchProverCtxFn), nil).Return(&proof1, &proof2, nil).Once() + + m.storageMock. + On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof1, mock.Anything). Run(func(args mock.Arguments) { proofArg, ok := args[1].(*state.Proof) assert.True(ok, "Expected argument of type *state.Proof") @@ -1019,18 +1020,19 @@ func Test_tryAggregateProofs(t *testing.T) { assert.ErrorIs(err, errTest) }, }, + { name: "AggregatedProof error", setup: func(m mox, a *Aggregator) { m.proverMock.On("Name").Return(proverName).Twice() m.proverMock.On("ID").Return(proverID).Twice() m.proverMock.On("Addr").Return("addr") - dbTx := &mocks.DbTxMock{} - lockProofsTxBegin := m.stateMock.On("BeginStateTransaction", mock.MatchedBy(matchProverCtxFn)).Return(dbTx, nil).Once() - lockProofsTxCommit := dbTx.On("Commit", mock.MatchedBy(matchProverCtxFn)).Return(nil).Once() - m.stateMock.On("GetProofsToAggregate", mock.MatchedBy(matchProverCtxFn), nil).Return(&proof1, &proof2, nil).Once() - proof1GeneratingTrueCall := m.stateMock. - On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof1, dbTx). + + lockProofsTxBegin := m.storageMock.On("BeginTx", mock.MatchedBy(matchProverCtxFn), (*sql.TxOptions)(nil)).Return(m.txerMock, nil).Once() + // lockProofsTxCommit := m.proverMock.On("Commit").Return(nil).Once() + m.storageMock.On("GetProofsToAggregate", mock.MatchedBy(matchProverCtxFn), nil).Return(&proof1, &proof2, nil).Once() + proof1GeneratingTrueCall := m.storageMock. + On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof1, mock.Anything). Run(func(args mock.Arguments) { proofArg, ok := args[1].(*state.Proof) assert.True(ok, "Expected argument of type *state.Proof") @@ -1038,8 +1040,8 @@ func Test_tryAggregateProofs(t *testing.T) { }). Return(nil). Once() - proof2GeneratingTrueCall := m.stateMock. - On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof2, dbTx). + proof2GeneratingTrueCall := m.storageMock. + On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof2, mock.Anything). Run(func(args mock.Arguments) { // Use a type assertion with a check proofArg, ok := args[1].(*state.Proof) @@ -1051,9 +1053,9 @@ func Test_tryAggregateProofs(t *testing.T) { Return(nil). Once() m.proverMock.On("AggregatedProof", proof1.Proof, proof2.Proof).Return(nil, errTest).Once() - m.stateMock.On("BeginStateTransaction", mock.MatchedBy(matchAggregatorCtxFn)).Return(dbTx, nil).Once().NotBefore(lockProofsTxBegin) - m.stateMock. - On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), &proof1, dbTx). + m.storageMock.On("BeginTx", mock.Anything, (*sql.TxOptions)(nil)).Return(m.txerMock, nil).Once().NotBefore(lockProofsTxBegin) + m.storageMock. + On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), &proof1, mock.Anything). Run(func(args mock.Arguments) { proofArg, ok := args[1].(*state.Proof) if !ok { @@ -1064,8 +1066,8 @@ func Test_tryAggregateProofs(t *testing.T) { Return(nil). Once(). NotBefore(proof1GeneratingTrueCall) - m.stateMock. - On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), &proof2, dbTx). + m.storageMock. + On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), &proof2, mock.Anything). Run(func(args mock.Arguments) { proofArg, ok := args[1].(*state.Proof) if !ok { @@ -1076,25 +1078,25 @@ func Test_tryAggregateProofs(t *testing.T) { Return(nil). Once(). NotBefore(proof2GeneratingTrueCall) - dbTx.On("Commit", mock.MatchedBy(matchAggregatorCtxFn)).Return(nil).Once().NotBefore(lockProofsTxCommit) + m.txerMock.On("Commit").Return(nil) }, asserts: func(result bool, a *Aggregator, err error) { assert.False(result) assert.ErrorIs(err, errTest) }, }, + { name: "WaitRecursiveProof prover error", setup: func(m mox, a *Aggregator) { m.proverMock.On("Name").Return(proverName).Twice() m.proverMock.On("ID").Return(proverID).Twice() m.proverMock.On("Addr").Return("addr") - dbTx := &mocks.DbTxMock{} - lockProofsTxBegin := m.stateMock.On("BeginStateTransaction", mock.MatchedBy(matchProverCtxFn)).Return(dbTx, nil).Once() - lockProofsTxCommit := dbTx.On("Commit", mock.MatchedBy(matchProverCtxFn)).Return(nil).Once() - m.stateMock.On("GetProofsToAggregate", mock.MatchedBy(matchProverCtxFn), nil).Return(&proof1, &proof2, nil).Once() - proof1GeneratingTrueCall := m.stateMock. - On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof1, dbTx). + lockProofsTxBegin := m.storageMock.On("BeginTx", mock.MatchedBy(matchProverCtxFn), (*sql.TxOptions)(nil)).Return(m.txerMock, nil).Once() + // lockProofsTxCommit := dbTx.On("Commit", mock.MatchedBy(matchProverCtxFn)).Return(nil).Once() + m.storageMock.On("GetProofsToAggregate", mock.MatchedBy(matchProverCtxFn), nil).Return(&proof1, &proof2, nil).Once() + proof1GeneratingTrueCall := m.storageMock. + On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof1, mock.Anything). Run(func(args mock.Arguments) { proofArg, ok := args[1].(*state.Proof) if !ok { @@ -1104,8 +1106,8 @@ func Test_tryAggregateProofs(t *testing.T) { }). Return(nil). Once() - proof2GeneratingTrueCall := m.stateMock. - On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof2, dbTx). + proof2GeneratingTrueCall := m.storageMock. + On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof2, mock.Anything). Run(func(args mock.Arguments) { proofArg, ok := args[1].(*state.Proof) assert.True(ok, "Expected argument of type *state.Proof") @@ -1114,10 +1116,11 @@ func Test_tryAggregateProofs(t *testing.T) { Return(nil). Once() m.proverMock.On("AggregatedProof", proof1.Proof, proof2.Proof).Return(&proofID, nil).Once() + m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return("", common.Hash{}, common.Hash{}, errTest).Once() - m.stateMock.On("BeginStateTransaction", mock.MatchedBy(matchAggregatorCtxFn)).Return(dbTx, nil).Once().NotBefore(lockProofsTxBegin) - m.stateMock. - On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), &proof1, dbTx). + m.storageMock.On("BeginTx", mock.MatchedBy(matchAggregatorCtxFn), (*sql.TxOptions)(nil)).Return(m.txerMock, nil).Once().NotBefore(lockProofsTxBegin) + m.storageMock. + On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), &proof1, mock.Anything). Run(func(args mock.Arguments) { proofArg, ok := args[1].(*state.Proof) assert.True(ok, "Expected argument of type *state.Proof") @@ -1126,8 +1129,8 @@ func Test_tryAggregateProofs(t *testing.T) { Return(nil). Once(). NotBefore(proof1GeneratingTrueCall) - m.stateMock. - On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), &proof2, dbTx). + m.storageMock. + On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), &proof2, mock.Anything). Run(func(args mock.Arguments) { proofArg, ok := args[1].(*state.Proof) assert.True(ok, "Expected argument of type *state.Proof") @@ -1136,25 +1139,25 @@ func Test_tryAggregateProofs(t *testing.T) { Return(nil). Once(). NotBefore(proof2GeneratingTrueCall) - dbTx.On("Commit", mock.MatchedBy(matchAggregatorCtxFn)).Return(nil).Once().NotBefore(lockProofsTxCommit) + m.txerMock.On("Commit").Return(nil) }, asserts: func(result bool, a *Aggregator, err error) { assert.False(result) assert.ErrorIs(err, errTest) }, }, + { name: "unlockProofsToAggregate error after WaitRecursiveProof prover error", setup: func(m mox, a *Aggregator) { m.proverMock.On("Name").Return(proverName).Twice() m.proverMock.On("ID").Return(proverID).Twice() m.proverMock.On("Addr").Return(proverID) - dbTx := &mocks.DbTxMock{} - lockProofsTxBegin := m.stateMock.On("BeginStateTransaction", mock.MatchedBy(matchProverCtxFn)).Return(dbTx, nil).Once() - dbTx.On("Commit", mock.MatchedBy(matchProverCtxFn)).Return(nil).Once() - m.stateMock.On("GetProofsToAggregate", mock.MatchedBy(matchProverCtxFn), nil).Return(&proof1, &proof2, nil).Once() - proof1GeneratingTrueCall := m.stateMock. - On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof1, dbTx). + // lockProofsTxBegin := m.storageMock.On("BeginTx", mock.MatchedBy(matchProverCtxFn)).Return(m.txerMock, nil).Once() + m.txerMock.On("Commit").Return(nil) + m.storageMock.On("GetProofsToAggregate", mock.MatchedBy(matchProverCtxFn), nil).Return(&proof1, &proof2, nil).Once() + proof1GeneratingTrueCall := m.storageMock. + On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof1, mock.Anything). Run(func(args mock.Arguments) { proofArg, ok := args[1].(*state.Proof) assert.True(ok, "Expected argument of type *state.Proof") @@ -1162,8 +1165,8 @@ func Test_tryAggregateProofs(t *testing.T) { }). Return(nil). Once() - m.stateMock. - On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof2, dbTx). + m.storageMock. + On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof2, mock.Anything). Run(func(args mock.Arguments) { proofArg, ok := args[1].(*state.Proof) assert.True(ok, "Expected argument of type *state.Proof") @@ -1173,9 +1176,9 @@ func Test_tryAggregateProofs(t *testing.T) { Once() m.proverMock.On("AggregatedProof", proof1.Proof, proof2.Proof).Return(&proofID, nil).Once() m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return("", common.Hash{}, common.Hash{}, errTest).Once() - m.stateMock.On("BeginStateTransaction", mock.MatchedBy(matchAggregatorCtxFn)).Return(dbTx, nil).Once().NotBefore(lockProofsTxBegin) - m.stateMock. - On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), &proof1, dbTx). + m.storageMock.On("BeginTx", mock.Anything, (*sql.TxOptions)(nil)).Return(m.txerMock, nil) + m.storageMock. + On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), &proof1, mock.Anything). Run(func(args mock.Arguments) { proofArg, ok := args[1].(*state.Proof) assert.True(ok, "Expected argument of type *state.Proof") @@ -1184,7 +1187,7 @@ func Test_tryAggregateProofs(t *testing.T) { Return(errTest). Once(). NotBefore(proof1GeneratingTrueCall) - dbTx.On("Rollback", mock.MatchedBy(matchAggregatorCtxFn)).Return(nil).Once() + m.txerMock.On("Rollback").Return(nil).Once() }, asserts: func(result bool, a *Aggregator, err error) { assert.False(result) @@ -1197,12 +1200,11 @@ func Test_tryAggregateProofs(t *testing.T) { m.proverMock.On("Name").Return(proverName).Twice() m.proverMock.On("ID").Return(proverID).Twice() m.proverMock.On("Addr").Return("addr") - dbTx := &mocks.DbTxMock{} - lockProofsTxBegin := m.stateMock.On("BeginStateTransaction", mock.MatchedBy(matchProverCtxFn)).Return(dbTx, nil).Twice() - lockProofsTxCommit := dbTx.On("Commit", mock.MatchedBy(matchProverCtxFn)).Return(nil).Once() - m.stateMock.On("GetProofsToAggregate", mock.MatchedBy(matchProverCtxFn), nil).Return(&proof1, &proof2, nil).Once() - proof1GeneratingTrueCall := m.stateMock. - On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof1, dbTx). + // lockProofsTxBegin := m.storageMock.On("BeginTx", mock.MatchedBy(matchProverCtxFn)).Return(dbTx, nil).Twice() + // lockProofsTxCommit := dbTx.On("Commit", mock.MatchedBy(matchProverCtxFn)).Return(nil).Once() + m.storageMock.On("GetProofsToAggregate", mock.MatchedBy(matchProverCtxFn), nil).Return(&proof1, &proof2, nil).Once() + proof1GeneratingTrueCall := m.storageMock. + On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof1, mock.Anything). Run(func(args mock.Arguments) { proofArg, ok := args[1].(*state.Proof) assert.True(ok, "Expected argument of type *state.Proof") @@ -1210,8 +1212,8 @@ func Test_tryAggregateProofs(t *testing.T) { }). Return(nil). Once() - proof2GeneratingTrueCall := m.stateMock. - On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof2, dbTx). + proof2GeneratingTrueCall := m.storageMock. + On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof2, mock.Anything). Run(func(args mock.Arguments) { proofArg, ok := args[1].(*state.Proof) assert.True(ok, "Expected argument of type *state.Proof") @@ -1221,11 +1223,11 @@ func Test_tryAggregateProofs(t *testing.T) { Once() m.proverMock.On("AggregatedProof", proof1.Proof, proof2.Proof).Return(&proofID, nil).Once() m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return(recursiveProof, common.Hash{}, common.Hash{}, nil).Once() - m.stateMock.On("DeleteGeneratedProofs", mock.MatchedBy(matchProverCtxFn), proof1.BatchNumber, proof2.BatchNumberFinal, dbTx).Return(errTest).Once() - dbTx.On("Rollback", mock.MatchedBy(matchProverCtxFn)).Return(nil).Once() - m.stateMock.On("BeginStateTransaction", mock.MatchedBy(matchAggregatorCtxFn)).Return(dbTx, nil).Once().NotBefore(lockProofsTxBegin) - m.stateMock. - On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), &proof1, dbTx). + m.storageMock.On("DeleteGeneratedProofs", mock.MatchedBy(matchProverCtxFn), proof1.BatchNumber, proof2.BatchNumberFinal, mock.Anything).Return(errTest).Once() + m.txerMock.On("Rollback").Return(nil).Once() + m.storageMock.On("BeginTx", mock.Anything, (*sql.TxOptions)(nil)).Return(m.txerMock, nil) + m.storageMock. + On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), &proof1, mock.Anything). Run(func(args mock.Arguments) { proofArg, ok := args[1].(*state.Proof) assert.True(ok, "Expected argument of type *state.Proof") @@ -1234,8 +1236,8 @@ func Test_tryAggregateProofs(t *testing.T) { Return(nil). Once(). NotBefore(proof1GeneratingTrueCall) - m.stateMock. - On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), &proof2, dbTx). + m.storageMock. + On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), &proof2, mock.Anything). Run(func(args mock.Arguments) { proofArg, ok := args[1].(*state.Proof) assert.True(ok, "Expected argument of type *state.Proof") @@ -1244,25 +1246,25 @@ func Test_tryAggregateProofs(t *testing.T) { Return(nil). Once(). NotBefore(proof2GeneratingTrueCall) - dbTx.On("Commit", mock.MatchedBy(matchAggregatorCtxFn)).Return(nil).Once().NotBefore(lockProofsTxCommit) + m.txerMock.On("Commit").Return(nil) }, asserts: func(result bool, a *Aggregator, err error) { assert.False(result) assert.ErrorIs(err, errTest) }, }, + { name: "rollback after AddGeneratedProof error in db transaction", setup: func(m mox, a *Aggregator) { m.proverMock.On("Name").Return(proverName).Twice() m.proverMock.On("ID").Return(proverID).Twice() m.proverMock.On("Addr").Return("addr") - dbTx := &mocks.DbTxMock{} - lockProofsTxBegin := m.stateMock.On("BeginStateTransaction", mock.MatchedBy(matchProverCtxFn)).Return(dbTx, nil).Twice() - lockProofsTxCommit := dbTx.On("Commit", mock.MatchedBy(matchProverCtxFn)).Return(nil).Once() - m.stateMock.On("GetProofsToAggregate", mock.MatchedBy(matchProverCtxFn), nil).Return(&proof1, &proof2, nil).Once() - proof1GeneratingTrueCall := m.stateMock. - On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof1, dbTx). + // lockProofsTxBegin := m.storageMock.On("BeginTx", mock.MatchedBy(matchProverCtxFn)).Return(dbTx, nil).Twice() + // lockProofsTxCommit := dbTx.On("Commit", mock.MatchedBy(matchProverCtxFn)).Return(nil).Once() + m.storageMock.On("GetProofsToAggregate", mock.MatchedBy(matchProverCtxFn), nil).Return(&proof1, &proof2, nil).Once() + proof1GeneratingTrueCall := m.storageMock. + On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof1, mock.Anything). Run(func(args mock.Arguments) { proofArg, ok := args[1].(*state.Proof) assert.True(ok, "Expected argument of type *state.Proof") @@ -1270,8 +1272,8 @@ func Test_tryAggregateProofs(t *testing.T) { }). Return(nil). Once() - proof2GeneratingTrueCall := m.stateMock. - On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof2, dbTx). + proof2GeneratingTrueCall := m.storageMock. + On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof2, mock.Anything). Run(func(args mock.Arguments) { proofArg, ok := args[1].(*state.Proof) assert.True(ok, "Expected argument of type *state.Proof") @@ -1281,12 +1283,12 @@ func Test_tryAggregateProofs(t *testing.T) { Once() m.proverMock.On("AggregatedProof", proof1.Proof, proof2.Proof).Return(&proofID, nil).Once() m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return(recursiveProof, common.Hash{}, common.Hash{}, nil).Once() - m.stateMock.On("DeleteGeneratedProofs", mock.MatchedBy(matchProverCtxFn), proof1.BatchNumber, proof2.BatchNumberFinal, dbTx).Return(nil).Once() - m.stateMock.On("AddGeneratedProof", mock.MatchedBy(matchProverCtxFn), mock.Anything, dbTx).Return(errTest).Once() - dbTx.On("Rollback", mock.MatchedBy(matchProverCtxFn)).Return(nil).Once() - m.stateMock.On("BeginStateTransaction", mock.MatchedBy(matchAggregatorCtxFn)).Return(dbTx, nil).Once().NotBefore(lockProofsTxBegin) - m.stateMock. - On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), &proof1, dbTx). + m.storageMock.On("DeleteGeneratedProofs", mock.MatchedBy(matchProverCtxFn), proof1.BatchNumber, proof2.BatchNumberFinal, mock.Anything).Return(nil).Once() + m.storageMock.On("AddGeneratedProof", mock.MatchedBy(matchProverCtxFn), mock.Anything, mock.Anything).Return(errTest).Once() + m.txerMock.On("Rollback").Return(nil).Once() + m.storageMock.On("BeginTx", mock.Anything, (*sql.TxOptions)(nil)).Return(m.txerMock, nil) + m.storageMock. + On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), &proof1, mock.Anything). Run(func(args mock.Arguments) { proofArg, ok := args[1].(*state.Proof) assert.True(ok, "Expected argument of type *state.Proof") @@ -1295,8 +1297,8 @@ func Test_tryAggregateProofs(t *testing.T) { Return(nil). Once(). NotBefore(proof1GeneratingTrueCall) - m.stateMock. - On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), &proof2, dbTx). + m.storageMock. + On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), &proof2, mock.Anything). Run(func(args mock.Arguments) { proofArg, ok := args[1].(*state.Proof) assert.True(ok, "Expected argument of type *state.Proof") @@ -1305,7 +1307,7 @@ func Test_tryAggregateProofs(t *testing.T) { Return(nil). Once(). NotBefore(proof2GeneratingTrueCall) - dbTx.On("Commit", mock.MatchedBy(matchAggregatorCtxFn)).Return(nil).Once().NotBefore(lockProofsTxCommit) + m.txerMock.On("Commit").Return(nil) }, asserts: func(result bool, a *Aggregator, err error) { assert.False(result) @@ -1320,12 +1322,11 @@ func Test_tryAggregateProofs(t *testing.T) { m.proverMock.On("Name").Return(proverName).Times(3) m.proverMock.On("ID").Return(proverID).Times(3) m.proverMock.On("Addr").Return("addr") - dbTx := &mocks.DbTxMock{} - m.stateMock.On("BeginStateTransaction", mock.MatchedBy(matchProverCtxFn)).Return(dbTx, nil).Twice() - dbTx.On("Commit", mock.MatchedBy(matchProverCtxFn)).Return(nil).Twice() - m.stateMock.On("GetProofsToAggregate", mock.MatchedBy(matchProverCtxFn), nil).Return(&proof1, &proof2, nil).Once() - m.stateMock. - On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof1, dbTx). + m.storageMock.On("BeginTx", mock.Anything, (*sql.TxOptions)(nil)).Return(m.txerMock, nil) + m.txerMock.On("Commit").Return(nil) + m.storageMock.On("GetProofsToAggregate", mock.MatchedBy(matchProverCtxFn), nil).Return(&proof1, &proof2, nil).Once() + m.storageMock. + On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof1, mock.Anything). Run(func(args mock.Arguments) { proofArg, ok := args[1].(*state.Proof) assert.True(ok, "Expected argument of type *state.Proof") @@ -1333,8 +1334,8 @@ func Test_tryAggregateProofs(t *testing.T) { }). Return(nil). Once() - m.stateMock. - On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof2, dbTx). + m.storageMock. + On("UpdateGeneratedProof", mock.MatchedBy(matchProverCtxFn), &proof2, mock.Anything). Run(func(args mock.Arguments) { proofArg, ok := args[1].(*state.Proof) assert.True(ok, "Expected argument of type *state.Proof") @@ -1345,14 +1346,14 @@ func Test_tryAggregateProofs(t *testing.T) { m.proverMock.On("AggregatedProof", proof1.Proof, proof2.Proof).Return(&proofID, nil).Once() m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return(recursiveProof, common.Hash{}, common.Hash{}, nil).Once() - m.stateMock.On("DeleteGeneratedProofs", mock.MatchedBy(matchProverCtxFn), proof1.BatchNumber, proof2.BatchNumberFinal, dbTx).Return(nil).Once() + m.storageMock.On("DeleteGeneratedProofs", mock.MatchedBy(matchProverCtxFn), proof1.BatchNumber, proof2.BatchNumberFinal, m.txerMock).Return(nil).Once() expectedInputProver := map[string]interface{}{ "recursive_proof_1": proof1.Proof, "recursive_proof_2": proof2.Proof, } b, err := json.Marshal(expectedInputProver) require.NoError(err) - m.stateMock.On("AddGeneratedProof", mock.MatchedBy(matchProverCtxFn), mock.Anything, dbTx).Run( + m.storageMock.On("AddGeneratedProof", mock.MatchedBy(matchProverCtxFn), mock.Anything, mock.Anything).Run( func(args mock.Arguments) { proof, ok := args[1].(*state.Proof) if !ok { @@ -1369,7 +1370,7 @@ func Test_tryAggregateProofs(t *testing.T) { ).Return(nil).Once() m.etherman.On("GetLatestVerifiedBatchNum").Return(uint64(42), errTest).Once() - m.stateMock.On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), mock.Anything, nil).Run( + m.storageMock.On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), mock.Anything, nil).Run( func(args mock.Arguments) { proof, ok := args[1].(*state.Proof) if !ok { @@ -1394,17 +1395,18 @@ func Test_tryAggregateProofs(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - stateMock := mocks.NewStateInterfaceMock(t) + storageMock := mocks.NewStorageInterfaceMock(t) ethTxManager := mocks.NewEthTxManagerClientMock(t) etherman := mocks.NewEthermanMock(t) proverMock := mocks.NewProverInterfaceMock(t) + txerMock := mocks.NewTxerMock(t) a := Aggregator{ cfg: cfg, - state: stateMock, + storage: storageMock, etherman: etherman, ethTxManager: ethTxManager, logger: log.GetDefaultLogger(), - stateDBMutex: &sync.Mutex{}, + storageMutex: &sync.Mutex{}, timeSendFinalProofMutex: &sync.RWMutex{}, timeCleanupLockedProofs: cfg.CleanupLockedProofsInterval, finalProof: make(chan finalProofMsg), @@ -1414,10 +1416,11 @@ func Test_tryAggregateProofs(t *testing.T) { aggregatorCtx := context.WithValue(context.Background(), "owner", ownerAggregator) //nolint:staticcheck a.ctx, a.exit = context.WithCancel(aggregatorCtx) m := mox{ - stateMock: stateMock, + storageMock: storageMock, ethTxManager: ethTxManager, etherman: etherman, proverMock: proverMock, + txerMock: txerMock, } if tc.setup != nil { tc.setup(m, &a) @@ -1439,7 +1442,6 @@ func Test_tryGenerateBatchProof(t *testing.T) { from := common.BytesToAddress([]byte("from")) cfg := Config{ VerifyProofInterval: types.Duration{Duration: time.Duration(10000000)}, - TxProfitabilityCheckerType: ProfitabilityAcceptAll, SenderAddress: from.Hex(), IntervalAfterWhichBatchConsolidateAnyway: types.Duration{Duration: time.Second * 1}, ChainID: uint64(1), @@ -1515,15 +1517,15 @@ func Test_tryGenerateBatchProof(t *testing.T) { m.proverMock.On("ID").Return(proverID) m.proverMock.On("Addr").Return("addr") m.etherman.On("GetLatestVerifiedBatchNum").Return(uint64(0), nil) - m.stateMock.On("CheckProofExistsForBatch", mock.Anything, uint64(1), nil).Return(false, nil) + m.storageMock.On("CheckProofExistsForBatch", mock.Anything, uint64(1), nil).Return(false, nil) m.synchronizerMock.On("GetSequenceByBatchNumber", mock.Anything, mock.Anything).Return(&sequence, nil) m.synchronizerMock.On("GetVirtualBatchByBatchNumber", mock.Anything, mock.Anything).Return(&virtualBatch, nil) m.synchronizerMock.On("GetL1BlockByNumber", mock.Anything, mock.Anything).Return(&synchronizer.L1Block{ParentHash: common.Hash{}}, nil) m.rpcMock.On("GetBatch", mock.Anything).Return(rpcBatch, nil) m.rpcMock.On("GetWitness", mock.Anything, false).Return([]byte("witness"), nil) - m.stateMock.On("AddGeneratedProof", mock.Anything, mock.Anything, nil).Return(nil) - m.stateMock.On("AddSequence", mock.Anything, mock.Anything, nil).Return(nil) - m.stateMock.On("DeleteGeneratedProofs", mock.Anything, uint64(1), uint64(1), nil).Return(nil) + m.storageMock.On("AddGeneratedProof", mock.Anything, mock.Anything, nil).Return(nil) + m.storageMock.On("AddSequence", mock.Anything, mock.Anything, nil).Return(nil) + m.storageMock.On("DeleteGeneratedProofs", mock.Anything, uint64(1), uint64(1), nil).Return(nil) m.synchronizerMock.On("GetLeafsByL1InfoRoot", mock.Anything, l1InfoRoot).Return(l1InfoTreeLeaf, nil) m.synchronizerMock.On("GetL1InfoTreeLeaves", mock.Anything, mock.Anything).Return(map[uint32]synchronizer.L1InfoTreeLeaf{ 1: { @@ -1581,8 +1583,8 @@ func Test_tryGenerateBatchProof(t *testing.T) { m.synchronizerMock.On("GetVirtualBatchByBatchNumber", mock.Anything, mock.Anything).Return(&virtualBatch, nil).Once() m.etherman.On("GetLatestVerifiedBatchNum").Return(lastVerifiedBatchNum, nil).Once() - m.stateMock.On("CheckProofExistsForBatch", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Return(true, nil) - m.stateMock.On("CleanupGeneratedProofs", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() + m.storageMock.On("CheckProofExistsForBatch", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Return(true, nil) + m.storageMock.On("CleanupGeneratedProofs", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() sequence := synchronizer.SequencedBatches{ FromBatchNumber: uint64(10), ToBatchNumber: uint64(20), @@ -1593,8 +1595,8 @@ func Test_tryGenerateBatchProof(t *testing.T) { rpcBatch.SetLastL2BLockTimestamp(uint64(time.Now().Unix())) m.rpcMock.On("GetWitness", mock.Anything, false).Return([]byte("witness"), nil) m.rpcMock.On("GetBatch", mock.Anything).Return(rpcBatch, nil) - m.stateMock.On("AddSequence", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Return(nil).Once() - m.stateMock.On("AddGeneratedProof", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Run( + m.storageMock.On("AddSequence", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Return(nil).Once() + m.storageMock.On("AddGeneratedProof", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Run( func(args mock.Arguments) { proof, ok := args[1].(*state.Proof) if !ok { @@ -1615,7 +1617,7 @@ func Test_tryGenerateBatchProof(t *testing.T) { }, nil) m.proverMock.On("BatchProof", mock.Anything).Return(nil, errTest).Once() - m.stateMock.On("DeleteGeneratedProofs", mock.MatchedBy(matchAggregatorCtxFn), batchToProve.BatchNumber, batchToProve.BatchNumber, nil).Return(nil).Once() + m.storageMock.On("DeleteGeneratedProofs", mock.MatchedBy(matchAggregatorCtxFn), batchToProve.BatchNumber, batchToProve.BatchNumber, nil).Return(nil).Once() }, asserts: func(result bool, a *Aggregator, err error) { assert.False(result) @@ -1642,7 +1644,7 @@ func Test_tryGenerateBatchProof(t *testing.T) { m.synchronizerMock.On("GetVirtualBatchByBatchNumber", mock.Anything, lastVerifiedBatchNum+1).Return(&virtualBatch, nil).Once() m.etherman.On("GetLatestVerifiedBatchNum").Return(lastVerifiedBatchNum, nil).Once() - m.stateMock.On("CheckProofExistsForBatch", mock.MatchedBy(matchProverCtxFn), mock.AnythingOfType("uint64"), nil).Return(false, nil).Once() + m.storageMock.On("CheckProofExistsForBatch", mock.MatchedBy(matchProverCtxFn), mock.AnythingOfType("uint64"), nil).Return(false, nil).Once() sequence := synchronizer.SequencedBatches{ FromBatchNumber: uint64(10), ToBatchNumber: uint64(20), @@ -1651,8 +1653,8 @@ func Test_tryGenerateBatchProof(t *testing.T) { rpcBatch := rpctypes.NewRPCBatch(lastVerifiedBatchNum+1, common.Hash{}, []string{}, batchL2Data, common.Hash{}, common.BytesToHash([]byte("mock LocalExitRoot")), common.BytesToHash([]byte("mock StateRoot")), common.Address{}, false) rpcBatch.SetLastL2BLockTimestamp(uint64(time.Now().Unix())) m.rpcMock.On("GetWitness", lastVerifiedBatchNum+1, false).Return([]byte("witness"), nil) - m.stateMock.On("AddSequence", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Return(nil).Once() - m.stateMock.On("AddGeneratedProof", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Run( + m.storageMock.On("AddSequence", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Return(nil).Once() + m.storageMock.On("AddGeneratedProof", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Run( func(args mock.Arguments) { proof, ok := args[1].(*state.Proof) if !ok { @@ -1676,7 +1678,7 @@ func Test_tryGenerateBatchProof(t *testing.T) { m.rpcMock.On("GetBatch", lastVerifiedBatchNum+1).Return(rpcBatch, nil) m.proverMock.On("BatchProof", mock.Anything).Return(&proofID, nil).Once() m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return("", common.Hash{}, common.Hash{}, errTest).Once() - m.stateMock.On("DeleteGeneratedProofs", mock.MatchedBy(matchAggregatorCtxFn), batchToProve.BatchNumber, batchToProve.BatchNumber, nil).Return(nil) + m.storageMock.On("DeleteGeneratedProofs", mock.MatchedBy(matchAggregatorCtxFn), batchToProve.BatchNumber, batchToProve.BatchNumber, nil).Return(nil) }, asserts: func(result bool, a *Aggregator, err error) { assert.False(result) @@ -1703,7 +1705,7 @@ func Test_tryGenerateBatchProof(t *testing.T) { m.synchronizerMock.On("GetVirtualBatchByBatchNumber", mock.Anything, lastVerifiedBatchNum+1).Return(&virtualBatch, nil) m.etherman.On("GetLatestVerifiedBatchNum").Return(lastVerifiedBatchNum, nil) - m.stateMock.On("CheckProofExistsForBatch", mock.MatchedBy(matchProverCtxFn), mock.AnythingOfType("uint64"), nil).Return(false, nil) + m.storageMock.On("CheckProofExistsForBatch", mock.MatchedBy(matchProverCtxFn), mock.AnythingOfType("uint64"), nil).Return(false, nil) sequence := synchronizer.SequencedBatches{ FromBatchNumber: uint64(10), ToBatchNumber: uint64(20), @@ -1712,8 +1714,8 @@ func Test_tryGenerateBatchProof(t *testing.T) { rpcBatch := rpctypes.NewRPCBatch(lastVerifiedBatchNum+1, common.Hash{}, []string{}, batchL2Data, common.Hash{}, common.BytesToHash([]byte("mock LocalExitRoot")), common.BytesToHash([]byte("mock StateRoot")), common.Address{}, false) rpcBatch.SetLastL2BLockTimestamp(uint64(time.Now().Unix())) m.rpcMock.On("GetWitness", lastVerifiedBatchNum+1, false).Return([]byte("witness"), nil) - m.stateMock.On("AddSequence", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Return(nil) - m.stateMock.On("AddGeneratedProof", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Run( + m.storageMock.On("AddSequence", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Return(nil) + m.storageMock.On("AddGeneratedProof", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Run( func(args mock.Arguments) { proof, ok := args[1].(*state.Proof) if !ok { @@ -1737,7 +1739,7 @@ func Test_tryGenerateBatchProof(t *testing.T) { m.rpcMock.On("GetBatch", lastVerifiedBatchNum+1).Return(rpcBatch, nil) m.proverMock.On("BatchProof", mock.Anything).Return(&proofID, nil) m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return("", common.Hash{}, common.Hash{}, nil) - m.stateMock.On("UpdateGeneratedProof", mock.Anything, mock.Anything, nil).Return(nil) + m.storageMock.On("UpdateGeneratedProof", mock.Anything, mock.Anything, nil).Return(nil) }, asserts: func(result bool, a *Aggregator, err error) { assert.True(result) @@ -1756,7 +1758,7 @@ func Test_tryGenerateBatchProof(t *testing.T) { l1InfoRoot := common.HexToHash("0x057e9950fbd39b002e323f37c2330d0c096e66919e24cc96fb4b2dfa8f4af782") m.etherman.On("GetLatestVerifiedBatchNum").Return(lastVerifiedBatchNum, nil).Once() - m.stateMock.On("CheckProofExistsForBatch", mock.MatchedBy(matchProverCtxFn), mock.AnythingOfType("uint64"), nil).Return(false, nil).Once() + m.storageMock.On("CheckProofExistsForBatch", mock.MatchedBy(matchProverCtxFn), mock.AnythingOfType("uint64"), nil).Return(false, nil).Once() sequence := synchronizer.SequencedBatches{ FromBatchNumber: uint64(10), ToBatchNumber: uint64(20), @@ -1765,8 +1767,8 @@ func Test_tryGenerateBatchProof(t *testing.T) { rpcBatch := rpctypes.NewRPCBatch(lastVerifiedBatchNum+1, common.Hash{}, []string{}, batchL2Data, common.Hash{}, common.BytesToHash([]byte("mock LocalExitRoot")), common.BytesToHash([]byte("mock StateRoot")), common.Address{}, false) rpcBatch.SetLastL2BLockTimestamp(uint64(time.Now().Unix())) m.rpcMock.On("GetBatch", lastVerifiedBatchNum+1).Return(rpcBatch, nil) - m.stateMock.On("AddSequence", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Return(nil).Once() - m.stateMock.On("AddGeneratedProof", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Run( + m.storageMock.On("AddSequence", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Return(nil).Once() + m.storageMock.On("AddGeneratedProof", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Run( func(args mock.Arguments) { proof, ok := args[1].(*state.Proof) if !ok { @@ -1799,7 +1801,7 @@ func Test_tryGenerateBatchProof(t *testing.T) { m.proverMock.On("BatchProof", mock.Anything).Return(&proofID, nil).Once() m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return("", common.Hash{}, common.Hash{}, errTest).Once() - m.stateMock.On("DeleteGeneratedProofs", mock.Anything, batchToProve.BatchNumber, batchToProve.BatchNumber, nil).Return(errTest).Once() + m.storageMock.On("DeleteGeneratedProofs", mock.MatchedBy(matchAggregatorCtxFn), batchToProve.BatchNumber, batchToProve.BatchNumber, nil).Return(errTest).Once() }, asserts: func(result bool, a *Aggregator, err error) { assert.False(result) @@ -1810,7 +1812,7 @@ func Test_tryGenerateBatchProof(t *testing.T) { for x, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - stateMock := mocks.NewStateInterfaceMock(t) + storageMock := mocks.NewStorageInterfaceMock(t) ethTxManager := mocks.NewEthTxManagerClientMock(t) etherman := mocks.NewEthermanMock(t) proverMock := mocks.NewProverInterfaceMock(t) @@ -1819,15 +1821,14 @@ func Test_tryGenerateBatchProof(t *testing.T) { a := Aggregator{ cfg: cfg, - state: stateMock, + storage: storageMock, etherman: etherman, ethTxManager: ethTxManager, logger: log.GetDefaultLogger(), - stateDBMutex: &sync.Mutex{}, + storageMutex: &sync.Mutex{}, timeSendFinalProofMutex: &sync.RWMutex{}, timeCleanupLockedProofs: cfg.CleanupLockedProofsInterval, finalProof: make(chan finalProofMsg), - profitabilityChecker: NewTxProfitabilityCheckerAcceptAll(stateMock, cfg.IntervalAfterWhichBatchConsolidateAnyway.Duration), l1Syncr: synchronizerMock, rpcClient: mockRPC, accInputHashes: make(map[uint64]common.Hash), @@ -1840,7 +1841,7 @@ func Test_tryGenerateBatchProof(t *testing.T) { a.ctx, a.exit = context.WithCancel(aggregatorCtx) m := mox{ - stateMock: stateMock, + storageMock: storageMock, ethTxManager: ethTxManager, etherman: etherman, proverMock: proverMock, diff --git a/aggregator/config.go b/aggregator/config.go index 2d7178f7..e17d68af 100644 --- a/aggregator/config.go +++ b/aggregator/config.go @@ -4,7 +4,6 @@ import ( "fmt" "math/big" - "github.com/0xPolygon/cdk/aggregator/db" "github.com/0xPolygon/cdk/config/types" "github.com/0xPolygon/cdk/log" "github.com/0xPolygon/zkevm-ethtx-manager/ethtxmanager" @@ -62,14 +61,6 @@ type Config struct { // ProofStatePollingInterval is the interval time to polling the prover about the generation state of a proof ProofStatePollingInterval types.Duration `mapstructure:"ProofStatePollingInterval"` - // TxProfitabilityCheckerType type for checking is it profitable for aggregator to validate batch - // possible values: base/acceptall - TxProfitabilityCheckerType TxProfitabilityCheckerType `mapstructure:"TxProfitabilityCheckerType"` - - // TxProfitabilityMinReward min reward for base tx profitability checker when aggregator will validate batch - // this parameter is used for the base tx profitability checker - TxProfitabilityMinReward TokenAmountWithDecimals `mapstructure:"TxProfitabilityMinReward"` - // IntervalAfterWhichBatchConsolidateAnyway is the interval duration for the main sequencer to check // if there are no transactions. If there are no transactions in this interval, the sequencer will // consolidate the batch anyway. @@ -117,8 +108,8 @@ type Config struct { // UseFullWitness is a flag to enable the use of full witness in the aggregator UseFullWitness bool `mapstructure:"UseFullWitness"` - // DB is the database configuration - DB db.Config `mapstructure:"DB"` + // DBPath is the path to the database + DBPath string `mapstructure:"DBPath"` // EthTxManager is the config for the ethtxmanager EthTxManager ethtxmanager.Config `mapstructure:"EthTxManager"` diff --git a/aggregator/db/config.go b/aggregator/db/config.go deleted file mode 100644 index ad56155f..00000000 --- a/aggregator/db/config.go +++ /dev/null @@ -1,25 +0,0 @@ -package db - -// Config provide fields to configure the pool -type Config struct { - // Database name - Name string `mapstructure:"Name"` - - // Database User name - User string `mapstructure:"User"` - - // Database Password of the user - Password string `mapstructure:"Password"` - - // Host address of database - Host string `mapstructure:"Host"` - - // Port Number of database - Port string `mapstructure:"Port"` - - // EnableLog - EnableLog bool `mapstructure:"EnableLog"` - - // MaxConns is the maximum number of connections in the pool. - MaxConns int `mapstructure:"MaxConns"` -} diff --git a/aggregator/db/db.go b/aggregator/db/db.go deleted file mode 100644 index ecfffc11..00000000 --- a/aggregator/db/db.go +++ /dev/null @@ -1,31 +0,0 @@ -package db - -import ( - "context" - "fmt" - - "github.com/0xPolygon/cdk/log" - "github.com/jackc/pgx/v4/pgxpool" -) - -// NewSQLDB creates a new SQL DB -func NewSQLDB(logger *log.Logger, cfg Config) (*pgxpool.Pool, error) { - config, err := pgxpool.ParseConfig(fmt.Sprintf("postgres://%s:%s@%s:%s/%s?pool_max_conns=%d", - cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Name, cfg.MaxConns)) - if err != nil { - logger.Errorf("Unable to parse DB config: %v\n", err) - return nil, err - } - - if cfg.EnableLog { - config.ConnConfig.Logger = dbLoggerImpl{} - } - - conn, err := pgxpool.ConnectConfig(context.Background(), config) - if err != nil { - logger.Errorf("Unable to connect to database: %v\n", err) - return nil, err - } - - return conn, nil -} diff --git a/aggregator/db/dbstorage/dbstorage.go b/aggregator/db/dbstorage/dbstorage.go new file mode 100644 index 00000000..b20a1c71 --- /dev/null +++ b/aggregator/db/dbstorage/dbstorage.go @@ -0,0 +1,35 @@ +package dbstorage + +import ( + "context" + "database/sql" + + "github.com/0xPolygon/cdk/db" +) + +// DBStorage implements the Storage interface +type DBStorage struct { + DB *sql.DB +} + +// NewDBStorage creates a new DBStorage instance +func NewDBStorage(dbPath string) (*DBStorage, error) { + db, err := db.NewSQLiteDB(dbPath) + if err != nil { + return nil, err + } + + return &DBStorage{DB: db}, nil +} + +func (d *DBStorage) BeginTx(ctx context.Context, options *sql.TxOptions) (db.Txer, error) { + return db.NewTx(ctx, d.DB) +} + +func (d *DBStorage) getExecQuerier(dbTx db.Txer) db.Querier { + if dbTx == nil { + return d.DB + } + + return dbTx +} diff --git a/aggregator/db/dbstorage/proof.go b/aggregator/db/dbstorage/proof.go new file mode 100644 index 00000000..d3065c7e --- /dev/null +++ b/aggregator/db/dbstorage/proof.go @@ -0,0 +1,356 @@ +package dbstorage + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" + + "github.com/0xPolygon/cdk/db" + "github.com/0xPolygon/cdk/state" +) + +// CheckProofExistsForBatch checks if the batch is already included in any proof +func (d *DBStorage) CheckProofExistsForBatch(ctx context.Context, batchNumber uint64, dbTx db.Txer) (bool, error) { + const checkProofExistsForBatchSQL = ` + SELECT EXISTS (SELECT 1 FROM proof p WHERE $1 >= p.batch_num AND $1 <= p.batch_num_final) + ` + e := d.getExecQuerier(dbTx) + var exists bool + err := e.QueryRow(checkProofExistsForBatchSQL, batchNumber).Scan(&exists) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return exists, err + } + return exists, nil +} + +// CheckProofContainsCompleteSequences checks if a recursive proof contains complete sequences +func (d *DBStorage) CheckProofContainsCompleteSequences( + ctx context.Context, proof *state.Proof, dbTx db.Txer, +) (bool, error) { + const getProofContainsCompleteSequencesSQL = ` + SELECT EXISTS (SELECT 1 FROM sequence s1 WHERE s1.from_batch_num = $1) AND + EXISTS (SELECT 1 FROM sequence s2 WHERE s2.to_batch_num = $2) + ` + e := d.getExecQuerier(dbTx) + var exists bool + err := e.QueryRow(getProofContainsCompleteSequencesSQL, proof.BatchNumber, proof.BatchNumberFinal).Scan(&exists) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return exists, err + } + return exists, nil +} + +// GetProofReadyToVerify return the proof that is ready to verify +func (d *DBStorage) GetProofReadyToVerify( + ctx context.Context, lastVerfiedBatchNumber uint64, dbTx db.Txer, +) (*state.Proof, error) { + const getProofReadyToVerifySQL = ` + SELECT + p.batch_num, + p.batch_num_final, + p.proof, + p.proof_id, + p.input_prover, + p.prover, + p.prover_id, + p.generating_since, + p.created_at, + p.updated_at + FROM proof p + WHERE batch_num = $1 AND generating_since IS NULL AND + EXISTS (SELECT 1 FROM sequence s1 WHERE s1.from_batch_num = p.batch_num) AND + EXISTS (SELECT 1 FROM sequence s2 WHERE s2.to_batch_num = p.batch_num_final) + ` + + var proof = &state.Proof{} + + e := d.getExecQuerier(dbTx) + row := e.QueryRow(getProofReadyToVerifySQL, lastVerfiedBatchNumber+1) + + var ( + generatingSince *uint64 + createdAt *uint64 + updatedAt *uint64 + ) + err := row.Scan( + &proof.BatchNumber, &proof.BatchNumberFinal, &proof.Proof, &proof.ProofID, + &proof.InputProver, &proof.Prover, &proof.ProverID, &generatingSince, + &createdAt, &updatedAt, + ) + + if generatingSince != nil { + timeSince := time.Unix(int64(*generatingSince), 0) + proof.GeneratingSince = &timeSince + } + + if createdAt != nil { + proof.CreatedAt = time.Unix(int64(*createdAt), 0) + } + + if updatedAt != nil { + proof.UpdatedAt = time.Unix(int64(*updatedAt), 0) + } + + if errors.Is(err, sql.ErrNoRows) { + return nil, state.ErrNotFound + } else if err != nil { + return nil, err + } + + return proof, err +} + +// GetProofsToAggregate return the next to proof that it is possible to aggregate +func (d *DBStorage) GetProofsToAggregate(ctx context.Context, dbTx db.Txer) (*state.Proof, *state.Proof, error) { + var ( + proof1 = &state.Proof{} + proof2 = &state.Proof{} + ) + + // TODO: add comments to explain the query + const getProofsToAggregateSQL = ` + SELECT + p1.batch_num as p1_batch_num, + p1.batch_num_final as p1_batch_num_final, + p1.proof as p1_proof, + p1.proof_id as p1_proof_id, + p1.input_prover as p1_input_prover, + p1.prover as p1_prover, + p1.prover_id as p1_prover_id, + p1.generating_since as p1_generating_since, + p1.created_at as p1_created_at, + p1.updated_at as p1_updated_at, + p2.batch_num as p2_batch_num, + p2.batch_num_final as p2_batch_num_final, + p2.proof as p2_proof, + p2.proof_id as p2_proof_id, + p2.input_prover as p2_input_prover, + p2.prover as p2_prover, + p2.prover_id as p2_prover_id, + p2.generating_since as p2_generating_since, + p2.created_at as p2_created_at, + p2.updated_at as p2_updated_at + FROM proof p1 INNER JOIN proof p2 ON p1.batch_num_final = p2.batch_num - 1 + WHERE p1.generating_since IS NULL AND p2.generating_since IS NULL AND + p1.proof IS NOT NULL AND p2.proof IS NOT NULL AND + ( + EXISTS ( + SELECT 1 FROM sequence s + WHERE p1.batch_num >= s.from_batch_num AND p1.batch_num <= s.to_batch_num AND + p1.batch_num_final >= s.from_batch_num AND p1.batch_num_final <= s.to_batch_num AND + p2.batch_num >= s.from_batch_num AND p2.batch_num <= s.to_batch_num AND + p2.batch_num_final >= s.from_batch_num AND p2.batch_num_final <= s.to_batch_num + ) + OR + ( + EXISTS ( SELECT 1 FROM sequence s WHERE p1.batch_num = s.from_batch_num) AND + EXISTS ( SELECT 1 FROM sequence s WHERE p1.batch_num_final = s.to_batch_num) AND + EXISTS ( SELECT 1 FROM sequence s WHERE p2.batch_num = s.from_batch_num) AND + EXISTS ( SELECT 1 FROM sequence s WHERE p2.batch_num_final = s.to_batch_num) + ) + ) + ORDER BY p1.batch_num ASC + LIMIT 1 + ` + + e := d.getExecQuerier(dbTx) + row := e.QueryRow(getProofsToAggregateSQL) + + var ( + generatingSince1, generatingSince2 *uint64 + createdAt1, createdAt2 *uint64 + updatedAt1, updatedAt2 *uint64 + ) + + err := row.Scan( + &proof1.BatchNumber, &proof1.BatchNumberFinal, &proof1.Proof, &proof1.ProofID, + &proof1.InputProver, &proof1.Prover, &proof1.ProverID, &generatingSince1, + &createdAt1, &updatedAt1, + &proof2.BatchNumber, &proof2.BatchNumberFinal, &proof2.Proof, &proof2.ProofID, + &proof2.InputProver, &proof2.Prover, &proof2.ProverID, &generatingSince2, + &createdAt1, &updatedAt1, + ) + + if generatingSince1 != nil { + timeSince1 := time.Unix(int64(*generatingSince1), 0) + proof1.GeneratingSince = &timeSince1 + } + + if generatingSince2 != nil { + timeSince2 := time.Unix(int64(*generatingSince2), 0) + proof2.GeneratingSince = &timeSince2 + } + + if createdAt1 != nil { + proof1.CreatedAt = time.Unix(int64(*createdAt1), 0) + } + + if createdAt2 != nil { + proof2.CreatedAt = time.Unix(int64(*createdAt2), 0) + } + + if updatedAt1 != nil { + proof1.UpdatedAt = time.Unix(int64(*updatedAt1), 0) + } + + if updatedAt2 != nil { + proof2.UpdatedAt = time.Unix(int64(*updatedAt2), 0) + } + + if errors.Is(err, sql.ErrNoRows) { + return nil, nil, state.ErrNotFound + } else if err != nil { + return nil, nil, err + } + + return proof1, proof2, err +} + +// AddGeneratedProof adds a generated proof to the storage +func (d *DBStorage) AddGeneratedProof(ctx context.Context, proof *state.Proof, dbTx db.Txer) error { + const addGeneratedProofSQL = ` + INSERT INTO proof ( + batch_num, batch_num_final, proof, proof_id, input_prover, prover, + prover_id, generating_since, created_at, updated_at + ) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10 + ) + ` + e := d.getExecQuerier(dbTx) + now := time.Now().UTC().Round(time.Microsecond) + + var ( + generatingSince *uint64 + createdAt *uint64 + updatedAt *uint64 + ) + + if proof.GeneratingSince != nil { + generatingSince = new(uint64) + *generatingSince = uint64(proof.GeneratingSince.Unix()) + } + + if !proof.CreatedAt.IsZero() { + createdAt = new(uint64) + *createdAt = uint64(proof.CreatedAt.Unix()) + } else { + createdAt = new(uint64) + *createdAt = uint64(now.Unix()) + } + + if !proof.UpdatedAt.IsZero() { + updatedAt = new(uint64) + *updatedAt = uint64(proof.UpdatedAt.Unix()) + } else { + updatedAt = new(uint64) + *updatedAt = uint64(now.Unix()) + } + + _, err := e.Exec( + addGeneratedProofSQL, proof.BatchNumber, proof.BatchNumberFinal, proof.Proof, proof.ProofID, + proof.InputProver, proof.Prover, proof.ProverID, generatingSince, createdAt, updatedAt, + ) + return err +} + +// UpdateGeneratedProof updates a generated proof in the storage +func (d *DBStorage) UpdateGeneratedProof(ctx context.Context, proof *state.Proof, dbTx db.Txer) error { + const updateGeneratedProofSQL = ` + UPDATE proof + SET proof = $3, + proof_id = $4, + input_prover = $5, + prover = $6, + prover_id = $7, + generating_since = $8, + updated_at = $9 + WHERE batch_num = $1 + AND batch_num_final = $2 + ` + e := d.getExecQuerier(dbTx) + now := time.Now().UTC().Round(time.Microsecond) + + var ( + generatingSince *uint64 + updatedAt *uint64 + ) + + if proof.GeneratingSince != nil { + generatingSince = new(uint64) + *generatingSince = uint64(proof.GeneratingSince.Unix()) + } + + if !proof.UpdatedAt.IsZero() { + updatedAt = new(uint64) + *updatedAt = uint64(proof.UpdatedAt.Unix()) + } else { + updatedAt = new(uint64) + *updatedAt = uint64(now.Unix()) + } + _, err := e.Exec( + updateGeneratedProofSQL, proof.Proof, proof.ProofID, proof.InputProver, + proof.Prover, proof.ProverID, generatingSince, updatedAt, proof.BatchNumber, proof.BatchNumberFinal, + ) + return err +} + +// DeleteGeneratedProofs deletes from the storage the generated proofs falling +// inside the batch numbers range. +func (d *DBStorage) DeleteGeneratedProofs( + ctx context.Context, batchNumber uint64, batchNumberFinal uint64, dbTx db.Txer, +) error { + const deleteGeneratedProofSQL = "DELETE FROM proof WHERE batch_num >= $1 AND batch_num_final <= $2" + e := d.getExecQuerier(dbTx) + _, err := e.Exec(deleteGeneratedProofSQL, batchNumber, batchNumberFinal) + return err +} + +// CleanupGeneratedProofs deletes from the storage the generated proofs up to +// the specified batch number included. +func (d *DBStorage) CleanupGeneratedProofs(ctx context.Context, batchNumber uint64, dbTx db.Txer) error { + const deleteGeneratedProofSQL = "DELETE FROM proof WHERE batch_num_final <= $1" + e := d.getExecQuerier(dbTx) + _, err := e.Exec(deleteGeneratedProofSQL, batchNumber) + return err +} + +// CleanupLockedProofs deletes from the storage the proofs locked in generating +// state for more than the provided threshold. +func (d *DBStorage) CleanupLockedProofs(ctx context.Context, duration string, dbTx db.Txer) (int64, error) { + seconds, err := convertDurationToSeconds(duration) + if err != nil { + return 0, err + } + + difference := time.Now().Unix() - seconds + + sql := fmt.Sprintf("DELETE FROM proof WHERE generating_since is not null and generating_since < %d", difference) + e := d.getExecQuerier(dbTx) + ct, err := e.Exec(sql) + if err != nil { + return 0, err + } + return ct.RowsAffected() +} + +// DeleteUngeneratedProofs deletes ungenerated proofs. +// This method is meant to be use during aggregator boot-up sequence +func (d *DBStorage) DeleteUngeneratedProofs(ctx context.Context, dbTx db.Txer) error { + const deleteUngeneratedProofsSQL = "DELETE FROM proof WHERE generating_since IS NOT NULL" + e := d.getExecQuerier(dbTx) + _, err := e.Exec(deleteUngeneratedProofsSQL) + return err +} + +func convertDurationToSeconds(duration string) (int64, error) { + // Parse the duration using time.ParseDuration + parsedDuration, err := time.ParseDuration(duration) + if err != nil { + return 0, fmt.Errorf("invalid duration format: %w", err) + } + + // Return the duration in seconds + return int64(parsedDuration.Seconds()), nil +} diff --git a/aggregator/db/dbstorage/proof_test.go b/aggregator/db/dbstorage/proof_test.go new file mode 100644 index 00000000..f8095086 --- /dev/null +++ b/aggregator/db/dbstorage/proof_test.go @@ -0,0 +1,150 @@ +package dbstorage + +import ( + "context" + "math" + "testing" + "time" + + "github.com/0xPolygon/cdk/aggregator/db" + "github.com/0xPolygon/cdk/state" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + proofID = "proof_1" + prover = "prover_1" + proverID = "prover_id" +) + +func Test_Proof(t *testing.T) { + dbPath := "file::memory:?cache=shared" + err := db.RunMigrationsUp(dbPath, db.AggregatorMigrationName) + assert.NoError(t, err) + + ctx := context.Background() + now := time.Now() + + DBStorage, err := NewDBStorage(dbPath) + assert.NoError(t, err) + + dbtxer, err := DBStorage.BeginTx(ctx, nil) + require.NoError(t, err) + + exists, err := DBStorage.CheckProofExistsForBatch(ctx, 1, dbtxer) + assert.NoError(t, err) + assert.False(t, exists) + + proof := state.Proof{ + BatchNumber: 1, + BatchNumberFinal: 1, + Proof: "proof content", + InputProver: "input prover", + ProofID: &proofID, + Prover: &prover, + ProverID: &proofID, + GeneratingSince: nil, + CreatedAt: now, + UpdatedAt: now, + } + + err = DBStorage.AddGeneratedProof(ctx, &proof, dbtxer) + assert.NoError(t, err) + + err = DBStorage.AddSequence(ctx, state.Sequence{FromBatchNumber: 1, ToBatchNumber: 1}, dbtxer) + assert.NoError(t, err) + + contains, err := DBStorage.CheckProofContainsCompleteSequences(ctx, &proof, dbtxer) + assert.NoError(t, err) + assert.True(t, contains) + + proof2, err := DBStorage.GetProofReadyToVerify(ctx, 0, dbtxer) + assert.NoError(t, err) + assert.NotNil(t, proof2) + + require.Equal(t, proof.BatchNumber, proof2.BatchNumber) + require.Equal(t, proof.BatchNumberFinal, proof2.BatchNumberFinal) + require.Equal(t, proof.Proof, proof2.Proof) + require.Equal(t, *proof.ProofID, *proof2.ProofID) + require.Equal(t, proof.InputProver, proof2.InputProver) + require.Equal(t, *proof.Prover, *proof2.Prover) + require.Equal(t, *proof.ProverID, *proof2.ProverID) + require.Equal(t, proof.CreatedAt.Unix(), proof2.CreatedAt.Unix()) + require.Equal(t, proof.UpdatedAt.Unix(), proof2.UpdatedAt.Unix()) + + proof = state.Proof{ + BatchNumber: 1, + BatchNumberFinal: 1, + Proof: "proof content", + InputProver: "input prover", + ProofID: &proofID, + Prover: &prover, + ProverID: &proofID, + GeneratingSince: &now, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + err = DBStorage.UpdateGeneratedProof(ctx, &proof, dbtxer) + assert.NoError(t, err) + + sequence := state.Sequence{FromBatchNumber: 3, ToBatchNumber: 4} + + proof3 := state.Proof{ + BatchNumber: 3, + BatchNumberFinal: 3, + GeneratingSince: nil, + } + + proof4 := state.Proof{ + BatchNumber: 4, + BatchNumberFinal: 4, + GeneratingSince: nil, + } + + err = DBStorage.AddSequence(ctx, sequence, dbtxer) + assert.NoError(t, err) + + err = DBStorage.AddGeneratedProof(ctx, &proof3, dbtxer) + assert.NoError(t, err) + + err = DBStorage.AddGeneratedProof(ctx, &proof4, dbtxer) + assert.NoError(t, err) + + proof5, proof6, err := DBStorage.GetProofsToAggregate(ctx, dbtxer) + assert.NoError(t, err) + assert.NotNil(t, proof5) + assert.NotNil(t, proof6) + + err = DBStorage.DeleteGeneratedProofs(ctx, 1, math.MaxInt, dbtxer) + assert.NoError(t, err) + + err = DBStorage.CleanupGeneratedProofs(ctx, 1, dbtxer) + assert.NoError(t, err) + + now = time.Now() + + proof3.GeneratingSince = &now + proof4.GeneratingSince = &now + + err = DBStorage.AddGeneratedProof(ctx, &proof3, dbtxer) + assert.NoError(t, err) + + err = DBStorage.AddGeneratedProof(ctx, &proof4, dbtxer) + assert.NoError(t, err) + + time.Sleep(5 * time.Second) + + affected, err := DBStorage.CleanupLockedProofs(ctx, "4s", dbtxer) + assert.NoError(t, err) + require.Equal(t, int64(2), affected) + + proof5, proof6, err = DBStorage.GetProofsToAggregate(ctx, dbtxer) + assert.EqualError(t, err, state.ErrNotFound.Error()) + assert.Nil(t, proof5) + assert.Nil(t, proof6) + + err = DBStorage.DeleteUngeneratedProofs(ctx, dbtxer) + assert.NoError(t, err) +} diff --git a/aggregator/db/dbstorage/sequence.go b/aggregator/db/dbstorage/sequence.go new file mode 100644 index 00000000..96063201 --- /dev/null +++ b/aggregator/db/dbstorage/sequence.go @@ -0,0 +1,21 @@ +package dbstorage + +import ( + "context" + + "github.com/0xPolygon/cdk/db" + "github.com/0xPolygon/cdk/state" +) + +// AddSequence stores the sequence information to allow the aggregator verify sequences. +func (d *DBStorage) AddSequence(ctx context.Context, sequence state.Sequence, dbTx db.Txer) error { + const addSequenceSQL = ` + INSERT INTO sequence (from_batch_num, to_batch_num) + VALUES($1, $2) + ON CONFLICT (from_batch_num) DO UPDATE SET to_batch_num = $2 + ` + + e := d.getExecQuerier(dbTx) + _, err := e.Exec(addSequenceSQL, sequence.FromBatchNumber, sequence.ToBatchNumber) + return err +} diff --git a/aggregator/db/migrations.go b/aggregator/db/migrations.go index 20e8c29a..221fb145 100644 --- a/aggregator/db/migrations.go +++ b/aggregator/db/migrations.go @@ -4,15 +4,14 @@ import ( "embed" "fmt" + "github.com/0xPolygon/cdk/db" "github.com/0xPolygon/cdk/log" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/stdlib" migrate "github.com/rubenv/sql-migrate" ) const ( // AggregatorMigrationName is the name of the migration used to associate with the migrations dir - AggregatorMigrationName = "zkevm-aggregator-db" + AggregatorMigrationName = "aggregator-db" ) var ( @@ -28,38 +27,33 @@ func init() { } // RunMigrationsUp runs migrate-up for the given config. -func RunMigrationsUp(cfg Config, name string) error { +func RunMigrationsUp(dbPath string, name string) error { log.Info("running migrations up") - return runMigrations(cfg, name, migrate.Up) + return runMigrations(dbPath, name, migrate.Up) } // CheckMigrations runs migrate-up for the given config. -func CheckMigrations(cfg Config, name string) error { - return checkMigrations(cfg, name) +func CheckMigrations(dbPath string, name string) error { + return checkMigrations(dbPath, name) } // RunMigrationsDown runs migrate-down for the given config. -func RunMigrationsDown(cfg Config, name string) error { +func RunMigrationsDown(dbPath string, name string) error { log.Info("running migrations down") - return runMigrations(cfg, name, migrate.Down) + return runMigrations(dbPath, name, migrate.Down) } // runMigrations will execute pending migrations if needed to keep // the database updated with the latest changes in either direction, // up or down. -func runMigrations(cfg Config, name string, direction migrate.MigrationDirection) error { - c, err := pgx.ParseConfig(fmt.Sprintf( - "postgres://%s:%s@%s:%s/%s", - cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Name, - )) +func runMigrations(dbPath string, name string, direction migrate.MigrationDirection) error { + db, err := db.NewSQLiteDB(dbPath) if err != nil { return err } - db := stdlib.OpenDB(*c) - embedMigration, ok := embedMigrations[name] if !ok { return fmt.Errorf("migration not found with name: %v", name) @@ -70,7 +64,7 @@ func runMigrations(cfg Config, name string, direction migrate.MigrationDirection Root: "migrations", } - nMigrations, err := migrate.Exec(db, "postgres", migrations, direction) + nMigrations, err := migrate.Exec(db, "sqlite3", migrations, direction) if err != nil { return err } @@ -80,17 +74,12 @@ func runMigrations(cfg Config, name string, direction migrate.MigrationDirection return nil } -func checkMigrations(cfg Config, name string) error { - c, err := pgx.ParseConfig(fmt.Sprintf( - "postgres://%s:%s@%s:%s/%s", - cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Name, - )) +func checkMigrations(dbPath string, name string) error { + db, err := db.NewSQLiteDB(dbPath) if err != nil { return err } - db := stdlib.OpenDB(*c) - embedMigration, ok := embedMigrations[name] if !ok { return fmt.Errorf("migration not found with name: %v", name) diff --git a/aggregator/db/migrations/0001.sql b/aggregator/db/migrations/0001.sql index 963dbea7..651597a3 100644 --- a/aggregator/db/migrations/0001.sql +++ b/aggregator/db/migrations/0001.sql @@ -1,32 +1,24 @@ -- +migrate Down -DROP SCHEMA IF EXISTS aggregator CASCADE; +DROP TABLE IF EXISTS proof; +DROP TABLE IF EXISTS sequence; -- +migrate Up -CREATE SCHEMA aggregator; - -CREATE TABLE IF NOT EXISTS aggregator.batch ( +CREATE TABLE IF NOT EXISTS proof ( batch_num BIGINT NOT NULL, - batch jsonb NOT NULL, - datastream varchar NOT NULL, - PRIMARY KEY (batch_num) -); - -CREATE TABLE IF NOT EXISTS aggregator.proof ( - batch_num BIGINT NOT NULL REFERENCES aggregator.batch (batch_num) ON DELETE CASCADE, batch_num_final BIGINT NOT NULL, - proof varchar NULL, - proof_id varchar NULL, - input_prover varchar NULL, - prover varchar NULL, - prover_id varchar NULL, - created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(), - updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(), - generating_since timestamptz NULL, + proof TEXT NULL, + proof_id TEXT NULL, + input_prover TEXT NULL, + prover TEXT NULL, + prover_id TEXT NULL, + created_at BIGINT NOT NULL, + updated_at BIGINT NOT NULL, + generating_since BIGINT DEFAULT NULL, PRIMARY KEY (batch_num, batch_num_final) ); -CREATE TABLE IF NOT EXISTS aggregator.sequence ( - from_batch_num BIGINT NOT NULL REFERENCES aggregator.batch (batch_num) ON DELETE CASCADE, +CREATE TABLE IF NOT EXISTS sequence ( + from_batch_num BIGINT NOT NULL, to_batch_num BIGINT NOT NULL, PRIMARY KEY (from_batch_num) ); diff --git a/aggregator/db/migrations/0002.sql b/aggregator/db/migrations/0002.sql deleted file mode 100644 index e2290e13..00000000 --- a/aggregator/db/migrations/0002.sql +++ /dev/null @@ -1,8 +0,0 @@ --- +migrate Up -DELETE FROM aggregator.batch; -ALTER TABLE aggregator.batch - ADD COLUMN IF NOT EXISTS witness varchar NOT NULL; - --- +migrate Down -ALTER TABLE aggregator.batch - DROP COLUMN IF EXISTS witness; diff --git a/aggregator/db/migrations/0003.sql b/aggregator/db/migrations/0003.sql deleted file mode 100644 index 5351f8e7..00000000 --- a/aggregator/db/migrations/0003.sql +++ /dev/null @@ -1,7 +0,0 @@ --- +migrate Up -ALTER TABLE aggregator.batch - ALTER COLUMN witness DROP NOT NULL; - --- +migrate Down -ALTER TABLE aggregator.batch - ALTER COLUMN witness SET NOT NULL; diff --git a/aggregator/db/migrations/0004.sql b/aggregator/db/migrations/0004.sql deleted file mode 100644 index cb186fc0..00000000 --- a/aggregator/db/migrations/0004.sql +++ /dev/null @@ -1,23 +0,0 @@ --- +migrate Down -CREATE TABLE IF NOT EXISTS aggregator.batch ( - batch_num BIGINT NOT NULL, - batch jsonb NOT NULL, - datastream varchar NOT NULL, - PRIMARY KEY (batch_num) -); - -ALTER TABLE aggregator.proof - ADD CONSTRAINT IF NOT EXISTS proof_batch_num_fkey FOREIGN KEY (batch_num) REFERENCES aggregator.batch (batch_num) ON DELETE CASCADE; - -ALTER TABLE aggregator.sequence - ADD CONSTRAINT IF NOT EXISTS sequence_from_batch_num_fkey FOREIGN KEY (from_batch_num) REFERENCES aggregator.batch (batch_num) ON DELETE CASCADE; - - --- +migrate Up -ALTER TABLE aggregator.proof - DROP CONSTRAINT IF EXISTS proof_batch_num_fkey; - -ALTER TABLE aggregator.sequence - DROP CONSTRAINT IF EXISTS sequence_from_batch_num_fkey; - -DROP TABLE IF EXISTS aggregator.batch; diff --git a/aggregator/db/migrations_test.go b/aggregator/db/migrations_test.go index 0a118c69..317178e9 100644 --- a/aggregator/db/migrations_test.go +++ b/aggregator/db/migrations_test.go @@ -16,3 +16,12 @@ func Test_checkMigrations(t *testing.T) { _, err := migrationSource.FileSystem.ReadFile("migrations/0001.sql") assert.NoError(t, err) } + +func Test_runMigrations(t *testing.T) { + dbPath := "file::memory:?cache=shared" + err := runMigrations(dbPath, AggregatorMigrationName, migrate.Up) + assert.NoError(t, err) + + err = runMigrations(dbPath, AggregatorMigrationName, migrate.Down) + assert.NoError(t, err) +} diff --git a/aggregator/interfaces.go b/aggregator/interfaces.go index f1673c46..5979272d 100644 --- a/aggregator/interfaces.go +++ b/aggregator/interfaces.go @@ -2,10 +2,12 @@ package aggregator import ( "context" + "database/sql" "math/big" ethmanTypes "github.com/0xPolygon/cdk/aggregator/ethmantypes" "github.com/0xPolygon/cdk/aggregator/prover" + "github.com/0xPolygon/cdk/db" "github.com/0xPolygon/cdk/rpc/types" "github.com/0xPolygon/cdk/state" "github.com/0xPolygon/zkevm-ethtx-manager/ethtxmanager" @@ -13,7 +15,6 @@ import ( "github.com/ethereum/go-ethereum/common" ethtypes "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto/kzg4844" - "github.com/jackc/pgx/v4" ) // Consumer interfaces required by the package. @@ -53,19 +54,19 @@ type aggregatorTxProfitabilityChecker interface { } // StateInterface gathers the methods to interact with the state. -type StateInterface interface { - BeginStateTransaction(ctx context.Context) (pgx.Tx, error) - CheckProofContainsCompleteSequences(ctx context.Context, proof *state.Proof, dbTx pgx.Tx) (bool, error) - GetProofReadyToVerify(ctx context.Context, lastVerfiedBatchNumber uint64, dbTx pgx.Tx) (*state.Proof, error) - GetProofsToAggregate(ctx context.Context, dbTx pgx.Tx) (*state.Proof, *state.Proof, error) - AddGeneratedProof(ctx context.Context, proof *state.Proof, dbTx pgx.Tx) error - UpdateGeneratedProof(ctx context.Context, proof *state.Proof, dbTx pgx.Tx) error - DeleteGeneratedProofs(ctx context.Context, batchNumber uint64, batchNumberFinal uint64, dbTx pgx.Tx) error - DeleteUngeneratedProofs(ctx context.Context, dbTx pgx.Tx) error - CleanupGeneratedProofs(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) error - CleanupLockedProofs(ctx context.Context, duration string, dbTx pgx.Tx) (int64, error) - CheckProofExistsForBatch(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) (bool, error) - AddSequence(ctx context.Context, sequence state.Sequence, dbTx pgx.Tx) error +type StorageInterface interface { + BeginTx(ctx context.Context, options *sql.TxOptions) (db.Txer, error) + CheckProofContainsCompleteSequences(ctx context.Context, proof *state.Proof, dbTx db.Txer) (bool, error) + GetProofReadyToVerify(ctx context.Context, lastVerfiedBatchNumber uint64, dbTx db.Txer) (*state.Proof, error) + GetProofsToAggregate(ctx context.Context, dbTx db.Txer) (*state.Proof, *state.Proof, error) + AddGeneratedProof(ctx context.Context, proof *state.Proof, dbTx db.Txer) error + UpdateGeneratedProof(ctx context.Context, proof *state.Proof, dbTx db.Txer) error + DeleteGeneratedProofs(ctx context.Context, batchNumber uint64, batchNumberFinal uint64, dbTx db.Txer) error + DeleteUngeneratedProofs(ctx context.Context, dbTx db.Txer) error + CleanupGeneratedProofs(ctx context.Context, batchNumber uint64, dbTx db.Txer) error + CleanupLockedProofs(ctx context.Context, duration string, dbTx db.Txer) (int64, error) + CheckProofExistsForBatch(ctx context.Context, batchNumber uint64, dbTx db.Txer) (bool, error) + AddSequence(ctx context.Context, sequence state.Sequence, dbTx db.Txer) error } // EthTxManagerClient represents the eth tx manager interface diff --git a/aggregator/mocks/mock_dbtx.go b/aggregator/mocks/mock_dbtx.go deleted file mode 100644 index f870cd57..00000000 --- a/aggregator/mocks/mock_dbtx.go +++ /dev/null @@ -1,350 +0,0 @@ -// Code generated by mockery v2.39.0. DO NOT EDIT. - -package mocks - -import ( - context "context" - - pgconn "github.com/jackc/pgconn" - mock "github.com/stretchr/testify/mock" - - pgx "github.com/jackc/pgx/v4" -) - -// DbTxMock is an autogenerated mock type for the Tx type -type DbTxMock struct { - mock.Mock -} - -// Begin provides a mock function with given fields: ctx -func (_m *DbTxMock) Begin(ctx context.Context) (pgx.Tx, error) { - ret := _m.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for Begin") - } - - var r0 pgx.Tx - var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (pgx.Tx, error)); ok { - return rf(ctx) - } - if rf, ok := ret.Get(0).(func(context.Context) pgx.Tx); ok { - r0 = rf(ctx) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(pgx.Tx) - } - } - - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// BeginFunc provides a mock function with given fields: ctx, f -func (_m *DbTxMock) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error { - ret := _m.Called(ctx, f) - - if len(ret) == 0 { - panic("no return value specified for BeginFunc") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, func(pgx.Tx) error) error); ok { - r0 = rf(ctx, f) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// Commit provides a mock function with given fields: ctx -func (_m *DbTxMock) Commit(ctx context.Context) error { - ret := _m.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for Commit") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context) error); ok { - r0 = rf(ctx) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// Conn provides a mock function with given fields: -func (_m *DbTxMock) Conn() *pgx.Conn { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for Conn") - } - - var r0 *pgx.Conn - if rf, ok := ret.Get(0).(func() *pgx.Conn); ok { - r0 = rf() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*pgx.Conn) - } - } - - return r0 -} - -// CopyFrom provides a mock function with given fields: ctx, tableName, columnNames, rowSrc -func (_m *DbTxMock) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { - ret := _m.Called(ctx, tableName, columnNames, rowSrc) - - if len(ret) == 0 { - panic("no return value specified for CopyFrom") - } - - var r0 int64 - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) (int64, error)); ok { - return rf(ctx, tableName, columnNames, rowSrc) - } - if rf, ok := ret.Get(0).(func(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) int64); ok { - r0 = rf(ctx, tableName, columnNames, rowSrc) - } else { - r0 = ret.Get(0).(int64) - } - - if rf, ok := ret.Get(1).(func(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) error); ok { - r1 = rf(ctx, tableName, columnNames, rowSrc) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// Exec provides a mock function with given fields: ctx, sql, arguments -func (_m *DbTxMock) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { - var _ca []interface{} - _ca = append(_ca, ctx, sql) - _ca = append(_ca, arguments...) - ret := _m.Called(_ca...) - - if len(ret) == 0 { - panic("no return value specified for Exec") - } - - var r0 pgconn.CommandTag - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) (pgconn.CommandTag, error)); ok { - return rf(ctx, sql, arguments...) - } - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) pgconn.CommandTag); ok { - r0 = rf(ctx, sql, arguments...) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(pgconn.CommandTag) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string, ...interface{}) error); ok { - r1 = rf(ctx, sql, arguments...) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// LargeObjects provides a mock function with given fields: -func (_m *DbTxMock) LargeObjects() pgx.LargeObjects { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for LargeObjects") - } - - var r0 pgx.LargeObjects - if rf, ok := ret.Get(0).(func() pgx.LargeObjects); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(pgx.LargeObjects) - } - - return r0 -} - -// Prepare provides a mock function with given fields: ctx, name, sql -func (_m *DbTxMock) Prepare(ctx context.Context, name string, sql string) (*pgconn.StatementDescription, error) { - ret := _m.Called(ctx, name, sql) - - if len(ret) == 0 { - panic("no return value specified for Prepare") - } - - var r0 *pgconn.StatementDescription - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) (*pgconn.StatementDescription, error)); ok { - return rf(ctx, name, sql) - } - if rf, ok := ret.Get(0).(func(context.Context, string, string) *pgconn.StatementDescription); ok { - r0 = rf(ctx, name, sql) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*pgconn.StatementDescription) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { - r1 = rf(ctx, name, sql) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// Query provides a mock function with given fields: ctx, sql, args -func (_m *DbTxMock) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { - var _ca []interface{} - _ca = append(_ca, ctx, sql) - _ca = append(_ca, args...) - ret := _m.Called(_ca...) - - if len(ret) == 0 { - panic("no return value specified for Query") - } - - var r0 pgx.Rows - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) (pgx.Rows, error)); ok { - return rf(ctx, sql, args...) - } - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) pgx.Rows); ok { - r0 = rf(ctx, sql, args...) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(pgx.Rows) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string, ...interface{}) error); ok { - r1 = rf(ctx, sql, args...) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// QueryFunc provides a mock function with given fields: ctx, sql, args, scans, f -func (_m *DbTxMock) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { - ret := _m.Called(ctx, sql, args, scans, f) - - if len(ret) == 0 { - panic("no return value specified for QueryFunc") - } - - var r0 pgconn.CommandTag - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, []interface{}, []interface{}, func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error)); ok { - return rf(ctx, sql, args, scans, f) - } - if rf, ok := ret.Get(0).(func(context.Context, string, []interface{}, []interface{}, func(pgx.QueryFuncRow) error) pgconn.CommandTag); ok { - r0 = rf(ctx, sql, args, scans, f) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(pgconn.CommandTag) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string, []interface{}, []interface{}, func(pgx.QueryFuncRow) error) error); ok { - r1 = rf(ctx, sql, args, scans, f) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// QueryRow provides a mock function with given fields: ctx, sql, args -func (_m *DbTxMock) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { - var _ca []interface{} - _ca = append(_ca, ctx, sql) - _ca = append(_ca, args...) - ret := _m.Called(_ca...) - - if len(ret) == 0 { - panic("no return value specified for QueryRow") - } - - var r0 pgx.Row - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) pgx.Row); ok { - r0 = rf(ctx, sql, args...) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(pgx.Row) - } - } - - return r0 -} - -// Rollback provides a mock function with given fields: ctx -func (_m *DbTxMock) Rollback(ctx context.Context) error { - ret := _m.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for Rollback") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context) error); ok { - r0 = rf(ctx) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// SendBatch provides a mock function with given fields: ctx, b -func (_m *DbTxMock) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { - ret := _m.Called(ctx, b) - - if len(ret) == 0 { - panic("no return value specified for SendBatch") - } - - var r0 pgx.BatchResults - if rf, ok := ret.Get(0).(func(context.Context, *pgx.Batch) pgx.BatchResults); ok { - r0 = rf(ctx, b) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(pgx.BatchResults) - } - } - - return r0 -} - -// NewDbTxMock creates a new instance of DbTxMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewDbTxMock(t interface { - mock.TestingT - Cleanup(func()) -}) *DbTxMock { - mock := &DbTxMock{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/aggregator/mocks/mock_state.go b/aggregator/mocks/mock_storage.go similarity index 54% rename from aggregator/mocks/mock_state.go rename to aggregator/mocks/mock_storage.go index 74c9021b..405cba46 100644 --- a/aggregator/mocks/mock_state.go +++ b/aggregator/mocks/mock_storage.go @@ -5,19 +5,21 @@ package mocks import ( context "context" - pgx "github.com/jackc/pgx/v4" + db "github.com/0xPolygon/cdk/db" mock "github.com/stretchr/testify/mock" + sql "database/sql" + state "github.com/0xPolygon/cdk/state" ) -// StateInterfaceMock is an autogenerated mock type for the StateInterface type -type StateInterfaceMock struct { +// StorageInterfaceMock is an autogenerated mock type for the StorageInterface type +type StorageInterfaceMock struct { mock.Mock } // AddGeneratedProof provides a mock function with given fields: ctx, proof, dbTx -func (_m *StateInterfaceMock) AddGeneratedProof(ctx context.Context, proof *state.Proof, dbTx pgx.Tx) error { +func (_m *StorageInterfaceMock) AddGeneratedProof(ctx context.Context, proof *state.Proof, dbTx db.Txer) error { ret := _m.Called(ctx, proof, dbTx) if len(ret) == 0 { @@ -25,7 +27,7 @@ func (_m *StateInterfaceMock) AddGeneratedProof(ctx context.Context, proof *stat } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, *state.Proof, pgx.Tx) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, *state.Proof, db.Txer) error); ok { r0 = rf(ctx, proof, dbTx) } else { r0 = ret.Error(0) @@ -35,7 +37,7 @@ func (_m *StateInterfaceMock) AddGeneratedProof(ctx context.Context, proof *stat } // AddSequence provides a mock function with given fields: ctx, sequence, dbTx -func (_m *StateInterfaceMock) AddSequence(ctx context.Context, sequence state.Sequence, dbTx pgx.Tx) error { +func (_m *StorageInterfaceMock) AddSequence(ctx context.Context, sequence state.Sequence, dbTx db.Txer) error { ret := _m.Called(ctx, sequence, dbTx) if len(ret) == 0 { @@ -43,7 +45,7 @@ func (_m *StateInterfaceMock) AddSequence(ctx context.Context, sequence state.Se } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, state.Sequence, pgx.Tx) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, state.Sequence, db.Txer) error); ok { r0 = rf(ctx, sequence, dbTx) } else { r0 = ret.Error(0) @@ -52,29 +54,29 @@ func (_m *StateInterfaceMock) AddSequence(ctx context.Context, sequence state.Se return r0 } -// BeginStateTransaction provides a mock function with given fields: ctx -func (_m *StateInterfaceMock) BeginStateTransaction(ctx context.Context) (pgx.Tx, error) { - ret := _m.Called(ctx) +// BeginTx provides a mock function with given fields: ctx, options +func (_m *StorageInterfaceMock) BeginTx(ctx context.Context, options *sql.TxOptions) (db.Txer, error) { + ret := _m.Called(ctx, options) if len(ret) == 0 { - panic("no return value specified for BeginStateTransaction") + panic("no return value specified for BeginTx") } - var r0 pgx.Tx + var r0 db.Txer var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (pgx.Tx, error)); ok { - return rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *sql.TxOptions) (db.Txer, error)); ok { + return rf(ctx, options) } - if rf, ok := ret.Get(0).(func(context.Context) pgx.Tx); ok { - r0 = rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *sql.TxOptions) db.Txer); ok { + r0 = rf(ctx, options) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(pgx.Tx) + r0 = ret.Get(0).(db.Txer) } } - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) + if rf, ok := ret.Get(1).(func(context.Context, *sql.TxOptions) error); ok { + r1 = rf(ctx, options) } else { r1 = ret.Error(1) } @@ -83,7 +85,7 @@ func (_m *StateInterfaceMock) BeginStateTransaction(ctx context.Context) (pgx.Tx } // CheckProofContainsCompleteSequences provides a mock function with given fields: ctx, proof, dbTx -func (_m *StateInterfaceMock) CheckProofContainsCompleteSequences(ctx context.Context, proof *state.Proof, dbTx pgx.Tx) (bool, error) { +func (_m *StorageInterfaceMock) CheckProofContainsCompleteSequences(ctx context.Context, proof *state.Proof, dbTx db.Txer) (bool, error) { ret := _m.Called(ctx, proof, dbTx) if len(ret) == 0 { @@ -92,16 +94,16 @@ func (_m *StateInterfaceMock) CheckProofContainsCompleteSequences(ctx context.Co var r0 bool var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *state.Proof, pgx.Tx) (bool, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, *state.Proof, db.Txer) (bool, error)); ok { return rf(ctx, proof, dbTx) } - if rf, ok := ret.Get(0).(func(context.Context, *state.Proof, pgx.Tx) bool); ok { + if rf, ok := ret.Get(0).(func(context.Context, *state.Proof, db.Txer) bool); ok { r0 = rf(ctx, proof, dbTx) } else { r0 = ret.Get(0).(bool) } - if rf, ok := ret.Get(1).(func(context.Context, *state.Proof, pgx.Tx) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, *state.Proof, db.Txer) error); ok { r1 = rf(ctx, proof, dbTx) } else { r1 = ret.Error(1) @@ -111,7 +113,7 @@ func (_m *StateInterfaceMock) CheckProofContainsCompleteSequences(ctx context.Co } // CheckProofExistsForBatch provides a mock function with given fields: ctx, batchNumber, dbTx -func (_m *StateInterfaceMock) CheckProofExistsForBatch(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) (bool, error) { +func (_m *StorageInterfaceMock) CheckProofExistsForBatch(ctx context.Context, batchNumber uint64, dbTx db.Txer) (bool, error) { ret := _m.Called(ctx, batchNumber, dbTx) if len(ret) == 0 { @@ -120,16 +122,16 @@ func (_m *StateInterfaceMock) CheckProofExistsForBatch(ctx context.Context, batc var r0 bool var r1 error - if rf, ok := ret.Get(0).(func(context.Context, uint64, pgx.Tx) (bool, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, uint64, db.Txer) (bool, error)); ok { return rf(ctx, batchNumber, dbTx) } - if rf, ok := ret.Get(0).(func(context.Context, uint64, pgx.Tx) bool); ok { + if rf, ok := ret.Get(0).(func(context.Context, uint64, db.Txer) bool); ok { r0 = rf(ctx, batchNumber, dbTx) } else { r0 = ret.Get(0).(bool) } - if rf, ok := ret.Get(1).(func(context.Context, uint64, pgx.Tx) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, uint64, db.Txer) error); ok { r1 = rf(ctx, batchNumber, dbTx) } else { r1 = ret.Error(1) @@ -139,7 +141,7 @@ func (_m *StateInterfaceMock) CheckProofExistsForBatch(ctx context.Context, batc } // CleanupGeneratedProofs provides a mock function with given fields: ctx, batchNumber, dbTx -func (_m *StateInterfaceMock) CleanupGeneratedProofs(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) error { +func (_m *StorageInterfaceMock) CleanupGeneratedProofs(ctx context.Context, batchNumber uint64, dbTx db.Txer) error { ret := _m.Called(ctx, batchNumber, dbTx) if len(ret) == 0 { @@ -147,7 +149,7 @@ func (_m *StateInterfaceMock) CleanupGeneratedProofs(ctx context.Context, batchN } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, uint64, pgx.Tx) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, uint64, db.Txer) error); ok { r0 = rf(ctx, batchNumber, dbTx) } else { r0 = ret.Error(0) @@ -157,7 +159,7 @@ func (_m *StateInterfaceMock) CleanupGeneratedProofs(ctx context.Context, batchN } // CleanupLockedProofs provides a mock function with given fields: ctx, duration, dbTx -func (_m *StateInterfaceMock) CleanupLockedProofs(ctx context.Context, duration string, dbTx pgx.Tx) (int64, error) { +func (_m *StorageInterfaceMock) CleanupLockedProofs(ctx context.Context, duration string, dbTx db.Txer) (int64, error) { ret := _m.Called(ctx, duration, dbTx) if len(ret) == 0 { @@ -166,16 +168,16 @@ func (_m *StateInterfaceMock) CleanupLockedProofs(ctx context.Context, duration var r0 int64 var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, pgx.Tx) (int64, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, db.Txer) (int64, error)); ok { return rf(ctx, duration, dbTx) } - if rf, ok := ret.Get(0).(func(context.Context, string, pgx.Tx) int64); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, db.Txer) int64); ok { r0 = rf(ctx, duration, dbTx) } else { r0 = ret.Get(0).(int64) } - if rf, ok := ret.Get(1).(func(context.Context, string, pgx.Tx) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, string, db.Txer) error); ok { r1 = rf(ctx, duration, dbTx) } else { r1 = ret.Error(1) @@ -185,7 +187,7 @@ func (_m *StateInterfaceMock) CleanupLockedProofs(ctx context.Context, duration } // DeleteGeneratedProofs provides a mock function with given fields: ctx, batchNumber, batchNumberFinal, dbTx -func (_m *StateInterfaceMock) DeleteGeneratedProofs(ctx context.Context, batchNumber uint64, batchNumberFinal uint64, dbTx pgx.Tx) error { +func (_m *StorageInterfaceMock) DeleteGeneratedProofs(ctx context.Context, batchNumber uint64, batchNumberFinal uint64, dbTx db.Txer) error { ret := _m.Called(ctx, batchNumber, batchNumberFinal, dbTx) if len(ret) == 0 { @@ -193,7 +195,7 @@ func (_m *StateInterfaceMock) DeleteGeneratedProofs(ctx context.Context, batchNu } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, uint64, uint64, pgx.Tx) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, uint64, uint64, db.Txer) error); ok { r0 = rf(ctx, batchNumber, batchNumberFinal, dbTx) } else { r0 = ret.Error(0) @@ -203,7 +205,7 @@ func (_m *StateInterfaceMock) DeleteGeneratedProofs(ctx context.Context, batchNu } // DeleteUngeneratedProofs provides a mock function with given fields: ctx, dbTx -func (_m *StateInterfaceMock) DeleteUngeneratedProofs(ctx context.Context, dbTx pgx.Tx) error { +func (_m *StorageInterfaceMock) DeleteUngeneratedProofs(ctx context.Context, dbTx db.Txer) error { ret := _m.Called(ctx, dbTx) if len(ret) == 0 { @@ -211,7 +213,7 @@ func (_m *StateInterfaceMock) DeleteUngeneratedProofs(ctx context.Context, dbTx } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, pgx.Tx) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, db.Txer) error); ok { r0 = rf(ctx, dbTx) } else { r0 = ret.Error(0) @@ -221,7 +223,7 @@ func (_m *StateInterfaceMock) DeleteUngeneratedProofs(ctx context.Context, dbTx } // GetProofReadyToVerify provides a mock function with given fields: ctx, lastVerfiedBatchNumber, dbTx -func (_m *StateInterfaceMock) GetProofReadyToVerify(ctx context.Context, lastVerfiedBatchNumber uint64, dbTx pgx.Tx) (*state.Proof, error) { +func (_m *StorageInterfaceMock) GetProofReadyToVerify(ctx context.Context, lastVerfiedBatchNumber uint64, dbTx db.Txer) (*state.Proof, error) { ret := _m.Called(ctx, lastVerfiedBatchNumber, dbTx) if len(ret) == 0 { @@ -230,10 +232,10 @@ func (_m *StateInterfaceMock) GetProofReadyToVerify(ctx context.Context, lastVer var r0 *state.Proof var r1 error - if rf, ok := ret.Get(0).(func(context.Context, uint64, pgx.Tx) (*state.Proof, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, uint64, db.Txer) (*state.Proof, error)); ok { return rf(ctx, lastVerfiedBatchNumber, dbTx) } - if rf, ok := ret.Get(0).(func(context.Context, uint64, pgx.Tx) *state.Proof); ok { + if rf, ok := ret.Get(0).(func(context.Context, uint64, db.Txer) *state.Proof); ok { r0 = rf(ctx, lastVerfiedBatchNumber, dbTx) } else { if ret.Get(0) != nil { @@ -241,7 +243,7 @@ func (_m *StateInterfaceMock) GetProofReadyToVerify(ctx context.Context, lastVer } } - if rf, ok := ret.Get(1).(func(context.Context, uint64, pgx.Tx) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, uint64, db.Txer) error); ok { r1 = rf(ctx, lastVerfiedBatchNumber, dbTx) } else { r1 = ret.Error(1) @@ -251,7 +253,7 @@ func (_m *StateInterfaceMock) GetProofReadyToVerify(ctx context.Context, lastVer } // GetProofsToAggregate provides a mock function with given fields: ctx, dbTx -func (_m *StateInterfaceMock) GetProofsToAggregate(ctx context.Context, dbTx pgx.Tx) (*state.Proof, *state.Proof, error) { +func (_m *StorageInterfaceMock) GetProofsToAggregate(ctx context.Context, dbTx db.Txer) (*state.Proof, *state.Proof, error) { ret := _m.Called(ctx, dbTx) if len(ret) == 0 { @@ -261,10 +263,10 @@ func (_m *StateInterfaceMock) GetProofsToAggregate(ctx context.Context, dbTx pgx var r0 *state.Proof var r1 *state.Proof var r2 error - if rf, ok := ret.Get(0).(func(context.Context, pgx.Tx) (*state.Proof, *state.Proof, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, db.Txer) (*state.Proof, *state.Proof, error)); ok { return rf(ctx, dbTx) } - if rf, ok := ret.Get(0).(func(context.Context, pgx.Tx) *state.Proof); ok { + if rf, ok := ret.Get(0).(func(context.Context, db.Txer) *state.Proof); ok { r0 = rf(ctx, dbTx) } else { if ret.Get(0) != nil { @@ -272,7 +274,7 @@ func (_m *StateInterfaceMock) GetProofsToAggregate(ctx context.Context, dbTx pgx } } - if rf, ok := ret.Get(1).(func(context.Context, pgx.Tx) *state.Proof); ok { + if rf, ok := ret.Get(1).(func(context.Context, db.Txer) *state.Proof); ok { r1 = rf(ctx, dbTx) } else { if ret.Get(1) != nil { @@ -280,7 +282,7 @@ func (_m *StateInterfaceMock) GetProofsToAggregate(ctx context.Context, dbTx pgx } } - if rf, ok := ret.Get(2).(func(context.Context, pgx.Tx) error); ok { + if rf, ok := ret.Get(2).(func(context.Context, db.Txer) error); ok { r2 = rf(ctx, dbTx) } else { r2 = ret.Error(2) @@ -290,7 +292,7 @@ func (_m *StateInterfaceMock) GetProofsToAggregate(ctx context.Context, dbTx pgx } // UpdateGeneratedProof provides a mock function with given fields: ctx, proof, dbTx -func (_m *StateInterfaceMock) UpdateGeneratedProof(ctx context.Context, proof *state.Proof, dbTx pgx.Tx) error { +func (_m *StorageInterfaceMock) UpdateGeneratedProof(ctx context.Context, proof *state.Proof, dbTx db.Txer) error { ret := _m.Called(ctx, proof, dbTx) if len(ret) == 0 { @@ -298,7 +300,7 @@ func (_m *StateInterfaceMock) UpdateGeneratedProof(ctx context.Context, proof *s } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, *state.Proof, pgx.Tx) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, *state.Proof, db.Txer) error); ok { r0 = rf(ctx, proof, dbTx) } else { r0 = ret.Error(0) @@ -307,13 +309,13 @@ func (_m *StateInterfaceMock) UpdateGeneratedProof(ctx context.Context, proof *s return r0 } -// NewStateInterfaceMock creates a new instance of StateInterfaceMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// NewStorageInterfaceMock creates a new instance of StorageInterfaceMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. -func NewStateInterfaceMock(t interface { +func NewStorageInterfaceMock(t interface { mock.TestingT Cleanup(func()) -}) *StateInterfaceMock { - mock := &StateInterfaceMock{} +}) *StorageInterfaceMock { + mock := &StorageInterfaceMock{} mock.Mock.Test(t) t.Cleanup(func() { mock.AssertExpectations(t) }) diff --git a/aggregator/mocks/mock_txer.go b/aggregator/mocks/mock_txer.go new file mode 100644 index 00000000..1de07124 --- /dev/null +++ b/aggregator/mocks/mock_txer.go @@ -0,0 +1,163 @@ +// Code generated by mockery v2.39.0. DO NOT EDIT. + +package mocks + +import ( + sql "database/sql" + + mock "github.com/stretchr/testify/mock" +) + +// TxerMock is an autogenerated mock type for the Txer type +type TxerMock struct { + mock.Mock +} + +// AddCommitCallback provides a mock function with given fields: cb +func (_m *TxerMock) AddCommitCallback(cb func()) { + _m.Called(cb) +} + +// AddRollbackCallback provides a mock function with given fields: cb +func (_m *TxerMock) AddRollbackCallback(cb func()) { + _m.Called(cb) +} + +// Commit provides a mock function with given fields: +func (_m *TxerMock) Commit() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Commit") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Exec provides a mock function with given fields: query, args +func (_m *TxerMock) Exec(query string, args ...interface{}) (sql.Result, error) { + var _ca []interface{} + _ca = append(_ca, query) + _ca = append(_ca, args...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for Exec") + } + + var r0 sql.Result + var r1 error + if rf, ok := ret.Get(0).(func(string, ...interface{}) (sql.Result, error)); ok { + return rf(query, args...) + } + if rf, ok := ret.Get(0).(func(string, ...interface{}) sql.Result); ok { + r0 = rf(query, args...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(sql.Result) + } + } + + if rf, ok := ret.Get(1).(func(string, ...interface{}) error); ok { + r1 = rf(query, args...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Query provides a mock function with given fields: query, args +func (_m *TxerMock) Query(query string, args ...interface{}) (*sql.Rows, error) { + var _ca []interface{} + _ca = append(_ca, query) + _ca = append(_ca, args...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for Query") + } + + var r0 *sql.Rows + var r1 error + if rf, ok := ret.Get(0).(func(string, ...interface{}) (*sql.Rows, error)); ok { + return rf(query, args...) + } + if rf, ok := ret.Get(0).(func(string, ...interface{}) *sql.Rows); ok { + r0 = rf(query, args...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*sql.Rows) + } + } + + if rf, ok := ret.Get(1).(func(string, ...interface{}) error); ok { + r1 = rf(query, args...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// QueryRow provides a mock function with given fields: query, args +func (_m *TxerMock) QueryRow(query string, args ...interface{}) *sql.Row { + var _ca []interface{} + _ca = append(_ca, query) + _ca = append(_ca, args...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for QueryRow") + } + + var r0 *sql.Row + if rf, ok := ret.Get(0).(func(string, ...interface{}) *sql.Row); ok { + r0 = rf(query, args...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*sql.Row) + } + } + + return r0 +} + +// Rollback provides a mock function with given fields: +func (_m *TxerMock) Rollback() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Rollback") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewTxerMock creates a new instance of TxerMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewTxerMock(t interface { + mock.TestingT + Cleanup(func()) +}) *TxerMock { + mock := &TxerMock{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/aggregator/profitabilitychecker.go b/aggregator/profitabilitychecker.go deleted file mode 100644 index dc91a21e..00000000 --- a/aggregator/profitabilitychecker.go +++ /dev/null @@ -1,92 +0,0 @@ -package aggregator - -import ( - "context" - "math/big" - "time" -) - -// TxProfitabilityCheckerType checks profitability of batch validation -type TxProfitabilityCheckerType string - -const ( - // ProfitabilityBase checks pol collateral with min reward - ProfitabilityBase = "base" - // ProfitabilityAcceptAll validate batch anyway and don't check anything - ProfitabilityAcceptAll = "acceptall" -) - -// TxProfitabilityCheckerBase checks pol collateral with min reward -type TxProfitabilityCheckerBase struct { - State StateInterface - IntervalAfterWhichBatchSentAnyway time.Duration - MinReward *big.Int -} - -// NewTxProfitabilityCheckerBase init base tx profitability checker -func NewTxProfitabilityCheckerBase( - state StateInterface, interval time.Duration, minReward *big.Int, -) *TxProfitabilityCheckerBase { - return &TxProfitabilityCheckerBase{ - State: state, - IntervalAfterWhichBatchSentAnyway: interval, - MinReward: minReward, - } -} - -// IsProfitable checks pol collateral with min reward -func (pc *TxProfitabilityCheckerBase) IsProfitable(ctx context.Context, polCollateral *big.Int) (bool, error) { - // if pc.IntervalAfterWhichBatchSentAnyway != 0 { - // ok, err := isConsolidatedBatchAppeared(ctx, pc.State, pc.IntervalAfterWhichBatchSentAnyway) - // if err != nil { - // return false, err - // } - // if ok { - // return true, nil - // } - // } - return polCollateral.Cmp(pc.MinReward) >= 0, nil -} - -// TxProfitabilityCheckerAcceptAll validate batch anyway and don't check anything -type TxProfitabilityCheckerAcceptAll struct { - State StateInterface - IntervalAfterWhichBatchSentAnyway time.Duration -} - -// NewTxProfitabilityCheckerAcceptAll init tx profitability checker that accept all txs -func NewTxProfitabilityCheckerAcceptAll(state StateInterface, interval time.Duration) *TxProfitabilityCheckerAcceptAll { - return &TxProfitabilityCheckerAcceptAll{ - State: state, - IntervalAfterWhichBatchSentAnyway: interval, - } -} - -// IsProfitable validate batch anyway and don't check anything -func (pc *TxProfitabilityCheckerAcceptAll) IsProfitable(ctx context.Context, polCollateral *big.Int) (bool, error) { - // if pc.IntervalAfterWhichBatchSentAnyway != 0 { - // ok, err := isConsolidatedBatchAppeared(ctx, pc.State, pc.IntervalAfterWhichBatchSentAnyway) - // if err != nil { - // return false, err - // } - // if ok { - // return true, nil - // } - // } - return true, nil -} - -// TODO: now it's impossible to check, when batch got consolidated, bcs it's not saved -// func isConsolidatedBatchAppeared(ctx context.Context, state StateInterface, -// intervalAfterWhichBatchConsolidatedAnyway time.Duration) (bool, error) { -// batch, err := state.GetLastVerifiedBatch(ctx, nil) -// if err != nil { -// return false, fmt.Errorf("failed to get last verified batch, err: %v", err) -// } -// interval := intervalAfterWhichBatchConsolidatedAnyway * time.Minute -// if batch..Before(time.Now().Add(-interval)) { -// return true, nil -// } -// -// return false, err -// } diff --git a/cmd/run.go b/cmd/run.go index 727533e8..cff1188b 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -34,15 +34,12 @@ import ( "github.com/0xPolygon/cdk/rpc" "github.com/0xPolygon/cdk/sequencesender" "github.com/0xPolygon/cdk/sequencesender/txbuilder" - "github.com/0xPolygon/cdk/state" - "github.com/0xPolygon/cdk/state/pgstatestorage" "github.com/0xPolygon/cdk/translator" ethtxman "github.com/0xPolygon/zkevm-ethtx-manager/etherman" "github.com/0xPolygon/zkevm-ethtx-manager/etherman/etherscan" "github.com/0xPolygon/zkevm-ethtx-manager/ethtxmanager" ethtxlog "github.com/0xPolygon/zkevm-ethtx-manager/log" "github.com/ethereum/go-ethereum/ethclient" - "github.com/jackc/pgx/v4/pgxpool" "github.com/urfave/cli/v2" ) @@ -180,17 +177,8 @@ func createAggregator(ctx context.Context, c config.Config, runMigrations bool) logger := log.WithFields("module", cdkcommon.AGGREGATOR) // Migrations if runMigrations { - logger.Infof( - "Running DB migrations host: %s:%s db:%s user:%s", - c.Aggregator.DB.Host, c.Aggregator.DB.Port, c.Aggregator.DB.Name, c.Aggregator.DB.User, - ) - runAggregatorMigrations(c.Aggregator.DB) - } - - // DB - stateSQLDB, err := db.NewSQLDB(logger, c.Aggregator.DB) - if err != nil { - logger.Fatal(err) + logger.Infof("Running DB migrations. File %s", c.Aggregator.DBPath) + runAggregatorMigrations(c.Aggregator.DBPath) } etherman, err := newEtherman(c) @@ -209,9 +197,7 @@ func createAggregator(ctx context.Context, c config.Config, runMigrations bool) c.Aggregator.ChainID = l2ChainID } - st := newState(&c, c.Aggregator.ChainID, stateSQLDB) - - aggregator, err := aggregator.New(ctx, c.Aggregator, logger, st, etherman) + aggregator, err := aggregator.New(ctx, c.Aggregator, logger, etherman) if err != nil { logger.Fatal(err) } @@ -432,13 +418,13 @@ func newDataAvailability(c config.Config, etherman *etherman.Client) (*dataavail return dataavailability.New(daBackend) } -func runAggregatorMigrations(c db.Config) { - runMigrations(c, db.AggregatorMigrationName) +func runAggregatorMigrations(dbPath string) { + runMigrations(dbPath, db.AggregatorMigrationName) } -func runMigrations(c db.Config, name string) { +func runMigrations(dbPath string, name string) { log.Infof("running migrations for %v", name) - err := db.RunMigrationsUp(c, name) + err := db.RunMigrationsUp(dbPath, name) if err != nil { log.Fatal(err) } @@ -484,19 +470,6 @@ func waitSignal(cancelFuncs []context.CancelFunc) { } } -func newState(c *config.Config, l2ChainID uint64, sqlDB *pgxpool.Pool) *state.State { - stateCfg := state.Config{ - DB: c.Aggregator.DB, - ChainID: l2ChainID, - } - - stateDB := pgstatestorage.NewPostgresStorage(stateCfg, sqlDB) - - st := state.NewState(stateCfg, stateDB) - - return st -} - func newReorgDetector( cfg *reorgdetector.Config, client *ethclient.Client, diff --git a/config/default.go b/config/default.go index 61b099c8..316cfb76 100644 --- a/config/default.go +++ b/config/default.go @@ -137,17 +137,10 @@ SettlementBackend = "l1" AggLayerTxTimeout = "5m" AggLayerURL = "{{AggLayerURL}}" SyncModeOnlyEnabled = false +DBPath = "{{PathRWData}}/aggregator_db.sqlite" [Aggregator.SequencerPrivateKey] Path = "{{SequencerPrivateKeyPath}}" Password = "{{SequencerPrivateKeyPassword}}" - [Aggregator.DB] - Name = "aggregator_db" - User = "aggregator_user" - Password = "aggregator_password" - Host = "cdk-aggregator-db" - Port = "5432" - EnableLog = false - MaxConns = 200 [Aggregator.Log] Environment ="{{Log.Environment}}" # "production" or "development" Level = "{{Log.Level}}" diff --git a/scripts/local_config b/scripts/local_config index 5830b6e6..ca25dbbb 100755 --- a/scripts/local_config +++ b/scripts/local_config @@ -194,9 +194,6 @@ function export_values_of_cdk_node_config(){ if [ $? -ne 0 ]; then export_key_from_toml_file zkevm_l2_agglayer_address $_CDK_CONFIG_FILE "." SenderProofToL1Addr fi - export_key_from_toml_file_or_fatal aggregator_db_name $_CDK_CONFIG_FILE Aggregator.DB Name - export_key_from_toml_file_or_fatal aggregator_db_user $_CDK_CONFIG_FILE Aggregator.DB User - export_key_from_toml_file_or_fatal aggregator_db_password $_CDK_CONFIG_FILE Aggregator.DB Password export_obj_key_from_toml_file zkevm_l2_aggregator_keystore_password $_CDK_CONFIG_FILE Aggregator.EthTxManager PrivateKeys Password if [ $? -ne 0 ]; then export_key_from_toml_file zkevm_l2_aggregator_keystore_password $_CDK_CONFIG_FILE "." AggregatorPrivateKeyPassword diff --git a/state/config.go b/state/config.go deleted file mode 100644 index e5a65e8b..00000000 --- a/state/config.go +++ /dev/null @@ -1,13 +0,0 @@ -package state - -import ( - "github.com/0xPolygon/cdk/aggregator/db" -) - -// Config is state config -type Config struct { - // ChainID is the L2 ChainID provided by the Network Config - ChainID uint64 - // DB is the database configuration - DB db.Config `mapstructure:"DB"` -} diff --git a/state/interfaces.go b/state/interfaces.go deleted file mode 100644 index fc4eb495..00000000 --- a/state/interfaces.go +++ /dev/null @@ -1,26 +0,0 @@ -package state - -import ( - "context" - - "github.com/jackc/pgconn" - "github.com/jackc/pgx/v4" -) - -type storage interface { - Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) - Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) - QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row - Begin(ctx context.Context) (pgx.Tx, error) - AddSequence(ctx context.Context, sequence Sequence, dbTx pgx.Tx) error - CheckProofContainsCompleteSequences(ctx context.Context, proof *Proof, dbTx pgx.Tx) (bool, error) - GetProofReadyToVerify(ctx context.Context, lastVerfiedBatchNumber uint64, dbTx pgx.Tx) (*Proof, error) - GetProofsToAggregate(ctx context.Context, dbTx pgx.Tx) (*Proof, *Proof, error) - AddGeneratedProof(ctx context.Context, proof *Proof, dbTx pgx.Tx) error - UpdateGeneratedProof(ctx context.Context, proof *Proof, dbTx pgx.Tx) error - DeleteGeneratedProofs(ctx context.Context, batchNumber uint64, batchNumberFinal uint64, dbTx pgx.Tx) error - DeleteUngeneratedProofs(ctx context.Context, dbTx pgx.Tx) error - CleanupGeneratedProofs(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) error - CleanupLockedProofs(ctx context.Context, duration string, dbTx pgx.Tx) (int64, error) - CheckProofExistsForBatch(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) (bool, error) -} diff --git a/state/pgstatestorage/interfaces.go b/state/pgstatestorage/interfaces.go deleted file mode 100644 index e5f7402b..00000000 --- a/state/pgstatestorage/interfaces.go +++ /dev/null @@ -1,14 +0,0 @@ -package pgstatestorage - -import ( - "context" - - "github.com/jackc/pgconn" - "github.com/jackc/pgx/v4" -) - -type ExecQuerier interface { - Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) - Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) - QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row -} diff --git a/state/pgstatestorage/pgstatestorage.go b/state/pgstatestorage/pgstatestorage.go deleted file mode 100644 index 7e294c6b..00000000 --- a/state/pgstatestorage/pgstatestorage.go +++ /dev/null @@ -1,29 +0,0 @@ -package pgstatestorage - -import ( - "github.com/0xPolygon/cdk/state" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgxpool" -) - -// PostgresStorage implements the Storage interface -type PostgresStorage struct { - cfg state.Config - *pgxpool.Pool -} - -// NewPostgresStorage creates a new StateDB -func NewPostgresStorage(cfg state.Config, db *pgxpool.Pool) *PostgresStorage { - return &PostgresStorage{ - cfg, - db, - } -} - -// getExecQuerier determines which execQuerier to use, dbTx or the main pgxpool -func (p *PostgresStorage) getExecQuerier(dbTx pgx.Tx) ExecQuerier { - if dbTx != nil { - return dbTx - } - return p -} diff --git a/state/pgstatestorage/proof.go b/state/pgstatestorage/proof.go deleted file mode 100644 index fa32fc99..00000000 --- a/state/pgstatestorage/proof.go +++ /dev/null @@ -1,266 +0,0 @@ -package pgstatestorage - -import ( - "context" - "errors" - "fmt" - "time" - - "github.com/0xPolygon/cdk/state" - "github.com/jackc/pgx/v4" -) - -// CheckProofExistsForBatch checks if the batch is already included in any proof -func (p *PostgresStorage) CheckProofExistsForBatch(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) (bool, error) { - const checkProofExistsForBatchSQL = ` - SELECT EXISTS (SELECT 1 FROM aggregator.proof p WHERE $1 >= p.batch_num AND $1 <= p.batch_num_final) - ` - e := p.getExecQuerier(dbTx) - var exists bool - err := e.QueryRow(ctx, checkProofExistsForBatchSQL, batchNumber).Scan(&exists) - if err != nil && !errors.Is(err, pgx.ErrNoRows) { - return exists, err - } - return exists, nil -} - -// CheckProofContainsCompleteSequences checks if a recursive proof contains complete sequences -func (p *PostgresStorage) CheckProofContainsCompleteSequences( - ctx context.Context, proof *state.Proof, dbTx pgx.Tx, -) (bool, error) { - const getProofContainsCompleteSequencesSQL = ` - SELECT EXISTS (SELECT 1 FROM aggregator.sequence s1 WHERE s1.from_batch_num = $1) AND - EXISTS (SELECT 1 FROM aggregator.sequence s2 WHERE s2.to_batch_num = $2) - ` - e := p.getExecQuerier(dbTx) - var exists bool - err := e.QueryRow(ctx, getProofContainsCompleteSequencesSQL, proof.BatchNumber, proof.BatchNumberFinal).Scan(&exists) - if err != nil && !errors.Is(err, pgx.ErrNoRows) { - return exists, err - } - return exists, nil -} - -// GetProofReadyToVerify return the proof that is ready to verify -func (p *PostgresStorage) GetProofReadyToVerify( - ctx context.Context, lastVerfiedBatchNumber uint64, dbTx pgx.Tx, -) (*state.Proof, error) { - const getProofReadyToVerifySQL = ` - SELECT - p.batch_num, - p.batch_num_final, - p.proof, - p.proof_id, - p.input_prover, - p.prover, - p.prover_id, - p.generating_since, - p.created_at, - p.updated_at - FROM aggregator.proof p - WHERE batch_num = $1 AND generating_since IS NULL AND - EXISTS (SELECT 1 FROM aggregator.sequence s1 WHERE s1.from_batch_num = p.batch_num) AND - EXISTS (SELECT 1 FROM aggregator.sequence s2 WHERE s2.to_batch_num = p.batch_num_final) - ` - - var proof = &state.Proof{} - - e := p.getExecQuerier(dbTx) - row := e.QueryRow(ctx, getProofReadyToVerifySQL, lastVerfiedBatchNumber+1) - err := row.Scan( - &proof.BatchNumber, &proof.BatchNumberFinal, &proof.Proof, &proof.ProofID, - &proof.InputProver, &proof.Prover, &proof.ProverID, &proof.GeneratingSince, - &proof.CreatedAt, &proof.UpdatedAt, - ) - - if errors.Is(err, pgx.ErrNoRows) { - return nil, state.ErrNotFound - } else if err != nil { - return nil, err - } - - return proof, err -} - -// GetProofsToAggregate return the next to proof that it is possible to aggregate -func (p *PostgresStorage) GetProofsToAggregate(ctx context.Context, dbTx pgx.Tx) (*state.Proof, *state.Proof, error) { - var ( - proof1 = &state.Proof{} - proof2 = &state.Proof{} - ) - - // TODO: add comments to explain the query - const getProofsToAggregateSQL = ` - SELECT - p1.batch_num as p1_batch_num, - p1.batch_num_final as p1_batch_num_final, - p1.proof as p1_proof, - p1.proof_id as p1_proof_id, - p1.input_prover as p1_input_prover, - p1.prover as p1_prover, - p1.prover_id as p1_prover_id, - p1.generating_since as p1_generating_since, - p1.created_at as p1_created_at, - p1.updated_at as p1_updated_at, - p2.batch_num as p2_batch_num, - p2.batch_num_final as p2_batch_num_final, - p2.proof as p2_proof, - p2.proof_id as p2_proof_id, - p2.input_prover as p2_input_prover, - p2.prover as p2_prover, - p2.prover_id as p2_prover_id, - p2.generating_since as p2_generating_since, - p2.created_at as p2_created_at, - p2.updated_at as p2_updated_at - FROM aggregator.proof p1 INNER JOIN aggregator.proof p2 ON p1.batch_num_final = p2.batch_num - 1 - WHERE p1.generating_since IS NULL AND p2.generating_since IS NULL AND - p1.proof IS NOT NULL AND p2.proof IS NOT NULL AND - ( - EXISTS ( - SELECT 1 FROM aggregator.sequence s - WHERE p1.batch_num >= s.from_batch_num AND p1.batch_num <= s.to_batch_num AND - p1.batch_num_final >= s.from_batch_num AND p1.batch_num_final <= s.to_batch_num AND - p2.batch_num >= s.from_batch_num AND p2.batch_num <= s.to_batch_num AND - p2.batch_num_final >= s.from_batch_num AND p2.batch_num_final <= s.to_batch_num - ) - OR - ( - EXISTS ( SELECT 1 FROM aggregator.sequence s WHERE p1.batch_num = s.from_batch_num) AND - EXISTS ( SELECT 1 FROM aggregator.sequence s WHERE p1.batch_num_final = s.to_batch_num) AND - EXISTS ( SELECT 1 FROM aggregator.sequence s WHERE p2.batch_num = s.from_batch_num) AND - EXISTS ( SELECT 1 FROM aggregator.sequence s WHERE p2.batch_num_final = s.to_batch_num) - ) - ) - ORDER BY p1.batch_num ASC - LIMIT 1 - ` - - e := p.getExecQuerier(dbTx) - row := e.QueryRow(ctx, getProofsToAggregateSQL) - err := row.Scan( - &proof1.BatchNumber, &proof1.BatchNumberFinal, &proof1.Proof, &proof1.ProofID, - &proof1.InputProver, &proof1.Prover, &proof1.ProverID, &proof1.GeneratingSince, - &proof1.CreatedAt, &proof1.UpdatedAt, - &proof2.BatchNumber, &proof2.BatchNumberFinal, &proof2.Proof, &proof2.ProofID, - &proof2.InputProver, &proof2.Prover, &proof2.ProverID, &proof2.GeneratingSince, - &proof2.CreatedAt, &proof2.UpdatedAt, - ) - - if errors.Is(err, pgx.ErrNoRows) { - return nil, nil, state.ErrNotFound - } else if err != nil { - return nil, nil, err - } - - return proof1, proof2, err -} - -// AddGeneratedProof adds a generated proof to the storage -func (p *PostgresStorage) AddGeneratedProof(ctx context.Context, proof *state.Proof, dbTx pgx.Tx) error { - const addGeneratedProofSQL = ` - INSERT INTO aggregator.proof ( - batch_num, batch_num_final, proof, proof_id, input_prover, prover, - prover_id, generating_since, created_at, updated_at - ) VALUES ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10 - ) - ` - e := p.getExecQuerier(dbTx) - now := time.Now().UTC().Round(time.Microsecond) - _, err := e.Exec( - ctx, addGeneratedProofSQL, proof.BatchNumber, proof.BatchNumberFinal, proof.Proof, proof.ProofID, - proof.InputProver, proof.Prover, proof.ProverID, proof.GeneratingSince, now, now, - ) - return err -} - -// UpdateGeneratedProof updates a generated proof in the storage -func (p *PostgresStorage) UpdateGeneratedProof(ctx context.Context, proof *state.Proof, dbTx pgx.Tx) error { - const addGeneratedProofSQL = ` - UPDATE aggregator.proof - SET proof = $3, - proof_id = $4, - input_prover = $5, - prover = $6, - prover_id = $7, - generating_since = $8, - updated_at = $9 - WHERE batch_num = $1 - AND batch_num_final = $2 - ` - e := p.getExecQuerier(dbTx) - now := time.Now().UTC().Round(time.Microsecond) - _, err := e.Exec( - ctx, addGeneratedProofSQL, proof.BatchNumber, proof.BatchNumberFinal, proof.Proof, proof.ProofID, - proof.InputProver, proof.Prover, proof.ProverID, proof.GeneratingSince, now, - ) - return err -} - -// DeleteGeneratedProofs deletes from the storage the generated proofs falling -// inside the batch numbers range. -func (p *PostgresStorage) DeleteGeneratedProofs( - ctx context.Context, batchNumber uint64, batchNumberFinal uint64, dbTx pgx.Tx, -) error { - const deleteGeneratedProofSQL = "DELETE FROM aggregator.proof WHERE batch_num >= $1 AND batch_num_final <= $2" - e := p.getExecQuerier(dbTx) - _, err := e.Exec(ctx, deleteGeneratedProofSQL, batchNumber, batchNumberFinal) - return err -} - -// CleanupGeneratedProofs deletes from the storage the generated proofs up to -// the specified batch number included. -func (p *PostgresStorage) CleanupGeneratedProofs(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) error { - const deleteGeneratedProofSQL = "DELETE FROM aggregator.proof WHERE batch_num_final <= $1" - e := p.getExecQuerier(dbTx) - _, err := e.Exec(ctx, deleteGeneratedProofSQL, batchNumber) - return err -} - -// CleanupLockedProofs deletes from the storage the proofs locked in generating -// state for more than the provided threshold. -func (p *PostgresStorage) CleanupLockedProofs(ctx context.Context, duration string, dbTx pgx.Tx) (int64, error) { - interval, err := toPostgresInterval(duration) - if err != nil { - return 0, err - } - sql := fmt.Sprintf("DELETE FROM aggregator.proof WHERE generating_since < (NOW() - interval '%s')", interval) - e := p.getExecQuerier(dbTx) - ct, err := e.Exec(ctx, sql) - if err != nil { - return 0, err - } - return ct.RowsAffected(), nil -} - -// DeleteUngeneratedProofs deletes ungenerated proofs. -// This method is meant to be use during aggregator boot-up sequence -func (p *PostgresStorage) DeleteUngeneratedProofs(ctx context.Context, dbTx pgx.Tx) error { - const deleteUngeneratedProofsSQL = "DELETE FROM aggregator.proof WHERE generating_since IS NOT NULL" - e := p.getExecQuerier(dbTx) - _, err := e.Exec(ctx, deleteUngeneratedProofsSQL) - return err -} - -func toPostgresInterval(duration string) (string, error) { - unit := duration[len(duration)-1] - var pgUnit string - - switch unit { - case 's': - pgUnit = "second" - case 'm': - pgUnit = "minute" - case 'h': - pgUnit = "hour" - default: - return "", state.ErrUnsupportedDuration - } - - isMoreThanOne := duration[0] != '1' || len(duration) > 2 //nolint:mnd - if isMoreThanOne { - pgUnit += "s" - } - - return fmt.Sprintf("%s %s", duration[:len(duration)-1], pgUnit), nil -} diff --git a/state/pgstatestorage/sequence.go b/state/pgstatestorage/sequence.go deleted file mode 100644 index 7d5be9fb..00000000 --- a/state/pgstatestorage/sequence.go +++ /dev/null @@ -1,21 +0,0 @@ -package pgstatestorage - -import ( - "context" - - "github.com/0xPolygon/cdk/state" - "github.com/jackc/pgx/v4" -) - -// AddSequence stores the sequence information to allow the aggregator verify sequences. -func (p *PostgresStorage) AddSequence(ctx context.Context, sequence state.Sequence, dbTx pgx.Tx) error { - const addSequenceSQL = ` - INSERT INTO aggregator.sequence (from_batch_num, to_batch_num) - VALUES($1, $2) - ON CONFLICT (from_batch_num) DO UPDATE SET to_batch_num = $2 - ` - - e := p.getExecQuerier(dbTx) - _, err := e.Exec(ctx, addSequenceSQL, sequence.FromBatchNumber, sequence.ToBatchNumber) - return err -} diff --git a/state/state.go b/state/state.go deleted file mode 100644 index c9235ce4..00000000 --- a/state/state.go +++ /dev/null @@ -1,40 +0,0 @@ -package state - -import ( - "context" - - "github.com/ethereum/go-ethereum/common" - "github.com/jackc/pgx/v4" -) - -var ( - // ZeroHash is the hash 0x0000000000000000000000000000000000000000000000000000000000000000 - ZeroHash = common.Hash{} - // ZeroAddress is the address 0x0000000000000000000000000000000000000000 - ZeroAddress = common.Address{} -) - -// State is an implementation of the state -type State struct { - cfg Config - storage -} - -// NewState creates a new State -func NewState(cfg Config, storage storage) *State { - state := &State{ - cfg: cfg, - storage: storage, - } - - return state -} - -// BeginStateTransaction starts a state transaction -func (s *State) BeginStateTransaction(ctx context.Context) (pgx.Tx, error) { - tx, err := s.Begin(ctx) - if err != nil { - return nil, err - } - return tx, nil -} diff --git a/test/Makefile b/test/Makefile index 49c22f95..d4e3c274 100644 --- a/test/Makefile +++ b/test/Makefile @@ -54,10 +54,10 @@ generate-mocks-sync: ## Generates mocks for sync, using mockery tool generate-mocks-aggregator: ## Generates mocks for aggregator, using mockery tool export "GOROOT=$$(go env GOROOT)" && $$(go env GOPATH)/bin/mockery --name=ProverInterface --dir=../aggregator --output=../aggregator/mocks --outpkg=mocks --structname=ProverInterfaceMock --filename=mock_prover.go export "GOROOT=$$(go env GOROOT)" && $$(go env GOPATH)/bin/mockery --name=Etherman --dir=../aggregator --output=../aggregator/mocks --outpkg=mocks --structname=EthermanMock --filename=mock_etherman.go - export "GOROOT=$$(go env GOROOT)" && $$(go env GOPATH)/bin/mockery --name=StateInterface --dir=../aggregator --output=../aggregator/mocks --outpkg=mocks --structname=StateInterfaceMock --filename=mock_state.go + export "GOROOT=$$(go env GOROOT)" && $$(go env GOPATH)/bin/mockery --name=StorageInterface --dir=../aggregator --output=../aggregator/mocks --outpkg=mocks --structname=StorageInterfaceMock --filename=mock_storage.go export "GOROOT=$$(go env GOROOT)" && $$(go env GOPATH)/bin/mockery --name=Synchronizer --srcpkg=github.com/0xPolygonHermez/zkevm-synchronizer-l1/synchronizer --output=../aggregator/mocks --outpkg=mocks --structname=SynchronizerInterfaceMock --filename=mock_synchronizer.go export "GOROOT=$$(go env GOROOT)" && $$(go env GOPATH)/bin/mockery --name=EthTxManagerClient --dir=../aggregator --output=../aggregator/mocks --outpkg=mocks --structname=EthTxManagerClientMock --filename=mock_eth_tx_manager.go - export "GOROOT=$$(go env GOROOT)" && $$(go env GOPATH)/bin/mockery --name=Tx --srcpkg=github.com/jackc/pgx/v4 --output=../aggregator/mocks --outpkg=mocks --structname=DbTxMock --filename=mock_dbtx.go + export "GOROOT=$$(go env GOROOT)" && $$(go env GOPATH)/bin/mockery --name=Txer --dir=../db --output=../aggregator/mocks --outpkg=mocks --structname=TxerMock --filename=mock_txer.go export "GOROOT=$$(go env GOROOT)" && $$(go env GOPATH)/bin/mockery --name=RPCInterface --dir=../aggregator --output=../aggregator/mocks --outpkg=mocks --structname=RPCInterfaceMock --filename=mock_rpc.go export "GOROOT=$$(go env GOROOT)" && $$(go env GOPATH)/bin/mockery --name=AggregatorService_ChannelServer --dir=../aggregator/prover --output=../aggregator/prover/mocks --outpkg=mocks --structname=ChannelMock --filename=mock_channel.go diff --git a/test/config/kurtosis-cdk-node-config.toml.template b/test/config/kurtosis-cdk-node-config.toml.template index 4069b350..45ef7464 100644 --- a/test/config/kurtosis-cdk-node-config.toml.template +++ b/test/config/kurtosis-cdk-node-config.toml.template @@ -53,14 +53,6 @@ Outputs = ["stderr"] VerifyProofInterval = "10s" GasOffset = 150000 SettlementBackend = "agglayer" - [Aggregator.DB] - Name = "{{.aggregator_db.name}}" - User = "{{.aggregator_db.user}}" - Password = "{{.aggregator_db.password}}" - Host = "{{.aggregator_db.hostname}}" - Port = "{{.aggregator_db.port}}" - EnableLog = false - MaxConns = 200 [AggSender] CertificateSendInterval = "1m" diff --git a/test/config/test.config.toml b/test/config/test.config.toml index 94940469..9da00c79 100644 --- a/test/config/test.config.toml +++ b/test/config/test.config.toml @@ -58,14 +58,6 @@ AggLayerURL = "" SyncModeOnlyEnabled = false UseFullWitness = false SequencerPrivateKey = {} - [Aggregator.DB] - Name = "aggregator_db" - User = "aggregator_user" - Password = "aggregator_password" - Host = "cdk-aggregator-db" - Port = "5432" - EnableLog = false - MaxConns = 200 [Aggregator.Log] Environment = "development" # "production" or "development" Level = "info"