diff --git a/common/utils/src/main/scala/org/apache/spark/SparkException.scala b/common/utils/src/main/scala/org/apache/spark/SparkException.scala index 00989fd29095c..a6a788685fe37 100644 --- a/common/utils/src/main/scala/org/apache/spark/SparkException.scala +++ b/common/utils/src/main/scala/org/apache/spark/SparkException.scala @@ -133,6 +133,11 @@ object SparkException { } } +/** + * Exception which indicates that the queryStage should be cancelled. + */ +private[spark] class SparkAQEStageCancelException extends RuntimeException + /** * Exception thrown when execution of some user code in the driver process fails, e.g. * accumulator update fails or failure in takeOrdered (user supplies an Ordering implementation diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index e4edd7c8419d6..926ebc87d7464 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -41,7 +41,7 @@ trait FutureAction[T] extends Future[T] { /** * Cancels the execution of this action with an optional reason. */ - def cancel(reason: Option[String]): Unit + def cancel(reason: Option[String], quiet: Boolean = false): Unit /** * Cancels the execution of this action. @@ -119,9 +119,9 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: @volatile private var _cancelled: Boolean = false - override def cancel(reason: Option[String]): Unit = { + override def cancel(reason: Option[String], quiet: Boolean = false): Unit = { _cancelled = true - jobWaiter.cancel(reason) + jobWaiter.cancel(reason, quiet) } override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = { @@ -193,10 +193,11 @@ class ComplexFutureAction[T](run : JobSubmitter => Future[T]) // A promise used to signal the future. private val p = Promise[T]().completeWith(run(jobSubmitter)) - override def cancel(reason: Option[String]): Unit = synchronized { - _cancelled = true - p.tryFailure(new SparkException("Action has been cancelled")) - subActions.foreach(_.cancel(reason)) + override def cancel(reason: Option[String], quiet: Boolean = false): Unit = + synchronized { + _cancelled = true + p.tryFailure(new SparkException("Action has been cancelled")) + subActions.foreach(_.cancel(reason, quiet = quiet)) } private def jobSubmitter = new JobSubmitter { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index baf0ed4df5309..41dc889f464ce 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1099,9 +1099,9 @@ private[spark] class DAGScheduler( /** * Cancel a job that is running or waiting in the queue. */ - def cancelJob(jobId: Int, reason: Option[String]): Unit = { + def cancelJob(jobId: Int, reason: Option[String], quiet: Boolean = false): Unit = { logInfo(log"Asked to cancel job ${MDC(JOB_ID, jobId)}") - eventProcessLoop.post(JobCancelled(jobId, reason)) + eventProcessLoop.post(JobCancelled(jobId, reason, quiet)) } /** @@ -2856,13 +2856,20 @@ private[spark] class DAGScheduler( } } - private[scheduler] def handleJobCancellation(jobId: Int, reason: Option[String]): Unit = { + private[scheduler] def handleJobCancellation( + jobId: Int, reason: Option[String], quiet: Boolean = false): Unit = { if (!jobIdToStageIds.contains(jobId)) { logDebug("Trying to cancel unregistered job " + jobId) } else { + val error = if (quiet) { + new SparkException("Job %d cancelled %s".format(jobId, reason.getOrElse(""))) + } else { + SparkCoreErrors.sparkJobCancelled(jobId, reason.getOrElse(""), null) + } failJobAndIndependentStages( job = jobIdToActiveJob(jobId), - error = SparkCoreErrors.sparkJobCancelled(jobId, reason.getOrElse(""), null) + error = error, + quiet = quiet ) } } @@ -2996,12 +3003,17 @@ private[spark] class DAGScheduler( /** Fails a job and all stages that are only used by that job, and cleans up relevant state. */ private def failJobAndIndependentStages( job: ActiveJob, - error: Exception): Unit = { + error: Exception, + quiet: Boolean = false): Unit = { if (cancelRunningIndependentStages(job, error.getMessage)) { // SPARK-15783 important to cleanup state first, just for tests where we have some asserts // against the state. Otherwise we have a *little* bit of flakiness in the tests. cleanupStateForJobAndIndependentStages(job) - job.listener.jobFailed(error) + if (quiet) { + job.listener.jobCancel(error) + } else { + job.listener.jobFailed(error) + } listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error))) } } @@ -3156,8 +3168,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case StageCancelled(stageId, reason) => dagScheduler.handleStageCancellation(stageId, reason) - case JobCancelled(jobId, reason) => - dagScheduler.handleJobCancellation(jobId, reason) + case JobCancelled(jobId, reason, quiet) => + dagScheduler.handleJobCancellation(jobId, reason, quiet) case JobGroupCancelled(groupId, cancelFutureJobs, reason) => dagScheduler.handleJobGroupCancelled(groupId, cancelFutureJobs, reason) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 8932d2ef323ba..08f925ca67566 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -62,7 +62,8 @@ private[scheduler] case class StageCancelled( private[scheduler] case class JobCancelled( jobId: Int, - reason: Option[String]) + reason: Option[String], + quiet: Boolean = false) extends DAGSchedulerEvent private[scheduler] case class JobGroupCancelled( diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala b/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala index e0f7c8f02132d..89fe9003b126c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala @@ -25,4 +25,6 @@ package org.apache.spark.scheduler private[spark] trait JobListener { def taskSucceeded(index: Int, result: Any): Unit def jobFailed(exception: Exception): Unit + + def jobCancel(exception: Exception): Unit = {} } diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index bfd6759387034..c2aad64dd0332 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.{Future, Promise} +import org.apache.spark.SparkAQEStageCancelException import org.apache.spark.internal.Logging /** @@ -49,8 +50,8 @@ private[spark] class JobWaiter[T]( * cancellation itself is handled asynchronously. After the low level scheduler cancels * all the tasks belonging to this job, it will fail this job with a SparkException. */ - def cancel(reason: Option[String]): Unit = { - dagScheduler.cancelJob(jobId, reason) + def cancel(reason: Option[String] = None, quiet: Boolean = false): Unit = { + dagScheduler.cancelJob(jobId, reason, quiet) } /** @@ -76,4 +77,10 @@ private[spark] class JobWaiter[T]( } } + override def jobCancel(exception: Exception): Unit = { + if (!jobPromise.tryFailure(new SparkAQEStageCancelException())) { + logWarning("Ignore failure", exception) + } + } + } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 5f1c0cbca0d0b..4b748d2dd6106 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -39,7 +39,13 @@ object MimaExcludes { ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Vector.getSizeInBytes"), // [SPARK-52221][SQL] Refactor SqlScriptingLocalVariableManager into more generic context manager - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.scripting.SqlScriptingExecution.withLocalVariableManager") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.scripting.SqlScriptingExecution.withLocalVariableManager"), + + // [SPARK-52024][SQL] Support cancel ShuffleQueryStage when propagate empty relations + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.FutureAction.cancel"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ComplexFutureAction.cancel"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SimpleFutureAction.cancel"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.cancel") ) // Default exclude rules diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index 86316494f6ff8..19466ffa2384f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -61,6 +61,8 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup protected def empty(plan: LogicalPlan): LogicalPlan = LocalRelation(plan.output, data = Seq.empty, isStreaming = plan.isStreaming) + protected def collectCancelableCandidates(maybeCancel: LogicalPlan*): Unit = {} + // Construct a project list from plan's output, while the value is always NULL. private def nullValueProjectList(plan: LogicalPlan): Seq[NamedExpression] = plan.output.map{ a => Alias(cast(Literal(null), a.dataType), a.name)(a.exprId) } @@ -69,7 +71,8 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup protected def commonApplyFunc: PartialFunction[LogicalPlan, LogicalPlan] = { case p: Union if p.children.exists(isEmpty) => - val newChildren = p.children.filterNot(isEmpty) + val (candidates, newChildren) = p.children.partition(isEmpty) + collectCancelableCandidates(candidates: _*) if (newChildren.isEmpty) { empty(p) } else { @@ -106,22 +109,39 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup } if (isLeftEmpty || isRightEmpty || isFalseCondition) { joinType match { - case _: InnerLike => empty(p) + case _: InnerLike => + collectCancelableCandidates(p.left, p.right) + empty(p) // Intersect is handled as LeftSemi by `ReplaceIntersectWithSemiJoin` rule. // Except is handled as LeftAnti by `ReplaceExceptWithAntiJoin` rule. - case LeftOuter | LeftSemi | LeftAnti if isLeftEmpty => empty(p) - case LeftSemi if isRightEmpty | isFalseCondition => empty(p) + case LeftOuter | LeftSemi | LeftAnti if isLeftEmpty => + collectCancelableCandidates(p.right) + empty(p) + case LeftSemi if isRightEmpty | isFalseCondition => + if (isRightEmpty) { + collectCancelableCandidates(p.left) + } else { + collectCancelableCandidates(p.left, p.right) + } + empty(p) case LeftAnti if (isRightEmpty | isFalseCondition) && canExecuteWithoutJoin(p.left) => + if (!isRightEmpty) { + collectCancelableCandidates(p.right) + } p.left case FullOuter if isLeftEmpty && isRightEmpty => empty(p) case LeftOuter | FullOuter if isRightEmpty && canExecuteWithoutJoin(p.left) => Project(p.left.output ++ nullValueProjectList(p.right), p.left) - case RightOuter if isRightEmpty => empty(p) + case RightOuter if isRightEmpty => + collectCancelableCandidates(p.left) + empty(p) case RightOuter | FullOuter if isLeftEmpty && canExecuteWithoutJoin(p.right) => Project(nullValueProjectList(p.left) ++ p.right.output, p.right) case LeftOuter if isFalseCondition && canExecuteWithoutJoin(p.left) => + collectCancelableCandidates(p.right) Project(p.left.output ++ nullValueProjectList(p.right), p.left) case RightOuter if isFalseCondition && canExecuteWithoutJoin(p.right) => + collectCancelableCandidates(p.left) Project(nullValueProjectList(p.left) ++ p.right.output, p.right) case _ => p } @@ -129,6 +149,7 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup nonEmpty(p.right) && canExecuteWithoutJoin(p.left)) { p.left } else if (joinType == LeftAnti && conditionOpt.isEmpty && nonEmpty(p.right)) { + collectCancelableCandidates(p.left) empty(p) } else { p diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 478b92de0b8e4..b3c08100314bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -930,6 +930,14 @@ object SQLConf { .checkValue(_ > 0, "The initial number of partitions must be positive.") .createOptional + val ADAPTIVE_EMPTY_TRIGGER_CANCEL_ENABLED = + buildConf("spark.sql.adaptive.empty.trigger.cancel.enabled") + .doc(s"When true and '${ADAPTIVE_EXECUTION_ENABLED.key}' is true, when propagate " + + " empty relation, Spark will try to cancel QueryStage that is unnecessary.") + .version("3.5.5") + .booleanConf + .createWithDefault(true) + lazy val ALLOW_COLLATIONS_IN_MAP_KEYS = buildConf("spark.sql.collation.allowInMapKeys") .doc("Allow for non-UTF8_BINARY collated strings inside of map's keys") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala index 0f1743eeaacfb..ee19c90e55eeb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala @@ -17,19 +17,26 @@ package org.apache.spark.sql.execution.adaptive +import scala.collection.concurrent.TrieMap + import org.apache.spark.internal.LogKeys.{BATCH_NAME, RULE_NAME} import org.apache.spark.internal.MDC import org.apache.spark.sql.catalyst.analysis.UpdateAttributeNullability import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, EliminateLimits, OptimizeOneRowPlan} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LogicalPlanIntegrity} import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils /** * The optimizer for re-optimizing the logical plan used by AdaptiveSparkPlanExec. */ -class AQEOptimizer(conf: SQLConf, extendedRuntimeOptimizerRules: Seq[Rule[LogicalPlan]]) +class AQEOptimizer( + conf: SQLConf, + stageReuse: TrieMap[SparkPlan, Unit], + stagesToCancel: collection.mutable.Map[Int, (String, ExchangeQueryStageExec)], + extendedRuntimeOptimizerRules: Seq[Rule[LogicalPlan]]) extends RuleExecutor[LogicalPlan] { private def fixedPoint = @@ -39,7 +46,7 @@ class AQEOptimizer(conf: SQLConf, extendedRuntimeOptimizerRules: Seq[Rule[Logica private val defaultBatches = Seq( Batch("Propagate Empty Relations", fixedPoint, - AQEPropagateEmptyRelation, + AQEPropagateEmptyRelation(stageReuse, stagesToCancel), ConvertToLocalRelation, UpdateAttributeNullability), Batch("Dynamic Join Selection", Once, DynamicJoinSelection), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala index 7b3e0cd549b85..9e165e4cb7c92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala @@ -17,14 +17,18 @@ package org.apache.spark.sql.execution.adaptive +import scala.collection.concurrent.TrieMap + import org.apache.spark.sql.catalyst.optimizer.PropagateEmptyRelationBase import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJoin import org.apache.spark.sql.catalyst.plans.logical.EmptyRelation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, LOGICAL_QUERY_STAGE, TRUE_OR_FALSE_LITERAL} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.aggregate.BaseAggregateExec import org.apache.spark.sql.execution.exchange.{REPARTITION_BY_COL, REPARTITION_BY_NUM, ShuffleExchangeLike} import org.apache.spark.sql.execution.joins.HashedRelationWithAllNullKeys +import org.apache.spark.sql.internal.SQLConf /** * This rule runs in the AQE optimizer and optimizes more cases @@ -33,7 +37,10 @@ import org.apache.spark.sql.execution.joins.HashedRelationWithAllNullKeys * Broadcasted [[HashedRelation]] is [[HashedRelationWithAllNullKeys]]. Eliminate join to an * empty [[LocalRelation]]. */ -object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase { +case class AQEPropagateEmptyRelation( + stageReuse: TrieMap[SparkPlan, Unit], + stagesToCancel: collection.mutable.Map[Int, (String, ExchangeQueryStageExec)]) + extends PropagateEmptyRelationBase { override protected def isEmpty(plan: LogicalPlan): Boolean = super.isEmpty(plan) || (!isRootRepartition(plan) && getEstimatedRowCount(plan).contains(0)) @@ -42,6 +49,18 @@ object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase { override protected def empty(plan: LogicalPlan): LogicalPlan = EmptyRelation(plan) + override protected def collectCancelableCandidates(candidates: LogicalPlan*): Unit = { + if (!conf.getConf(SQLConf.ADAPTIVE_EMPTY_TRIGGER_CANCEL_ENABLED)) return + candidates.foreach(_.foreach { + case LogicalQueryStage(_, physicalPlan: SparkPlan) => + physicalPlan.collect { + case s: ShuffleQueryStageExec if !s.isMaterialized && + !stageReuse.contains(s.plan.canonicalized) => s + }.foreach(s => stagesToCancel(s.id) = ("empty relation", s)) + case _ => + }) + } + private def isRootRepartition(plan: LogicalPlan): Boolean = plan match { case l: LogicalQueryStage if l.getTagValue(ROOT_REPARTITION).isDefined => true case _ => false @@ -77,6 +96,7 @@ object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase { private def eliminateSingleColumnNullAwareAntiJoin: PartialFunction[LogicalPlan, LogicalPlan] = { case j @ ExtractSingleColumnNullAwareAntiJoin(_, _) if isRelationWithAllNullKeys(j.right) => + collectCancelableCandidates(j.left) empty(j) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 996e01a0ea936..5962fdec74c72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -26,8 +26,7 @@ import scala.concurrent.ExecutionContext import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal -import org.apache.spark.SparkException -import org.apache.spark.broadcast +import org.apache.spark.{broadcast, SparkAQEStageCancelException, SparkException} import org.apache.spark.internal.{MDC, MessageWithContext} import org.apache.spark.internal.LogKeys._ import org.apache.spark.rdd.RDD @@ -81,8 +80,13 @@ case class AdaptiveSparkPlanExec( @transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]() + @transient private val stagesToCancel: + collection.mutable.Map[Int, (String, ExchangeQueryStageExec)] = + new collection.mutable.HashMap[Int, (String, ExchangeQueryStageExec)]() + // The logical plan optimizer for re-optimizing the current logical plan. - @transient private val optimizer = new AQEOptimizer(conf, + @transient private val optimizer = new AQEOptimizer( + conf, context.stageReuse, stagesToCancel, context.session.sessionState.adaptiveRulesHolder.runtimeOptimizerRules) // `EnsureRequirements` may remove user-specified repartition and assume the query plan won't @@ -309,7 +313,12 @@ case class AdaptiveSparkPlanExec( } events.offer(StageSuccess(stage, res.get)) } else { - events.offer(StageFailure(stage, res.failed.get)) + res.failed.get match { + // There is no need to trigger a new round to reOptimize + case _: SparkAQEStageCancelException => // ignore + case err: Throwable => + events.offer(StageFailure(stage, err)) + } } // explicitly clean up the resources in this stage stage.cleanupResources() @@ -367,7 +376,15 @@ case class AdaptiveSparkPlanExec( currentPhysicalPlan = newPhysicalPlan currentLogicalPlan = newLogicalPlan stagesToReplace = Seq.empty[QueryStageExec] + + stagesToCancel.values.foreach(reasonAndStage => { + if (!reasonAndStage._2.isCancelled) { + reasonAndStage._2.cancel(reasonAndStage._1, quiet = true) + context.stageCache.remove(reasonAndStage._2.plan.canonicalized) + } + }) } + stagesToCancel.clear() } } // Now that some stages have finished, we can try creating new stages. @@ -582,7 +599,9 @@ case class AdaptiveSparkPlanExec( // First have a quick check in the `stageCache` without having to traverse down the node. context.stageCache.get(e.canonicalized) match { case Some(existingStage) if conf.exchangeReuseEnabled => + context.stageReuse.put(e.canonicalized, ()) val stage = reuseQueryStage(existingStage, e) + context.stageReuse.put(stage.plan.canonicalized, ()) val isMaterialized = stage.isMaterialized CreateStageResult( newPlan = stage, @@ -602,7 +621,9 @@ case class AdaptiveSparkPlanExec( val queryStage = context.stageCache.getOrElseUpdate( newStage.plan.canonicalized, newStage) if (queryStage.ne(newStage)) { + context.stageReuse.put(newStage.plan.canonicalized, ()) newStage = reuseQueryStage(queryStage, e) + context.stageReuse.put(newStage.plan.canonicalized, ()) } } val isMaterialized = newStage.isMaterialized @@ -954,6 +975,8 @@ case class AdaptiveExecutionContext(session: SparkSession, qe: QueryExecution) { val stageCache: TrieMap[SparkPlan, ExchangeQueryStageExec] = new TrieMap[SparkPlan, ExchangeQueryStageExec]() + val stageReuse: TrieMap[SparkPlan, Unit] = new TrieMap[SparkPlan, Unit]() + val shuffleIds: ConcurrentHashMap[Int, Boolean] = new ConcurrentHashMap[Int, Boolean]() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index 0a5bdefea7bc5..48dbad1b9b33a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -165,15 +165,23 @@ abstract class QueryStageExec extends LeafExecNode { */ abstract class ExchangeQueryStageExec extends QueryStageExec { + @transient + var isCancelled: Boolean = false + /** * Cancel the stage materialization if in progress with a reason; otherwise do nothing. */ - final def cancel(reason: String): Unit = { - logDebug(s"Cancel query stage: $name") - doCancel(reason) + final def cancel(reason: String, quiet: Boolean = false): Unit = { + this.synchronized { + if (!isCancelled) { + isCancelled = true + logDebug(s"Cancel query stage, quiet: $quiet") + doCancel(reason, quiet) + } + } } - protected def doCancel(reason: String): Unit + protected def doCancel(reason: String, quiet: Boolean): Unit /** * The canonicalized plan before applying query stage optimizer rules. @@ -219,7 +227,8 @@ case class ShuffleQueryStageExec( reuse } - override protected def doCancel(reason: String): Unit = shuffle.cancelShuffleJob(Option(reason)) + override protected def doCancel(reason: String, quiet: Boolean): Unit = + shuffle.cancelShuffleJob(Option(reason), quiet) /** * Returns the Option[MapOutputStatistics]. If the shuffle map stage has no partition, @@ -266,7 +275,8 @@ case class BroadcastQueryStageExec( reuse } - override protected def doCancel(reason: String): Unit = + // TODO: currently broadcast job cannot be cancelled quietly + override protected def doCancel(reason: String, quiet: Boolean): Unit = broadcast.cancelBroadcastJob(Option(reason)) override def getRuntimeStatistics: Statistics = broadcast.runtimeStatistics diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 31a3f53eb7191..4d40cb328af3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -78,9 +78,14 @@ trait ShuffleExchangeLike extends Exchange { private[sql] // Exposed for testing val futureAction = new AtomicReference[Option[FutureAction[MapOutputStatistics]]](None) + @volatile @transient private var isCancelled: Boolean = false + @volatile + @transient + private var quietly: Boolean = false + @transient private lazy val triggerFuture: java.util.concurrent.Future[Any] = { SQLExecution.withThreadLocalCaptured(session, ShuffleExchangeExec.executionContext) { @@ -90,7 +95,7 @@ trait ShuffleExchangeLike extends Exchange { executeQuery(null) // Submit shuffle job if not cancelled. this.synchronized { - if (isCancelled) { + if (isCancelled && !quietly) { promise.tryFailure(new SparkException("Shuffle cancelled.")) } else { val shuffleJob = RDDOperationScope.withScope(sparkContext, nodeName, false, true) { @@ -124,10 +129,16 @@ trait ShuffleExchangeLike extends Exchange { /** * Cancels the shuffle job with an optional reason. */ - final def cancelShuffleJob(reason: Option[String]): Unit = this.synchronized { - if (!isCancelled) { - isCancelled = true - futureAction.get().foreach(_.cancel(reason)) + final def cancelShuffleJob(reason: Option[String], quiet: Boolean): Unit = this.synchronized { + this.synchronized { + if (!isCancelled) { + isCancelled = true + if (quiet) { + quietly = quiet + promise.tryFailure(new SparkAQEStageCancelException) + } + futureAction.get().foreach(_.cancel(reason, quiet = quiet)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala index c1a001117be9f..cafb04932011a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala @@ -179,7 +179,10 @@ abstract class CTEInlineSuiteBase test("SPARK-36447: With in subquery of main query") { withSQLConf( - SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> + AQEPropagateEmptyRelation( + scala.collection.concurrent.TrieMap.empty, + collection.mutable.Map.empty).ruleName) { withTempView("t") { Seq((2, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t") val df = sql( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 68299804ea877..b522646ec59bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -840,7 +840,9 @@ class CachedTableSuite extends QueryTest with SQLTestUtils withSQLConf( SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> enabled.toString, SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> - AQEPropagateEmptyRelation.ruleName) { + AQEPropagateEmptyRelation( + scala.collection.concurrent.TrieMap.empty, + collection.mutable.Map.empty).ruleName) { Seq(1).toDF("c1").createOrReplaceTempView("t1") spark.catalog.cacheTable("t1") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala index 1ed5ea4216a9f..7f11632d1cd20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala @@ -456,7 +456,10 @@ abstract class DynamicPartitionPruningSuiteBase withSQLConf( SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false", - SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> + AQEPropagateEmptyRelation( + scala.collection.concurrent.TrieMap.empty, + collection.mutable.Map.empty).ruleName) { Given("no stats and selective predicate") withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "true") { @@ -1153,7 +1156,10 @@ abstract class DynamicPartitionPruningSuiteBase test("join key with multiple references on the filtering plan") { withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true", - SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName, + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> + AQEPropagateEmptyRelation( + scala.collection.concurrent.TrieMap.empty, + collection.mutable.Map.empty).ruleName, SQLConf.ANSI_ENABLED.key -> "false" // ANSI mode doesn't support "String + String" ) { // when enable AQE, the reusedExchange is inserted when executed. @@ -1316,7 +1322,10 @@ abstract class DynamicPartitionPruningSuiteBase withSQLConf( SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true", - SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> + AQEPropagateEmptyRelation( + scala.collection.concurrent.TrieMap.empty, + collection.mutable.Map.empty).ruleName) { val df = sql( """ |SELECT * FROM fact_sk f diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala index 7d7185ae6c139..6e1e463741fff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala @@ -562,7 +562,10 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true", // Re-enable `MergeScalarSubqueries` SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> "", - SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> + AQEPropagateEmptyRelation( + scala.collection.concurrent.TrieMap.empty, + collection.mutable.Map.empty).ruleName) { val query = "select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " + "bf1.b1 = bf2.b2 where bf2.a2 = 62" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 13de81065cb77..9833f018eecf5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -317,7 +317,10 @@ class AdaptiveQueryExecSuite withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", - SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> + AQEPropagateEmptyRelation( + scala.collection.concurrent.TrieMap.empty, + collection.mutable.Map.empty).ruleName) { val df1 = spark.range(10).withColumn("a", $"id") val df2 = spark.range(10).withColumn("b", $"id") withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { @@ -1471,7 +1474,10 @@ class AdaptiveQueryExecSuite SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString, // This test is a copy of test(SPARK-32573), in order to test the configuration // `spark.sql.adaptive.optimizer.excludedRules` works as expect. - SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> + AQEPropagateEmptyRelation( + scala.collection.concurrent.TrieMap.empty, + collection.mutable.Map.empty).ruleName) { val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( "SELECT * FROM testData2 t1 WHERE t1.b NOT IN (SELECT b FROM testData3)") val bhj = findTopLevelBroadcastHashJoin(plan) @@ -2101,7 +2107,10 @@ class AdaptiveQueryExecSuite withTable("t") { withSQLConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", SQLConf.SHUFFLE_PARTITIONS.key -> "2", - SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> + AQEPropagateEmptyRelation( + scala.collection.concurrent.TrieMap.empty, + collection.mutable.Map.empty).ruleName) { spark.sql("CREATE TABLE t (c1 int) USING PARQUET") val (_, adaptive) = runAdaptiveAndVerifyResult("SELECT c1, count(*) FROM t GROUP BY c1") assert( @@ -2281,7 +2290,10 @@ class AdaptiveQueryExecSuite withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> + AQEPropagateEmptyRelation( + scala.collection.concurrent.TrieMap.empty, + collection.mutable.Map.empty).ruleName) { withTempView("t2") { // create a temp view with 0 partition spark.createDataFrame(sparkContext.emptyRDD[Row], new StructType().add("b", IntegerType)) @@ -2600,7 +2612,10 @@ class AdaptiveQueryExecSuite withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1048584", - SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> + AQEPropagateEmptyRelation( + scala.collection.concurrent.TrieMap.empty, + collection.mutable.Map.empty).ruleName) { // Spark estimates a string column as 20 bytes so with 60k rows, these relations should be // estimated at ~120m bytes which is greater than the broadcast join threshold. val joinKeyOne = "00112233445566778899"