From b801d2a98a3f29cdd114b96d97924701e40cce0f Mon Sep 17 00:00:00 2001 From: Stone-afk <73482944+Stone-afk@users.noreply.github.com> Date: Mon, 10 Jul 2023 22:47:03 +0800 Subject: [PATCH 1/3] =?UTF-8?q?feature=EF=BC=9A=20=E5=88=86=E5=BA=93?= =?UTF-8?q?=E5=88=86=E8=A1=A8=EF=BC=9Adatasource-=E7=AE=80=E5=8D=95?= =?UTF-8?q?=E7=9A=84=E5=88=86=E5=B8=83=E5=BC=8F=E4=BA=8B=E5=8A=A1=E6=96=B9?= =?UTF-8?q?=E6=A1=88=E6=94=AF=E6=8C=81=20(#210)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feature: datasource-简单的分布式事务方案支持 * feature: datasource-简单的分布式事务方案支持 * feature: datasource-简单的分布式事务方案支持 * feat: add test * faeture: add test * faeture: add test * 解决了事务卡死的问题 * feature: modify test * feature: modify test * feature: modify test * feature: modify test --------- Co-authored-by: Deng Ming --- .CHANGELOG.md | 1 + db.go | 2 +- internal/datasource/cluster/cluster_db.go | 37 +- .../datasource/cluster/cluster_db_test.go | 4 +- .../shardingsource/sharding_datasource.go | 41 +- .../sharding_datasource_test.go | 4 +- internal/datasource/single/db.go | 27 +- .../transaction/delay_transaction.go | 114 ++++ .../transaction/delay_transaction_test.go | 524 ++++++++++++++++++ .../transaction/single_transaction.go | 113 ++++ .../transaction/single_transaction_test.go | 384 +++++++++++++ .../transaction/transaction_suite_test.go | 187 +++++++ .../transaction/transaction_test.go | 28 +- internal/datasource/transaction/types.go | 80 +++ internal/datasource/types.go | 4 + internal/errs/error.go | 39 +- .../sharding_delay_transaction_test.go | 256 +++++++++ internal/integration/sharding_select_test.go | 91 +-- .../sharding_single_transaction_test.go | 240 ++++++++ internal/integration/sharding_suite_test.go | 105 +++- internal/integration/sharding_update_test.go | 88 +-- internal/merger/batchmerger/merger.go | 16 + sharding_insert_test.go | 4 +- sharding_select.go | 79 ++- sharding_select_test.go | 20 +- sharding_update_test.go | 14 +- transaction_test.go | 40 +- 27 files changed, 2272 insertions(+), 270 deletions(-) create mode 100644 internal/datasource/transaction/delay_transaction.go create mode 100644 internal/datasource/transaction/delay_transaction_test.go create mode 100644 internal/datasource/transaction/single_transaction.go create mode 100644 internal/datasource/transaction/single_transaction_test.go create mode 100644 internal/datasource/transaction/transaction_suite_test.go create mode 100644 internal/datasource/transaction/types.go create mode 100644 internal/integration/sharding_delay_transaction_test.go create mode 100644 internal/integration/sharding_single_transaction_test.go diff --git a/.CHANGELOG.md b/.CHANGELOG.md index 1584a5f5..df2a89fd 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -31,6 +31,7 @@ - [eorm: 分库分表:Inserter 支持分库分表](https://github.com/ecodeclub/eorm/pull/200) - [eorm: ShardingInserter 修改为表维度执行](https://github.com/ecodeclub/eorm/pull/211) - [eorm: 分库分表:ShardingUpdater 实现](https://github.com/ecodeclub/eorm/pull/201) +- [eorm: 分库分表:datasource-简单的分布式事务方案支持](https://github.com/ecodeclub/eorm/pull/204) ## v0.0.1: - [Init Project](https://github.com/ecodeclub/eorm/pull/1) diff --git a/db.go b/db.go index 3d0f0a67..1439a80b 100644 --- a/db.go +++ b/db.go @@ -50,7 +50,7 @@ func DBWithMiddlewares(ms ...Middleware) DBOption { } } -func DBOptionWithMetaRegistry(r model.MetaRegistry) DBOption { +func DBWithMetaRegistry(r model.MetaRegistry) DBOption { return func(db *DB) { db.metaRegistry = r } diff --git a/internal/datasource/cluster/cluster_db.go b/internal/datasource/cluster/cluster_db.go index fa4478a2..25c70853 100644 --- a/internal/datasource/cluster/cluster_db.go +++ b/internal/datasource/cluster/cluster_db.go @@ -19,13 +19,17 @@ import ( "database/sql" "fmt" + "github.com/ecodeclub/eorm/internal/datasource/transaction" + "github.com/ecodeclub/eorm/internal/datasource" "github.com/ecodeclub/eorm/internal/datasource/masterslave" "github.com/ecodeclub/eorm/internal/errs" "go.uber.org/multierr" ) +var _ datasource.TxBeginner = &clusterDB{} var _ datasource.DataSource = &clusterDB{} +var _ datasource.Finder = &clusterDB{} // clusterDB 以 DB 名称作为索引目标数据库 type clusterDB struct { @@ -34,9 +38,9 @@ type clusterDB struct { } func (c *clusterDB) Query(ctx context.Context, query datasource.Query) (*sql.Rows, error) { - ms, ok := c.masterSlavesDBs[query.DB] - if !ok { - return nil, errs.ErrNotFoundTargetDB + ms, err := c.getTgt(query) + if err != nil { + return nil, err } return ms.Query(ctx, query) } @@ -44,7 +48,7 @@ func (c *clusterDB) Query(ctx context.Context, query datasource.Query) (*sql.Row func (c *clusterDB) Exec(ctx context.Context, query datasource.Query) (sql.Result, error) { ms, ok := c.masterSlavesDBs[query.DB] if !ok { - return nil, errs.ErrNotFoundTargetDB + return nil, errs.NewErrNotFoundTargetDB(query.DB) } return ms.Exec(ctx, query) } @@ -60,6 +64,31 @@ func (c *clusterDB) Close() error { return err } +func (c *clusterDB) FindTgt(_ context.Context, query datasource.Query) (datasource.TxBeginner, error) { + db, err := c.getTgt(query) + if err != nil { + return nil, err + } + return db, nil +} + +func (c *clusterDB) getTgt(query datasource.Query) (*masterslave.MasterSlavesDB, error) { + db, ok := c.masterSlavesDBs[query.DB] + if !ok { + return nil, errs.NewErrNotFoundTargetDB(query.DB) + } + return db, nil +} + +func (c *clusterDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (datasource.Tx, error) { + facade, err := transaction.NewTxFacade(ctx, c) + if err != nil { + return nil, err + } + + return facade.BeginTx(ctx, opts) +} + func NewClusterDB(ms map[string]*masterslave.MasterSlavesDB) datasource.DataSource { return &clusterDB{masterSlavesDBs: ms} } diff --git a/internal/datasource/cluster/cluster_db_test.go b/internal/datasource/cluster/cluster_db_test.go index 65cf02dc..c398d910 100644 --- a/internal/datasource/cluster/cluster_db_test.go +++ b/internal/datasource/cluster/cluster_db_test.go @@ -128,7 +128,7 @@ func (c *ClusterSuite) TestClusterDbQuery() { masterSlaves := map[string]*masterslave.MasterSlavesDB{"order_db_0": db} return masterSlaves }(), - wantErr: errs.ErrNotFoundTargetDB, + wantErr: errs.NewErrNotFoundTargetDB("order_db_1"), }, { name: "select default use slave", @@ -219,7 +219,7 @@ func (c *ClusterSuite) TestClusterDbExec() { masterSlaves := map[string]*masterslave.MasterSlavesDB{"order_db_0": db} return masterSlaves }(), - wantErr: errs.ErrNotFoundTargetDB, + wantErr: errs.NewErrNotFoundTargetDB("order_db_1"), }, { name: "null slave", diff --git a/internal/datasource/shardingsource/sharding_datasource.go b/internal/datasource/shardingsource/sharding_datasource.go index a64d57d0..e8dcd90e 100644 --- a/internal/datasource/shardingsource/sharding_datasource.go +++ b/internal/datasource/shardingsource/sharding_datasource.go @@ -19,6 +19,8 @@ import ( "database/sql" "fmt" + "github.com/ecodeclub/eorm/internal/datasource/transaction" + "github.com/ecodeclub/eorm/internal/datasource" "go.uber.org/multierr" @@ -27,29 +29,54 @@ import ( var _ datasource.TxBeginner = &ShardingDataSource{} var _ datasource.DataSource = &ShardingDataSource{} +var _ datasource.Finder = &ShardingDataSource{} type ShardingDataSource struct { sources map[string]datasource.DataSource } func (s *ShardingDataSource) Query(ctx context.Context, query datasource.Query) (*sql.Rows, error) { - ds, ok := s.sources[query.Datasource] - if !ok { - return nil, errs.ErrNotFoundTargetDataSource + ds, err := s.getTgt(query) + if err != nil { + return nil, err } return ds.Query(ctx, query) } func (s *ShardingDataSource) Exec(ctx context.Context, query datasource.Query) (sql.Result, error) { + ds, err := s.getTgt(query) + if err != nil { + return nil, err + } + return ds.Exec(ctx, query) +} + +func (s *ShardingDataSource) FindTgt(ctx context.Context, query datasource.Query) (datasource.TxBeginner, error) { + ds, err := s.getTgt(query) + if err != nil { + return nil, err + } + f, ok := ds.(datasource.Finder) + if !ok { + return nil, errs.NewErrNotCompleteFinder(query.Datasource) + } + return f.FindTgt(ctx, query) +} + +func (s *ShardingDataSource) getTgt(query datasource.Query) (datasource.DataSource, error) { ds, ok := s.sources[query.Datasource] if !ok { - return nil, errs.ErrNotFoundTargetDataSource + return nil, errs.NewErrNotFoundTargetDataSource(query.Datasource) } - return ds.Exec(ctx, query) + return ds, nil } -func (*ShardingDataSource) BeginTx(_ context.Context, _ *sql.TxOptions) (datasource.Tx, error) { - panic("`BeginTx` must be completed") +func (s *ShardingDataSource) BeginTx(ctx context.Context, opts *sql.TxOptions) (datasource.Tx, error) { + facade, err := transaction.NewTxFacade(ctx, s) + if err != nil { + return nil, err + } + return facade.BeginTx(ctx, opts) } func NewShardingDataSource(m map[string]datasource.DataSource) datasource.DataSource { diff --git a/internal/datasource/shardingsource/sharding_datasource_test.go b/internal/datasource/shardingsource/sharding_datasource_test.go index 283adeed..a3bf1111 100644 --- a/internal/datasource/shardingsource/sharding_datasource_test.go +++ b/internal/datasource/shardingsource/sharding_datasource_test.go @@ -180,7 +180,7 @@ func (c *ShardingDataSourceSuite) TestClusterDbQuery() { DB: "db_0", Datasource: "2.db.cluster.company.com:3306", }, - wantErr: errs.ErrNotFoundTargetDataSource, + wantErr: errs.NewErrNotFoundTargetDataSource("2.db.cluster.company.com:3306"), }, { name: "cluster0 select default use slave", @@ -280,7 +280,7 @@ func (c *ShardingDataSourceSuite) TestClusterDbExec() { DB: "db_0", Datasource: "2.db.cluster.company.com:3306", }, - wantErr: errs.ErrNotFoundTargetDataSource, + wantErr: errs.NewErrNotFoundTargetDataSource("2.db.cluster.company.com:3306"), }, { name: "cluster0 exec", diff --git a/internal/datasource/single/db.go b/internal/datasource/single/db.go index f895a3ad..8e06af62 100644 --- a/internal/datasource/single/db.go +++ b/internal/datasource/single/db.go @@ -31,7 +31,8 @@ var _ datasource.DataSource = &DB{} // DB represents a database type DB struct { - db *sql.DB + db *sql.DB + multiStatements bool } func (db *DB) Query(ctx context.Context, query datasource.Query) (*sql.Rows, error) { @@ -42,12 +43,22 @@ func (db *DB) Exec(ctx context.Context, query datasource.Query) (sql.Result, err return db.db.ExecContext(ctx, query.SQL, query.Args...) } -func OpenDB(driver string, dsn string) (*DB, error) { +func OpenDB(driver string, dsn string, opts ...Option) (*DB, error) { + res := &DB{} + for _, o := range opts { + o(res) + } + + if res.multiStatements { + dsn = dsn + "?multiStatements=true" + } + db, err := sql.Open(driver, dsn) if err != nil { return nil, err } - return &DB{db: db}, nil + res.db = db + return res, nil } func NewDB(db *sql.DB) *DB { @@ -77,3 +88,13 @@ func (db *DB) Wait() error { func (db *DB) Close() error { return db.db.Close() } + +type Option func(db *DB) + +// DBWithMultiStatements 在创建连接时 加入参数 multiStatements=true,允许多条语句查询 +// 当然 multi statements 可能会增加sql注入的风险,故该操作只允许一次性业务操作,连接使用完成后需要关闭连接 +func DBWithMultiStatements(m bool) Option { + return func(db *DB) { + db.multiStatements = m + } +} diff --git a/internal/datasource/transaction/delay_transaction.go b/internal/datasource/transaction/delay_transaction.go new file mode 100644 index 00000000..93d7b3da --- /dev/null +++ b/internal/datasource/transaction/delay_transaction.go @@ -0,0 +1,114 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transaction + +import ( + "context" + "database/sql" + "fmt" + "sync" + + "github.com/ecodeclub/eorm/internal/datasource" + "go.uber.org/multierr" +) + +type DelayTxFactory struct{} + +func (DelayTxFactory) TxOf(ctx Context, finder datasource.Finder) (datasource.Tx, error) { + return NewDelayTx(ctx, finder), nil +} + +type DelayTx struct { + ctx Context + lock sync.RWMutex + txs map[string]datasource.Tx + finder datasource.Finder +} + +func (t *DelayTx) findTgt(ctx context.Context, query datasource.Query) (datasource.TxBeginner, error) { + return t.finder.FindTgt(ctx, query) +} + +func (t *DelayTx) findOrBeginTx(ctx context.Context, query datasource.Query) (datasource.Tx, error) { + t.lock.RLock() + tx, ok := t.txs[query.DB] + t.lock.RUnlock() + if ok { + return tx, nil + } + t.lock.Lock() + defer t.lock.Unlock() + if tx, ok = t.txs[query.DB]; ok { + return tx, nil + } + var err error + db, err := t.findTgt(ctx, query) + if err != nil { + return nil, err + } + tx, err = db.BeginTx(t.ctx.TxCtx, t.ctx.Opts) + if err != nil { + return nil, err + } + t.txs[query.DB] = tx + return tx, nil +} + +func (t *DelayTx) Query(ctx context.Context, query datasource.Query) (*sql.Rows, error) { + // 防止 GetMulti 的查询重复创建多个事务 + tx, err := t.findOrBeginTx(ctx, query) + if err != nil { + return nil, err + } + return tx.Query(ctx, query) +} + +func (t *DelayTx) Exec(ctx context.Context, query datasource.Query) (sql.Result, error) { + tx, err := t.findOrBeginTx(ctx, query) + if err != nil { + return nil, err + } + return tx.Exec(ctx, query) +} + +func (t *DelayTx) Commit() error { + var err error + for name, tx := range t.txs { + if er := tx.Commit(); er != nil { + err = multierr.Combine( + err, fmt.Errorf("masterslave DB name [%s] Commit error: %w", name, er)) + } + } + return err +} + +func (t *DelayTx) Rollback() error { + var err error + for name, tx := range t.txs { + if er := tx.Rollback(); er != nil { + err = multierr.Combine( + err, fmt.Errorf("masterslave DB name [%s] Rollback error: %w", name, er)) + } + } + return err +} + +func NewDelayTx(ctx Context, finder datasource.Finder) *DelayTx { + return &DelayTx{ + ctx: ctx, + finder: finder, + txs: make(map[string]datasource.Tx, 8), + } +} diff --git a/internal/datasource/transaction/delay_transaction_test.go b/internal/datasource/transaction/delay_transaction_test.go new file mode 100644 index 00000000..3bdc5c4c --- /dev/null +++ b/internal/datasource/transaction/delay_transaction_test.go @@ -0,0 +1,524 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transaction_test + +import ( + "context" + "database/sql" + "errors" + "regexp" + "strings" + "testing" + + "github.com/ecodeclub/eorm/internal/datasource" + "github.com/ecodeclub/eorm/internal/datasource/cluster" + "github.com/ecodeclub/eorm/internal/datasource/shardingsource" + "github.com/ecodeclub/eorm/internal/errs" + "github.com/ecodeclub/eorm/internal/model" + "go.uber.org/multierr" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/ecodeclub/eorm" + "github.com/ecodeclub/eorm/internal/datasource/masterslave" + "github.com/ecodeclub/eorm/internal/datasource/transaction" + "github.com/ecodeclub/eorm/internal/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type TestDelayTxTestSuite struct { + ShardingTransactionSuite +} + +func (s *TestDelayTxTestSuite) TestExecute_Commit_Or_Rollback() { + t := s.T() + testCases := []struct { + name string + wantAffected int64 + wantErr error + values []*test.OrderDetail + querySet []*test.OrderDetail + txFunc func() (*eorm.Tx, error) + mockOrder func(mock1, mock2 sqlmock.Sqlmock) + afterFunc func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) + }{ + { + name: "begin err", + wantErr: errors.New("begin err"), + querySet: []*test.OrderDetail{ + {OrderId: 8, ItemId: 6, UsingCol1: "Kobe", UsingCol2: "Bryant"}, + {OrderId: 11, ItemId: 8, UsingCol1: "James", UsingCol2: "Harden"}, + {OrderId: 234, ItemId: 12, UsingCol1: "Kevin", UsingCol2: "Durant"}, + {OrderId: 253, ItemId: 8, UsingCol1: "Stephen", UsingCol2: "Curry"}, + {OrderId: 181, ItemId: 11, UsingCol1: "Kawhi", UsingCol2: "Leonard"}, + }, + mockOrder: func(mock1, mock2 sqlmock.Sqlmock) { + mock2.ExpectBegin().WillReturnError(errors.New("begin err")) + mock1.ExpectBegin().WillReturnError(errors.New("begin err")) + }, + txFunc: func() (*eorm.Tx, error) { + return s.shardingDB.BeginTx(transaction.UsingTxType(context.Background(), transaction.Delay), &sql.TxOptions{}) + }, + afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) {}, + }, + { + name: "not find data source err", + wantErr: errs.NewErrNotFoundTargetDataSource("0.db.cluster.company.com:3306"), + mockOrder: func(mock1, mock2 sqlmock.Sqlmock) {}, + afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) {}, + txFunc: func() (*eorm.Tx, error) { + s.DataSource = shardingsource.NewShardingDataSource(map[string]datasource.DataSource{ + "1.db.cluster.company.com:3306": s.clusterDB, + }) + r := model.NewMetaRegistry() + _, err := r.Register(&test.OrderDetail{}, + model.WithTableShardingAlgorithm(s.algorithm)) + require.NoError(t, err) + db, err := eorm.OpenDS("mysql", s.DataSource, eorm.DBWithMetaRegistry(r)) + require.NoError(t, err) + return db.BeginTx(transaction.UsingTxType(context.Background(), transaction.Delay), &sql.TxOptions{}) + }, + }, + { + name: "not complete Finder err", + wantErr: errs.NewErrNotCompleteFinder("0.db.cluster.company.com:3306"), + mockOrder: func(mock1, mock2 sqlmock.Sqlmock) {}, + afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) {}, + txFunc: func() (*eorm.Tx, error) { + s.DataSource = shardingsource.NewShardingDataSource(map[string]datasource.DataSource{ + "0.db.cluster.company.com:3306": masterslave.NewMasterSlavesDB(s.mockMaster1DB, masterslave.MasterSlavesWithSlaves( + newSlaves(t, s.mockSlave1DB, s.mockSlave2DB, s.mockSlave3DB))), + }) + r := model.NewMetaRegistry() + _, err := r.Register(&test.OrderDetail{}, + model.WithTableShardingAlgorithm(s.algorithm)) + require.NoError(t, err) + db, err := eorm.OpenDS("mysql", s.DataSource, eorm.DBWithMetaRegistry(r)) + require.NoError(t, err) + return db.BeginTx(transaction.UsingTxType(context.Background(), transaction.Delay), &sql.TxOptions{}) + }, + }, + { + name: "not find target db err", + wantErr: errs.NewErrNotFoundTargetDB("order_detail_db_1"), + mockOrder: func(mock1, mock2 sqlmock.Sqlmock) { + mock1.ExpectBegin() + }, + afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) {}, + txFunc: func() (*eorm.Tx, error) { + clusterDB := cluster.NewClusterDB(map[string]*masterslave.MasterSlavesDB{ + "order_detail_db_0": masterslave.NewMasterSlavesDB(s.mockMaster1DB, masterslave.MasterSlavesWithSlaves( + newSlaves(t, s.mockSlave1DB, s.mockSlave2DB, s.mockSlave3DB))), + }) + s.DataSource = shardingsource.NewShardingDataSource(map[string]datasource.DataSource{ + "0.db.cluster.company.com:3306": clusterDB, + }) + r := model.NewMetaRegistry() + _, err := r.Register(&test.OrderDetail{}, + model.WithTableShardingAlgorithm(s.algorithm)) + require.NoError(t, err) + db, err := eorm.OpenDS("mysql", s.DataSource, eorm.DBWithMetaRegistry(r)) + require.NoError(t, err) + return db.BeginTx(transaction.UsingTxType(context.Background(), transaction.Delay), &sql.TxOptions{}) + }, + }, + { + name: "select insert all commit err", + wantAffected: 2, + values: []*test.OrderDetail{ + {OrderId: 288, ItemId: 101, UsingCol1: "Jimmy", UsingCol2: "Butler"}, + {OrderId: 33, ItemId: 100, UsingCol1: "Nikolai", UsingCol2: "Jokic"}, + }, + querySet: []*test.OrderDetail{ + {OrderId: 8, ItemId: 6, UsingCol1: "Kobe", UsingCol2: "Bryant"}, + {OrderId: 11, ItemId: 8, UsingCol1: "James", UsingCol2: "Harden"}, + {OrderId: 234, ItemId: 12, UsingCol1: "Kevin", UsingCol2: "Durant"}, + {OrderId: 253, ItemId: 8, UsingCol1: "Stephen", UsingCol2: "Curry"}, + {OrderId: 181, ItemId: 11, UsingCol1: "Kawhi", UsingCol2: "Leonard"}, + }, + mockOrder: func(mock1, mock2 sqlmock.Sqlmock) { + mock1.MatchExpectationsInOrder(false) + mock2.MatchExpectationsInOrder(false) + mock1.ExpectBegin() + mock2.ExpectBegin() + + mock1.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_0` WHERE `order_id`!=?;SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_1` WHERE `order_id`!=?;SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_2` WHERE `order_id`!=?;")). + WithArgs(123, 123, 123). + WillReturnRows(mock1.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}).AddRow(234, 12, "Kevin", "Durant").AddRow(8, 6, "Kobe", "Bryant")) + + mock2.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_0` WHERE `order_id`!=?;SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_1` WHERE `order_id`!=?;SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_2` WHERE `order_id`!=?;")). + WithArgs(123, 123, 123). + WillReturnRows(mock2.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}).AddRow(253, 8, "Stephen", "Curry").AddRow(181, 11, "Kawhi", "Leonard").AddRow(11, 8, "James", "Harden")) + + mock1.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_detail_db_0`.`order_detail_tab_0`(`order_id`,`item_id`,`using_col1`,`using_col2`) VALUES(?,?,?,?);")). + WithArgs(288, 101, "Jimmy", "Butler").WillReturnResult(sqlmock.NewResult(1, 1)) + mock2.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_detail_db_1`.`order_detail_tab_0`(`order_id`,`item_id`,`using_col1`,`using_col2`) VALUES(?,?,?,?);")). + WithArgs(33, 100, "Nikolai", "Jokic").WillReturnResult(sqlmock.NewResult(1, 1)) + + commitErr := errors.New("commit fail") + mock1.ExpectCommit().WillReturnError(commitErr) + mock2.ExpectCommit().WillReturnError(commitErr) + }, + txFunc: func() (*eorm.Tx, error) { + return s.shardingDB.BeginTx(transaction.UsingTxType(context.Background(), transaction.Delay), &sql.TxOptions{}) + }, + afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) { + err := tx.Commit() + newErr := errors.New("commit fail") + errSlice := strings.Split(err.Error(), "; ") + wantErrSlice := []string{ + newMockCommitErr("order_detail_db_0", newErr).Error(), + newMockCommitErr("order_detail_db_1", newErr).Error()} + assert.ElementsMatch(t, wantErrSlice, errSlice) + + s.mockMaster.MatchExpectationsInOrder(false) + s.mockMaster2.MatchExpectationsInOrder(false) + rows := s.mockMaster.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}) + s.mockMaster.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_0` WHERE (`order_id`=?) OR (`order_id`=?);")). + WithArgs(288, 33).WillReturnRows(rows) + + s.mockMaster2.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_0` WHERE (`order_id`=?) OR (`order_id`=?);")). + WithArgs(288, 33).WillReturnRows(rows) + + queryVal := s.findTgt(t, values) + var wantOds []*test.OrderDetail + assert.ElementsMatch(t, wantOds, queryVal) + }, + }, + { + name: "select insert part commit err", + wantAffected: 2, + values: []*test.OrderDetail{ + {OrderId: 288, ItemId: 101, UsingCol1: "Jimmy", UsingCol2: "Butler"}, + {OrderId: 33, ItemId: 100, UsingCol1: "Nikolai", UsingCol2: "Jokic"}, + }, + querySet: []*test.OrderDetail{ + {OrderId: 8, ItemId: 6, UsingCol1: "Kobe", UsingCol2: "Bryant"}, + {OrderId: 11, ItemId: 8, UsingCol1: "James", UsingCol2: "Harden"}, + {OrderId: 234, ItemId: 12, UsingCol1: "Kevin", UsingCol2: "Durant"}, + {OrderId: 253, ItemId: 8, UsingCol1: "Stephen", UsingCol2: "Curry"}, + {OrderId: 181, ItemId: 11, UsingCol1: "Kawhi", UsingCol2: "Leonard"}, + }, + mockOrder: func(mock1, mock2 sqlmock.Sqlmock) { + mock1.MatchExpectationsInOrder(false) + mock2.MatchExpectationsInOrder(false) + mock1.ExpectBegin() + mock2.ExpectBegin() + + mock1.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_0` WHERE `order_id`!=?;SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_1` WHERE `order_id`!=?;SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_2` WHERE `order_id`!=?;")). + WithArgs(123, 123, 123). + WillReturnRows(mock1.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}).AddRow(234, 12, "Kevin", "Durant").AddRow(8, 6, "Kobe", "Bryant")) + + mock2.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_0` WHERE `order_id`!=?;SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_1` WHERE `order_id`!=?;SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_2` WHERE `order_id`!=?;")). + WithArgs(123, 123, 123). + WillReturnRows(mock1.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}).AddRow(253, 8, "Stephen", "Curry").AddRow(181, 11, "Kawhi", "Leonard").AddRow(11, 8, "James", "Harden")) + + mock1.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_detail_db_0`.`order_detail_tab_0`(`order_id`,`item_id`,`using_col1`,`using_col2`) VALUES(?,?,?,?);")). + WithArgs(288, 101, "Jimmy", "Butler").WillReturnResult(sqlmock.NewResult(1, 1)) + mock2.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_detail_db_1`.`order_detail_tab_0`(`order_id`,`item_id`,`using_col1`,`using_col2`) VALUES(?,?,?,?);")). + WithArgs(33, 100, "Nikolai", "Jokic").WillReturnResult(sqlmock.NewResult(1, 1)) + + mock1.ExpectCommit() + mock2.ExpectCommit().WillReturnError(errors.New("commit fail")) + }, + txFunc: func() (*eorm.Tx, error) { + return s.shardingDB.BeginTx(transaction.UsingTxType(context.Background(), transaction.Delay), &sql.TxOptions{}) + }, + afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) { + err := tx.Commit() + wantErr := multierr.Combine(newMockCommitErr("order_detail_db_1", errors.New("commit fail"))) + assert.Equal(t, wantErr, err) + + s.mockMaster.MatchExpectationsInOrder(false) + s.mockMaster2.MatchExpectationsInOrder(false) + + rows := s.mockMaster.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}) + s.mockMaster.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_0` WHERE (`order_id`=?) OR (`order_id`=?);")). + WithArgs(288, 33).WillReturnRows(rows) + + s.mockMaster2.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_0` WHERE (`order_id`=?) OR (`order_id`=?);")). + WithArgs(288, 33).WillReturnRows(rows) + + queryVal := s.findTgt(t, values) + var wantVal []*test.OrderDetail + assert.ElementsMatch(t, wantVal, queryVal) + }, + }, + { + name: "select insert all rollback err", + wantAffected: 2, + values: []*test.OrderDetail{ + {OrderId: 288, ItemId: 101, UsingCol1: "Jimmy", UsingCol2: "Butler"}, + {OrderId: 33, ItemId: 100, UsingCol1: "Nikolai", UsingCol2: "Jokic"}, + }, + querySet: []*test.OrderDetail{ + {OrderId: 8, ItemId: 6, UsingCol1: "Kobe", UsingCol2: "Bryant"}, + {OrderId: 11, ItemId: 8, UsingCol1: "James", UsingCol2: "Harden"}, + {OrderId: 234, ItemId: 12, UsingCol1: "Kevin", UsingCol2: "Durant"}, + {OrderId: 253, ItemId: 8, UsingCol1: "Stephen", UsingCol2: "Curry"}, + {OrderId: 181, ItemId: 11, UsingCol1: "Kawhi", UsingCol2: "Leonard"}, + }, + mockOrder: func(mock1, mock2 sqlmock.Sqlmock) { + mock1.MatchExpectationsInOrder(false) + mock2.MatchExpectationsInOrder(false) + mock1.ExpectBegin() + mock2.ExpectBegin() + + mock1.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_0` WHERE `order_id`!=?;SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_1` WHERE `order_id`!=?;SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_2` WHERE `order_id`!=?;")). + WithArgs(123, 123, 123). + WillReturnRows(mock1.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}).AddRow(234, 12, "Kevin", "Durant").AddRow(8, 6, "Kobe", "Bryant")) + + mock2.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_0` WHERE `order_id`!=?;SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_1` WHERE `order_id`!=?;SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_2` WHERE `order_id`!=?;")). + WithArgs(123, 123, 123). + WillReturnRows(mock1.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}).AddRow(253, 8, "Stephen", "Curry").AddRow(181, 11, "Kawhi", "Leonard").AddRow(11, 8, "James", "Harden")) + + mock1.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_detail_db_0`.`order_detail_tab_0`(`order_id`,`item_id`,`using_col1`,`using_col2`) VALUES(?,?,?,?);")). + WithArgs(288, 101, "Jimmy", "Butler").WillReturnResult(sqlmock.NewResult(1, 1)) + mock2.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_detail_db_1`.`order_detail_tab_0`(`order_id`,`item_id`,`using_col1`,`using_col2`) VALUES(?,?,?,?);")). + WithArgs(33, 100, "Nikolai", "Jokic").WillReturnResult(sqlmock.NewResult(1, 1)) + + rollbackErr := errors.New("rollback fail") + mock1.ExpectRollback().WillReturnError(rollbackErr) + mock2.ExpectRollback().WillReturnError(rollbackErr) + }, + txFunc: func() (*eorm.Tx, error) { + return s.shardingDB.BeginTx(transaction.UsingTxType(context.Background(), transaction.Delay), &sql.TxOptions{}) + }, + afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) { + err := tx.Rollback() + newErr := errors.New("rollback fail") + errSlice := strings.Split(err.Error(), "; ") + wantErrSlice := []string{ + newMockRollbackErr("order_detail_db_0", newErr).Error(), + newMockRollbackErr("order_detail_db_1", newErr).Error()} + assert.ElementsMatch(t, wantErrSlice, errSlice) + + s.mockMaster.MatchExpectationsInOrder(false) + s.mockMaster2.MatchExpectationsInOrder(false) + + rows := s.mockMaster.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}) + s.mockMaster.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_0` WHERE (`order_id`=?) OR (`order_id`=?);")). + WithArgs(288, 33).WillReturnRows(rows) + + s.mockMaster2.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_0` WHERE (`order_id`=?) OR (`order_id`=?);")). + WithArgs(288, 33).WillReturnRows(rows) + + queryVal := s.findTgt(t, values) + var wantOds []*test.OrderDetail + assert.ElementsMatch(t, wantOds, queryVal) + }, + }, + { + name: "select insert part rollback err", + wantAffected: 2, + values: []*test.OrderDetail{ + {OrderId: 288, ItemId: 101, UsingCol1: "Jimmy", UsingCol2: "Butler"}, + {OrderId: 33, ItemId: 100, UsingCol1: "Nikolai", UsingCol2: "Jokic"}, + }, + querySet: []*test.OrderDetail{ + {OrderId: 8, ItemId: 6, UsingCol1: "Kobe", UsingCol2: "Bryant"}, + {OrderId: 11, ItemId: 8, UsingCol1: "James", UsingCol2: "Harden"}, + {OrderId: 234, ItemId: 12, UsingCol1: "Kevin", UsingCol2: "Durant"}, + {OrderId: 253, ItemId: 8, UsingCol1: "Stephen", UsingCol2: "Curry"}, + {OrderId: 181, ItemId: 11, UsingCol1: "Kawhi", UsingCol2: "Leonard"}, + }, + mockOrder: func(mock1, mock2 sqlmock.Sqlmock) { + mock1.MatchExpectationsInOrder(false) + mock2.MatchExpectationsInOrder(false) + mock1.ExpectBegin() + mock2.ExpectBegin() + + mock1.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_0` WHERE `order_id`!=?;SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_1` WHERE `order_id`!=?;SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_2` WHERE `order_id`!=?;")). + WithArgs(123, 123, 123). + WillReturnRows(mock1.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}).AddRow(234, 12, "Kevin", "Durant").AddRow(8, 6, "Kobe", "Bryant")) + + mock2.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_0` WHERE `order_id`!=?;SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_1` WHERE `order_id`!=?;SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_2` WHERE `order_id`!=?;")). + WithArgs(123, 123, 123). + WillReturnRows(mock1.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}).AddRow(253, 8, "Stephen", "Curry").AddRow(181, 11, "Kawhi", "Leonard").AddRow(11, 8, "James", "Harden")) + + mock1.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_detail_db_0`.`order_detail_tab_0`(`order_id`,`item_id`,`using_col1`,`using_col2`) VALUES(?,?,?,?);")). + WithArgs(288, 101, "Jimmy", "Butler").WillReturnResult(sqlmock.NewResult(1, 1)) + mock2.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_detail_db_1`.`order_detail_tab_0`(`order_id`,`item_id`,`using_col1`,`using_col2`) VALUES(?,?,?,?);")). + WithArgs(33, 100, "Nikolai", "Jokic").WillReturnResult(sqlmock.NewResult(1, 1)) + + mock1.ExpectRollback().WillReturnError(errors.New("rollback fail")) + mock2.ExpectRollback() + }, + txFunc: func() (*eorm.Tx, error) { + return s.shardingDB.BeginTx(transaction.UsingTxType(context.Background(), transaction.Delay), &sql.TxOptions{}) + }, + afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) { + err := tx.Rollback() + wantErr := multierr.Combine(newMockRollbackErr("order_detail_db_0", errors.New("rollback fail"))) + assert.Equal(t, wantErr, err) + + s.mockMaster.MatchExpectationsInOrder(false) + s.mockMaster2.MatchExpectationsInOrder(false) + + rows := s.mockMaster.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}) + s.mockMaster.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_0` WHERE (`order_id`=?) OR (`order_id`=?);")). + WithArgs(288, 33).WillReturnRows(rows) + + s.mockMaster2.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_0` WHERE (`order_id`=?) OR (`order_id`=?);")). + WithArgs(288, 33).WillReturnRows(rows) + + queryVal := s.findTgt(t, values) + var wantOds []*test.OrderDetail + assert.ElementsMatch(t, wantOds, queryVal) + }, + }, + { + name: "select insert commit", + wantAffected: 2, + values: []*test.OrderDetail{ + {OrderId: 288, ItemId: 101, UsingCol1: "Jimmy", UsingCol2: "Butler"}, + {OrderId: 33, ItemId: 100, UsingCol1: "Nikolai", UsingCol2: "Jokic"}, + }, + querySet: []*test.OrderDetail{ + {OrderId: 8, ItemId: 6, UsingCol1: "Kobe", UsingCol2: "Bryant"}, + {OrderId: 11, ItemId: 8, UsingCol1: "James", UsingCol2: "Harden"}, + {OrderId: 234, ItemId: 12, UsingCol1: "Kevin", UsingCol2: "Durant"}, + {OrderId: 253, ItemId: 8, UsingCol1: "Stephen", UsingCol2: "Curry"}, + {OrderId: 181, ItemId: 11, UsingCol1: "Kawhi", UsingCol2: "Leonard"}, + }, + mockOrder: func(mock1, mock2 sqlmock.Sqlmock) { + mock1.MatchExpectationsInOrder(false) + mock2.MatchExpectationsInOrder(false) + mock1.ExpectBegin() + mock2.ExpectBegin() + + mock1.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_0` WHERE `order_id`!=?;SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_1` WHERE `order_id`!=?;SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_2` WHERE `order_id`!=?;")). + WithArgs(123, 123, 123). + WillReturnRows(mock1.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}).AddRow(234, 12, "Kevin", "Durant").AddRow(8, 6, "Kobe", "Bryant")) + + mock2.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_0` WHERE `order_id`!=?;SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_1` WHERE `order_id`!=?;SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_2` WHERE `order_id`!=?;")). + WithArgs(123, 123, 123). + WillReturnRows(mock1.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}).AddRow(253, 8, "Stephen", "Curry").AddRow(181, 11, "Kawhi", "Leonard").AddRow(11, 8, "James", "Harden")) + + mock1.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_detail_db_0`.`order_detail_tab_0`(`order_id`,`item_id`,`using_col1`,`using_col2`) VALUES(?,?,?,?);")). + WithArgs(288, 101, "Jimmy", "Butler").WillReturnResult(sqlmock.NewResult(1, 1)) + mock2.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_detail_db_1`.`order_detail_tab_0`(`order_id`,`item_id`,`using_col1`,`using_col2`) VALUES(?,?,?,?);")). + WithArgs(33, 100, "Nikolai", "Jokic").WillReturnResult(sqlmock.NewResult(1, 1)) + + mock1.ExpectCommit() + mock2.ExpectCommit() + }, + txFunc: func() (*eorm.Tx, error) { + return s.shardingDB.BeginTx(transaction.UsingTxType(context.Background(), transaction.Delay), &sql.TxOptions{}) + }, + afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) { + err := tx.Commit() + require.NoError(t, err) + + s.mockMaster.MatchExpectationsInOrder(false) + s.mockMaster2.MatchExpectationsInOrder(false) + + s.mockMaster.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_0` WHERE (`order_id`=?) OR (`order_id`=?);")). + WithArgs(288, 33).WillReturnRows(s.mockMaster.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}).AddRow(288, 101, "Jimmy", "Butler")) + + s.mockMaster2.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_0` WHERE (`order_id`=?) OR (`order_id`=?);")). + WithArgs(288, 33).WillReturnRows(s.mockMaster.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}).AddRow(33, 100, "Nikolai", "Jokic")) + + queryVal := s.findTgt(t, values) + assert.ElementsMatch(t, values, queryVal) + }, + }, + { + name: "select insert rollback", + wantAffected: 2, + values: []*test.OrderDetail{ + {OrderId: 199, ItemId: 100, UsingCol1: "Jason", UsingCol2: "Tatum"}, + {OrderId: 299, ItemId: 101, UsingCol1: "Paul", UsingCol2: "George"}, + }, + querySet: []*test.OrderDetail{ + {OrderId: 8, ItemId: 6, UsingCol1: "Kobe", UsingCol2: "Bryant"}, + {OrderId: 11, ItemId: 8, UsingCol1: "James", UsingCol2: "Harden"}, + {OrderId: 234, ItemId: 12, UsingCol1: "Kevin", UsingCol2: "Durant"}, + {OrderId: 253, ItemId: 8, UsingCol1: "Stephen", UsingCol2: "Curry"}, + {OrderId: 181, ItemId: 11, UsingCol1: "Kawhi", UsingCol2: "Leonard"}, + }, + mockOrder: func(mock1, mock2 sqlmock.Sqlmock) { + mock1.MatchExpectationsInOrder(false) + mock2.MatchExpectationsInOrder(false) + mock1.ExpectBegin() + mock2.ExpectBegin() + + mock1.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_0` WHERE `order_id`!=?;SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_1` WHERE `order_id`!=?;SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_2` WHERE `order_id`!=?;")). + WithArgs(123, 123, 123). + WillReturnRows(mock1.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}).AddRow(234, 12, "Kevin", "Durant").AddRow(8, 6, "Kobe", "Bryant")) + + mock2.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_0` WHERE `order_id`!=?;SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_1` WHERE `order_id`!=?;SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_2` WHERE `order_id`!=?;")). + WithArgs(123, 123, 123). + WillReturnRows(mock1.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}).AddRow(253, 8, "Stephen", "Curry").AddRow(181, 11, "Kawhi", "Leonard").AddRow(11, 8, "James", "Harden")) + + mock2.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_detail_db_1`.`order_detail_tab_1`(`order_id`,`item_id`,`using_col1`,`using_col2`) VALUES(?,?,?,?);")). + WithArgs(199, 100, "Jason", "Tatum").WillReturnResult(sqlmock.NewResult(1, 1)) + mock2.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_detail_db_1`.`order_detail_tab_2`(`order_id`,`item_id`,`using_col1`,`using_col2`) VALUES(?,?,?,?);")). + WithArgs(299, 101, "Paul", "George").WillReturnResult(sqlmock.NewResult(1, 1)) + + mock1.ExpectRollback() + mock2.ExpectRollback() + }, + txFunc: func() (*eorm.Tx, error) { + return s.shardingDB.BeginTx(transaction.UsingTxType(context.Background(), transaction.Delay), &sql.TxOptions{}) + }, + afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) { + err := tx.Rollback() + require.NoError(t, err) + s.mockMaster2.MatchExpectationsInOrder(false) + + rows := s.mockMaster2.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}) + s.mockMaster2.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_2` WHERE (`order_id`=?) OR (`order_id`=?);SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_1` WHERE (`order_id`=?) OR (`order_id`=?);")). + WithArgs(199, 299, 199, 299).WillReturnRows(rows) + + queryVal := s.findTgt(t, values) + var wantOds []*test.OrderDetail + assert.ElementsMatch(t, wantOds, queryVal) + }, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.mockOrder(s.mockMaster, s.mockMaster2) + tx, err := tc.txFunc() + require.NoError(t, err) + + // TODO GetMultiV2 待将 table 维度改成 db 维度 + querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx). + Where(eorm.C("OrderId").NEQ(123)). + GetMultiV2(masterslave.UseMaster(context.Background())) + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + assert.ElementsMatch(t, tc.querySet, querySet) + + values := tc.values + res := eorm.NewShardingInsert[test.OrderDetail](tx). + Values(values).Exec(context.Background()) + affected, err := res.RowsAffected() + require.NoError(t, err) + assert.Equal(t, tc.wantAffected, affected) + tc.afterFunc(t, tx, values) + }) + } +} + +func TestDelayTransactionSuite(t *testing.T) { + suite.Run(t, &TestDelayTxTestSuite{ + ShardingTransactionSuite: newShardingTransactionSuite(), + }) +} diff --git a/internal/datasource/transaction/single_transaction.go b/internal/datasource/transaction/single_transaction.go new file mode 100644 index 00000000..5804ec37 --- /dev/null +++ b/internal/datasource/transaction/single_transaction.go @@ -0,0 +1,113 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transaction + +import ( + "context" + "database/sql" + "sync" + + "github.com/ecodeclub/eorm/internal/errs" + + "github.com/ecodeclub/eorm/internal/datasource" +) + +type SingleTxFactory struct{} + +func (SingleTxFactory) TxOf(ctx Context, finder datasource.Finder) (datasource.Tx, error) { + return NewSingleTx(ctx, finder), nil +} + +type SingleTx struct { + DB string + ctx Context + lock sync.RWMutex + tx datasource.Tx + finder datasource.Finder +} + +func (t *SingleTx) findTgt(ctx context.Context, query datasource.Query) (datasource.TxBeginner, error) { + return t.finder.FindTgt(ctx, query) +} + +func (t *SingleTx) findOrBeginTx(ctx context.Context, query datasource.Query) (datasource.Tx, error) { + t.lock.RLock() + if t.DB != "" && t.tx != nil { + if t.DB != query.DB { + t.lock.RUnlock() + return nil, errs.NewErrDBNotEqual(t.DB, query.DB) + } + t.lock.RUnlock() + return t.tx, nil + } + t.lock.RUnlock() + t.lock.Lock() + defer t.lock.Unlock() + if t.DB != "" && t.tx != nil { + if t.DB != query.DB { + return nil, errs.NewErrDBNotEqual(t.DB, query.DB) + } + return t.tx, nil + } + db, err := t.findTgt(ctx, query) + if err != nil { + return nil, err + } + tx, err := db.BeginTx(t.ctx.TxCtx, t.ctx.Opts) + if err != nil { + return nil, err + } + t.tx = tx + t.DB = query.DB + return tx, nil +} + +func (t *SingleTx) Query(ctx context.Context, query datasource.Query) (*sql.Rows, error) { + // 防止 GetMulti 的查询重复创建多个事务 + tx, err := t.findOrBeginTx(ctx, query) + if err != nil { + return nil, err + } + return tx.Query(ctx, query) +} + +func (t *SingleTx) Exec(ctx context.Context, query datasource.Query) (sql.Result, error) { + tx, err := t.findOrBeginTx(ctx, query) + if err != nil { + return nil, err + } + return tx.Exec(ctx, query) +} + +func (t *SingleTx) Commit() error { + if t.tx != nil { + return t.tx.Commit() + } + return nil +} + +func (t *SingleTx) Rollback() error { + if t.tx != nil { + return t.tx.Rollback() + } + return nil +} + +func NewSingleTx(ctx Context, finder datasource.Finder) *SingleTx { + return &SingleTx{ + ctx: ctx, + finder: finder, + } +} diff --git a/internal/datasource/transaction/single_transaction_test.go b/internal/datasource/transaction/single_transaction_test.go new file mode 100644 index 00000000..bb42afad --- /dev/null +++ b/internal/datasource/transaction/single_transaction_test.go @@ -0,0 +1,384 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transaction_test + +import ( + "context" + "database/sql" + "errors" + "regexp" + "strings" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/ecodeclub/eorm" + "github.com/ecodeclub/eorm/internal/datasource/masterslave" + "github.com/ecodeclub/eorm/internal/datasource/transaction" + "github.com/ecodeclub/eorm/internal/errs" + "github.com/ecodeclub/eorm/internal/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type TestSingleTxTestSuite struct { + ShardingTransactionSuite +} + +func (s *TestSingleTxTestSuite) TestExecute_Commit_Or_Rollback() { + t := s.T() + testCases := []struct { + name string + wantAffected int64 + wantErr error + shardingVal int + values []*test.OrderDetail + querySet []*test.OrderDetail + tx *eorm.Tx + mockOrder func(mock1, mock2 sqlmock.Sqlmock) + afterFunc func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) + }{ + { + name: "select insert commit", + wantAffected: 1, + shardingVal: 234, + values: []*test.OrderDetail{ + {OrderId: 288, ItemId: 101, UsingCol1: "Jimmy", UsingCol2: "Butler"}, + }, + querySet: []*test.OrderDetail{ + {OrderId: 234, ItemId: 12, UsingCol1: "Kevin", UsingCol2: "Durant"}, + }, + mockOrder: func(mock1, mock2 sqlmock.Sqlmock) { + mock1.MatchExpectationsInOrder(false) + mock1.ExpectBegin() + + mock1.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_0` WHERE `order_id`=?;")). + WithArgs(234). + WillReturnRows(mock1.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}).AddRow(234, 12, "Kevin", "Durant")) + + mock1.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_detail_db_0`.`order_detail_tab_0`(`order_id`,`item_id`,`using_col1`,`using_col2`) VALUES(?,?,?,?);")). + WithArgs(288, 101, "Jimmy", "Butler").WillReturnResult(sqlmock.NewResult(1, 1)) + + mock1.ExpectCommit() + }, + tx: func() *eorm.Tx { + tx, er := s.shardingDB.BeginTx( + transaction.UsingTxType(context.Background(), transaction.Single), &sql.TxOptions{}) + require.NoError(t, er) + return tx + }(), + afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) { + err := tx.Commit() + require.NoError(t, err) + + s.mockMaster.MatchExpectationsInOrder(false) + + s.mockMaster.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_0` WHERE `order_id`=?;")). + WithArgs(288).WillReturnRows(s.mockMaster.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}).AddRow(288, 101, "Jimmy", "Butler")) + + queryVal := s.findTgt(t, values) + assert.ElementsMatch(t, values, queryVal) + }, + }, + { + name: "select insert rollback", + wantAffected: 1, + shardingVal: 253, + values: []*test.OrderDetail{ + {OrderId: 199, ItemId: 100, UsingCol1: "Jason", UsingCol2: "Tatum"}, + }, + querySet: []*test.OrderDetail{ + {OrderId: 253, ItemId: 8, UsingCol1: "Stephen", UsingCol2: "Curry"}, + }, + mockOrder: func(mock1, mock2 sqlmock.Sqlmock) { + mock2.MatchExpectationsInOrder(false) + mock2.ExpectBegin() + + mock2.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_1` WHERE `order_id`=?;")). + WithArgs(253). + WillReturnRows(mock1.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}).AddRow(253, 8, "Stephen", "Curry")) + + mock2.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_detail_db_1`.`order_detail_tab_1`(`order_id`,`item_id`,`using_col1`,`using_col2`) VALUES(?,?,?,?);")). + WithArgs(199, 100, "Jason", "Tatum").WillReturnResult(sqlmock.NewResult(1, 1)) + + mock2.ExpectRollback() + }, + tx: func() *eorm.Tx { + tx, er := s.shardingDB.BeginTx( + transaction.UsingTxType(context.Background(), transaction.Single), &sql.TxOptions{}) + require.NoError(t, er) + return tx + }(), + afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) { + err := tx.Rollback() + require.NoError(t, err) + + s.mockMaster2.MatchExpectationsInOrder(false) + s.mockMaster2.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_1` WHERE `order_id`=?;")). + WithArgs(199).WillReturnRows(s.mockMaster2.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"})) + + queryVal := s.findTgt(t, values) + var wantOds []*test.OrderDetail + assert.ElementsMatch(t, wantOds, queryVal) + }, + }, + { + name: "insert use multi db err", + wantAffected: 2, + shardingVal: 234, + wantErr: errs.NewErrDBNotEqual("order_detail_db_0", "order_detail_db_1"), + values: []*test.OrderDetail{ + {OrderId: 288, ItemId: 101, UsingCol1: "Jimmy", UsingCol2: "Butler"}, + {OrderId: 33, ItemId: 100, UsingCol1: "Nikolai", UsingCol2: "Jokic"}, + }, + querySet: []*test.OrderDetail{ + {OrderId: 234, ItemId: 12, UsingCol1: "Kevin", UsingCol2: "Durant"}, + }, + mockOrder: func(mock1, mock2 sqlmock.Sqlmock) { + mock1.MatchExpectationsInOrder(false) + mock2.MatchExpectationsInOrder(false) + mock1.ExpectBegin() + mock2.ExpectBegin() + + mock1.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_0` WHERE `order_id`=?;")). + WithArgs(234). + WillReturnRows(mock1.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}).AddRow(234, 12, "Kevin", "Durant")) + + mock1.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_detail_db_0`.`order_detail_tab_0`(`order_id`,`item_id`,`using_col1`,`using_col2`) VALUES(?,?,?,?);")). + WithArgs(288, 101, "Jimmy", "Butler").WillReturnResult(sqlmock.NewResult(1, 1)) + mock2.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_detail_db_1`.`order_detail_tab_0`(`order_id`,`item_id`,`using_col1`,`using_col2`) VALUES(?,?,?,?);")). + WithArgs(33, 100, "Nikolai", "Jokic").WillReturnResult(sqlmock.NewResult(1, 1)) + + commitErr := errors.New("commit fail") + mock1.ExpectCommit().WillReturnError(commitErr) + mock2.ExpectCommit().WillReturnError(commitErr) + }, + tx: func() *eorm.Tx { + tx, er := s.shardingDB.BeginTx( + transaction.UsingTxType(context.Background(), transaction.Single), &sql.TxOptions{}) + require.NoError(t, er) + return tx + }(), + afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) { + err := tx.Commit() + newErr := errors.New("commit fail") + errSlice := strings.Split(err.Error(), "; ") + wantErrSlice := []string{ + newMockCommitErr("order_detail_db_0", newErr).Error(), + newMockCommitErr("order_detail_db_1", newErr).Error()} + assert.ElementsMatch(t, wantErrSlice, errSlice) + + s.mockMaster.MatchExpectationsInOrder(false) + s.mockMaster2.MatchExpectationsInOrder(false) + + //row1 := s.mockMaster.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}) + //row2 := s.mockMaster.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}) + //s.mockMaster.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_0` WHERE (`order_id`=?) OR (`order_id`=?);SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_1` WHERE (`order_id`=?) OR (`order_id`=?);SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_2` WHERE (`order_id`=?) OR (`order_id`=?);")). + // WithArgs(288, 33, 288, 33, 288, 33).WillReturnRows(row1) + // + //s.mockMaster2.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_0` WHERE (`order_id`=?) OR (`order_id`=?);SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_1` WHERE (`order_id`=?) OR (`order_id`=?);SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_2` WHERE (`order_id`=?) OR (`order_id`=?);")). + // WithArgs(288, 33, 288, 33, 288, 33).WillReturnRows(row2) + + row1 := s.mockMaster.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}) + row2 := s.mockMaster.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}) + s.mockMaster.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_0` WHERE (`order_id`=?) OR (`order_id`=?);")). + WithArgs(288, 33).WillReturnRows(row1) + + s.mockMaster2.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_0` WHERE (`order_id`=?) OR (`order_id`=?);")). + WithArgs(288, 33).WillReturnRows(row2) + + queryVal := s.findTgt(t, values) + var wantOds []*test.OrderDetail + assert.ElementsMatch(t, wantOds, queryVal) + }, + }, + { + name: "select and insert use multi db err", + wantAffected: 2, + shardingVal: 234, + wantErr: errs.NewErrDBNotEqual("order_detail_db_0", "order_detail_db_1"), + values: []*test.OrderDetail{ + {OrderId: 33, ItemId: 100, UsingCol1: "Nikolai", UsingCol2: "Jokic"}, + }, + querySet: []*test.OrderDetail{ + {OrderId: 234, ItemId: 12, UsingCol1: "Kevin", UsingCol2: "Durant"}, + }, + mockOrder: func(mock1, mock2 sqlmock.Sqlmock) { + mock1.MatchExpectationsInOrder(false) + mock1.ExpectBegin() + + mock1.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_0` WHERE `order_id`=?;")). + WithArgs(234). + WillReturnRows(mock1.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}).AddRow(234, 12, "Kevin", "Durant")) + + mock1.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_detail_db_1`.`order_detail_tab_0`(`order_id`,`item_id`,`using_col1`,`using_col2`) VALUES(?,?,?,?);")). + WithArgs(33, 100, "Nikolai", "Jokic").WillReturnResult(sqlmock.NewResult(1, 1)) + + mock1.ExpectCommit() + }, + tx: func() *eorm.Tx { + tx, er := s.shardingDB.BeginTx( + transaction.UsingTxType(context.Background(), transaction.Single), &sql.TxOptions{}) + require.NoError(t, er) + return tx + }(), + afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) { + err := tx.Commit() + newErr := errors.New("commit fail") + errSlice := strings.Split(err.Error(), "; ") + wantErrSlice := []string{ + newMockCommitErr("order_detail_db_0", newErr).Error(), + newMockCommitErr("order_detail_db_1", newErr).Error()} + assert.ElementsMatch(t, wantErrSlice, errSlice) + + s.mockMaster.MatchExpectationsInOrder(false) + s.mockMaster2.MatchExpectationsInOrder(false) + + row1 := s.mockMaster.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}) + row2 := s.mockMaster.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}) + //s.mockMaster.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_0` WHERE (`order_id`=?) OR (`order_id`=?);SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_1` WHERE (`order_id`=?) OR (`order_id`=?);SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_2` WHERE (`order_id`=?) OR (`order_id`=?);")). + // WithArgs(288, 33, 288, 33, 288, 33).WillReturnRows(row1) + // + //s.mockMaster2.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_0` WHERE (`order_id`=?) OR (`order_id`=?);SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_1` WHERE (`order_id`=?) OR (`order_id`=?);SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_2` WHERE (`order_id`=?) OR (`order_id`=?);")). + // WithArgs(288, 33, 288, 33, 288, 33).WillReturnRows(row2) + + s.mockMaster.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_0` WHERE `order_id`=?;")). + WithArgs(33).WillReturnRows(row1) + + s.mockMaster2.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_0` WHERE `order_id`=?;")). + WithArgs(33).WillReturnRows(row2) + + queryVal := s.findTgt(t, values) + var wantOds []*test.OrderDetail + assert.ElementsMatch(t, wantOds, queryVal) + }, + }, + { + name: "select insert commit err", + wantAffected: 1, + shardingVal: 234, + values: []*test.OrderDetail{ + {OrderId: 288, ItemId: 101, UsingCol1: "Jimmy", UsingCol2: "Butler"}, + }, + querySet: []*test.OrderDetail{ + {OrderId: 234, ItemId: 12, UsingCol1: "Kevin", UsingCol2: "Durant"}, + }, + mockOrder: func(mock1, mock2 sqlmock.Sqlmock) { + mock1.MatchExpectationsInOrder(false) + mock1.ExpectBegin() + + mock1.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_0` WHERE `order_id`=?;")). + WithArgs(234). + WillReturnRows(mock1.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}).AddRow(234, 12, "Kevin", "Durant")) + + mock1.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_detail_db_0`.`order_detail_tab_0`(`order_id`,`item_id`,`using_col1`,`using_col2`) VALUES(?,?,?,?);")). + WithArgs(288, 101, "Jimmy", "Butler").WillReturnResult(sqlmock.NewResult(1, 1)) + + mock1.ExpectCommit().WillReturnError(errors.New("commit fail")) + }, + tx: func() *eorm.Tx { + tx, er := s.shardingDB.BeginTx( + transaction.UsingTxType(context.Background(), transaction.Single), &sql.TxOptions{}) + require.NoError(t, er) + return tx + }(), + afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) { + err := tx.Commit() + wantErr := errors.New("commit fail") + assert.Equal(t, wantErr, err) + + s.mockMaster.MatchExpectationsInOrder(false) + s.mockMaster.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_0`.`order_detail_tab_0` WHERE `order_id`=?;")). + WithArgs(288).WillReturnRows(s.mockMaster.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"})) + + queryVal := s.findTgt(t, values) + var wantOds []*test.OrderDetail + assert.ElementsMatch(t, wantOds, queryVal) + }, + }, + { + name: "select insert rollback err", + wantAffected: 1, + shardingVal: 253, + values: []*test.OrderDetail{ + {OrderId: 33, ItemId: 100, UsingCol1: "Nikolai", UsingCol2: "Jokic"}, + }, + querySet: []*test.OrderDetail{ + {OrderId: 253, ItemId: 8, UsingCol1: "Stephen", UsingCol2: "Curry"}, + }, + mockOrder: func(mock1, mock2 sqlmock.Sqlmock) { + mock2.MatchExpectationsInOrder(false) + mock2.ExpectBegin() + + mock2.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_1` WHERE `order_id`=?;")). + WithArgs(253). + WillReturnRows(mock1.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}).AddRow(253, 8, "Stephen", "Curry")) + + mock2.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_detail_db_1`.`order_detail_tab_0`(`order_id`,`item_id`,`using_col1`,`using_col2`) VALUES(?,?,?,?);")). + WithArgs(33, 100, "Nikolai", "Jokic").WillReturnResult(sqlmock.NewResult(1, 1)) + + mock2.ExpectRollback().WillReturnError(errors.New("rollback fail")) + }, + tx: func() *eorm.Tx { + tx, er := s.shardingDB.BeginTx( + transaction.UsingTxType(context.Background(), transaction.Single), &sql.TxOptions{}) + require.NoError(t, er) + return tx + }(), + afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) { + err := tx.Rollback() + wantErr := errors.New("rollback fail") + assert.Equal(t, wantErr, err) + + s.mockMaster2.MatchExpectationsInOrder(false) + + s.mockMaster2.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_0` WHERE `order_id`=?")). + WithArgs(33).WillReturnRows(s.mockMaster2.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"})) + + queryVal := s.findTgt(t, values) + var wantOds []*test.OrderDetail + assert.ElementsMatch(t, wantOds, queryVal) + }, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.mockOrder(s.mockMaster, s.mockMaster2) + tx := tc.tx + querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx). + Where(eorm.C("OrderId").EQ(tc.shardingVal)). + GetMulti(masterslave.UseMaster(context.Background())) + require.NoError(t, err) + assert.ElementsMatch(t, tc.querySet, querySet) + + values := tc.values + res := eorm.NewShardingInsert[test.OrderDetail](tx). + Values(values).Exec(context.Background()) + affected, err := res.RowsAffected() + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + assert.Equal(t, tc.wantAffected, affected) + tc.afterFunc(t, tx, values) + }) + } +} + +func TestSingleTransactionSuite(t *testing.T) { + suite.Run(t, &TestSingleTxTestSuite{ + ShardingTransactionSuite: newShardingTransactionSuite(), + }) +} diff --git a/internal/datasource/transaction/transaction_suite_test.go b/internal/datasource/transaction/transaction_suite_test.go new file mode 100644 index 00000000..a7d768db --- /dev/null +++ b/internal/datasource/transaction/transaction_suite_test.go @@ -0,0 +1,187 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transaction_test + +import ( + "context" + "database/sql" + "fmt" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/ecodeclub/eorm" + "github.com/ecodeclub/eorm/internal/datasource" + "github.com/ecodeclub/eorm/internal/datasource/cluster" + "github.com/ecodeclub/eorm/internal/datasource/masterslave" + "github.com/ecodeclub/eorm/internal/datasource/masterslave/slaves" + "github.com/ecodeclub/eorm/internal/datasource/masterslave/slaves/roundrobin" + "github.com/ecodeclub/eorm/internal/datasource/shardingsource" + "github.com/ecodeclub/eorm/internal/model" + "github.com/ecodeclub/eorm/internal/sharding" + "github.com/ecodeclub/eorm/internal/sharding/hash" + "github.com/ecodeclub/eorm/internal/test" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type ShardingTransactionSuite struct { + shardingKey string + shardingDB *eorm.DB + algorithm sharding.Algorithm + + suite.Suite + datasource.DataSource + clusterDB datasource.DataSource + mockMaster1DB *sql.DB + mockMaster sqlmock.Sqlmock + + mockSlave1DB *sql.DB + mockSlave1 sqlmock.Sqlmock + + mockSlave2DB *sql.DB + mockSlave2 sqlmock.Sqlmock + + mockSlave3DB *sql.DB + mockSlave3 sqlmock.Sqlmock + + mockMaster2DB *sql.DB + mockMaster2 sqlmock.Sqlmock + + mockSlave4DB *sql.DB + mockSlave4 sqlmock.Sqlmock + + mockSlave5DB *sql.DB + mockSlave5 sqlmock.Sqlmock + + mockSlave6DB *sql.DB + mockSlave6 sqlmock.Sqlmock +} + +func (s *ShardingTransactionSuite) SetupTest() { + t := s.T() + s.initMock(t) +} + +func (s *ShardingTransactionSuite) initMock(t *testing.T) { + var err error + s.mockMaster1DB, s.mockMaster, err = sqlmock.New() + if err != nil { + t.Fatal(err) + } + s.mockSlave1DB, s.mockSlave1, err = sqlmock.New() + if err != nil { + t.Fatal(err) + } + s.mockSlave2DB, s.mockSlave2, err = sqlmock.New() + if err != nil { + t.Fatal(err) + } + s.mockSlave3DB, s.mockSlave3, err = sqlmock.New() + if err != nil { + t.Fatal(err) + } + + s.mockMaster2DB, s.mockMaster2, err = sqlmock.New() + if err != nil { + t.Fatal(err) + } + s.mockSlave4DB, s.mockSlave4, err = sqlmock.New() + if err != nil { + t.Fatal(err) + } + s.mockSlave5DB, s.mockSlave5, err = sqlmock.New() + if err != nil { + t.Fatal(err) + } + s.mockSlave6DB, s.mockSlave6, err = sqlmock.New() + if err != nil { + t.Fatal(err) + } + + db1 := masterslave.NewMasterSlavesDB(s.mockMaster1DB, masterslave.MasterSlavesWithSlaves( + newSlaves(t, s.mockSlave1DB, s.mockSlave2DB, s.mockSlave3DB))) + + db2 := masterslave.NewMasterSlavesDB(s.mockMaster2DB, masterslave.MasterSlavesWithSlaves( + newSlaves(t, s.mockSlave4DB, s.mockSlave5DB, s.mockSlave6DB))) + + s.clusterDB = cluster.NewClusterDB(map[string]*masterslave.MasterSlavesDB{ + "order_detail_db_0": db1, + "order_detail_db_1": db2, + }) + + s.DataSource = shardingsource.NewShardingDataSource(map[string]datasource.DataSource{ + "0.db.cluster.company.com:3306": s.clusterDB, + }) + + r := model.NewMetaRegistry() + sk := "OrderId" + s.algorithm = &hash.Hash{ + ShardingKey: sk, + DBPattern: &hash.Pattern{Name: "order_detail_db_%d", Base: 2}, + TablePattern: &hash.Pattern{Name: "order_detail_tab_%d", Base: 3}, + DsPattern: &hash.Pattern{Name: "0.db.cluster.company.com:3306", NotSharding: true}, + } + s.shardingKey = sk + _, err = r.Register(&test.OrderDetail{}, + model.WithTableShardingAlgorithm(s.algorithm)) + require.NoError(t, err) + db, err := eorm.OpenDS("mysql", s.DataSource, eorm.DBWithMetaRegistry(r)) + require.NoError(t, err) + s.shardingDB = db +} + +func (s *ShardingTransactionSuite) TearDownTest() { + _ = s.mockMaster1DB.Close() + _ = s.mockSlave1DB.Close() + _ = s.mockSlave2DB.Close() + _ = s.mockSlave3DB.Close() + + _ = s.mockMaster2DB.Close() + _ = s.mockSlave4DB.Close() + _ = s.mockSlave5DB.Close() + _ = s.mockSlave6DB.Close() +} + +func (s *ShardingTransactionSuite) findTgt(t *testing.T, values []*test.OrderDetail) []*test.OrderDetail { + od := values[0] + pre := eorm.C(s.shardingKey).EQ(od.OrderId) + for i := 1; i < len(values); i++ { + od = values[i] + pre = pre.Or(eorm.C(s.shardingKey).EQ(od.OrderId)) + } + // TODO GetMultiV2 待将 table 维度改成 db 维度 + querySet, err := eorm.NewShardingSelector[test.OrderDetail](s.shardingDB). + Where(pre).GetMultiV2(masterslave.UseMaster(context.Background())) + require.NoError(t, err) + return querySet +} + +func newShardingTransactionSuite() ShardingTransactionSuite { + return ShardingTransactionSuite{} +} + +func newSlaves(t *testing.T, dbs ...*sql.DB) slaves.Slaves { + res, err := roundrobin.NewSlaves(dbs...) + require.NoError(t, err) + return res +} + +func newMockCommitErr(dbName string, err error) error { + return fmt.Errorf("masterslave DB name [%s] Commit error: %w", dbName, err) +} + +func newMockRollbackErr(dbName string, err error) error { + return fmt.Errorf("masterslave DB name [%s] Rollback error: %w", dbName, err) +} diff --git a/internal/datasource/transaction/transaction_test.go b/internal/datasource/transaction/transaction_test.go index b69a83c9..e22bd95a 100644 --- a/internal/datasource/transaction/transaction_test.go +++ b/internal/datasource/transaction/transaction_test.go @@ -12,13 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -package transaction +package transaction_test import ( "context" "database/sql" "testing" + "github.com/ecodeclub/eorm/internal/datasource/transaction" + "github.com/stretchr/testify/suite" "github.com/DATA-DOG/go-sqlmock" @@ -96,7 +98,7 @@ func (db *testMockDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (datasou if err != nil { return nil, err } - return NewTx(tx, db), nil + return transaction.NewTx(tx, db), nil } func (db *testMockDB) Close() error { @@ -146,7 +148,7 @@ func (s *TransactionSuite) TestDBQuery() { //s.mock.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("value")) testCases := []struct { name string - tx *Tx + tx *transaction.Tx query datasource.Query mockRows *sqlmock.Rows wantResp []string @@ -157,14 +159,14 @@ func (s *TransactionSuite) TestDBQuery() { query: datasource.Query{ SQL: "SELECT `first_name` FROM `test_model`", }, - tx: func() *Tx { + tx: func() *transaction.Tx { s.mock1.ExpectBegin() s.mock1.ExpectQuery("SELECT *").WillReturnRows( sqlmock.NewRows([]string{"first_name"}).AddRow("value")) s.mock1.ExpectCommit() tx, err := s.mockDB1.BeginTx(context.Background(), &sql.TxOptions{}) assert.Nil(s.T(), err) - return NewTx(tx, NewMockDB(s.mockDB1)) + return transaction.NewTx(tx, NewMockDB(s.mockDB1)) }(), wantResp: []string{"value"}, }, @@ -202,7 +204,7 @@ func (s *TransactionSuite) TestDBExec() { rowsAffected int64 wantErr error isCommit bool - tx *Tx + tx *transaction.Tx query datasource.Query }{ { @@ -210,14 +212,14 @@ func (s *TransactionSuite) TestDBExec() { query: datasource.Query{ SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(1,2,3,4)", }, - tx: func() *Tx { + tx: func() *transaction.Tx { s.mock1.ExpectBegin() s.mock1.ExpectExec("^INSERT INTO (.+)"). WillReturnResult(sqlmock.NewResult(2, 1)) s.mock1.ExpectRollback() tx, err := s.mockDB1.BeginTx(context.Background(), &sql.TxOptions{}) assert.Nil(s.T(), err) - return NewTx(tx, NewMockDB(s.mockDB1)) + return transaction.NewTx(tx, NewMockDB(s.mockDB1)) }(), lastInsertId: int64(2), rowsAffected: int64(1), @@ -227,14 +229,14 @@ func (s *TransactionSuite) TestDBExec() { query: datasource.Query{ SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(1,2,3,4)", }, - tx: func() *Tx { + tx: func() *transaction.Tx { s.mock2.ExpectBegin() s.mock2.ExpectExec("^INSERT INTO (.+)"). WillReturnResult(sqlmock.NewResult(2, 1)) s.mock2.ExpectCommit() tx, err := s.mockDB2.BeginTx(context.Background(), &sql.TxOptions{}) assert.Nil(s.T(), err) - return NewTx(tx, NewMockDB(s.mockDB2)) + return transaction.NewTx(tx, NewMockDB(s.mockDB2)) }(), isCommit: true, lastInsertId: int64(2), @@ -245,14 +247,14 @@ func (s *TransactionSuite) TestDBExec() { query: datasource.Query{ SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(1,2,3,4) (1,2,3,4)", }, - tx: func() *Tx { + tx: func() *transaction.Tx { s.mock3.ExpectBegin() s.mock3.ExpectExec("^INSERT INTO (.+)"). WillReturnResult(sqlmock.NewResult(4, 2)) s.mock3.ExpectCommit() tx, err := s.mockDB3.BeginTx(context.Background(), &sql.TxOptions{}) assert.Nil(s.T(), err) - return NewTx(tx, NewMockDB(s.mockDB3)) + return transaction.NewTx(tx, NewMockDB(s.mockDB3)) }(), isCommit: true, lastInsertId: int64(4), @@ -300,7 +302,7 @@ func (m *mockDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (datasource.T if err != nil { return nil, err } - return NewTx(tx, m), nil + return transaction.NewTx(tx, m), nil } func (m *mockDB) Close() error { diff --git a/internal/datasource/transaction/types.go b/internal/datasource/transaction/types.go new file mode 100644 index 00000000..d0c0510c --- /dev/null +++ b/internal/datasource/transaction/types.go @@ -0,0 +1,80 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transaction + +import ( + "context" + "database/sql" + + "github.com/ecodeclub/eorm/internal/errs" + + "github.com/ecodeclub/eorm/internal/datasource" +) + +// 为了方便管理不同类型的 分布式 Tx,所以这里引入 TxType 常量来支持创建不同的 分布式Tx类型 以便提高后续引入 XA 方案的扩展性。 +const ( + Delay = "delay" + Single = "single" +) + +type TxFactory interface { + TxOf(ctx Context, finder datasource.Finder) (datasource.Tx, error) +} + +type Context struct { + TxName string + TxCtx context.Context + Opts *sql.TxOptions +} + +type TypeKey struct{} + +func UsingTxType(ctx context.Context, val string) context.Context { + return context.WithValue(ctx, TypeKey{}, val) +} + +func GetCtxTypeKey(ctx context.Context) any { + return ctx.Value(TypeKey{}) +} + +type TxFacade struct { + factory TxFactory + finder datasource.Finder +} + +func NewTxFacade(ctx context.Context, finder datasource.Finder) (TxFacade, error) { + res := TxFacade{ + finder: finder, + } + switch GetCtxTypeKey(ctx).(string) { + case Delay: + res.factory = DelayTxFactory{} + return res, nil + case Single: + res.factory = SingleTxFactory{} + return res, nil + default: + return TxFacade{}, errs.ErrUnsupportedDistributedTransaction + } +} + +func (t *TxFacade) BeginTx(ctx context.Context, opts *sql.TxOptions) (datasource.Tx, error) { + dsCtx := Context{ + TxCtx: ctx, + Opts: opts, + TxName: GetCtxTypeKey(ctx).(string), + } + return t.factory.TxOf(dsCtx, t.finder) +} diff --git a/internal/datasource/types.go b/internal/datasource/types.go index 465a144f..5e8495d5 100644 --- a/internal/datasource/types.go +++ b/internal/datasource/types.go @@ -30,6 +30,10 @@ type TxBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) } +type Finder interface { + FindTgt(ctx context.Context, query Query) (TxBeginner, error) +} + type Tx interface { Executor Commit() error diff --git a/internal/errs/error.go b/internal/errs/error.go index e39e0a05..0ef1b155 100644 --- a/internal/errs/error.go +++ b/internal/errs/error.go @@ -28,20 +28,35 @@ var ( ErrTooManyColumns = errors.New("eorm: 过多列") // ErrCombinationIsNotStruct 不支持的组合类型,eorm 只支持结构体组合 - ErrCombinationIsNotStruct = errors.New("eorm: 不支持的组合类型,eorm 只支持结构体组合") - ErrMissingShardingKey = errors.New("eorm: sharding key 未设置") - ErrOnlyResultOneQuery = errors.New("eorm: 只能生成一个 SQL") - ErrUnsupportedTooComplexQuery = errors.New("eorm: 暂未支持太复杂的查询") - ErrSlaveNotFound = errors.New("eorm: slave不存在") - ErrNotFoundTargetDataSource = errors.New("eorm: 未发现目标 data source") - ErrNotFoundTargetDB = errors.New("eorm: 未发现目标 DB") - ErrNotGenShardingQuery = errors.New("eorm: 未生成 sharding query") - ErrNotCompleteTxBeginner = errors.New("eorm: 未实现 TxBeginner 接口") - ErrInsertShardingKeyNotFound = errors.New("eorm: insert语句中未包含sharding key") - ErrInsertFindingDst = errors.New("eorm: 一行数据只能插入一个表") - ErrUnsupportedAssignment = errors.New("eorm: 不支持的 assignment") + ErrCombinationIsNotStruct = errors.New("eorm: 不支持的组合类型,eorm 只支持结构体组合") + ErrMissingShardingKey = errors.New("eorm: sharding key 未设置") + ErrOnlyResultOneQuery = errors.New("eorm: 只能生成一个 SQL") + ErrUnsupportedTooComplexQuery = errors.New("eorm: 暂未支持太复杂的查询") + ErrSlaveNotFound = errors.New("eorm: slave不存在") + ErrNotGenShardingQuery = errors.New("eorm: 未生成 sharding query") + ErrNotCompleteTxBeginner = errors.New("eorm: 未实现 TxBeginner 接口") + ErrInsertShardingKeyNotFound = errors.New("eorm: insert语句中未包含sharding key") + ErrInsertFindingDst = errors.New("eorm: 一行数据只能插入一个表") + ErrUnsupportedAssignment = errors.New("eorm: 不支持的 assignment") + ErrUnsupportedDistributedTransaction = errors.New("eorm: 不支持的分布式事务类型") ) +func NewErrDBNotEqual(oldDB, tgtDB string) error { + return fmt.Errorf("eorm:禁止跨库操作: %s 不等于 %s ", oldDB, tgtDB) +} + +func NewErrNotCompleteFinder(name string) error { + return fmt.Errorf("eorm: %s 未实现 Finder 接口", name) +} + +func NewErrNotFoundTargetDataSource(name string) error { + return fmt.Errorf("eorm: 未发现目标 data dource %s", name) +} + +func NewErrNotFoundTargetDB(name string) error { + return fmt.Errorf("eorm: 未发现目标 DB %s", name) +} + func NewErrUpdateShardingKeyUnsupported(field string) error { return fmt.Errorf("eorm: ShardingKey `%s` 不支持更新", field) } diff --git a/internal/integration/sharding_delay_transaction_test.go b/internal/integration/sharding_delay_transaction_test.go new file mode 100644 index 00000000..11b28253 --- /dev/null +++ b/internal/integration/sharding_delay_transaction_test.go @@ -0,0 +1,256 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build e2e + +package integration + +import ( + "context" + "database/sql" + "testing" + + "github.com/ecodeclub/eorm/internal/datasource/masterslave" + "github.com/ecodeclub/eorm/internal/datasource/transaction" + + "github.com/ecodeclub/eorm" + "github.com/ecodeclub/eorm/internal/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type ShardingDelayTxTestSuite struct { + ShardingSelectUpdateInsertSuite +} + +func (s *ShardingDelayTxTestSuite) TestDoubleShardingSelect() { + t := s.T() + testCases := []struct { + name string + wantErr error + querySet []*test.OrderDetail + txFunc func(t *testing.T) *eorm.Tx + }{ + { + name: "double select", + querySet: []*test.OrderDetail{ + {OrderId: 8, ItemId: 6, UsingCol1: "Kobe", UsingCol2: "Bryant"}, + {OrderId: 11, ItemId: 8, UsingCol1: "James", UsingCol2: "Harden"}, + {OrderId: 234, ItemId: 12, UsingCol1: "Kevin", UsingCol2: "Durant"}, + {OrderId: 253, ItemId: 8, UsingCol1: "Stephen", UsingCol2: "Curry"}, + {OrderId: 181, ItemId: 11, UsingCol1: "Kawhi", UsingCol2: "Leonard"}, + }, + txFunc: func(t *testing.T) *eorm.Tx { + tx, er := s.shardingDB.BeginTx( + transaction.UsingTxType(context.Background(), transaction.Delay), &sql.TxOptions{}) + require.NoError(t, er) + return tx + }, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tx := tc.txFunc(t) + defer tx.Commit() + querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx). + Where(eorm.C("OrderId").NEQ(123)). + GetMultiV2(masterslave.UseMaster(context.Background())) + require.NoError(t, err) + assert.ElementsMatch(t, tc.querySet, querySet) + + querySet, err = eorm.NewShardingSelector[test.OrderDetail](tx). + Where(eorm.C("OrderId").NEQ(123)). + GetMultiV2(masterslave.UseMaster(context.Background())) + require.NoError(t, err) + assert.ElementsMatch(t, tc.querySet, querySet) + }) + } +} + +func (s *ShardingDelayTxTestSuite) TestShardingSelectUpdateInsert_Commit_Or_Rollback() { + t := s.T() + testCases := []struct { + name string + updateAffected int64 + insertAffected int64 + target *test.OrderDetail + upPre eorm.Predicate + insertValues []*test.OrderDetail + querySet []*test.OrderDetail + upQuerySet []*test.OrderDetail + txFunc func(t *testing.T) *eorm.Tx + afterFunc func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) + }{ + { + name: "select insert update commit", + upPre: eorm.C("OrderId").EQ(181), + updateAffected: 1, + insertAffected: 2, + target: &test.OrderDetail{UsingCol1: "Jordan"}, + insertValues: []*test.OrderDetail{ + {OrderId: 33, ItemId: 100, UsingCol1: "Nikolai", UsingCol2: "Jokic"}, + {OrderId: 288, ItemId: 101, UsingCol1: "Jimmy", UsingCol2: "Butler"}, + }, + querySet: []*test.OrderDetail{ + {OrderId: 8, ItemId: 6, UsingCol1: "Kobe", UsingCol2: "Bryant"}, + {OrderId: 11, ItemId: 8, UsingCol1: "James", UsingCol2: "Harden"}, + {OrderId: 181, ItemId: 11, UsingCol1: "Kawhi", UsingCol2: "Leonard"}, + {OrderId: 234, ItemId: 12, UsingCol1: "Kevin", UsingCol2: "Durant"}, + {OrderId: 253, ItemId: 8, UsingCol1: "Stephen", UsingCol2: "Curry"}, + }, + upQuerySet: []*test.OrderDetail{ + {OrderId: 8, ItemId: 6, UsingCol1: "Kobe", UsingCol2: "Bryant"}, + {OrderId: 11, ItemId: 8, UsingCol1: "James", UsingCol2: "Harden"}, + {OrderId: 33, ItemId: 100, UsingCol1: "Nikolai", UsingCol2: "Jokic"}, + {OrderId: 181, ItemId: 11, UsingCol1: "Jordan", UsingCol2: "Leonard"}, + {OrderId: 234, ItemId: 12, UsingCol1: "Kevin", UsingCol2: "Durant"}, + {OrderId: 253, ItemId: 8, UsingCol1: "Stephen", UsingCol2: "Curry"}, + {OrderId: 288, ItemId: 101, UsingCol1: "Jimmy", UsingCol2: "Butler"}, + }, + txFunc: func(t *testing.T) *eorm.Tx { + tx, er := s.shardingDB.BeginTx( + transaction.UsingTxType(context.Background(), transaction.Delay), &sql.TxOptions{}) + require.NoError(t, er) + return tx + }, + afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) { + err := tx.Commit() + require.NoError(t, err) + + queryVal := s.findTgt(t, values) + assert.ElementsMatch(t, values, queryVal) + }, + }, + { + name: "select insert update broadcast commit", + upPre: eorm.C("OrderId").GTEQ(253), + updateAffected: 2, + insertAffected: 2, + target: &test.OrderDetail{UsingCol1: "Jordan"}, + insertValues: []*test.OrderDetail{ + {OrderId: 199, ItemId: 100, UsingCol1: "Jason", UsingCol2: "Tatum"}, + {OrderId: 299, ItemId: 101, UsingCol1: "Paul", UsingCol2: "George"}, + }, + querySet: []*test.OrderDetail{ + {OrderId: 8, ItemId: 6, UsingCol1: "Kobe", UsingCol2: "Bryant"}, + {OrderId: 11, ItemId: 8, UsingCol1: "James", UsingCol2: "Harden"}, + {OrderId: 33, ItemId: 100, UsingCol1: "Nikolai", UsingCol2: "Jokic"}, + {OrderId: 181, ItemId: 11, UsingCol1: "Jordan", UsingCol2: "Leonard"}, + {OrderId: 234, ItemId: 12, UsingCol1: "Kevin", UsingCol2: "Durant"}, + {OrderId: 253, ItemId: 8, UsingCol1: "Stephen", UsingCol2: "Curry"}, + {OrderId: 288, ItemId: 101, UsingCol1: "Jimmy", UsingCol2: "Butler"}, + }, + upQuerySet: []*test.OrderDetail{ + {OrderId: 8, ItemId: 6, UsingCol1: "Kobe", UsingCol2: "Bryant"}, + {OrderId: 11, ItemId: 8, UsingCol1: "James", UsingCol2: "Harden"}, + {OrderId: 33, ItemId: 100, UsingCol1: "Nikolai", UsingCol2: "Jokic"}, + {OrderId: 181, ItemId: 11, UsingCol1: "Jordan", UsingCol2: "Leonard"}, + {OrderId: 199, ItemId: 100, UsingCol1: "Jason", UsingCol2: "Tatum"}, + {OrderId: 234, ItemId: 12, UsingCol1: "Kevin", UsingCol2: "Durant"}, + {OrderId: 253, ItemId: 8, UsingCol1: "Jordan", UsingCol2: "Curry"}, + {OrderId: 288, ItemId: 101, UsingCol1: "Jordan", UsingCol2: "Butler"}, + {OrderId: 299, ItemId: 101, UsingCol1: "Paul", UsingCol2: "George"}, + }, + txFunc: func(t *testing.T) *eorm.Tx { + tx, er := s.shardingDB.BeginTx( + transaction.UsingTxType(context.Background(), transaction.Delay), &sql.TxOptions{}) + require.NoError(t, er) + return tx + }, + afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) { + err := tx.Commit() + require.NoError(t, err) + + queryVal := s.findTgt(t, values) + assert.ElementsMatch(t, values, queryVal) + }, + }, + { + name: "select insert update rollback", + upPre: eorm.C("OrderId").EQ(299), + updateAffected: 1, + insertAffected: 1, + target: &test.OrderDetail{UsingCol1: "Jordan"}, + insertValues: []*test.OrderDetail{ + {OrderId: 48, ItemId: 100}, + }, + querySet: []*test.OrderDetail{ + {OrderId: 8, ItemId: 6, UsingCol1: "Kobe", UsingCol2: "Bryant"}, + {OrderId: 11, ItemId: 8, UsingCol1: "James", UsingCol2: "Harden"}, + {OrderId: 33, ItemId: 100, UsingCol1: "Nikolai", UsingCol2: "Jokic"}, + {OrderId: 181, ItemId: 11, UsingCol1: "Jordan", UsingCol2: "Leonard"}, + {OrderId: 199, ItemId: 100, UsingCol1: "Jason", UsingCol2: "Tatum"}, + {OrderId: 234, ItemId: 12, UsingCol1: "Kevin", UsingCol2: "Durant"}, + {OrderId: 253, ItemId: 8, UsingCol1: "Jordan", UsingCol2: "Curry"}, + {OrderId: 288, ItemId: 101, UsingCol1: "Jordan", UsingCol2: "Butler"}, + {OrderId: 299, ItemId: 101, UsingCol1: "Paul", UsingCol2: "George"}, + }, + upQuerySet: []*test.OrderDetail{ + {OrderId: 8, ItemId: 6, UsingCol1: "Kobe", UsingCol2: "Bryant"}, + {OrderId: 11, ItemId: 8, UsingCol1: "James", UsingCol2: "Harden"}, + {OrderId: 33, ItemId: 100, UsingCol1: "Nikolai", UsingCol2: "Jokic"}, + {OrderId: 181, ItemId: 11, UsingCol1: "Jordan", UsingCol2: "Leonard"}, + {OrderId: 199, ItemId: 100, UsingCol1: "Jason", UsingCol2: "Tatum"}, + {OrderId: 234, ItemId: 12, UsingCol1: "Kevin", UsingCol2: "Durant"}, + {OrderId: 253, ItemId: 8, UsingCol1: "Jordan", UsingCol2: "Curry"}, + {OrderId: 288, ItemId: 101, UsingCol1: "Jordan", UsingCol2: "Butler"}, + {OrderId: 299, ItemId: 101, UsingCol1: "Paul", UsingCol2: "George"}, + }, + txFunc: func(t *testing.T) *eorm.Tx { + tx, er := s.shardingDB.BeginTx( + transaction.UsingTxType(context.Background(), transaction.Delay), &sql.TxOptions{}) + require.NoError(t, er) + return tx + }, + afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) { + err := tx.Rollback() + require.NoError(t, err) + + queryVal := s.findTgt(t, values) + assert.ElementsMatch(t, values, queryVal) + }, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tx := tc.txFunc(t) + querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx). + Where(eorm.C("OrderId").NEQ(123)). + GetMultiV2(masterslave.UseMaster(context.Background())) + require.NoError(t, err) + assert.ElementsMatch(t, tc.querySet, querySet) + + res := eorm.NewShardingUpdater[test.OrderDetail](tx).Update(tc.target). + Set(eorm.C("UsingCol1")).Where(tc.upPre).Exec(context.Background()) + affected, err := res.RowsAffected() + require.NoError(t, err) + assert.Equal(t, tc.updateAffected, affected) + + res = eorm.NewShardingInsert[test.OrderDetail](tx). + Values(tc.insertValues).Exec(context.Background()) + affected, err = res.RowsAffected() + require.NoError(t, err) + assert.Equal(t, tc.insertAffected, affected) + + tc.afterFunc(t, tx, tc.upQuerySet) + }) + } +} + +func TestMySQL8ShardingDelayTxTestSuite(t *testing.T) { + suite.Run(t, &ShardingDelayTxTestSuite{ + ShardingSelectUpdateInsertSuite: newShardingSelectUpdateInsertSuite(), + }) +} diff --git a/internal/integration/sharding_select_test.go b/internal/integration/sharding_select_test.go index c31a5135..79c5301e 100644 --- a/internal/integration/sharding_select_test.go +++ b/internal/integration/sharding_select_test.go @@ -18,17 +18,11 @@ package integration import ( "context" - "fmt" "testing" - "time" "github.com/ecodeclub/eorm" - "github.com/ecodeclub/eorm/internal/datasource" "github.com/ecodeclub/eorm/internal/datasource/masterslave" "github.com/ecodeclub/eorm/internal/model" - operator "github.com/ecodeclub/eorm/internal/operator" - "github.com/ecodeclub/eorm/internal/sharding" - "github.com/ecodeclub/eorm/internal/sharding/hash" "github.com/ecodeclub/eorm/internal/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -36,32 +30,7 @@ import ( ) type ShardingSelectTestSuite struct { - ShardingSuite - data []*test.OrderDetail -} - -func (s *ShardingSelectTestSuite) SetupSuite() { - t := s.T() - s.ShardingSuite.SetupSuite() - for _, item := range s.data { - shardingRes, err := s.algorithm.Sharding( - context.Background(), sharding.Request{Op: operator.OpEQ, SkValues: map[string]any{"OrderId": item.OrderId}}) - require.NoError(t, err) - require.NotNil(t, shardingRes.Dsts) - for _, dst := range shardingRes.Dsts { - tbl := fmt.Sprintf("`%s`.`%s`", dst.DB, dst.Table) - sql := fmt.Sprintf("INSERT INTO %s (`order_id`,`item_id`,`using_col1`,`using_col2`) VALUES(?,?,?,?);", tbl) - args := []any{item.OrderId, item.ItemId, item.UsingCol1, item.UsingCol2} - source, ok := s.dataSources[dst.Name] - require.True(t, ok) - _, err := source.Exec(context.Background(), datasource.Query{SQL: sql, Args: args, DB: dst.DB}) - if err != nil { - t.Fatal(err) - } - } - } - // 防止主从延迟 - time.Sleep(1) + ShardingSelectUpdateInsertSuite } func (s *ShardingSelectTestSuite) TestSardingSelectorGet() { @@ -70,7 +39,7 @@ func (s *ShardingSelectTestSuite) TestSardingSelectorGet() { _, err := r.Register(&test.OrderDetail{}, model.WithTableShardingAlgorithm(s.algorithm)) require.NoError(t, err) - eorm.DBOptionWithMetaRegistry(r)(s.shardingDB) + eorm.DBWithMetaRegistry(r)(s.shardingDB) testCases := []struct { name string @@ -127,7 +96,7 @@ func (s *ShardingSelectTestSuite) TestSardingSelectorGetMulti() { _, err := r.Register(&test.OrderDetail{}, model.WithTableShardingAlgorithm(s.algorithm)) require.NoError(t, err) - eorm.DBOptionWithMetaRegistry(r)(s.shardingDB) + eorm.DBWithMetaRegistry(r)(s.shardingDB) testCases := []struct { name string @@ -583,60 +552,8 @@ func (s *ShardingSelectTestSuite) TestSardingSelectorGetMulti() { } } -func (s *ShardingSelectTestSuite) TearDownSuite() { - t := s.T() - for _, item := range s.data { - shardingRes, err := s.algorithm.Sharding( - context.Background(), sharding.Request{Op: operator.OpEQ, SkValues: map[string]any{"OrderId": item.OrderId}}) - require.NoError(t, err) - require.NotNil(t, shardingRes.Dsts) - for _, dst := range shardingRes.Dsts { - tbl := fmt.Sprintf("`%s`.`%s`", dst.DB, dst.Table) - sql := fmt.Sprintf("DELETE FROM %s", tbl) - source, ok := s.dataSources[dst.Name] - require.True(t, ok) - _, err := source.Exec(context.Background(), datasource.Query{SQL: sql, DB: dst.DB}) - if err != nil { - t.Fatal(err) - } - } - } -} - func TestMySQL8ShardingSelect(t *testing.T) { - m := []*masterSalvesDriver{ - { - masterdsn: "root:root@tcp(localhost:13307)/order_detail_db_0", - slavedsns: []string{"root:root@tcp(localhost:13308)/order_detail_db_0"}, - }, - { - masterdsn: "root:root@tcp(localhost:13307)/order_detail_db_1", - slavedsns: []string{"root:root@tcp(localhost:13308)/order_detail_db_1"}, - }, - } - clusterDr := &clusterDriver{msDrivers: m} suite.Run(t, &ShardingSelectTestSuite{ - ShardingSuite: ShardingSuite{ - driver: "mysql", - algorithm: &hash.Hash{ - ShardingKey: "OrderId", - DBPattern: &hash.Pattern{Name: "order_detail_db_%d", Base: 2}, - TablePattern: &hash.Pattern{Name: "order_detail_tab_%d", Base: 3}, - DsPattern: &hash.Pattern{Name: "root:root@tcp(localhost:13307).0", NotSharding: true}, - }, - DBPattern: "order_detail_db_%d", - DsPattern: "root:root@tcp(localhost:13307).%d", - clusters: &clusterDrivers{ - clDrivers: []*clusterDriver{clusterDr}, - }, - }, - data: []*test.OrderDetail{ - {8, 6, "Kobe", "Bryant"}, - {11, 8, "James", "Harden"}, - {123, 10, "LeBron", "James"}, - {234, 12, "Kevin", "Durant"}, - {253, 8, "Stephen", "Curry"}, - {181, 11, "Kawhi", "Leonard"}, - }, + ShardingSelectUpdateInsertSuite: newShardingSelectUpdateInsertSuite(), }) } diff --git a/internal/integration/sharding_single_transaction_test.go b/internal/integration/sharding_single_transaction_test.go new file mode 100644 index 00000000..127af56a --- /dev/null +++ b/internal/integration/sharding_single_transaction_test.go @@ -0,0 +1,240 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//go:build e2e + +package integration + +import ( + "context" + "database/sql" + "testing" + + "github.com/ecodeclub/eorm/internal/datasource/masterslave" + "github.com/ecodeclub/eorm/internal/datasource/transaction" + + "github.com/ecodeclub/eorm" + "github.com/ecodeclub/eorm/internal/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type ShardingSingleTxTestSuite struct { + ShardingSelectUpdateInsertSuite +} + +func (s *ShardingSingleTxTestSuite) TestDoubleShardingSelect() { + t := s.T() + testCases := []struct { + name string + querySet []*test.OrderDetail + txFunc func(t *testing.T) *eorm.Tx + }{ + { + name: "double select", + querySet: []*test.OrderDetail{ + {OrderId: 123, ItemId: 10, UsingCol1: "LeBron", UsingCol2: "James"}, + }, + txFunc: func(t *testing.T) *eorm.Tx { + tx, er := s.shardingDB.BeginTx( + transaction.UsingTxType(context.Background(), transaction.Single), &sql.TxOptions{}) + require.NoError(t, er) + return tx + }, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tx := tc.txFunc(t) + defer tx.Commit() + querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx). + Where(eorm.C("OrderId").EQ(123)). + GetMultiV2(masterslave.UseMaster(context.Background())) + require.NoError(t, err) + assert.ElementsMatch(t, tc.querySet, querySet) + + querySet, err = eorm.NewShardingSelector[test.OrderDetail](tx). + Where(eorm.C("OrderId").EQ(123)). + GetMultiV2(masterslave.UseMaster(context.Background())) + require.NoError(t, err) + assert.ElementsMatch(t, tc.querySet, querySet) + }) + } +} + +func (s *ShardingSingleTxTestSuite) TestShardingSelectInsert_Commit_Or_Rollback() { + t := s.T() + testCases := []struct { + name string + wantAffected int64 + values []*test.OrderDetail + querySet []*test.OrderDetail + txFunc func(t *testing.T) *eorm.Tx + afterFunc func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) + }{ + { + name: "select insert commit", + wantAffected: 1, + values: []*test.OrderDetail{ + {OrderId: 33, ItemId: 100, UsingCol1: "Nikolai", UsingCol2: "Jokic"}, + }, + querySet: []*test.OrderDetail{ + {OrderId: 123, ItemId: 10, UsingCol1: "LeBron", UsingCol2: "James"}, + }, + txFunc: func(t *testing.T) *eorm.Tx { + tx, er := s.shardingDB.BeginTx( + transaction.UsingTxType(context.Background(), transaction.Single), &sql.TxOptions{}) + require.NoError(t, er) + return tx + }, + afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) { + err := tx.Commit() + require.NoError(t, err) + + queryVal := s.findTgt(t, values) + assert.ElementsMatch(t, values, queryVal) + }, + }, + { + name: "select insert rollback", + wantAffected: 1, + values: []*test.OrderDetail{ + {OrderId: 199, ItemId: 100, UsingCol1: "Jason", UsingCol2: "Tatum"}, + }, + querySet: []*test.OrderDetail{ + {OrderId: 123, ItemId: 10, UsingCol1: "LeBron", UsingCol2: "James"}, + }, + txFunc: func(t *testing.T) *eorm.Tx { + tx, er := s.shardingDB.BeginTx( + transaction.UsingTxType(context.Background(), transaction.Single), &sql.TxOptions{}) + require.NoError(t, er) + return tx + }, + afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) { + var wantOds []*test.OrderDetail + err := tx.Rollback() + require.NoError(t, err) + + queryVal := s.findTgt(t, values) + assert.ElementsMatch(t, wantOds, queryVal) + }, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tx := tc.txFunc(t) + querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx). + Where(eorm.C("OrderId").EQ(123)). + GetMultiV2(masterslave.UseMaster(context.Background())) + require.NoError(t, err) + assert.ElementsMatch(t, tc.querySet, querySet) + res := eorm.NewShardingInsert[test.OrderDetail](tx). + Values(tc.values).Exec(context.Background()) + affected, err := res.RowsAffected() + require.NoError(t, err) + assert.Equal(t, tc.wantAffected, affected) + tc.afterFunc(t, tx, tc.values) + }) + } +} + +func (s *ShardingSingleTxTestSuite) TestShardingSelectUpdate_Commit_Or_Rollback() { + t := s.T() + testCases := []struct { + name string + wantAffected int64 + target *test.OrderDetail + upPre eorm.Predicate + querySet []*test.OrderDetail + upQuerySet []*test.OrderDetail + txFunc func(t *testing.T) *eorm.Tx + afterFunc func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) + }{ + { + name: "select update where eq commit", + upPre: eorm.C("OrderId").EQ(11), + wantAffected: 1, + target: &test.OrderDetail{UsingCol1: "ben"}, + querySet: []*test.OrderDetail{ + {OrderId: 123, ItemId: 10, UsingCol1: "LeBron", UsingCol2: "James"}, + }, + upQuerySet: []*test.OrderDetail{ + {OrderId: 11, ItemId: 8, UsingCol1: "ben", UsingCol2: "Harden"}, + {OrderId: 123, ItemId: 10, UsingCol1: "LeBron", UsingCol2: "James"}, + }, + txFunc: func(t *testing.T) *eorm.Tx { + tx, er := s.shardingDB.BeginTx( + transaction.UsingTxType(context.Background(), transaction.Single), &sql.TxOptions{}) + require.NoError(t, er) + return tx + }, + afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) { + err := tx.Commit() + require.NoError(t, err) + + queryVal := s.findTgt(t, values) + assert.ElementsMatch(t, values, queryVal) + }, + }, + { + name: "select update rollback", + upPre: eorm.C("OrderId").EQ(181), + wantAffected: 1, + target: &test.OrderDetail{UsingCol1: "Jordan"}, + querySet: []*test.OrderDetail{ + {OrderId: 123, ItemId: 10, UsingCol1: "LeBron", UsingCol2: "James"}, + }, + upQuerySet: []*test.OrderDetail{ + {OrderId: 181, ItemId: 11, UsingCol1: "Kawhi", UsingCol2: "Leonard"}, + {OrderId: 123, ItemId: 10, UsingCol1: "LeBron", UsingCol2: "James"}, + }, + txFunc: func(t *testing.T) *eorm.Tx { + tx, er := s.shardingDB.BeginTx( + transaction.UsingTxType(context.Background(), transaction.Single), &sql.TxOptions{}) + require.NoError(t, er) + return tx + }, + afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) { + err := tx.Rollback() + require.NoError(t, err) + + queryVal := s.findTgt(t, values) + assert.ElementsMatch(t, values, queryVal) + }, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tx := tc.txFunc(t) + querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx). + Where(eorm.C("OrderId").EQ(123)). + GetMultiV2(masterslave.UseMaster(context.Background())) + require.NoError(t, err) + assert.ElementsMatch(t, tc.querySet, querySet) + res := eorm.NewShardingUpdater[test.OrderDetail](tx).Update(tc.target). + Set(eorm.C("UsingCol1")).Where(tc.upPre).Exec(context.Background()) + affected, err := res.RowsAffected() + require.NoError(t, err) + assert.Equal(t, tc.wantAffected, affected) + tc.afterFunc(t, tx, tc.upQuerySet) + }) + } +} + +func TestMySQL8ShardingSingleTxTestSuite(t *testing.T) { + suite.Run(t, &ShardingSingleTxTestSuite{ + ShardingSelectUpdateInsertSuite: newShardingSelectUpdateInsertSuite(), + }) +} diff --git a/internal/integration/sharding_suite_test.go b/internal/integration/sharding_suite_test.go index 5bdb253c..538e21ac 100644 --- a/internal/integration/sharding_suite_test.go +++ b/internal/integration/sharding_suite_test.go @@ -22,6 +22,7 @@ import ( "database/sql/driver" "fmt" "log" + "testing" "time" "github.com/ecodeclub/eorm" @@ -32,6 +33,7 @@ import ( "github.com/ecodeclub/eorm/internal/datasource/masterslave/slaves/roundrobin" "github.com/ecodeclub/eorm/internal/datasource/shardingsource" "github.com/ecodeclub/eorm/internal/model" + operator "github.com/ecodeclub/eorm/internal/operator" "github.com/ecodeclub/eorm/internal/sharding" "github.com/ecodeclub/eorm/internal/sharding/hash" "github.com/ecodeclub/eorm/internal/test" @@ -51,17 +53,18 @@ type ShardingSuite struct { DBPattern string DsPattern string TablePattern string + ShardingKey string } func newDefaultShardingSuite() ShardingSuite { m := []*masterSalvesDriver{ { - masterdsn: "root:root@tcp(localhost:13307)/order_detail_db_0", - slavedsns: []string{"root:root@tcp(localhost:13308)/order_detail_db_0"}, + masterdsn: "root:root@tcp(localhost:13307)/order_detail_db_0?multiStatements=true&interpolateParams=true", + slavedsns: []string{"root:root@tcp(localhost:13308)/order_detail_db_0?multiStatements=true&interpolateParams=true"}, }, { - masterdsn: "root:root@tcp(localhost:13307)/order_detail_db_1", - slavedsns: []string{"root:root@tcp(localhost:13308)/order_detail_db_1"}, + masterdsn: "root:root@tcp(localhost:13307)/order_detail_db_1?multiStatements=true&interpolateParams=true", + slavedsns: []string{"root:root@tcp(localhost:13308)/order_detail_db_1?multiStatements=true&interpolateParams=true"}, }, } clusterDr := &clusterDriver{msDrivers: m} @@ -69,13 +72,7 @@ func newDefaultShardingSuite() ShardingSuite { dsPattern := "root:root@tcp(localhost:13307).%d" tablePattern := "order_detail_tab_%d" return ShardingSuite{ - driver: "mysql", - algorithm: &hash.Hash{ - ShardingKey: "OrderId", - DBPattern: &hash.Pattern{Name: dbPattern, Base: 2}, - TablePattern: &hash.Pattern{Name: tablePattern, Base: 3}, - DsPattern: &hash.Pattern{Name: "root:root@tcp(localhost:13307).0", NotSharding: true}, - }, + driver: "mysql", DBPattern: dbPattern, DsPattern: dsPattern, TablePattern: tablePattern, @@ -148,15 +145,99 @@ func (s *ShardingSuite) initDB(r model.MetaRegistry) (*eorm.DB, error) { } s.dataSources = sourceMap dataSource := shardingsource.NewShardingDataSource(sourceMap) - return eorm.OpenDS(s.driver, dataSource, eorm.DBOptionWithMetaRegistry(r)) + return eorm.OpenDS(s.driver, dataSource, eorm.DBWithMetaRegistry(r)) } func (s *ShardingSuite) SetupSuite() { t := s.T() r := model.NewMetaRegistry() + sk := "OrderId" + s.algorithm = &hash.Hash{ + ShardingKey: sk, + DBPattern: &hash.Pattern{Name: s.DBPattern, Base: 2}, + TablePattern: &hash.Pattern{Name: s.TablePattern, Base: 3}, + DsPattern: &hash.Pattern{Name: "root:root@tcp(localhost:13307).0", NotSharding: true}, + } + s.ShardingKey = sk _, err := r.Register(&test.OrderDetail{}, model.WithTableShardingAlgorithm(s.algorithm)) db, err := s.initDB(r) require.NoError(t, err) s.shardingDB = db } + +type ShardingSelectUpdateInsertSuite struct { + ShardingSuite + data []*test.OrderDetail +} + +func newShardingSelectUpdateInsertSuite() ShardingSelectUpdateInsertSuite { + return ShardingSelectUpdateInsertSuite{ + ShardingSuite: newDefaultShardingSuite(), + data: []*test.OrderDetail{ + {OrderId: 8, ItemId: 6, UsingCol1: "Kobe", UsingCol2: "Bryant"}, + {OrderId: 11, ItemId: 8, UsingCol1: "James", UsingCol2: "Harden"}, + {OrderId: 123, ItemId: 10, UsingCol1: "LeBron", UsingCol2: "James"}, + {OrderId: 234, ItemId: 12, UsingCol1: "Kevin", UsingCol2: "Durant"}, + {OrderId: 253, ItemId: 8, UsingCol1: "Stephen", UsingCol2: "Curry"}, + {OrderId: 181, ItemId: 11, UsingCol1: "Kawhi", UsingCol2: "Leonard"}, + }, + } +} + +func (s *ShardingSelectUpdateInsertSuite) SetupSuite() { + t := s.T() + s.ShardingSuite.SetupSuite() + for _, item := range s.data { + shardingRes, err := s.algorithm.Sharding( + context.Background(), sharding.Request{Op: operator.OpEQ, SkValues: map[string]any{s.ShardingKey: item.OrderId}}) + require.NoError(t, err) + require.NotNil(t, shardingRes.Dsts) + for _, dst := range shardingRes.Dsts { + tbl := fmt.Sprintf("`%s`.`%s`", dst.DB, dst.Table) + sql := fmt.Sprintf("INSERT INTO %s (`order_id`,`item_id`,`using_col1`,`using_col2`) VALUES(?,?,?,?);", tbl) + args := []any{item.OrderId, item.ItemId, item.UsingCol1, item.UsingCol2} + source, ok := s.dataSources[dst.Name] + require.True(t, ok) + _, err := source.Exec(context.Background(), datasource.Query{SQL: sql, Args: args, DB: dst.DB}) + if err != nil { + t.Fatal(err) + } + } + } + // 防止主从延迟 + time.Sleep(1) +} + +func (s *ShardingSelectUpdateInsertSuite) findTgt(t *testing.T, values []*test.OrderDetail) []*test.OrderDetail { + od := values[0] + pre := eorm.C(s.ShardingKey).EQ(od.OrderId) + for i := 1; i < len(values); i++ { + od = values[i] + pre = pre.Or(eorm.C(s.ShardingKey).EQ(od.OrderId)) + } + querySet, err := eorm.NewShardingSelector[test.OrderDetail](s.shardingDB). + Where(pre).GetMulti(masterslave.UseMaster(context.Background())) + require.NoError(t, err) + return querySet +} + +func (s *ShardingSelectUpdateInsertSuite) TearDownSuite() { + t := s.T() + for _, item := range s.data { + shardingRes, err := s.algorithm.Sharding( + context.Background(), sharding.Request{Op: operator.OpEQ, SkValues: map[string]any{"OrderId": item.OrderId}}) + require.NoError(t, err) + require.NotNil(t, shardingRes.Dsts) + for _, dst := range shardingRes.Dsts { + tbl := fmt.Sprintf("`%s`.`%s`", dst.DB, dst.Table) + sql := fmt.Sprintf("DELETE FROM %s", tbl) + source, ok := s.dataSources[dst.Name] + require.True(t, ok) + _, err := source.Exec(context.Background(), datasource.Query{SQL: sql, DB: dst.DB}) + if err != nil { + t.Fatal(err) + } + } + } +} diff --git a/internal/integration/sharding_update_test.go b/internal/integration/sharding_update_test.go index ca06e0a0..fda248da 100644 --- a/internal/integration/sharding_update_test.go +++ b/internal/integration/sharding_update_test.go @@ -18,17 +18,12 @@ package integration import ( "context" - "fmt" "testing" - "time" "github.com/ecodeclub/eorm" - "github.com/ecodeclub/eorm/internal/datasource" "github.com/ecodeclub/eorm/internal/datasource/masterslave" "github.com/ecodeclub/eorm/internal/model" - operator "github.com/ecodeclub/eorm/internal/operator" "github.com/ecodeclub/eorm/internal/sharding" - "github.com/ecodeclub/eorm/internal/sharding/hash" "github.com/ecodeclub/eorm/internal/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -36,32 +31,7 @@ import ( ) type ShardingUpdateTestSuite struct { - ShardingSuite - data []*test.OrderDetail -} - -func (s *ShardingUpdateTestSuite) SetupSuite() { - t := s.T() - s.ShardingSuite.SetupSuite() - for _, item := range s.data { - shardingRes, err := s.algorithm.Sharding( - context.Background(), sharding.Request{Op: operator.OpEQ, SkValues: map[string]any{"OrderId": item.OrderId}}) - require.NoError(t, err) - require.NotNil(t, shardingRes.Dsts) - for _, dst := range shardingRes.Dsts { - tbl := fmt.Sprintf("`%s`.`%s`", dst.DB, dst.Table) - sql := fmt.Sprintf("INSERT INTO %s (`order_id`,`item_id`,`using_col1`,`using_col2`) VALUES(?,?,?,?);", tbl) - args := []any{item.OrderId, item.ItemId, item.UsingCol1, item.UsingCol2} - source, ok := s.dataSources[dst.Name] - require.True(t, ok) - _, err := source.Exec(context.Background(), datasource.Query{SQL: sql, Args: args, DB: dst.DB}) - if err != nil { - t.Fatal(err) - } - } - } - // 防止主从延迟 - time.Sleep(1) + ShardingSelectUpdateInsertSuite } func (s *ShardingUpdateTestSuite) TestShardingUpdater_Exec() { @@ -70,7 +40,7 @@ func (s *ShardingUpdateTestSuite) TestShardingUpdater_Exec() { _, err := r.Register(&test.OrderDetail{}, model.WithTableShardingAlgorithm(s.algorithm)) require.NoError(t, err) - eorm.DBOptionWithMetaRegistry(r)(s.shardingDB) + eorm.DBWithMetaRegistry(r)(s.shardingDB) testCases := []struct { name string wantAffectedRows int64 @@ -149,60 +119,8 @@ func (s *ShardingUpdateTestSuite) TestShardingUpdater_Exec() { } } -func (s *ShardingUpdateTestSuite) TearDownSuite() { - t := s.T() - for _, item := range s.data { - shardingRes, err := s.algorithm.Sharding( - context.Background(), sharding.Request{Op: operator.OpEQ, SkValues: map[string]any{"OrderId": item.OrderId}}) - require.NoError(t, err) - require.NotNil(t, shardingRes.Dsts) - for _, dst := range shardingRes.Dsts { - tbl := fmt.Sprintf("`%s`.`%s`", dst.DB, dst.Table) - sql := fmt.Sprintf("DELETE FROM %s", tbl) - source, ok := s.dataSources[dst.Name] - require.True(t, ok) - _, err := source.Exec(context.Background(), datasource.Query{SQL: sql, DB: dst.DB}) - if err != nil { - t.Fatal(err) - } - } - } -} - func TestMySQL8ShardingUpdate(t *testing.T) { - m := []*masterSalvesDriver{ - { - masterdsn: "root:root@tcp(localhost:13307)/order_detail_db_0", - slavedsns: []string{"root:root@tcp(localhost:13308)/order_detail_db_0"}, - }, - { - masterdsn: "root:root@tcp(localhost:13307)/order_detail_db_1", - slavedsns: []string{"root:root@tcp(localhost:13308)/order_detail_db_1"}, - }, - } - clusterDr := &clusterDriver{msDrivers: m} suite.Run(t, &ShardingUpdateTestSuite{ - ShardingSuite: ShardingSuite{ - driver: "mysql", - algorithm: &hash.Hash{ - ShardingKey: "OrderId", - DBPattern: &hash.Pattern{Name: "order_detail_db_%d", Base: 2}, - TablePattern: &hash.Pattern{Name: "order_detail_tab_%d", Base: 3}, - DsPattern: &hash.Pattern{Name: "root:root@tcp(localhost:13307).0", NotSharding: true}, - }, - DBPattern: "order_detail_db_%d", - DsPattern: "root:root@tcp(localhost:13307).%d", - clusters: &clusterDrivers{ - clDrivers: []*clusterDriver{clusterDr}, - }, - }, - data: []*test.OrderDetail{ - {8, 6, "Kobe", "Bryant"}, - {11, 8, "James", "Harden"}, - {123, 10, "LeBron", "James"}, - {234, 12, "Kevin", "Durant"}, - {253, 8, "Stephen", "Curry"}, - {181, 11, "Kawhi", "Leonard"}, - }, + ShardingSelectUpdateInsertSuite: newShardingSelectUpdateInsertSuite(), }) } diff --git a/internal/merger/batchmerger/merger.go b/internal/merger/batchmerger/merger.go index 37098e30..9b8e07e1 100644 --- a/internal/merger/batchmerger/merger.go +++ b/internal/merger/batchmerger/merger.go @@ -112,23 +112,39 @@ func (r *Rows) Next() bool { func (r *Rows) nextRows() (bool, error) { row := r.rowsList[r.cnt] + if row.Next() { return true, nil } + + for row.NextResultSet() { + if row.Next() { + return true, nil + } + } + if row.Err() != nil { return false, row.Err() } + for { r.cnt++ if r.cnt >= len(r.rowsList) { break } row = r.rowsList[r.cnt] + if row.Next() { return true, nil } else if row.Err() != nil { return false, row.Err() } + + for row.NextResultSet() { + if row.Next() { + return true, nil + } + } } return false, nil } diff --git a/sharding_insert_test.go b/sharding_insert_test.go index 4b9aec99..234f7bdd 100644 --- a/sharding_insert_test.go +++ b/sharding_insert_test.go @@ -71,7 +71,7 @@ func TestShardingInsert_Build(t *testing.T) { "1.db.cluster.company.com:3306": clusterDB, } shardingDB, err := OpenDS("sqlite3", - shardingsource.NewShardingDataSource(ds), DBOptionWithMetaRegistry(r)) + shardingsource.NewShardingDataSource(ds), DBWithMetaRegistry(r)) require.NoError(t, err) testCases := []struct { name string @@ -255,7 +255,7 @@ func (s *ShardingInsertSuite) TestShardingInsert_Exec() { "0.db.cluster.company.com:3306": clusterDB, } shardingDB, err := OpenDS("mysql", - shardingsource.NewShardingDataSource(ds), DBOptionWithMetaRegistry(r)) + shardingsource.NewShardingDataSource(ds), DBWithMetaRegistry(r)) require.NoError(s.T(), err) testcases := []struct { name string diff --git a/sharding_select.go b/sharding_select.go index be5a1f02..b305e7e8 100644 --- a/sharding_select.go +++ b/sharding_select.go @@ -303,17 +303,92 @@ func (s *ShardingSelector[T]) GetMulti(ctx context.Context) ([]*T, error) { if err != nil { return nil, err } + var rowsSlice []*sql.Rows var eg errgroup.Group for _, query := range qs { q := query eg.Go(func() error { - s.lock.Lock() - defer s.lock.Unlock() + //s.lock.Lock() + //defer s.lock.Unlock() + // TODO 利用 ctx 传递 DB name + rows, err := s.db.queryContext(ctx, q) + if err == nil { + s.lock.Lock() + rowsSlice = append(rowsSlice, rows) + s.lock.Unlock() + } + return err + }) + } + err = eg.Wait() + if err != nil { + return nil, err + } + + mgr := batchmerger.NewMerger() + rows, err := mgr.Merge(ctx, rowsSlice) + if err != nil { + return nil, err + } + var res []*T + for rows.Next() { + tp := new(T) + val := s.valCreator.NewPrimitiveValue(tp, s.meta) + if err = val.SetColumns(rows); err != nil { + return nil, err + } + res = append(res, tp) + } + return res, nil +} + +func (s *ShardingSelector[T]) GetMultiV2(ctx context.Context) ([]*T, error) { + qs, err := s.Build(ctx) + if err != nil { + return nil, err + } + var sdQs []sharding.Query + dsMap := make(map[string]map[string]sharding.Query, 8) + for _, q := range qs { + dbMap, ok := dsMap[q.Datasource] + if !ok { + dsMap[q.Datasource] = map[string]sharding.Query{q.DB: q} + continue + } + old, ok := dbMap[q.DB] + if !ok { + if dbMap == nil { + dbMap = make(map[string]sharding.Query, 8) + } + dbMap[q.DB] = q + continue + } + old.SQL = old.SQL + q.SQL + old.Args = append(old.Args, q.Args...) + dbMap[q.DB] = old + dsMap[q.Datasource] = dbMap + } + for _, dbMap := range dsMap { + for _, q := range dbMap { + sdQs = append(sdQs, q) + } + } + + var rowsSlice []*sql.Rows + var eg errgroup.Group + for _, query := range sdQs { + q := query + //fmt.Println(q.String()) + eg.Go(func() error { + //s.lock.Lock() + //defer s.lock.Unlock() // TODO 利用 ctx 传递 DB name rows, err := s.db.queryContext(ctx, q) if err == nil { + s.lock.Lock() rowsSlice = append(rowsSlice, rows) + s.lock.Unlock() } return err }) diff --git a/sharding_select_test.go b/sharding_select_test.go index 543d7292..e61d4f21 100644 --- a/sharding_select_test.go +++ b/sharding_select_test.go @@ -64,7 +64,7 @@ func TestShardingSelector_shadow_Build(t *testing.T) { "0.db.cluster.company.com:3306": clusterDB, } shardingDB, err := OpenDS("sqlite3", - shardingsource.NewShardingDataSource(ds), DBOptionWithMetaRegistry(r)) + shardingsource.NewShardingDataSource(ds), DBWithMetaRegistry(r)) require.NoError(t, err) testCases := []struct { @@ -625,7 +625,7 @@ func TestShardingSelector_onlyDataSource_Build(t *testing.T) { "1.db.cluster.company.com:3306": clusterDB, } shardingDB, err := OpenDS("sqlite3", - shardingsource.NewShardingDataSource(ds), DBOptionWithMetaRegistry(r)) + shardingsource.NewShardingDataSource(ds), DBWithMetaRegistry(r)) require.NoError(t, err) testCases := []struct { @@ -1492,7 +1492,7 @@ func TestShardingSelector_onlyTable_Build(t *testing.T) { "0.db.cluster.company.com:3306": clusterDB, } shardingDB, err := OpenDS("sqlite3", - shardingsource.NewShardingDataSource(ds), DBOptionWithMetaRegistry(r)) + shardingsource.NewShardingDataSource(ds), DBWithMetaRegistry(r)) require.NoError(t, err) testCases := []struct { @@ -2376,7 +2376,7 @@ func TestShardingSelector_onlyDB_Build(t *testing.T) { "0.db.cluster.company.com:3306": clusterDB, } shardingDB, err := OpenDS("sqlite3", - shardingsource.NewShardingDataSource(ds), DBOptionWithMetaRegistry(r)) + shardingsource.NewShardingDataSource(ds), DBWithMetaRegistry(r)) require.NoError(t, err) testCases := []struct { @@ -3290,7 +3290,7 @@ func TestShardingSelector_all_Build(t *testing.T) { "1.db.cluster.company.com:3306": clusterDB, } shardingDB, err := OpenDS("sqlite3", - shardingsource.NewShardingDataSource(ds), DBOptionWithMetaRegistry(r)) + shardingsource.NewShardingDataSource(ds), DBWithMetaRegistry(r)) require.NoError(t, err) testCases := []struct { @@ -4546,7 +4546,7 @@ func TestShardingSelector_Build(t *testing.T) { "0.db.cluster.company.com:3306": clusterDB, } shardingDB, err := OpenDS("sqlite3", - shardingsource.NewShardingDataSource(ds), DBOptionWithMetaRegistry(r)) + shardingsource.NewShardingDataSource(ds), DBWithMetaRegistry(r)) require.NoError(t, err) testCases := []struct { @@ -5802,7 +5802,7 @@ func TestShardingSelector_Build_Error(t *testing.T) { "0.db.cluster.company.com:3306": clusterDB, } shardingDB, err := OpenDS("sqlite3", - shardingsource.NewShardingDataSource(ds), DBOptionWithMetaRegistry(r)) + shardingsource.NewShardingDataSource(ds), DBWithMetaRegistry(r)) require.NoError(t, err) testCases := []struct { @@ -5907,7 +5907,7 @@ func TestShardingSelector_Build_Error(t *testing.T) { shardingsource.NewShardingDataSource(map[string]datasource.DataSource{ "0.db.cluster.company.com:3306": MasterSlavesMemoryDB(), }), - DBOptionWithMetaRegistry(reg)) + DBWithMetaRegistry(reg)) require.NoError(t, err) s := NewShardingSelector[Order](db).Where(C("UserId").EQ(123)) return s @@ -5954,7 +5954,7 @@ func TestShardingSelector_Get(t *testing.T) { "0.db.slave.company.com:3306": masterSlaveDB, } shardingDB, err := OpenDS("mysql", - shardingsource.NewShardingDataSource(m), DBOptionWithMetaRegistry(r)) + shardingsource.NewShardingDataSource(m), DBWithMetaRegistry(r)) require.NoError(t, err) testCases := []struct { @@ -6143,7 +6143,7 @@ func TestShardingSelector_GetMulti(t *testing.T) { "0.db.cluster.company.com:3306": clusterDB, } shardingDB, err := OpenDS("mysql", - shardingsource.NewShardingDataSource(ds), DBOptionWithMetaRegistry(r)) + shardingsource.NewShardingDataSource(ds), DBWithMetaRegistry(r)) require.NoError(t, err) testCases := []struct { diff --git a/sharding_update_test.go b/sharding_update_test.go index b65c490a..5e4296c4 100644 --- a/sharding_update_test.go +++ b/sharding_update_test.go @@ -73,10 +73,10 @@ func TestShardingUpdater_Build(t *testing.T) { "0.db.cluster.company.com:3306": clusterDB, } shardingDB, err := OpenDS("sqlite3", - shardingsource.NewShardingDataSource(ds), DBOptionWithMetaRegistry(r)) + shardingsource.NewShardingDataSource(ds), DBWithMetaRegistry(r)) require.NoError(t, err) shardingDB2, err := OpenDS("sqlite3", - shardingsource.NewShardingDataSource(ds), DBOptionWithMetaRegistry(r2)) + shardingsource.NewShardingDataSource(ds), DBWithMetaRegistry(r2)) require.NoError(t, err) testCases := []struct { name string @@ -640,7 +640,7 @@ func TestShardingUpdater_Build_Error(t *testing.T) { "0.db.cluster.company.com:3306": clusterDB, } shardingDB, err := OpenDS("sqlite3", - shardingsource.NewShardingDataSource(ds), DBOptionWithMetaRegistry(r)) + shardingsource.NewShardingDataSource(ds), DBWithMetaRegistry(r)) require.NoError(t, err) testCases := []struct { name string @@ -725,7 +725,7 @@ func TestShardingUpdater_Build_Error(t *testing.T) { shardingsource.NewShardingDataSource(map[string]datasource.DataSource{ "0.db.cluster.company.com:3306": MasterSlavesMemoryDB(), }), - DBOptionWithMetaRegistry(reg)) + DBWithMetaRegistry(reg)) require.NoError(t, err) s := NewShardingUpdater[Order](db). Update(&Order{Content: "1", Account: 1.0}). @@ -796,7 +796,7 @@ func (s *ShardingUpdaterSuite) TestShardingUpdater_Exec() { "0.db.cluster.company.com:3306": clusterDB, } shardingDB, err := OpenDS("sqlite3", - shardingsource.NewShardingDataSource(ds), DBOptionWithMetaRegistry(r)) + shardingsource.NewShardingDataSource(ds), DBWithMetaRegistry(r)) require.NoError(t, err) testCases := []struct { name string @@ -887,7 +887,7 @@ func ExampleShardingUpdater_SkipNilValue() { "0.db.cluster.company.com:3306": clusterDB, } shardingDB, _ := OpenDS("sqlite3", - shardingsource.NewShardingDataSource(ds), DBOptionWithMetaRegistry(r)) + shardingsource.NewShardingDataSource(ds), DBWithMetaRegistry(r)) query, _ := NewShardingUpdater[OrderDetail](shardingDB).Update(&OrderDetail{ UsingCol1: "Jack", ItemId: 11, }).SkipNilValue().Where(C("OrderId").EQ(1)).Build(context.Background()) @@ -915,7 +915,7 @@ func ExampleShardingUpdater_SkipZeroValue() { "0.db.cluster.company.com:3306": clusterDB, } shardingDB, _ := OpenDS("sqlite3", - shardingsource.NewShardingDataSource(ds), DBOptionWithMetaRegistry(r)) + shardingsource.NewShardingDataSource(ds), DBWithMetaRegistry(r)) query, _ := NewShardingUpdater[OrderDetail](shardingDB).Update(&OrderDetail{ UsingCol1: "Jack", }).SkipZeroValue().Where(C("OrderId").EQ(1)).Build(context.Background()) diff --git a/transaction_test.go b/transaction_test.go index 026442e8..43319918 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -22,8 +22,6 @@ import ( "github.com/DATA-DOG/go-sqlmock" "github.com/ecodeclub/eorm/internal/datasource" - "github.com/ecodeclub/eorm/internal/datasource/cluster" - "github.com/ecodeclub/eorm/internal/datasource/masterslave" "github.com/ecodeclub/eorm/internal/datasource/single" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -182,25 +180,25 @@ func TestTx_ExecContext(t *testing.T) { wantErr error isCommit bool }{ - { - name: "source err", - mockOrder: func(mock sqlmock.Sqlmock) { - mock.ExpectBegin() - mock.ExpectExec("DELETE FROM `test_model` WHERE `id`=").WithArgs(1).WillReturnResult(sqlmock.NewResult(10, 20)) - mock.ExpectCommit() - }, - sourceFunc: func(db *sql.DB, t *testing.T) datasource.DataSource { - clusterDB := cluster.NewClusterDB(map[string]*masterslave.MasterSlavesDB{ - "db0": masterslave.NewMasterSlavesDB(db), - }) - return clusterDB - }, - query: Query{ - SQL: "DELETE FROM `test_model` WHERE `id`=", - Args: []any{1}, - }, - wantBeginTxErr: errors.New("eorm: 未实现 TxBeginner 接口"), - }, + //{ + // name: "source err", + // mockOrder: func(mock sqlmock.Sqlmock) { + // mock.ExpectBegin() + // mock.ExpectExec("DELETE FROM `test_model` WHERE `id`=").WithArgs(1).WillReturnResult(sqlmock.NewResult(10, 20)) + // mock.ExpectCommit() + // }, + // sourceFunc: func(db *sql.DB, t *testing.T) datasource.DataSource { + // clusterDB := cluster.NewClusterDB(map[string]*masterslave.MasterSlavesDB{ + // "db0": masterslave.NewMasterSlavesDB(db), + // }) + // return clusterDB + // }, + // query: Query{ + // SQL: "DELETE FROM `test_model` WHERE `id`=", + // Args: []any{1}, + // }, + // wantBeginTxErr: errors.New("eorm: 未实现 TxBeginner 接口"), + //}, { name: "commit err", mockOrder: func(mock sqlmock.Sqlmock) { From d10f1b74c1e03a94fb264d5b3e9de29862c9a9f4 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 28 Aug 2023 00:41:11 +0800 Subject: [PATCH 2/3] =?UTF-8?q?merger:=20=E4=BD=BF=E7=94=A8=20sqlx.Scanner?= =?UTF-8?q?=20=E6=9D=A5=E8=AF=BB=E5=8F=96=E6=95=B0=E6=8D=AE=20(#216)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .CHANGELOG.md | 1 + .gitignore | 1 + internal/merger/aggregatemerger/merger.go | 35 +-- .../groupby_merger/aggregator_merger.go | 42 +-- internal/merger/utils/scan.go | 48 ---- internal/merger/utils/scan_test.go | 249 ------------------ 6 files changed, 26 insertions(+), 350 deletions(-) delete mode 100644 internal/merger/utils/scan.go delete mode 100644 internal/merger/utils/scan_test.go diff --git a/.CHANGELOG.md b/.CHANGELOG.md index df2a89fd..2487da69 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -32,6 +32,7 @@ - [eorm: ShardingInserter 修改为表维度执行](https://github.com/ecodeclub/eorm/pull/211) - [eorm: 分库分表:ShardingUpdater 实现](https://github.com/ecodeclub/eorm/pull/201) - [eorm: 分库分表:datasource-简单的分布式事务方案支持](https://github.com/ecodeclub/eorm/pull/204) +- [merger: 使用 sqlx.Scanner 来读取数据](https://github.com/ecodeclub/eorm/pull/216) ## v0.0.1: - [Init Project](https://github.com/ecodeclub/eorm/pull/1) diff --git a/.gitignore b/.gitignore index d0fd7913..0bb3b4dd 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,4 @@ .idea *.iml fuzz +go.work \ No newline at end of file diff --git a/internal/merger/aggregatemerger/merger.go b/internal/merger/aggregatemerger/merger.go index 7cc44c8d..5ebff9e1 100644 --- a/internal/merger/aggregatemerger/merger.go +++ b/internal/merger/aggregatemerger/merger.go @@ -17,9 +17,12 @@ package aggregatemerger import ( "context" "database/sql" + "errors" "sync" _ "unsafe" + "github.com/ecodeclub/ekit/sqlx" + "github.com/ecodeclub/eorm/internal/merger" "github.com/ecodeclub/eorm/internal/merger/utils" @@ -54,9 +57,8 @@ func (m *Merger) Merge(ctx context.Context, results []*sql.Rows) (merger.Rows, e return nil, errs.ErrMergerEmptyRows } for _, res := range results { - err := m.checkColumns(res) - if err != nil { - return nil, err + if res == nil { + return nil, errs.ErrMergerRowsIsNull } } @@ -70,12 +72,6 @@ func (m *Merger) Merge(ctx context.Context, results []*sql.Rows) (merger.Rows, e }, nil } -func (m *Merger) checkColumns(rows *sql.Rows) error { - if rows == nil { - return errs.ErrMergerRowsIsNull - } - return nil -} type Rows struct { rowsList []*sql.Rows @@ -149,23 +145,18 @@ func (r *Rows) getSqlRowsData() ([][]any, error) { } return rowsData, nil } -func (r *Rows) getSqlRowData(row *sql.Rows) ([]any, error) { - +func (*Rows) getSqlRowData(row *sql.Rows) ([]any, error) { var colsData []any var err error - if row.Next() { - colsData, err = utils.Scan(row) - if err != nil { - return nil, err - } - } else { - // sql.Rows迭代过程中发生报错,返回报错 - if row.Err() != nil { - return nil, row.Err() - } + scanner, err := sqlx.NewSQLRowsScanner(row) + if err != nil { + return nil, err + } + colsData, err = scanner.Scan() + if errors.Is(err, sqlx.ErrNoMoreRows) { return nil, errs.ErrMergerAggregateHasEmptyRows } - return colsData, nil + return colsData, err } func (r *Rows) Scan(dest ...any) error { diff --git a/internal/merger/groupby_merger/aggregator_merger.go b/internal/merger/groupby_merger/aggregator_merger.go index 1d60e878..56eae0af 100644 --- a/internal/merger/groupby_merger/aggregator_merger.go +++ b/internal/merger/groupby_merger/aggregator_merger.go @@ -21,6 +21,10 @@ import ( "sync" _ "unsafe" + "github.com/ecodeclub/ekit/slice" + + "github.com/ecodeclub/ekit/sqlx" + "github.com/ecodeclub/eorm/internal/merger/utils" "go.uber.org/multierr" @@ -61,11 +65,8 @@ func (a *AggregatorMerger) Merge(ctx context.Context, results []*sql.Rows) (merg return nil, errs.ErrMergerEmptyRows } - for _, res := range results { - err := a.checkColumns(res) - if err != nil { - return nil, err - } + if slice.Contains[*sql.Rows](results, nil) { + return nil, errs.ErrMergerRowsIsNull } dataMap, dataIndex, err := a.getCols(results) if err != nil { @@ -82,13 +83,6 @@ func (a *AggregatorMerger) Merge(ctx context.Context, results []*sql.Rows) (merg cur: -1, cols: a.columnsName, }, nil - -} -func (a *AggregatorMerger) checkColumns(rows *sql.Rows) error { - if rows == nil { - return errs.ErrMergerRowsIsNull - } - return nil } func (a *AggregatorMerger) getCols(rowsList []*sql.Rows) (*mapx.TreeMap[Key, [][]any], []Key, error) { @@ -98,7 +92,11 @@ func (a *AggregatorMerger) getCols(rowsList []*sql.Rows) (*mapx.TreeMap[Key, [][ } keys := make([]Key, 0, 16) for _, res := range rowsList { - colsData, err := a.getCol(res) + scanner, err := sqlx.NewSQLRowsScanner(res) + if err != nil { + return nil, nil, err + } + colsData, err := scanner.ScanAll() if err != nil { return nil, nil, err } @@ -108,7 +106,6 @@ func (a *AggregatorMerger) getCols(rowsList []*sql.Rows) (*mapx.TreeMap[Key, [][ key.columnValues = append(key.columnValues, colData[groupByCol.Index]) } val, ok := treeMap.Get(key) - if ok { val = append(val, colData) err = treeMap.Set(key, val) @@ -127,23 +124,6 @@ func (a *AggregatorMerger) getCols(rowsList []*sql.Rows) (*mapx.TreeMap[Key, [][ return treeMap, keys, nil } -func (a *AggregatorMerger) getCol(row *sql.Rows) ([][]any, error) { - ans := make([][]any, 0, 16) - for row.Next() { - colsData, err := utils.Scan(row) - if err != nil { - return nil, err - } - ans = append(ans, colsData) - } - if row.Err() != nil { - return nil, row.Err() - } - - return ans, nil - -} - type AggregatorRows struct { rowsList []*sql.Rows aggregators []aggregator.Aggregator diff --git a/internal/merger/utils/scan.go b/internal/merger/utils/scan.go deleted file mode 100644 index 3412a902..00000000 --- a/internal/merger/utils/scan.go +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2021 ecodeclub -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "database/sql" - "reflect" -) - -func Scan(row *sql.Rows) ([]any, error) { - colsInfo, err := row.ColumnTypes() - if err != nil { - return nil, err - } - colsData := make([]any, 0, len(colsInfo)) - // 拿到sql.Rows字段的类型然后初始化 - for _, colInfo := range colsInfo { - typ := colInfo.ScanType() - // sqlite3的驱动返回的是指针。循环的去除指针 - for typ.Kind() == reflect.Pointer { - typ = typ.Elem() - } - newData := reflect.New(typ).Interface() - colsData = append(colsData, newData) - } - // 通过Scan赋值 - err = row.Scan(colsData...) - if err != nil { - return nil, err - } - // 去掉reflect.New的指针 - for i := 0; i < len(colsData); i++ { - colsData[i] = reflect.ValueOf(colsData[i]).Elem().Interface() - } - return colsData, nil -} diff --git a/internal/merger/utils/scan_test.go b/internal/merger/utils/scan_test.go deleted file mode 100644 index 61fc0890..00000000 --- a/internal/merger/utils/scan_test.go +++ /dev/null @@ -1,249 +0,0 @@ -// Copyright 2021 ecodeclub -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "context" - "database/sql" - "testing" - "time" - - "github.com/DATA-DOG/go-sqlmock" - _ "github.com/mattn/go-sqlite3" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" -) - -type ScanSuite struct { - suite.Suite - mockDB01 *sql.DB - mock01 sqlmock.Sqlmock - db02 *sql.DB -} - -func (ms *ScanSuite) SetupTest() { - t := ms.T() - ms.initMock(t) -} - -func (ms *ScanSuite) TearDownTest() { - _ = ms.mockDB01.Close() - _ = ms.db02.Close() -} -func (ms *ScanSuite) initMock(t *testing.T) { - var err error - query := "CREATE TABLE t1 (\n id int primary key,\n `int` int,\n `integer` integer,\n `tinyint` TINYINT,\n `smallint` smallint,\n `MEDIUMINT` MEDIUMINT,\n `BIGINT` BIGINT,\n `UNSIGNED_BIG_INT` UNSIGNED BIG INT,\n `INT2` INT2,\n `INT8` INT8,\n `VARCHAR` VARCHAR(20),\n \t\t`CHARACTER` CHARACTER(20),\n `VARYING_CHARACTER` VARYING_CHARACTER(20),\n `NCHAR` NCHAR(23),\n `TEXT` TEXT,\n `CLOB` CLOB,\n `REAL` REAL,\n `DOUBLE` DOUBLE,\n `DOUBLE_PRECISION` DOUBLE PRECISION,\n `FLOAT` FLOAT,\n `DATETIME` DATETIME \n );" - ms.mockDB01, ms.mock01, err = sqlmock.New() - if err != nil { - t.Fatal(err) - } - db02, err := sql.Open("sqlite3", "file:test01.db?cache=shared&mode=memory") - if err != nil { - t.Fatal(err) - } - ms.db02 = db02 - _, err = db02.ExecContext(context.Background(), query) - if err != nil { - t.Fatal(err) - } -} -func (ms *ScanSuite) TestScan() { - testcases := []struct { - name string - rows *sql.Rows - want []any - err error - afterFunc func() - }{ - { - name: "浮点数", - rows: func() *sql.Rows { - cols := []string{"float64"} - query := "SELECT float64 FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(float64(1.1))) - rows, err := ms.mockDB01.QueryContext(context.Background(), query) - require.NoError(ms.T(), err) - return rows - }(), - want: []any{float64(1.1)}, - }, - { - name: "int64", - rows: func() *sql.Rows { - cols := []string{"int64"} - query := "SELECT int64 FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(int64(1))) - rows, err := ms.mockDB01.QueryContext(context.Background(), query) - require.NoError(ms.T(), err) - return rows - }(), - want: []any{int64(1)}, - }, - { - name: "int32", - rows: func() *sql.Rows { - cols := []string{"int32"} - query := "SELECT int32 FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(int32(1))) - rows, err := ms.mockDB01.QueryContext(context.Background(), query) - require.NoError(ms.T(), err) - return rows - }(), - want: []any{int32(1)}, - }, - { - name: "int16", - rows: func() *sql.Rows { - cols := []string{"int16"} - query := "SELECT int16 FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(int16(1))) - rows, err := ms.mockDB01.QueryContext(context.Background(), query) - require.NoError(ms.T(), err) - return rows - }(), - want: []any{int16(1)}, - }, - { - name: "int8", - rows: func() *sql.Rows { - cols := []string{"int8"} - query := "SELECT int8 FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(int8(1))) - rows, err := ms.mockDB01.QueryContext(context.Background(), query) - require.NoError(ms.T(), err) - return rows - }(), - want: []any{int8(1)}, - }, - { - name: "int", - rows: func() *sql.Rows { - cols := []string{"int"} - query := "SELECT FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) - rows, err := ms.mockDB01.QueryContext(context.Background(), query) - require.NoError(ms.T(), err) - return rows - }(), - want: []any{1}, - }, - { - name: "string", - rows: func() *sql.Rows { - cols := []string{"string"} - query := "SELECT string FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("xx")) - rows, err := ms.mockDB01.QueryContext(context.Background(), query) - require.NoError(ms.T(), err) - return rows - }(), - want: []any{"string"}, - }, - { - name: "bool", - rows: func() *sql.Rows { - cols := []string{"bool"} - query := "SELECT bool FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(true)) - rows, err := ms.mockDB01.QueryContext(context.Background(), query) - require.NoError(ms.T(), err) - return rows - }(), - want: []any{true}, - }, - { - name: "sqlite3 int类型", - rows: func() *sql.Rows { - _, err := ms.db02.Exec("INSERT INTO `t1` (`int`,`integer`,`tinyint`,`smallint`,`MEDIUMINT`,`BIGINT`,`UNSIGNED_BIG_INT`,`INT2`) VALUES (1,1,1,1,1,1,1,1);") - require.NoError(ms.T(), err) - query := "SELECT `int`,`integer`,`tinyint`,`smallint`,`MEDIUMINT`,`BIGINT`,`UNSIGNED_BIG_INT`,`INT2`,`INT8` FROM `t1`;" - rows, err := ms.db02.QueryContext(context.Background(), query) - require.NoError(ms.T(), err) - return rows - }(), - want: []any{sql.NullInt64{Valid: true, Int64: 1}, sql.NullInt64{Valid: true, Int64: 1}, sql.NullInt64{Valid: true, Int64: 1}, sql.NullInt64{Valid: true, Int64: 1}, sql.NullInt64{Valid: true, Int64: 1}, sql.NullInt64{Valid: true, Int64: 1}, sql.NullInt64{Valid: true, Int64: 1}, sql.NullInt64{Valid: true, Int64: 1}, sql.NullInt64{Valid: false, Int64: 0}}, - afterFunc: func() { - _, err := ms.db02.Exec("truncate table `t1`") - require.NoError(ms.T(), err) - }, - }, - { - name: "sqlite3 string类型", - rows: func() *sql.Rows { - _, err := ms.db02.Exec("INSERT INTO `t1` (`VARCHAR`,`CHARACTER`,`VARYING_CHARACTER`,`NCHAR`,`TEXT`) VALUES ('zwl','zwl','zwl','zwl','zwl');") - require.NoError(ms.T(), err) - query := "SELECT `VARCHAR`,`CHARACTER`,`VARYING_CHARACTER`,`NCHAR`,`TEXT`,`CLOB` FROM `t1`;" - rows, err := ms.db02.QueryContext(context.Background(), query) - require.NoError(ms.T(), err) - return rows - }(), - want: []any{sql.NullString{Valid: true, String: "zwl"}, sql.NullString{Valid: true, String: "zwl"}, sql.NullString{Valid: true, String: "zwl"}, sql.NullString{Valid: true, String: "zwl"}, sql.NullString{Valid: true, String: "zwl"}, sql.NullString{Valid: false, String: ""}}, - afterFunc: func() { - _, err := ms.db02.Exec("truncate table `t1`") - require.NoError(ms.T(), err) - }, - }, - { - name: "sqlite3 浮点类型", - rows: func() *sql.Rows { - _, err := ms.db02.Exec("INSERT INTO `t1` (`REAL`,`DOUBLE`,`DOUBLE_PRECISION`) VALUES (1.0,1.0,1.0);") - require.NoError(ms.T(), err) - query := "SELECT `REAL`,`DOUBLE`,`DOUBLE_PRECISION`,`FLOAT` FROM `t1`;" - rows, err := ms.db02.QueryContext(context.Background(), query) - require.NoError(ms.T(), err) - return rows - }(), - want: []any{sql.NullFloat64{Valid: true, Float64: 1.0}, sql.NullFloat64{Valid: true, Float64: 1.0}, sql.NullFloat64{Valid: true, Float64: 1.0}, sql.NullFloat64{Valid: false, Float64: 0}}, - afterFunc: func() { - _, err := ms.db02.Exec("truncate table `t1`") - require.NoError(ms.T(), err) - }, - }, - { - name: "sqlite3时间类型", - rows: func() *sql.Rows { - _, err := ms.db02.Exec("INSERT INTO `t1` (`DATETIME`) VALUES ('2022-01-01 12:00:00');") - require.NoError(ms.T(), err) - query := "SELECT `DATETIME` FROM `t1`;" - rows, err := ms.db02.QueryContext(context.Background(), query) - require.NoError(ms.T(), err) - return rows - }(), - want: []any{sql.NullTime{Valid: true, Time: func() time.Time { - t, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-01 12:00:00", time.Local) - require.NoError(ms.T(), err) - return t - - }()}}, - }, - } - for _, tc := range testcases { - ms.T().Run(tc.name, func(t *testing.T) { - rows := tc.rows - require.True(t, rows.Next()) - got, err := Scan(rows) - require.Equal(t, tc.err, err) - if err == nil { - return - } - require.Equal(t, tc.want, got) - tc.afterFunc() - }) - } -} - -func TestMerger(t *testing.T) { - suite.Run(t, &ScanSuite{}) -} From 07dc4162a78f3ce310929bc470c50096b50b0025 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Tue, 5 Sep 2023 23:43:04 +0800 Subject: [PATCH 3/3] =?UTF-8?q?rows,=20merger:=20=E4=BD=BF=E7=94=A8=20sqlx?= =?UTF-8?q?.Rows=20=E4=BD=9C=E4=B8=BA=E6=8E=A5=E5=8F=A3=EF=BC=8C=E5=B9=B6?= =?UTF-8?q?=E9=87=8D=E6=9E=84=20merger=20=E5=8C=85=20(#217)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .CHANGELOG.md | 1 + .github/workflows/go-fmt.yml | 2 +- .github/workflows/go.yml | 2 +- .github/workflows/golangci-lint.yml | 2 +- .github/workflows/integration_test.yml | 2 +- README.md | 2 +- go.mod | 4 +- go.sum | 4 +- internal/errs/error.go | 4 + internal/merger/aggregatemerger/merger.go | 22 +-- .../merger/aggregatemerger/merger_test.go | 92 ++++++------ internal/merger/batchmerger/merger.go | 16 ++- internal/merger/batchmerger/merger_test.go | 96 +++++++------ .../groupby_merger/aggregator_merger.go | 21 ++- .../groupby_merger/aggregator_merger_test.go | 60 ++++---- internal/merger/pagedmerger/merger.go | 15 +- internal/merger/pagedmerger/merger_test.go | 56 ++++---- internal/merger/sortmerger/merger.go | 26 ++-- internal/merger/sortmerger/merger_test.go | 134 +++++++++--------- internal/merger/type.go | 5 +- .../{merger/utils => rows}/convert_assign.go | 2 +- .../utils => rows}/convert_assign_test.go | 2 +- internal/rows/types.go | 14 +- sharding_select.go | 21 +-- 24 files changed, 335 insertions(+), 270 deletions(-) rename internal/{merger/utils => rows}/convert_assign.go (98%) rename internal/{merger/utils => rows}/convert_assign_test.go (99%) diff --git a/.CHANGELOG.md b/.CHANGELOG.md index 2487da69..b3d7ca5d 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -33,6 +33,7 @@ - [eorm: 分库分表:ShardingUpdater 实现](https://github.com/ecodeclub/eorm/pull/201) - [eorm: 分库分表:datasource-简单的分布式事务方案支持](https://github.com/ecodeclub/eorm/pull/204) - [merger: 使用 sqlx.Scanner 来读取数据](https://github.com/ecodeclub/eorm/pull/216) +- [rows, merger: 使用 sqlx.Rows 作为接口,并重构 merger 包 ](https://github.com/ecodeclub/eorm/pull/217) ## v0.0.1: - [Init Project](https://github.com/ecodeclub/eorm/pull/1) diff --git a/.github/workflows/go-fmt.yml b/.github/workflows/go-fmt.yml index 69fea98a..23f9523b 100644 --- a/.github/workflows/go-fmt.yml +++ b/.github/workflows/go-fmt.yml @@ -28,7 +28,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: ">=1.18.0" + go-version: ">=1.20.0" - name: Install goimports run: go install golang.org/x/tools/cmd/goimports@latest diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index a443303c..aa724fe3 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -28,7 +28,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v2 with: - go-version: 1.18 + go-version: '1.20' - name: Build run: go build -v ./... diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 4d8f64a4..30abeff0 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -33,7 +33,7 @@ jobs: steps: - uses: actions/setup-go@v3 with: - go-version: '1.18' + go-version: '1.20' - uses: actions/checkout@v3 - name: golangci-lint uses: golangci/golangci-lint-action@v3 diff --git a/.github/workflows/integration_test.yml b/.github/workflows/integration_test.yml index 5588ed9f..a8b8f050 100644 --- a/.github/workflows/integration_test.yml +++ b/.github/workflows/integration_test.yml @@ -28,7 +28,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v2 with: - go-version: 1.18 + go-version: '1.20' - name: Test run: sudo sh ./script/integrate_test.sh \ No newline at end of file diff --git a/README.md b/README.md index d024e365..b9c6fbb4 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ ### Go 版本 -请使用 Go 1.18 以上版本。 +请使用 Go 1.20 以上版本。 ### SQL 2003 标准 理论上来说,我们计划支持 [SQL 2003 standard](https://ronsavage.github.io/SQL/sql-2003-2.bnf.html#query%20specification). 不过据我们所知,并不是所有的数据库都支持全部的 SQL 2003 标准,所以用户还是需要进一步检查目标数据库的语法。 diff --git a/go.mod b/go.mod index 49d794ee..f02222e2 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,10 @@ module github.com/ecodeclub/eorm -go 1.18 +go 1.20 require ( github.com/DATA-DOG/go-sqlmock v1.5.0 - github.com/ecodeclub/ekit v0.0.4-0.20230530053225-e671c5fdd2d1 + github.com/ecodeclub/ekit v0.0.4-0.20230904153403-e76aae064994 github.com/go-sql-driver/mysql v1.6.0 github.com/gotomicro/ekit v0.0.0-20230224040531-869798da3c4d github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index fa9fef4d..7b335064 100644 --- a/go.sum +++ b/go.sum @@ -3,8 +3,8 @@ github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/ecodeclub/ekit v0.0.4-0.20230530053225-e671c5fdd2d1 h1:a1Dbg0zZOQPfG3pgFqZjkQM2ty1ZABewjzRK970OQ8w= -github.com/ecodeclub/ekit v0.0.4-0.20230530053225-e671c5fdd2d1/go.mod h1:OqTojKeKFTxeeAAUwNIPKu339SRkX6KAuoK/8A5BCEs= +github.com/ecodeclub/ekit v0.0.4-0.20230904153403-e76aae064994 h1:4Rp8WrJhISj8GDtnueoD22ygPuppajnCVZuEfRjg6w8= +github.com/ecodeclub/ekit v0.0.4-0.20230904153403-e76aae064994/go.mod h1:OqTojKeKFTxeeAAUwNIPKu339SRkX6KAuoK/8A5BCEs= github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/gotomicro/ekit v0.0.0-20230224040531-869798da3c4d h1:kmDgYRZ06UifBqAfew+cj02juQQ3Ko349NzsDIZ0QPw= diff --git a/internal/errs/error.go b/internal/errs/error.go index 0ef1b155..e60fdedf 100644 --- a/internal/errs/error.go +++ b/internal/errs/error.go @@ -112,3 +112,7 @@ func NewInvalidDSNError(dsn string) error { func NewFailedToGetSlavesFromDNS(err error) error { return fmt.Errorf("eorm: 从DNS中解析从库失败 %w", err) } + +func NewErrScanWrongDestinationArguments(expect int, actual int) error { + return fmt.Errorf("eorm: Scan 方法收到过多或者过少的参数,预期 %d,实际 %d", expect, actual) +} diff --git a/internal/merger/aggregatemerger/merger.go b/internal/merger/aggregatemerger/merger.go index 5ebff9e1..515529e6 100644 --- a/internal/merger/aggregatemerger/merger.go +++ b/internal/merger/aggregatemerger/merger.go @@ -21,11 +21,9 @@ import ( "sync" _ "unsafe" - "github.com/ecodeclub/ekit/sqlx" - - "github.com/ecodeclub/eorm/internal/merger" + "github.com/ecodeclub/eorm/internal/rows" - "github.com/ecodeclub/eorm/internal/merger/utils" + "github.com/ecodeclub/ekit/sqlx" "github.com/ecodeclub/eorm/internal/merger/aggregatemerger/aggregator" "github.com/ecodeclub/eorm/internal/merger/internal/errs" @@ -49,7 +47,7 @@ func NewMerger(aggregators ...aggregator.Aggregator) *Merger { } } -func (m *Merger) Merge(ctx context.Context, results []*sql.Rows) (merger.Rows, error) { +func (m *Merger) Merge(ctx context.Context, results []rows.Rows) (rows.Rows, error) { if ctx.Err() != nil { return nil, ctx.Err() } @@ -74,7 +72,7 @@ func (m *Merger) Merge(ctx context.Context, results []*sql.Rows) (merger.Rows, e } type Rows struct { - rowsList []*sql.Rows + rowsList []rows.Rows aggregators []aggregator.Aggregator closed bool mu *sync.RWMutex @@ -84,6 +82,14 @@ type Rows struct { nextCalled bool } +func (r *Rows) ColumnTypes() ([]*sql.ColumnType, error) { + return r.rowsList[0].ColumnTypes() +} + +func (*Rows) NextResultSet() bool { + return false +} + func (r *Rows) Next() bool { r.mu.Lock() if r.closed || r.lastErr != nil { @@ -145,7 +151,7 @@ func (r *Rows) getSqlRowsData() ([][]any, error) { } return rowsData, nil } -func (*Rows) getSqlRowData(row *sql.Rows) ([]any, error) { +func (*Rows) getSqlRowData(row rows.Rows) ([]any, error) { var colsData []any var err error scanner, err := sqlx.NewSQLRowsScanner(row) @@ -173,7 +179,7 @@ func (r *Rows) Scan(dest ...any) error { return errs.ErrMergerScanNotNext } for i := 0; i < len(dest); i++ { - err := utils.ConvertAssign(dest[i], r.cur[i]) + err := rows.ConvertAssign(dest[i], r.cur[i]) if err != nil { return err } diff --git a/internal/merger/aggregatemerger/merger_test.go b/internal/merger/aggregatemerger/merger_test.go index ea12b127..9b1cf756 100644 --- a/internal/merger/aggregatemerger/merger_test.go +++ b/internal/merger/aggregatemerger/merger_test.go @@ -21,12 +21,14 @@ import ( "fmt" "testing" + "github.com/ecodeclub/eorm/internal/rows" + _ "github.com/mattn/go-sqlite3" + "github.com/ecodeclub/eorm/internal/merger" "github.com/DATA-DOG/go-sqlmock" "github.com/ecodeclub/eorm/internal/merger/aggregatemerger/aggregator" "github.com/ecodeclub/eorm/internal/merger/internal/errs" - _ "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -34,8 +36,8 @@ import ( ) var ( - nextMockErr error = errors.New("rows: MockNextErr") - aggregatorErr error = errors.New("aggregator: MockAggregatorErr") + nextMockErr = errors.New("rows: MockNextErr") + aggregatorErr = errors.New("aggregator: MockAggregatorErr") ) func newCloseMockErr(dbName string) error { @@ -109,7 +111,7 @@ func TestMerger(t *testing.T) { func (ms *MergerSuite) TestRows_NextAndScan() { testcases := []struct { name string - sqlRows func() []*sql.Rows + sqlRows func() []rows.Rows wantVal []any aggregators func() []aggregator.Aggregator gotVal []any @@ -117,7 +119,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }{ { name: "sqlite的ColumnType 使用了多级指针", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { query1 := "insert into `t1` values (1,10),(2,20),(3,30)" _, err := ms.db05.ExecContext(context.Background(), query1) require.NoError(ms.T(), err) @@ -127,7 +129,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.db05} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -149,14 +151,14 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }, { name: "SUM(id)", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { cols := []string{"SUM(id)"} query := "SELECT SUM(`id`) FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -179,14 +181,14 @@ func (ms *MergerSuite) TestRows_NextAndScan() { { name: "MAX(id)", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { cols := []string{"MAX(id)"} query := "SELECT MAX(`id`) FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -208,14 +210,14 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }, { name: "MIN(id)", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { cols := []string{"MIN(id)"} query := "SELECT MIN(`id`) FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -237,14 +239,14 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }, { name: "COUNT(id)", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { cols := []string{"COUNT(id)"} query := "SELECT COUNT(`id`) FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -266,14 +268,14 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }, { name: "AVG(grade)", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { cols := []string{"SUM(grade)", "COUNT(grade)"} query := "SELECT SUM(`grade`),COUNT(`grade`) FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2000, 10)) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2000, 20)) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2000, 10)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -300,14 +302,14 @@ func (ms *MergerSuite) TestRows_NextAndScan() { // 1.每种聚合函数出现一次 { name: "COUNT(id),MAX(id),MIN(id),SUM(id),AVG(grade)", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { cols := []string{"COUNT(id)", "MAX(id)", "MIN(id)", "SUM(id)", "SUM(grade)", "COUNT(grade)"} query := "SELECT COUNT(`id`),MAX(`id`),MIN(`id`),SUM(`id`),SUM(`grade`),COUNT(`student`) FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10, 20, 1, 100, 2000, 20)) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20, 30, 0, 200, 800, 10)) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10, 40, 2, 300, 1800, 20)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -337,14 +339,14 @@ func (ms *MergerSuite) TestRows_NextAndScan() { // 两个avg会包含sum列在前,和sum列在后的状态。并且有完全相同的列出现 { name: "AVG(grade),SUM(grade),AVG(grade),MIN(id),MIN(userid),MAX(id),COUNT(id)", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { cols := []string{"SUM(grade)", "COUNT(grade)", "SUM(grade)", "COUNT(grade)", "SUM(grade)", "MIN(id)", "MIN(userid)", "MAX(id)", "COUNT(id)"} query := "SELECT SUM(`grade`),COUNT(`grade`),SUM(`grade`),COUNT(`grade`),SUM(`grade`),MIN(`id`),MIN(`userid`),MAX(`id`),COUNT(`id`) FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2000, 20, 2000, 20, 2000, 10, 20, 200, 200)) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1000, 10, 1000, 10, 1000, 20, 30, 300, 300)) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(800, 10, 800, 10, 800, 5, 6, 100, 200)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -378,7 +380,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { // 1. Rows 列表中有一个Rows返回行数为空,在前面会返回错误 { name: "RowsList有一个Rows为空,在前面", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { cols := []string{"SUM(id)"} query := "SELECT SUM(`id`) FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) @@ -386,7 +388,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) dbs := []*sql.DB{ms.mockDB04, ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -410,7 +412,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { // 2. Rows 列表中有一个Rows返回行数为空,在中间会返回错误 { name: "RowsList有一个Rows为空,在中间", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { cols := []string{"SUM(id)"} query := "SELECT SUM(`id`) FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) @@ -418,7 +420,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB04, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -442,7 +444,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { // 3. Rows 列表中有一个Rows返回行数为空,在后面会返回错误 { name: "RowsList有一个Rows为空,在最后", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { cols := []string{"SUM(id)"} query := "SELECT SUM(`id`) FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) @@ -450,7 +452,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -474,7 +476,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { // 4. Rows 列表中全部Rows返回的行数为空,不会返回错误 { name: "RowsList全部为空", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { cols := []string{"SUM(id)"} query := "SELECT SUM(`id`) FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) @@ -482,7 +484,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -523,13 +525,13 @@ func (ms *MergerSuite) TestRows_NextAndScan() { func (ms *MergerSuite) TestRows_NextAndErr() { testcases := []struct { name string - rowsList func() []*sql.Rows + rowsList func() []rows.Rows wantErr error aggregators []aggregator.Aggregator }{ { name: "sqlRows列表中有一个返回error", - rowsList: func() []*sql.Rows { + rowsList: func() []rows.Rows { cols := []string{"COUNT(id)"} query := "SELECT COUNT(`id`) FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) @@ -537,7 +539,7 @@ func (ms *MergerSuite) TestRows_NextAndErr() { ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(4).RowError(0, nextMockErr)) ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -554,7 +556,7 @@ func (ms *MergerSuite) TestRows_NextAndErr() { }, { name: "有一个aggregator返回error", - rowsList: func() []*sql.Rows { + rowsList: func() []rows.Rows { cols := []string{"COUNT(id)"} query := "SELECT COUNT(`id`) FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) @@ -562,7 +564,7 @@ func (ms *MergerSuite) TestRows_NextAndErr() { ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(4)) ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -601,7 +603,7 @@ func (ms *MergerSuite) TestRows_Close() { ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3).CloseError(newCloseMockErr("db03"))) merger := NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)"))) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -656,7 +658,7 @@ func (ms *MergerSuite) TestRows_Columns() { } merger := NewMerger(aggregators...) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -688,7 +690,7 @@ func (ms *MergerSuite) TestMerger_Merge() { merger func() *Merger ctx func() (context.Context, context.CancelFunc) wantErr error - sqlRows func() []*sql.Rows + sqlRows func() []rows.Rows }{ { name: "超时", @@ -700,10 +702,10 @@ func (ms *MergerSuite) TestMerger_Merge() { return ctx, cancel }, wantErr: context.DeadlineExceeded, - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { query := "SELECT SUM(`id`) FROM `t1`;" cols := []string{"SUM(id)"} - res := make([]*sql.Rows, 0, 1) + res := make([]rows.Rows, 0, 1) ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) rows, _ := ms.mockDB01.QueryContext(context.Background(), query) res = append(res, rows) @@ -720,8 +722,8 @@ func (ms *MergerSuite) TestMerger_Merge() { return ctx, cancel }, wantErr: errs.ErrMergerEmptyRows, - sqlRows: func() []*sql.Rows { - return []*sql.Rows{} + sqlRows: func() []rows.Rows { + return []rows.Rows{} }, }, { @@ -734,8 +736,8 @@ func (ms *MergerSuite) TestMerger_Merge() { return ctx, cancel }, wantErr: errs.ErrMergerRowsIsNull, - sqlRows: func() []*sql.Rows { - return []*sql.Rows{nil} + sqlRows: func() []rows.Rows { + return []rows.Rows{nil} }, }, } @@ -763,6 +765,10 @@ func (m *mockAggregate) Aggregate(cols [][]any) (any, error) { return nil, aggregatorErr } -func (m *mockAggregate) ColumnName() string { +func (*mockAggregate) ColumnName() string { return "mockAggregate" } + +func TestRows_NextResultSet(t *testing.T) { + assert.False(t, (&Rows{}).NextResultSet()) +} diff --git a/internal/merger/batchmerger/merger.go b/internal/merger/batchmerger/merger.go index 9b8e07e1..c017cb34 100644 --- a/internal/merger/batchmerger/merger.go +++ b/internal/merger/batchmerger/merger.go @@ -19,7 +19,7 @@ import ( "database/sql" "sync" - "github.com/ecodeclub/eorm/internal/merger" + "github.com/ecodeclub/eorm/internal/rows" "go.uber.org/multierr" @@ -34,7 +34,7 @@ func NewMerger() *Merger { return &Merger{} } -func (m *Merger) Merge(ctx context.Context, results []*sql.Rows) (merger.Rows, error) { +func (m *Merger) Merge(ctx context.Context, results []rows.Rows) (rows.Rows, error) { if ctx.Err() != nil { return nil, ctx.Err() } @@ -55,7 +55,7 @@ func (m *Merger) Merge(ctx context.Context, results []*sql.Rows) (merger.Rows, e } // checkColumns 检查sql.Rows列表中sql.Rows的列集是否相同,并且sql.Rows不能为nil -func (m *Merger) checkColumns(rows *sql.Rows) error { +func (m *Merger) checkColumns(rows rows.Rows) error { if rows == nil { return errs.ErrMergerRowsIsNull } @@ -79,7 +79,7 @@ func (m *Merger) checkColumns(rows *sql.Rows) error { } type Rows struct { - rowsList []*sql.Rows + rowsList []rows.Rows cnt int mu *sync.RWMutex columns []string @@ -87,6 +87,14 @@ type Rows struct { lastErr error } +func (r *Rows) ColumnTypes() ([]*sql.ColumnType, error) { + return r.rowsList[0].ColumnTypes() +} + +func (*Rows) NextResultSet() bool { + return false +} + func (r *Rows) Next() bool { r.mu.Lock() if r.closed { diff --git a/internal/merger/batchmerger/merger_test.go b/internal/merger/batchmerger/merger_test.go index 5bc6c11b..253bf0ce 100644 --- a/internal/merger/batchmerger/merger_test.go +++ b/internal/merger/batchmerger/merger_test.go @@ -21,6 +21,8 @@ import ( "fmt" "testing" + "github.com/ecodeclub/eorm/internal/rows" + "go.uber.org/multierr" "github.com/DATA-DOG/go-sqlmock" @@ -85,14 +87,14 @@ func (ms *MergerSuite) initMock(t *testing.T) { func (ms *MergerSuite) TestMerger_Merge() { testcases := []struct { name string - rowsList func() []*sql.Rows + rowsList func() []rows.Rows ctx func() (context.Context, context.CancelFunc) wantErr error }{ { name: "sql.Rows列表中没有元素", - rowsList: func() []*sql.Rows { - return []*sql.Rows{} + rowsList: func() []rows.Rows { + return []rows.Rows{} }, ctx: func() (context.Context, context.CancelFunc) { return context.WithCancel(context.Background()) @@ -101,8 +103,8 @@ func (ms *MergerSuite) TestMerger_Merge() { }, { name: "sql.Rows列表中有元素为nil", - rowsList: func() []*sql.Rows { - return []*sql.Rows{nil} + rowsList: func() []rows.Rows { + return []rows.Rows{nil} }, ctx: func() (context.Context, context.CancelFunc) { return context.WithCancel(context.Background()) @@ -114,12 +116,12 @@ func (ms *MergerSuite) TestMerger_Merge() { ctx: func() (context.Context, context.CancelFunc) { return context.WithCancel(context.Background()) }, - rowsList: func() []*sql.Rows { + rowsList: func() []rows.Rows { query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "address"}).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(3, "alex").AddRow(4, "x")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -134,12 +136,12 @@ func (ms *MergerSuite) TestMerger_Merge() { ctx: func() (context.Context, context.CancelFunc) { return context.WithCancel(context.Background()) }, - rowsList: func() []*sql.Rows { + rowsList: func() []rows.Rows { query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "address"}).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "email"}).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -151,12 +153,12 @@ func (ms *MergerSuite) TestMerger_Merge() { }, { name: "正常的案例", - rowsList: func() []*sql.Rows { + rowsList: func() []rows.Rows { query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "address"}).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "address"}).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -175,10 +177,10 @@ func (ms *MergerSuite) TestMerger_Merge() { return ctx, cancel }, wantErr: context.DeadlineExceeded, - rowsList: func() []*sql.Rows { + rowsList: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id"} - res := make([]*sql.Rows, 0, 1) + res := make([]rows.Rows, 0, 1) ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) rows, _ := ms.mockDB01.QueryContext(context.Background(), query) res = append(res, rows) @@ -205,21 +207,21 @@ func (ms *MergerSuite) TestRows_NextAndScan() { testCases := []struct { name string - sqlRows func() []*sql.Rows + sqlRows func() []rows.Rows wantVal []string wantErr error scanErr error }{ { name: "sqlRows列表中没有空行", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id"} ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1").AddRow("2")) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2").AddRow("2")) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - res := make([]*sql.Rows, 0, len(dbs)) + res := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, _ := db.QueryContext(context.Background(), query) res = append(res, row) @@ -230,7 +232,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }, { name: "sqlRows列表中,在前面有一个sqlRows返回空行在前面", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id"} ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) @@ -238,7 +240,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) dbs := []*sql.DB{ms.mockDB04, ms.mockDB01, ms.mockDB02, ms.mockDB03} - res := make([]*sql.Rows, 0, len(dbs)) + res := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, _ := db.QueryContext(context.Background(), query) res = append(res, row) @@ -249,7 +251,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }, { name: "sqlRows列表中,在前面有多个sqlRows返回空行", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id"} ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) @@ -257,7 +259,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) dbs := []*sql.DB{ms.mockDB04, ms.mockDB01, ms.mockDB02, ms.mockDB03} - res := make([]*sql.Rows, 0, len(dbs)) + res := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, _ := db.QueryContext(context.Background(), query) res = append(res, row) @@ -268,14 +270,14 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }, { name: "sqlRows列表中,在中间有一个sqlRows返回空行", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id"} ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) dbs := []*sql.DB{ms.mockDB02, ms.mockDB01, ms.mockDB03} - res := make([]*sql.Rows, 0, len(dbs)) + res := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, _ := db.QueryContext(context.Background(), query) res = append(res, row) @@ -286,7 +288,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }, { name: "sqlRows列表中,在中间有多个sqlRows返回空行", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id"} ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) @@ -294,7 +296,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) dbs := []*sql.DB{ms.mockDB02, ms.mockDB01, ms.mockDB04, ms.mockDB03} - res := make([]*sql.Rows, 0, len(dbs)) + res := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, _ := db.QueryContext(context.Background(), query) res = append(res, row) @@ -305,14 +307,14 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }, { name: "sqlRows列表中,在后面有一个sqlRows返回空行", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id"} ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) dbs := []*sql.DB{ms.mockDB02, ms.mockDB03, ms.mockDB01} - res := make([]*sql.Rows, 0, len(dbs)) + res := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, _ := db.QueryContext(context.Background(), query) res = append(res, row) @@ -323,7 +325,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }, { name: "sqlRows列表中,在后面有多个个sqlRows返回空行", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id"} ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) @@ -331,7 +333,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) dbs := []*sql.DB{ms.mockDB02, ms.mockDB03, ms.mockDB01, ms.mockDB04} - res := make([]*sql.Rows, 0, len(dbs)) + res := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, _ := db.QueryContext(context.Background(), query) res = append(res, row) @@ -342,11 +344,11 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }, { name: "sqlRows列表中的元素均返回空行", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id"})) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id"})) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id"})) - res := make([]*sql.Rows, 0, 3) + res := make([]rows.Rows, 0, 3) row01, _ := ms.mockDB01.QueryContext(context.Background(), "SELECT * FROM `t1`;") res = append(res, row01) row02, _ := ms.mockDB02.QueryContext(context.Background(), "SELECT * FROM `t1`;") @@ -386,12 +388,12 @@ func (ms *MergerSuite) TestRows_NextAndScan() { func (ms *MergerSuite) TestRows_NextAndErr() { testcases := []struct { name string - rowsList func() []*sql.Rows + rowsList func() []rows.Rows wantErr error }{ { name: "sqlRows列表中有一个返回error", - rowsList: func() []*sql.Rows { + rowsList: func() []rows.Rows { cols := []string{"id"} query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) @@ -399,7 +401,7 @@ func (ms *MergerSuite) TestRows_NextAndErr() { ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4").RowError(1, nextMockErr)) ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("5")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB04, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -429,7 +431,7 @@ func (ms *MergerSuite) TestRows_ScanAndErr() { ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1).AddRow(5)) r, err := ms.mockDB01.QueryContext(context.Background(), query) require.NoError(t, err) - rowsList := []*sql.Rows{r} + rowsList := []rows.Rows{r} merger := NewMerger() rows, err := merger.Merge(context.Background(), rowsList) require.NoError(t, err) @@ -443,7 +445,7 @@ func (ms *MergerSuite) TestRows_ScanAndErr() { ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1).RowError(0, nextMockErr)) r, err := ms.mockDB01.QueryContext(context.Background(), query) require.NoError(t, err) - rowsList := []*sql.Rows{r} + rowsList := []rows.Rows{r} merger := NewMerger() rows, err := merger.Merge(context.Background(), rowsList) require.NoError(t, err) @@ -464,20 +466,20 @@ func (ms *MergerSuite) TestRows_Close() { ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4").CloseError(newCloseMockErr("db03"))) merger := NewMerger() dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) rowsList = append(rowsList, row) } - rows, err := merger.Merge(context.Background(), rowsList) + rs, err := merger.Merge(context.Background(), rowsList) require.NoError(ms.T(), err) // 判断当前是可以正常读取的 - require.True(ms.T(), rows.Next()) + require.True(ms.T(), rs.Next()) var id int - err = rows.Scan(&id) + err = rs.Scan(&id) require.NoError(ms.T(), err) - err = rows.Close() + err = rs.Close() ms.T().Run("close返回multierror", func(t *testing.T) { assert.Equal(ms.T(), multierr.Combine(newCloseMockErr("db02"), newCloseMockErr("db03")), err) }) @@ -485,20 +487,20 @@ func (ms *MergerSuite) TestRows_Close() { for i := 0; i < len(rowsList); i++ { require.False(ms.T(), rowsList[i].Next()) } - require.False(ms.T(), rows.Next()) + require.False(ms.T(), rs.Next()) }) ms.T().Run("close之后Scan返回迭代过程中的错误", func(t *testing.T) { var id int - err := rows.Scan(&id) + err := rs.Scan(&id) assert.Equal(t, errs.ErrMergerRowsClosed, err) }) ms.T().Run("close之后调用Columns方法返回错误", func(t *testing.T) { - _, err := rows.Columns() + _, err := rs.Columns() require.Error(t, err) }) ms.T().Run("close多次是等效的", func(t *testing.T) { for i := 0; i < 4; i++ { - err = rows.Close() + err = rs.Close() require.NoError(t, err) } }) @@ -512,7 +514,7 @@ func (ms *MergerSuite) TestRows_Columns() { ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) merger := NewMerger() dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -541,3 +543,7 @@ func (ms *MergerSuite) TestRows_Columns() { func TestMerger(t *testing.T) { suite.Run(t, &MergerSuite{}) } + +func TestRows_NextResultSet(t *testing.T) { + assert.False(t, (&Rows{}).NextResultSet()) +} diff --git a/internal/merger/groupby_merger/aggregator_merger.go b/internal/merger/groupby_merger/aggregator_merger.go index 56eae0af..a2b1b16f 100644 --- a/internal/merger/groupby_merger/aggregator_merger.go +++ b/internal/merger/groupby_merger/aggregator_merger.go @@ -21,11 +21,12 @@ import ( "sync" _ "unsafe" + "github.com/ecodeclub/eorm/internal/rows" + "github.com/ecodeclub/ekit/slice" "github.com/ecodeclub/ekit/sqlx" - "github.com/ecodeclub/eorm/internal/merger/utils" "go.uber.org/multierr" "github.com/ecodeclub/eorm/internal/merger" @@ -57,7 +58,7 @@ func NewAggregatorMerger(aggregators []aggregator.Aggregator, groupColumns []mer } // Merge 该实现会全部拿取results里面的数据,由于sql.Rows数据拿完之后会自动关闭,所以这边隐式的关闭了所有的sql.Rows -func (a *AggregatorMerger) Merge(ctx context.Context, results []*sql.Rows) (merger.Rows, error) { +func (a *AggregatorMerger) Merge(ctx context.Context, results []rows.Rows) (rows.Rows, error) { if ctx.Err() != nil { return nil, ctx.Err() } @@ -65,7 +66,7 @@ func (a *AggregatorMerger) Merge(ctx context.Context, results []*sql.Rows) (merg return nil, errs.ErrMergerEmptyRows } - if slice.Contains[*sql.Rows](results, nil) { + if slice.Contains[rows.Rows](results, nil) { return nil, errs.ErrMergerRowsIsNull } dataMap, dataIndex, err := a.getCols(results) @@ -85,7 +86,7 @@ func (a *AggregatorMerger) Merge(ctx context.Context, results []*sql.Rows) (merg }, nil } -func (a *AggregatorMerger) getCols(rowsList []*sql.Rows) (*mapx.TreeMap[Key, [][]any], []Key, error) { +func (a *AggregatorMerger) getCols(rowsList []rows.Rows) (*mapx.TreeMap[Key, [][]any], []Key, error) { treeMap, err := mapx.NewTreeMap[Key, [][]any](compareKey) if err != nil { return nil, nil, err @@ -125,7 +126,7 @@ func (a *AggregatorMerger) getCols(rowsList []*sql.Rows) (*mapx.TreeMap[Key, [][ } type AggregatorRows struct { - rowsList []*sql.Rows + rowsList []rows.Rows aggregators []aggregator.Aggregator groupColumns []merger.ColumnInfo dataMap *mapx.TreeMap[Key, [][]any] @@ -138,6 +139,14 @@ type AggregatorRows struct { cols []string } +func (a *AggregatorRows) ColumnTypes() ([]*sql.ColumnType, error) { + return a.rowsList[0].ColumnTypes() +} + +func (*AggregatorRows) NextResultSet() bool { + return false +} + // Next 返回列的顺序先分组信息然后是聚合函数信息 func (a *AggregatorRows) Next() bool { a.mu.Lock() @@ -184,7 +193,7 @@ func (a *AggregatorRows) Scan(dest ...any) error { return errs.ErrMergerScanNotNext } for i := 0; i < len(dest); i++ { - err := utils.ConvertAssign(dest[i], a.curData[i]) + err := rows.ConvertAssign(dest[i], a.curData[i]) if err != nil { return err } diff --git a/internal/merger/groupby_merger/aggregator_merger_test.go b/internal/merger/groupby_merger/aggregator_merger_test.go index 1f63bd32..deb1d6aa 100644 --- a/internal/merger/groupby_merger/aggregator_merger_test.go +++ b/internal/merger/groupby_merger/aggregator_merger_test.go @@ -20,6 +20,8 @@ import ( "errors" "testing" + "github.com/ecodeclub/eorm/internal/rows" + "github.com/ecodeclub/eorm/internal/merger" "github.com/DATA-DOG/go-sqlmock" @@ -87,7 +89,7 @@ func (ms *MergerSuite) TestAggregatorMerger_Merge() { testcases := []struct { name string aggregators []aggregator.Aggregator - rowsList []*sql.Rows + rowsList []rows.Rows GroupByColumns []merger.ColumnInfo wantErr error ctx func() (context.Context, context.CancelFunc) @@ -101,14 +103,14 @@ func (ms *MergerSuite) TestAggregatorMerger_Merge() { merger.NewColumnInfo(0, "county"), merger.NewColumnInfo(1, "gender"), }, - rowsList: func() []*sql.Rows { + rowsList: func() []rows.Rows { query := "SELECT `county`,`gender`,SUM(`id`) FROM `t1` GROUP BY `country`,`gender`" cols := []string{"county", "gender", "SUM(id)"} ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("hangzhou", "male", 10).AddRow("hangzhou", "female", 20).AddRow("shanghai", "female", 30)) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 40).AddRow("shanghai", "female", 50).AddRow("hangzhou", "female", 60)) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 70).AddRow("shanghai", "female", 80)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -130,14 +132,14 @@ func (ms *MergerSuite) TestAggregatorMerger_Merge() { GroupByColumns: []merger.ColumnInfo{ merger.NewColumnInfo(0, "user_name"), }, - rowsList: func() []*sql.Rows { + rowsList: func() []rows.Rows { query := "SELECT `user_name`,SUM(`id`) FROM `t1` GROUP BY `user_name`" cols := []string{"user_name", "SUM(id)"} ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 10).AddRow("dm", 20)) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("xz", 10).AddRow("zwl", 20)) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("dm", 20)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -159,8 +161,8 @@ func (ms *MergerSuite) TestAggregatorMerger_Merge() { GroupByColumns: []merger.ColumnInfo{ merger.NewColumnInfo(0, "user_name"), }, - rowsList: func() []*sql.Rows { - return []*sql.Rows{} + rowsList: func() []rows.Rows { + return []rows.Rows{} }(), ctx: func() (context.Context, context.CancelFunc) { ctx, cancel := context.WithCancel(context.Background()) @@ -176,8 +178,8 @@ func (ms *MergerSuite) TestAggregatorMerger_Merge() { GroupByColumns: []merger.ColumnInfo{ merger.NewColumnInfo(0, "user_name"), }, - rowsList: func() []*sql.Rows { - return []*sql.Rows{nil} + rowsList: func() []rows.Rows { + return []rows.Rows{nil} }(), ctx: func() (context.Context, context.CancelFunc) { ctx, cancel := context.WithCancel(context.Background()) @@ -193,14 +195,14 @@ func (ms *MergerSuite) TestAggregatorMerger_Merge() { GroupByColumns: []merger.ColumnInfo{ merger.NewColumnInfo(0, "user_name"), }, - rowsList: func() []*sql.Rows { + rowsList: func() []rows.Rows { query := "SELECT `user_name`,SUM(`id`) FROM `t1` GROUP BY `user_name`" cols := []string{"user_name", "SUM(id)"} ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 10).AddRow("dm", 20).RowError(1, nextMockErr)) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("xz", 10).AddRow("zwl", 20)) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("dm", 20)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -234,7 +236,7 @@ func (ms *MergerSuite) TestAggregatorRows_NextAndScan() { testcases := []struct { name string aggregators []aggregator.Aggregator - rowsList []*sql.Rows + rowsList []rows.Rows wantVal [][]any gotVal [][]any GroupByColumns []merger.ColumnInfo @@ -248,14 +250,14 @@ func (ms *MergerSuite) TestAggregatorRows_NextAndScan() { GroupByColumns: []merger.ColumnInfo{ merger.NewColumnInfo(0, "user_name"), }, - rowsList: func() []*sql.Rows { + rowsList: func() []rows.Rows { query := "SELECT `user_name`,COUNT(`id`) FROM `t1` GROUP BY `user_name`" cols := []string{"user_name", "SUM(id)"} ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 10).AddRow("dm", 20)) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("xz", 10).AddRow("zwl", 20)) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("dm", 20)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -282,14 +284,14 @@ func (ms *MergerSuite) TestAggregatorRows_NextAndScan() { GroupByColumns: []merger.ColumnInfo{ merger.NewColumnInfo(0, "user_name"), }, - rowsList: func() []*sql.Rows { + rowsList: func() []rows.Rows { query := "SELECT `user_name`,COUNT(`id`) FROM `t1` GROUP BY `user_name`" cols := []string{"user_name", "SUM(id)"} ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 10).AddRow("xm", 20)) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("xz", 10).AddRow("xx", 20)) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("dm", 20)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -321,14 +323,14 @@ func (ms *MergerSuite) TestAggregatorRows_NextAndScan() { merger.NewColumnInfo(0, "county"), merger.NewColumnInfo(1, "gender"), }, - rowsList: func() []*sql.Rows { + rowsList: func() []rows.Rows { query := "SELECT `county`,`gender`,SUM(`id`) FROM `t1` GROUP BY `country`,`gender`" cols := []string{"county", "gender", "SUM(id)"} ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("hangzhou", "male", 10).AddRow("hangzhou", "female", 20).AddRow("shanghai", "female", 30)) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 40).AddRow("shanghai", "female", 50).AddRow("hangzhou", "female", 60)) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 70).AddRow("shanghai", "female", 80)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -376,14 +378,14 @@ func (ms *MergerSuite) TestAggregatorRows_NextAndScan() { merger.NewColumnInfo(1, "gender"), }, - rowsList: func() []*sql.Rows { + rowsList: func() []rows.Rows { query := "SELECT `county`,`gender`,SUM(`id`),SUM(`age`),COUNT(`age`) FROM `t1` GROUP BY `country`,`gender`" cols := []string{"county", "gender", "SUM(id)", "SUM(age)", "COUNT(age)"} ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("hangzhou", "male", 10, 100, 2).AddRow("hangzhou", "female", 20, 120, 3).AddRow("shanghai", "female", 30, 90, 3)) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 40, 120, 5).AddRow("shanghai", "female", 50, 120, 4).AddRow("hangzhou", "female", 60, 150, 3)) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 70, 100, 5).AddRow("shanghai", "female", 80, 150, 5)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -457,7 +459,7 @@ func (ms *MergerSuite) TestAggregatorRows_ScanAndErr() { ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 10).AddRow(5, 20)) r, err := ms.mockDB01.QueryContext(context.Background(), query) require.NoError(t, err) - rowsList := []*sql.Rows{r} + rowsList := []rows.Rows{r} merger := NewAggregatorMerger([]aggregator.Aggregator{aggregator.NewSum(merger.NewColumnInfo(1, "SUM(id)"))}, []merger.ColumnInfo{merger.NewColumnInfo(0, "userid")}) rows, err := merger.Merge(context.Background(), rowsList) require.NoError(t, err) @@ -472,7 +474,7 @@ func (ms *MergerSuite) TestAggregatorRows_ScanAndErr() { ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 10).AddRow(5, 20)) r, err := ms.mockDB01.QueryContext(context.Background(), query) require.NoError(t, err) - rowsList := []*sql.Rows{r} + rowsList := []rows.Rows{r} merger := NewAggregatorMerger([]aggregator.Aggregator{&mockAggregate{}}, []merger.ColumnInfo{merger.NewColumnInfo(0, "userid")}) rows, err := merger.Merge(context.Background(), rowsList) require.NoError(t, err) @@ -488,14 +490,14 @@ func (ms *MergerSuite) TestAggregatorRows_ScanAndErr() { func (ms *MergerSuite) TestAggregatorRows_NextAndErr() { testcases := []struct { name string - rowsList func() []*sql.Rows + rowsList func() []rows.Rows wantErr error aggregators []aggregator.Aggregator GroupByColumns []merger.ColumnInfo }{ { name: "有一个aggregator返回error", - rowsList: func() []*sql.Rows { + rowsList: func() []rows.Rows { cols := []string{"username", "COUNT(id)"} query := "SELECT username,COUNT(`id`) FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 1)) @@ -503,7 +505,7 @@ func (ms *MergerSuite) TestAggregatorRows_NextAndErr() { ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("wu", 4)) ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("ming", 5)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -556,7 +558,7 @@ func (ms *MergerSuite) TestAggregatorRows_Columns() { } merger := NewAggregatorMerger(aggregators, groupbyColumns) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -591,6 +593,10 @@ func (m *mockAggregate) Aggregate(cols [][]any) (any, error) { return nil, aggregatorErr } -func (m *mockAggregate) ColumnName() string { +func (*mockAggregate) ColumnName() string { return "mockAggregate" } + +func TestAggregatorRows_NextResultSet(t *testing.T) { + assert.False(t, (&AggregatorRows{}).NextResultSet()) +} diff --git a/internal/merger/pagedmerger/merger.go b/internal/merger/pagedmerger/merger.go index ccb296f4..4f6c3387 100644 --- a/internal/merger/pagedmerger/merger.go +++ b/internal/merger/pagedmerger/merger.go @@ -19,6 +19,8 @@ import ( "database/sql" "sync" + "github.com/ecodeclub/eorm/internal/rows" + "github.com/ecodeclub/eorm/internal/merger" "github.com/ecodeclub/eorm/internal/merger/internal/errs" ) @@ -41,7 +43,7 @@ func NewMerger(m merger.Merger, offset int, limit int) (*Merger, error) { }, nil } -func (m *Merger) Merge(ctx context.Context, results []*sql.Rows) (merger.Rows, error) { +func (m *Merger) Merge(ctx context.Context, results []rows.Rows) (rows.Rows, error) { rs, err := m.m.Merge(ctx, results) if err != nil { return nil, err @@ -58,7 +60,7 @@ func (m *Merger) Merge(ctx context.Context, results []*sql.Rows) (merger.Rows, e } // nextOffset 会把游标挪到 offset 所指定的位置。 -func (m *Merger) nextOffset(ctx context.Context, rows merger.Rows) error { +func (m *Merger) nextOffset(ctx context.Context, rows rows.Rows) error { offset := m.offset for i := 0; i < offset; i++ { if ctx.Err() != nil { @@ -73,7 +75,7 @@ func (m *Merger) nextOffset(ctx context.Context, rows merger.Rows) error { } type Rows struct { - rows merger.Rows + rows rows.Rows limit int cnt int lastErr error @@ -81,6 +83,10 @@ type Rows struct { mu *sync.RWMutex } +func (*Rows) NextResultSet() bool { + return false +} + func (r *Rows) Next() bool { r.mu.Lock() if r.closed { @@ -137,6 +143,9 @@ func (r *Rows) Close() error { return r.rows.Close() } +func (r *Rows) ColumnTypes() ([]*sql.ColumnType, error) { + return r.rows.ColumnTypes() +} func (r *Rows) Columns() ([]string, error) { return r.rows.Columns() } diff --git a/internal/merger/pagedmerger/merger_test.go b/internal/merger/pagedmerger/merger_test.go index 52d25d57..9499cb5d 100644 --- a/internal/merger/pagedmerger/merger_test.go +++ b/internal/merger/pagedmerger/merger_test.go @@ -21,6 +21,8 @@ import ( "fmt" "testing" + "github.com/ecodeclub/eorm/internal/rows" + "github.com/DATA-DOG/go-sqlmock" "github.com/ecodeclub/eorm/internal/merger" "github.com/ecodeclub/eorm/internal/merger/internal/errs" @@ -132,7 +134,7 @@ func (ms *MergerSuite) TestMerger_Merge() { testcases := []struct { name string getMerger func() (merger.Merger, error) - GetRowsList func() []*sql.Rows + GetRowsList func() []rows.Rows wantErr error ctx func() (context.Context, context.CancelFunc) limit int @@ -143,8 +145,8 @@ func (ms *MergerSuite) TestMerger_Merge() { getMerger: func() (merger.Merger, error) { return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) }, - GetRowsList: func() []*sql.Rows { - return []*sql.Rows{} + GetRowsList: func() []rows.Rows { + return []rows.Rows{} }, wantErr: errs.ErrMergerEmptyRows, ctx: func() (context.Context, context.CancelFunc) { @@ -158,14 +160,14 @@ func (ms *MergerSuite) TestMerger_Merge() { getMerger: func() (merger.Merger, error) { return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) }, - GetRowsList: func() []*sql.Rows { + GetRowsList: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "bruce", "cn").RowError(1, offsetMockErr)) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "a", "cn").AddRow(7, "b", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -185,14 +187,14 @@ func (ms *MergerSuite) TestMerger_Merge() { getMerger: func() (merger.Merger, error) { return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) }, - GetRowsList: func() []*sql.Rows { + GetRowsList: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "bruce", "cn")) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "a", "cn").AddRow(7, "b", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -211,14 +213,14 @@ func (ms *MergerSuite) TestMerger_Merge() { getMerger: func() (merger.Merger, error) { return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) }, - GetRowsList: func() []*sql.Rows { + GetRowsList: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "bruce", "cn")) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "a", "cn").AddRow(7, "b", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -258,7 +260,7 @@ func (ms *MergerSuite) TestMerger_NextAndScan() { testcases := []struct { name string getMerger func() (merger.Merger, error) - GetRowsList func() []*sql.Rows + GetRowsList func() []rows.Rows wantVal []TestModel limit int offset int @@ -268,14 +270,14 @@ func (ms *MergerSuite) TestMerger_NextAndScan() { getMerger: func() (merger.Merger, error) { return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) }, - GetRowsList: func() []*sql.Rows { + GetRowsList: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -318,14 +320,14 @@ func (ms *MergerSuite) TestMerger_NextAndScan() { getMerger: func() (merger.Merger, error) { return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) }, - GetRowsList: func() []*sql.Rows { + GetRowsList: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -353,14 +355,14 @@ func (ms *MergerSuite) TestMerger_NextAndScan() { getMerger: func() (merger.Merger, error) { return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) }, - GetRowsList: func() []*sql.Rows { + GetRowsList: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -377,14 +379,14 @@ func (ms *MergerSuite) TestMerger_NextAndScan() { getMerger: func() (merger.Merger, error) { return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) }, - GetRowsList: func() []*sql.Rows { + GetRowsList: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -454,7 +456,7 @@ func (ms *MergerSuite) TestRows_NextAndErr() { testcases := []struct { name string getMerger func() (merger.Merger, error) - GetRowsList func() []*sql.Rows + GetRowsList func() []rows.Rows wantErr error limit int offset int @@ -464,14 +466,14 @@ func (ms *MergerSuite) TestRows_NextAndErr() { getMerger: func() (merger.Merger, error) { return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) }, - GetRowsList: func() []*sql.Rows { + GetRowsList: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn").RowError(1, limitMockErr)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -507,7 +509,7 @@ func (ms *MergerSuite) TestRows_ScanAndErr() { ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1).AddRow(5)) r, err := ms.mockDB01.QueryContext(context.Background(), query) require.NoError(t, err) - rowsList := []*sql.Rows{r} + rowsList := []rows.Rows{r} merger, err := sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) require.NoError(t, err) limitMerger, err := NewMerger(merger, 0, 1) @@ -524,7 +526,7 @@ func (ms *MergerSuite) TestRows_ScanAndErr() { ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1).AddRow(2).RowError(1, limitMockErr)) r, err := ms.mockDB01.QueryContext(context.Background(), query) require.NoError(t, err) - rowsList := []*sql.Rows{r} + rowsList := []rows.Rows{r} merger, err := sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) require.NoError(t, err) limitMerger, err := NewMerger(merger, 0, 1) @@ -550,7 +552,7 @@ func (ms *MergerSuite) TestRows_Close() { limitMerger, err := NewMerger(merger, 1, 6) require.NoError(ms.T(), err) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -601,7 +603,7 @@ func (ms *MergerSuite) TestRows_Columns() { limitMerger, err := NewMerger(merger, 0, 10) require.NoError(ms.T(), err) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -623,3 +625,7 @@ type TestModel struct { Name string Address string } + +func TestRows_NextResultSet(t *testing.T) { + assert.False(t, (&Rows{}).NextResultSet()) +} diff --git a/internal/merger/sortmerger/merger.go b/internal/merger/sortmerger/merger.go index 003ef3b9..8f8b04d7 100644 --- a/internal/merger/sortmerger/merger.go +++ b/internal/merger/sortmerger/merger.go @@ -21,9 +21,7 @@ import ( "reflect" "sync" - "github.com/ecodeclub/eorm/internal/merger" - - "github.com/ecodeclub/eorm/internal/merger/utils" + "github.com/ecodeclub/eorm/internal/rows" "go.uber.org/multierr" @@ -111,7 +109,7 @@ func newSortColumns(sortCols ...SortColumn) (sortColumns, error) { return scs, nil } -func (m *Merger) Merge(ctx context.Context, results []*sql.Rows) (merger.Rows, error) { +func (m *Merger) Merge(ctx context.Context, results []rows.Rows) (rows.Rows, error) { // 检测results是否符合条件 if ctx.Err() != nil { return nil, ctx.Err() @@ -130,7 +128,7 @@ func (m *Merger) Merge(ctx context.Context, results []*sql.Rows) (merger.Rows, e return m.initRows(results) } -func (m *Merger) initRows(results []*sql.Rows) (*Rows, error) { +func (m *Merger) initRows(results []rows.Rows) (*Rows, error) { rs := &Rows{ rowsList: results, sortColumns: m.sortColumns, @@ -152,7 +150,7 @@ func (m *Merger) initRows(results []*sql.Rows) (*Rows, error) { return rs, nil } -func (m *Merger) checkColumns(rows *sql.Rows) error { +func (m *Merger) checkColumns(rows rows.Rows) error { if rows == nil { return errs.ErrMergerRowsIsNull } @@ -183,7 +181,7 @@ func (m *Merger) checkColumns(rows *sql.Rows) error { return nil } -func newNode(row *sql.Rows, sortCols sortColumns, index int) (*node, error) { +func newNode(row rows.Rows, sortCols sortColumns, index int) (*node, error) { colsInfo, err := row.ColumnTypes() if err != nil { return nil, err @@ -221,7 +219,7 @@ func newNode(row *sql.Rows, sortCols sortColumns, index int) (*node, error) { } type Rows struct { - rowsList []*sql.Rows + rowsList []rows.Rows sortColumns sortColumns hp *Heap cur *node @@ -231,6 +229,14 @@ type Rows struct { columns []string } +func (r *Rows) ColumnTypes() ([]*sql.ColumnType, error) { + return r.rowsList[0].ColumnTypes() +} + +func (*Rows) NextResultSet() bool { + return false +} + func (r *Rows) Next() bool { r.mu.Lock() if r.closed { @@ -255,7 +261,7 @@ func (r *Rows) Next() bool { return true } -func (r *Rows) nextRows(row *sql.Rows, index int) error { +func (r *Rows) nextRows(row rows.Rows, index int) error { if row.Next() { n, err := newNode(row, r.sortColumns, index) if err != nil { @@ -282,7 +288,7 @@ func (r *Rows) Scan(dest ...any) error { } var err error for i := 0; i < len(dest); i++ { - err = utils.ConvertAssign(dest[i], r.cur.columns[i]) + err = rows.ConvertAssign(dest[i], r.cur.columns[i]) if err != nil { return err } diff --git a/internal/merger/sortmerger/merger_test.go b/internal/merger/sortmerger/merger_test.go index 399f1039..8ea3be1d 100644 --- a/internal/merger/sortmerger/merger_test.go +++ b/internal/merger/sortmerger/merger_test.go @@ -21,9 +21,12 @@ import ( "fmt" "testing" + _ "github.com/mattn/go-sqlite3" + + "github.com/ecodeclub/eorm/internal/rows" + "github.com/DATA-DOG/go-sqlmock" "github.com/ecodeclub/eorm/internal/merger/internal/errs" - _ "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -131,7 +134,7 @@ func (ms *MergerSuite) TestMerger_Merge() { merger func() (*Merger, error) ctx func() (context.Context, context.CancelFunc) wantErr error - sqlRows func() []*sql.Rows + sqlRows func() []rows.Rows }{ { name: "sqlRows字段不同", @@ -141,12 +144,12 @@ func (ms *MergerSuite) TestMerger_Merge() { ctx: func() (context.Context, context.CancelFunc) { return context.WithCancel(context.Background()) }, - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "address"}).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "email"}).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -164,12 +167,12 @@ func (ms *MergerSuite) TestMerger_Merge() { ctx: func() (context.Context, context.CancelFunc) { return context.WithCancel(context.Background()) }, - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "address"}).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(3, "alex").AddRow(4, "x")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -189,10 +192,10 @@ func (ms *MergerSuite) TestMerger_Merge() { return ctx, cancel }, wantErr: context.DeadlineExceeded, - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id"} - res := make([]*sql.Rows, 0, 1) + res := make([]rows.Rows, 0, 1) ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) rows, _ := ms.mockDB01.QueryContext(context.Background(), query) res = append(res, rows) @@ -207,8 +210,8 @@ func (ms *MergerSuite) TestMerger_Merge() { merger: func() (*Merger, error) { return NewMerger(NewSortColumn("id", ASC)) }, - sqlRows: func() []*sql.Rows { - return []*sql.Rows{} + sqlRows: func() []rows.Rows { + return []rows.Rows{} }, wantErr: errs.ErrMergerEmptyRows, }, @@ -220,8 +223,8 @@ func (ms *MergerSuite) TestMerger_Merge() { ctx: func() (context.Context, context.CancelFunc) { return context.WithCancel(context.Background()) }, - sqlRows: func() []*sql.Rows { - return []*sql.Rows{nil} + sqlRows: func() []rows.Rows { + return []rows.Rows{nil} }, wantErr: errs.ErrMergerRowsIsNull, }, @@ -230,10 +233,10 @@ func (ms *MergerSuite) TestMerger_Merge() { merger: func() (*Merger, error) { return NewMerger(NewSortColumn("age", ASC)) }, - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id"} - res := make([]*sql.Rows, 0, 1) + res := make([]rows.Rows, 0, 1) ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) rows, _ := ms.mockDB01.QueryContext(context.Background(), query) res = append(res, rows) @@ -249,10 +252,10 @@ func (ms *MergerSuite) TestMerger_Merge() { merger: func() (*Merger, error) { return NewMerger(NewSortColumn("id", ASC), NewSortColumn("age", ASC)) }, - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id"} - res := make([]*sql.Rows, 0, 1) + res := make([]rows.Rows, 0, 1) ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) rows, _ := ms.mockDB01.QueryContext(context.Background(), query) res = append(res, rows) @@ -268,10 +271,10 @@ func (ms *MergerSuite) TestMerger_Merge() { merger: func() (*Merger, error) { return NewMerger(NewSortColumn("age", ASC)) }, - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id", "name", "address"} - res := make([]*sql.Rows, 0, 1) + res := make([]rows.Rows, 0, 1) ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "zwl", "sh")) rows, _ := ms.mockDB01.QueryContext(context.Background(), query) res = append(res, rows) @@ -287,10 +290,10 @@ func (ms *MergerSuite) TestMerger_Merge() { merger: func() (*Merger, error) { return NewMerger(NewSortColumn("id", ASC), NewSortColumn("age", ASC), NewSortColumn("name", ASC)) }, - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id", "name", "address"} - res := make([]*sql.Rows, 0, 1) + res := make([]rows.Rows, 0, 1) ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "zwl", "sh")) rows, _ := ms.mockDB01.QueryContext(context.Background(), query) res = append(res, rows) @@ -306,10 +309,10 @@ func (ms *MergerSuite) TestMerger_Merge() { merger: func() (*Merger, error) { return NewMerger(NewSortColumn("id", ASC), NewSortColumn("name", ASC), NewSortColumn("age", ASC)) }, - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id", "name", "address"} - res := make([]*sql.Rows, 0, 1) + res := make([]rows.Rows, 0, 1) ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "zwl", "sh")) rows, _ := ms.mockDB01.QueryContext(context.Background(), query) res = append(res, rows) @@ -325,10 +328,10 @@ func (ms *MergerSuite) TestMerger_Merge() { merger: func() (*Merger, error) { return NewMerger(NewSortColumn("id", ASC)) }, - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id"} - res := make([]*sql.Rows, 0, 1) + res := make([]rows.Rows, 0, 1) ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) rows, _ := ms.mockDB01.QueryContext(context.Background(), query) res = append(res, rows) @@ -343,10 +346,10 @@ func (ms *MergerSuite) TestMerger_Merge() { merger: func() (*Merger, error) { return NewMerger(NewSortColumn("id", ASC), NewSortColumn("age", ASC)) }, - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id", "age"} - res := make([]*sql.Rows, 0, 1) + res := make([]rows.Rows, 0, 1) ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 18)) rows, _ := ms.mockDB01.QueryContext(context.Background(), query) res = append(res, rows) @@ -361,10 +364,10 @@ func (ms *MergerSuite) TestMerger_Merge() { merger: func() (*Merger, error) { return NewMerger(NewSortColumn("id", ASC), NewSortColumn("name", ASC)) }, - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id", "name", "address"} - res := make([]*sql.Rows, 0, 1) + res := make([]rows.Rows, 0, 1) ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "zwl", "sh")) rows, _ := ms.mockDB01.QueryContext(context.Background(), query) res = append(res, rows) @@ -379,10 +382,10 @@ func (ms *MergerSuite) TestMerger_Merge() { merger: func() (*Merger, error) { return NewMerger(NewSortColumn("id", ASC)) }, - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id", "name", "address"} - res := make([]*sql.Rows, 0, 1) + res := make([]rows.Rows, 0, 1) ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "zwl", "sh").RowError(0, nextMockErr)) rows, _ := ms.mockDB01.QueryContext(context.Background(), query) res = append(res, rows) @@ -412,24 +415,23 @@ func (ms *MergerSuite) TestMerger_Merge() { } func (ms *MergerSuite) TestRows_NextAndScan() { - testCases := []struct { name string - sqlRows func() []*sql.Rows + sqlRows func() []rows.Rows wantVal []TestModel sortColumns []SortColumn wantErr error }{ { name: "完全交叉读,sqlRows返回行数相同", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -475,14 +477,14 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }, { name: "完全交叉读,sqlRows返回行数部分不同", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(7, "b", "cn").AddRow(6, "x", "cn").AddRow(1, "x", "cn")) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(8, "alex", "cn").AddRow(4, "bruce", "cn")) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(9, "a", "cn").AddRow(5, "abex", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -534,7 +536,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { { // 包含一个sqlRows返回的行数为0,在前面 name: "完全交叉读,sqlRows返回行数完全不同", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "c", "cn").AddRow(2, "bruce", "cn").AddRow(2, "zwl", "cn")) @@ -542,7 +544,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "c", "cn").AddRow(3, "b", "cn").AddRow(5, "c", "cn").AddRow(7, "c", "cn")) ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) dbs := []*sql.DB{ms.mockDB04, ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -604,14 +606,14 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }, { name: "部分交叉读,sqlRows返回行数相同", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn")) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(5, "bruce", "cn")) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(4, "x", "cn").AddRow(7, "b", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -657,14 +659,14 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }, { name: "部分交叉读,sqlRows返回行数部分相同", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn").AddRow(5, "bruce", "cn")) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(7, "b", "cn").AddRow(8, "b", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -716,7 +718,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { { // 包含一个sqlRows返回的行数为0,在中间 name: "部分交叉读,sqlRows返回行数完全不同", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn").AddRow(5, "bruce", "cn")) @@ -724,7 +726,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(7, "b", "cn")) ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB04, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -770,14 +772,14 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }, { name: "顺序读,sqlRows返回行数相同", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn")) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "bruce", "cn").AddRow(7, "b", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -823,14 +825,14 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }, { name: "顺序读,sqlRows返回行数部分不同", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn")) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "bruce", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -873,7 +875,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { { // 包含一个sqlRows返回的行数为0,在后面 name: "顺序读,sqlRows返回行数完全不同", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn")) @@ -881,7 +883,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(4, "x", "cn").AddRow(5, "bruce", "cn").AddRow(7, "b", "cn")) ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -928,7 +930,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { { name: "所有sqlRows返回的行数均为空", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) @@ -936,7 +938,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -952,14 +954,14 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }, { name: "排序列返回的顺序和数据库里的字段顺序不一致", - sqlRows: func() []*sql.Rows { + sqlRows: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "hz").AddRow(3, "b", "hz").AddRow(2, "b", "cs")) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "a", "cs").AddRow(1, "a", "cs").AddRow(3, "e", "cn")) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "d", "hm").AddRow(5, "k", "xx").AddRow(4, "k", "xz")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -1050,7 +1052,7 @@ func (ms *MergerSuite) TestRows_Columns() { merger, err := NewMerger(NewSortColumn("id", DESC)) require.NoError(ms.T(), err) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -1086,7 +1088,7 @@ func (ms *MergerSuite) TestRows_Close() { merger, err := NewMerger(NewSortColumn("id", DESC)) require.NoError(ms.T(), err) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -1130,13 +1132,13 @@ func (ms *MergerSuite) TestRows_Close() { func (ms *MergerSuite) TestRows_NextAndErr() { testcases := []struct { name string - rowsList func() []*sql.Rows + rowsList func() []rows.Rows wantErr error sortColumns []SortColumn }{ { name: "sqlRows列表中有一个返回error", - rowsList: func() []*sql.Rows { + rowsList: func() []rows.Rows { cols := []string{"id"} query := "SELECT * FROM `t1`" ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) @@ -1144,7 +1146,7 @@ func (ms *MergerSuite) TestRows_NextAndErr() { ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4").RowError(1, nextMockErr)) ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("5")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { row, err := db.QueryContext(context.Background(), query) require.NoError(ms.T(), err) @@ -1179,7 +1181,7 @@ func (ms *MergerSuite) TestRows_ScanErr() { ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) r, err := ms.mockDB01.QueryContext(context.Background(), query) require.NoError(t, err) - rowsList := []*sql.Rows{r} + rowsList := []rows.Rows{r} merger, err := NewMerger(NewSortColumn("id", DESC)) require.NoError(t, err) rows, err := merger.Merge(context.Background(), rowsList) @@ -1194,7 +1196,7 @@ func (ms *MergerSuite) TestRows_ScanErr() { ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn").RowError(1, nextMockErr)) r, err := ms.mockDB01.QueryContext(context.Background(), query) require.NoError(t, err) - rowsList := []*sql.Rows{r} + rowsList := []rows.Rows{r} merger, err := NewMerger(NewSortColumn("id", DESC)) require.NoError(t, err) rows, err := merger.Merge(context.Background(), rowsList) @@ -1267,7 +1269,7 @@ func (ms *NullableMergerSuite) TearDownSuite() { func (ms *NullableMergerSuite) TestRows_Nullable() { testcases := []struct { name string - rowsList func() []*sql.Rows + rowsList func() []rows.Rows sortColumns []SortColumn wantErr error afterFunc func() @@ -1275,7 +1277,7 @@ func (ms *NullableMergerSuite) TestRows_Nullable() { }{ { name: "多个nullable类型排序 age asc,name desc", - rowsList: func() []*sql.Rows { + rowsList: func() []rows.Rows { db1InsertSql := []string{ "insert into t1 (id, name) values (1, 'zwl')", "insert into t1 (id, age, name) values (2, 10, 'zwl')", @@ -1304,7 +1306,7 @@ func (ms *NullableMergerSuite) TestRows_Nullable() { require.NoError(ms.T(), err) } dbs := []*sql.DB{ms.db01, ms.db02, ms.db03} - rowsList := make([]*sql.Rows, 0, len(dbs)) + rowsList := make([]rows.Rows, 0, len(dbs)) query := "SELECT `id`, `age`,`name` FROM `t1` order by age asc,name desc" for _, db := range dbs { rows, err := db.QueryContext(context.Background(), query) @@ -1401,3 +1403,7 @@ type Nullable struct { Age sql.NullInt64 Name sql.NullString } + +func TestRows_NextResultSet(t *testing.T) { + assert.False(t, (&Rows{}).NextResultSet()) +} diff --git a/internal/merger/type.go b/internal/merger/type.go index b0dd1a09..8cd5a7bd 100644 --- a/internal/merger/type.go +++ b/internal/merger/type.go @@ -16,7 +16,6 @@ package merger import ( "context" - "database/sql" "github.com/ecodeclub/eorm/internal/rows" ) @@ -24,7 +23,7 @@ import ( // Merger 将sql.Rows列表里的元素合并,返回一个类似sql.Rows的迭代器 // Merger sql.Rows列表中每个sql.Rows仅支持单个结果集且每个sql.Rows中列集必须完全相同。 type Merger interface { - Merge(ctx context.Context, results []*sql.Rows) (Rows, error) + Merge(ctx context.Context, results []rows.Rows) (rows.Rows, error) } type ColumnInfo struct { @@ -38,5 +37,3 @@ func NewColumnInfo(index int, name string) ColumnInfo { Name: name, } } - -type Rows = rows.Rows diff --git a/internal/merger/utils/convert_assign.go b/internal/rows/convert_assign.go similarity index 98% rename from internal/merger/utils/convert_assign.go rename to internal/rows/convert_assign.go index 9bb28287..31ea326d 100644 --- a/internal/merger/utils/convert_assign.go +++ b/internal/rows/convert_assign.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package utils +package rows import ( "database/sql/driver" diff --git a/internal/merger/utils/convert_assign_test.go b/internal/rows/convert_assign_test.go similarity index 99% rename from internal/merger/utils/convert_assign_test.go rename to internal/rows/convert_assign_test.go index 10599679..781c1fa1 100644 --- a/internal/merger/utils/convert_assign_test.go +++ b/internal/rows/convert_assign_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package utils +package rows import ( "database/sql" diff --git a/internal/rows/types.go b/internal/rows/types.go index 85fabd9d..8d1e2bde 100644 --- a/internal/rows/types.go +++ b/internal/rows/types.go @@ -14,15 +14,9 @@ package rows -import "database/sql" - -var _ Rows = &sql.Rows{} +import ( + "github.com/ecodeclub/ekit/sqlx" +) // Rows 各方法用法及语义尽可能与sql.Rows相同 -type Rows interface { - Next() bool - Scan(dest ...any) error - Close() error - Columns() ([]string, error) - Err() error -} +type Rows = sqlx.Rows diff --git a/sharding_select.go b/sharding_select.go index b305e7e8..e66d345c 100644 --- a/sharding_select.go +++ b/sharding_select.go @@ -16,9 +16,10 @@ package eorm import ( "context" - "database/sql" "sync" + "github.com/ecodeclub/eorm/internal/rows" + "github.com/ecodeclub/eorm/internal/merger/batchmerger" "github.com/ecodeclub/eorm/internal/sharding" @@ -304,7 +305,7 @@ func (s *ShardingSelector[T]) GetMulti(ctx context.Context) ([]*T, error) { return nil, err } - var rowsSlice []*sql.Rows + var rowsSlice []rows.Rows var eg errgroup.Group for _, query := range qs { q := query @@ -312,10 +313,10 @@ func (s *ShardingSelector[T]) GetMulti(ctx context.Context) ([]*T, error) { //s.lock.Lock() //defer s.lock.Unlock() // TODO 利用 ctx 传递 DB name - rows, err := s.db.queryContext(ctx, q) + rs, err := s.db.queryContext(ctx, q) if err == nil { s.lock.Lock() - rowsSlice = append(rowsSlice, rows) + rowsSlice = append(rowsSlice, rs) s.lock.Unlock() } return err @@ -375,7 +376,7 @@ func (s *ShardingSelector[T]) GetMultiV2(ctx context.Context) ([]*T, error) { } } - var rowsSlice []*sql.Rows + var rowsSlice []rows.Rows var eg errgroup.Group for _, query := range sdQs { q := query @@ -384,10 +385,10 @@ func (s *ShardingSelector[T]) GetMultiV2(ctx context.Context) ([]*T, error) { //s.lock.Lock() //defer s.lock.Unlock() // TODO 利用 ctx 传递 DB name - rows, err := s.db.queryContext(ctx, q) + rs, err := s.db.queryContext(ctx, q) if err == nil { s.lock.Lock() - rowsSlice = append(rowsSlice, rows) + rowsSlice = append(rowsSlice, rs) s.lock.Unlock() } return err @@ -399,15 +400,15 @@ func (s *ShardingSelector[T]) GetMultiV2(ctx context.Context) ([]*T, error) { } mgr := batchmerger.NewMerger() - rows, err := mgr.Merge(ctx, rowsSlice) + rs, err := mgr.Merge(ctx, rowsSlice) if err != nil { return nil, err } var res []*T - for rows.Next() { + for rs.Next() { tp := new(T) val := s.valCreator.NewPrimitiveValue(tp, s.meta) - if err = val.SetColumns(rows); err != nil { + if err = val.SetColumns(rs); err != nil { return nil, err } res = append(res, tp)