diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkProcedures.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkProcedures.java index 35b65a7b530b..21f14e5d7a38 100644 --- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkProcedures.java +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkProcedures.java @@ -48,6 +48,7 @@ import java.util.Locale; import java.util.Map; +import java.util.Set; import java.util.function.Supplier; /** The {@link Procedure}s including all the stored procedures. */ @@ -62,6 +63,10 @@ public static ProcedureBuilder newBuilder(String name) { return builderSupplier != null ? builderSupplier.get() : null; } + public static Set names() { + return BUILDERS.keySet(); + } + private static Map> initProcedureBuilders() { ImmutableMap.Builder> procedureBuilders = ImmutableMap.builder(); diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/AbstractPaimonSparkSqlExtensionsParser.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/AbstractPaimonSparkSqlExtensionsParser.scala index c1d61e973834..557b0735c74d 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/AbstractPaimonSparkSqlExtensionsParser.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/AbstractPaimonSparkSqlExtensionsParser.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst.parser.extensions +import org.apache.paimon.spark.SparkProcedures + import org.antlr.v4.runtime._ import org.antlr.v4.runtime.atn.PredictionMode import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException} @@ -34,6 +36,8 @@ import org.apache.spark.sql.types.{DataType, StructType} import java.util.Locale +import scala.collection.JavaConverters._ + /* This file is based on source code from the Iceberg Project (http://iceberg.apache.org/), licensed by the Apache * Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for * additional information regarding copyright ownership. */ @@ -100,8 +104,15 @@ abstract class AbstractPaimonSparkSqlExtensionsParser(val delegate: ParserInterf .replaceAll("--.*?\\n", " ") .replaceAll("\\s+", " ") .replaceAll("/\\*.*?\\*/", " ") + .replaceAll("`", "") .trim() - normalized.startsWith("call") || isTagRefDdl(normalized) + isPaimonProcedure(normalized) || isTagRefDdl(normalized) + } + + // All builtin paimon procedures are under the 'sys' namespace + private def isPaimonProcedure(normalized: String): Boolean = { + normalized.startsWith("call") && + SparkProcedures.names().asScala.map("sys." + _).exists(normalized.contains) } private def isTagRefDdl(normalized: String): Boolean = { diff --git a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/extensions/CallStatementParserTest.java b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/extensions/CallStatementParserTest.java index 61e06016cbd3..e4e571e96bc9 100644 --- a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/extensions/CallStatementParserTest.java +++ b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/extensions/CallStatementParserTest.java @@ -79,14 +79,37 @@ public void stopSparkSession() { } } + @Test + public void testDelegateUnsupportedProcedure() { + assertThatThrownBy(() -> parser.parsePlan("CALL cat.d.t()")) + .isInstanceOf(ParseException.class) + .satisfies( + exception -> { + ParseException parseException = (ParseException) exception; + assertThat(parseException.getErrorClass()) + .isEqualTo("PARSE_SYNTAX_ERROR"); + assertThat(parseException.getMessageParameters().get("error")) + .isEqualTo("'CALL'"); + }); + } + + @Test + public void testCallWithBackticks() throws ParseException { + PaimonCallStatement call = + (PaimonCallStatement) parser.parsePlan("CALL cat.`sys`.`rollback`()"); + assertThat(JavaConverters.seqAsJavaList(call.name())) + .isEqualTo(Arrays.asList("cat", "sys", "rollback")); + assertThat(call.args().size()).isEqualTo(0); + } + @Test public void testCallWithNamedArguments() throws ParseException { PaimonCallStatement callStatement = (PaimonCallStatement) parser.parsePlan( - "CALL catalog.system.named_args_func(arg1 => 1, arg2 => 'test', arg3 => true)"); + "CALL catalog.sys.rollback(arg1 => 1, arg2 => 'test', arg3 => true)"); assertThat(JavaConverters.seqAsJavaList(callStatement.name())) - .isEqualTo(Arrays.asList("catalog", "system", "named_args_func")); + .isEqualTo(Arrays.asList("catalog", "sys", "rollback")); assertThat(callStatement.args().size()).isEqualTo(3); assertArgument(callStatement, 0, "arg1", 1, DataTypes.IntegerType); assertArgument(callStatement, 1, "arg2", "test", DataTypes.StringType); @@ -98,11 +121,11 @@ public void testCallWithPositionalArguments() throws ParseException { PaimonCallStatement callStatement = (PaimonCallStatement) parser.parsePlan( - "CALL catalog.system.positional_args_func(1, '${spark.sql.extensions}', 2L, true, 3.0D, 4" + "CALL catalog.sys.rollback(1, '${spark.sql.extensions}', 2L, true, 3.0D, 4" + ".0e1,500e-1BD, " + "TIMESTAMP '2017-02-03T10:37:30.00Z')"); assertThat(JavaConverters.seqAsJavaList(callStatement.name())) - .isEqualTo(Arrays.asList("catalog", "system", "positional_args_func")); + .isEqualTo(Arrays.asList("catalog", "sys", "rollback")); assertThat(callStatement.args().size()).isEqualTo(8); assertArgument(callStatement, 0, 1, DataTypes.IntegerType); assertArgument( @@ -127,9 +150,9 @@ public void testCallWithPositionalArguments() throws ParseException { public void testCallWithMixedArguments() throws ParseException { PaimonCallStatement callStatement = (PaimonCallStatement) - parser.parsePlan("CALL catalog.system.mixed_function(arg1 => 1, 'test')"); + parser.parsePlan("CALL catalog.sys.rollback(arg1 => 1, 'test')"); assertThat(JavaConverters.seqAsJavaList(callStatement.name())) - .isEqualTo(Arrays.asList("catalog", "system", "mixed_function")); + .isEqualTo(Arrays.asList("catalog", "sys", "rollback")); assertThat(callStatement.args().size()).isEqualTo(2); assertArgument(callStatement, 0, "arg1", 1, DataTypes.IntegerType); assertArgument(callStatement, 1, "test", DataTypes.StringType); @@ -137,9 +160,9 @@ public void testCallWithMixedArguments() throws ParseException { @Test public void testCallWithParseException() { - assertThatThrownBy(() -> parser.parsePlan("CALL catalog.system func abc")) + assertThatThrownBy(() -> parser.parsePlan("CALL catalog.sys.rollback abc")) .isInstanceOf(PaimonParseException.class) - .hasMessageContaining("missing '(' at 'func'"); + .hasMessageContaining("missing '(' at 'abc'"); } private void assertArgument( diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/ProcedureTestBase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/ProcedureTestBase.scala index f3cb7fa26665..a5f9f3ffa01b 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/ProcedureTestBase.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/ProcedureTestBase.scala @@ -19,8 +19,8 @@ package org.apache.paimon.spark.procedure import org.apache.paimon.spark.PaimonSparkTestBase -import org.apache.paimon.spark.analysis.NoSuchProcedureException +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.parser.extensions.PaimonParseException import org.assertj.core.api.Assertions.assertThatThrownBy @@ -32,7 +32,7 @@ abstract class ProcedureTestBase extends PaimonSparkTestBase { |""".stripMargin) assertThatThrownBy(() => spark.sql("CALL sys.unknown_procedure(table => 'test.T')")) - .isInstanceOf(classOf[NoSuchProcedureException]) + .isInstanceOf(classOf[ParseException]) } test(s"test parse exception") {