From 27de3a298357e0a128b7663cce9f5d71b67f4250 Mon Sep 17 00:00:00 2001 From: tomsisso Date: Sun, 15 Dec 2024 15:48:43 +0200 Subject: [PATCH] fix ScalaAggregator serialization to support UDAF in Dataset.observe() --- .../spark/sql/execution/aggregate/udaf.scala | 1 + .../sql/util/DataFrameCallbackSuite.scala | 51 +++++++++++++++++-- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 09d9915022a65..c5ce1ea3991fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -519,6 +519,7 @@ case class ScalaAggregator[IN, BUF, OUT]( def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ScalaAggregator[IN, BUF, OUT] = copy(inputAggBufferOffset = newInputAggBufferOffset) + @transient private[this] lazy val inputProjection = UnsafeProjection.create(children) def createAggregationBuffer(): BUF = agg.zero diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index be91f5e789e2c..657d78dda3cdd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -18,11 +18,9 @@ package org.apache.spark.sql.util import java.lang.{Long => JLong} - import scala.collection.mutable.ArrayBuffer - import org.apache.spark._ -import org.apache.spark.sql.{functions, Dataset, QueryTest, Row, SparkSession} +import org.apache.spark.sql.{Dataset, Encoder, Encoders, QueryTest, Row, SparkSession, functions} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegenExec} @@ -30,6 +28,8 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, LeafRunnableCommand} import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand import org.apache.spark.sql.execution.datasources.json.JsonFileFormat +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.functions.{expr, udaf} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StringType @@ -339,6 +339,51 @@ class DataFrameCallbackSuite extends QueryTest } } + test("SPARK-50581: support observe with udaf") { + withUserDefinedFunction(("someUdaf", true)) { + spark.udf.register("someUdaf", udaf(new Aggregator[JLong, JLong, JLong] { + def zero: JLong = 0L + def reduce(b: JLong, a: JLong): JLong = a + b + def merge(b1: JLong, b2: JLong): JLong = b1 + b2 + def finish(r: JLong): JLong = r + def bufferEncoder: Encoder[JLong] = Encoders.LONG + def outputEncoder: Encoder[JLong] = Encoders.LONG + })) + + val df = spark.range(100) + + val metricMaps = ArrayBuffer.empty[Map[String, Row]] + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + if (qe.observedMetrics.nonEmpty) { + metricMaps += qe.observedMetrics + } + } + + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + // No-op + } + } + try { + spark.listenerManager.register(listener) + + // udaf usage in observe is not working (serialization exception) + df.observe( + name = "my_metrics", + expr("someUdaf(id)").as("agg") + ) + .collect() + + sparkContext.listenerBus.waitUntilEmpty() + assert(metricMaps.size === 1) + assert(metricMaps.head("my_metrics") === Row(4950L)) + + } finally { + spark.listenerManager.unregister(listener) + } + } + } + private def validateObservedMetrics(df: Dataset[JLong]): Unit = { val metricMaps = ArrayBuffer.empty[Map[String, Row]] val listener = new QueryExecutionListener {