Skip to content

Commit

Permalink
Fix regex parser for parsing functions having SQL body with language …
Browse files Browse the repository at this point in the history
…sql (PG 15 feature) (#2201)

SQL body syntax - https://www.postgresql.org/docs/15/sql-createfunction.html#:~:text=a%20new%20session.-,sql_body,-The%20body%20of, where with language SQL
```
BEGIN ATOMIC;
....
END;
```
or
```
language sql
RETURN ...;
```
  • Loading branch information
priyanshi-yb authored Jan 22, 2025
1 parent d2aa0ad commit d8f1f49
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 8 deletions.
29 changes: 24 additions & 5 deletions yb-voyager/cmd/analyzeSchema.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ var (
parserIssueDetector = queryissue.NewParserIssueDetector()
multiRegex = regexp.MustCompile(`([a-zA-Z0-9_\.]+[,|;])`)
dollarQuoteRegex = regexp.MustCompile(`(\$.*\$)`)
sqlBodyBeginRegex = re("BEGIN", "ATOMIC")
//TODO: optional but replace every possible space or new line char with [\s\n]+ in all regexs
viewWithCheckRegex = re("VIEW", capture(ident), anything, "WITH", opt(commonClause), "CHECK", "OPTION")
rangeRegex = re("PRECEDING", "and", anything, ":float")
Expand Down Expand Up @@ -911,7 +912,6 @@ sqlParsingLoop:

stmt += currLine + " "
formattedStmt += currLine + "\n"

// Assuming that both the dollar quote strings will not be in same line
switch dollarQuoteFlag {
case CODE_BLOCK_NOT_STARTED:
Expand All @@ -920,14 +920,30 @@ sqlParsingLoop:
} else if matches := dollarQuoteRegex.FindStringSubmatch(currLine); matches != nil {
dollarQuoteFlag = 1 //denotes start of the code/body part
codeBlockDelimiter = matches[0]
} else if matches := sqlBodyBeginRegex.FindStringSubmatch(currLine); matches != nil {
dollarQuoteFlag = 1 //denotes start of the sql body part https://www.postgresql.org/docs/15/sql-createfunction.html#:~:text=a%20new%20session.-,sql_body,-The%20body%20of
codeBlockDelimiter = "END" //SQL body to determine the end of BEGIN ATOMIC ... END; sql body
}
case CODE_BLOCK_STARTED:
if strings.Contains(currLine, codeBlockDelimiter) {
dollarQuoteFlag = 2 //denotes end of code/body part
if isEndOfSqlStmt(currLine) {
break sqlParsingLoop
switch codeBlockDelimiter {
case "END":
if strings.Contains(currLine, codeBlockDelimiter) ||
strings.Contains(currLine, strings.ToLower(codeBlockDelimiter)) {
//TODO: anyways we should be using pg-parser: but for now for the END sql body delimiter checking the UPPER and LOWER both
dollarQuoteFlag = 2 //denotes end of code/body part
if isEndOfSqlStmt(currLine) {
break sqlParsingLoop
}
}
default:
if strings.Contains(currLine, codeBlockDelimiter) {
dollarQuoteFlag = 2 //denotes end of code/body part
if isEndOfSqlStmt(currLine) {
break sqlParsingLoop
}
}
}

case CODE_BLOCK_COMPLETED:
if isEndOfSqlStmt(currLine) {
break sqlParsingLoop
Expand Down Expand Up @@ -971,6 +987,9 @@ func isEndOfSqlStmt(line string) bool {
line = line[0:cmtStartIdx] // ignore comment
line = strings.TrimRight(line, " ")
}
if len(line) == 0 {
return false
}
return line[len(line)-1] == ';'
}

Expand Down
191 changes: 188 additions & 3 deletions yb-voyager/cmd/analyzeSchema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,12 @@ CREATE TABLE another_table (
t.Errorf("Error creating file for the objType %s: %v", objType, err)
}


defer os.Remove(sqlFile.Name())

sqlInfoArr := parseSqlFileForObjectType(sqlFile.Name(), objType)
// Validate the number of SQL statements found
if len(sqlInfoArr) != len(expectedSqlInfoArr) {
t.Errorf("Expected %d SQL statements for %s, got %d",len(expectedSqlInfoArr), objType, len(sqlInfoArr))
t.Errorf("Expected %d SQL statements for %s, got %d", len(expectedSqlInfoArr), objType, len(sqlInfoArr))
}

for i, expectedSqlInfo := range expectedSqlInfoArr {
Expand Down Expand Up @@ -140,9 +139,195 @@ $$ LANGUAGE plpgsql;`,

// Validate the number of SQL statements found
if len(sqlInfoArr) != len(expectedSqlInfoArr) {
t.Errorf("Expected %d SQL statements for %s, got %d",len(expectedSqlInfoArr), objType, len(sqlInfoArr))
t.Errorf("Expected %d SQL statements for %s, got %d", len(expectedSqlInfoArr), objType, len(sqlInfoArr))
}

for i, expectedSqlInfo := range expectedSqlInfoArr {
assert.Equal(t, expectedSqlInfo.objName, sqlInfoArr[i].objName)
assert.Equal(t, expectedSqlInfo.stmt, sqlInfoArr[i].stmt)
assert.Equal(t, expectedSqlInfo.formattedStmt, sqlInfoArr[i].formattedStmt)
}

}

func TestFunctionSQLFile(t *testing.T) {
functionFileContent := `CREATE FUNCTION public.asterisks(n integer) RETURNS SETOF text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
BEGIN ATOMIC
SELECT repeat('*'::text, g.g) AS repeat
FROM generate_series(1, asterisks.n) g(g);
END;
CREATE OR REPLACE FUNCTION copy_high_earners(threshold NUMERIC) RETURNS VOID AS $$
DECLARE
temp_salary employees.salary%TYPE;
BEGIN
CREATE TEMP TABLE temp_high_earners AS
SELECT * FROM employees WHERE salary > threshold;
FOR temp_salary IN SELECT salary FROM temp_high_earners LOOP
RAISE NOTICE 'High earner salary: %', temp_salary;
END LOOP;
END;
$$ LANGUAGE plpgsql;
CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END;
CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE
BEGIN ATOMIC; SELECT $1 + $2; END;
CREATE FUNCTION public.case_sensitive_test(n integer) RETURNS SETOF text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
begin atomic
SELECT repeat('*'::text, g.g) AS repeat
FROM generate_series(1, asterisks.n) g(g);
end;
CREATE FUNCTION public.asterisks1(n integer) RETURNS text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
RETURN repeat('*'::text, n);
CREATE FUNCTION add(integer, integer) RETURNS integer
AS 'select test;'
LANGUAGE SQL
IMMUTABLE
RETURNS NULL ON NULL INPUT;
CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$
BEGIN
RETURN i + 1;
END;
$$ LANGUAGE plpgsql;
CREATE FUNCTION public.dup(integer, OUT f1 integer, OUT f2 text) RETURNS record
LANGUAGE sql
AS $_$ SELECT $1, CAST($1 AS text) || ' is text' $_$;
CREATE FUNCTION check_password(uname TEXT, pass TEXT)
RETURNS BOOLEAN AS $$
DECLARE passed BOOLEAN;
BEGIN
SELECT (pwd = $2) INTO passed
FROM pwds
WHERE username = $1;
RETURN passed;
END;
$$ LANGUAGE plpgsql
SECURITY DEFINER
-- Set a secure search_path: trusted schema(s), then 'pg_temp'.
SET search_path = admin, pg_temp;`

expectedSqlInfoArr := []sqlInfo{
sqlInfo{
objName: "public.asterisks",
stmt: "CREATE FUNCTION public.asterisks(n integer) RETURNS SETOF text LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE BEGIN ATOMIC SELECT repeat('*'::text, g.g) AS repeat FROM generate_series(1, asterisks.n) g(g); END; ",
formattedStmt: `CREATE FUNCTION public.asterisks(n integer) RETURNS SETOF text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
BEGIN ATOMIC
SELECT repeat('*'::text, g.g) AS repeat
FROM generate_series(1, asterisks.n) g(g);
END;`,
},
sqlInfo{
objName: "copy_high_earners",
stmt: "CREATE OR REPLACE FUNCTION copy_high_earners(threshold NUMERIC) RETURNS VOID AS $$ DECLARE temp_salary employees.salary%TYPE; BEGIN CREATE TEMP TABLE temp_high_earners AS SELECT * FROM employees WHERE salary > threshold; FOR temp_salary IN SELECT salary FROM temp_high_earners LOOP RAISE NOTICE 'High earner salary: %', temp_salary; END LOOP; END; $$ LANGUAGE plpgsql; ",
formattedStmt: `CREATE OR REPLACE FUNCTION copy_high_earners(threshold NUMERIC) RETURNS VOID AS $$
DECLARE
temp_salary employees.salary%TYPE;
BEGIN
CREATE TEMP TABLE temp_high_earners AS
SELECT * FROM employees WHERE salary > threshold;
FOR temp_salary IN SELECT salary FROM temp_high_earners LOOP
RAISE NOTICE 'High earner salary: %', temp_salary;
END LOOP;
END;
$$ LANGUAGE plpgsql;`,
},
sqlInfo{
objName: "add",
stmt: "CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END; ",
formattedStmt: `CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END;`,
},
sqlInfo{
objName: "add",
stmt: "CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END; ",
formattedStmt: "CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE\nBEGIN ATOMIC; SELECT $1 + $2; END;",
},
sqlInfo{
objName: "public.case_sensitive_test",
stmt: "CREATE FUNCTION public.case_sensitive_test(n integer) RETURNS SETOF text LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE begin atomic SELECT repeat('*'::text, g.g) AS repeat FROM generate_series(1, asterisks.n) g(g); end; ",
formattedStmt: `CREATE FUNCTION public.case_sensitive_test(n integer) RETURNS SETOF text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
begin atomic
SELECT repeat('*'::text, g.g) AS repeat
FROM generate_series(1, asterisks.n) g(g);
end;`,
},
sqlInfo{
objName: "public.asterisks1",
stmt: "CREATE FUNCTION public.asterisks1(n integer) RETURNS text LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE RETURN repeat('*'::text, n); ",
formattedStmt: `CREATE FUNCTION public.asterisks1(n integer) RETURNS text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
RETURN repeat('*'::text, n);`,
},
sqlInfo{
objName: "add",
stmt: "CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select test;' LANGUAGE SQL IMMUTABLE RETURNS NULL ON NULL INPUT; ",
formattedStmt: `CREATE FUNCTION add(integer, integer) RETURNS integer
AS 'select test;'
LANGUAGE SQL
IMMUTABLE
RETURNS NULL ON NULL INPUT;`,
},
sqlInfo{
objName: "increment",
stmt: "CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$ BEGIN RETURN i + 1; END; $$ LANGUAGE plpgsql; ",
formattedStmt: `CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$
BEGIN
RETURN i + 1;
END;
$$ LANGUAGE plpgsql;`,
},
sqlInfo{
objName: "public.dup",
stmt: "CREATE FUNCTION public.dup(integer, OUT f1 integer, OUT f2 text) RETURNS record LANGUAGE sql AS $_$ SELECT $1, CAST($1 AS text) || ' is text' $_$; ",
formattedStmt: `CREATE FUNCTION public.dup(integer, OUT f1 integer, OUT f2 text) RETURNS record
LANGUAGE sql
AS $_$ SELECT $1, CAST($1 AS text) || ' is text' $_$;`,
},
sqlInfo{
objName: "check_password",
stmt: "CREATE FUNCTION check_password(uname TEXT, pass TEXT) RETURNS BOOLEAN AS $$ DECLARE passed BOOLEAN; BEGIN SELECT (pwd = $2) INTO passed FROM pwds WHERE username = $1; RETURN passed; END; $$ LANGUAGE plpgsql SECURITY DEFINER -- Set a secure search_path: trusted schema(s), then 'pg_temp'. SET search_path = admin, pg_temp; ",
formattedStmt: `CREATE FUNCTION check_password(uname TEXT, pass TEXT)
RETURNS BOOLEAN AS $$
DECLARE passed BOOLEAN;
BEGIN
SELECT (pwd = $2) INTO passed
FROM pwds
WHERE username = $1;
RETURN passed;
END;
$$ LANGUAGE plpgsql
SECURITY DEFINER
-- Set a secure search_path: trusted schema(s), then 'pg_temp'.
SET search_path = admin, pg_temp;`,
},
}
objType := "FUNCTION"
sqlFile, err := setupFile(objType, functionFileContent)
if err != nil {
t.Errorf("Error creating file for the objType %s: %v", objType, err)
}

defer os.Remove(sqlFile.Name())

sqlInfoArr := parseSqlFileForObjectType(sqlFile.Name(), objType)

// Validate the number of SQL statements found
if len(sqlInfoArr) != len(expectedSqlInfoArr) {
t.Errorf("Expected %d SQL statements for %s, got %d", len(expectedSqlInfoArr), objType, len(sqlInfoArr))
}

for i, expectedSqlInfo := range expectedSqlInfoArr {
assert.Equal(t, expectedSqlInfo.objName, sqlInfoArr[i].objName)
assert.Equal(t, expectedSqlInfo.stmt, sqlInfoArr[i].stmt)
Expand Down

0 comments on commit d8f1f49

Please sign in to comment.