Skip to content

Commit

Permalink
[SPARK-50880][SQL] Add a new visitBinaryComparison method to V2Expres…
Browse files Browse the repository at this point in the history
…sionSQLBuilder
  • Loading branch information
beliefer committed Jan 18, 2025
1 parent 8bbec5d commit 2441ca0
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public String build(Expression expr) {
case "ENDS_WITH" -> visitEndsWith(build(e.children()[0]), build(e.children()[1]));
case "CONTAINS" -> visitContains(build(e.children()[0]), build(e.children()[1]));
case "=", "<>", "<=>", "<", "<=", ">", ">=" ->
visitBinaryComparison(name, inputToSQL(e.children()[0]), inputToSQL(e.children()[1]));
visitBinaryComparison(name, e.children()[0], e.children()[1]);
case "+", "*", "/", "%", "&", "|", "^" ->
visitBinaryArithmetic(name, inputToSQL(e.children()[0]), inputToSQL(e.children()[1]));
case "-" -> {
Expand Down Expand Up @@ -219,6 +219,10 @@ protected String inputToSQL(Expression input) {
}
}

protected String visitBinaryComparison(String name, Expression le, Expression re) {
return visitBinaryComparison(name, inputToSQL(le), inputToSQL(re));
}

protected String visitBinaryComparison(String name, String l, String r) {
if (name.equals("<=>")) {
return "((" + l + " IS NOT NULL AND " + r + " IS NOT NULL AND " + l + " = " + r + ") " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.util.control.NonFatal

import org.apache.spark.{SparkThrowable, SparkUnsupportedOperationException}
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.connector.expressions.{Expression, GeneralScalarExpression, Literal}
import org.apache.spark.sql.connector.expressions.{Expression, Literal}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
import org.apache.spark.sql.jdbc.OracleDialect._
Expand Down Expand Up @@ -62,33 +62,27 @@ private case class OracleDialect() extends JdbcDialect with SQLConfHelper with N
super.visitAggregateFunction(funcName, isDistinct, inputs)
}

override def visitBinaryComparison(name: String, le: Expression, re: Expression): String = {
(le, re) match {
case (lhs: Literal[_], rhs: Expression) if lhs.dataType == BinaryType =>
compareBlob(lhs, name, rhs)
case (lhs: Expression, rhs: Literal[_]) if rhs.dataType == BinaryType =>
compareBlob(lhs, name, rhs)
case _ =>
super.visitBinaryComparison(name, le, re);
}
}

private def compareBlob(lhs: Expression, operator: String, rhs: Expression): String = {
val l = inputToSQL(lhs)
val r = inputToSQL(rhs)
val op = if (operator == "<=>") "=" else operator
val compare = s"DBMS_LOB.COMPARE($l, $r) $op 0"
if (operator == "<=>") {
val compare = s"DBMS_LOB.COMPARE($l, $r) = 0"
s"(($l IS NOT NULL AND $r IS NOT NULL AND $compare) OR ($l IS NULL AND $r IS NULL))"
} else {
compare
s"DBMS_LOB.COMPARE($l, $r) $operator 0"
}
}

override def build(expr: Expression): String = expr match {
case e: GeneralScalarExpression =>
e.name() match {
case "=" | "<>" | "<=>" | "<" | "<=" | ">" | ">=" =>
(e.children()(0), e.children()(1)) match {
case (lhs: Literal[_], rhs: Expression) if lhs.dataType == BinaryType =>
compareBlob(lhs, e.name, rhs)
case (lhs: Expression, rhs: Literal[_]) if rhs.dataType == BinaryType =>
compareBlob(lhs, e.name, rhs)
case _ => super.build(expr)
}
case _ => super.build(expr)
}
case _ => super.build(expr)
}
}

override def compileExpression(expr: Expression): Option[String] = {
Expand Down

0 comments on commit 2441ca0

Please sign in to comment.