From daa2f6b194efd319c319025e3fcbd573146a6c1b Mon Sep 17 00:00:00 2001 From: Prashant Singh Date: Fri, 26 Apr 2024 15:03:34 -0700 Subject: [PATCH] [WIP] Add support for BNLJ --- core/src/execution/datafusion/planner.rs | 22 +++++++++- core/src/execution/proto/operator.proto | 7 +++ .../comet/CometSparkSessionExtensions.scala | 21 ++++++++- .../apache/comet/serde/QueryPlanSerde.scala | 43 ++++++++++++++++++- .../apache/spark/sql/comet/operators.scala | 33 ++++++++++++++ .../apache/comet/exec/CometJoinSuite.scala | 30 ++++++++++++- 6 files changed, 152 insertions(+), 4 deletions(-) diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 72174790b..4dc58116a 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -40,7 +40,7 @@ use datafusion::{ physical_plan::{ aggregates::{AggregateMode as DFAggregateMode, PhysicalGroupBy}, filter::FilterExec, - joins::{utils::JoinFilter, HashJoinExec, PartitionMode, SortMergeJoinExec}, + joins::{utils::JoinFilter, HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec}, limit::LocalLimitExec, projection::ProjectionExec, sorts::sort::SortExec, @@ -977,6 +977,26 @@ impl PhysicalPlanner { )?); Ok((scans, join)) } + + OpStruct::BroadcastNestedLoopJoin(join) => { + // create physical op of arrow data fusion. + let empty_keys: &[Expr] = &[]; + let (join_params, scans) = self.parse_join_parameters( + inputs, + children, + &empty_keys, // as bnlj doesn't have join keys + &empty_keys, // as bnlj doesn't have join keys + join.join_type, + &join.condition, + )?; + let join = Arc::new(NestedLoopJoinExec::try_new( + join_params.left, + join_params.right, + join_params.join_filter, + &join_params.join_type + )?); + Ok((scans, join)) + } } } diff --git a/core/src/execution/proto/operator.proto b/core/src/execution/proto/operator.proto index 6080c5668..c9f0f3999 100644 --- a/core/src/execution/proto/operator.proto +++ b/core/src/execution/proto/operator.proto @@ -42,6 +42,7 @@ message Operator { Expand expand = 107; SortMergeJoin sort_merge_join = 108; HashJoin hash_join = 109; + BroadcastNestedLoopJoin broadcast_nested_loop_join = 110; } } @@ -104,6 +105,12 @@ message SortMergeJoin { repeated spark.spark_expression.Expr sort_options = 4; } +message BroadcastNestedLoopJoin { + // join keys will always be null. + JoinType join_type = 1; + optional spark.spark_expression.Expr condition = 2; +} + enum JoinType { Inner = 0; LeftOuter = 1; diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 8ef8cb83e..0735102cb 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -470,6 +470,25 @@ class CometSparkSessionExtensions op } + case op: BroadcastNestedLoopJoinExec + if isCometOperatorEnabled(conf, "broadcast_nested_loop_join") && + op.children.forall(isCometNative(_)) => + val newOp = transform1(op) + newOp match { + case Some(nativeOp) => + CometBroadcastNestedLoopJoinExec( + nativeOp, + op, + op.joinType, + op.condition, + op.buildSide, + op.left, + op.right, + SerializedPlan(None)) + case None => + op + } + case op: BroadcastHashJoinExec if !isCometOperatorEnabled(conf, "broadcast_hash_join") => withInfo(op, "BroadcastHashJoin is not enabled") op diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 57b15e2f5..cd86c91fa 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -2356,6 +2356,47 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { None } + case join: BroadcastNestedLoopJoinExec + if isCometOperatorEnabled(op.conf, "broadcast_nested_loop_join") => + if (join.buildSide == BuildRight) { + if (join.joinType != Inner && join.joinType != LeftOuter + && join.joinType != LeftSemi && join.joinType != LeftAnti) { + return None + } + } else { + if (join.joinType != RightOuter && join.joinType != FullOuter) { + return None + } + } + + val joinType = join.joinType match { + case Inner => JoinType.Inner + case LeftOuter => JoinType.LeftOuter + case RightOuter => JoinType.RightOuter + case FullOuter => JoinType.FullOuter + case LeftSemi => JoinType.LeftSemi + case LeftAnti => JoinType.LeftAnti + case _ => return None // Spark doesn't support other join types + } + + val condition = join.condition.map { cond => + val condProto = exprToProto(cond, join.left.output ++ join.right.output) + if (condProto.isEmpty) { + return None + } + condProto.get + } + + if (childOp.nonEmpty) { + val joinBuilder = OperatorOuterClass.BroadcastNestedLoopJoin + .newBuilder() + .setJoinType(joinType) + condition.map(joinBuilder.setCondition(_)) + Some(result.setBroadcastNestedLoopJoin(joinBuilder).build()) + } else { + None + } + case join: SortMergeJoinExec if isCometOperatorEnabled(op.conf, "sort_merge_join") => // `requiredOrders` and `getKeyOrdering` are copied from Spark's SortMergeJoinExec. def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 39ffef140..6d8724809 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -848,6 +848,39 @@ case class CometBroadcastHashJoinExec( Objects.hashCode(leftKeys, rightKeys, condition, buildSide, left, right) } +case class CometBroadcastNestedLoopJoinExec( + override val nativeOp: Operator, + override val originalPlan: SparkPlan, + joinType: JoinType, + condition: Option[Expression], + buildSide: BuildSide, + override val left: SparkPlan, + override val right: SparkPlan, + override val serializedPlanOpt: SerializedPlan) + extends CometBinaryExec { + override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = + this.copy(left = newLeft, right = newRight) + + override def stringArgs: Iterator[Any] = + Iterator(joinType, condition, left, right) + + override def equals(obj: Any): Boolean = { + obj match { + case other: CometBroadcastNestedLoopJoinExec => + this.condition == other.condition && + this.buildSide == other.buildSide && + this.left == other.left && + this.right == other.right && + this.serializedPlanOpt == other.serializedPlanOpt + case _ => + false + } + } + + override def hashCode(): Int = + Objects.hashCode(condition, buildSide, left, right) +} + case class CometSortMergeJoinExec( override val nativeOp: Operator, override val originalPlan: SparkPlan, diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index 54c0baf16..917517baf 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -23,7 +23,7 @@ import org.scalactic.source.Position import org.scalatest.Tag import org.apache.spark.sql.CometTestBase -import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec, CometBroadcastNestedLoopJoinExec} import org.apache.spark.sql.internal.SQLConf import org.apache.comet.CometConf @@ -232,4 +232,32 @@ class CometJoinSuite extends CometTestBase { } } } + + test("BroadcastNestedLoopJoin without filter") { + assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+") + withSQLConf( + CometConf.COMET_BATCH_SIZE.key -> "100", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + "spark.sql.join.forceApplyShuffledHashJoin" -> "true", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withParquetTable((0 until 1000).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 1000).map(i => (i % 10, i + 2)), "tbl_b") { + // Inner join: build right + val df1 = + sql("SELECT /*+ BROADCAST(tbl_b) */ * FROM tbl_a JOIN tbl_b") + checkSparkAnswerAndOperator( + df1, + Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastNestedLoopJoinExec])) + + // Right join: build left + val df2 = + sql("SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b") + checkSparkAnswerAndOperator( + df2, + Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastNestedLoopJoinExec])) + } + } + } + } }