diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 0d833d7bf3928..689794cfcfc02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -722,6 +722,8 @@ case class CTERelationRef( */ case class WithCTE(plan: LogicalPlan, cteDefs: Seq[CTERelationDef]) extends LogicalPlan { + val curId = new java.util.concurrent.atomic.AtomicLong() + final override val nodePatterns: Seq[TreePattern] = Seq(CTE) override def output: Seq[Attribute] = plan.output @@ -736,6 +738,28 @@ case class WithCTE(plan: LogicalPlan, cteDefs: Seq[CTERelationDef]) extends Logi def withNewPlan(newPlan: LogicalPlan): WithCTE = { withNewChildren(children.init :+ newPlan).asInstanceOf[WithCTE] } + + override def doCanonicalize(): LogicalPlan = { + def canonicalizeCTE(plan: LogicalPlan, defIdToNewId: Map[Long, Long]): LogicalPlan = { + plan.transformDownWithPruning( + _.containsAnyPattern(CTE, PLAN_EXPRESSION)) { + case ref: CTERelationRef if defIdToNewId.contains(ref.cteId) => + ref.copy(cteId = defIdToNewId(ref.cteId)) + case other => + other.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) { + case e: SubqueryExpression => e.withNewPlan(canonicalizeCTE(e.plan, defIdToNewId)) + } + } + } + val canonicalize = super.doCanonicalize().asInstanceOf[WithCTE] + val defIdToNewId = canonicalize.cteDefs.map(_.id).map((_, curId.getAndIncrement())).toMap + val normalizedPlan = canonicalizeCTE(canonicalize.plan, defIdToNewId) + val newCteDefs = canonicalize.cteDefs.map { cteDef => + val normalizedCteDef = canonicalizeCTE(cteDef.child, defIdToNewId) + cteDef.copy(child = normalizedCteDef, id = defIdToNewId(cteDef.id)) + } + canonicalize.copy(plan = normalizedPlan, cteDefs = newCteDefs) + } } case class WithWindowDefinition( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index d27821b59bea0..1b7d56d8e227e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -1668,4 +1668,95 @@ class CachedTableSuite extends QueryTest with SQLTestUtils } } } + + test("Cache reuse in case of multiple cte") { + sql( + """ + |CACHE TABLE cached_cte_multiple AS + | WITH cte1 AS ( + | SELECT 1 AS id, 'Alice' AS name + | UNION ALL + | SELECT 2 AS id, 'Bob' AS name + | ), + | cte2 AS ( + | SELECT 1 AS id, 10 AS score + | UNION ALL + | SELECT 2 AS id, 20 AS score + | ) + |SELECT cte1.id, cte1.name, cte2.score + |FROM cte1 + |JOIN cte2 ON cte1.id = cte2.id + |""".stripMargin + ) + + val ds = sql("SELECT * FROM cached_cte_multiple") + assert(getNumInMemoryRelations(ds) == 1) + + // Assert we can reuse the cached data + assertCached( + sql( + """ + | WITH cte1 AS ( + | SELECT 1 AS id, 'Alice' AS name + | UNION ALL + | SELECT 2 AS id, 'Bob' AS name + | ), + | cte2 AS ( + | SELECT 1 AS id, 10 AS score + | UNION ALL + | SELECT 2 AS id, 20 AS score + | ) + |SELECT cte1.id, cte1.name, cte2.score + |FROM cte1 + |JOIN cte2 ON cte1.id = cte2.id + |""".stripMargin + ) + ) + + uncacheTable("cached_cte_multiple") + } + + test("Cache reuse in case of nested cte") { + // Create a cached view from nested CTEs + sql( + """ + |CACHE TABLE cached_cte_nested AS + | WITH t1 AS ( + | SELECT 1 + | ), + | t2 AS ( + | WITH t3 AS ( + | SELECT * FROM t1 + | ) + | SELECT * FROM t3 + | ) + |SELECT * + |FROM t2 + |""".stripMargin + ) + + val ds = sql("SELECT * FROM cached_cte_nested") + assert(getNumInMemoryRelations(ds) == 1) + + // Assert we can reuse the cached data + assertCached( + sql( + """ + |WITH t1 AS ( + | SELECT 1 + |), + |t2 AS ( + | WITH t3 AS ( + | SELECT * FROM t1 + | ) + | SELECT * FROM t3 + |) + |SELECT * + |FROM t2 + |""".stripMargin + ) + ) + + uncacheTable("cached_cte_nested") + } }