diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala index 951c3005bc1c0..de0262edccf69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{AliasHelper, EvalHelper, Expression} +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{AliasHelper, EvalHelper, Expression, SubqueryExpression, VariableReference} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{CreateView, LogicalPlan} import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_IDENTIFIER +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StringType /** @@ -34,15 +38,70 @@ class ResolveIdentifierClause(earlyBatches: Seq[RuleExecutor[LogicalPlan]#Batch] override def batches: Seq[Batch] = earlyBatches.asInstanceOf[Seq[Batch]] } - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( - _.containsPattern(UNRESOLVED_IDENTIFIER)) { - case p: PlanWithUnresolvedIdentifier if p.identifierExpr.resolved && p.childrenResolved => - executor.execute(p.planBuilder.apply(evalIdentifierExpr(p.identifierExpr), p.children)) - case other => - other.transformExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_IDENTIFIER)) { - case e: ExpressionWithUnresolvedIdentifier if e.identifierExpr.resolved => - e.exprBuilder.apply(evalIdentifierExpr(e.identifierExpr), e.otherExprs) - } + override def apply(plan: LogicalPlan): LogicalPlan = { + plan match { + case createView: CreateView => + if (conf.getConf(SQLConf.VARIABLES_UNDER_IDENTIFIER_IN_VIEW)) { + apply0(createView) + } else { + val referredTempVars = new mutable.ArrayBuffer[Seq[String]] + val analyzedChild = apply0(createView.child) + val analyzedQuery = apply0(createView.query, Some(referredTempVars)) + if (referredTempVars.nonEmpty) { + throw QueryCompilationErrors.notAllowedToCreatePermanentViewByReferencingTempVarError( + Seq("unknown"), + referredTempVars.head + ) + } + createView.copy(child = analyzedChild, query = analyzedQuery) + } + case _ => apply0(plan) + } + } + + private def apply0( + plan: LogicalPlan, + referredTempVars: Option[mutable.ArrayBuffer[Seq[String]]] = None): LogicalPlan = + plan.resolveOperatorsUpWithPruning(_.containsPattern(UNRESOLVED_IDENTIFIER)) { + case p: PlanWithUnresolvedIdentifier if p.identifierExpr.resolved && p.childrenResolved => + + if (referredTempVars.isDefined) { + referredTempVars.get ++= collectTemporaryVariablesInLogicalPlan(p) + } + + executor.execute(p.planBuilder.apply(evalIdentifierExpr(p.identifierExpr), p.children)) + case other => + other.transformExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_IDENTIFIER)) { + case e: ExpressionWithUnresolvedIdentifier if e.identifierExpr.resolved => + + if (referredTempVars.isDefined) { + referredTempVars.get ++= collectTemporaryVariablesInExpressionTree(e) + } + + e.exprBuilder.apply(evalIdentifierExpr(e.identifierExpr), e.otherExprs) + } + } + + private def collectTemporaryVariablesInLogicalPlan(child: LogicalPlan): Seq[Seq[String]] = { + def collectTempVars(child: LogicalPlan): Seq[Seq[String]] = { + child.flatMap { plan => + plan.expressions.flatMap { e => collectTemporaryVariablesInExpressionTree(e) } + }.distinct + } + collectTempVars(child) + } + + private def collectTemporaryVariablesInExpressionTree(child: Expression): Seq[Seq[String]] = { + def collectTempVars(child: Expression): Seq[Seq[String]] = { + child.flatMap { expr => + expr.children.flatMap(_.flatMap { + case e: SubqueryExpression => collectTemporaryVariablesInLogicalPlan(e.plan) + case r: VariableReference => Seq(r.originalNameParts) + case _ => Seq.empty + }) + }.distinct + } + collectTempVars(child) } private def evalIdentifierExpr(expr: Expression): Seq[String] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index fa0a90135934c..bd554def2c2f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3179,13 +3179,13 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat } def notAllowedToCreatePermanentViewByReferencingTempVarError( - name: TableIdentifier, - varName: String): Throwable = { + nameParts: Seq[String], + varName: Seq[String]): Throwable = { new AnalysisException( errorClass = "INVALID_TEMP_OBJ_REFERENCE", messageParameters = Map( "obj" -> "VIEW", - "objName" -> toSQLId(name.nameParts), + "objName" -> toSQLId(nameParts), "tempObj" -> "VARIABLE", "tempObjName" -> toSQLId(varName))) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 19fa4c574221b..4429a50bf35c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -5651,6 +5651,17 @@ object SQLConf { .booleanConf .createWithDefault(true) + val VARIABLES_UNDER_IDENTIFIER_IN_VIEW = + buildConf("spark.sql.legacy.allowSessionVariableInPersistedView") + .internal() + .doc( + "When set to true, variables can be found under identifiers in a view query. Throw " + + "otherwise." + ) + .version("4.1.0") + .booleanConf + .createWithDefault(false) + /** * Holds information about keys that have been deprecated. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index dbf98c70504d8..008f21b9a04db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -675,7 +675,7 @@ object ViewHelper extends SQLConfHelper with Logging { val tempVars = collectTemporaryVariables(child) tempVars.foreach { nameParts => throw QueryCompilationErrors.notAllowedToCreatePermanentViewByReferencingTempVarError( - name, nameParts.quoted) + name.nameParts, nameParts) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala index 0faace9227dd1..0d57b815a192c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala @@ -793,6 +793,40 @@ class PersistedViewTestSuite extends SQLViewTestSuite with SharedSparkSession { } } + test("SPARK-51552: Temporary variables under identifiers are not allowed in persisted view") { + sql("declare table_name = 'table';") + sql("create table identifier(table_name) (c1 int);") + sql("create view v_table_1 as select * from table") + sql("create view identifier('v_' || table_name || '_2') as select * from table") + checkError( + exception = intercept[AnalysisException] { + sql("create view v_table_3 as select * from identifier(table_name)") + }, + condition = "INVALID_TEMP_OBJ_REFERENCE", + parameters = Map( + "obj" -> "VIEW", + "objName" -> "`unknown`", + "tempObj" -> "VARIABLE", + "tempObjName" -> "`table_name`" + ) + ) + checkError( + exception = intercept[AnalysisException] { + sql( + """create view identifier('v_' || table_name || '_4') + |as select * from identifier(table_name); + |""".stripMargin) + }, + condition = "INVALID_TEMP_OBJ_REFERENCE", + parameters = Map( + "obj" -> "VIEW", + "objName" -> "`unknown`", + "tempObj" -> "VARIABLE", + "tempObjName" -> "`table_name`" + ) + ) + } + def getShowCreateDDL(view: String, serde: Boolean = false): String = { val result = if (serde) { sql(s"SHOW CREATE TABLE $view AS SERDE")