Skip to content

Commit

Permalink
[spark] Paimon parser only resolve own supported procedures (#4662)
Browse files Browse the repository at this point in the history
  • Loading branch information
askwang authored Dec 9, 2024
1 parent 9191e2e commit bfa0c5c
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;

/** The {@link Procedure}s including all the stored procedures. */
Expand All @@ -62,6 +63,10 @@ public static ProcedureBuilder newBuilder(String name) {
return builderSupplier != null ? builderSupplier.get() : null;
}

public static Set<String> names() {
return BUILDERS.keySet();
}

private static Map<String, Supplier<ProcedureBuilder>> initProcedureBuilders() {
ImmutableMap.Builder<String, Supplier<ProcedureBuilder>> procedureBuilders =
ImmutableMap.builder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

package org.apache.spark.sql.catalyst.parser.extensions

import org.apache.paimon.spark.SparkProcedures

import org.antlr.v4.runtime._
import org.antlr.v4.runtime.atn.PredictionMode
import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException}
Expand All @@ -34,6 +36,8 @@ import org.apache.spark.sql.types.{DataType, StructType}

import java.util.Locale

import scala.collection.JavaConverters._

/* This file is based on source code from the Iceberg Project (http://iceberg.apache.org/), licensed by the Apache
* Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for
* additional information regarding copyright ownership. */
Expand Down Expand Up @@ -100,8 +104,15 @@ abstract class AbstractPaimonSparkSqlExtensionsParser(val delegate: ParserInterf
.replaceAll("--.*?\\n", " ")
.replaceAll("\\s+", " ")
.replaceAll("/\\*.*?\\*/", " ")
.replaceAll("`", "")
.trim()
normalized.startsWith("call") || isTagRefDdl(normalized)
isPaimonProcedure(normalized) || isTagRefDdl(normalized)
}

// All builtin paimon procedures are under the 'sys' namespace
private def isPaimonProcedure(normalized: String): Boolean = {
normalized.startsWith("call") &&
SparkProcedures.names().asScala.map("sys." + _).exists(normalized.contains)
}

private def isTagRefDdl(normalized: String): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,37 @@ public void stopSparkSession() {
}
}

@Test
public void testDelegateUnsupportedProcedure() {
assertThatThrownBy(() -> parser.parsePlan("CALL cat.d.t()"))
.isInstanceOf(ParseException.class)
.satisfies(
exception -> {
ParseException parseException = (ParseException) exception;
assertThat(parseException.getErrorClass())
.isEqualTo("PARSE_SYNTAX_ERROR");
assertThat(parseException.getMessageParameters().get("error"))
.isEqualTo("'CALL'");
});
}

@Test
public void testCallWithBackticks() throws ParseException {
PaimonCallStatement call =
(PaimonCallStatement) parser.parsePlan("CALL cat.`sys`.`rollback`()");
assertThat(JavaConverters.seqAsJavaList(call.name()))
.isEqualTo(Arrays.asList("cat", "sys", "rollback"));
assertThat(call.args().size()).isEqualTo(0);
}

@Test
public void testCallWithNamedArguments() throws ParseException {
PaimonCallStatement callStatement =
(PaimonCallStatement)
parser.parsePlan(
"CALL catalog.system.named_args_func(arg1 => 1, arg2 => 'test', arg3 => true)");
"CALL catalog.sys.rollback(arg1 => 1, arg2 => 'test', arg3 => true)");
assertThat(JavaConverters.seqAsJavaList(callStatement.name()))
.isEqualTo(Arrays.asList("catalog", "system", "named_args_func"));
.isEqualTo(Arrays.asList("catalog", "sys", "rollback"));
assertThat(callStatement.args().size()).isEqualTo(3);
assertArgument(callStatement, 0, "arg1", 1, DataTypes.IntegerType);
assertArgument(callStatement, 1, "arg2", "test", DataTypes.StringType);
Expand All @@ -98,11 +121,11 @@ public void testCallWithPositionalArguments() throws ParseException {
PaimonCallStatement callStatement =
(PaimonCallStatement)
parser.parsePlan(
"CALL catalog.system.positional_args_func(1, '${spark.sql.extensions}', 2L, true, 3.0D, 4"
"CALL catalog.sys.rollback(1, '${spark.sql.extensions}', 2L, true, 3.0D, 4"
+ ".0e1,500e-1BD, "
+ "TIMESTAMP '2017-02-03T10:37:30.00Z')");
assertThat(JavaConverters.seqAsJavaList(callStatement.name()))
.isEqualTo(Arrays.asList("catalog", "system", "positional_args_func"));
.isEqualTo(Arrays.asList("catalog", "sys", "rollback"));
assertThat(callStatement.args().size()).isEqualTo(8);
assertArgument(callStatement, 0, 1, DataTypes.IntegerType);
assertArgument(
Expand All @@ -127,19 +150,19 @@ public void testCallWithPositionalArguments() throws ParseException {
public void testCallWithMixedArguments() throws ParseException {
PaimonCallStatement callStatement =
(PaimonCallStatement)
parser.parsePlan("CALL catalog.system.mixed_function(arg1 => 1, 'test')");
parser.parsePlan("CALL catalog.sys.rollback(arg1 => 1, 'test')");
assertThat(JavaConverters.seqAsJavaList(callStatement.name()))
.isEqualTo(Arrays.asList("catalog", "system", "mixed_function"));
.isEqualTo(Arrays.asList("catalog", "sys", "rollback"));
assertThat(callStatement.args().size()).isEqualTo(2);
assertArgument(callStatement, 0, "arg1", 1, DataTypes.IntegerType);
assertArgument(callStatement, 1, "test", DataTypes.StringType);
}

@Test
public void testCallWithParseException() {
assertThatThrownBy(() -> parser.parsePlan("CALL catalog.system func abc"))
assertThatThrownBy(() -> parser.parsePlan("CALL catalog.sys.rollback abc"))
.isInstanceOf(PaimonParseException.class)
.hasMessageContaining("missing '(' at 'func'");
.hasMessageContaining("missing '(' at 'abc'");
}

private void assertArgument(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
package org.apache.paimon.spark.procedure

import org.apache.paimon.spark.PaimonSparkTestBase
import org.apache.paimon.spark.analysis.NoSuchProcedureException

import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.parser.extensions.PaimonParseException
import org.assertj.core.api.Assertions.assertThatThrownBy

Expand All @@ -32,7 +32,7 @@ abstract class ProcedureTestBase extends PaimonSparkTestBase {
|""".stripMargin)

assertThatThrownBy(() => spark.sql("CALL sys.unknown_procedure(table => 'test.T')"))
.isInstanceOf(classOf[NoSuchProcedureException])
.isInstanceOf(classOf[ParseException])
}

test(s"test parse exception") {
Expand Down

0 comments on commit bfa0c5c

Please sign in to comment.