diff --git a/canal/canal_test.go b/canal/canal_test.go index 954c3648c..333729997 100644 --- a/canal/canal_test.go +++ b/canal/canal_test.go @@ -323,3 +323,85 @@ func TestWithoutSchemeExp(t *testing.T) { } } } + +func TestCreateIndexExp(t *testing.T) { + cases := []string{ + "create index test0 on test.test (id)", + "create index test0 ON test.test (id)", + "CREATE INDEX test0 on `test`.test (id)", + "CREATE INDEX test0 ON test.test (id)", + "CREATE index test0 on `test`.test (id)", + "CREATE index test0 ON test.test (id)", + "create INDEX test0 on `test`.test (id)", + "create INDEX test0 ON test.test (id)", + "CREATE INDEX `test0` ON `test`.`test` (`id`) /* generated by server */", + "CREATE /*generated by server */ INDEX `test0` ON `test`.`test` (`id`)", + "CREATE INDEX `test0` ON `test`.test (id)", + "CREATE INDEX `test0` ON test.`test` (id)", + "CREATE INDEX `test0` ON test.test (`id`)", + "CREATE INDEX test0 ON `test`.`test` (`id`)", + "CREATE INDEX test0 ON `test`.`test` (id)", + "CREATE INDEX test0 ON test.test (`id`)", + } + + baseTable := "test" + db := "test" + pr := parser.New() + for _, s := range cases { + stmts, _, err := pr.Parse(s, "", "") + require.NoError(t, err) + for _, st := range stmts { + nodes := parseStmt(st) + require.NotZero(t, nodes) + for _, node := range nodes { + rdb := node.db + rtable := node.table + require.Equal(t, db, rdb) + require.Equal(t, baseTable, rtable) + } + } + } +} + +func TestDropIndexExp(t *testing.T) { + cases := []string{ + "drop index test0 on test.test", + "DROP INDEX test0 ON test.test", + "drop INDEX test0 on test.test", + "DROP index test0 ON test.test", + "drop INDEX `test0` on `test`.`test`", + "drop INDEX test0 ON `test`.`test`", + "drop INDEX test0 on `test`.test", + "drop INDEX test0 on test.`test`", + "DROP index `test0` on `test`.`test`", + "DROP index test0 ON `test`.`test`", + "DROP index test0 on `test`.test", + "DROP index test0 on test.`test`", + "DROP INDEX `test0` ON `test`.`test` /* generated by server */", + "DROP /*generated by server */ INDEX `test0` ON `test`.`test`", + "DROP INDEX `test0` ON `test`.test", + "DROP INDEX `test0` ON test.`test`", + "DROP INDEX `test0` ON test.test", + "DROP INDEX test0 ON `test`.`test`", + "DROP INDEX test0 ON `test`.`test`", + "DROP INDEX test0 ON test.test", + } + + baseTable := "test" + db := "test" + pr := parser.New() + for _, s := range cases { + stmts, _, err := pr.Parse(s, "", "") + require.NoError(t, err) + for _, st := range stmts { + nodes := parseStmt(st) + require.NotZero(t, nodes) + for _, node := range nodes { + rdb := node.db + rtable := node.table + require.Equal(t, db, rdb) + require.Equal(t, baseTable, rtable) + } + } + } +} diff --git a/canal/sync.go b/canal/sync.go index b71be536a..b7006b17e 100644 --- a/canal/sync.go +++ b/canal/sync.go @@ -216,6 +216,18 @@ func parseStmt(stmt ast.StmtNode) (ns []*node) { table: t.Table.Name.String(), } ns = []*node{n} + case *ast.CreateIndexStmt: + n := &node{ + db: t.Table.Schema.String(), + table: t.Table.Name.String(), + } + ns = []*node{n} + case *ast.DropIndexStmt: + n := &node{ + db: t.Table.Schema.String(), + table: t.Table.Name.String(), + } + ns = []*node{n} } return ns }