Skip to content

Require dtype argument to cudf_polars Column container #19193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
fba9e00
Support pl.struct
mroeschke Jun 2, 2025
1268bb5
Pass along dtype in Column container
mroeschke Jun 3, 2025
685c3fa
Merge remote-tracking branch 'upstream/branch-25.08' into feat/cudf_p…
mroeschke Jun 3, 2025
57e58b0
Recursively set struct fields in ColumnMetadata
mroeschke Jun 3, 2025
ac841f2
Merge branch 'branch-25.08' into feat/cudf_polars/struct_expr
mroeschke Jun 3, 2025
2bd55a6
Merge remote-tracking branch 'upstream/branch-25.08' into feat/cudf_p…
mroeschke Jun 4, 2025
91563c1
Merge remote-tracking branch 'upstream/branch-25.08' into feat/cudf_p…
mroeschke Jun 12, 2025
c0a412b
Merge branch 'feat/cudf_polars/struct_expr' of https://github.com/mro…
mroeschke Jun 12, 2025
0364f45
Merge branch 'branch-25.08' into feat/cudf_polars/struct_expr
mroeschke Jun 13, 2025
83137d7
Merge remote-tracking branch 'upstream/branch-25.08' into feat/cudf_p…
mroeschke Jun 13, 2025
d45201f
Make can_cast return False for nested types
mroeschke Jun 13, 2025
446f4da
Merge remote-tracking branch 'upstream/branch-25.08' into feat/cudf_p…
mroeschke Jun 16, 2025
05a607e
Add tests to xfail list
mroeschke Jun 16, 2025
aeeb515
Merge remote-tracking branch 'upstream/branch-25.08' into feat/cudf_p…
mroeschke Jun 16, 2025
aab241e
Merge remote-tracking branch 'upstream/branch-25.08' into feat/cudf_p…
mroeschke Jun 17, 2025
753ba58
Replace with polars issue
mroeschke Jun 17, 2025
87793b5
Add test for nested struct
mroeschke Jun 17, 2025
9b0f2ee
Merge branch 'branch-25.08' into feat/cudf_polars/struct_expr
mroeschke Jun 17, 2025
87af891
Add dtypes argument to DataFrame.from_table
mroeschke Jun 17, 2025
2b6f848
Remove tests that require from_table to have a dtype
mroeschke Jun 17, 2025
92b23a3
Make dtype a required argument in Column
mroeschke Jun 17, 2025
32b9949
Remove getattr workaround
mroeschke Jun 17, 2025
0c611cd
Merge remote-tracking branch 'upstream/branch-25.08' into ref/cudf_po…
mroeschke Jun 18, 2025
1327b44
Add _dtype_short_repr_to_dtype to address nested type limitations
mroeschke Jun 18, 2025
37f76ac
Merge remote-tracking branch 'upstream/branch-25.08' into ref/cudf_po…
mroeschke Jun 18, 2025
2184dfd
Use _ for value_counts unused vars
mroeschke Jun 18, 2025
aa147f1
Add div by zero value_counts case
mroeschke Jun 18, 2025
a44f5e4
Add test for value_counts not implmented in groupby
mroeschke Jun 18, 2025
289f81b
Add back xfailing tests that require nested list[struct] support
mroeschke Jun 18, 2025
59036a2
Merge remote-tracking branch 'upstream/branch-25.08' into ref/cudf_po…
mroeschke Jun 18, 2025
1bccd42
Ensure we return a pl.DataType instance
mroeschke Jun 18, 2025
6516650
Merge remote-tracking branch 'upstream/branch-25.08' into ref/cudf_po…
mroeschke Jun 18, 2025
2292a95
Raise during deserialization for invalid dtype
mroeschke Jun 18, 2025
2cecd8a
Raise during deserialization for invalid dtype
mroeschke Jun 18, 2025
0560c3b
Add test for custom list parsing
mroeschke Jun 18, 2025
e320742
Merge remote-tracking branch 'upstream/branch-25.08' into ref/cudf_po…
mroeschke Jun 25, 2025
760c084
Remove pyarrow import from boolean.py
mroeschke Jun 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions python/cudf_polars/cudf_polars/containers/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from __future__ import annotations

import functools
import inspect
from typing import TYPE_CHECKING

import polars as pl
Expand Down Expand Up @@ -37,6 +38,19 @@
__all__: list[str] = ["Column"]


def _dtype_short_repr_to_dtype(dtype_str: str) -> pl.DataType:
"""Convert a Polars dtype short repr to a Polars dtype."""
# limitations of dtype_short_repr_to_dtype described in
# py-polars/polars/datatypes/convert.py#L299
if dtype_str.startswith("list["):
stripped = dtype_str.removeprefix("list[").removesuffix("]")
return pl.List(_dtype_short_repr_to_dtype(stripped))
pl_type = pl.datatypes.convert.dtype_short_repr_to_dtype(dtype_str)
if pl_type is None:
raise ValueError(f"{dtype_str} was not able to be parsed by Polars.")
return pl_type() if inspect.isclass(pl_type) else pl_type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How safe is pl_type(), without any arguments, here? Some types (Array, Enum) require additional arguments. Maybe we don't support those yet?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the types we support in DataType, I believe this is fairly safe as I'm hoping that dtype_short_repr_to_dtype will return instances for types with parameters (polars.Datetime and polars.Duration).

For those types that we don't support that take arguments, those should be rejected when constructing a DataType



class Column:
"""An immutable column with sortedness metadata."""

Expand All @@ -48,19 +62,17 @@ class Column:
# Optional name, only ever set by evaluation of NamedExpr nodes
# The internal evaluation should not care about the name.
name: str | None
# Optional dtype, used for preserving dtype metadata like
# struct fields
dtype: DataType | None
dtype: DataType

def __init__(
self,
column: plc.Column,
dtype: DataType,
*,
is_sorted: plc.types.Sorted = plc.types.Sorted.NO,
order: plc.types.Order = plc.types.Order.ASCENDING,
null_order: plc.types.NullOrder = plc.types.NullOrder.BEFORE,
name: str | None = None,
dtype: DataType | None = None,
):
self.obj = column
self.is_scalar = self.size == 1
Expand Down Expand Up @@ -98,12 +110,9 @@ def deserialize_ctor_kwargs(
column_kwargs: ColumnOptions,
) -> DeserializedColumnOptions:
"""Deserialize the constructor kwargs for a Column."""
if (serialized_dtype := column_kwargs.get("dtype", None)) is not None:
dtype: DataType | None = DataType( # pragma: no cover
pl.datatypes.convert.dtype_short_repr_to_dtype(serialized_dtype)
)
else: # pragma: no cover
dtype = None # pragma: no cover
dtype = DataType( # pragma: no cover
_dtype_short_repr_to_dtype(column_kwargs["dtype"])
)
return {
"is_sorted": column_kwargs["is_sorted"],
"order": column_kwargs["order"],
Expand Down Expand Up @@ -142,15 +151,12 @@ def serialize(

def serialize_ctor_kwargs(self) -> ColumnOptions:
"""Serialize the constructor kwargs for self."""
serialized_dtype = (
None if self.dtype is None else pl.polars.dtype_str_repr(self.dtype.polars)
)
return {
"is_sorted": self.is_sorted,
"order": self.order,
"null_order": self.null_order,
"name": self.name,
"dtype": serialized_dtype,
"dtype": pl.polars.dtype_str_repr(self.dtype.polars),
}

@functools.cached_property
Expand Down Expand Up @@ -406,7 +412,7 @@ def mask_nans(self) -> Self:
if plc.traits.is_floating_point(self.obj.type()):
old_count = self.null_count
mask, new_count = plc.transform.nans_to_nulls(self.obj)
result = type(self)(self.obj.with_mask(mask, new_count))
result = type(self)(self.obj.with_mask(mask, new_count), self.dtype)
if old_count == new_count:
return result.sorted_like(self)
return result
Expand Down Expand Up @@ -454,4 +460,4 @@ def slice(self, zlice: Slice | None) -> Self:
conversion.from_polars_slice(zlice, num_rows=self.size),
)
(column,) = table.columns()
return type(self)(column, name=self.name).sorted_like(self)
return type(self)(column, name=self.name, dtype=self.dtype).sorted_like(self)
34 changes: 21 additions & 13 deletions python/cudf_polars/cudf_polars/containers/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


def _create_polars_column_metadata(
name: str | None, dtype: pl.DataType | None
name: str | None, dtype: pl.DataType
) -> plc.interop.ColumnMetadata:
"""Create ColumnMetadata preserving pl.Struct field names."""
if isinstance(dtype, pl.Struct):
Expand Down Expand Up @@ -72,6 +72,7 @@ def __init__(self, columns: Iterable[Column]) -> None:
if any(c.name is None for c in columns):
raise ValueError("All columns must have a name")
self.columns = [cast(NamedColumn, c) for c in columns]
self.dtypes = [c.dtype for c in self.columns]
self.column_map = {c.name: c for c in self.columns}
self.table = plc.Table([c.obj for c in self.columns])

Expand All @@ -89,12 +90,8 @@ def to_polars(self) -> pl.DataFrame:
# serialise with names we control and rename with that map.
name_map = {f"column_{i}": name for i, name in enumerate(self.column_map)}
metadata = [
_create_polars_column_metadata(
name,
# Can remove the getattr if we ever consistently set Column.dtype
getattr(col.dtype, "polars", None),
)
for name, col in zip(name_map, self.columns, strict=True)
_create_polars_column_metadata(name, dtype.polars)
for name, dtype in zip(name_map, self.dtypes, strict=True)
]
table_with_metadata = _ObjectWithArrowMetadata(self.table, metadata)
df = pl.DataFrame(table_with_metadata)
Expand Down Expand Up @@ -148,7 +145,9 @@ def from_polars(cls, df: pl.DataFrame) -> Self:
)

@classmethod
def from_table(cls, table: plc.Table, names: Sequence[str]) -> Self:
def from_table(
cls, table: plc.Table, names: Sequence[str], dtypes: Sequence[DataType]
) -> Self:
"""
Create from a pylibcudf table.

Expand All @@ -158,6 +157,8 @@ def from_table(cls, table: plc.Table, names: Sequence[str]) -> Self:
Pylibcudf table to obtain columns from
names
Names for the columns
dtypes
Dtypes for the columns

Returns
-------
Expand All @@ -172,9 +173,8 @@ def from_table(cls, table: plc.Table, names: Sequence[str]) -> Self:
if table.num_columns() != len(names):
raise ValueError("Mismatching name and table length.")
return cls(
# TODO: Pass along dtypes here
Column(c, name=name)
for c, name in zip(table.columns(), names, strict=True)
Column(c, name=name, dtype=dtype)
for c, name, dtype in zip(table.columns(), names, dtypes, strict=True)
)

@classmethod
Expand Down Expand Up @@ -317,7 +317,11 @@ def select_columns(self, names: Set[str]) -> list[Column]:
def filter(self, mask: Column) -> Self:
"""Return a filtered table given a mask."""
table = plc.stream_compaction.apply_boolean_mask(self.table, mask.obj)
return type(self).from_table(table, self.column_names).sorted_like(self)
return (
type(self)
.from_table(table, self.column_names, self.dtypes)
.sorted_like(self)
)

def slice(self, zlice: Slice | None) -> Self:
"""
Expand All @@ -338,4 +342,8 @@ def slice(self, zlice: Slice | None) -> Self:
(table,) = plc.copying.slice(
self.table, conversion.from_polars_slice(zlice, num_rows=self.num_rows)
)
return type(self).from_table(table, self.column_names).sorted_like(self)
return (
type(self)
.from_table(table, self.column_names, self.dtypes)
.sorted_like(self)
)
24 changes: 6 additions & 18 deletions python/cudf_polars/cudf_polars/dsl/expressions/boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,13 @@
from functools import partial, reduce
from typing import TYPE_CHECKING, Any, ClassVar

import pyarrow as pa

import pylibcudf as plc

from cudf_polars.containers import Column
from cudf_polars.containers import Column, DataType
from cudf_polars.dsl.expressions.base import (
ExecutionContext,
Expr,
)
from cudf_polars.dsl.expressions.literal import LiteralColumn
from cudf_polars.utils.versions import POLARS_VERSION_LT_128

if TYPE_CHECKING:
Expand All @@ -28,7 +25,7 @@
import polars.type_aliases as pl_types
from polars.polars import _expr_nodes as pl_expr

from cudf_polars.containers import DataFrame, DataType
from cudf_polars.containers import DataFrame

__all__ = ["BooleanFunction"]

Expand Down Expand Up @@ -99,15 +96,6 @@ def __init__(
# TODO: If polars IR doesn't put the casts in, we need to
# mimic the supertype promotion rules.
raise NotImplementedError("IsIn doesn't support supertype casting")
if self.name is BooleanFunction.Name.IsIn:
_, haystack = self.children
# TODO: Use pl.List isinstance check once we have https://github.com/rapidsai/cudf/pull/18564
if isinstance(haystack, LiteralColumn) and isinstance(
haystack.value, pa.ListArray
):
raise NotImplementedError(
"IsIn does not support nested list column input"
) # pragma: no cover

@staticmethod
def _distinct(
Expand Down Expand Up @@ -302,10 +290,10 @@ def do_evaluate(
needles, haystack = columns
if haystack.obj.type().id() == plc.TypeId.LIST:
# Unwrap values from the list column
haystack = Column(haystack.obj.children()[1])
# TODO: Remove check once Column's require dtype
if needles.dtype is not None:
haystack = haystack.astype(needles.dtype)
haystack = Column(
haystack.obj.children()[1],
dtype=DataType(haystack.dtype.polars.inner),
).astype(needles.dtype)
if haystack.size:
return Column(
plc.search.contains(haystack.obj, needles.obj), dtype=self.dtype
Expand Down
42 changes: 22 additions & 20 deletions python/cudf_polars/cudf_polars/dsl/expressions/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
from enum import IntEnum, auto
from typing import TYPE_CHECKING, Any

import polars as pl
from polars.exceptions import InvalidOperationError

import pylibcudf as plc

from cudf_polars.containers import Column, DataType
from cudf_polars.containers import Column
from cudf_polars.dsl.expressions.base import ExecutionContext, Expr
from cudf_polars.dsl.expressions.literal import Literal, LiteralColumn
from cudf_polars.dsl.utils.reshape import broadcast
Expand All @@ -24,7 +23,7 @@

from polars.polars import _expr_nodes as pl_expr

from cudf_polars.containers import DataFrame
from cudf_polars.containers import DataFrame, DataType

__all__ = ["StringFunction"]

Expand Down Expand Up @@ -211,9 +210,9 @@ def do_evaluate(
"""Evaluate this expression given a dataframe for context."""
if self.name is StringFunction.Name.ConcatHorizontal:
columns = [
Column(child.evaluate(df, context=context).obj).astype(
DataType(pl.String())
)
Column(
child.evaluate(df, context=context).obj, dtype=child.dtype
).astype(self.dtype)
for child in self.children
]

Expand All @@ -226,13 +225,12 @@ def do_evaluate(
return Column(
plc.strings.combine.concatenate(
plc.Table([col.obj for col in broadcasted]),
plc.Scalar.from_py(delimiter, plc.DataType(plc.TypeId.STRING)),
None
if ignore_nulls
else plc.Scalar.from_py(None, plc.DataType(plc.TypeId.STRING)),
plc.Scalar.from_py(delimiter, self.dtype.plc),
None if ignore_nulls else plc.Scalar.from_py(None, self.dtype.plc),
None,
plc.strings.combine.SeparatorOnNulls.NO,
)
),
dtype=self.dtype,
)
elif self.name is StringFunction.Name.ConcatVertical:
(child,) = self.children
Expand Down Expand Up @@ -323,20 +321,21 @@ def do_evaluate(
if self.children[1].value is None:
return Column(
plc.Column.from_scalar(
plc.Scalar.from_py(None, plc.DataType(plc.TypeId.STRING)),
plc.Scalar.from_py(None, self.dtype.plc),
column.size,
)
),
self.dtype,
)
elif self.children[1].value == 0:
result = plc.Column.from_scalar(
plc.Scalar.from_py("", plc.DataType(plc.TypeId.STRING)),
plc.Scalar.from_py("", self.dtype.plc),
column.size,
)
if column.obj.null_mask():
result = result.with_mask(
column.obj.null_mask(), column.obj.null_count()
)
return Column(result)
return Column(result, self.dtype)

else:
start = -(self.children[1].value)
Expand All @@ -347,7 +346,8 @@ def do_evaluate(
plc.Scalar.from_py(start, plc.DataType(plc.TypeId.INT32)),
plc.Scalar.from_py(end, plc.DataType(plc.TypeId.INT32)),
None,
)
),
self.dtype,
)
elif self.name is StringFunction.Name.Head:
column = self.children[0].evaluate(df, context=context)
Expand All @@ -358,16 +358,18 @@ def do_evaluate(
if end is None:
return Column(
plc.Column.from_scalar(
plc.Scalar.from_py(None, plc.DataType(plc.TypeId.STRING)),
plc.Scalar.from_py(None, self.dtype.plc),
column.size,
)
),
self.dtype,
)
return Column(
plc.strings.slice.slice_strings(
column.obj,
plc.Scalar.from_py(0, plc.DataType(plc.TypeId.INT32)),
plc.Scalar.from_py(end, plc.DataType(plc.TypeId.INT32)),
)
),
self.dtype,
)

columns = [child.evaluate(df, context=context) for child in self.children]
Expand Down Expand Up @@ -446,7 +448,7 @@ def do_evaluate(
)
elif self.name is StringFunction.Name.Titlecase:
(column,) = columns
return Column(plc.strings.capitalize.title(column.obj))
return Column(plc.strings.capitalize.title(column.obj), dtype=self.dtype)
raise NotImplementedError(
f"StringFunction {self.name}"
) # pragma: no cover; handled by init raising
2 changes: 1 addition & 1 deletion python/cudf_polars/cudf_polars/dsl/expressions/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def do_evaluate(
null_order=null_order,
)
elif self.name == "value_counts":
(sort, parallel, name, normalize) = self.options
(sort, _, _, normalize) = self.options
count_agg = [plc.aggregation.count(plc.types.NullPolicy.INCLUDE)]
gb_requests = [
plc.groupby.GroupByRequest(
Expand Down
Loading