Skip to content

Commit c56b5bc

Browse files
committed
initial commit
1 parent f4493fd commit c56b5bc

File tree

2 files changed

+93
-4
lines changed

2 files changed

+93
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala

+76-4
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,101 @@
1717

1818
package org.apache.spark.sql.catalyst.analysis
1919

20-
import org.apache.spark.sql.catalyst.expressions.{AliasHelper, EvalHelper, Expression}
20+
import scala.collection.mutable
21+
22+
import org.apache.spark.sql.catalyst.TableIdentifier
23+
import org.apache.spark.sql.catalyst.expressions.{AliasHelper, EvalHelper, Expression, SubqueryExpression, VariableReference}
2124
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
22-
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
25+
import org.apache.spark.sql.catalyst.plans.logical.{CreateView, LogicalPlan}
2326
import org.apache.spark.sql.catalyst.rules.Rule
2427
import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_IDENTIFIER
28+
import org.apache.spark.sql.connector.catalog.CatalogV2Util.isSessionCatalog
29+
import org.apache.spark.sql.errors.QueryCompilationErrors
2530
import org.apache.spark.sql.types.StringType
2631

2732
/**
2833
* Resolves the identifier expressions and builds the original plans/expressions.
2934
*/
3035
object ResolveIdentifierClause extends Rule[LogicalPlan] with AliasHelper with EvalHelper {
3136

32-
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
33-
_.containsPattern(UNRESOLVED_IDENTIFIER)) {
37+
override def apply(plan: LogicalPlan): LogicalPlan = {
38+
val referredTempVars = new mutable.ArrayBuffer[Seq[String]]
39+
val analyzedPlan = apply0(plan, referredTempVars)
40+
41+
analyzedPlan match {
42+
case cv @ CreateView(
43+
ResolvedIdentifierInSessionCatalog(ident), _, _, _, _, _, _, _, _, _) =>
44+
if (referredTempVars.nonEmpty) {
45+
throw QueryCompilationErrors.notAllowedToCreatePermanentViewByReferencingTempVarError(
46+
ident,
47+
referredTempVars.head.quoted
48+
)
49+
}
50+
cv
51+
case _ => analyzedPlan
52+
}
53+
}
54+
55+
object ResolvedIdentifierInSessionCatalog{
56+
def unapply(resolved: LogicalPlan): Option[TableIdentifier] = resolved match {
57+
case ResolvedIdentifier(catalog, ident) if isSessionCatalog(catalog) =>
58+
if (ident.namespace().length != 1) {
59+
throw QueryCompilationErrors
60+
.requiresSinglePartNamespaceError(ident.namespace().toImmutableArraySeq)
61+
}
62+
Some(TableIdentifier(ident.name, Some(ident.namespace.head), Some(catalog.name)))
63+
case _ => None
64+
}
65+
}
66+
67+
private def apply0(
68+
plan: LogicalPlan,
69+
referredTempVars: mutable.ArrayBuffer[Seq[String]]): LogicalPlan =
70+
plan.resolveOperatorsUpWithPruning(_.containsPattern(UNRESOLVED_IDENTIFIER)) {
3471
case p: PlanWithUnresolvedIdentifier if p.identifierExpr.resolved && p.childrenResolved =>
72+
73+
referredTempVars ++= collectTemporaryVariablesInLogicalPlan(p)
74+
3575
p.planBuilder.apply(evalIdentifierExpr(p.identifierExpr), p.children)
3676
case other =>
3777
other.transformExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_IDENTIFIER)) {
3878
case e: ExpressionWithUnresolvedIdentifier if e.identifierExpr.resolved =>
79+
80+
referredTempVars ++= collectTemporaryVariablesInExpressionTree(e)
81+
3982
e.exprBuilder.apply(evalIdentifierExpr(e.identifierExpr), e.otherExprs)
4083
}
4184
}
4285

86+
/**
87+
* Collect all temporary SQL variables and return the identifiers separately.
88+
*/
89+
private def collectTemporaryVariablesInLogicalPlan(child: LogicalPlan): Seq[Seq[String]] = {
90+
def collectTempVars(child: LogicalPlan): Seq[Seq[String]] = {
91+
child.flatMap { plan =>
92+
plan.expressions.flatMap(_.flatMap {
93+
case e: SubqueryExpression => collectTempVars(e.plan)
94+
case r: VariableReference => Seq(r.originalNameParts)
95+
case _ => Seq.empty
96+
})
97+
}.distinct
98+
}
99+
collectTempVars(child)
100+
}
101+
102+
private def collectTemporaryVariablesInExpressionTree(child: Expression): Seq[Seq[String]] = {
103+
def collectTempVars(child: Expression): Seq[Seq[String]] = {
104+
child.flatMap { expr =>
105+
expr.children.flatMap(_.flatMap {
106+
case e: SubqueryExpression => collectTemporaryVariablesInLogicalPlan(e.plan)
107+
case r: VariableReference => Seq(r.originalNameParts)
108+
case _ => Seq.empty
109+
})
110+
}.distinct
111+
}
112+
collectTempVars(child)
113+
}
114+
43115
private def evalIdentifierExpr(expr: Expression): Seq[String] = {
44116
trimAliases(prepareForEval(expr)) match {
45117
case e if !e.foldable => expr.failAnalysis(

sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala

+17
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,23 @@ class PersistedViewTestSuite extends SQLViewTestSuite with SharedSparkSession {
793793
}
794794
}
795795

796+
test("SPARK-51552: Temporary variables under identifiers are not allowed in persisted view") {
797+
sql("declare table_name = 'table';")
798+
sql("create table identifier(table_name) (c1 int);")
799+
checkError(
800+
exception = intercept[AnalysisException] {
801+
sql("create view identifier('v_' || table_name) as select * from identifier(table_name);")
802+
},
803+
condition = "INVALID_TEMP_OBJ_REFERENCE",
804+
parameters = Map(
805+
"obj" -> "VIEW",
806+
"objName" -> "`spark_catalog`.`default`.`v_table`",
807+
"tempObj" -> "VARIABLE",
808+
"tempObjName" -> "`table_name`"
809+
)
810+
)
811+
}
812+
796813
def getShowCreateDDL(view: String, serde: Boolean = false): String = {
797814
val result = if (serde) {
798815
sql(s"SHOW CREATE TABLE $view AS SERDE")

0 commit comments

Comments
 (0)