Skip to content

Commit

Permalink
add more cases, and handle comment case in IsEndOFSqlStmt
Browse files Browse the repository at this point in the history
  • Loading branch information
priyanshi-yb committed Jan 21, 2025
1 parent 80de2d6 commit a7abfab
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 6 deletions.
7 changes: 4 additions & 3 deletions yb-voyager/cmd/analyzeSchema.go
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,6 @@ sqlParsingLoop:

stmt += currLine + " "
formattedStmt += currLine + "\n"

// Assuming that both the dollar quote strings will not be in same line
switch dollarQuoteFlag {
case CODE_BLOCK_NOT_STARTED:
Expand All @@ -929,8 +928,7 @@ sqlParsingLoop:
case CODE_BLOCK_STARTED:
if strings.Contains(currLine, codeBlockDelimiter) ||
strings.Contains(currLine, strings.ToLower(codeBlockDelimiter)) {
//TODO: anyways we should be using pg-parser: for the END sql body delimiter checking the UPPER and LOWER both
//not using regex as there are some issues while doing that (not debugged that)
//TODO: anyways we should be using pg-parser: but for now for the END sql body delimiter checking the UPPER and LOWER both
dollarQuoteFlag = 2 //denotes end of code/body part
if isEndOfSqlStmt(currLine) {
break sqlParsingLoop
Expand Down Expand Up @@ -979,6 +977,9 @@ func isEndOfSqlStmt(line string) bool {
line = line[0:cmtStartIdx] // ignore comment
line = strings.TrimRight(line, " ")
}
if len(line) == 0 {
return false
}
return line[len(line)-1] == ';'
}

Expand Down
80 changes: 77 additions & 3 deletions yb-voyager/cmd/analyzeSchema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func TestFunctionSQLFile(t *testing.T) {
BEGIN ATOMIC
SELECT repeat('*'::text, g.g) AS repeat
FROM generate_series(1, asterisks.n) g(g);
END;
END;
CREATE OR REPLACE FUNCTION copy_high_earners(threshold NUMERIC) RETURNS VOID AS $$
DECLARE
Expand All @@ -180,11 +180,42 @@ CREATE FUNCTION public.case_sensitive_test(n integer) RETURNS SETOF text
begin atomic
SELECT repeat('*'::text, g.g) AS repeat
FROM generate_series(1, asterisks.n) g(g);
end;
end;
CREATE FUNCTION public.asterisks1(n integer) RETURNS text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
RETURN repeat('*'::text, n);`
RETURN repeat('*'::text, n);
CREATE FUNCTION add(integer, integer) RETURNS integer
AS 'select test;'
LANGUAGE SQL
IMMUTABLE
RETURNS NULL ON NULL INPUT;
CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$
BEGIN
RETURN i + 1;
END;
$$ LANGUAGE plpgsql;
CREATE FUNCTION public.dup(integer, OUT f1 integer, OUT f2 text) RETURNS record
LANGUAGE sql
AS $_$ SELECT $1, CAST($1 AS text) || ' is text' $_$;
CREATE FUNCTION check_password(uname TEXT, pass TEXT)
RETURNS BOOLEAN AS $$
DECLARE passed BOOLEAN;
BEGIN
SELECT (pwd = $2) INTO passed
FROM pwds
WHERE username = $1;
RETURN passed;
END;
$$ LANGUAGE plpgsql
SECURITY DEFINER
-- Set a secure search_path: trusted schema(s), then 'pg_temp'.
SET search_path = admin, pg_temp;`

expectedSqlInfoArr := []sqlInfo{
sqlInfo{
Expand Down Expand Up @@ -239,6 +270,48 @@ end;`,
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
RETURN repeat('*'::text, n);`,
},
sqlInfo{
objName: "add",
stmt: "CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select test;' LANGUAGE SQL IMMUTABLE RETURNS NULL ON NULL INPUT; ",
formattedStmt: `CREATE FUNCTION add(integer, integer) RETURNS integer
AS 'select test;'
LANGUAGE SQL
IMMUTABLE
RETURNS NULL ON NULL INPUT;`,
},
sqlInfo{
objName: "increment",
stmt: "CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$ BEGIN RETURN i + 1; END; $$ LANGUAGE plpgsql; ",
formattedStmt: `CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$
BEGIN
RETURN i + 1;
END;
$$ LANGUAGE plpgsql;`,
},
sqlInfo{
objName: "public.dup",
stmt: "CREATE FUNCTION public.dup(integer, OUT f1 integer, OUT f2 text) RETURNS record LANGUAGE sql AS $_$ SELECT $1, CAST($1 AS text) || ' is text' $_$; ",
formattedStmt: `CREATE FUNCTION public.dup(integer, OUT f1 integer, OUT f2 text) RETURNS record
LANGUAGE sql
AS $_$ SELECT $1, CAST($1 AS text) || ' is text' $_$;`,
},
sqlInfo{
objName: "check_password",
stmt: "CREATE FUNCTION check_password(uname TEXT, pass TEXT) RETURNS BOOLEAN AS $$ DECLARE passed BOOLEAN; BEGIN SELECT (pwd = $2) INTO passed FROM pwds WHERE username = $1; RETURN passed; END; $$ LANGUAGE plpgsql SECURITY DEFINER -- Set a secure search_path: trusted schema(s), then 'pg_temp'. SET search_path = admin, pg_temp; ",
formattedStmt: `CREATE FUNCTION check_password(uname TEXT, pass TEXT)
RETURNS BOOLEAN AS $$
DECLARE passed BOOLEAN;
BEGIN
SELECT (pwd = $2) INTO passed
FROM pwds
WHERE username = $1;
RETURN passed;
END;
$$ LANGUAGE plpgsql
SECURITY DEFINER
-- Set a secure search_path: trusted schema(s), then 'pg_temp'.
SET search_path = admin, pg_temp;`,
},
}
objType := "FUNCTION"
sqlFile, err := setupFile(objType, functionFileContent)
Expand All @@ -250,6 +323,7 @@ end;`,

sqlInfoArr := parseSqlFileForObjectType(sqlFile.Name(), objType)

fmt.Printf("%v", sqlInfoArr)
// Validate the number of SQL statements found
if len(sqlInfoArr) != len(expectedSqlInfoArr) {
t.Errorf("Expected %d SQL statements for %s, got %d", len(expectedSqlInfoArr), objType, len(sqlInfoArr))
Expand Down

0 comments on commit a7abfab

Please sign in to comment.