Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
Zouxxyy committed Dec 19, 2024
1 parent 19b59ee commit 7fa4f5f
Showing 1 changed file with 46 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import org.apache.paimon.spark.catalog.functions.PaimonFunctions

import org.apache.spark.sql.Row
import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, IntegerType, StructType}

class PaimonFunctionTest extends PaimonHiveTestBase {

Expand Down Expand Up @@ -85,4 +87,48 @@ class PaimonFunctionTest extends PaimonHiveTestBase {
)
}
}

test("Paimon function: show and load function with SparkGenericCatalog") {
sql(s"USE $sparkCatalogName")
sql(s"USE $hiveDbName")
sql("CREATE FUNCTION myIntSum AS 'org.apache.paimon.spark.sql.MyIntSum'")
checkAnswer(
sql(s"SHOW FUNCTIONS FROM $hiveDbName LIKE 'myIntSum'"),
Row("spark_catalog.test_hive.myintsum"))

withTable("t") {
sql("CREATE TABLE t (id INT)")
sql("INSERT INTO t VALUES (1), (2), (3)")
checkAnswer(sql("SELECT myIntSum(id) FROM t"), Row(6))
}

sql("DROP FUNCTION myIntSum")
checkAnswer(sql(s"SHOW FUNCTIONS FROM $hiveDbName LIKE 'myIntSum'"), Seq.empty)
}
}

private class MyIntSum extends UserDefinedAggregateFunction {
override def inputSchema: StructType = new StructType().add("input", IntegerType)

override def bufferSchema: StructType = new StructType().add("buffer", IntegerType)

override def dataType: DataType = IntegerType

override def deterministic: Boolean = true

override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, 0)
}

override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0, input.getInt(0) + buffer.getInt(0))
}

override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.getInt(0) + buffer2.getInt(0))
}

override def evaluate(buffer: Row): Any = {
buffer.getInt(0)
}
}

0 comments on commit 7fa4f5f

Please sign in to comment.