Skip to content

Commit

Permalink
sql: support query strings containing multiple statements
Browse files Browse the repository at this point in the history
  • Loading branch information
ligfx committed Jul 17, 2024
1 parent 29b0189 commit 6136116
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 6 deletions.
4 changes: 2 additions & 2 deletions internal/impl/sql/conn_fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,14 @@ func (c *connSettings) apply(ctx context.Context, db *sql.DB, log *service.Logge

c.initOnce.Do(func() {
for _, fileStmt := range c.initFileStatements {
if _, err := db.ExecContext(ctx, fileStmt[1]); err != nil {
if err := execMultiWithContext(db, ctx, fileStmt[1]); err != nil {
log.Warnf("Failed to execute init_file '%v': %v", fileStmt[0], err)
} else {
log.Debugf("Successfully ran init_file '%v'", fileStmt[0])
}
}
if c.initStatement != "" {
if _, err := db.ExecContext(ctx, c.initStatement); err != nil {
if err := execMultiWithContext(db, ctx, c.initStatement); err != nil {
log.Warnf("Failed to execute init_statement: %v", err)
} else {
log.Debug("Successfully ran init_statement")
Expand Down
2 changes: 1 addition & 1 deletion internal/impl/sql/input_sql_raw.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (s *sqlRawInput) Connect(ctx context.Context) (err error) {
}

var rows *sql.Rows
if rows, err = db.QueryContext(ctx, s.queryStatic, args...); err != nil {
if rows, err = queryMultiWithContext(db, ctx, s.queryStatic, args...); err != nil {
return
} else if err = rows.Err(); err != nil {
s.logger.With("err", err).Warnf("unexpected error while execute raw query %q", s.queryStatic)
Expand Down
136 changes: 136 additions & 0 deletions internal/impl/sql/multi_statement.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// Copyright 2024 Redpanda Data, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sql

import (
"context"
"database/sql"
"strings"
)

func splitSQLStatements(statement string) []string {
var result []string
startp := 0
p := 0
sawNonCommentOrSpace := false
for {
if p == len(statement) || statement[p] == ';' {
if p != len(statement) && statement[p] == ';' {
// include trailing semicolon
p++
}
statementPart := statement[startp:p]
if sawNonCommentOrSpace {
result = append(result, strings.TrimSpace(statementPart))
} else {
// coalesce any functionally "empty" statements into the previous statement
// so any configurations that have something like "statement; -- final comment"
// will still work
result[len(result)-1] += statementPart
}
if p == len(statement) {
break
}
startp = p
sawNonCommentOrSpace = false
} else if statement[p] == '\'' || statement[p] == '"' || statement[p] == '`' {
// single-quoted strings, double-quoted identifiers, and backtick-quoted identifiers
sentinel := statement[p]
p++
for p < len(statement) && statement[p] != sentinel {
p++
}
sawNonCommentOrSpace = true
} else if statement[p] == '#' ||
(p+1 < len(statement) && statement[p:p+2] == "--") ||
(p+1 < len(statement) && statement[p:p+2] == "//") {
// single-line comments starting with hash, double-dash, or double-slash
for p < len(statement) && statement[p] != '\n' {
p++
}
} else if p+1 < len(statement) && statement[p:p+2] == "/*" {
// multi-line comments starting with slash-asterisk
for p+1 < len(statement) && statement[p:p+2] != "*/" {
p++
}
} else if !(statement[p] == ' ' || statement[p] == '\t' || statement[p] == '\r' || statement[p] == '\n') {
sawNonCommentOrSpace = true
}
if p != len(statement) {
p++
}
}

return result
}

func execMultiWithContext(db *sql.DB, ctx context.Context, query string, args ...any) error {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() {
_ = tx.Rollback()
}()

statements := splitSQLStatements(query)
for _, part := range statements {
if _, err = tx.ExecContext(ctx, part, args...); err != nil {
return err
}
args = []any{}
}

if err = tx.Commit(); err != nil {
return err
}

// TODO: should this return anything for a result?
return nil
}

func queryMultiWithContext(db *sql.DB, ctx context.Context, query string, args ...any) (*sql.Rows, error) {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() {
_ = tx.Rollback()
}()

statements := splitSQLStatements(query)
var rows *sql.Rows
for i, part := range statements {
// this may not be useful to only give the args to the first query. but, principle of least surprise,
// make it act the same way that execMultiWithContext and the various drivers do.
if i < len(statements)-1 {
if _, err = tx.ExecContext(ctx, part, args...); err != nil {
return nil, err
}
} else {
rows, err = tx.QueryContext(ctx, part, args...)
if err != nil {
return nil, err
}
}
args = []any{}
}

if err = tx.Commit(); err != nil {
return nil, err
}

return rows, nil
}
84 changes: 84 additions & 0 deletions internal/impl/sql/multi_statement_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright 2024 Redpanda Data, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sql

import (
"testing"

"github.com/stretchr/testify/assert"
)

func assertSplitEquals(t *testing.T, message string, statement string, wanted []string) {
result := splitSQLStatements(statement)
assert.Equal(t, wanted, result, message)
}

func TestSplitStatements(t *testing.T) {
assertSplitEquals(t, "no semicolon", "select null", []string{"select null"})

assertSplitEquals(t, "basic semicolon", "select 1; select 2", []string{"select 1;", "select 2"})

assertSplitEquals(t, "semicolon in single-quoted string",
"select 'singlequoted;string'; select null",
[]string{"select 'singlequoted;string';", "select null"})

assertSplitEquals(t, "semicolon in double-quoted identifier",
"select \"doublequoted;ident\"; select null",
[]string{"select \"doublequoted;ident\";", "select null"})

assertSplitEquals(t, "semicolon in backtick-quoted identifier",
"select `backtick;ident`; select null",
[]string{"select `backtick;ident`;", "select null"})

assertSplitEquals(t, "semicolon in hash-comment", `
select #hash;comment
1; select 2
`, []string{"select #hash;comment\n\t\t1;", "select 2"})

assertSplitEquals(t, "semicolon in double-dash comment", `
select --double-dash;comment
1; select 2
`, []string{"select --double-dash;comment\n\t\t1;", "select 2"})

assertSplitEquals(t, "semicolon in double-slash comment", `
select //double-slash;comment
1; select 2
`, []string{"select //double-slash;comment\n\t\t1;", "select 2"})

assertSplitEquals(t, "semicolon in multi-line comment", `
select /*multi;
line;comment*/
1; select 2
`, []string{"select /*multi;\n\t\tline;comment*/\n\t\t1;", "select 2"})

assertSplitEquals(t, "semicolon at end should be single statement",
"select null;",
[]string{"select null;"})

assertSplitEquals(t, "comment with no newline should not fail",
"select null // comment with no newline",
[]string{"select null // comment with no newline"})

assertSplitEquals(t, "semicolon followed by comment at end should be single statement",
"select null; // trailing comment",
[]string{"select null; // trailing comment"})

assertSplitEquals(t, "coalesce empty statements into previous but not nonempty statements",
`select 1; // comment
;
select 2;`,
[]string{"select 1; // comment\n\t\t;", "select 2;"})

}
2 changes: 1 addition & 1 deletion internal/impl/sql/output_sql_raw.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ func (s *sqlRawOutput) WriteBatch(ctx context.Context, batch service.MessageBatc
}
}

if _, err := s.db.ExecContext(ctx, queryStr, args...); err != nil {
if err := execMultiWithContext(s.db, ctx, queryStr, args...); err != nil {
return err
}
}
Expand Down
4 changes: 2 additions & 2 deletions internal/impl/sql/processor_sql_raw.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,13 @@ func (s *sqlRawProcessor) ProcessBatch(ctx context.Context, batch service.Messag
}

if s.onlyExec {
if _, err := s.db.ExecContext(ctx, queryStr, args...); err != nil {
if err := execMultiWithContext(s.db, ctx, queryStr, args...); err != nil {
s.logger.Debugf("Failed to run query: %v", err)
msg.SetError(err)
continue
}
} else {
rows, err := s.db.QueryContext(ctx, queryStr, args...)
rows, err := queryMultiWithContext(s.db, ctx, queryStr, args...)
if err != nil {
s.logger.Debugf("Failed to run query: %v", err)
msg.SetError(err)
Expand Down

0 comments on commit 6136116

Please sign in to comment.