diff --git a/yb-voyager/cmd/analyzeSchema.go b/yb-voyager/cmd/analyzeSchema.go index 1b52a9611..4ccf8f99d 100644 --- a/yb-voyager/cmd/analyzeSchema.go +++ b/yb-voyager/cmd/analyzeSchema.go @@ -134,6 +134,7 @@ var ( parserIssueDetector = queryissue.NewParserIssueDetector() multiRegex = regexp.MustCompile(`([a-zA-Z0-9_\.]+[,|;])`) dollarQuoteRegex = regexp.MustCompile(`(\$.*\$)`) + sqlBodyBeginRegex = re("BEGIN", "ATOMIC") //TODO: optional but replace every possible space or new line char with [\s\n]+ in all regexs viewWithCheckRegex = re("VIEW", capture(ident), anything, "WITH", opt(commonClause), "CHECK", "OPTION") rangeRegex = re("PRECEDING", "and", anything, ":float") @@ -911,7 +912,6 @@ sqlParsingLoop: stmt += currLine + " " formattedStmt += currLine + "\n" - // Assuming that both the dollar quote strings will not be in same line switch dollarQuoteFlag { case CODE_BLOCK_NOT_STARTED: @@ -920,14 +920,30 @@ sqlParsingLoop: } else if matches := dollarQuoteRegex.FindStringSubmatch(currLine); matches != nil { dollarQuoteFlag = 1 //denotes start of the code/body part codeBlockDelimiter = matches[0] + } else if matches := sqlBodyBeginRegex.FindStringSubmatch(currLine); matches != nil { + dollarQuoteFlag = 1 //denotes start of the sql body part https://www.postgresql.org/docs/15/sql-createfunction.html#:~:text=a%20new%20session.-,sql_body,-The%20body%20of + codeBlockDelimiter = "END" //SQL body to determine the end of BEGIN ATOMIC ... END; sql body } case CODE_BLOCK_STARTED: - if strings.Contains(currLine, codeBlockDelimiter) { - dollarQuoteFlag = 2 //denotes end of code/body part - if isEndOfSqlStmt(currLine) { - break sqlParsingLoop + switch codeBlockDelimiter { + case "END": + if strings.Contains(currLine, codeBlockDelimiter) || + strings.Contains(currLine, strings.ToLower(codeBlockDelimiter)) { + //TODO: anyways we should be using pg-parser: but for now for the END sql body delimiter checking the UPPER and LOWER both + dollarQuoteFlag = 2 //denotes end of code/body part + if isEndOfSqlStmt(currLine) { + break sqlParsingLoop + } + } + default: + if strings.Contains(currLine, codeBlockDelimiter) { + dollarQuoteFlag = 2 //denotes end of code/body part + if isEndOfSqlStmt(currLine) { + break sqlParsingLoop + } } } + case CODE_BLOCK_COMPLETED: if isEndOfSqlStmt(currLine) { break sqlParsingLoop @@ -971,6 +987,9 @@ func isEndOfSqlStmt(line string) bool { line = line[0:cmtStartIdx] // ignore comment line = strings.TrimRight(line, " ") } + if len(line) == 0 { + return false + } return line[len(line)-1] == ';' } diff --git a/yb-voyager/cmd/analyzeSchema_test.go b/yb-voyager/cmd/analyzeSchema_test.go index d768b94b0..97bd80a3b 100644 --- a/yb-voyager/cmd/analyzeSchema_test.go +++ b/yb-voyager/cmd/analyzeSchema_test.go @@ -76,13 +76,12 @@ CREATE TABLE another_table ( t.Errorf("Error creating file for the objType %s: %v", objType, err) } - defer os.Remove(sqlFile.Name()) sqlInfoArr := parseSqlFileForObjectType(sqlFile.Name(), objType) // Validate the number of SQL statements found if len(sqlInfoArr) != len(expectedSqlInfoArr) { - t.Errorf("Expected %d SQL statements for %s, got %d",len(expectedSqlInfoArr), objType, len(sqlInfoArr)) + t.Errorf("Expected %d SQL statements for %s, got %d", len(expectedSqlInfoArr), objType, len(sqlInfoArr)) } for i, expectedSqlInfo := range expectedSqlInfoArr { @@ -140,9 +139,195 @@ $$ LANGUAGE plpgsql;`, // Validate the number of SQL statements found if len(sqlInfoArr) != len(expectedSqlInfoArr) { - t.Errorf("Expected %d SQL statements for %s, got %d",len(expectedSqlInfoArr), objType, len(sqlInfoArr)) + t.Errorf("Expected %d SQL statements for %s, got %d", len(expectedSqlInfoArr), objType, len(sqlInfoArr)) + } + + for i, expectedSqlInfo := range expectedSqlInfoArr { + assert.Equal(t, expectedSqlInfo.objName, sqlInfoArr[i].objName) + assert.Equal(t, expectedSqlInfo.stmt, sqlInfoArr[i].stmt) + assert.Equal(t, expectedSqlInfo.formattedStmt, sqlInfoArr[i].formattedStmt) + } + +} + +func TestFunctionSQLFile(t *testing.T) { + functionFileContent := `CREATE FUNCTION public.asterisks(n integer) RETURNS SETOF text + LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE + BEGIN ATOMIC + SELECT repeat('*'::text, g.g) AS repeat + FROM generate_series(1, asterisks.n) g(g); +END; + +CREATE OR REPLACE FUNCTION copy_high_earners(threshold NUMERIC) RETURNS VOID AS $$ +DECLARE + temp_salary employees.salary%TYPE; +BEGIN + CREATE TEMP TABLE temp_high_earners AS + SELECT * FROM employees WHERE salary > threshold; + FOR temp_salary IN SELECT salary FROM temp_high_earners LOOP + RAISE NOTICE 'High earner salary: %', temp_salary; + END LOOP; +END; +$$ LANGUAGE plpgsql; + +CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END; + +CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE +BEGIN ATOMIC; SELECT $1 + $2; END; + +CREATE FUNCTION public.case_sensitive_test(n integer) RETURNS SETOF text + LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE + begin atomic + SELECT repeat('*'::text, g.g) AS repeat + FROM generate_series(1, asterisks.n) g(g); +end; + +CREATE FUNCTION public.asterisks1(n integer) RETURNS text + LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE + RETURN repeat('*'::text, n); + +CREATE FUNCTION add(integer, integer) RETURNS integer + AS 'select test;' + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT; + +CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; +$$ LANGUAGE plpgsql; + +CREATE FUNCTION public.dup(integer, OUT f1 integer, OUT f2 text) RETURNS record + LANGUAGE sql + AS $_$ SELECT $1, CAST($1 AS text) || ' is text' $_$; + +CREATE FUNCTION check_password(uname TEXT, pass TEXT) +RETURNS BOOLEAN AS $$ +DECLARE passed BOOLEAN; +BEGIN + SELECT (pwd = $2) INTO passed + FROM pwds + WHERE username = $1; + + RETURN passed; +END; +$$ LANGUAGE plpgsql + SECURITY DEFINER + -- Set a secure search_path: trusted schema(s), then 'pg_temp'. + SET search_path = admin, pg_temp;` + + expectedSqlInfoArr := []sqlInfo{ + sqlInfo{ + objName: "public.asterisks", + stmt: "CREATE FUNCTION public.asterisks(n integer) RETURNS SETOF text LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE BEGIN ATOMIC SELECT repeat('*'::text, g.g) AS repeat FROM generate_series(1, asterisks.n) g(g); END; ", + formattedStmt: `CREATE FUNCTION public.asterisks(n integer) RETURNS SETOF text + LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE + BEGIN ATOMIC + SELECT repeat('*'::text, g.g) AS repeat + FROM generate_series(1, asterisks.n) g(g); +END;`, + }, + sqlInfo{ + objName: "copy_high_earners", + stmt: "CREATE OR REPLACE FUNCTION copy_high_earners(threshold NUMERIC) RETURNS VOID AS $$ DECLARE temp_salary employees.salary%TYPE; BEGIN CREATE TEMP TABLE temp_high_earners AS SELECT * FROM employees WHERE salary > threshold; FOR temp_salary IN SELECT salary FROM temp_high_earners LOOP RAISE NOTICE 'High earner salary: %', temp_salary; END LOOP; END; $$ LANGUAGE plpgsql; ", + formattedStmt: `CREATE OR REPLACE FUNCTION copy_high_earners(threshold NUMERIC) RETURNS VOID AS $$ +DECLARE + temp_salary employees.salary%TYPE; +BEGIN + CREATE TEMP TABLE temp_high_earners AS + SELECT * FROM employees WHERE salary > threshold; + FOR temp_salary IN SELECT salary FROM temp_high_earners LOOP + RAISE NOTICE 'High earner salary: %', temp_salary; + END LOOP; +END; +$$ LANGUAGE plpgsql;`, + }, + sqlInfo{ + objName: "add", + stmt: "CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END; ", + formattedStmt: `CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END;`, + }, + sqlInfo{ + objName: "add", + stmt: "CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END; ", + formattedStmt: "CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE\nBEGIN ATOMIC; SELECT $1 + $2; END;", + }, + sqlInfo{ + objName: "public.case_sensitive_test", + stmt: "CREATE FUNCTION public.case_sensitive_test(n integer) RETURNS SETOF text LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE begin atomic SELECT repeat('*'::text, g.g) AS repeat FROM generate_series(1, asterisks.n) g(g); end; ", + formattedStmt: `CREATE FUNCTION public.case_sensitive_test(n integer) RETURNS SETOF text + LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE + begin atomic + SELECT repeat('*'::text, g.g) AS repeat + FROM generate_series(1, asterisks.n) g(g); +end;`, + }, + sqlInfo{ + objName: "public.asterisks1", + stmt: "CREATE FUNCTION public.asterisks1(n integer) RETURNS text LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE RETURN repeat('*'::text, n); ", + formattedStmt: `CREATE FUNCTION public.asterisks1(n integer) RETURNS text + LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE + RETURN repeat('*'::text, n);`, + }, + sqlInfo{ + objName: "add", + stmt: "CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select test;' LANGUAGE SQL IMMUTABLE RETURNS NULL ON NULL INPUT; ", + formattedStmt: `CREATE FUNCTION add(integer, integer) RETURNS integer + AS 'select test;' + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT;`, + }, + sqlInfo{ + objName: "increment", + stmt: "CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$ BEGIN RETURN i + 1; END; $$ LANGUAGE plpgsql; ", + formattedStmt: `CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; +$$ LANGUAGE plpgsql;`, + }, + sqlInfo{ + objName: "public.dup", + stmt: "CREATE FUNCTION public.dup(integer, OUT f1 integer, OUT f2 text) RETURNS record LANGUAGE sql AS $_$ SELECT $1, CAST($1 AS text) || ' is text' $_$; ", + formattedStmt: `CREATE FUNCTION public.dup(integer, OUT f1 integer, OUT f2 text) RETURNS record + LANGUAGE sql + AS $_$ SELECT $1, CAST($1 AS text) || ' is text' $_$;`, + }, + sqlInfo{ + objName: "check_password", + stmt: "CREATE FUNCTION check_password(uname TEXT, pass TEXT) RETURNS BOOLEAN AS $$ DECLARE passed BOOLEAN; BEGIN SELECT (pwd = $2) INTO passed FROM pwds WHERE username = $1; RETURN passed; END; $$ LANGUAGE plpgsql SECURITY DEFINER -- Set a secure search_path: trusted schema(s), then 'pg_temp'. SET search_path = admin, pg_temp; ", + formattedStmt: `CREATE FUNCTION check_password(uname TEXT, pass TEXT) +RETURNS BOOLEAN AS $$ +DECLARE passed BOOLEAN; +BEGIN + SELECT (pwd = $2) INTO passed + FROM pwds + WHERE username = $1; + RETURN passed; +END; +$$ LANGUAGE plpgsql + SECURITY DEFINER + -- Set a secure search_path: trusted schema(s), then 'pg_temp'. + SET search_path = admin, pg_temp;`, + }, + } + objType := "FUNCTION" + sqlFile, err := setupFile(objType, functionFileContent) + if err != nil { + t.Errorf("Error creating file for the objType %s: %v", objType, err) } + defer os.Remove(sqlFile.Name()) + + sqlInfoArr := parseSqlFileForObjectType(sqlFile.Name(), objType) + + // Validate the number of SQL statements found + if len(sqlInfoArr) != len(expectedSqlInfoArr) { + t.Errorf("Expected %d SQL statements for %s, got %d", len(expectedSqlInfoArr), objType, len(sqlInfoArr)) + } + for i, expectedSqlInfo := range expectedSqlInfoArr { assert.Equal(t, expectedSqlInfo.objName, sqlInfoArr[i].objName) assert.Equal(t, expectedSqlInfo.stmt, sqlInfoArr[i].stmt)