diff --git a/.CHANGELOG.md b/.CHANGELOG.md index 4607c1f8..bd620332 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -34,6 +34,7 @@ - [eorm: ShardingInserter 修改为表维度执行](https://github.com/ecodeclub/eorm/pull/211) - [merger: 使用 sqlx.Scanner 来读取数据](https://github.com/ecodeclub/eorm/pull/216) - [rows, merger: 使用 sqlx.Rows 作为接口,并重构 merger 包 ](https://github.com/ecodeclub/eorm/pull/217) +- [rows: 同库事务语句合并执行,提前读取所有数据](https://github.com/ecodeclub/eorm/pull/219) ## v0.0.1: diff --git a/db.go b/db.go index 1439a80b..1cabd0f2 100644 --- a/db.go +++ b/db.go @@ -39,7 +39,7 @@ type DBOption func(db *DB) // DB represents a database type DB struct { - core + baseSession ds datasource.DataSource } @@ -62,14 +62,6 @@ func UseReflection() DBOption { } } -func (db *DB) queryContext(ctx context.Context, q datasource.Query) (*sql.Rows, error) { - return db.ds.Query(ctx, q) -} - -func (db *DB) execContext(ctx context.Context, q datasource.Query) (sql.Result, error) { - return db.ds.Exec(ctx, q) -} - // Open 创建一个 ORM 实例 // 注意该实例是一个无状态的对象,你应该尽可能复用它 func Open(driver string, dsn string, opts ...DBOption) (*DB, error) { @@ -86,12 +78,15 @@ func OpenDS(driver string, ds datasource.DataSource, opts ...DBOption) (*DB, err return nil, err } orm := &DB{ - core: core{ - metaRegistry: model.NewMetaRegistry(), - dialect: dl, - // 可以设为默认,因为原本这里也有默认 - valCreator: valuer.PrimitiveCreator{ - Creator: valuer.NewUnsafeValue, + baseSession: baseSession{ + executor: ds, + core: core{ + metaRegistry: model.NewMetaRegistry(), + dialect: dl, + // 可以设为默认,因为原本这里也有默认 + valCreator: valuer.PrimitiveCreator{ + Creator: valuer.NewUnsafeValue, + }, }, }, ds: ds, @@ -111,13 +106,12 @@ func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { if err != nil { return nil, err } - return &Tx{tx: tx, core: db.getCore()}, nil + return &Tx{tx: tx, baseSession: baseSession{ + executor: tx, + core: db.core, + }}, nil } func (db *DB) Close() error { return db.ds.Close() } - -func (db *DB) getCore() core { - return db.core -} diff --git a/go.mod b/go.mod index f02222e2..78b91bf5 100644 --- a/go.mod +++ b/go.mod @@ -4,9 +4,8 @@ go 1.20 require ( github.com/DATA-DOG/go-sqlmock v1.5.0 - github.com/ecodeclub/ekit v0.0.4-0.20230904153403-e76aae064994 + github.com/ecodeclub/ekit v0.0.8-0.20231001021557-856d32ae850b github.com/go-sql-driver/mysql v1.6.0 - github.com/gotomicro/ekit v0.0.0-20230224040531-869798da3c4d github.com/mattn/go-sqlite3 v1.14.15 github.com/stretchr/testify v1.8.1 github.com/valyala/bytebufferpool v1.0.0 diff --git a/go.sum b/go.sum index 7b335064..d9d2533d 100644 --- a/go.sum +++ b/go.sum @@ -3,12 +3,10 @@ github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/ecodeclub/ekit v0.0.4-0.20230904153403-e76aae064994 h1:4Rp8WrJhISj8GDtnueoD22ygPuppajnCVZuEfRjg6w8= -github.com/ecodeclub/ekit v0.0.4-0.20230904153403-e76aae064994/go.mod h1:OqTojKeKFTxeeAAUwNIPKu339SRkX6KAuoK/8A5BCEs= +github.com/ecodeclub/ekit v0.0.8-0.20231001021557-856d32ae850b h1:T1OvEeJJEOhkrhkg55//A5kzX7lgdeX9gDJuVDahSpw= +github.com/ecodeclub/ekit v0.0.8-0.20231001021557-856d32ae850b/go.mod h1:OqTojKeKFTxeeAAUwNIPKu339SRkX6KAuoK/8A5BCEs= github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= -github.com/gotomicro/ekit v0.0.0-20230224040531-869798da3c4d h1:kmDgYRZ06UifBqAfew+cj02juQQ3Ko349NzsDIZ0QPw= -github.com/gotomicro/ekit v0.0.0-20230224040531-869798da3c4d/go.mod h1:ISYxgxcx3SOYGm/Hg9+M+pHVhN5G6W7p91/Pn7x6Hz8= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI= github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= diff --git a/internal/datasource/transaction/delay_transaction_test.go b/internal/datasource/transaction/delay_transaction_test.go index 3bdc5c4c..de283702 100644 --- a/internal/datasource/transaction/delay_transaction_test.go +++ b/internal/datasource/transaction/delay_transaction_test.go @@ -80,14 +80,14 @@ func (s *TestDelayTxTestSuite) TestExecute_Commit_Or_Rollback() { mockOrder: func(mock1, mock2 sqlmock.Sqlmock) {}, afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) {}, txFunc: func() (*eorm.Tx, error) { - s.DataSource = shardingsource.NewShardingDataSource(map[string]datasource.DataSource{ + ds := shardingsource.NewShardingDataSource(map[string]datasource.DataSource{ "1.db.cluster.company.com:3306": s.clusterDB, }) r := model.NewMetaRegistry() _, err := r.Register(&test.OrderDetail{}, model.WithTableShardingAlgorithm(s.algorithm)) require.NoError(t, err) - db, err := eorm.OpenDS("mysql", s.DataSource, eorm.DBWithMetaRegistry(r)) + db, err := eorm.OpenDS("mysql", ds, eorm.DBWithMetaRegistry(r)) require.NoError(t, err) return db.BeginTx(transaction.UsingTxType(context.Background(), transaction.Delay), &sql.TxOptions{}) }, @@ -98,7 +98,7 @@ func (s *TestDelayTxTestSuite) TestExecute_Commit_Or_Rollback() { mockOrder: func(mock1, mock2 sqlmock.Sqlmock) {}, afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) {}, txFunc: func() (*eorm.Tx, error) { - s.DataSource = shardingsource.NewShardingDataSource(map[string]datasource.DataSource{ + ds := shardingsource.NewShardingDataSource(map[string]datasource.DataSource{ "0.db.cluster.company.com:3306": masterslave.NewMasterSlavesDB(s.mockMaster1DB, masterslave.MasterSlavesWithSlaves( newSlaves(t, s.mockSlave1DB, s.mockSlave2DB, s.mockSlave3DB))), }) @@ -106,7 +106,7 @@ func (s *TestDelayTxTestSuite) TestExecute_Commit_Or_Rollback() { _, err := r.Register(&test.OrderDetail{}, model.WithTableShardingAlgorithm(s.algorithm)) require.NoError(t, err) - db, err := eorm.OpenDS("mysql", s.DataSource, eorm.DBWithMetaRegistry(r)) + db, err := eorm.OpenDS("mysql", ds, eorm.DBWithMetaRegistry(r)) require.NoError(t, err) return db.BeginTx(transaction.UsingTxType(context.Background(), transaction.Delay), &sql.TxOptions{}) }, @@ -123,14 +123,14 @@ func (s *TestDelayTxTestSuite) TestExecute_Commit_Or_Rollback() { "order_detail_db_0": masterslave.NewMasterSlavesDB(s.mockMaster1DB, masterslave.MasterSlavesWithSlaves( newSlaves(t, s.mockSlave1DB, s.mockSlave2DB, s.mockSlave3DB))), }) - s.DataSource = shardingsource.NewShardingDataSource(map[string]datasource.DataSource{ + ds := shardingsource.NewShardingDataSource(map[string]datasource.DataSource{ "0.db.cluster.company.com:3306": clusterDB, }) r := model.NewMetaRegistry() _, err := r.Register(&test.OrderDetail{}, model.WithTableShardingAlgorithm(s.algorithm)) require.NoError(t, err) - db, err := eorm.OpenDS("mysql", s.DataSource, eorm.DBWithMetaRegistry(r)) + db, err := eorm.OpenDS("mysql", ds, eorm.DBWithMetaRegistry(r)) require.NoError(t, err) return db.BeginTx(transaction.UsingTxType(context.Background(), transaction.Delay), &sql.TxOptions{}) }, @@ -483,10 +483,6 @@ func (s *TestDelayTxTestSuite) TestExecute_Commit_Or_Rollback() { rows := s.mockMaster2.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"}) s.mockMaster2.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_2` WHERE (`order_id`=?) OR (`order_id`=?);SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_1` WHERE (`order_id`=?) OR (`order_id`=?);")). WithArgs(199, 299, 199, 299).WillReturnRows(rows) - - queryVal := s.findTgt(t, values) - var wantOds []*test.OrderDetail - assert.ElementsMatch(t, wantOds, queryVal) }, }, } @@ -496,10 +492,9 @@ func (s *TestDelayTxTestSuite) TestExecute_Commit_Or_Rollback() { tx, err := tc.txFunc() require.NoError(t, err) - // TODO GetMultiV2 待将 table 维度改成 db 维度 querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx). Where(eorm.C("OrderId").NEQ(123)). - GetMultiV2(masterslave.UseMaster(context.Background())) + GetMulti(masterslave.UseMaster(context.Background())) assert.Equal(t, tc.wantErr, err) if err != nil { return diff --git a/internal/datasource/transaction/transaction_suite_test.go b/internal/datasource/transaction/transaction_suite_test.go index a7d768db..5496b227 100644 --- a/internal/datasource/transaction/transaction_suite_test.go +++ b/internal/datasource/transaction/transaction_suite_test.go @@ -161,9 +161,8 @@ func (s *ShardingTransactionSuite) findTgt(t *testing.T, values []*test.OrderDet od = values[i] pre = pre.Or(eorm.C(s.shardingKey).EQ(od.OrderId)) } - // TODO GetMultiV2 待将 table 维度改成 db 维度 querySet, err := eorm.NewShardingSelector[test.OrderDetail](s.shardingDB). - Where(pre).GetMultiV2(masterslave.UseMaster(context.Background())) + Where(pre).GetMulti(masterslave.UseMaster(context.Background())) require.NoError(t, err) return querySet } diff --git a/internal/integration/sharding_delay_transaction_test.go b/internal/integration/sharding_delay_transaction_test.go index 11b28253..7ac69adc 100644 --- a/internal/integration/sharding_delay_transaction_test.go +++ b/internal/integration/sharding_delay_transaction_test.go @@ -66,13 +66,13 @@ func (s *ShardingDelayTxTestSuite) TestDoubleShardingSelect() { defer tx.Commit() querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx). Where(eorm.C("OrderId").NEQ(123)). - GetMultiV2(masterslave.UseMaster(context.Background())) + GetMulti(masterslave.UseMaster(context.Background())) require.NoError(t, err) assert.ElementsMatch(t, tc.querySet, querySet) querySet, err = eorm.NewShardingSelector[test.OrderDetail](tx). Where(eorm.C("OrderId").NEQ(123)). - GetMultiV2(masterslave.UseMaster(context.Background())) + GetMulti(masterslave.UseMaster(context.Background())) require.NoError(t, err) assert.ElementsMatch(t, tc.querySet, querySet) }) @@ -228,7 +228,7 @@ func (s *ShardingDelayTxTestSuite) TestShardingSelectUpdateInsert_Commit_Or_Roll tx := tc.txFunc(t) querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx). Where(eorm.C("OrderId").NEQ(123)). - GetMultiV2(masterslave.UseMaster(context.Background())) + GetMulti(masterslave.UseMaster(context.Background())) require.NoError(t, err) assert.ElementsMatch(t, tc.querySet, querySet) diff --git a/internal/integration/sharding_single_transaction_test.go b/internal/integration/sharding_single_transaction_test.go index 127af56a..0832dba5 100644 --- a/internal/integration/sharding_single_transaction_test.go +++ b/internal/integration/sharding_single_transaction_test.go @@ -61,13 +61,13 @@ func (s *ShardingSingleTxTestSuite) TestDoubleShardingSelect() { defer tx.Commit() querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx). Where(eorm.C("OrderId").EQ(123)). - GetMultiV2(masterslave.UseMaster(context.Background())) + GetMulti(masterslave.UseMaster(context.Background())) require.NoError(t, err) assert.ElementsMatch(t, tc.querySet, querySet) querySet, err = eorm.NewShardingSelector[test.OrderDetail](tx). Where(eorm.C("OrderId").EQ(123)). - GetMultiV2(masterslave.UseMaster(context.Background())) + GetMulti(masterslave.UseMaster(context.Background())) require.NoError(t, err) assert.ElementsMatch(t, tc.querySet, querySet) }) @@ -137,7 +137,7 @@ func (s *ShardingSingleTxTestSuite) TestShardingSelectInsert_Commit_Or_Rollback( tx := tc.txFunc(t) querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx). Where(eorm.C("OrderId").EQ(123)). - GetMultiV2(masterslave.UseMaster(context.Background())) + GetMulti(masterslave.UseMaster(context.Background())) require.NoError(t, err) assert.ElementsMatch(t, tc.querySet, querySet) res := eorm.NewShardingInsert[test.OrderDetail](tx). @@ -220,7 +220,7 @@ func (s *ShardingSingleTxTestSuite) TestShardingSelectUpdate_Commit_Or_Rollback( tx := tc.txFunc(t) querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx). Where(eorm.C("OrderId").EQ(123)). - GetMultiV2(masterslave.UseMaster(context.Background())) + GetMulti(masterslave.UseMaster(context.Background())) require.NoError(t, err) assert.ElementsMatch(t, tc.querySet, querySet) res := eorm.NewShardingUpdater[test.OrderDetail](tx).Update(tc.target). diff --git a/internal/merger/groupby_merger/aggregator_merger.go b/internal/merger/groupby_merger/aggregator_merger.go index a2b1b16f..346a966d 100644 --- a/internal/merger/groupby_merger/aggregator_merger.go +++ b/internal/merger/groupby_merger/aggregator_merger.go @@ -29,10 +29,10 @@ import ( "go.uber.org/multierr" + "github.com/ecodeclub/ekit/mapx" "github.com/ecodeclub/eorm/internal/merger" "github.com/ecodeclub/eorm/internal/merger/aggregatemerger/aggregator" "github.com/ecodeclub/eorm/internal/merger/internal/errs" - "github.com/gotomicro/ekit/mapx" ) type AggregatorMerger struct { @@ -109,7 +109,7 @@ func (a *AggregatorMerger) getCols(rowsList []rows.Rows) (*mapx.TreeMap[Key, [][ val, ok := treeMap.Get(key) if ok { val = append(val, colData) - err = treeMap.Set(key, val) + err = treeMap.Put(key, val) if err != nil { return nil, nil, err } diff --git a/internal/merger/internal/errs/error.go b/internal/merger/internal/errs/error.go index 9982b037..9dc7e5f5 100644 --- a/internal/merger/internal/errs/error.go +++ b/internal/merger/internal/errs/error.go @@ -30,7 +30,6 @@ var ( ErrMergerAggregateHasEmptyRows = errors.New("merger: 聚合函数计算时rowsList有一个或多个为空") ErrMergerInvalidAggregateColumnIndex = errors.New("merger: ColumnInfo的index不合法") ErrMergerAggregateFuncNotFound = errors.New("merger: 聚合函数方法未找到") - ErrMergerNullable = errors.New("merger: 接收数据的类型需要为sql.Nullable") ) func NewRepeatSortColumn(column string) error { diff --git a/internal/rows/convert_assign.go b/internal/rows/convert_assign.go index 31ea326d..3e6795ee 100644 --- a/internal/rows/convert_assign.go +++ b/internal/rows/convert_assign.go @@ -15,6 +15,7 @@ package rows import ( + "database/sql" "database/sql/driver" _ "unsafe" ) @@ -31,5 +32,14 @@ func ConvertAssign(dest, src any) error { return err } } + // 预处理一下 sqlConvertAssign 不支持的转换,遇到一个加一个 + switch sv := src.(type) { + case sql.RawBytes: + switch dv := dest.(type) { + case *string: + *dv = string(sv) + return nil + } + } return sqlConvertAssign(dest, src) } diff --git a/internal/rows/data_rows.go b/internal/rows/data_rows.go new file mode 100644 index 00000000..a7a677ed --- /dev/null +++ b/internal/rows/data_rows.go @@ -0,0 +1,88 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rows + +import ( + "database/sql" + + "github.com/ecodeclub/eorm/internal/errs" +) + +var _ Rows = (*DataRows)(nil) + +// DataRows 直接传入数据,伪装成了一个 Rows +// 非线程安全实现 +type DataRows struct { + data [][]any + len int + columns []string + columnTypes []*sql.ColumnType + // 第几行 + idx int +} + +func (*DataRows) NextResultSet() bool { + return false +} + +func (d *DataRows) ColumnTypes() ([]*sql.ColumnType, error) { + return d.columnTypes, nil +} + +func NewDataRows(data [][]any, columns []string, columnTypes []*sql.ColumnType) *DataRows { + // 这里并没有什么必要检查 data 和 columns 的输入 + // 因为只有在很故意的情况下,data 和 columns 才可能会有问题 + return &DataRows{ + data: data, + len: len(data), + columns: columns, + idx: -1, + columnTypes: columnTypes, + } +} + +func (d *DataRows) Next() bool { + if d.idx >= d.len-1 { + return false + } + d.idx++ + return true +} + +func (d *DataRows) Scan(dest ...any) error { + // 不需要检测,作为内部代码我们可以预期用户会主动控制 + data := d.data[d.idx] + if len(data) != len(dest) { + return errs.NewErrScanWrongDestinationArguments(len(data), len(dest)) + } + for idx, dst := range dest { + if err := ConvertAssign(dst, data[idx]); err != nil { + return err + } + } + return nil +} + +func (*DataRows) Close() error { + return nil +} + +func (d *DataRows) Columns() ([]string, error) { + return d.columns, nil +} + +func (*DataRows) Err() error { + return nil +} diff --git a/internal/rows/data_rows_test.go b/internal/rows/data_rows_test.go new file mode 100644 index 00000000..b862ccf7 --- /dev/null +++ b/internal/rows/data_rows_test.go @@ -0,0 +1,162 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rows + +import ( + "errors" + "testing" + + "github.com/ecodeclub/ekit" + "github.com/ecodeclub/eorm/internal/errs" + "github.com/stretchr/testify/assert" +) + +func TestDataRows_Close(t *testing.T) { + rows := NewDataRows(nil, nil, nil) + assert.Nil(t, rows.Close()) +} + +func TestDataRows_Columns(t *testing.T) { + testCases := []struct { + name string + columns []string + }{ + { + name: "nil", + }, + { + name: "columns", + columns: []string{"column1"}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rows := NewDataRows(nil, tc.columns, nil) + columns, err := rows.Columns() + assert.NoError(t, err) + assert.Equal(t, tc.columns, columns) + }) + } +} + +func TestDataRows_Err(t *testing.T) { + rows := NewDataRows(nil, nil, nil) + assert.NoError(t, rows.Err()) +} + +func TestDataRows_Next(t *testing.T) { + testCases := []struct { + name string + data [][]any + beforeIdx int + + wantNext bool + afterIdx int + }{ + { + name: "nil", + wantNext: false, + beforeIdx: -1, + afterIdx: -1, + }, + { + name: "第一个", + data: [][]any{{1, 2, 3}}, + wantNext: true, + beforeIdx: -1, + afterIdx: 0, + }, + { + name: "还有一个", + data: [][]any{{1}, {2}}, + beforeIdx: 0, + wantNext: true, + afterIdx: 1, + }, + { + name: "到了最后一个", + data: [][]any{{1}, {2}}, + beforeIdx: 1, + wantNext: false, + afterIdx: 1, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rows := NewDataRows(tc.data, nil, nil) + rows.idx = tc.beforeIdx + assert.Equal(t, tc.wantNext, rows.Next()) + assert.Equal(t, tc.afterIdx, rows.idx) + }) + } +} + +func TestDataRows_Scan(t *testing.T) { + testCases := []struct { + name string + data [][]any + idx int + + input []any + wantRes []any + wantErr error + }{ + { + name: "获得了数据", + data: [][]any{{1, 2, 3}}, + input: []any{new(int), new(int32), new(int64)}, + wantRes: []any{ekit.ToPtr[int](1), + ekit.ToPtr[int32](2), ekit.ToPtr[int64](3)}, + wantErr: nil, + }, + { + name: "dst 过长", + data: [][]any{{1, 2, 3}}, + input: []any{new(int), new(int32), new(int64), new(int64)}, + wantErr: errs.NewErrScanWrongDestinationArguments(3, 4), + }, + { + name: "dst 过短", + data: [][]any{{1, 2, 3}}, + input: []any{new(int), new(int32)}, + wantErr: errs.NewErrScanWrongDestinationArguments(3, 2), + }, + { + name: "ConvertAndAssign错误", + data: [][]any{{1, "abc", 3}}, + input: []any{new(int), new(int64), new(int64)}, + wantErr: errors.New(`converting driver.Value type string ("abc") to a int64: invalid syntax`), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rows := NewDataRows(tc.data, nil, nil) + rows.idx = tc.idx + err := rows.Scan(tc.input...) + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + assert.Equal(t, tc.wantRes, tc.input) + }) + } +} + +func TestDataRows_NextResultSet(t *testing.T) { + // 固化行为,防止不小心改了 + rows := NewDataRows(nil, nil, nil) + assert.False(t, rows.NextResultSet()) +} diff --git a/internal/test/types.go b/internal/test/types.go index 641966fc..1ecf08bc 100644 --- a/internal/test/types.go +++ b/internal/test/types.go @@ -21,7 +21,7 @@ import ( "encoding/json" "fmt" - "github.com/gotomicro/ekit" + "github.com/ecodeclub/ekit" ) // SimpleStruct 包含所有 eorm 支持的类型 diff --git a/internal/test/types_test.go b/internal/test/types_test.go index 908caf48..4abf9b23 100644 --- a/internal/test/types_test.go +++ b/internal/test/types_test.go @@ -17,7 +17,7 @@ package test import ( "testing" - "github.com/gotomicro/ekit" + "github.com/ecodeclub/ekit" "github.com/stretchr/testify/assert" ) diff --git a/script/integrate_test.sh b/script/integrate_test.sh index 1eae897b..737553bb 100644 --- a/script/integrate_test.sh +++ b/script/integrate_test.sh @@ -4,5 +4,5 @@ set -e docker compose -f script/integration_test_compose.yml down docker compose -f script/integration_test_compose.yml up -d echo "127.0.0.1 slave.a.com" >> /etc/hosts -go test -race ./... -tags=e2e +go test -timeout=30m -race ./... -tags=e2e docker compose -f script/integration_test_compose.yml down diff --git a/session.go b/session.go new file mode 100644 index 00000000..034ad4ca --- /dev/null +++ b/session.go @@ -0,0 +1,62 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package eorm + +import ( + "context" + "database/sql" + + "github.com/ecodeclub/ekit/list" + "github.com/ecodeclub/eorm/internal/datasource" + "github.com/ecodeclub/eorm/internal/rows" + "golang.org/x/sync/errgroup" +) + +var _ Session = (*baseSession)(nil) + +type baseSession struct { + core + executor datasource.Executor +} + +func (sess *baseSession) queryContext(ctx context.Context, q Query) (rows.Rows, error) { + return sess.executor.Query(ctx, q) +} + +func (sess *baseSession) queryMulti(ctx context.Context, qs []Query) (list.List[rows.Rows], error) { + res := &list.ConcurrentList[rows.Rows]{ + List: list.NewArrayList[rows.Rows](len(qs)), + } + var eg errgroup.Group + for _, query := range qs { + q := query + eg.Go(func() error { + rs, err := sess.queryContext(ctx, q) + if err == nil { + return res.Append(rs) + } + return err + }) + } + return res, eg.Wait() +} + +func (sess *baseSession) execContext(ctx context.Context, q Query) (sql.Result, error) { + return sess.executor.Exec(ctx, q) +} + +func (sess *baseSession) getCore() core { + return sess.core +} diff --git a/sharding_select.go b/sharding_select.go index 539bae6d..13a9de7d 100644 --- a/sharding_select.go +++ b/sharding_select.go @@ -18,15 +18,12 @@ import ( "context" "sync" - "github.com/ecodeclub/eorm/internal/rows" - "github.com/ecodeclub/eorm/internal/merger/batchmerger" "github.com/ecodeclub/eorm/internal/sharding" "github.com/ecodeclub/eorm/internal/errs" "github.com/valyala/bytebufferpool" - "golang.org/x/sync/errgroup" ) type ShardingSelector[T any] struct { @@ -304,35 +301,16 @@ func (s *ShardingSelector[T]) GetMulti(ctx context.Context) ([]*T, error) { if err != nil { return nil, err } - - var rowsSlice []rows.Rows - var eg errgroup.Group - for _, query := range sdQs { - q := query - //fmt.Println(q.String()) - eg.Go(func() error { - //s.lock.Lock() - //defer s.lock.Unlock() - // TODO 利用 ctx 传递 DB name - rs, err := s.db.queryContext(ctx, q) - if err == nil { - s.lock.Lock() - rowsSlice = append(rowsSlice, rs) - s.lock.Unlock() - } - return err - }) - } - err = eg.Wait() + mgr := batchmerger.NewMerger() + rowsList, err := s.db.queryMulti(ctx, qs) if err != nil { return nil, err } - - mgr := batchmerger.NewMerger() - rows, err := mgr.Merge(ctx, rowsSlice) + rows, err := mgr.Merge(ctx, rowsList.AsSlice()) if err != nil { return nil, err } + defer rows.Close() var res []*T for rows.Next() { tp := new(T) @@ -345,78 +323,6 @@ func (s *ShardingSelector[T]) GetMulti(ctx context.Context) ([]*T, error) { return res, nil } -func (s *ShardingSelector[T]) GetMultiV2(ctx context.Context) ([]*T, error) { - qs, err := s.Build(ctx) - if err != nil { - return nil, err - } - var sdQs []sharding.Query - dsMap := make(map[string]map[string]sharding.Query, 8) - for _, q := range qs { - dbMap, ok := dsMap[q.Datasource] - if !ok { - dsMap[q.Datasource] = map[string]sharding.Query{q.DB: q} - continue - } - old, ok := dbMap[q.DB] - if !ok { - if dbMap == nil { - dbMap = make(map[string]sharding.Query, 8) - } - dbMap[q.DB] = q - continue - } - old.SQL = old.SQL + q.SQL - old.Args = append(old.Args, q.Args...) - dbMap[q.DB] = old - dsMap[q.Datasource] = dbMap - } - for _, dbMap := range dsMap { - for _, q := range dbMap { - sdQs = append(sdQs, q) - } - } - - var rowsSlice []rows.Rows - var eg errgroup.Group - for _, query := range sdQs { - q := query - //fmt.Println(q.String()) - eg.Go(func() error { - //s.lock.Lock() - //defer s.lock.Unlock() - // TODO 利用 ctx 传递 DB name - rs, err := s.db.queryContext(ctx, q) - if err == nil { - s.lock.Lock() - rowsSlice = append(rowsSlice, rs) - s.lock.Unlock() - } - return err - }) - } - err = eg.Wait() - if err != nil { - return nil, err - } - - mgr := batchmerger.NewMerger() - rs, err := mgr.Merge(ctx, rowsSlice) - if err != nil { - return nil, err - } - var res []*T - for rs.Next() { - tp := new(T) - val := s.valCreator.NewPrimitiveValue(tp, s.meta) - if err = val.SetColumns(rs); err != nil { - return nil, err - } - res = append(res, tp) - } - return res, nil -} - // Select 指定查询的列。 // 列可以是物理列,也可以是聚合函数,或者 RawExpr func (s *ShardingSelector[T]) Select(columns ...Selectable) *ShardingSelector[T] { diff --git a/transaction.go b/transaction.go index 681b5138..ee953d0c 100644 --- a/transaction.go +++ b/transaction.go @@ -18,24 +18,116 @@ import ( "context" "database/sql" + "github.com/ecodeclub/ekit/list" + "github.com/ecodeclub/ekit/mapx" + "github.com/ecodeclub/ekit/sqlx" + "github.com/ecodeclub/eorm/internal/rows" + "github.com/valyala/bytebufferpool" + "golang.org/x/sync/errgroup" + "github.com/ecodeclub/eorm/internal/datasource" ) type Tx struct { - core + baseSession tx datasource.Tx } -func (t *Tx) getCore() core { - return t.core +func (t *Tx) queryMulti(ctx context.Context, qs []Query) (list.List[rows.Rows], error) { + // 事务在查询的时候,需要将同一个 DB 上的语句合并在一起 + // 参考 https://github.com/ecodeclub/eorm/discussions/213 + mp := mapx.NewMultiBuiltinMap[string, Query](len(qs)) + for _, q := range qs { + if err := mp.Put(q.DB+"_"+q.Datasource, q); err != nil { + return nil, err + } + } + keys := mp.Keys() + rowsList := &list.ConcurrentList[rows.Rows]{ + List: list.NewArrayList[rows.Rows](len(keys)), + } + var eg errgroup.Group + for _, key := range keys { + dbQs, _ := mp.Get(key) + eg.Go(func() error { + return t.execDBQueries(ctx, dbQs, rowsList) + }) + } + return rowsList, eg.Wait() +} + +// execDBQueries 执行某个 DB 上的全部查询。 +// 执行结果会被加入进去 rowsList 里面。虽然这种修改传入参数的做法不是很好,但是作为一个内部方法还是可以接受的。 +func (t *Tx) execDBQueries(ctx context.Context, dbQs []Query, rowsList *list.ConcurrentList[rows.Rows]) error { + qsCnt := len(dbQs) + // 考虑到大部分都只有一个查询,我们做一个快路径的优化。 + if qsCnt == 1 { + rs, err := t.tx.Query(ctx, dbQs[0]) + if err != nil { + return err + } + return rowsList.Append(rs) + } + // 慢路径,也就是必须要把同一个库的查询合并在一起 + q := t.mergeDBQueries(dbQs) + rs, err := t.tx.Query(ctx, q) + if err != nil { + return err + } + // 查询之后,事务必须再次按照结果集分割开。 + // 这样是为了让结果集的数量和查询数量保持一致。 + return t.splitTxResultSet(rowsList, rs) } -func (t *Tx) queryContext(ctx context.Context, query datasource.Query) (*sql.Rows, error) { - return t.tx.Query(ctx, query) +func (t *Tx) splitTxResultSet(list list.List[rows.Rows], rs *sql.Rows) error { + cs, err := rs.Columns() + if err != nil { + return err + } + ct, err := rs.ColumnTypes() + if err != nil { + return err + } + scanner, err := sqlx.NewSQLRowsScanner(rs) + if err != nil { + return err + } + // 虽然这里我们可以尝试不读取最后一个 ResultSet + // 但是这个优化目前来说不准备做, + // 防止用户出现因为类型转换遇到一些潜在的问题 + // 数据库类型到 GO 类型再到用户希望的类型,是一个漫长的过程。 + hasNext := true + for hasNext { + var data [][]any + data, err = scanner.ScanAll() + if err != nil { + return err + } + err = list.Append(rows.NewDataRows(data, cs, ct)) + if err != nil { + return err + } + hasNext = scanner.NextResultSet() + } + return nil } -func (t *Tx) execContext(ctx context.Context, query datasource.Query) (sql.Result, error) { - return t.tx.Exec(ctx, query) +func (t *Tx) mergeDBQueries(dbQs []Query) Query { + buffer := bytebufferpool.Get() + defer bytebufferpool.Put(buffer) + first := dbQs[0] + // 预估有多少查询参数,一个查询的参数个数 * 查询个数 + args := make([]any, 0, len(first.Args)*len(dbQs)) + for _, dbQ := range dbQs { + _, _ = buffer.WriteString(dbQ.SQL) + args = append(args, dbQ.Args...) + } + return Query{ + SQL: buffer.String(), + Args: args, + DB: first.DB, + Datasource: first.Datasource, + } } func (t *Tx) Commit() error { diff --git a/types.go b/types.go index 373bf9de..22bf9c11 100644 --- a/types.go +++ b/types.go @@ -18,7 +18,8 @@ import ( "context" "database/sql" - "github.com/ecodeclub/eorm/internal/datasource" + "github.com/ecodeclub/ekit/list" + "github.com/ecodeclub/eorm/internal/rows" ) // Executor sql 语句执行器 @@ -34,6 +35,7 @@ type QueryBuilder interface { // Session 代表一个抽象的概念,即会话 type Session interface { getCore() core - queryContext(ctx context.Context, query datasource.Query) (*sql.Rows, error) - execContext(ctx context.Context, query datasource.Query) (sql.Result, error) + queryMulti(ctx context.Context, qs []Query) (list.List[rows.Rows], error) + queryContext(ctx context.Context, query Query) (rows.Rows, error) + execContext(ctx context.Context, query Query) (sql.Result, error) }