diff --git a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala index 505634d5bb048..45bc0962efa6f 100644 --- a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala +++ b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala @@ -88,5 +88,7 @@ private[spark] object InternalAccumulator { val RECORDS_READ = INPUT_METRICS_PREFIX + "recordsRead" } + val COLLECT_METRICS_ACCUMULATOR = METRICS_PREFIX + "collectMetricsAccumulator" + // scalastyle:on } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 97754d5457bec..3b066e15386b8 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -1271,12 +1271,13 @@ private[spark] class Executor( if (taskRunner.task != null) { taskRunner.task.metrics.mergeShuffleReadMetrics() taskRunner.task.metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime) - val accumulatorsToReport = + val accumulatorsToReport = { if (HEARTBEAT_DROP_ZEROES) { taskRunner.task.metrics.accumulators().filterNot(_.isZero) } else { taskRunner.task.metrics.accumulators() } + }.filterNot(_.excludeFromHeartbeat) accumUpdates += ((taskRunner.taskId, accumulatorsToReport)) } } diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 383a89d40ecee..1745498456213 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -45,6 +45,8 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { private[spark] var metadata: AccumulatorMetadata = _ private[this] var atDriverSide = true + def excludeFromHeartbeat: Boolean = false + private[spark] def register( sc: SparkContext, name: Option[String] = None, diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index df809f4fad745..d5d7c449f23d6 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -488,7 +488,10 @@ private[spark] object JsonProtocol extends JsonUtils { g.writeEndObject() } - private[util] val accumulableExcludeList = Set(InternalAccumulator.UPDATED_BLOCK_STATUSES) + private[util] val accumulableExcludeList = Set( + InternalAccumulator.UPDATED_BLOCK_STATUSES, + InternalAccumulator.COLLECT_METRICS_ACCUMULATOR + ) private[this] val taskMetricAccumulableNames = TaskMetrics.empty.nameToAccums.keySet.toSet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala index 667d1a67b3932..bb785bfd18f02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala @@ -44,6 +44,8 @@ class AggregatingAccumulator private( assert(bufferSchema.size == updateExpressions.size) assert(mergeExpressions == null || bufferSchema.size == mergeExpressions.size) + override def excludeFromHeartbeat: Boolean = true + @transient private var joinedRow: JoinedRow = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala index 0a487bac77696..fd89be2368af6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.execution -import org.apache.spark.TaskContext +import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} @@ -38,10 +38,14 @@ case class CollectMetricsExec( private lazy val accumulator: AggregatingAccumulator = { val acc = AggregatingAccumulator(metricExpressions, child.output) - acc.register(sparkContext, Option("Collected metrics")) + acc.register(sparkContext, Option(CollectMetricsExec.ACCUMULATOR_NAME)) acc } + private[sql] def accumulatorId: Long = { + accumulator.id + } + val metricsSchema: StructType = { DataTypeUtils.fromAttributes(metricExpressions.map(_.toAttribute)) } @@ -95,6 +99,9 @@ case class CollectMetricsExec( } object CollectMetricsExec extends AdaptiveSparkPlanHelper { + + val ACCUMULATOR_NAME: String = InternalAccumulator.COLLECT_METRICS_ACCUMULATOR + /** * Recursively collect all collected metrics from a query tree. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 099a09d7784d7..1ec9aca857e22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -18,19 +18,24 @@ package org.apache.spark.sql.util import java.lang.{Long => JLong} +import java.util.concurrent.{CopyOnWriteArrayList, CountDownLatch, TimeUnit} import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ import org.apache.spark._ +import org.apache.spark.internal.config.{EXECUTOR_HEARTBEAT_DROP_ZERO_ACCUMULATOR_UPDATES, EXECUTOR_HEARTBEAT_INTERVAL} +import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerExecutorMetricsUpdate} import org.apache.spark.sql.{functions, Encoder, Encoders, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.classic.Dataset -import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{CollectMetricsExec, QueryExecution, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, LeafRunnableCommand} import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand import org.apache.spark.sql.execution.datasources.json.JsonFileFormat +import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -42,6 +47,10 @@ class DataFrameCallbackSuite extends QueryTest import testImplicits._ import functions._ + override protected def sparkConf: SparkConf = { + super.sparkConf.set(EXECUTOR_HEARTBEAT_DROP_ZERO_ACCUMULATOR_UPDATES, false) + } + test("execute callback functions when a DataFrame action finished successfully") { val metrics = ArrayBuffer.empty[(String, QueryExecution, Long)] val listener = new QueryExecutionListener { @@ -341,6 +350,55 @@ class DataFrameCallbackSuite extends QueryTest } } + test("SPARK-52006: executor heartbeat should exclude observable metrics") { + val metricMaps = ArrayBuffer.empty[Map[String, Row]] + @volatile var accumulatorId = 0L + val listener = new SparkListener { + override def onExecutorMetricsUpdate(msg: SparkListenerExecutorMetricsUpdate): Unit = { + HeartbeatMonitor.heartbeatReceived(msg) + } + + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case e: SparkListenerSQLExecutionEnd => + metricMaps += e.qe.observedMetrics + val accumulators = e.qe.executedPlan + .collect { case exec: CollectMetricsExec => exec.accumulatorId } + assert(accumulators.length === 1) + accumulatorId = accumulators.head + case _ => // Ignore + } + } + sparkContext.listenerBus.waitUntilEmpty() + sparkContext.addSparkListener(listener) + + try { + val heartbeatInterval = sparkContext.getConf.get(EXECUTOR_HEARTBEAT_INTERVAL) + val df = spark.range(0, 100, 1, 1) + .mapPartitions { iter => + TaskContext.get().addTaskCompletionListener[Unit] { _ => + // Wait for heartbeat sent, 30s timeout by default + assert(HeartbeatMonitor.await(heartbeatInterval * 3, TimeUnit.MILLISECONDS)) + } + iter + }.toDF("id") + .observe( + name = "my_event", + max($"id").as("max_val"), + percentile_approx($"id", lit(0.5), lit(100)), + percentile_approx($"id", lit(0.5), lit(100)), + min($"id").as("min_val")) + df.collect() + sparkContext.listenerBus.waitUntilEmpty() + + val msgs = HeartbeatMonitor.msgs.asScala + val accumulatorIds = msgs.flatMap(_.accumUpdates.flatMap(_._4)).map(_.id).toSet + assert(accumulatorId != 0) + assert(msgs.nonEmpty && !accumulatorIds.contains(accumulatorId)) + } finally { + sparkContext.removeSparkListener(listener) + } + } + test("SPARK-50581: support observe with udaf") { withUserDefinedFunction(("someUdaf", true)) { spark.udf.register("someUdaf", functions.udaf(new Aggregator[JLong, JLong, JLong] { @@ -452,3 +510,20 @@ case class ErrorTestCommand(foo: String) extends LeafRunnableCommand { override def run(sparkSession: org.apache.spark.sql.SparkSession): Seq[Row] = throw new java.lang.Error(foo) } + +/** Singleton utils for testing SPARK-52006 */ +object HeartbeatMonitor { + private val heartbeatLatch = new CountDownLatch(1) + val msgs = new CopyOnWriteArrayList[SparkListenerExecutorMetricsUpdate]() + + def heartbeatReceived(msg: SparkListenerExecutorMetricsUpdate): Unit = { + if (msg.accumUpdates.nonEmpty) { + msgs.add(msg) + heartbeatLatch.countDown() + } + } + + def await(timeout: Long, timeUnit: TimeUnit): Boolean = { + heartbeatLatch.await(timeout, timeUnit) + } +}