Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-51068][SQL] Canonicalized CTEs to avoid cached result not being used and recomputed #50360

Open
wants to merge 2 commits into
base: branch-3.3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,8 @@ case class CTERelationRef(
*/
case class WithCTE(plan: LogicalPlan, cteDefs: Seq[CTERelationDef]) extends LogicalPlan {

val curId = new java.util.concurrent.atomic.AtomicLong()

final override val nodePatterns: Seq[TreePattern] = Seq(CTE)

override def output: Seq[Attribute] = plan.output
Expand All @@ -736,6 +738,28 @@ case class WithCTE(plan: LogicalPlan, cteDefs: Seq[CTERelationDef]) extends Logi
def withNewPlan(newPlan: LogicalPlan): WithCTE = {
withNewChildren(children.init :+ newPlan).asInstanceOf[WithCTE]
}

override def doCanonicalize(): LogicalPlan = {
def canonicalizeCTE(plan: LogicalPlan, defIdToNewId: Map[Long, Long]): LogicalPlan = {
plan.transformDownWithPruning(
_.containsAnyPattern(CTE, PLAN_EXPRESSION)) {
case ref: CTERelationRef if defIdToNewId.contains(ref.cteId) =>
ref.copy(cteId = defIdToNewId(ref.cteId))
case other =>
other.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
case e: SubqueryExpression => e.withNewPlan(canonicalizeCTE(e.plan, defIdToNewId))
}
}
}
val canonicalize = super.doCanonicalize().asInstanceOf[WithCTE]
val defIdToNewId = canonicalize.cteDefs.map(_.id).map((_, curId.getAndIncrement())).toMap
val normalizedPlan = canonicalizeCTE(canonicalize.plan, defIdToNewId)
val newCteDefs = canonicalize.cteDefs.map { cteDef =>
val normalizedCteDef = canonicalizeCTE(cteDef.child, defIdToNewId)
cteDef.copy(child = normalizedCteDef, id = defIdToNewId(cteDef.id))
}
canonicalize.copy(plan = normalizedPlan, cteDefs = newCteDefs)
}
}

case class WithWindowDefinition(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1668,4 +1668,95 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
}
}
}

test("Cache reuse in case of multiple cte") {
sql(
"""
|CACHE TABLE cached_cte_multiple AS
| WITH cte1 AS (
| SELECT 1 AS id, 'Alice' AS name
| UNION ALL
| SELECT 2 AS id, 'Bob' AS name
| ),
| cte2 AS (
| SELECT 1 AS id, 10 AS score
| UNION ALL
| SELECT 2 AS id, 20 AS score
| )
|SELECT cte1.id, cte1.name, cte2.score
|FROM cte1
|JOIN cte2 ON cte1.id = cte2.id
|""".stripMargin
)

val ds = sql("SELECT * FROM cached_cte_multiple")
assert(getNumInMemoryRelations(ds) == 1)

// Assert we can reuse the cached data
assertCached(
sql(
"""
| WITH cte1 AS (
| SELECT 1 AS id, 'Alice' AS name
| UNION ALL
| SELECT 2 AS id, 'Bob' AS name
| ),
| cte2 AS (
| SELECT 1 AS id, 10 AS score
| UNION ALL
| SELECT 2 AS id, 20 AS score
| )
|SELECT cte1.id, cte1.name, cte2.score
|FROM cte1
|JOIN cte2 ON cte1.id = cte2.id
|""".stripMargin
)
)

uncacheTable("cached_cte_multiple")
}

test("Cache reuse in case of nested cte") {
// Create a cached view from nested CTEs
sql(
"""
|CACHE TABLE cached_cte_nested AS
| WITH t1 AS (
| SELECT 1
| ),
| t2 AS (
| WITH t3 AS (
| SELECT * FROM t1
| )
| SELECT * FROM t3
| )
|SELECT *
|FROM t2
|""".stripMargin
)

val ds = sql("SELECT * FROM cached_cte_nested")
assert(getNumInMemoryRelations(ds) == 1)

// Assert we can reuse the cached data
assertCached(
sql(
"""
|WITH t1 AS (
| SELECT 1
|),
|t2 AS (
| WITH t3 AS (
| SELECT * FROM t1
| )
| SELECT * FROM t3
|)
|SELECT *
|FROM t2
|""".stripMargin
)
)

uncacheTable("cached_cte_nested")
}
}
Loading