From 9bb4b28a52193aa446be90706df1465ca35460b5 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Mon, 26 Aug 2024 15:29:13 +0800 Subject: [PATCH] [GLUTEN-6951][CORE][CH] Move CustomerExpressionTransformer to CH backend (#6993) Closes #6951 --- .../clickhouse/CHListenerApi.scala | 9 + .../clickhouse/CHSparkPlanExecApi.scala | 22 +- .../CHHashAggregateExecTransformer.scala | 13 +- .../gluten/expression/CHExpressions.scala | 45 ++++ .../extension/ExpressionExtensionTrait.scala | 11 +- .../spark/sql/utils/ExpressionUtil.scala | 3 +- .../CustomerExpressionTransformer.scala | 0 ...seCustomerExpressionTransformerSuite.scala | 24 +- .../org/apache/gluten/GlutenPlugin.scala | 12 +- .../gluten/backendsapi/SparkPlanExecApi.scala | 9 +- .../AggregateFunctionsBuilder.scala | 34 +-- .../expression/ExpressionConverter.scala | 254 ++++++++---------- .../expression/ExpressionMappings.scala | 22 +- .../clickhouse/ClickHouseTestSettings.scala | 3 +- .../utils/velox/VeloxTestSettings.scala | 3 +- .../org/apache/gluten/GlutenConfig.scala | 2 + .../sql/catalyst/expressions/EvalMode.scala | 36 +++ .../sql/catalyst/expressions/EvalMode.scala | 36 +++ 18 files changed, 323 insertions(+), 215 deletions(-) create mode 100644 backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala rename {gluten-core => backends-clickhouse}/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala (86%) rename {gluten-core => backends-clickhouse}/src/main/scala/org/apache/spark/sql/utils/ExpressionUtil.scala (92%) rename {gluten-ut/spark32 => backends-clickhouse}/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala (100%) rename gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/GlutenCustomerExpressionTransformerSuite.scala => backends-clickhouse/src/test/scala/org/apache/spark/sql/extension/GlutenClickhouseCustomerExpressionTransformerSuite.scala (87%) create mode 100644 shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/expressions/EvalMode.scala create mode 100644 shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/expressions/EvalMode.scala diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala index 69797feb65fb..60dc3dad0b87 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala @@ -21,6 +21,7 @@ import org.apache.gluten.backendsapi.ListenerApi import org.apache.gluten.execution.CHBroadcastBuildSideCache import org.apache.gluten.execution.datasource.{GlutenOrcWriterInjects, GlutenParquetWriterInjects, GlutenRowSplitter} import org.apache.gluten.expression.UDFMappings +import org.apache.gluten.extension.ExpressionExtensionTrait import org.apache.gluten.vectorized.{CHNativeExpressionEvaluator, JniLibLoader} import org.apache.spark.{SparkConf, SparkContext} @@ -30,6 +31,7 @@ import org.apache.spark.listener.CHGlutenSQLAppStatusListener import org.apache.spark.network.util.JavaUtils import org.apache.spark.rpc.{GlutenDriverEndpoint, GlutenExecutorEndpoint} import org.apache.spark.sql.execution.datasources.v1._ +import org.apache.spark.sql.utils.ExpressionUtil import org.apache.spark.util.SparkDirectoryUtil import org.apache.commons.lang3.StringUtils @@ -42,6 +44,13 @@ class CHListenerApi extends ListenerApi with Logging { GlutenDriverEndpoint.glutenDriverEndpointRef = (new GlutenDriverEndpoint).self CHGlutenSQLAppStatusListener.registerListener(sc) initialize(pc.conf, isDriver = true) + + val expressionExtensionTransformer = ExpressionUtil.extendedExpressionTransformer( + pc.conf.get(GlutenConfig.GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF, "") + ) + if (expressionExtensionTransformer != null) { + ExpressionExtensionTrait.expressionExtensionTransformer = expressionExtensionTransformer + } } override def onDriverShutdown(): Unit = shutdown() diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index bfa59aee7318..a8996c4d2e83 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -18,10 +18,10 @@ package org.apache.gluten.backendsapi.clickhouse import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.{BackendsApiManager, SparkPlanExecApi} -import org.apache.gluten.exception.GlutenException -import org.apache.gluten.exception.GlutenNotSupportException +import org.apache.gluten.exception.{GlutenException, GlutenNotSupportException} import org.apache.gluten.execution._ import org.apache.gluten.expression._ +import org.apache.gluten.extension.ExpressionExtensionTrait import org.apache.gluten.extension.columnar.AddFallbackTagRule import org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides import org.apache.gluten.extension.columnar.transition.Convention @@ -558,9 +558,25 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging { Sig[CollectList](ExpressionNames.COLLECT_LIST), Sig[CollectSet](ExpressionNames.COLLECT_SET) ) ++ + ExpressionExtensionTrait.expressionExtensionTransformer.expressionSigList ++ SparkShimLoader.getSparkShims.bloomFilterExpressionMappings() } + /** Define backend-specific expression converter. */ + override def extraExpressionConverter( + substraitExprName: String, + expr: Expression, + attributeSeq: Seq[Attribute]): Option[ExpressionTransformer] = expr match { + case e + if ExpressionExtensionTrait.expressionExtensionTransformer.extensionExpressionsMapping + .contains(e.getClass) => + // Use extended expression transformer to replace custom expression first + Some( + ExpressionExtensionTrait.expressionExtensionTransformer + .replaceWithExtensionExpressionTransformer(substraitExprName, e, attributeSeq)) + case _ => None + } + override def genStringTranslateTransformer( substraitExprName: String, srcExpr: ExpressionTransformer, @@ -700,7 +716,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging { .doTransform(args))) val windowFunctionNode = ExpressionBuilder.makeWindowFunction( - AggregateFunctionsBuilder.create(args, aggExpression.aggregateFunction).toInt, + CHExpressions.createAggregateFunction(args, aggExpression.aggregateFunction).toInt, childrenNodeList, columnName, ConverterUtils.getTypeNode(aggExpression.dataType, aggExpression.nullable), diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala index 6c1fee39c423..d641c05cd62e 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala @@ -20,6 +20,7 @@ import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.execution.CHHashAggregateExecTransformer.getAggregateResultAttributes import org.apache.gluten.expression._ +import org.apache.gluten.extension.ExpressionExtensionTrait import org.apache.gluten.substrait.`type`.{TypeBuilder, TypeNode} import org.apache.gluten.substrait.{AggregationParams, SubstraitContext} import org.apache.gluten.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode} @@ -249,7 +250,7 @@ case class CHHashAggregateExecTransformer( childrenNodeList.add(node) } val aggFunctionNode = ExpressionBuilder.makeAggregateFunction( - AggregateFunctionsBuilder.create(args, aggregateFunc), + CHExpressions.createAggregateFunction(args, aggregateFunc), childrenNodeList, modeToKeyWord(aggExpr.mode), ConverterUtils.getTypeNode(aggregateFunc.dataType, aggregateFunc.nullable) @@ -286,10 +287,10 @@ case class CHHashAggregateExecTransformer( val aggregateFunc = aggExpr.aggregateFunction var aggFunctionName = if ( - ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains( - aggregateFunc.getClass) + ExpressionExtensionTrait.expressionExtensionTransformer.extensionExpressionsMapping + .contains(aggregateFunc.getClass) ) { - ExpressionMappings.expressionExtensionTransformer + ExpressionExtensionTrait.expressionExtensionTransformer .buildCustomAggregateFunction(aggregateFunc) ._1 .get @@ -437,10 +438,10 @@ case class CHHashAggregateExecPullOutHelper( val aggregateFunc = exp.aggregateFunction // First handle the custom aggregate functions if ( - ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains( + ExpressionExtensionTrait.expressionExtensionTransformer.extensionExpressionsMapping.contains( aggregateFunc.getClass) ) { - ExpressionMappings.expressionExtensionTransformer + ExpressionExtensionTrait.expressionExtensionTransformer .getAttrsIndexForExtensionAggregateExpr( aggregateFunc, exp.mode, diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala new file mode 100644 index 000000000000..af1ac52b1e40 --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.expression + +import org.apache.gluten.expression.ConverterUtils.FunctionConfig +import org.apache.gluten.extension.ExpressionExtensionTrait +import org.apache.gluten.substrait.expression.ExpressionBuilder + +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction + +// Static helper object for handling expressions that are specifically used in CH backend. +object CHExpressions { + // Since https://github.com/apache/incubator-gluten/pull/1937. + def createAggregateFunction(args: java.lang.Object, aggregateFunc: AggregateFunction): Long = { + val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + if ( + ExpressionExtensionTrait.expressionExtensionTransformer.extensionExpressionsMapping.contains( + aggregateFunc.getClass) + ) { + val (substraitAggFuncName, inputTypes) = + ExpressionExtensionTrait.expressionExtensionTransformer.buildCustomAggregateFunction( + aggregateFunc) + assert(substraitAggFuncName.isDefined) + return ExpressionBuilder.newScalarFunction( + functionMap, + ConverterUtils.makeFuncName(substraitAggFuncName.get, inputTypes, FunctionConfig.REQ)) + } + + AggregateFunctionsBuilder.create(args, aggregateFunc) + } +} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala similarity index 86% rename from gluten-core/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala rename to backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala index 89bcb70641bd..c64f26869eb6 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala @@ -63,8 +63,13 @@ trait ExpressionExtensionTrait { } } -case class DefaultExpressionExtensionTransformer() extends ExpressionExtensionTrait with Logging { +object ExpressionExtensionTrait { + var expressionExtensionTransformer: ExpressionExtensionTrait = + DefaultExpressionExtensionTransformer() - /** Generate the extension expressions list, format: Sig[XXXExpression]("XXXExpressionName") */ - override def expressionSigList: Seq[Sig] = Seq.empty[Sig] + case class DefaultExpressionExtensionTransformer() extends ExpressionExtensionTrait with Logging { + + /** Generate the extension expressions list, format: Sig[XXXExpression]("XXXExpressionName") */ + override def expressionSigList: Seq[Sig] = Seq.empty[Sig] + } } diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/utils/ExpressionUtil.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/utils/ExpressionUtil.scala similarity index 92% rename from gluten-core/src/main/scala/org/apache/spark/sql/utils/ExpressionUtil.scala rename to backends-clickhouse/src/main/scala/org/apache/spark/sql/utils/ExpressionUtil.scala index b5c45e090f38..852b34a099f2 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/utils/ExpressionUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/utils/ExpressionUtil.scala @@ -16,7 +16,8 @@ */ package org.apache.spark.sql.utils -import org.apache.gluten.extension.{DefaultExpressionExtensionTransformer, ExpressionExtensionTrait} +import org.apache.gluten.extension.ExpressionExtensionTrait +import org.apache.gluten.extension.ExpressionExtensionTrait.DefaultExpressionExtensionTransformer import org.apache.spark.internal.Logging import org.apache.spark.util.Utils diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala b/backends-clickhouse/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala similarity index 100% rename from gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala rename to backends-clickhouse/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/GlutenCustomerExpressionTransformerSuite.scala b/backends-clickhouse/src/test/scala/org/apache/spark/sql/extension/GlutenClickhouseCustomerExpressionTransformerSuite.scala similarity index 87% rename from gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/GlutenCustomerExpressionTransformerSuite.scala rename to backends-clickhouse/src/test/scala/org/apache/spark/sql/extension/GlutenClickhouseCustomerExpressionTransformerSuite.scala index 91344f8778ca..cd8bf579fa66 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/GlutenCustomerExpressionTransformerSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/spark/sql/extension/GlutenClickhouseCustomerExpressionTransformerSuite.scala @@ -16,24 +16,25 @@ */ package org.apache.spark.sql.extension -import org.apache.gluten.execution.ProjectExecTransformer +import org.apache.gluten.execution.{GlutenClickHouseWholeStageTransformerSuite, ProjectExecTransformer} import org.apache.gluten.expression.ExpressionConverter import org.apache.spark.SparkConf -import org.apache.spark.sql.{GlutenSQLTestsTrait, Row} +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{IntervalUtils, TypeUtils} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{AbstractDataType, CalendarIntervalType, DayTimeIntervalType, TypeCollection, YearMonthIntervalType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval case class CustomAdd( left: Expression, right: Expression, - failOnError: Boolean = SQLConf.get.ansiEnabled) - extends BinaryArithmetic { + override val failOnError: Boolean = SQLConf.get.ansiEnabled) + extends BinaryArithmetic + with CustomAdd.Compatibility { def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) @@ -69,9 +70,18 @@ case class CustomAdd( newLeft: Expression, newRight: Expression ): CustomAdd = copy(left = newLeft, right = newRight) + + override protected val evalMode: EvalMode.Value = EvalMode.LEGACY +} + +object CustomAdd { + trait Compatibility { + protected val evalMode: EvalMode.Value + } } -class GlutenCustomerExpressionTransformerSuite extends GlutenSQLTestsTrait { +class GlutenClickhouseCustomerExpressionTransformerSuite + extends GlutenClickHouseWholeStageTransformerSuite { override def sparkConf: SparkConf = { super.sparkConf @@ -92,7 +102,7 @@ class GlutenCustomerExpressionTransformerSuite extends GlutenSQLTestsTrait { ) } - testGluten("test custom expression transformer") { + test("test custom expression transformer") { spark .createDataFrame(Seq((1, 1.1), (2, 2.2))) .createOrReplaceTempView("custom_table") diff --git a/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala b/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala index f775d78a15ac..6d0cdd0f8a05 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala @@ -20,7 +20,6 @@ import org.apache.gluten.GlutenConfig.GLUTEN_DEFAULT_SESSION_TIMEZONE_KEY import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.events.GlutenBuildInfoEvent import org.apache.gluten.exception.GlutenException -import org.apache.gluten.expression.ExpressionMappings import org.apache.gluten.extension.GlutenSessionExtensions.{GLUTEN_SESSION_EXTENSION_NAME, SPARK_SESSION_EXTS_KEY} import org.apache.gluten.test.TestStats import org.apache.gluten.utils.TaskListener @@ -32,7 +31,6 @@ import org.apache.spark.listener.GlutenListenerFactory import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.execution.ui.GlutenEventUtils import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.utils.ExpressionUtil import org.apache.spark.util.{SparkResourceUtil, TaskResources} import java.util @@ -73,14 +71,6 @@ private[gluten] class GlutenDriverPlugin extends DriverPlugin with Logging { BackendsApiManager.getListenerApiInstance.onDriverStart(sc, pluginContext) GlutenListenerFactory.addToSparkListenerBus(sc) - val expressionExtensionTransformer = ExpressionUtil.extendedExpressionTransformer( - conf.get(GlutenConfig.GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF, "") - ) - - if (expressionExtensionTransformer != null) { - ExpressionMappings.expressionExtensionTransformer = expressionExtensionTransformer - } - Collections.emptyMap() } @@ -275,7 +265,7 @@ private[gluten] class GlutenDriverPlugin extends DriverPlugin with Logging { } private[gluten] class GlutenExecutorPlugin extends ExecutorPlugin { - private val taskListeners: Seq[TaskListener] = Array(TaskResources) + private val taskListeners: Seq[TaskListener] = Seq(TaskResources) /** Initialize the executor plugin. */ override def init(ctx: PluginContext, extraConf: util.Map[String, String]): Unit = { diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index fb87a9ac93c0..a55926d76d12 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -448,9 +448,16 @@ trait SparkPlanExecApi { GenericExpressionTransformer(substraitExprName, exprs, original) } - /** Define backend specfic expression mappings. */ + /** Define backend-specific expression mappings. */ def extraExpressionMappings: Seq[Sig] = Seq.empty + /** Define backend-specific expression converter. */ + def extraExpressionConverter( + substraitExprName: String, + expr: Expression, + attributeSeq: Seq[Attribute]): Option[ExpressionTransformer] = + None + /** * Define whether the join operator is fallback because of the join operator is not supported by * backend diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala index 6ac2c67eb086..bd73b7b7aa54 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala @@ -29,32 +29,18 @@ object AggregateFunctionsBuilder { val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] // First handle the custom aggregate functions - val (substraitAggFuncName, inputTypes) = - if ( - ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains( - aggregateFunc.getClass) - ) { - val (substraitAggFuncName, inputTypes) = - ExpressionMappings.expressionExtensionTransformer.buildCustomAggregateFunction( - aggregateFunc) - assert(substraitAggFuncName.isDefined) - (substraitAggFuncName.get, inputTypes) - } else { - val substraitAggFuncName = getSubstraitFunctionName(aggregateFunc) + val substraitAggFuncName = getSubstraitFunctionName(aggregateFunc) - // Check whether each backend supports this aggregate function. - if ( - !BackendsApiManager.getValidatorApiInstance.doExprValidate( - substraitAggFuncName, - aggregateFunc) - ) { - throw new GlutenNotSupportException( - s"Aggregate function not supported for $aggregateFunc.") - } + // Check whether each backend supports this aggregate function. + if ( + !BackendsApiManager.getValidatorApiInstance.doExprValidate( + substraitAggFuncName, + aggregateFunc) + ) { + throw new GlutenNotSupportException(s"Aggregate function not supported for $aggregateFunc.") + } - val inputTypes: Seq[DataType] = aggregateFunc.children.map(child => child.dataType) - (substraitAggFuncName, inputTypes) - } + val inputTypes: Seq[DataType] = aggregateFunc.children.map(child => child.dataType) ExpressionBuilder.newScalarFunction( functionMap, diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index d5ca31bb5e78..c5ba3a8a7839 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -43,16 +43,14 @@ object ExpressionConverter extends SQLConfHelper with Logging { exprs: Seq[Expression], attributeSeq: Seq[Attribute]): Seq[ExpressionTransformer] = { val expressionsMap = ExpressionMappings.expressionsMap - exprs.map { - expr => replaceWithExpressionTransformerInternal(expr, attributeSeq, expressionsMap) - } + exprs.map(expr => replaceWithExpressionTransformer0(expr, attributeSeq, expressionsMap)) } def replaceWithExpressionTransformer( expr: Expression, attributeSeq: Seq[Attribute]): ExpressionTransformer = { val expressionsMap = ExpressionMappings.expressionsMap - replaceWithExpressionTransformerInternal(expr, attributeSeq, expressionsMap) + replaceWithExpressionTransformer0(expr, attributeSeq, expressionsMap) } private def replacePythonUDFWithExpressionTransformer( @@ -64,8 +62,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { case Some(name) => GenericExpressionTransformer( name, - udf.children.map( - replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), + udf.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), udf) case _ => throw new GlutenNotSupportException(s"Not supported python udf: $udf.") @@ -84,8 +81,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { case Some(name) => GenericExpressionTransformer( name, - udf.children.map( - replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), + udf.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), udf) case _ => throw new GlutenNotSupportException(s"Not supported scala udf: $udf.") @@ -108,13 +104,13 @@ object ExpressionConverter extends SQLConfHelper with Logging { ) val leftChild = - replaceWithExpressionTransformerInternal(left, attributeSeq, expressionsMap) + replaceWithExpressionTransformer0(left, attributeSeq, expressionsMap) val rightChild = - replaceWithExpressionTransformerInternal(right, attributeSeq, expressionsMap) + replaceWithExpressionTransformer0(right, attributeSeq, expressionsMap) DecimalArithmeticExpressionTransformer(substraitName, leftChild, rightChild, resultType, b) } - private def replaceWithExpressionTransformerInternal( + private def replaceWithExpressionTransformer0( expr: Expression, attributeSeq: Seq[Attribute], expressionsMap: Map[Class[_], String]): ExpressionTransformer = { @@ -139,14 +135,12 @@ object ExpressionConverter extends SQLConfHelper with Logging { case "decode" => return GenericExpressionTransformer( ExpressionNames.URL_DECODE, - child.map( - replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), + child.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), i) case "encode" => return GenericExpressionTransformer( ExpressionNames.URL_ENCODE, - child.map( - replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), + child.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), i) } } @@ -154,61 +148,61 @@ object ExpressionConverter extends SQLConfHelper with Logging { } val substraitExprName: String = getAndCheckSubstraitName(expr, expressionsMap) - + val backendConverted = BackendsApiManager.getSparkPlanExecApiInstance.extraExpressionConverter( + substraitExprName, + expr, + attributeSeq) + if (backendConverted.isDefined) { + return backendConverted.get + } expr match { - case extendedExpr - if ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains( - extendedExpr.getClass) => - // Use extended expression transformer to replace custom expression first - ExpressionMappings.expressionExtensionTransformer - .replaceWithExtensionExpressionTransformer(substraitExprName, extendedExpr, attributeSeq) case c: CreateArray => val children = - c.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)) + c.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)) CreateArrayTransformer(substraitExprName, children, c) case g: GetArrayItem => BackendsApiManager.getSparkPlanExecApiInstance.genGetArrayItemTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(g.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(g.right, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(g.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(g.right, attributeSeq, expressionsMap), g ) case c: CreateMap => val children = - c.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)) + c.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)) CreateMapTransformer(substraitExprName, children, c) case g: GetMapValue => BackendsApiManager.getSparkPlanExecApiInstance.genGetMapValueTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(g.child, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(g.key, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(g.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(g.key, attributeSeq, expressionsMap), g ) case m: MapEntries => BackendsApiManager.getSparkPlanExecApiInstance.genMapEntriesTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(m.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(m.child, attributeSeq, expressionsMap), m) case e: Explode => ExplodeTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(e.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(e.child, attributeSeq, expressionsMap), e) case p: PosExplode => BackendsApiManager.getSparkPlanExecApiInstance.genPosExplodeTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(p.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(p.child, attributeSeq, expressionsMap), p, attributeSeq) case i: Inline => BackendsApiManager.getSparkPlanExecApiInstance.genInlineTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(i.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(i.child, attributeSeq, expressionsMap), i) case a: Alias => BackendsApiManager.getSparkPlanExecApiInstance.genAliasTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(a.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.child, attributeSeq, expressionsMap), a) case a: AttributeReference => if (attributeSeq == null) { @@ -233,14 +227,14 @@ object ExpressionConverter extends SQLConfHelper with Logging { case d: DateDiff => BackendsApiManager.getSparkPlanExecApiInstance.genDateDiffTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(d.endDate, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(d.startDate, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(d.endDate, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(d.startDate, attributeSeq, expressionsMap), d ) case r: Round if r.child.dataType.isInstanceOf[DecimalType] => DecimalRoundTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(r.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(r.child, attributeSeq, expressionsMap), r) case t: ToUnixTimestamp => // The failOnError depends on the config for ANSI. ANSI is not supported currently. @@ -248,8 +242,8 @@ object ExpressionConverter extends SQLConfHelper with Logging { GenericExpressionTransformer( substraitExprName, Seq( - replaceWithExpressionTransformerInternal(t.timeExp, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(t.format, attributeSeq, expressionsMap) + replaceWithExpressionTransformer0(t.timeExp, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(t.format, attributeSeq, expressionsMap) ), t ) @@ -257,33 +251,33 @@ object ExpressionConverter extends SQLConfHelper with Logging { GenericExpressionTransformer( substraitExprName, Seq( - replaceWithExpressionTransformerInternal(u.timeExp, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(u.format, attributeSeq, expressionsMap) + replaceWithExpressionTransformer0(u.timeExp, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(u.format, attributeSeq, expressionsMap) ), ToUnixTimestamp(u.timeExp, u.format, u.timeZoneId, u.failOnError) ) case t: TruncTimestamp => BackendsApiManager.getSparkPlanExecApiInstance.genTruncTimestampTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(t.format, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(t.timestamp, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(t.format, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(t.timestamp, attributeSeq, expressionsMap), t.timeZoneId, t ) case m: MonthsBetween => MonthsBetweenTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(m.date1, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(m.date2, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(m.roundOff, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(m.date1, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(m.date2, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(m.roundOff, attributeSeq, expressionsMap), m ) case i: If => IfTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(i.predicate, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(i.trueValue, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(i.falseValue, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(i.predicate, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(i.trueValue, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(i.falseValue, attributeSeq, expressionsMap), i ) case cw: CaseWhen => @@ -293,14 +287,14 @@ object ExpressionConverter extends SQLConfHelper with Logging { expr => { ( - replaceWithExpressionTransformerInternal(expr._1, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(expr._2, attributeSeq, expressionsMap)) + replaceWithExpressionTransformer0(expr._1, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(expr._2, attributeSeq, expressionsMap)) } }, cw.elseValue.map { expr => { - replaceWithExpressionTransformerInternal(expr, attributeSeq, expressionsMap) + replaceWithExpressionTransformer0(expr, attributeSeq, expressionsMap) } }, cw @@ -312,12 +306,12 @@ object ExpressionConverter extends SQLConfHelper with Logging { } InTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(i.value, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(i.value, attributeSeq, expressionsMap), i) case i: InSet => InSetTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(i.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(i.child, attributeSeq, expressionsMap), i) case s: ScalarSubquery => ScalarSubqueryTransformer(substraitExprName, s) @@ -327,7 +321,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { BackendsApiManager.getSparkPlanExecApiInstance.genCastWithNewChild(c) CastTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(newCast.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(newCast.child, attributeSeq, expressionsMap), newCast) case s: String2TrimExpression => val (srcStr, trimStr) = s match { @@ -336,9 +330,9 @@ object ExpressionConverter extends SQLConfHelper with Logging { case StringTrimRight(srcStr, trimStr) => (srcStr, trimStr) } val children = trimStr - .map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)) + .map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)) .toSeq ++ - Seq(replaceWithExpressionTransformerInternal(srcStr, attributeSeq, expressionsMap)) + Seq(replaceWithExpressionTransformer0(srcStr, attributeSeq, expressionsMap)) GenericExpressionTransformer( substraitExprName, children, @@ -348,23 +342,20 @@ object ExpressionConverter extends SQLConfHelper with Logging { BackendsApiManager.getSparkPlanExecApiInstance.genHashExpressionTransformer( substraitExprName, m.children.map( - expr => replaceWithExpressionTransformerInternal(expr, attributeSeq, expressionsMap)), + expr => replaceWithExpressionTransformer0(expr, attributeSeq, expressionsMap)), m) case getStructField: GetStructField => // Different backends may have different result. BackendsApiManager.getSparkPlanExecApiInstance.genGetStructFieldTransformer( substraitExprName, - replaceWithExpressionTransformerInternal( - getStructField.child, - attributeSeq, - expressionsMap), + replaceWithExpressionTransformer0(getStructField.child, attributeSeq, expressionsMap), getStructField.ordinal, getStructField) case getArrayStructFields: GetArrayStructFields => GenericExpressionTransformer( substraitExprName, Seq( - replaceWithExpressionTransformerInternal( + replaceWithExpressionTransformer0( getArrayStructFields.child, attributeSeq, expressionsMap), @@ -374,26 +365,26 @@ object ExpressionConverter extends SQLConfHelper with Logging { case t: StringTranslate => BackendsApiManager.getSparkPlanExecApiInstance.genStringTranslateTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(t.srcExpr, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(t.matchingExpr, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(t.replaceExpr, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(t.srcExpr, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(t.matchingExpr, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(t.replaceExpr, attributeSeq, expressionsMap), t ) case r: RegExpReplace => BackendsApiManager.getSparkPlanExecApiInstance.genRegexpReplaceTransformer( substraitExprName, Seq( - replaceWithExpressionTransformerInternal(r.subject, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(r.regexp, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(r.rep, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(r.pos, attributeSeq, expressionsMap) + replaceWithExpressionTransformer0(r.subject, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(r.regexp, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(r.rep, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(r.pos, attributeSeq, expressionsMap) ), r ) case size: Size => // Covers Spark ArraySize which is replaced by Size(child, false). val child = - replaceWithExpressionTransformerInternal(size.child, attributeSeq, expressionsMap) + replaceWithExpressionTransformer0(size.child, attributeSeq, expressionsMap) GenericExpressionTransformer( substraitExprName, Seq(child, LiteralTransformer(size.legacySizeOfNull)), @@ -402,7 +393,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { BackendsApiManager.getSparkPlanExecApiInstance.genNamedStructTransformer( substraitExprName, namedStruct.children.map( - replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), + replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), namedStruct, attributeSeq) case namedLambdaVariable: NamedLambdaVariable => @@ -415,64 +406,57 @@ object ExpressionConverter extends SQLConfHelper with Logging { case lambdaFunction: LambdaFunction => LambdaFunctionTransformer( substraitExprName, - function = replaceWithExpressionTransformerInternal( + function = replaceWithExpressionTransformer0( lambdaFunction.function, attributeSeq, expressionsMap), arguments = lambdaFunction.arguments.map( - replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), + replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), original = lambdaFunction ) case j: JsonTuple => val children = - j.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)) + j.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)) JsonTupleExpressionTransformer(substraitExprName, children, j) case l: Like => BackendsApiManager.getSparkPlanExecApiInstance.genLikeTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(l.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(l.right, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(l.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(l.right, attributeSeq, expressionsMap), l ) case m: MakeDecimal => GenericExpressionTransformer( substraitExprName, Seq( - replaceWithExpressionTransformerInternal(m.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(m.child, attributeSeq, expressionsMap), LiteralTransformer(m.nullOnOverflow)), m ) case _: NormalizeNaNAndZero | _: PromotePrecision | _: TaggingExpression => ChildTransformer( substraitExprName, - replaceWithExpressionTransformerInternal( - expr.children.head, - attributeSeq, - expressionsMap), + replaceWithExpressionTransformer0(expr.children.head, attributeSeq, expressionsMap), expr ) case _: GetDateField | _: GetTimeField => ExtractDateTransformer( substraitExprName, - replaceWithExpressionTransformerInternal( - expr.children.head, - attributeSeq, - expressionsMap), + replaceWithExpressionTransformer0(expr.children.head, attributeSeq, expressionsMap), expr) case _: StringToMap => BackendsApiManager.getSparkPlanExecApiInstance.genStringToMapTransformer( substraitExprName, - expr.children.map( - replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), + expr.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), expr) case CheckOverflow(b: BinaryArithmetic, decimalType, _) if !BackendsApiManager.getSettings.transformCheckOverflow && DecimalArithmeticUtil.isDecimalArithmetic(b) => DecimalArithmeticUtil.checkAllowDecimalArithmetic() val leftChild = - replaceWithExpressionTransformerInternal(b.left, attributeSeq, expressionsMap) + replaceWithExpressionTransformer0(b.left, attributeSeq, expressionsMap) val rightChild = - replaceWithExpressionTransformerInternal(b.right, attributeSeq, expressionsMap) + replaceWithExpressionTransformer0(b.right, attributeSeq, expressionsMap) DecimalArithmeticExpressionTransformer( getAndCheckSubstraitName(b, expressionsMap), leftChild, @@ -482,15 +466,14 @@ object ExpressionConverter extends SQLConfHelper with Logging { case c: CheckOverflow => CheckOverflowTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(c.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(c.child, attributeSeq, expressionsMap), c) case b: BinaryArithmetic if DecimalArithmeticUtil.isDecimalArithmetic(b) => DecimalArithmeticUtil.checkAllowDecimalArithmetic() if (!BackendsApiManager.getSettings.transformCheckOverflow) { GenericExpressionTransformer( substraitExprName, - expr.children.map( - replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), + expr.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), expr ) } else { @@ -501,14 +484,14 @@ object ExpressionConverter extends SQLConfHelper with Logging { case n: NaNvl => BackendsApiManager.getSparkPlanExecApiInstance.genNaNvlTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(n.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(n.right, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(n.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(n.right, attributeSeq, expressionsMap), n ) case m: MakeTimestamp => BackendsApiManager.getSparkPlanExecApiInstance.genMakeTimestampTransformer( substraitExprName, - m.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), + m.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), m) case timestampAdd if timestampAdd.getClass.getSimpleName.equals("TimestampAdd") => // for spark3.3 @@ -520,111 +503,99 @@ object ExpressionConverter extends SQLConfHelper with Logging { TimestampAddTransformer( substraitExprName, extract.get.head, - replaceWithExpressionTransformerInternal(add.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(add.right, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(add.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(add.right, attributeSeq, expressionsMap), extract.get.last, add ) case e: Transformable => val childrenTransformers = - e.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)) + e.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)) e.getTransformer(childrenTransformers) case u: Uuid => BackendsApiManager.getSparkPlanExecApiInstance.genUuidTransformer(substraitExprName, u) case f: ArrayFilter => BackendsApiManager.getSparkPlanExecApiInstance.genArrayFilterTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(f.argument, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(f.function, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(f.argument, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(f.function, attributeSeq, expressionsMap), f ) case arrayTransform: ArrayTransform => BackendsApiManager.getSparkPlanExecApiInstance.genArrayTransformTransformer( substraitExprName, - replaceWithExpressionTransformerInternal( - arrayTransform.argument, - attributeSeq, - expressionsMap), - replaceWithExpressionTransformerInternal( - arrayTransform.function, - attributeSeq, - expressionsMap), + replaceWithExpressionTransformer0(arrayTransform.argument, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(arrayTransform.function, attributeSeq, expressionsMap), arrayTransform ) case arraySort: ArraySort => BackendsApiManager.getSparkPlanExecApiInstance.genArraySortTransformer( substraitExprName, - replaceWithExpressionTransformerInternal( - arraySort.argument, - attributeSeq, - expressionsMap), - replaceWithExpressionTransformerInternal( - arraySort.function, - attributeSeq, - expressionsMap), + replaceWithExpressionTransformer0(arraySort.argument, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(arraySort.function, attributeSeq, expressionsMap), arraySort ) case tryEval @ TryEval(a: Add) => BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.right, attributeSeq, expressionsMap), tryEval, ExpressionNames.CHECKED_ADD ) case tryEval @ TryEval(a: Subtract) => BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.right, attributeSeq, expressionsMap), tryEval, ExpressionNames.CHECKED_SUBTRACT ) case tryEval @ TryEval(a: Divide) => BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.right, attributeSeq, expressionsMap), tryEval, ExpressionNames.CHECKED_DIVIDE ) case tryEval @ TryEval(a: Multiply) => BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.right, attributeSeq, expressionsMap), tryEval, ExpressionNames.CHECKED_MULTIPLY ) case a: Add => BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.right, attributeSeq, expressionsMap), a, ExpressionNames.CHECKED_ADD ) case a: Subtract => BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.right, attributeSeq, expressionsMap), a, ExpressionNames.CHECKED_SUBTRACT ) case a: Multiply => BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.right, attributeSeq, expressionsMap), a, ExpressionNames.CHECKED_MULTIPLY ) case a: Divide => BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.right, attributeSeq, expressionsMap), a, ExpressionNames.CHECKED_DIVIDE ) @@ -632,34 +603,34 @@ object ExpressionConverter extends SQLConfHelper with Logging { // This is a placeholder to handle try_eval(other expressions). BackendsApiManager.getSparkPlanExecApiInstance.genTryEvalTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(tryEval.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(tryEval.child, attributeSeq, expressionsMap), tryEval ) case a: ArrayForAll => BackendsApiManager.getSparkPlanExecApiInstance.genArrayForAllTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(a.argument, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(a.function, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.argument, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.function, attributeSeq, expressionsMap), a ) case a: ArrayExists => BackendsApiManager.getSparkPlanExecApiInstance.genArrayExistsTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(a.argument, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(a.function, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.argument, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.function, attributeSeq, expressionsMap), a ) case s: Shuffle => GenericExpressionTransformer( substraitExprName, Seq( - replaceWithExpressionTransformerInternal(s.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(s.child, attributeSeq, expressionsMap), LiteralTransformer(Literal(s.randomSeed.get))), s) case c: PreciseTimestampConversion => BackendsApiManager.getSparkPlanExecApiInstance.genPreciseTimestampConversionTransformer( substraitExprName, - Seq(replaceWithExpressionTransformerInternal(c.child, attributeSeq, expressionsMap)), + Seq(replaceWithExpressionTransformer0(c.child, attributeSeq, expressionsMap)), c ) case t: TransformKeys => @@ -674,7 +645,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { } GenericExpressionTransformer( substraitExprName, - t.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), + t.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), t ) case e: EulerNumber => @@ -700,8 +671,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { case expr => GenericExpressionTransformer( substraitExprName, - expr.children.map( - replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), + expr.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), expr ) } diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala index f2bb4a90621a..38f9de629a16 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala @@ -19,7 +19,6 @@ package org.apache.gluten.expression import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.expression.ExpressionNames._ -import org.apache.gluten.extension.{DefaultExpressionExtensionTransformer, ExpressionExtensionTrait} import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.sql.catalyst.expressions._ @@ -338,22 +337,19 @@ object ExpressionMappings { def expressionsMap: Map[Class[_], String] = { val blacklist = GlutenConfig.getConf.expressionBlacklist - val supportedExprs = defaultExpressionsMap ++ - expressionExtensionTransformer.extensionExpressionsMapping - if (blacklist.isEmpty) { - supportedExprs - } else { - supportedExprs.filterNot(kv => blacklist.contains(kv._2)) - } + val filtered = (defaultExpressionsMap ++ toMap( + BackendsApiManager.getSparkPlanExecApiInstance.extraExpressionMappings)).filterNot( + kv => blacklist.contains(kv._2)) + filtered } private lazy val defaultExpressionsMap: Map[Class[_], String] = { - (SCALAR_SIGS ++ AGGREGATE_SIGS ++ WINDOW_SIGS ++ - BackendsApiManager.getSparkPlanExecApiInstance.extraExpressionMappings) + toMap(SCALAR_SIGS ++ AGGREGATE_SIGS ++ WINDOW_SIGS) + } + + private def toMap(sigs: Seq[Sig]): Map[Class[_], String] = { + sigs .map(s => (s.expClass, s.name)) .toMap[Class[_], String] } - - var expressionExtensionTransformer: ExpressionExtensionTrait = - DefaultExpressionExtensionTransformer() } diff --git a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 5c2833de4bc0..9b2e2ab95bc9 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.datasources.text.{GlutenTextV1Suite, Glute import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.execution.exchange.GlutenEnsureRequirementsSuite import org.apache.spark.sql.execution.joins.{GlutenExistenceJoinSuite, GlutenInnerJoinSuite, GlutenOuterJoinSuite} -import org.apache.spark.sql.extension.{GlutenCustomerExpressionTransformerSuite, GlutenCustomerExtensionSuite, GlutenSessionExtensionSuite} +import org.apache.spark.sql.extension.{GlutenCustomerExtensionSuite, GlutenSessionExtensionSuite} import org.apache.spark.sql.hive.execution.GlutenHiveSQLQueryCHSuite import org.apache.spark.sql.sources._ import org.apache.spark.sql.statistics.SparkFunctionStatistics @@ -2133,7 +2133,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("right outer join with unique keys using ShuffledHashJoin (whole-stage-codegen on)") .exclude("right outer join with unique keys using SortMergeJoin (whole-stage-codegen off)") .exclude("right outer join with unique keys using SortMergeJoin (whole-stage-codegen on)") - enableSuite[GlutenCustomerExpressionTransformerSuite] enableSuite[GlutenCustomerExtensionSuite] enableSuite[GlutenSessionExtensionSuite] enableSuite[GlutenBucketedReadWithoutHiveSupportSuite] diff --git a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index c4799366dc96..e064f2afc9d7 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.datasources.text.{GlutenTextV1Suite, Glute import org.apache.spark.sql.execution.datasources.v2.GlutenFileTableSuite import org.apache.spark.sql.execution.exchange.GlutenEnsureRequirementsSuite import org.apache.spark.sql.execution.joins.{GlutenBroadcastJoinSuite, GlutenExistenceJoinSuite, GlutenInnerJoinSuite, GlutenOuterJoinSuite} -import org.apache.spark.sql.extension.{GlutenCollapseProjectExecTransformerSuite, GlutenCustomerExpressionTransformerSuite, GlutenSessionExtensionSuite} +import org.apache.spark.sql.extension.{GlutenCollapseProjectExecTransformerSuite, GlutenSessionExtensionSuite} import org.apache.spark.sql.hive.execution.GlutenHiveSQLQuerySuite import org.apache.spark.sql.sources._ @@ -44,7 +44,6 @@ import org.apache.spark.sql.sources._ class VeloxTestSettings extends BackendTestSettings { enableSuite[GlutenSessionExtensionSuite] - enableSuite[GlutenCustomerExpressionTransformerSuite] enableSuite[GlutenDataFrameAggregateSuite] .exclude( diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala index bb0e683c25eb..5c032d4b0ef4 100644 --- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala +++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala @@ -610,6 +610,7 @@ object GlutenConfig { val GLUTEN_SUPPORTED_PYTHON_UDFS = "spark.gluten.supported.python.udfs" val GLUTEN_SUPPORTED_SCALA_UDFS = "spark.gluten.supported.scala.udfs" + // FIXME: This only works with CH backend. val GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF = "spark.gluten.sql.columnar.extended.expressions.transformer" @@ -1686,6 +1687,7 @@ object GlutenConfig { .stringConf .createWithDefaultString("") + // FIXME: This only works with CH backend. val EXTENDED_EXPRESSION_TRAN_CONF = buildConf(GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF) .doc("A class for the extended expressions transformer.") diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/expressions/EvalMode.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/expressions/EvalMode.scala new file mode 100644 index 000000000000..0a3c63ccd8b9 --- /dev/null +++ b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/expressions/EvalMode.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.internal.SQLConf + +/** For compatibility with Spark version <= 3.3. The class was added in vanilla Spark since 3.4. */ +object EvalMode extends Enumeration { + val LEGACY, ANSI, TRY = Value + + def fromSQLConf(conf: SQLConf): Value = if (conf.ansiEnabled) { + ANSI + } else { + LEGACY + } + + def fromBoolean(ansiEnabled: Boolean): Value = if (ansiEnabled) { + ANSI + } else { + LEGACY + } +} diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/expressions/EvalMode.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/expressions/EvalMode.scala new file mode 100644 index 000000000000..0a3c63ccd8b9 --- /dev/null +++ b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/expressions/EvalMode.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.internal.SQLConf + +/** For compatibility with Spark version <= 3.3. The class was added in vanilla Spark since 3.4. */ +object EvalMode extends Enumeration { + val LEGACY, ANSI, TRY = Value + + def fromSQLConf(conf: SQLConf): Value = if (conf.ansiEnabled) { + ANSI + } else { + LEGACY + } + + def fromBoolean(ansiEnabled: Boolean): Value = if (ansiEnabled) { + ANSI + } else { + LEGACY + } +}