Skip to content

Commit 349a91d

Browse files
authored
Merge pull request #260 from lincc-frameworks/nf-autocast-list-struct
Autocast list-struct columns to nested
2 parents 9dafe49 + b72ed72 commit 349a91d

File tree

3 files changed

+85
-7
lines changed

3 files changed

+85
-7
lines changed

src/nested_pandas/nestedframe/core.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
from nested_pandas.series.dtype import NestedDtype
2020
from nested_pandas.series.packer import pack, pack_lists, pack_sorted_df_into_struct
21+
from nested_pandas.series.utils import is_pa_type_a_list
2122

2223
pd.set_option("display.max_rows", 30)
2324
pd.set_option("display.min_rows", 5)
@@ -41,6 +42,32 @@ class NestedFrame(pd.DataFrame):
4142
# contains those aliases, keyed by the cleaned name.
4243
_metadata = ["_aliases"]
4344

45+
def __init__(self, *args, **kwargs) -> None:
46+
super().__init__(*args, **kwargs)
47+
self._cast_cols_to_nested(struct_list=False)
48+
49+
def _cast_cols_to_nested(self, *, struct_list: bool) -> None:
50+
"""Cast arrow columns to nested.
51+
52+
Parameters
53+
----------
54+
struct_list : bool
55+
If `False` cast list-struct columns only. If `True`, also
56+
try to cast struct-list columns validating if they have
57+
valid nested structure.
58+
"""
59+
for column, dtype in self.dtypes.items():
60+
if not isinstance(dtype, pd.ArrowDtype):
61+
continue
62+
pa_type = dtype.pyarrow_dtype
63+
if not is_pa_type_a_list(pa_type) and not (struct_list and pa.types.is_struct(pa_type)):
64+
continue
65+
try:
66+
nested_dtype = NestedDtype(pa_type)
67+
except (TypeError, ValueError):
68+
continue
69+
self[column] = self[column].astype(nested_dtype)
70+
4471
@property
4572
def _constructor(self) -> Self: # type: ignore[name-defined] # noqa: F821
4673
return NestedFrame
@@ -224,7 +251,8 @@ def __setitem__(self, key, value):
224251
self._update_inplace(new_df)
225252
return None
226253

227-
return super().__setitem__(key, value)
254+
super().__setitem__(key, value)
255+
self._cast_cols_to_nested(struct_list=False)
228256

229257
def add_nested(
230258
self,

src/nested_pandas/series/ext_array.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,11 @@ def _from_sequence(cls, scalars, *, dtype=None, copy: bool = False) -> Self: #
240240
"""
241241
del copy
242242

243-
if isinstance(dtype, NestedDtype) and isinstance(scalars, (cls, pa.Array, pa.ChunkedArray)):
244-
if is_pa_type_a_list(scalars.type):
245-
return cls(scalars.cast(dtype.list_struct_pa_dtype))
246-
elif pa.types.is_struct(scalars.type):
247-
return cls(scalars.cast(dtype.pyarrow_dtype))
243+
if isinstance(dtype, NestedDtype):
244+
try:
245+
return cls._from_arrow_like(scalars, dtype=dtype)
246+
except ValueError:
247+
pass
248248

249249
pa_type = to_pyarrow_dtype(dtype)
250250
pa_array = cls._box_pa_array(scalars, pa_type=pa_type)
@@ -663,6 +663,29 @@ def _box_pa_array(cls, value, *, pa_type: pa.DataType | None) -> pa.Array | pa.C
663663

664664
return pa_array
665665

666+
@classmethod
667+
def _from_arrow_like(cls, arraylike, dtype: NestedDtype | None = None) -> Self: # type: ignore[name-defined] # noqa: F821
668+
if isinstance(arraylike, cls):
669+
if dtype is None or dtype == arraylike.dtype:
670+
return arraylike
671+
array = arraylike.list_array
672+
elif isinstance(arraylike, (pa.Array, pa.ChunkedArray)):
673+
array = arraylike
674+
else:
675+
array = pa.array(arraylike)
676+
677+
if dtype is None:
678+
return cls(array)
679+
680+
try:
681+
cast_array = array.cast(dtype.pyarrow_dtype)
682+
except (ValueError, TypeError, KeyError, pa.ArrowNotImplementedError):
683+
try:
684+
cast_array = array.cast(dtype.list_struct_pa_dtype)
685+
except (ValueError, TypeError, KeyError, pa.ArrowNotImplementedError):
686+
raise ValueError(f"Cannot cast input to {dtype}") from None
687+
return cls(cast_array)
688+
666689
@classmethod
667690
def _convert_struct_scalar_to_df(cls, value: pa.StructScalar, *, copy: bool, na_value: Any = None) -> Any:
668691
"""Converts a struct scalar of equal-length list scalars to a pd.DataFrame

tests/nested_pandas/nestedframe/test_nestedframe.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pandas as pd
33
import pyarrow as pa
44
import pytest
5-
from nested_pandas import NestedFrame
5+
from nested_pandas import NestedDtype, NestedFrame
66
from nested_pandas.datasets import generate_data
77
from nested_pandas.nestedframe.core import _SeriesFromNest
88
from nested_pandas.series.packer import pack_lists
@@ -14,6 +14,17 @@ def test_nestedframe_construction():
1414
base = NestedFrame(data={"a": [1, 2, 3], "b": [2, 4, 6]}, index=[0, 1, 2])
1515

1616
assert isinstance(base, NestedFrame)
17+
assert_frame_equal(base, pd.DataFrame({"a": [1, 2, 3], "b": [2, 4, 6]}, index=[0, 1, 2]))
18+
19+
list_struct_array = pa.array(
20+
[[{"x": 1, "y": 1.0}], [{"x": 2, "y": 2.0}], [{"x": 3, "y": 3.0}, {"x": 4, "y": 4.0}]]
21+
)
22+
list_struct_series = pd.Series(list_struct_array, dtype=pd.ArrowDtype(list_struct_array.type))
23+
nested_series = pd.Series(list_struct_series, dtype=NestedDtype(list_struct_array.type))
24+
25+
nf = NestedFrame(base.to_dict(orient="series") | {"list_struct": list_struct_series})
26+
# Test auto-cast to nested
27+
assert_frame_equal(nf, base.assign(list_struct=nested_series))
1728

1829

1930
def test_nestedseries_construction():
@@ -223,6 +234,22 @@ def test_set_new_nested_col():
223234
)
224235

225236

237+
def test_set_list_struct_col():
238+
"""Test that __setitem__ would cast list-struct columns to nested."""
239+
nf = generate_data(10, 3)
240+
nf["a"] = nf["a"].astype(pd.ArrowDtype(pa.float64()))
241+
nf["b"] = nf["b"].astype(pd.ArrowDtype(pa.float64()))
242+
243+
list_struct_array = pa.array(nf.nested)
244+
list_struct_series = pd.Series(list_struct_array, dtype=pd.ArrowDtype(list_struct_array.type))
245+
246+
nf["nested2"] = list_struct_series
247+
assert_frame_equal(nf.nested.nest.to_flat(), nf.nested2.nest.to_flat())
248+
249+
nf = nf.assign(nested3=list_struct_series)
250+
assert_frame_equal(nf.nested.nest.to_flat(), nf.nested3.nest.to_flat())
251+
252+
226253
def test_get_dot_names():
227254
"""Test the ability to still work with column names with '.' characters outside of nesting"""
228255
nf = NestedFrame.from_flat(

0 commit comments

Comments
 (0)