Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: window function range offset should be long instead of int #733

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 45 additions & 13 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1692,16 +1692,33 @@ impl PhysicalPlanner {
.and_then(|inner| inner.lower_frame_bound_struct.as_ref())
{
Some(l) => match l {
LowerFrameBoundStruct::UnboundedPreceding(_) => {
WindowFrameBound::Preceding(ScalarValue::UInt64(None))
}
LowerFrameBoundStruct::UnboundedPreceding(_) => match units {
WindowFrameUnits::Rows => {
WindowFrameBound::Preceding(ScalarValue::UInt64(None))
}
WindowFrameUnits::Range | WindowFrameUnits::Groups => {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we use Groups?

WindowFrameBound::Preceding(ScalarValue::Int64(None))
}
},
LowerFrameBoundStruct::Preceding(offset) => {
let offset_value = offset.offset.unsigned_abs() as u64;
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(offset_value)))
let offset_value = offset.offset.abs();
match units {
WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(
Some(offset_value as u64),
)),
WindowFrameUnits::Range | WindowFrameUnits::Groups => {
WindowFrameBound::Preceding(ScalarValue::Int64(Some(offset_value)))
}
}
}
LowerFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow,
},
None => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
None => match units {
WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
WindowFrameUnits::Range | WindowFrameUnits::Groups => {
WindowFrameBound::Preceding(ScalarValue::Int64(None))
}
},
};

let upper_bound: WindowFrameBound = match spark_window_frame
Expand All @@ -1710,15 +1727,30 @@ impl PhysicalPlanner {
.and_then(|inner| inner.upper_frame_bound_struct.as_ref())
{
Some(u) => match u {
UpperFrameBoundStruct::UnboundedFollowing(_) => {
WindowFrameBound::Following(ScalarValue::UInt64(None))
}
UpperFrameBoundStruct::Following(offset) => {
WindowFrameBound::Following(ScalarValue::UInt64(Some(offset.offset as u64)))
}
UpperFrameBoundStruct::UnboundedFollowing(_) => match units {
WindowFrameUnits::Rows => {
WindowFrameBound::Following(ScalarValue::UInt64(None))
}
WindowFrameUnits::Range | WindowFrameUnits::Groups => {
WindowFrameBound::Following(ScalarValue::Int64(None))
}
},
UpperFrameBoundStruct::Following(offset) => match units {
WindowFrameUnits::Rows => {
WindowFrameBound::Following(ScalarValue::UInt64(Some(offset.offset as u64)))
}
WindowFrameUnits::Range | WindowFrameUnits::Groups => {
WindowFrameBound::Following(ScalarValue::Int64(Some(offset.offset)))
}
},
UpperFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow,
},
None => WindowFrameBound::Following(ScalarValue::UInt64(None)),
None => match units {
WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)),
WindowFrameUnits::Range | WindowFrameUnits::Groups => {
WindowFrameBound::Following(ScalarValue::Int64(None))
}
},
};

let window_frame = WindowFrame::new_bounds(units, lower_bound, upper_bound);
Expand Down
4 changes: 2 additions & 2 deletions native/proto/src/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,11 @@ message UpperWindowFrameBound {
}

message Preceding {
int32 offset = 1;
int64 offset = 1;
}

message Following {
int32 offset = 1;
int64 offset = 1;
}

message UnboundedPreceding {}
Expand Down
65 changes: 60 additions & 5 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -255,15 +255,17 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
(None, exprToProto(windowExpr.windowFunction, output))
}

if (aggExpr.isEmpty && builtinFunc.isEmpty) {
return None
}

val f = windowExpr.windowSpec.frameSpecification

val (frameType, lowerBound, upperBound) = f match {
case SpecifiedWindowFrame(frameType, lBound, uBound) =>
val frameProto = frameType match {
case RowFrame => OperatorOuterClass.WindowFrameType.Rows
case RangeFrame =>
withInfo(windowExpr, "Range frame is not supported")
return None
case RangeFrame => OperatorOuterClass.WindowFrameType.Range
}

val lBoundProto = lBound match {
Expand All @@ -278,12 +280,17 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
.setCurrentRow(OperatorOuterClass.CurrentRow.newBuilder().build())
.build()
case e =>
val offset = e.eval() match {
case i: Integer => i.toLong
case l: Long => l
case _ => return None
}
OperatorOuterClass.LowerWindowFrameBound
.newBuilder()
.setPreceding(
OperatorOuterClass.Preceding
.newBuilder()
.setOffset(e.eval().asInstanceOf[Int])
.setOffset(offset)
.build())
.build()
}
Expand All @@ -300,12 +307,18 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
.setCurrentRow(OperatorOuterClass.CurrentRow.newBuilder().build())
.build()
case e =>
val offset = e.eval() match {
case i: Integer => i.toLong
case l: Long => l
case _ => return None
}

OperatorOuterClass.UpperWindowFrameBound
.newBuilder()
.setFollowing(
OperatorOuterClass.Following
.newBuilder()
.setOffset(e.eval().asInstanceOf[Int])
.setOffset(offset)
.build())
.build()
}
Expand Down Expand Up @@ -2774,6 +2787,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
return None
}

if (partitionSpec.nonEmpty && orderSpec.nonEmpty &&
!validatePartitionAndSortSpecsForWindowFunc(partitionSpec, orderSpec, op)) {
return None
}

val windowExprProto = winExprs.map(windowExprToProto(_, output, op.conf))
val partitionExprs = partitionSpec.map(exprToProto(_, child.output))

Expand Down Expand Up @@ -3277,4 +3295,41 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
true
}
}

private def validatePartitionAndSortSpecsForWindowFunc(
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
op: SparkPlan): Boolean = {
if (partitionSpec.length != orderSpec.length) {
withInfo(op, "Partitioning and sorting specifications do not match")
return false
} else {
val partitionColumnNames = partitionSpec.collect { case a: AttributeReference =>
a.name
}

if (partitionColumnNames.length != partitionSpec.length) {
withInfo(op, "Unsupported partitioning specification")
return false
}

val orderColumnNames = orderSpec.collect { case s: SortOrder =>
s.child match {
case a: AttributeReference => a.name
}
}

if (orderColumnNames.length != orderSpec.length) {
withInfo(op, "Unsupported SortOrder")
return false
}

if (partitionColumnNames.toSet != orderColumnNames.toSet) {
withInfo(op, "Partitioning and sorting specifications do not match")
return false
}
Comment on lines +3327 to +3330
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe check the partition column and order column one by one instead of a set? I'm not sure if (PARTITION BY k, v ORDER BY v, k) work.


true
}
}
}
39 changes: 26 additions & 13 deletions spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,22 @@ class CometExecSuite extends CometTestBase {
}
}

test(
"fall back to Spark when the partition spec and order spec are not the same for window function") {
withTempView("test") {
sql("""
|CREATE OR REPLACE TEMPORARY VIEW test_agg AS SELECT * FROM VALUES
| (1, true), (1, false),
|(2, true), (3, false), (4, true) AS test(k, v)
|""".stripMargin)

val df = sql("""
SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg
|""".stripMargin)
checkSparkAnswer(df)
}
}

test("Native window operator should be CometUnaryExec") {
withTempView("testData") {
sql("""
Expand All @@ -164,11 +180,11 @@ class CometExecSuite extends CometTestBase {
|(3, 1L, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), null)
|AS testData(val, val_long, val_double, val_date, val_timestamp, cate)
|""".stripMargin)
val df = sql("""
val df1 = sql("""
|SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val ROWS CURRENT ROW)
|FROM testData ORDER BY cate, val
|""".stripMargin)
checkSparkAnswer(df)
checkSparkAnswer(df1)
}
}

Expand All @@ -193,23 +209,21 @@ class CometExecSuite extends CometTestBase {
}
}

test("Window range frame should fall back to Spark") {
test("Window range frame with long boundary should not fail") {
val df =
Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), (3L, "2"), (2L, "1"), (2147483650L, "2"))
.toDF("key", "value")

checkAnswer(
checkSparkAnswer(
df.select(
$"key",
count("key").over(
Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L))),
Seq(Row(1, 3), Row(1, 3), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)))
checkAnswer(
Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L))))
checkSparkAnswer(
df.select(
$"key",
count("key").over(
Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))),
Seq(Row(1, 2), Row(1, 2), Row(2, 3), Row(2147483650L, 2), Row(2147483650L, 4), Row(3, 1)))
Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))))
}

test("Unsupported window expression should fall back to Spark") {
Expand Down Expand Up @@ -1769,10 +1783,9 @@ class CometExecSuite extends CometTestBase {
aggregateFunctions.foreach { function =>
val queries = Seq(
s"SELECT $function OVER() FROM t1",
// TODO: Range frame is not supported yet.
// s"SELECT $function OVER(order by _2) FROM t1",
// s"SELECT $function OVER(order by _2 desc) FROM t1",
// s"SELECT $function OVER(partition by _2 order by _2) FROM t1",
s"SELECT $function OVER(order by _2) FROM t1",
s"SELECT $function OVER(order by _2 desc) FROM t1",
s"SELECT $function OVER(partition by _2 order by _2) FROM t1",
s"SELECT $function OVER(rows between 1 preceding and 1 following) FROM t1",
s"SELECT $function OVER(order by _2 rows between 1 preceding and current row) FROM t1",
s"SELECT $function OVER(order by _2 rows between current row and 1 following) FROM t1")
Expand Down
Loading