Skip to content

Commit

Permalink
support native hive udaf
Browse files Browse the repository at this point in the history
  • Loading branch information
marin-ma committed Aug 27, 2024
1 parent c44f8a4 commit 5b27456
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 239 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import org.apache.gluten.utils.VeloxIntermediateData
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.expression.UDFResolver
import org.apache.spark.sql.hive.HiveUDAFInspector
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -681,14 +683,25 @@ object VeloxAggregateFunctionsBuilder {
aggregateFunc: AggregateFunction,
mode: AggregateMode): Long = {
val functionMap = args.asInstanceOf[JHashMap[String, JLong]]
val sigName = AggregateFunctionsBuilder.getSubstraitFunctionName(aggregateFunc)
val (sigName, aggFunc) =
try {
(AggregateFunctionsBuilder.getSubstraitFunctionName(aggregateFunc), aggregateFunc)
} catch {
case e: GlutenNotSupportException =>
HiveUDAFInspector.getUDAFClassName(aggregateFunc) match {
case Some(udafClass) if UDFResolver.UDAFNames.contains(udafClass) =>
(udafClass, UDFResolver.getUdafExpression(udafClass)(aggregateFunc.children))
case _ => throw e
}
case e: Throwable => throw e
}

ExpressionBuilder.newScalarFunction(
functionMap,
ConverterUtils.makeFuncName(
// Substrait-to-Velox procedure will choose appropriate companion function if needed.
sigName,
VeloxIntermediateData.getInputTypes(aggregateFunc, mode == PartialMerge || mode == Final),
VeloxIntermediateData.getInputTypes(aggFunc, mode == PartialMerge || mode == Final),
FunctionConfig.REQ
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/
package org.apache.gluten.expression

import org.apache.gluten.backendsapi.velox.VeloxBackendSettings
import org.apache.gluten.tags.{SkipTestTags, UDFTest}

import org.apache.spark.SparkConf
Expand Down Expand Up @@ -91,26 +90,50 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper {
.set("spark.memory.offHeap.size", "1024MB")
}

ignore("test udaf") {
val df = spark.sql("""select
| myavg(1),
| myavg(1L),
| myavg(cast(1.0 as float)),
| myavg(cast(1.0 as double)),
| mycount_if(true)
|""".stripMargin)
df.collect()
assert(
df.collect()
.sameElements(Array(Row(1.0, 1.0, 1.0, 1.0, 1L))))
}
test("test native hive udaf") {
val tbl = "test_hive_udaf_replacement"
withTempPath {
dir =>
try {
// Check native hive udaf has been registered.
val udafClass = "test.org.apache.spark.sql.MyDoubleAvg"
assert(UDFResolver.UDAFNames.contains(udafClass))

ignore("test udaf allow type conversion") {
withSQLConf(VeloxBackendSettings.GLUTEN_VELOX_UDF_ALLOW_TYPE_CONVERSION -> "true") {
val df = spark.sql("""select myavg("1"), myavg("1.0"), mycount_if("true")""")
assert(
df.collect()
.sameElements(Array(Row(1.0, 1.0, 1L))))
spark.sql(s"""
|CREATE TEMPORARY FUNCTION my_double_avg
|AS '$udafClass'
|""".stripMargin)
spark.sql(s"""
|CREATE EXTERNAL TABLE $tbl
|LOCATION 'file://$dir'
|AS select * from values (1, '1'), (2, '2'), (3, '3')
|""".stripMargin)
val df = spark.sql(s"""select
| my_double_avg(cast(col1 as double)),
| my_double_avg(cast(col2 as double))
| from $tbl
|""".stripMargin)
val nativeImplicitConversionDF = spark.sql(s"""select
| my_double_avg(col1),
| my_double_avg(col2)
| from $tbl
|""".stripMargin)
val nativeResult = df.collect()
val nativeImplicitConversionResult = nativeImplicitConversionDF.collect()

UDFResolver.UDAFNames.remove(udafClass)
val fallbackDF = spark.sql(s"""select
| my_double_avg(cast(col1 as double)),
| my_double_avg(cast(col2 as double))
| from $tbl
|""".stripMargin)
val fallbackResult = fallbackDF.collect()
assert(nativeResult.sameElements(fallbackResult))
assert(nativeImplicitConversionResult.sameElements(fallbackResult))
} finally {
spark.sql(s"DROP TABLE IF EXISTS $tbl")
spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS my_double_avg")
}
}
}

Expand Down Expand Up @@ -205,6 +228,7 @@ class VeloxUdfSuiteLocal extends VeloxUdfSuite {
super.sparkConf
.set("spark.files", udfLibPath)
.set("spark.gluten.sql.columnar.backend.velox.udfLibraryPaths", udfLibRelativePath)
.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
}
}

Expand Down
4 changes: 2 additions & 2 deletions cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1162,11 +1162,11 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& ag
"regr_sxy",
"regr_replacement"};

auto udfFuncs = UdfLoader::getInstance()->getRegisteredUdafNames();
auto udafFuncs = UdfLoader::getInstance()->getRegisteredUdafNames();

for (const auto& funcSpec : funcSpecs) {
auto funcName = SubstraitParser::getNameBeforeDelimiter(funcSpec);
if (supportedAggFuncs.find(funcName) == supportedAggFuncs.end() && udfFuncs.find(funcName) == udfFuncs.end()) {
if (supportedAggFuncs.find(funcName) == supportedAggFuncs.end() && udafFuncs.find(funcName) == udafFuncs.end()) {
LOG_VALIDATION_MSG(funcName + " was not supported in AggregateRel.");
return false;
}
Expand Down
Loading

0 comments on commit 5b27456

Please sign in to comment.