Skip to content

Commit

Permalink
[GLUTEN-6951][CORE][CH] Move CustomerExpressionTransformer to CH back…
Browse files Browse the repository at this point in the history
…end (#6993)

Closes #6951
  • Loading branch information
zhztheplayer authored Aug 26, 2024
1 parent 603faba commit 9bb4b28
Show file tree
Hide file tree
Showing 18 changed files with 323 additions and 215 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.gluten.backendsapi.ListenerApi
import org.apache.gluten.execution.CHBroadcastBuildSideCache
import org.apache.gluten.execution.datasource.{GlutenOrcWriterInjects, GlutenParquetWriterInjects, GlutenRowSplitter}
import org.apache.gluten.expression.UDFMappings
import org.apache.gluten.extension.ExpressionExtensionTrait
import org.apache.gluten.vectorized.{CHNativeExpressionEvaluator, JniLibLoader}

import org.apache.spark.{SparkConf, SparkContext}
Expand All @@ -30,6 +31,7 @@ import org.apache.spark.listener.CHGlutenSQLAppStatusListener
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rpc.{GlutenDriverEndpoint, GlutenExecutorEndpoint}
import org.apache.spark.sql.execution.datasources.v1._
import org.apache.spark.sql.utils.ExpressionUtil
import org.apache.spark.util.SparkDirectoryUtil

import org.apache.commons.lang3.StringUtils
Expand All @@ -42,6 +44,13 @@ class CHListenerApi extends ListenerApi with Logging {
GlutenDriverEndpoint.glutenDriverEndpointRef = (new GlutenDriverEndpoint).self
CHGlutenSQLAppStatusListener.registerListener(sc)
initialize(pc.conf, isDriver = true)

val expressionExtensionTransformer = ExpressionUtil.extendedExpressionTransformer(
pc.conf.get(GlutenConfig.GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF, "")
)
if (expressionExtensionTransformer != null) {
ExpressionExtensionTrait.expressionExtensionTransformer = expressionExtensionTransformer
}
}

override def onDriverShutdown(): Unit = shutdown()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ package org.apache.gluten.backendsapi.clickhouse

import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.{BackendsApiManager, SparkPlanExecApi}
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.exception.{GlutenException, GlutenNotSupportException}
import org.apache.gluten.execution._
import org.apache.gluten.expression._
import org.apache.gluten.extension.ExpressionExtensionTrait
import org.apache.gluten.extension.columnar.AddFallbackTagRule
import org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides
import org.apache.gluten.extension.columnar.transition.Convention
Expand Down Expand Up @@ -558,9 +558,25 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
Sig[CollectList](ExpressionNames.COLLECT_LIST),
Sig[CollectSet](ExpressionNames.COLLECT_SET)
) ++
ExpressionExtensionTrait.expressionExtensionTransformer.expressionSigList ++
SparkShimLoader.getSparkShims.bloomFilterExpressionMappings()
}

/** Define backend-specific expression converter. */
override def extraExpressionConverter(
substraitExprName: String,
expr: Expression,
attributeSeq: Seq[Attribute]): Option[ExpressionTransformer] = expr match {
case e
if ExpressionExtensionTrait.expressionExtensionTransformer.extensionExpressionsMapping
.contains(e.getClass) =>
// Use extended expression transformer to replace custom expression first
Some(
ExpressionExtensionTrait.expressionExtensionTransformer
.replaceWithExtensionExpressionTransformer(substraitExprName, e, attributeSeq))
case _ => None
}

override def genStringTranslateTransformer(
substraitExprName: String,
srcExpr: ExpressionTransformer,
Expand Down Expand Up @@ -700,7 +716,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
.doTransform(args)))

val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
AggregateFunctionsBuilder.create(args, aggExpression.aggregateFunction).toInt,
CHExpressions.createAggregateFunction(args, aggExpression.aggregateFunction).toInt,
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(aggExpression.dataType, aggExpression.nullable),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution.CHHashAggregateExecTransformer.getAggregateResultAttributes
import org.apache.gluten.expression._
import org.apache.gluten.extension.ExpressionExtensionTrait
import org.apache.gluten.substrait.`type`.{TypeBuilder, TypeNode}
import org.apache.gluten.substrait.{AggregationParams, SubstraitContext}
import org.apache.gluten.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode}
Expand Down Expand Up @@ -249,7 +250,7 @@ case class CHHashAggregateExecTransformer(
childrenNodeList.add(node)
}
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
AggregateFunctionsBuilder.create(args, aggregateFunc),
CHExpressions.createAggregateFunction(args, aggregateFunc),
childrenNodeList,
modeToKeyWord(aggExpr.mode),
ConverterUtils.getTypeNode(aggregateFunc.dataType, aggregateFunc.nullable)
Expand Down Expand Up @@ -286,10 +287,10 @@ case class CHHashAggregateExecTransformer(
val aggregateFunc = aggExpr.aggregateFunction
var aggFunctionName =
if (
ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains(
aggregateFunc.getClass)
ExpressionExtensionTrait.expressionExtensionTransformer.extensionExpressionsMapping
.contains(aggregateFunc.getClass)
) {
ExpressionMappings.expressionExtensionTransformer
ExpressionExtensionTrait.expressionExtensionTransformer
.buildCustomAggregateFunction(aggregateFunc)
._1
.get
Expand Down Expand Up @@ -437,10 +438,10 @@ case class CHHashAggregateExecPullOutHelper(
val aggregateFunc = exp.aggregateFunction
// First handle the custom aggregate functions
if (
ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains(
ExpressionExtensionTrait.expressionExtensionTransformer.extensionExpressionsMapping.contains(
aggregateFunc.getClass)
) {
ExpressionMappings.expressionExtensionTransformer
ExpressionExtensionTrait.expressionExtensionTransformer
.getAttrsIndexForExtensionAggregateExpr(
aggregateFunc,
exp.mode,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.expression

import org.apache.gluten.expression.ConverterUtils.FunctionConfig
import org.apache.gluten.extension.ExpressionExtensionTrait
import org.apache.gluten.substrait.expression.ExpressionBuilder

import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction

// Static helper object for handling expressions that are specifically used in CH backend.
object CHExpressions {
// Since https://github.com/apache/incubator-gluten/pull/1937.
def createAggregateFunction(args: java.lang.Object, aggregateFunc: AggregateFunction): Long = {
val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]]
if (
ExpressionExtensionTrait.expressionExtensionTransformer.extensionExpressionsMapping.contains(
aggregateFunc.getClass)
) {
val (substraitAggFuncName, inputTypes) =
ExpressionExtensionTrait.expressionExtensionTransformer.buildCustomAggregateFunction(
aggregateFunc)
assert(substraitAggFuncName.isDefined)
return ExpressionBuilder.newScalarFunction(
functionMap,
ConverterUtils.makeFuncName(substraitAggFuncName.get, inputTypes, FunctionConfig.REQ))
}

AggregateFunctionsBuilder.create(args, aggregateFunc)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,13 @@ trait ExpressionExtensionTrait {
}
}

case class DefaultExpressionExtensionTransformer() extends ExpressionExtensionTrait with Logging {
object ExpressionExtensionTrait {
var expressionExtensionTransformer: ExpressionExtensionTrait =
DefaultExpressionExtensionTransformer()

/** Generate the extension expressions list, format: Sig[XXXExpression]("XXXExpressionName") */
override def expressionSigList: Seq[Sig] = Seq.empty[Sig]
case class DefaultExpressionExtensionTransformer() extends ExpressionExtensionTrait with Logging {

/** Generate the extension expressions list, format: Sig[XXXExpression]("XXXExpressionName") */
override def expressionSigList: Seq[Sig] = Seq.empty[Sig]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
*/
package org.apache.spark.sql.utils

import org.apache.gluten.extension.{DefaultExpressionExtensionTransformer, ExpressionExtensionTrait}
import org.apache.gluten.extension.ExpressionExtensionTrait
import org.apache.gluten.extension.ExpressionExtensionTrait.DefaultExpressionExtensionTransformer

import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,25 @@
*/
package org.apache.spark.sql.extension

import org.apache.gluten.execution.ProjectExecTransformer
import org.apache.gluten.execution.{GlutenClickHouseWholeStageTransformerSuite, ProjectExecTransformer}
import org.apache.gluten.expression.ExpressionConverter

import org.apache.spark.SparkConf
import org.apache.spark.sql.{GlutenSQLTestsTrait, Row}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.{IntervalUtils, TypeUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{AbstractDataType, CalendarIntervalType, DayTimeIntervalType, TypeCollection, YearMonthIntervalType}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

case class CustomAdd(
left: Expression,
right: Expression,
failOnError: Boolean = SQLConf.get.ansiEnabled)
extends BinaryArithmetic {
override val failOnError: Boolean = SQLConf.get.ansiEnabled)
extends BinaryArithmetic
with CustomAdd.Compatibility {

def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled)

Expand Down Expand Up @@ -69,9 +70,18 @@ case class CustomAdd(
newLeft: Expression,
newRight: Expression
): CustomAdd = copy(left = newLeft, right = newRight)

override protected val evalMode: EvalMode.Value = EvalMode.LEGACY
}

object CustomAdd {
trait Compatibility {
protected val evalMode: EvalMode.Value
}
}

class GlutenCustomerExpressionTransformerSuite extends GlutenSQLTestsTrait {
class GlutenClickhouseCustomerExpressionTransformerSuite
extends GlutenClickHouseWholeStageTransformerSuite {

override def sparkConf: SparkConf = {
super.sparkConf
Expand All @@ -92,7 +102,7 @@ class GlutenCustomerExpressionTransformerSuite extends GlutenSQLTestsTrait {
)
}

testGluten("test custom expression transformer") {
test("test custom expression transformer") {
spark
.createDataFrame(Seq((1, 1.1), (2, 2.2)))
.createOrReplaceTempView("custom_table")
Expand Down
12 changes: 1 addition & 11 deletions gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import org.apache.gluten.GlutenConfig.GLUTEN_DEFAULT_SESSION_TIMEZONE_KEY
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.events.GlutenBuildInfoEvent
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.expression.ExpressionMappings
import org.apache.gluten.extension.GlutenSessionExtensions.{GLUTEN_SESSION_EXTENSION_NAME, SPARK_SESSION_EXTS_KEY}
import org.apache.gluten.test.TestStats
import org.apache.gluten.utils.TaskListener
Expand All @@ -32,7 +31,6 @@ import org.apache.spark.listener.GlutenListenerFactory
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.execution.ui.GlutenEventUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.utils.ExpressionUtil
import org.apache.spark.util.{SparkResourceUtil, TaskResources}

import java.util
Expand Down Expand Up @@ -73,14 +71,6 @@ private[gluten] class GlutenDriverPlugin extends DriverPlugin with Logging {
BackendsApiManager.getListenerApiInstance.onDriverStart(sc, pluginContext)
GlutenListenerFactory.addToSparkListenerBus(sc)

val expressionExtensionTransformer = ExpressionUtil.extendedExpressionTransformer(
conf.get(GlutenConfig.GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF, "")
)

if (expressionExtensionTransformer != null) {
ExpressionMappings.expressionExtensionTransformer = expressionExtensionTransformer
}

Collections.emptyMap()
}

Expand Down Expand Up @@ -275,7 +265,7 @@ private[gluten] class GlutenDriverPlugin extends DriverPlugin with Logging {
}

private[gluten] class GlutenExecutorPlugin extends ExecutorPlugin {
private val taskListeners: Seq[TaskListener] = Array(TaskResources)
private val taskListeners: Seq[TaskListener] = Seq(TaskResources)

/** Initialize the executor plugin. */
override def init(ctx: PluginContext, extraConf: util.Map[String, String]): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,9 +448,16 @@ trait SparkPlanExecApi {
GenericExpressionTransformer(substraitExprName, exprs, original)
}

/** Define backend specfic expression mappings. */
/** Define backend-specific expression mappings. */
def extraExpressionMappings: Seq[Sig] = Seq.empty

/** Define backend-specific expression converter. */
def extraExpressionConverter(
substraitExprName: String,
expr: Expression,
attributeSeq: Seq[Attribute]): Option[ExpressionTransformer] =
None

/**
* Define whether the join operator is fallback because of the join operator is not supported by
* backend
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,32 +29,18 @@ object AggregateFunctionsBuilder {
val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]]

// First handle the custom aggregate functions
val (substraitAggFuncName, inputTypes) =
if (
ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains(
aggregateFunc.getClass)
) {
val (substraitAggFuncName, inputTypes) =
ExpressionMappings.expressionExtensionTransformer.buildCustomAggregateFunction(
aggregateFunc)
assert(substraitAggFuncName.isDefined)
(substraitAggFuncName.get, inputTypes)
} else {
val substraitAggFuncName = getSubstraitFunctionName(aggregateFunc)
val substraitAggFuncName = getSubstraitFunctionName(aggregateFunc)

// Check whether each backend supports this aggregate function.
if (
!BackendsApiManager.getValidatorApiInstance.doExprValidate(
substraitAggFuncName,
aggregateFunc)
) {
throw new GlutenNotSupportException(
s"Aggregate function not supported for $aggregateFunc.")
}
// Check whether each backend supports this aggregate function.
if (
!BackendsApiManager.getValidatorApiInstance.doExprValidate(
substraitAggFuncName,
aggregateFunc)
) {
throw new GlutenNotSupportException(s"Aggregate function not supported for $aggregateFunc.")
}

val inputTypes: Seq[DataType] = aggregateFunc.children.map(child => child.dataType)
(substraitAggFuncName, inputTypes)
}
val inputTypes: Seq[DataType] = aggregateFunc.children.map(child => child.dataType)

ExpressionBuilder.newScalarFunction(
functionMap,
Expand Down
Loading

0 comments on commit 9bb4b28

Please sign in to comment.