Skip to content

Commit

Permalink
增加在SQL审核阶段自动识别并合并相同表的alter table语句的功能 (#669)
Browse files Browse the repository at this point in the history
* 增加在SQL审核阶段自动识别并合并相同表的alter table语句的功能
  • Loading branch information
jiweixiao authored Aug 25, 2024
1 parent 261854d commit 7c6a1c0
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 7 deletions.
1 change: 1 addition & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ type Binlog struct {

// Inc is the inception section of the config.
type Inc struct {
AlterAutoMerge bool `toml:"alter_auto_merge" json:"alter_auto_merge"`
BackupHost string `toml:"backup_host" json:"backup_host"` // 远程备份库信息
BackupPassword string `toml:"backup_password" json:"backup_password"`
BackupPort uint `toml:"backup_port" json:"backup_port"`
Expand Down
9 changes: 8 additions & 1 deletion session/inception_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ type Record struct {
// delete多表时,默认delete后第一个表为主表,其余表才会记录到该处
// 仅在发现多表操作时,初始化该参数
MultiTables map[string]*TableInfo

// 判断该语句是否是需要被合并的(只有 alter table, create index, drop index三种语句需要被合并),不需要为0,已经被合并过的SQL会被设置为-1,需要的数字为对应的合并后的SQL的行号
NeedMerge int
}

func (r *Record) appendWarningMessage(msg string) {
Expand Down Expand Up @@ -297,7 +300,7 @@ func NewRecordSets() *MyRecordSets {
fieldCount: 0,
}

rc.fields = make([]*ast.ResultField, 12)
rc.fields = make([]*ast.ResultField, 13)

// 序号
rc.CreateFiled("order_id", mysql.TypeLong)
Expand All @@ -321,6 +324,8 @@ func NewRecordSets() *MyRecordSets {
rc.CreateFiled("sqlsha1", mysql.TypeString)
// 备份用时
rc.CreateFiled("backup_time", mysql.TypeString)
// 判断该语句是否是需要被合并的(只有 alter table, create index, drop index三种语句需要被合并),不需要为0,已经被合并过的SQL会被设置为-1,需要的数字为对应的合并后的SQL的行号
rc.CreateFiled("needMerge", mysql.TypeTiny)

t.rc = rc
return t
Expand Down Expand Up @@ -394,6 +399,8 @@ func (s *MyRecordSets) setFields(r *Record) {
row[11].SetString(r.BackupCostTime)
}

row[12].SetValue(r.NeedMerge)

s.rc.data[s.rc.count] = row
s.rc.count++
}
Expand Down
12 changes: 12 additions & 0 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,19 @@ func (h *StmtHistory) Count() int {
return len(h.history)
}

// jwx added
type alterTableInfo struct {
Name string
alterStmtList []ast.AlterTableStmt
mergedSql string
recordSetsPosList []int // 记录当前语句在s.recordSets里的位置,用于修改needMerge字段
}

type session struct {

//jwx added
alterTableInfoList []alterTableInfo

// processInfo is used by ShowProcess(), and should be modified atomically.
processInfo atomic.Value
txn TxnState
Expand Down
127 changes: 122 additions & 5 deletions session/session_inception.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,36 @@ func (s *session) executeInc(ctx context.Context, sql string) (recordSets []sqle
s.initDisableTypes()
continue
case *ast.InceptionCommitStmt:
/******* jwx added 将对同一个表的多条alter语句合并成一条 ******/
if s.inc.AlterAutoMerge {
for _, info := range s.alterTableInfoList {
if len(info.alterStmtList) >= 2 {
merged := info.alterStmtList[0]
for seq, alterStmt := range info.alterStmtList {
if seq > 0 {
merged.Specs = append(merged.Specs, alterStmt.Specs...)
}
}
var builder strings.Builder
_ = merged.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &builder))
info.mergedSql = builder.String()
mergedRecord := &Record{
Sql: info.mergedSql,
Buf: new(bytes.Buffer),
Type: &merged,
Stage: StageCheck,
ErrorMessage: "MERGED",
NeedMerge: -1,
}
s.recordSets.Append(mergedRecord)
for _, pos := range info.recordSetsPosList {
s.recordSets.records[pos].NeedMerge = s.recordSets.SeqNo
}
}

}
}
/****************/

if !s.haveBegin {
s.appendErrorMsg("Must start as begin statement.")
Expand Down Expand Up @@ -606,7 +636,7 @@ func (s *session) processCommand(ctx context.Context, stmtNode ast.StmtNode,
case *ast.CreateTableStmt:
s.checkCreateTable(node, currentSql)
case *ast.AlterTableStmt:
s.checkAlterTable(node, currentSql)
s.checkAlterTable(node, currentSql, false)
case *ast.DropTableStmt:
s.checkDropTable(node, currentSql)
case *ast.RenameTableStmt:
Expand All @@ -629,11 +659,24 @@ func (s *session) processCommand(ctx context.Context, stmtNode ast.StmtNode,
if node.KeyType == ast.IndexKeyTypeFullText {
tp = ast.ConstraintFulltext
}
s.checkCreateIndex(node.Table, node.IndexName,
node.IndexColNames, node.IndexOption, nil, node.Unique, tp)
if !s.inc.AlterAutoMerge { // jwx added
s.checkCreateIndex(node.Table, node.IndexName,
node.IndexColNames, node.IndexOption, nil, node.Unique, tp)
} else {
alter := s.convertCreateIndexToAlterTable(node)
s.checkAlterTable(alter, node.Text(), true)
s.checkCreateIndex(node.Table, node.IndexName,
node.IndexColNames, node.IndexOption, nil, node.Unique, tp)
}

case *ast.DropIndexStmt:
s.checkDropIndex(node, currentSql)
if !s.inc.AlterAutoMerge { // jwx added
s.checkDropIndex(node, currentSql)
} else {
alter := s.convertDropIndexToAlterTable(node)
s.checkAlterTable(alter, node.Text(), true)
s.checkDropIndex(node, currentSql)
}

case *ast.CreateViewStmt:
s.checkCreateView(node, currentSql)
Expand Down Expand Up @@ -3294,7 +3337,7 @@ func (s *session) checkTableCharsetCollation(character, collation string) {
}
}

func (s *session) checkAlterTable(node *ast.AlterTableStmt, sql string) {
func (s *session) checkAlterTable(node *ast.AlterTableStmt, sql string, mergeOnly bool) {
log.Debug("checkAlterTable")

if node.Table.Schema.O == "" {
Expand All @@ -3310,6 +3353,34 @@ func (s *session) checkAlterTable(node *ast.AlterTableStmt, sql string) {
return
}

/*********** jwx added **********/
if s.inc.AlterAutoMerge {
tableNameInString := fmt.Sprintf("%s.%s", node.Table.Schema.O, node.Table.Name.O)
var found bool = false
var seq int = 0
for j, i := range s.alterTableInfoList {
if tableNameInString == i.Name {
found = true
seq = j
break
}
}
if found {
s.alterTableInfoList[seq].alterStmtList = append(s.alterTableInfoList[seq].alterStmtList, *node)
s.alterTableInfoList[seq].recordSetsPosList = append(s.alterTableInfoList[seq].recordSetsPosList, s.recordSets.SeqNo)
} else {
var info alterTableInfo = alterTableInfo{Name: tableNameInString}
info.alterStmtList = append(info.alterStmtList, *node)
info.recordSetsPosList = append(info.recordSetsPosList, s.recordSets.SeqNo)
s.alterTableInfoList = append(s.alterTableInfoList, info)
}

if mergeOnly {
return
}
}
/******************************/

table.AlterCount += 1

if table.AlterCount > 1 {
Expand Down Expand Up @@ -5508,6 +5579,52 @@ func (s *session) checkAddConstraint(t *TableInfo, c *ast.AlterTableSpec) {
}
}

func (s *session) convertCreateIndexToAlterTable(node *ast.CreateIndexStmt) *ast.AlterTableStmt {
log.Debug("convertCreateIndexToAlterTable")
var alter *ast.AlterTableStmt = &ast.AlterTableStmt{Specs: []*ast.AlterTableSpec{}}
var spec *ast.AlterTableSpec = &ast.AlterTableSpec{Tp: ast.AlterTableAddConstraint, Constraint: &ast.Constraint{}}
spec.IfNotExists = node.IfNotExists
spec.Constraint.Name = node.IndexName
if node.Unique {
spec.Constraint.Tp = ast.ConstraintUniq
} else {
spec.Constraint.Tp = ast.ConstraintIndex
}
spec.Constraint.Keys = node.IndexColNames
spec.Constraint.Option = node.IndexOption
if node.LockAlg != nil {
spec.LockType = node.LockAlg.LockTp
spec.Algorithm = node.LockAlg.AlgorithmTp
} else {
spec.LockType = 0
spec.Algorithm = 0
}
spec.Partition = node.Partition
alter.SetText(node.Text())
alter.Table = node.Table
alter.Specs = append(alter.Specs, spec)
return alter
}

func (s *session) convertDropIndexToAlterTable(node *ast.DropIndexStmt) *ast.AlterTableStmt {
log.Debug("convertDropIndexToAlterTable")
var alter *ast.AlterTableStmt = &ast.AlterTableStmt{Specs: []*ast.AlterTableSpec{}}
var spec *ast.AlterTableSpec = &ast.AlterTableSpec{Tp: ast.AlterTableDropIndex}
spec.IfExists = node.IfExists
spec.Name = node.IndexName
if node.LockAlg != nil {
spec.LockType = node.LockAlg.LockTp
spec.Algorithm = node.LockAlg.AlgorithmTp
} else {
spec.LockType = 0
spec.Algorithm = 0
}
alter.SetText(node.Text())
alter.Table = node.Table
alter.Specs = append(alter.Specs, spec)
return alter
}

func (s *session) checkDBExists(db string, reportNotExists bool) bool {

if db == "" {
Expand Down
3 changes: 2 additions & 1 deletion session/tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ func RegisterStore(name string, driver kv.Driver) error {
// session.Open() but with the dbname cut off.
// Examples:
// goleveldb://relative/path
// boltdb:///absolute/path

// boltdb:///absolute/path
//
// The engine should be registered before creating storage.
func NewStore(path string) (kv.Storage, error) {
Expand Down

0 comments on commit 7c6a1c0

Please sign in to comment.