Skip to content

Commit 1ccf6dc

Browse files
authored
Remove pyarrow from cudf_polars tests (#19219)
With #19193 and this PR, we'll not import `pyarrow` explicitly in `cudf_polars` xref #18534 Authors: - Matthew Roeschke (https://github.com/mroeschke) Approvers: - Matthew Murray (https://github.com/Matt711) URL: #19219
1 parent 7b536ee commit 1ccf6dc

File tree

4 files changed

+40
-55
lines changed

4 files changed

+40
-55
lines changed

python/cudf_polars/tests/containers/test_column.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from __future__ import annotations
55

6-
import pyarrow
76
import pytest
87

98
import polars as pl
@@ -71,21 +70,23 @@ def test_shallow_copy():
7170
@pytest.mark.parametrize("typeid", [pl.Int8(), pl.Float32()])
7271
def test_mask_nans(typeid):
7372
dtype = DataType(typeid)
74-
values = pyarrow.array([0, 0, 0], type=plc.interop.to_arrow(dtype.plc))
75-
column = Column(plc.Column.from_arrow(values), dtype=dtype)
73+
column = Column(
74+
plc.Column.from_iterable_of_py([0, 0, 0], dtype=dtype.plc), dtype=dtype
75+
)
7676
masked = column.mask_nans()
7777
assert column.null_count == masked.null_count
7878

7979

8080
def test_mask_nans_float():
8181
dtype = DataType(pl.Float32())
82-
values = pyarrow.array([0, 0, float("nan")], type=plc.interop.to_arrow(dtype.plc))
83-
column = Column(plc.Column.from_arrow(values), dtype=dtype)
82+
column = Column(
83+
plc.Column.from_iterable_of_py([0, 0, float("nan")], dtype=dtype.plc),
84+
dtype=dtype,
85+
)
8486
masked = column.mask_nans()
85-
expect = pyarrow.array([0, 0, None], type=plc.interop.to_arrow(dtype.plc))
86-
got = pyarrow.array(plc.interop.to_arrow(masked.obj))
87-
88-
assert expect == got
87+
assert masked.nan_count == 0
88+
assert masked.slice((0, 2)).null_count == 0
89+
assert masked.slice((2, 1)).null_count == 1
8990

9091

9192
def test_slice_none_returns_self():

python/cudf_polars/tests/containers/test_dataframe.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from __future__ import annotations
55

6-
import pyarrow as pa
76
import pytest
87

98
import polars as pl
@@ -175,20 +174,18 @@ def test_empty_name_roundtrips_no_overlap():
175174

176175

177176
@pytest.mark.parametrize(
178-
"arrow_tbl",
177+
"polars_tbl",
179178
[
180-
pa.table([]),
181-
pa.table({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}),
182-
pa.table({"a": [1, 2, 3]}),
183-
pa.table({"a": [1], "b": [2], "c": [3]}),
184-
pa.table({"a": ["a", "bb", "ccc"]}),
185-
pa.table({"a": [1, 2, None], "b": [None, 3, 4]}),
179+
pl.DataFrame(),
180+
pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}),
181+
pl.DataFrame({"a": [1, 2, 3]}),
182+
pl.DataFrame({"a": [1], "b": [2], "c": [3]}),
183+
pl.DataFrame({"a": ["a", "bb", "ccc"]}),
184+
pl.DataFrame({"a": [1, 2, None], "b": [None, 3, 4]}),
186185
],
187186
)
188-
def test_serialization_roundtrip(arrow_tbl):
189-
plc_tbl = plc.Table.from_arrow(arrow_tbl)
190-
dtypes = [DataType(pl_type) for pl_type in pl.from_arrow(arrow_tbl).dtypes]
191-
df = DataFrame.from_table(plc_tbl, names=arrow_tbl.column_names, dtypes=dtypes)
187+
def test_serialization_roundtrip(polars_tbl):
188+
df = DataFrame.from_polars(polars_tbl)
192189

193190
header, frames = df.serialize()
194191
res = DataFrame.deserialize(header, frames)

python/cudf_polars/tests/experimental/test_dask_serialize.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,16 @@
33

44
from __future__ import annotations
55

6-
import numpy as np
7-
import pyarrow as pa
86
import pytest
97
from distributed.protocol import deserialize, serialize
108

119
import polars as pl
1210
from polars.testing.asserts import assert_frame_equal
1311

14-
import pylibcudf as plc
1512
import rmm
1613
from rmm.pylibrmm.stream import DEFAULT_STREAM
1714

18-
from cudf_polars.containers import DataFrame, DataType
15+
from cudf_polars.containers import DataFrame
1916
from cudf_polars.experimental.dask_registers import register
2017

2118
# Must register serializers before running tests
@@ -33,15 +30,15 @@ def convert_to_rmm(frame):
3330

3431

3532
@pytest.mark.parametrize(
36-
"arrow_tbl",
33+
"polars_tbl",
3734
[
38-
pa.table([]),
39-
pa.table({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}),
40-
pa.table({"a": [1, 2, 3]}),
41-
pa.table({"a": [1], "b": [2], "c": [3]}),
42-
pa.table({"a": ["a", "bb", "ccc"]}),
43-
pa.table({"a": [1, 2, None], "b": [None, 3, 4]}),
44-
pa.table({"a": pa.array(np.arange(1e7))}),
35+
pl.DataFrame(),
36+
pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}),
37+
pl.DataFrame({"a": [1, 2, 3]}),
38+
pl.DataFrame({"a": [1], "b": [2], "c": [3]}),
39+
pl.DataFrame({"a": ["a", "bb", "ccc"]}),
40+
pl.DataFrame({"a": [1, 2, None], "b": [None, 3, 4]}),
41+
pl.DataFrame({"a": range(int(1e7))}),
4542
],
4643
)
4744
@pytest.mark.parametrize("protocol", ["cuda", "cuda_rmm", "dask"])
@@ -57,10 +54,8 @@ def convert_to_rmm(frame):
5754
},
5855
],
5956
)
60-
def test_dask_serialization_roundtrip(arrow_tbl, protocol, context):
61-
plc_tbl = plc.Table.from_arrow(arrow_tbl)
62-
dtypes = [DataType(pl_type) for pl_type in pl.from_arrow(arrow_tbl).dtypes]
63-
df = DataFrame.from_table(plc_tbl, names=arrow_tbl.column_names, dtypes=dtypes)
57+
def test_dask_serialization_roundtrip(polars_tbl, protocol, context):
58+
df = DataFrame.from_polars(polars_tbl)
6459

6560
cuda_rmm = protocol == "cuda_rmm"
6661
protocol = "cuda" if protocol == "cuda_rmm" else protocol
@@ -91,10 +86,7 @@ def test_dask_serialization_roundtrip(arrow_tbl, protocol, context):
9186

9287

9388
def test_dask_serialization_error():
94-
arrow_tbl = pa.table({"a": [1, 2, 3]})
95-
plc_tbl = plc.Table.from_arrow(arrow_tbl)
96-
dtypes = [DataType(pl_type) for pl_type in pl.from_arrow(arrow_tbl).dtypes]
97-
df = DataFrame.from_table(plc_tbl, names=arrow_tbl.column_names, dtypes=dtypes)
89+
df = DataFrame.from_polars(pl.DataFrame({"a": [1, 2, 3]}))
9890

9991
header, frames = serialize(
10092
df,

python/cudf_polars/tests/experimental/test_dask_sizeof.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,29 @@
33

44
from __future__ import annotations
55

6-
import pyarrow as pa
76
import pytest
87
from dask.sizeof import sizeof
98

109
import polars as pl
1110

12-
import pylibcudf as plc
13-
14-
from cudf_polars.containers import DataFrame, DataType
11+
from cudf_polars.containers import DataFrame
1512
from cudf_polars.experimental.dask_registers import register
1613

1714
# Must register sizeof dispatch before running tests
1815
register()
1916

2017

2118
@pytest.mark.parametrize(
22-
"arrow_tbl, size",
19+
"polars_tbl, size",
2320
[
24-
(pa.table([]), 0),
25-
(pa.table({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}), 9 * 8),
26-
(pa.table({"a": [1, 2, 3]}), 3 * 8),
27-
(pa.table({"a": ["a"], "b": ["bc"]}), 2 * 8 + 3),
28-
(pa.table({"a": [1, 2, None]}), 88),
21+
(pl.DataFrame(), 0),
22+
(pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}), 9 * 8),
23+
(pl.DataFrame({"a": [1, 2, 3]}), 3 * 8),
24+
(pl.DataFrame({"a": ["a"], "b": ["bc"]}), 2 * 8 + 3),
25+
(pl.DataFrame({"a": [1, 2, None]}), 88),
2926
],
3027
)
31-
def test_dask_sizeof(arrow_tbl, size):
32-
plc_tbl = plc.Table.from_arrow(arrow_tbl)
33-
dtypes = [DataType(pl_type) for pl_type in pl.from_arrow(arrow_tbl).dtypes]
34-
df = DataFrame.from_table(plc_tbl, names=arrow_tbl.column_names, dtypes=dtypes)
28+
def test_dask_sizeof(polars_tbl, size):
29+
df = DataFrame.from_polars(polars_tbl)
3530
assert sizeof(df) == size
3631
assert sum(sizeof(c) for c in df.columns) == size

0 commit comments

Comments
 (0)