Skip to content

Commit

Permalink
remove native registration
Browse files Browse the repository at this point in the history
  • Loading branch information
marin-ma committed Aug 30, 2024
1 parent e7caab1 commit 94771a9
Show file tree
Hide file tree
Showing 13 changed files with 53 additions and 371 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@

public class UdfJniWrapper {

public static native void getFunctionSignatures();
public static native void registerFunctionSignatures();
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.gluten.backendsapi.ListenerApi
import org.apache.gluten.execution.datasource.{GlutenOrcWriterInjects, GlutenParquetWriterInjects, GlutenRowSplitter}
import org.apache.gluten.expression.UDFMappings
import org.apache.gluten.init.NativeBackendInitializer
import org.apache.gluten.udf.UdfJniWrapper
import org.apache.gluten.utils._
import org.apache.gluten.vectorized.{JniLibLoader, JniWorkspace}

Expand Down Expand Up @@ -81,6 +82,7 @@ class VeloxListenerApi extends ListenerApi with Logging {
SparkDirectoryUtil.init(conf)
UDFResolver.resolveUdfConf(conf, isDriver = true)
initialize(conf)
UdfJniWrapper.registerFunctionSignatures()
}

override def onDriverShutdown(): Unit = shutdown()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import org.apache.gluten.extension.injector.GlutenInjector.{LegacyInjector, RasI
import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages, GlutenFallbackReporter}
import org.apache.spark.sql.expression.UDFResolver

class VeloxRuleApi extends RuleApi {
import VeloxRuleApi._
Expand All @@ -47,7 +46,6 @@ private object VeloxRuleApi {
// Regular Spark rules.
injector.injectOptimizerRule(CollectRewriteRule.apply)
injector.injectOptimizerRule(HLLRewriteRule.apply)
UDFResolver.getFunctionSignatures().foreach(injector.injectFunction)
injector.injectPostHocResolutionRule(ArrowConvertorRule.apply)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,14 @@ package org.apache.spark.sql.expression

import org.apache.gluten.backendsapi.velox.VeloxBackendSettings
import org.apache.gluten.exception.{GlutenException, GlutenNotSupportException}
import org.apache.gluten.expression.{ConverterUtils, ExpressionTransformer, ExpressionType, GenericExpressionTransformer, Transformable}
import org.apache.gluten.udf.UdfJniWrapper
import org.apache.gluten.expression._
import org.apache.gluten.vectorized.JniWorkspace

import org.apache.spark.{SparkConf, SparkContext, SparkFiles}
import org.apache.spark.{SparkConf, SparkFiles}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Expression, ExpressionInfo, Unevaluable}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Expression, Unevaluable}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
Expand Down Expand Up @@ -334,32 +332,6 @@ object UDFResolver extends Logging {
.mkString(",")
}

def getFunctionSignatures(): Seq[(FunctionIdentifier, ExpressionInfo, FunctionBuilder)] = {
val sparkContext = SparkContext.getActive.get
val sparkConf = sparkContext.conf
val udfLibPaths = sparkConf.getOption(VeloxBackendSettings.GLUTEN_VELOX_UDF_LIB_PATHS)

udfLibPaths match {
case None =>
Seq.empty
case Some(_) =>
UdfJniWrapper.getFunctionSignatures()
UDFNames.map {
name =>
(
new FunctionIdentifier(name),
new ExpressionInfo(classOf[UDFExpression].getName, name),
(e: Seq[Expression]) => getUdfExpression(name, name)(e))
}.toSeq ++ UDAFNames.map {
name =>
(
new FunctionIdentifier(name),
new ExpressionInfo(classOf[UserDefinedAggregateFunction].getName, name),
(e: Seq[Expression]) => getUdafExpression(name)(e))
}.toSeq
}
}

private def checkAllowTypeConversion: Boolean = {
SQLConf.get
.getConfString(VeloxBackendSettings.GLUTEN_VELOX_UDF_ALLOW_TYPE_CONVERSION, "false")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.expression.UDFResolver

import java.nio.file.Paths
import java.sql.Date

abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper {

Expand Down Expand Up @@ -92,41 +91,7 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper {
.set("spark.memory.offHeap.size", "1024MB")
}

test("test udf") {
val df = spark.sql("""select
| myudf1(100L),
| myudf2(1),
| myudf2(1L),
| myudf3(),
| myudf3(1),
| myudf3(1, 2, 3),
| myudf3(1L),
| myudf3(1L, 2L, 3L),
| mydate(cast('2024-03-25' as date), 5)
|""".stripMargin)
assert(
df.collect()
.sameElements(Array(Row(105L, 6, 6L, 5, 6, 11, 6L, 11L, Date.valueOf("2024-03-30")))))
}

test("test udf allow type conversion") {
withSQLConf(VeloxBackendSettings.GLUTEN_VELOX_UDF_ALLOW_TYPE_CONVERSION -> "true") {
val df = spark.sql("""select myudf1("100"), myudf1(1), mydate('2024-03-25', 5)""")
assert(
df.collect()
.sameElements(Array(Row(105L, 6L, Date.valueOf("2024-03-30")))))
}

withSQLConf(VeloxBackendSettings.GLUTEN_VELOX_UDF_ALLOW_TYPE_CONVERSION -> "false") {
assert(
spark
.sql("select mydate2('2024-03-25', 5)")
.collect()
.sameElements(Array(Row(Date.valueOf("2024-03-30")))))
}
}

test("test udaf") {
ignore("test udaf") {
val df = spark.sql("""select
| myavg(1),
| myavg(1L),
Expand All @@ -140,7 +105,7 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper {
.sameElements(Array(Row(1.0, 1.0, 1.0, 1.0, 1L))))
}

test("test udaf allow type conversion") {
ignore("test udaf allow type conversion") {
withSQLConf(VeloxBackendSettings.GLUTEN_VELOX_UDF_ALLOW_TYPE_CONVERSION -> "true") {
val df = spark.sql("""select myavg("1"), myavg("1.0"), mycount_if("true")""")
assert(
Expand All @@ -149,7 +114,7 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper {
}
}

test("test hive udf replacement") {
test("test native hive udf") {
val tbl = "test_hive_udf_replacement"
withTempPath {
dir =>
Expand All @@ -169,12 +134,15 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper {
|AS 'org.apache.spark.sql.hive.execution.UDFStringString'
|""".stripMargin)

val nativeResultWithImplicitConversion =
spark.sql(s"""SELECT hive_string_string(col1, 'a') FROM $tbl""").collect()
val nativeResult =
spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""").collect()
// Unregister native hive udf to fallback.
UDFResolver.UDFNames.remove("org.apache.spark.sql.hive.execution.UDFStringString")
val fallbackResult =
spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""").collect()
assert(nativeResultWithImplicitConversion.sameElements(fallbackResult))
assert(nativeResult.sameElements(fallbackResult))

// Add an unimplemented udf to the map to test fallback of registered native hive udf.
Expand Down
2 changes: 1 addition & 1 deletion cpp/velox/jni/JniUdf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ void gluten::finalizeVeloxJniUDF(JNIEnv* env) {
env->DeleteGlobalRef(udfResolverClass);
}

void gluten::jniGetFunctionSignatures(JNIEnv* env) {
void gluten::jniRegisterFunctionSignatures(JNIEnv* env) {
auto udfLoader = gluten::UdfLoader::getInstance();

const auto& signatures = udfLoader->getRegisteredUdfSignatures();
Expand Down
2 changes: 1 addition & 1 deletion cpp/velox/jni/JniUdf.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ void initVeloxJniUDF(JNIEnv* env);

void finalizeVeloxJniUDF(JNIEnv* env);

void jniGetFunctionSignatures(JNIEnv* env);
void jniRegisterFunctionSignatures(JNIEnv* env);

} // namespace gluten
4 changes: 2 additions & 2 deletions cpp/velox/jni/VeloxJniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ JNIEXPORT void JNICALL Java_org_apache_gluten_init_NativeBackendInitializer_shut
JNI_METHOD_END()
}

JNIEXPORT void JNICALL Java_org_apache_gluten_udf_UdfJniWrapper_getFunctionSignatures( // NOLINT
JNIEXPORT void JNICALL Java_org_apache_gluten_udf_UdfJniWrapper_registerFunctionSignatures( // NOLINT
JNIEnv* env,
jclass) {
JNI_METHOD_START
gluten::jniGetFunctionSignatures(env);
gluten::jniRegisterFunctionSignatures(env);
JNI_METHOD_END()
}

Expand Down
21 changes: 13 additions & 8 deletions cpp/velox/tests/MyUdfTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,17 @@ class MyUdfTest : public FunctionBaseTest {
}
};

TEST_F(MyUdfTest, myudf1) {
const auto myudf1 = [&](const int64_t& number) {
return evaluateOnce<int64_t>("myudf1(c0)", BIGINT(), std::make_optional(number));
};

EXPECT_EQ(5, myudf1(0));
EXPECT_EQ(105, myudf1(100));
EXPECT_EQ(3147483652, myudf1(3147483647)); // int64
TEST_F(MyUdfTest, hivestringstring) {
auto map = facebook::velox::exec::vectorFunctionFactories();
const std::string candidate = {"org.apache.spark.sql.hive.execution.UDFStringString"};
ASSERT(map.withRLock([&candidate](auto& self) -> bool {
auto iter = self.find(candidate);
std::unordered_map<std::string, std::string> values;
const facebook::velox::core::QueryConfig config(std::move(values));
iter->second.factory(
candidate,
{facebook::velox::exec::VectorFunctionArg{facebook::velox::VARCHAR()},
facebook::velox::exec::VectorFunctionArg{facebook::velox::VARCHAR()}},
config) != nullptr;
});)
}
3 changes: 0 additions & 3 deletions cpp/velox/udf/examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,3 @@ target_link_libraries(myudf velox)

add_library(myudaf SHARED "MyUDAF.cc")
target_link_libraries(myudaf velox)

add_executable(test_myudf "TestMyUDF.cc")
target_link_libraries(test_myudf velox)
Loading

0 comments on commit 94771a9

Please sign in to comment.