Skip to content

Fix slicing after Join and GroupBy in streaming cudf-polars #19187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 25, 2025
21 changes: 20 additions & 1 deletion python/cudf_polars/cudf_polars/experimental/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from cudf_polars.containers import DataType
from cudf_polars.dsl.expr import Agg, BinOp, Col, Len, NamedExpr
from cudf_polars.dsl.ir import GroupBy, Select
from cudf_polars.dsl.ir import GroupBy, Select, Slice
from cudf_polars.dsl.traversal import traversal
from cudf_polars.dsl.utils.naming import unique_names
from cudf_polars.experimental.base import PartitionInfo
Expand Down Expand Up @@ -143,6 +143,25 @@ def decompose(
def _(
ir: GroupBy, rec: LowerIRTransformer
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
# Pull slice operations out of the GroupBy before lowering
if ir.zlice is not None:
offset, length = ir.zlice
if length is None: # pragma: no cover
return _lower_ir_fallback(
ir,
rec,
msg="This slice not supported for multiple partitions.",
)
new_join = GroupBy(
ir.schema,
ir.keys,
ir.agg_requests,
ir.maintain_order,
None,
*ir.children,
)
return rec(Slice(ir.schema, offset, length, new_join))

# Extract child partitioning
child, partition_info = rec(ir.children[0])

Expand Down
20 changes: 19 additions & 1 deletion python/cudf_polars/cudf_polars/experimental/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from functools import reduce
from typing import TYPE_CHECKING, Any

from cudf_polars.dsl.ir import ConditionalJoin, Join
from cudf_polars.dsl.ir import ConditionalJoin, Join, Slice
from cudf_polars.experimental.base import PartitionInfo, get_key_name
from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node
from cudf_polars.experimental.repartition import Repartition
Expand Down Expand Up @@ -226,6 +226,24 @@ def _(
def _(
ir: Join, rec: LowerIRTransformer
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
# Pull slice operations out of the Join before lowering
if (zlice := ir.options[2]) is not None:
offset, length = zlice
if length is None: # pragma: no cover
return _lower_ir_fallback(
ir,
rec,
msg="This slice not supported for multiple partitions.",
)
new_join = Join(
ir.schema,
ir.left_on,
ir.right_on,
(*ir.options[:2], None, *ir.options[3:]),
*ir.children,
)
return rec(Slice(ir.schema, offset, length, new_join))

# Lower children
children, _partition_info = zip(*(rec(c) for c in ir.children), strict=True)
partition_info = reduce(operator.or_, _partition_info)
Expand Down
25 changes: 25 additions & 0 deletions python/cudf_polars/cudf_polars/experimental/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
HStack,
MapFunction,
Projection,
Slice,
Union,
)
from cudf_polars.dsl.traversal import CachingVisitor, traversal
Expand All @@ -32,6 +33,7 @@
generate_ir_tasks,
lower_ir_node,
)
from cudf_polars.experimental.repartition import Repartition
from cudf_polars.experimental.utils import _concat, _lower_ir_fallback

if TYPE_CHECKING:
Expand Down Expand Up @@ -336,6 +338,29 @@ def _lower_ir_pwise(
lower_ir_node.register(HConcat, _lower_ir_pwise)


@lower_ir_node.register(Slice)
def _(
ir: Slice, rec: LowerIRTransformer
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
if ir.offset == 0:
# Taking the first N rows.
# We don't know how large each partition is, so we reduce.
new_node, partition_info = _lower_ir_pwise(ir, rec)
if partition_info[new_node].count > 1:
# Collapse down to single partition
inter = Repartition(new_node.schema, new_node)
partition_info[inter] = PartitionInfo(count=1)
# Slice reduced partition
new_node = ir.reconstruct([inter])
partition_info[new_node] = PartitionInfo(count=1)
return new_node, partition_info

# Fallback
return _lower_ir_fallback(
ir, rec, msg="This slice not supported for multiple partitions."
)


@lower_ir_node.register(HStack)
def _(
ir: HStack, rec: LowerIRTransformer
Expand Down
14 changes: 14 additions & 0 deletions python/cudf_polars/tests/experimental/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,20 @@ def test_groupby_agg_empty(df: pl.LazyFrame, engine: pl.GPUEngine) -> None:
assert_gpu_result_equal(q, engine=engine, check_row_order=False)


@pytest.mark.parametrize("zlice", [(0, 2), (2, 2), (-2, None)])
def test_groupby_then_slice(
df: pl.LazyFrame, engine: pl.GPUEngine, zlice: tuple[int, int]
) -> None:
df = pl.LazyFrame(
{
"x": [0, 1, 2, 3] * 2,
"y": [1, 2, 1, 2] * 2,
}
)
q = df.group_by("y", maintain_order=True).max().slice(*zlice)
assert_gpu_result_equal(q, engine=engine)


def test_groupby_on_equality(df: pl.LazyFrame, engine: pl.GPUEngine) -> None:
# See: https://github.com/rapidsai/cudf/issues/19152
df = pl.LazyFrame(
Expand Down
37 changes: 37 additions & 0 deletions python/cudf_polars/tests/experimental/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,40 @@ def test_join_conditional(reverse, max_rows_per_partition):
left, right = right, left
q = left.join_where(right, pl.col("y") < pl.col("yy"))
assert_gpu_result_equal(q, engine=engine, check_row_order=False)


@pytest.mark.parametrize("zlice", [(0, 2), (2, 2), (-2, None)])
def test_join_and_slice(zlice):
engine = pl.GPUEngine(
raise_on_fail=True,
executor="streaming",
executor_options={
"max_rows_per_partition": 3,
"broadcast_join_limit": 100,
"scheduler": DEFAULT_SCHEDULER,
"shuffle_method": "tasks",
"fallback_mode": "warn" if zlice[0] == 0 else "silent",
},
)
left = pl.LazyFrame(
{
"a": [1, 2, 3, 1, None],
"b": [1, 2, 3, 4, 5],
"c": [2, 3, 4, 5, 6],
}
)
right = pl.LazyFrame(
{
"a": [1, 4, 3, 7, None, None, 1],
"c": [2, 3, 4, 5, 6, 7, 8],
"d": [6, None, 7, 8, -1, 2, 4],
}
)
q = left.join(right, on="a", how="inner").slice(*zlice)
# Check that we get the correct row count
# See: https://github.com/rapidsai/cudf/issues/19153
assert q.collect(engine=engine).height == q.collect().height

# Need sort to match order after a join
q = left.join(right, on="a", how="inner").sort(pl.col("a")).slice(*zlice)
assert_gpu_result_equal(q, engine=engine)