diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala index 0694d02550a90..561f9cd72efbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala @@ -91,7 +91,8 @@ case class UnionLoopExec( override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numIterations" -> SQLMetrics.createMetric(sparkContext, "number of recursive iterations")) + "numIterations" -> SQLMetrics.createMetric(sparkContext, "number of recursive iterations"), + "numAnchorOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of anchor output rows")) /** * This function executes the plan (optionally with appended limit node) and caches the result, @@ -123,6 +124,7 @@ case class UnionLoopExec( val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) val numOutputRows = longMetric("numOutputRows") val numIterations = longMetric("numIterations") + val numAnchorOutputRows = longMetric("numAnchorOutputRows") val levelLimit = conf.getConf(SQLConf.CTE_RECURSION_LEVEL_LIMIT) val rowLimit = conf.getConf(SQLConf.CTE_RECURSION_ROW_LIMIT) @@ -136,6 +138,8 @@ case class UnionLoopExec( var (prevDF, prevCount) = executeAndCacheAndCount(anchor, currentLimit) + numAnchorOutputRows += prevCount + var currentLevel = 1 var currentNumRows = 0 @@ -177,7 +181,6 @@ case class UnionLoopExec( // Update metrics numOutputRows += prevCount numIterations += 1 - SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) if (!limitReached) { // the current plan is created by substituting UnionLoopRef node with the project node of @@ -200,6 +203,8 @@ case class UnionLoopExec( } } + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) + if (unionChildren.isEmpty) { new EmptyRDD[InternalRow](sparkContext) } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 0dd90925d3c74..44c923095c348 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -107,6 +107,24 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils } } + test("Recursive CTEs metrics") { + val df = sql("""WITH RECURSIVE t(n) AS( + | VALUES 1, 2 + | UNION ALL + | SELECT n+1 FROM t WHERE n < 20 + | ) + | SELECT * FROM t""".stripMargin) + val unionLoopExec = df.queryExecution.executedPlan.collect { + case ule: UnionLoopExec => ule + } + sparkContext.listenerBus.waitUntilEmpty() + assert(unionLoopExec.size == 1) + val expected = Map("number of output rows" -> 39L, "number of recursive iterations" -> 20L, + "number of anchor output rows" -> 2L) + testSparkPlanMetrics(df, 22, Map( + 2L -> (("UnionLoop", expected)))) + } + test("Filter metrics") { // Assume the execution plan is // PhysicalRDD(nodeId = 1) -> Filter(nodeId = 0)