Skip to content

Commit 870c2ff

Browse files
authored
Add support for horizontal string concatenation pl.concat_str (#19142)
Needed to get TPC-DS Query 5 running. Authors: - Matthew Murray (https://github.com/Matt711) - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Matthew Roeschke (https://github.com/mroeschke) URL: #19142
1 parent 628bfc1 commit 870c2ff

File tree

4 files changed

+115
-68
lines changed

4 files changed

+115
-68
lines changed

python/cudf_polars/cudf_polars/dsl/expressions/string.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,22 @@
1212
import pyarrow as pa
1313
import pyarrow.compute as pc
1414

15+
import polars as pl
1516
from polars.exceptions import InvalidOperationError
1617

1718
import pylibcudf as plc
1819

19-
from cudf_polars.containers import Column
20+
from cudf_polars.containers import Column, DataType
2021
from cudf_polars.dsl.expressions.base import ExecutionContext, Expr
2122
from cudf_polars.dsl.expressions.literal import Literal, LiteralColumn
23+
from cudf_polars.dsl.utils.reshape import broadcast
2224

2325
if TYPE_CHECKING:
2426
from typing_extensions import Self
2527

2628
from polars.polars import _expr_nodes as pl_expr
2729

28-
from cudf_polars.containers import DataFrame, DataType
30+
from cudf_polars.containers import DataFrame
2931

3032
__all__ = ["StringFunction"]
3133

@@ -110,6 +112,7 @@ def __init__(
110112

111113
def _validate_input(self) -> None:
112114
if self.name not in (
115+
StringFunction.Name.ConcatHorizontal,
113116
StringFunction.Name.ConcatVertical,
114117
StringFunction.Name.Contains,
115118
StringFunction.Name.EndsWith,
@@ -212,7 +215,32 @@ def do_evaluate(
212215
self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
213216
) -> Column:
214217
"""Evaluate this expression given a dataframe for context."""
215-
if self.name is StringFunction.Name.ConcatVertical:
218+
if self.name is StringFunction.Name.ConcatHorizontal:
219+
columns = [
220+
Column(child.evaluate(df, context=context).obj).astype(
221+
DataType(pl.String())
222+
)
223+
for child in self.children
224+
]
225+
226+
broadcasted = broadcast(
227+
*columns, target_length=max(col.size for col in columns)
228+
)
229+
230+
delimiter, ignore_nulls = self.options
231+
232+
return Column(
233+
plc.strings.combine.concatenate(
234+
plc.Table([col.obj for col in broadcasted]),
235+
plc.Scalar.from_py(delimiter, plc.DataType(plc.TypeId.STRING)),
236+
None
237+
if ignore_nulls
238+
else plc.Scalar.from_py(None, plc.DataType(plc.TypeId.STRING)),
239+
None,
240+
plc.strings.combine.SeparatorOnNulls.NO,
241+
)
242+
)
243+
elif self.name is StringFunction.Name.ConcatVertical:
216244
(child,) = self.children
217245
column = child.evaluate(df, context=context)
218246
delimiter, ignore_nulls = self.options

python/cudf_polars/cudf_polars/dsl/ir.py

Lines changed: 1 addition & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from cudf_polars.dsl.nodebase import Node
3636
from cudf_polars.dsl.to_ast import to_ast, to_parquet_filter
3737
from cudf_polars.dsl.tracing import nvtx_annotate_cudf_polars
38+
from cudf_polars.dsl.utils.reshape import broadcast
3839
from cudf_polars.dsl.utils.windows import range_window_bounds
3940
from cudf_polars.utils import dtypes
4041
from cudf_polars.utils.versions import POLARS_VERSION_LT_128
@@ -80,71 +81,6 @@
8081
]
8182

8283

83-
def broadcast(*columns: Column, target_length: int | None = None) -> list[Column]:
84-
"""
85-
Broadcast a sequence of columns to a common length.
86-
87-
Parameters
88-
----------
89-
columns
90-
Columns to broadcast.
91-
target_length
92-
Optional length to broadcast to. If not provided, uses the
93-
non-unit length of existing columns.
94-
95-
Returns
96-
-------
97-
List of broadcasted columns all of the same length.
98-
99-
Raises
100-
------
101-
RuntimeError
102-
If broadcasting is not possible.
103-
104-
Notes
105-
-----
106-
In evaluation of a set of expressions, polars type-puns length-1
107-
columns with scalars. When we insert these into a DataFrame
108-
object, we need to ensure they are of equal length. This function
109-
takes some columns, some of which may be length-1 and ensures that
110-
all length-1 columns are broadcast to the length of the others.
111-
112-
Broadcasting is only possible if the set of lengths of the input
113-
columns is a subset of ``{1, n}`` for some (fixed) ``n``. If
114-
``target_length`` is provided and not all columns are length-1
115-
(i.e. ``n != 1``), then ``target_length`` must be equal to ``n``.
116-
"""
117-
if len(columns) == 0:
118-
return []
119-
lengths: set[int] = {column.size for column in columns}
120-
if lengths == {1}:
121-
if target_length is None:
122-
return list(columns)
123-
nrows = target_length
124-
else:
125-
try:
126-
(nrows,) = lengths.difference([1])
127-
except ValueError as e:
128-
raise RuntimeError("Mismatching column lengths") from e
129-
if target_length is not None and nrows != target_length:
130-
raise RuntimeError(
131-
f"Cannot broadcast columns of length {nrows=} to {target_length=}"
132-
)
133-
return [
134-
column
135-
if column.size != 1
136-
else Column(
137-
plc.Column.from_scalar(column.obj_scalar, nrows),
138-
is_sorted=plc.types.Sorted.YES,
139-
order=plc.types.Order.ASCENDING,
140-
null_order=plc.types.NullOrder.BEFORE,
141-
name=column.name,
142-
dtype=column.dtype,
143-
)
144-
for column in columns
145-
]
146-
147-
14884
class IR(Node["IR"]):
14985
"""Abstract plan node, representing an unevaluated dataframe."""
15086

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-License-Identifier: Apache-2.0
3+
"""Utilities for reshaping Columns."""
4+
5+
from __future__ import annotations
6+
7+
import pylibcudf as plc
8+
9+
from cudf_polars.containers import Column
10+
11+
12+
def broadcast(*columns: Column, target_length: int | None = None) -> list[Column]:
13+
"""
14+
Broadcast a sequence of columns to a common length.
15+
16+
Parameters
17+
----------
18+
columns
19+
Columns to broadcast.
20+
target_length
21+
Optional length to broadcast to. If not provided, uses the
22+
non-unit length of existing columns.
23+
24+
Returns
25+
-------
26+
List of broadcasted columns all of the same length.
27+
28+
Raises
29+
------
30+
RuntimeError
31+
If broadcasting is not possible.
32+
33+
Notes
34+
-----
35+
In evaluation of a set of expressions, polars type-puns length-1
36+
columns with scalars. When we insert these into a DataFrame
37+
object, we need to ensure they are of equal length. This function
38+
takes some columns, some of which may be length-1 and ensures that
39+
all length-1 columns are broadcast to the length of the others.
40+
41+
Broadcasting is only possible if the set of lengths of the input
42+
columns is a subset of ``{1, n}`` for some (fixed) ``n``. If
43+
``target_length`` is provided and not all columns are length-1
44+
(i.e. ``n != 1``), then ``target_length`` must be equal to ``n``.
45+
"""
46+
if len(columns) == 0:
47+
return []
48+
lengths: set[int] = {column.size for column in columns}
49+
if lengths == {1}:
50+
if target_length is None:
51+
return list(columns)
52+
nrows = target_length
53+
else:
54+
try:
55+
(nrows,) = lengths.difference([1])
56+
except ValueError as e:
57+
raise RuntimeError("Mismatching column lengths") from e
58+
if target_length is not None and nrows != target_length:
59+
raise RuntimeError(
60+
f"Cannot broadcast columns of length {nrows=} to {target_length=}"
61+
)
62+
return [
63+
column
64+
if column.size != 1
65+
else Column(
66+
plc.Column.from_scalar(column.obj_scalar, nrows),
67+
is_sorted=plc.types.Sorted.YES,
68+
order=plc.types.Order.ASCENDING,
69+
null_order=plc.types.NullOrder.BEFORE,
70+
name=column.name,
71+
dtype=column.dtype,
72+
)
73+
for column in columns
74+
]

python/cudf_polars/tests/expressions/test_stringfunction.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,3 +498,12 @@ def test_string_tail(ldf, tail):
498498
def test_string_head(ldf, head):
499499
q = ldf.select(pl.col("a").str.head(head))
500500
assert_gpu_result_equal(q)
501+
502+
503+
@pytest.mark.parametrize("ignore_nulls", [True, False])
504+
@pytest.mark.parametrize("separator", ["*", ""])
505+
def test_concat_horizontal(ldf, ignore_nulls, separator):
506+
q = ldf.select(
507+
pl.concat_str(["a", "c"], separator=separator, ignore_nulls=ignore_nulls)
508+
)
509+
assert_gpu_result_equal(q)

0 commit comments

Comments
 (0)