diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteViewCommands.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteViewCommands.scala index 151622191fad..e48357971211 100644 --- a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteViewCommands.scala +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteViewCommands.scala @@ -33,11 +33,13 @@ import org.apache.spark.sql.catalyst.plans.logical.views.DropIcebergView import org.apache.spark.sql.catalyst.plans.logical.views.ResolvedV2View import org.apache.spark.sql.catalyst.plans.logical.views.ShowIcebergViews import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_FUNCTION import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.connector.catalog.CatalogPlugin import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.LookupCatalog import org.apache.spark.sql.connector.catalog.ViewCatalog +import scala.collection.mutable /** * ResolveSessionCatalog exits early for some v2 View commands, @@ -45,6 +47,8 @@ import org.apache.spark.sql.connector.catalog.ViewCatalog */ case class RewriteViewCommands(spark: SparkSession) extends Rule[LogicalPlan] with LookupCatalog { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + protected lazy val catalogManager: CatalogManager = spark.sessionState.catalogManager override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { @@ -85,6 +89,13 @@ case class RewriteViewCommands(spark: SparkSession) extends Rule[LogicalPlan] wi catalogManager.v1SessionCatalog.isTempView(nameParts) } + private def isTempFunction(nameParts: Seq[String]): Boolean = { + if (nameParts.size > 1) { + return false + } + catalogManager.v1SessionCatalog.isTemporaryFunction(nameParts.asFunctionIdentifier) + } + object ResolvedIdent { def unapply(unresolved: UnresolvedIdentifier): Option[ResolvedIdentifier] = unresolved match { case UnresolvedIdentifier(nameParts, true) if isTempView(nameParts) => @@ -104,18 +115,20 @@ case class RewriteViewCommands(spark: SparkSession) extends Rule[LogicalPlan] wi private def verifyTemporaryObjectsDontExist( name: Identifier, child: LogicalPlan): Unit = { - import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - val tempViews = collectTemporaryViews(child) - tempViews.foreach { nameParts => - throw new AnalysisException(String.format("Cannot create the persistent object %s" + - " of the type VIEW because it references to the temporary object %s of" + - " the type VIEW. Please make the temporary object %s" + - " persistent, or make the persistent object %s temporary", - name.name(), nameParts.quoted, nameParts.quoted, name.name())) - }; - - // TODO: check for temp function names + if (tempViews.nonEmpty) { + throw invalidRefToTempObject(name, tempViews.map(v => v.quoted).mkString("[", ", ", "]"), "view") + } + + val tempFunctions = collectTemporaryFunctions(child) + if (tempFunctions.nonEmpty) { + throw invalidRefToTempObject(name, tempFunctions.mkString("[", ", ", "]"), "function") + } + } + + private def invalidRefToTempObject(name: Identifier, tempObjectNames: String, tempObjectType: String) = { + new AnalysisException(String.format("Cannot create view %s that references temporary %s: %s", + name, tempObjectType, tempObjectNames)) } /** @@ -149,4 +162,20 @@ case class RewriteViewCommands(spark: SparkSession) extends Rule[LogicalPlan] wi None } } + + /** + * Collect the names of all temporary functions. + */ + private def collectTemporaryFunctions(child: LogicalPlan): Seq[String] = { + val tempFunctions = new mutable.HashSet[String]() + child.resolveExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_FUNCTION)) { + case f @ UnresolvedFunction(nameParts, _, _, _, _) if isTempFunction(nameParts) => + tempFunctions += nameParts.head + f + case e: SubqueryExpression => + tempFunctions ++= collectTemporaryFunctions(e.plan) + e + } + tempFunctions.toSeq + } } diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestViews.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestViews.java index ea9ccc9133fe..26bcb03ba101 100644 --- a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestViews.java +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestViews.java @@ -477,6 +477,39 @@ public void readFromViewReferencingGlobalTempView() throws NoSuchTableException .hasMessageContaining("cannot be found"); } + @Test + public void readFromViewReferencingTempFunction() throws NoSuchTableException { + insertRows(10); + String viewName = viewName("viewReferencingTempFunction"); + String functionName = "test_avg"; + String sql = String.format("SELECT %s(id) FROM %s", functionName, tableName); + sql( + "CREATE TEMPORARY FUNCTION %s AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage'", + functionName); + + ViewCatalog viewCatalog = viewCatalog(); + Schema schema = tableCatalog().loadTable(TableIdentifier.of(NAMESPACE, tableName)).schema(); + + // it wouldn't be possible to reference a TEMP FUNCTION if the view had been created via SQL, + // but this can't be prevented when using the API directly + viewCatalog + .buildView(TableIdentifier.of(NAMESPACE, viewName)) + .withQuery("spark", sql) + .withDefaultNamespace(NAMESPACE) + .withDefaultCatalog(catalogName) + .withSchema(schema) + .create(); + + assertThat(sql(sql)).hasSize(1).containsExactly(row(5.5)); + + // reading from a view that references a TEMP FUNCTION shouldn't be possible + assertThatThrownBy(() -> sql("SELECT * FROM %s", viewName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("The function") + .hasMessageContaining(functionName) + .hasMessageContaining("cannot be found"); + } + @Test public void readFromViewWithCTE() throws NoSuchTableException { insertRows(10); @@ -947,9 +980,9 @@ public void createViewReferencingTempView() throws NoSuchTableException { assertThatThrownBy( () -> sql("CREATE VIEW %s AS SELECT id FROM %s", viewReferencingTempView, tempView)) .isInstanceOf(AnalysisException.class) - .hasMessageContaining("Cannot create the persistent object") - .hasMessageContaining(viewReferencingTempView) - .hasMessageContaining("of the type VIEW because it references to the temporary object") + .hasMessageContaining( + String.format("Cannot create view %s.%s", NAMESPACE, viewReferencingTempView)) + .hasMessageContaining("that references temporary view:") .hasMessageContaining(tempView); } @@ -970,10 +1003,59 @@ public void createViewReferencingGlobalTempView() throws NoSuchTableException { "CREATE VIEW %s AS SELECT id FROM global_temp.%s", viewReferencingTempView, globalTempView)) .isInstanceOf(AnalysisException.class) - .hasMessageContaining("Cannot create the persistent object") - .hasMessageContaining(viewReferencingTempView) - .hasMessageContaining("of the type VIEW because it references to the temporary object") - .hasMessageContaining(globalTempView); + .hasMessageContaining( + String.format("Cannot create view %s.%s", NAMESPACE, viewReferencingTempView)) + .hasMessageContaining("that references temporary view:") + .hasMessageContaining(String.format("%s.%s", "global_temp", globalTempView)); + } + + @Test + public void createViewReferencingTempFunction() { + String viewName = viewName("viewReferencingTemporaryFunction"); + String functionName = "test_avg_func"; + + sql( + "CREATE TEMPORARY FUNCTION %s AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage'", + functionName); + + // creating a view that references a TEMP FUNCTION shouldn't be possible + assertThatThrownBy( + () -> sql("CREATE VIEW %s AS SELECT %s(id) FROM %s", viewName, functionName, tableName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining(String.format("Cannot create view %s.%s", NAMESPACE, viewName)) + .hasMessageContaining("that references temporary function:") + .hasMessageContaining(functionName); + } + + @Test + public void createViewReferencingQualifiedTempFunction() { + String viewName = viewName("viewReferencingTemporaryFunction"); + String functionName = "test_avg_func_qualified"; + + sql( + "CREATE TEMPORARY FUNCTION %s AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage'", + functionName); + + // TEMP Function can't be referenced using catalog.schema.name + assertThatThrownBy( + () -> + sql( + "CREATE VIEW %s AS SELECT %s.%s.%s(id) FROM %s", + viewName, catalogName, NAMESPACE, functionName, tableName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Cannot resolve function") + .hasMessageContaining( + String.format("`%s`.`%s`.`%s`", catalogName, NAMESPACE, functionName)); + + // TEMP Function can't be referenced using schema.name + assertThatThrownBy( + () -> + sql( + "CREATE VIEW %s AS SELECT %s.%s(id) FROM %s", + viewName, NAMESPACE, functionName, tableName)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("Cannot resolve function") + .hasMessageContaining(String.format("`%s`.`%s`", NAMESPACE, functionName)); } @Test @@ -1118,12 +1200,32 @@ public void createViewWithCTEReferencingTempView() { assertThatThrownBy(() -> sql("CREATE VIEW %s AS %s", viewName, sql)) .isInstanceOf(AnalysisException.class) - .hasMessageContaining("Cannot create the persistent object") - .hasMessageContaining(viewName) - .hasMessageContaining("of the type VIEW because it references to the temporary object") + .hasMessageContaining(String.format("Cannot create view %s.%s", NAMESPACE, viewName)) + .hasMessageContaining("that references temporary view:") .hasMessageContaining(tempViewInCTE); } + @Test + public void createViewWithCTEReferencingTempFunction() { + String viewName = "viewWithCTEReferencingTempFunction"; + String functionName = "avg_function_in_cte"; + String sql = + String.format( + "WITH avg_data AS (SELECT %s(id) as avg FROM %s) " + + "SELECT avg, count(1) AS count FROM avg_data GROUP BY max", + functionName, tableName); + + sql( + "CREATE TEMPORARY FUNCTION %s AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage'", + functionName); + + assertThatThrownBy(() -> sql("CREATE VIEW %s AS %s", viewName, sql)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining(String.format("Cannot create view %s.%s", NAMESPACE, viewName)) + .hasMessageContaining("that references temporary function:") + .hasMessageContaining(functionName); + } + @Test public void createViewWithNonExistingQueryColumn() { assertThatThrownBy( @@ -1147,9 +1249,9 @@ public void createViewWithSubqueryExpressionUsingTempView() { assertThatThrownBy(() -> sql("CREATE VIEW %s AS %s", viewName, sql)) .isInstanceOf(AnalysisException.class) - .hasMessageContaining(String.format("Cannot create the persistent object %s", viewName)) - .hasMessageContaining( - String.format("because it references to the temporary object %s", tempView)); + .hasMessageContaining(String.format("Cannot create view %s.%s", NAMESPACE, viewName)) + .hasMessageContaining("that references temporary view:") + .hasMessageContaining(tempView); } @Test @@ -1167,10 +1269,29 @@ public void createViewWithSubqueryExpressionUsingGlobalTempView() { assertThatThrownBy(() -> sql("CREATE VIEW %s AS %s", viewName, sql)) .isInstanceOf(AnalysisException.class) - .hasMessageContaining(String.format("Cannot create the persistent object %s", viewName)) - .hasMessageContaining( - String.format( - "because it references to the temporary object global_temp.%s", globalTempView)); + .hasMessageContaining(String.format("Cannot create view %s.%s", NAMESPACE, viewName)) + .hasMessageContaining("that references temporary view:") + .hasMessageContaining(String.format("%s.%s", "global_temp", globalTempView)); + } + + @Test + public void createViewWithSubqueryExpressionUsingTempFunction() { + String viewName = viewName("viewWithSubqueryExpression"); + String functionName = "avg_function_in_subquery"; + String sql = + String.format( + "SELECT * FROM %s WHERE id < (SELECT %s(id) FROM %s)", + tableName, functionName, tableName); + + sql( + "CREATE TEMPORARY FUNCTION %s AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage'", + functionName); + + assertThatThrownBy(() -> sql("CREATE VIEW %s AS %s", viewName, sql)) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining(String.format("Cannot create view %s.%s", NAMESPACE, viewName)) + .hasMessageContaining("that references temporary function:") + .hasMessageContaining(functionName); } @Test