Skip to content

Commit 925c181

Browse files
authored
Avoid pylibcudf.interop.to_arrow in DataFrame.to_polars in cudf_polars (#19198)
This PR and #18564 are the last PRs removing `pylibcudf.interop.to_arrow` in `cudf_polars`. Towards #18534 Authors: - Matthew Roeschke (https://github.com/mroeschke) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) URL: #19198
1 parent fb1628a commit 925c181

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

python/cudf_polars/cudf_polars/containers/dataframe.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
if TYPE_CHECKING:
1919
from collections.abc import Iterable, Mapping, Sequence, Set
2020

21-
from typing_extensions import Any, Self
21+
from typing_extensions import Any, CapsuleType, Self
2222

2323
from cudf_polars.typing import ColumnOptions, DataFrameHeader, Slice
2424

@@ -40,6 +40,20 @@ def _create_polars_column_metadata(
4040
return plc.interop.ColumnMetadata(name=name, children_meta=children_meta)
4141

4242

43+
# This is also defined in pylibcudf.interop
44+
class _ObjectWithArrowMetadata:
45+
def __init__(
46+
self, obj: plc.Table, metadata: list[plc.interop.ColumnMetadata]
47+
) -> None:
48+
self.obj = obj
49+
self.metadata = metadata
50+
51+
def __arrow_c_array__(
52+
self, requested_schema: None = None
53+
) -> tuple[CapsuleType, CapsuleType]:
54+
return self.obj._to_schema(self.metadata), self.obj._to_host_array()
55+
56+
4357
# Pacify the type checker. DataFrame init asserts that all the columns
4458
# have a string name, so let's narrow the type.
4559
class NamedColumn(Column):
@@ -82,8 +96,8 @@ def to_polars(self) -> pl.DataFrame:
8296
)
8397
for name, col in zip(name_map, self.columns, strict=True)
8498
]
85-
table = plc.interop.to_arrow(self.table, metadata=metadata)
86-
df: pl.DataFrame = pl.from_arrow(table)
99+
table_with_metadata = _ObjectWithArrowMetadata(self.table, metadata)
100+
df = pl.DataFrame(table_with_metadata)
87101
return df.rename(name_map).with_columns(
88102
pl.col(c.name).set_sorted(descending=c.order == plc.types.Order.DESCENDING)
89103
if c.is_sorted

0 commit comments

Comments
 (0)