Skip to content

Commit 8839984

Browse files
pavle-martinovic_datacloud-fan
pavle-martinovic_data
authored andcommitted
[SPARK-51655][SQL] Fix metric collection in UnionLoopExec and add test
### What changes were proposed in this pull request? Fix metrics collection method for Recursive CTEs. Also, add new metric which tracks the number of rows that the anchor returns. ### Why are the changes needed? Current way of collecting metrics for recursive CTEs is incorrect. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New test in SQLMetricsSuite. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50449 from Pajaraja/pavle-martinovic_data/MetricFixAndTestRecursiveCTE. Authored-by: pavle-martinovic_data <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 4e5ed45 commit 8839984

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

Diff for: sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala

+7-2
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ case class UnionLoopExec(
9191

9292
override lazy val metrics = Map(
9393
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
94-
"numIterations" -> SQLMetrics.createMetric(sparkContext, "number of recursive iterations"))
94+
"numIterations" -> SQLMetrics.createMetric(sparkContext, "number of recursive iterations"),
95+
"numAnchorOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of anchor output rows"))
9596

9697
/**
9798
* This function executes the plan (optionally with appended limit node) and caches the result,
@@ -123,6 +124,7 @@ case class UnionLoopExec(
123124
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
124125
val numOutputRows = longMetric("numOutputRows")
125126
val numIterations = longMetric("numIterations")
127+
val numAnchorOutputRows = longMetric("numAnchorOutputRows")
126128
val levelLimit = conf.getConf(SQLConf.CTE_RECURSION_LEVEL_LIMIT)
127129
val rowLimit = conf.getConf(SQLConf.CTE_RECURSION_ROW_LIMIT)
128130

@@ -136,6 +138,8 @@ case class UnionLoopExec(
136138

137139
var (prevDF, prevCount) = executeAndCacheAndCount(anchor, currentLimit)
138140

141+
numAnchorOutputRows += prevCount
142+
139143
var currentLevel = 1
140144

141145
var currentNumRows = 0
@@ -177,7 +181,6 @@ case class UnionLoopExec(
177181
// Update metrics
178182
numOutputRows += prevCount
179183
numIterations += 1
180-
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
181184

182185
if (!limitReached) {
183186
// the current plan is created by substituting UnionLoopRef node with the project node of
@@ -200,6 +203,8 @@ case class UnionLoopExec(
200203
}
201204
}
202205

206+
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
207+
203208
if (unionChildren.isEmpty) {
204209
new EmptyRDD[InternalRow](sparkContext)
205210
} else {

Diff for: sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala

+18
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,24 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils
107107
}
108108
}
109109

110+
test("Recursive CTEs metrics") {
111+
val df = sql("""WITH RECURSIVE t(n) AS(
112+
| VALUES 1, 2
113+
| UNION ALL
114+
| SELECT n+1 FROM t WHERE n < 20
115+
| )
116+
| SELECT * FROM t""".stripMargin)
117+
val unionLoopExec = df.queryExecution.executedPlan.collect {
118+
case ule: UnionLoopExec => ule
119+
}
120+
sparkContext.listenerBus.waitUntilEmpty()
121+
assert(unionLoopExec.size == 1)
122+
val expected = Map("number of output rows" -> 39L, "number of recursive iterations" -> 20L,
123+
"number of anchor output rows" -> 2L)
124+
testSparkPlanMetrics(df, 22, Map(
125+
2L -> (("UnionLoop", expected))))
126+
}
127+
110128
test("Filter metrics") {
111129
// Assume the execution plan is
112130
// PhysicalRDD(nodeId = 1) -> Filter(nodeId = 0)

0 commit comments

Comments
 (0)