Skip to content

Commit d03d87c

Browse files
authored
Fix slicing after Join and GroupBy in streaming cudf-polars (#19187)
The streaming executor does not properly handle slicing after `Join` or `GroupBy`. Rather than slicing the "reduced" join/gropuby result, each partition is sliced individually. This PR includes a general fix, by pulling the slice operation out of `Join` and `GroupBy` nodes. Authors: - Richard (Rick) Zamora (https://github.com/rjzamora) - Tom Augspurger (https://github.com/TomAugspurger) Approvers: - Tom Augspurger (https://github.com/TomAugspurger) URL: #19187
1 parent 3c1e4db commit d03d87c

File tree

5 files changed

+115
-2
lines changed

5 files changed

+115
-2
lines changed

python/cudf_polars/cudf_polars/experimental/groupby.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from cudf_polars.containers import DataType
1616
from cudf_polars.dsl.expr import Agg, BinOp, Col, Len, NamedExpr
17-
from cudf_polars.dsl.ir import GroupBy, Select
17+
from cudf_polars.dsl.ir import GroupBy, Select, Slice
1818
from cudf_polars.dsl.traversal import traversal
1919
from cudf_polars.dsl.utils.naming import unique_names
2020
from cudf_polars.experimental.base import PartitionInfo
@@ -143,6 +143,25 @@ def decompose(
143143
def _(
144144
ir: GroupBy, rec: LowerIRTransformer
145145
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
146+
# Pull slice operations out of the GroupBy before lowering
147+
if ir.zlice is not None:
148+
offset, length = ir.zlice
149+
if length is None: # pragma: no cover
150+
return _lower_ir_fallback(
151+
ir,
152+
rec,
153+
msg="This slice not supported for multiple partitions.",
154+
)
155+
new_join = GroupBy(
156+
ir.schema,
157+
ir.keys,
158+
ir.agg_requests,
159+
ir.maintain_order,
160+
None,
161+
*ir.children,
162+
)
163+
return rec(Slice(ir.schema, offset, length, new_join))
164+
146165
# Extract child partitioning
147166
child, partition_info = rec(ir.children[0])
148167

python/cudf_polars/cudf_polars/experimental/join.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from functools import reduce
99
from typing import TYPE_CHECKING, Any
1010

11-
from cudf_polars.dsl.ir import ConditionalJoin, Join
11+
from cudf_polars.dsl.ir import ConditionalJoin, Join, Slice
1212
from cudf_polars.experimental.base import PartitionInfo, get_key_name
1313
from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node
1414
from cudf_polars.experimental.repartition import Repartition
@@ -226,6 +226,24 @@ def _(
226226
def _(
227227
ir: Join, rec: LowerIRTransformer
228228
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
229+
# Pull slice operations out of the Join before lowering
230+
if (zlice := ir.options[2]) is not None:
231+
offset, length = zlice
232+
if length is None: # pragma: no cover
233+
return _lower_ir_fallback(
234+
ir,
235+
rec,
236+
msg="This slice not supported for multiple partitions.",
237+
)
238+
new_join = Join(
239+
ir.schema,
240+
ir.left_on,
241+
ir.right_on,
242+
(*ir.options[:2], None, *ir.options[3:]),
243+
*ir.children,
244+
)
245+
return rec(Slice(ir.schema, offset, length, new_join))
246+
229247
# Lower children
230248
children, _partition_info = zip(*(rec(c) for c in ir.children), strict=True)
231249
partition_info = reduce(operator.or_, _partition_info)

python/cudf_polars/cudf_polars/experimental/parallel.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
HStack,
2525
MapFunction,
2626
Projection,
27+
Slice,
2728
Union,
2829
)
2930
from cudf_polars.dsl.traversal import CachingVisitor, traversal
@@ -32,6 +33,7 @@
3233
generate_ir_tasks,
3334
lower_ir_node,
3435
)
36+
from cudf_polars.experimental.repartition import Repartition
3537
from cudf_polars.experimental.utils import _concat, _lower_ir_fallback
3638

3739
if TYPE_CHECKING:
@@ -336,6 +338,29 @@ def _lower_ir_pwise(
336338
lower_ir_node.register(HConcat, _lower_ir_pwise)
337339

338340

341+
@lower_ir_node.register(Slice)
342+
def _(
343+
ir: Slice, rec: LowerIRTransformer
344+
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
345+
if ir.offset == 0:
346+
# Taking the first N rows.
347+
# We don't know how large each partition is, so we reduce.
348+
new_node, partition_info = _lower_ir_pwise(ir, rec)
349+
if partition_info[new_node].count > 1:
350+
# Collapse down to single partition
351+
inter = Repartition(new_node.schema, new_node)
352+
partition_info[inter] = PartitionInfo(count=1)
353+
# Slice reduced partition
354+
new_node = ir.reconstruct([inter])
355+
partition_info[new_node] = PartitionInfo(count=1)
356+
return new_node, partition_info
357+
358+
# Fallback
359+
return _lower_ir_fallback(
360+
ir, rec, msg="This slice not supported for multiple partitions."
361+
)
362+
363+
339364
@lower_ir_node.register(HStack)
340365
def _(
341366
ir: HStack, rec: LowerIRTransformer

python/cudf_polars/tests/experimental/test_groupby.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,20 @@ def test_groupby_agg_empty(df: pl.LazyFrame, engine: pl.GPUEngine) -> None:
187187
assert_gpu_result_equal(q, engine=engine, check_row_order=False)
188188

189189

190+
@pytest.mark.parametrize("zlice", [(0, 2), (2, 2), (-2, None)])
191+
def test_groupby_then_slice(
192+
df: pl.LazyFrame, engine: pl.GPUEngine, zlice: tuple[int, int]
193+
) -> None:
194+
df = pl.LazyFrame(
195+
{
196+
"x": [0, 1, 2, 3] * 2,
197+
"y": [1, 2, 1, 2] * 2,
198+
}
199+
)
200+
q = df.group_by("y", maintain_order=True).max().slice(*zlice)
201+
assert_gpu_result_equal(q, engine=engine)
202+
203+
190204
def test_groupby_on_equality(df: pl.LazyFrame, engine: pl.GPUEngine) -> None:
191205
# See: https://github.com/rapidsai/cudf/issues/19152
192206
df = pl.LazyFrame(

python/cudf_polars/tests/experimental/test_join.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,40 @@ def test_join_conditional(reverse, max_rows_per_partition):
158158
left, right = right, left
159159
q = left.join_where(right, pl.col("y") < pl.col("yy"))
160160
assert_gpu_result_equal(q, engine=engine, check_row_order=False)
161+
162+
163+
@pytest.mark.parametrize("zlice", [(0, 2), (2, 2), (-2, None)])
164+
def test_join_and_slice(zlice):
165+
engine = pl.GPUEngine(
166+
raise_on_fail=True,
167+
executor="streaming",
168+
executor_options={
169+
"max_rows_per_partition": 3,
170+
"broadcast_join_limit": 100,
171+
"scheduler": DEFAULT_SCHEDULER,
172+
"shuffle_method": "tasks",
173+
"fallback_mode": "warn" if zlice[0] == 0 else "silent",
174+
},
175+
)
176+
left = pl.LazyFrame(
177+
{
178+
"a": [1, 2, 3, 1, None],
179+
"b": [1, 2, 3, 4, 5],
180+
"c": [2, 3, 4, 5, 6],
181+
}
182+
)
183+
right = pl.LazyFrame(
184+
{
185+
"a": [1, 4, 3, 7, None, None, 1],
186+
"c": [2, 3, 4, 5, 6, 7, 8],
187+
"d": [6, None, 7, 8, -1, 2, 4],
188+
}
189+
)
190+
q = left.join(right, on="a", how="inner").slice(*zlice)
191+
# Check that we get the correct row count
192+
# See: https://github.com/rapidsai/cudf/issues/19153
193+
assert q.collect(engine=engine).height == q.collect().height
194+
195+
# Need sort to match order after a join
196+
q = left.join(right, on="a", how="inner").sort(pl.col("a")).slice(*zlice)
197+
assert_gpu_result_equal(q, engine=engine)

0 commit comments

Comments
 (0)