Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

增加在SQL审核阶段自动识别并合并相同表的alter table语句的功能 #669

Merged
merged 4 commits into from
Aug 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ type Binlog struct {

// Inc is the inception section of the config.
type Inc struct {
AlterAutoMerge bool `toml:"alter_auto_merge" json:"alter_auto_merge"`
BackupHost string `toml:"backup_host" json:"backup_host"` // 远程备份库信息
BackupPassword string `toml:"backup_password" json:"backup_password"`
BackupPort uint `toml:"backup_port" json:"backup_port"`
Expand Down
9 changes: 8 additions & 1 deletion session/inception_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ type Record struct {
// delete多表时,默认delete后第一个表为主表,其余表才会记录到该处
// 仅在发现多表操作时,初始化该参数
MultiTables map[string]*TableInfo

// 判断该语句是否是需要被合并的(只有 alter table, create index, drop index三种语句需要被合并),不需要为0,已经被合并过的SQL会被设置为-1,需要的数字为对应的合并后的SQL的行号
NeedMerge int
}

func (r *Record) appendWarningMessage(msg string) {
Expand Down Expand Up @@ -297,7 +300,7 @@ func NewRecordSets() *MyRecordSets {
fieldCount: 0,
}

rc.fields = make([]*ast.ResultField, 12)
rc.fields = make([]*ast.ResultField, 13)

// 序号
rc.CreateFiled("order_id", mysql.TypeLong)
Expand All @@ -321,6 +324,8 @@ func NewRecordSets() *MyRecordSets {
rc.CreateFiled("sqlsha1", mysql.TypeString)
// 备份用时
rc.CreateFiled("backup_time", mysql.TypeString)
// 判断该语句是否是需要被合并的(只有 alter table, create index, drop index三种语句需要被合并),不需要为0,已经被合并过的SQL会被设置为-1,需要的数字为对应的合并后的SQL的行号
rc.CreateFiled("needMerge", mysql.TypeTiny)

t.rc = rc
return t
Expand Down Expand Up @@ -394,6 +399,8 @@ func (s *MyRecordSets) setFields(r *Record) {
row[11].SetString(r.BackupCostTime)
}

row[12].SetValue(r.NeedMerge)

s.rc.data[s.rc.count] = row
s.rc.count++
}
Expand Down
12 changes: 12 additions & 0 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,19 @@ func (h *StmtHistory) Count() int {
return len(h.history)
}

// jwx added
type alterTableInfo struct {
Name string
alterStmtList []ast.AlterTableStmt
mergedSql string
recordSetsPosList []int // 记录当前语句在s.recordSets里的位置,用于修改needMerge字段
}

type session struct {

//jwx added
alterTableInfoList []alterTableInfo

// processInfo is used by ShowProcess(), and should be modified atomically.
processInfo atomic.Value
txn TxnState
Expand Down
127 changes: 122 additions & 5 deletions session/session_inception.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,36 @@ func (s *session) executeInc(ctx context.Context, sql string) (recordSets []sqle
s.initDisableTypes()
continue
case *ast.InceptionCommitStmt:
/******* jwx added 将对同一个表的多条alter语句合并成一条 ******/
if s.inc.AlterAutoMerge {
for _, info := range s.alterTableInfoList {
if len(info.alterStmtList) >= 2 {
merged := info.alterStmtList[0]
for seq, alterStmt := range info.alterStmtList {
if seq > 0 {
merged.Specs = append(merged.Specs, alterStmt.Specs...)
}
}
var builder strings.Builder
_ = merged.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &builder))
info.mergedSql = builder.String()
mergedRecord := &Record{
Sql: info.mergedSql,
Buf: new(bytes.Buffer),
Type: &merged,
Stage: StageCheck,
ErrorMessage: "MERGED",
NeedMerge: -1,
}
s.recordSets.Append(mergedRecord)
for _, pos := range info.recordSetsPosList {
s.recordSets.records[pos].NeedMerge = s.recordSets.SeqNo
}
}

}
}
/****************/

if !s.haveBegin {
s.appendErrorMsg("Must start as begin statement.")
Expand Down Expand Up @@ -606,7 +636,7 @@ func (s *session) processCommand(ctx context.Context, stmtNode ast.StmtNode,
case *ast.CreateTableStmt:
s.checkCreateTable(node, currentSql)
case *ast.AlterTableStmt:
s.checkAlterTable(node, currentSql)
s.checkAlterTable(node, currentSql, false)
case *ast.DropTableStmt:
s.checkDropTable(node, currentSql)
case *ast.RenameTableStmt:
Expand All @@ -629,11 +659,24 @@ func (s *session) processCommand(ctx context.Context, stmtNode ast.StmtNode,
if node.KeyType == ast.IndexKeyTypeFullText {
tp = ast.ConstraintFulltext
}
s.checkCreateIndex(node.Table, node.IndexName,
node.IndexColNames, node.IndexOption, nil, node.Unique, tp)
if !s.inc.AlterAutoMerge { // jwx added
s.checkCreateIndex(node.Table, node.IndexName,
node.IndexColNames, node.IndexOption, nil, node.Unique, tp)
} else {
alter := s.convertCreateIndexToAlterTable(node)
s.checkAlterTable(alter, node.Text(), true)
s.checkCreateIndex(node.Table, node.IndexName,
node.IndexColNames, node.IndexOption, nil, node.Unique, tp)
}

case *ast.DropIndexStmt:
s.checkDropIndex(node, currentSql)
if !s.inc.AlterAutoMerge { // jwx added
s.checkDropIndex(node, currentSql)
} else {
alter := s.convertDropIndexToAlterTable(node)
s.checkAlterTable(alter, node.Text(), true)
s.checkDropIndex(node, currentSql)
}

case *ast.CreateViewStmt:
s.checkCreateView(node, currentSql)
Expand Down Expand Up @@ -3294,7 +3337,7 @@ func (s *session) checkTableCharsetCollation(character, collation string) {
}
}

func (s *session) checkAlterTable(node *ast.AlterTableStmt, sql string) {
func (s *session) checkAlterTable(node *ast.AlterTableStmt, sql string, mergeOnly bool) {
log.Debug("checkAlterTable")

if node.Table.Schema.O == "" {
Expand All @@ -3310,6 +3353,34 @@ func (s *session) checkAlterTable(node *ast.AlterTableStmt, sql string) {
return
}

/*********** jwx added **********/
if s.inc.AlterAutoMerge {
tableNameInString := fmt.Sprintf("%s.%s", node.Table.Schema.O, node.Table.Name.O)
var found bool = false
var seq int = 0
for j, i := range s.alterTableInfoList {
if tableNameInString == i.Name {
found = true
seq = j
break
}
}
if found {
s.alterTableInfoList[seq].alterStmtList = append(s.alterTableInfoList[seq].alterStmtList, *node)
s.alterTableInfoList[seq].recordSetsPosList = append(s.alterTableInfoList[seq].recordSetsPosList, s.recordSets.SeqNo)
} else {
var info alterTableInfo = alterTableInfo{Name: tableNameInString}
info.alterStmtList = append(info.alterStmtList, *node)
info.recordSetsPosList = append(info.recordSetsPosList, s.recordSets.SeqNo)
s.alterTableInfoList = append(s.alterTableInfoList, info)
}

if mergeOnly {
return
}
}
/******************************/

table.AlterCount += 1

if table.AlterCount > 1 {
Expand Down Expand Up @@ -5508,6 +5579,52 @@ func (s *session) checkAddConstraint(t *TableInfo, c *ast.AlterTableSpec) {
}
}

func (s *session) convertCreateIndexToAlterTable(node *ast.CreateIndexStmt) *ast.AlterTableStmt {
log.Debug("convertCreateIndexToAlterTable")
var alter *ast.AlterTableStmt = &ast.AlterTableStmt{Specs: []*ast.AlterTableSpec{}}
var spec *ast.AlterTableSpec = &ast.AlterTableSpec{Tp: ast.AlterTableAddConstraint, Constraint: &ast.Constraint{}}
spec.IfNotExists = node.IfNotExists
spec.Constraint.Name = node.IndexName
if node.Unique {
spec.Constraint.Tp = ast.ConstraintUniq
} else {
spec.Constraint.Tp = ast.ConstraintIndex
}
spec.Constraint.Keys = node.IndexColNames
spec.Constraint.Option = node.IndexOption
if node.LockAlg != nil {
spec.LockType = node.LockAlg.LockTp
spec.Algorithm = node.LockAlg.AlgorithmTp
} else {
spec.LockType = 0
spec.Algorithm = 0
}
spec.Partition = node.Partition
alter.SetText(node.Text())
alter.Table = node.Table
alter.Specs = append(alter.Specs, spec)
return alter
}

func (s *session) convertDropIndexToAlterTable(node *ast.DropIndexStmt) *ast.AlterTableStmt {
log.Debug("convertDropIndexToAlterTable")
var alter *ast.AlterTableStmt = &ast.AlterTableStmt{Specs: []*ast.AlterTableSpec{}}
var spec *ast.AlterTableSpec = &ast.AlterTableSpec{Tp: ast.AlterTableDropIndex}
spec.IfExists = node.IfExists
spec.Name = node.IndexName
if node.LockAlg != nil {
spec.LockType = node.LockAlg.LockTp
spec.Algorithm = node.LockAlg.AlgorithmTp
} else {
spec.LockType = 0
spec.Algorithm = 0
}
alter.SetText(node.Text())
alter.Table = node.Table
alter.Specs = append(alter.Specs, spec)
return alter
}

func (s *session) checkDBExists(db string, reportNotExists bool) bool {

if db == "" {
Expand Down
3 changes: 2 additions & 1 deletion session/tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ func RegisterStore(name string, driver kv.Driver) error {
// session.Open() but with the dbname cut off.
// Examples:
// goleveldb://relative/path
// boltdb:///absolute/path

// boltdb:///absolute/path
//
// The engine should be registered before creating storage.
func NewStore(path string) (kv.Storage, error) {
Expand Down
Loading