Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add support for BNLJ #343

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use datafusion::{
physical_plan::{
aggregates::{AggregateMode as DFAggregateMode, PhysicalGroupBy},
filter::FilterExec,
joins::{utils::JoinFilter, HashJoinExec, PartitionMode, SortMergeJoinExec},
joins::{utils::JoinFilter, HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec},
limit::LocalLimitExec,
projection::ProjectionExec,
sorts::sort::SortExec,
Expand Down Expand Up @@ -977,6 +977,26 @@ impl PhysicalPlanner {
)?);
Ok((scans, join))
}

OpStruct::BroadcastNestedLoopJoin(join) => {
// create physical op of arrow data fusion.
let empty_keys: &[Expr] = &[];
let (join_params, scans) = self.parse_join_parameters(
inputs,
children,
&empty_keys, // as bnlj doesn't have join keys
&empty_keys, // as bnlj doesn't have join keys
join.join_type,
&join.condition,
)?;
let join = Arc::new(NestedLoopJoinExec::try_new(
join_params.left,
join_params.right,
join_params.join_filter,
&join_params.join_type
)?);
Ok((scans, join))
}
}
}

Expand Down
7 changes: 7 additions & 0 deletions core/src/execution/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ message Operator {
Expand expand = 107;
SortMergeJoin sort_merge_join = 108;
HashJoin hash_join = 109;
BroadcastNestedLoopJoin broadcast_nested_loop_join = 110;
}
}

Expand Down Expand Up @@ -104,6 +105,12 @@ message SortMergeJoin {
repeated spark.spark_expression.Expr sort_options = 4;
}

message BroadcastNestedLoopJoin {
// join keys will always be null.
JoinType join_type = 1;
optional spark.spark_expression.Expr condition = 2;
}

enum JoinType {
Inner = 0;
LeftOuter = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -470,6 +470,25 @@ class CometSparkSessionExtensions
op
}

case op: BroadcastNestedLoopJoinExec
if isCometOperatorEnabled(conf, "broadcast_nested_loop_join") &&
op.children.forall(isCometNative(_)) =>
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometBroadcastNestedLoopJoinExec(
nativeOp,
op,
op.joinType,
op.condition,
op.buildSide,
op.left,
op.right,
SerializedPlan(None))
case None =>
op
}

case op: BroadcastHashJoinExec if !isCometOperatorEnabled(conf, "broadcast_hash_join") =>
withInfo(op, "BroadcastHashJoin is not enabled")
op
Expand Down
43 changes: 42 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -2356,6 +2356,47 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
None
}

case join: BroadcastNestedLoopJoinExec
if isCometOperatorEnabled(op.conf, "broadcast_nested_loop_join") =>
if (join.buildSide == BuildRight) {
if (join.joinType != Inner && join.joinType != LeftOuter
&& join.joinType != LeftSemi && join.joinType != LeftAnti) {
return None
}
} else {
if (join.joinType != RightOuter && join.joinType != FullOuter) {
return None
}
}

val joinType = join.joinType match {
case Inner => JoinType.Inner
case LeftOuter => JoinType.LeftOuter
case RightOuter => JoinType.RightOuter
case FullOuter => JoinType.FullOuter
case LeftSemi => JoinType.LeftSemi
case LeftAnti => JoinType.LeftAnti
case _ => return None // Spark doesn't support other join types
}

val condition = join.condition.map { cond =>
val condProto = exprToProto(cond, join.left.output ++ join.right.output)
if (condProto.isEmpty) {
return None
}
condProto.get
}

if (childOp.nonEmpty) {
val joinBuilder = OperatorOuterClass.BroadcastNestedLoopJoin
.newBuilder()
.setJoinType(joinType)
condition.map(joinBuilder.setCondition(_))
Some(result.setBroadcastNestedLoopJoin(joinBuilder).build())
} else {
None
}

case join: SortMergeJoinExec if isCometOperatorEnabled(op.conf, "sort_merge_join") =>
// `requiredOrders` and `getKeyOrdering` are copied from Spark's SortMergeJoinExec.
def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
Expand Down
33 changes: 33 additions & 0 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,39 @@ case class CometBroadcastHashJoinExec(
Objects.hashCode(leftKeys, rightKeys, condition, buildSide, left, right)
}

case class CometBroadcastNestedLoopJoinExec(
override val nativeOp: Operator,
override val originalPlan: SparkPlan,
joinType: JoinType,
condition: Option[Expression],
buildSide: BuildSide,
override val left: SparkPlan,
override val right: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometBinaryExec {
override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan =
this.copy(left = newLeft, right = newRight)

override def stringArgs: Iterator[Any] =
Iterator(joinType, condition, left, right)

override def equals(obj: Any): Boolean = {
obj match {
case other: CometBroadcastNestedLoopJoinExec =>
this.condition == other.condition &&
this.buildSide == other.buildSide &&
this.left == other.left &&
this.right == other.right &&
this.serializedPlanOpt == other.serializedPlanOpt
case _ =>
false
}
}

override def hashCode(): Int =
Objects.hashCode(condition, buildSide, left, right)
}

case class CometSortMergeJoinExec(
override val nativeOp: Operator,
override val originalPlan: SparkPlan,
Expand Down
30 changes: 29 additions & 1 deletion spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.scalactic.source.Position
import org.scalatest.Tag

import org.apache.spark.sql.CometTestBase
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec}
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec, CometBroadcastNestedLoopJoinExec}
import org.apache.spark.sql.internal.SQLConf

import org.apache.comet.CometConf
Expand Down Expand Up @@ -232,4 +232,32 @@ class CometJoinSuite extends CometTestBase {
}
}
}

test("BroadcastNestedLoopJoin without filter") {
assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+")
withSQLConf(
CometConf.COMET_BATCH_SIZE.key -> "100",
SQLConf.PREFER_SORTMERGEJOIN.key -> "false",
"spark.sql.join.forceApplyShuffledHashJoin" -> "true",
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withParquetTable((0 until 1000).map(i => (i, i % 5)), "tbl_a") {
withParquetTable((0 until 1000).map(i => (i % 10, i + 2)), "tbl_b") {
// Inner join: build right
val df1 =
sql("SELECT /*+ BROADCAST(tbl_b) */ * FROM tbl_a JOIN tbl_b")
checkSparkAnswerAndOperator(
df1,
Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastNestedLoopJoinExec]))

// Right join: build left
val df2 =
sql("SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b")
checkSparkAnswerAndOperator(
df2,
Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastNestedLoopJoinExec]))
}
}
}
}
}