Skip to content

Commit

Permalink
fix ScalaAggregator serialization to support UDAF in Dataset.observe()
Browse files Browse the repository at this point in the history
  • Loading branch information
toms-definity committed Dec 15, 2024
1 parent d2965ae commit 27de3a2
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ case class ScalaAggregator[IN, BUF, OUT](
def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ScalaAggregator[IN, BUF, OUT] =
copy(inputAggBufferOffset = newInputAggBufferOffset)

@transient
private[this] lazy val inputProjection = UnsafeProjection.create(children)

def createAggregationBuffer(): BUF = agg.zero
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@
package org.apache.spark.sql.util

import java.lang.{Long => JLong}

import scala.collection.mutable.ArrayBuffer

import org.apache.spark._
import org.apache.spark.sql.{functions, Dataset, QueryTest, Row, SparkSession}
import org.apache.spark.sql.{Dataset, Encoder, Encoders, QueryTest, Row, SparkSession, functions}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, LeafRunnableCommand}
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions.{expr, udaf}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StringType
Expand Down Expand Up @@ -339,6 +339,51 @@ class DataFrameCallbackSuite extends QueryTest
}
}

test("SPARK-50581: support observe with udaf") {
withUserDefinedFunction(("someUdaf", true)) {
spark.udf.register("someUdaf", udaf(new Aggregator[JLong, JLong, JLong] {
def zero: JLong = 0L
def reduce(b: JLong, a: JLong): JLong = a + b
def merge(b1: JLong, b2: JLong): JLong = b1 + b2
def finish(r: JLong): JLong = r
def bufferEncoder: Encoder[JLong] = Encoders.LONG
def outputEncoder: Encoder[JLong] = Encoders.LONG
}))

val df = spark.range(100)

val metricMaps = ArrayBuffer.empty[Map[String, Row]]
val listener = new QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
if (qe.observedMetrics.nonEmpty) {
metricMaps += qe.observedMetrics
}
}

override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
// No-op
}
}
try {
spark.listenerManager.register(listener)

// udaf usage in observe is not working (serialization exception)
df.observe(
name = "my_metrics",
expr("someUdaf(id)").as("agg")
)
.collect()

sparkContext.listenerBus.waitUntilEmpty()
assert(metricMaps.size === 1)
assert(metricMaps.head("my_metrics") === Row(4950L))

} finally {
spark.listenerManager.unregister(listener)
}
}
}

private def validateObservedMetrics(df: Dataset[JLong]): Unit = {
val metricMaps = ArrayBuffer.empty[Map[String, Row]]
val listener = new QueryExecutionListener {
Expand Down

0 comments on commit 27de3a2

Please sign in to comment.