Skip to content

[SPARK-52006][SQL][CORE] Exclude CollectMetricsExec accumulator from Spark UI + event logs + metric heartbeats #50812

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,7 @@ private[spark] object InternalAccumulator {
val RECORDS_READ = INPUT_METRICS_PREFIX + "recordsRead"
}

val COLLECT_METRICS_ACCUMULATOR = METRICS_PREFIX + "collectMetricsAccumulator"

// scalastyle:on
}
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1271,12 +1271,13 @@ private[spark] class Executor(
if (taskRunner.task != null) {
taskRunner.task.metrics.mergeShuffleReadMetrics()
taskRunner.task.metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
val accumulatorsToReport =
val accumulatorsToReport = {
if (HEARTBEAT_DROP_ZEROES) {
taskRunner.task.metrics.accumulators().filterNot(_.isZero)
} else {
taskRunner.task.metrics.accumulators()
}
}.filterNot(_.excludeFromHeartbeat)
accumUpdates += ((taskRunner.taskId, accumulatorsToReport))
}
}
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
private[spark] var metadata: AccumulatorMetadata = _
private[this] var atDriverSide = true

def excludeFromHeartbeat: Boolean = false

private[spark] def register(
sc: SparkContext,
name: Option[String] = None,
Expand Down
5 changes: 4 additions & 1 deletion core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,10 @@ private[spark] object JsonProtocol extends JsonUtils {
g.writeEndObject()
}

private[util] val accumulableExcludeList = Set(InternalAccumulator.UPDATED_BLOCK_STATUSES)
private[util] val accumulableExcludeList = Set(
InternalAccumulator.UPDATED_BLOCK_STATUSES,
InternalAccumulator.COLLECT_METRICS_ACCUMULATOR
)

private[this] val taskMetricAccumulableNames = TaskMetrics.empty.nameToAccums.keySet.toSet

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class AggregatingAccumulator private(
assert(bufferSchema.size == updateExpressions.size)
assert(mergeExpressions == null || bufferSchema.size == mergeExpressions.size)

override def excludeFromHeartbeat: Boolean = true

@transient
private var joinedRow: JoinedRow = _

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.execution

import org.apache.spark.TaskContext
import org.apache.spark.{InternalAccumulator, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
Expand All @@ -38,10 +38,14 @@ case class CollectMetricsExec(

private lazy val accumulator: AggregatingAccumulator = {
val acc = AggregatingAccumulator(metricExpressions, child.output)
acc.register(sparkContext, Option("Collected metrics"))
acc.register(sparkContext, Option(CollectMetricsExec.ACCUMULATOR_NAME))
acc
}

private[sql] def accumulatorId: Long = {
accumulator.id
}

val metricsSchema: StructType = {
DataTypeUtils.fromAttributes(metricExpressions.map(_.toAttribute))
}
Expand Down Expand Up @@ -95,6 +99,9 @@ case class CollectMetricsExec(
}

object CollectMetricsExec extends AdaptiveSparkPlanHelper {

val ACCUMULATOR_NAME: String = InternalAccumulator.COLLECT_METRICS_ACCUMULATOR

/**
* Recursively collect all collected metrics from a query tree.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,24 @@
package org.apache.spark.sql.util

import java.lang.{Long => JLong}
import java.util.concurrent.{CopyOnWriteArrayList, CountDownLatch, TimeUnit}

import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._

import org.apache.spark._
import org.apache.spark.internal.config.{EXECUTOR_HEARTBEAT_DROP_ZERO_ACCUMULATOR_UPDATES, EXECUTOR_HEARTBEAT_INTERVAL}
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerExecutorMetricsUpdate}
import org.apache.spark.sql.{functions, Encoder, Encoders, QueryTest, Row}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.classic.Dataset
import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegenExec}
import org.apache.spark.sql.execution.{CollectMetricsExec, QueryExecution, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, LeafRunnableCommand}
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
Expand All @@ -42,6 +47,10 @@ class DataFrameCallbackSuite extends QueryTest
import testImplicits._
import functions._

override protected def sparkConf: SparkConf = {
super.sparkConf.set(EXECUTOR_HEARTBEAT_DROP_ZERO_ACCUMULATOR_UPDATES, false)
}

test("execute callback functions when a DataFrame action finished successfully") {
val metrics = ArrayBuffer.empty[(String, QueryExecution, Long)]
val listener = new QueryExecutionListener {
Expand Down Expand Up @@ -341,6 +350,55 @@ class DataFrameCallbackSuite extends QueryTest
}
}

test("SPARK-52006: executor heartbeat should exclude observable metrics") {
val metricMaps = ArrayBuffer.empty[Map[String, Row]]
@volatile var accumulatorId = 0L
val listener = new SparkListener {
override def onExecutorMetricsUpdate(msg: SparkListenerExecutorMetricsUpdate): Unit = {
HeartbeatMonitor.heartbeatReceived(msg)
}

override def onOtherEvent(event: SparkListenerEvent): Unit = event match {
case e: SparkListenerSQLExecutionEnd =>
metricMaps += e.qe.observedMetrics
val accumulators = e.qe.executedPlan
.collect { case exec: CollectMetricsExec => exec.accumulatorId }
assert(accumulators.length === 1)
accumulatorId = accumulators.head
case _ => // Ignore
}
}
sparkContext.listenerBus.waitUntilEmpty()
sparkContext.addSparkListener(listener)

try {
val heartbeatInterval = sparkContext.getConf.get(EXECUTOR_HEARTBEAT_INTERVAL)
val df = spark.range(0, 100, 1, 1)
.mapPartitions { iter =>
TaskContext.get().addTaskCompletionListener[Unit] { _ =>
// Wait for heartbeat sent, 30s timeout by default
assert(HeartbeatMonitor.await(heartbeatInterval * 3, TimeUnit.MILLISECONDS))
}
iter
}.toDF("id")
.observe(
name = "my_event",
max($"id").as("max_val"),
percentile_approx($"id", lit(0.5), lit(100)),
percentile_approx($"id", lit(0.5), lit(100)),
min($"id").as("min_val"))
df.collect()
sparkContext.listenerBus.waitUntilEmpty()

val msgs = HeartbeatMonitor.msgs.asScala
val accumulatorIds = msgs.flatMap(_.accumUpdates.flatMap(_._4)).map(_.id).toSet
assert(accumulatorId != 0)
assert(msgs.nonEmpty && !accumulatorIds.contains(accumulatorId))
} finally {
sparkContext.removeSparkListener(listener)
}
}

test("SPARK-50581: support observe with udaf") {
withUserDefinedFunction(("someUdaf", true)) {
spark.udf.register("someUdaf", functions.udaf(new Aggregator[JLong, JLong, JLong] {
Expand Down Expand Up @@ -452,3 +510,20 @@ case class ErrorTestCommand(foo: String) extends LeafRunnableCommand {
override def run(sparkSession: org.apache.spark.sql.SparkSession): Seq[Row] =
throw new java.lang.Error(foo)
}

/** Singleton utils for testing SPARK-52006 */
object HeartbeatMonitor {
private val heartbeatLatch = new CountDownLatch(1)
val msgs = new CopyOnWriteArrayList[SparkListenerExecutorMetricsUpdate]()

def heartbeatReceived(msg: SparkListenerExecutorMetricsUpdate): Unit = {
if (msg.accumUpdates.nonEmpty) {
msgs.add(msg)
heartbeatLatch.countDown()
}
}

def await(timeout: Long, timeUnit: TimeUnit): Boolean = {
heartbeatLatch.await(timeout, timeUnit)
}
}