Skip to content

Commit 3d3cab4

Browse files
authored
Merge pull request #203 from lincc-frameworks/extarr-list-struct
Use both struct-list and list-struct representations in the extension array
2 parents a416b6b + 1ef99be commit 3d3cab4

File tree

5 files changed

+417
-56
lines changed

5 files changed

+417
-56
lines changed

src/nested_pandas/series/ext_array.py

Lines changed: 73 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,14 @@
5151
from pandas.io.formats.format import format_array
5252

5353
from nested_pandas.series.dtype import NestedDtype
54-
from nested_pandas.series.utils import enumerate_chunks, is_pa_type_a_list
54+
from nested_pandas.series.utils import (
55+
enumerate_chunks,
56+
is_pa_type_a_list,
57+
transpose_list_struct_array,
58+
transpose_struct_list_array,
59+
transpose_struct_list_type,
60+
validate_struct_list_array_for_equal_lengths,
61+
)
5562

5663
__all__ = ["NestedExtensionArray"]
5764

@@ -549,6 +556,8 @@ def __arrow_array__(self, type=None):
549556
"""Convert the extension array to a PyArrow array."""
550557
if type is None:
551558
return self._chunked_array
559+
if isinstance(type, pa.ListType):
560+
return self._list_array.cast(type)
552561
return self._chunked_array.cast(type)
553562

554563
def __array__(self, dtype=None):
@@ -650,12 +659,27 @@ def __init__(self, values: pa.Array | pa.ChunkedArray, *, validate: bool = True)
650659
if isinstance(values, pa.Array):
651660
values = pa.chunked_array([values])
652661

653-
if validate:
662+
# Convert list-struct array to struct-list array
663+
if is_pa_type_a_list(values.type):
664+
struct_chunks = []
665+
for list_chunk in values.iterchunks():
666+
struct_chunks.append(transpose_list_struct_array(list_chunk))
667+
values = pa.chunked_array(struct_chunks)
668+
# Validate struct-array with list fields
669+
elif validate:
654670
self._validate(values)
655671

656672
self._chunked_array = values
657673
self._dtype = NestedDtype(values.type)
658674

675+
@property
676+
def _list_array(self) -> pa.ChunkedArray:
677+
"""Pyarrow chunked list-struct array representation"""
678+
list_chunks = []
679+
for struct_chunk in self._chunked_array.iterchunks():
680+
list_chunks.append(transpose_struct_list_array(struct_chunk, validate=False))
681+
return pa.chunked_array(list_chunks)
682+
659683
@classmethod
660684
def from_sequence(cls, scalars, *, dtype: NestedDtype | pd.ArrowDtype | pa.DataType = None) -> Self: # type: ignore[name-defined] # noqa: F821
661685
"""Construct a NestedExtensionArray from a sequence of items
@@ -677,53 +701,65 @@ def from_sequence(cls, scalars, *, dtype: NestedDtype | pd.ArrowDtype | pa.DataT
677701
return cls._from_sequence(scalars, dtype=dtype)
678702

679703
@property
680-
def _pyarrow_dtype(self) -> pa.DataType:
704+
def _pyarrow_dtype(self) -> pa.StructType:
681705
"""PyArrow data type of the extension array"""
682706
return self._dtype.pyarrow_dtype
683707

708+
@property
709+
def _pyarrow_list_struct_dtype(self) -> pa.ListType:
710+
"""PyArrow data type of the list-struct view over the ext. array"""
711+
return transpose_struct_list_type(self._pyarrow_dtype)
712+
684713
@property
685714
def chunked_array(self) -> pa.ChunkedArray:
686715
"""The underlying PyArrow ChunkedArray"""
687716
return self._chunked_array
688717

718+
@property
719+
def chunked_list_struct_array(self) -> pa.ChunkedArray:
720+
"""Chunked list-struct view over the extension array"""
721+
return self._list_array
722+
689723
@staticmethod
690724
def _validate(array: pa.ChunkedArray) -> None:
691725
"""Raises ValueError if the input array is not a struct array with all fields being
692726
list arrays of the same lengths.
727+
728+
Parameters
729+
----------
730+
array : pa.ChunkedArray
731+
The array to validate.
732+
733+
Raises
734+
------
735+
ValueError
693736
"""
694737
for chunk in array.iterchunks():
695-
if not pa.types.is_struct(chunk.type):
696-
raise ValueError(f"Expected a StructArray, got {chunk.type}")
697-
struct_array = cast(pa.StructArray, chunk)
698-
699-
first_list_array: pa.ListArray | None = None
700-
for field in struct_array.type:
701-
inner_array = struct_array.field(field.name)
702-
if not is_pa_type_a_list(inner_array.type):
703-
raise ValueError(f"Expected a ListArray, got {inner_array.type}")
704-
list_array = cast(pa.ListArray, inner_array)
705-
706-
if first_list_array is None:
707-
first_list_array = list_array
708-
continue
709-
# compare offsets from the first list array with the current one
710-
if not first_list_array.offsets.equals(list_array.offsets):
711-
raise ValueError("Offsets of all ListArrays must be the same")
738+
validate_struct_list_array_for_equal_lengths(chunk)
712739

713740
@classmethod
714741
def from_arrow_ext_array(cls, array: ArrowExtensionArray) -> Self: # type: ignore[name-defined] # noqa: F821
715742
"""Create a NestedExtensionArray from pandas' ArrowExtensionArray"""
716743
return cls(array._pa_array)
717744

718-
def to_arrow_ext_array(self) -> ArrowExtensionArray:
719-
"""Convert the extension array to pandas' ArrowExtensionArray"""
745+
def to_arrow_ext_array(self, list_struct: bool = False) -> ArrowExtensionArray:
746+
"""Convert the extension array to pandas' ArrowExtensionArray
747+
748+
Parameters
749+
----------
750+
list_struct : bool, optional
751+
If False (default), return struct-list array, otherwise return
752+
list-struct array.
753+
"""
754+
if list_struct:
755+
return ArrowExtensionArray(self._list_array)
720756
return ArrowExtensionArray(self._chunked_array)
721757

722758
def _replace_chunked_array(self, pa_array: pa.ChunkedArray, *, validate: bool) -> None:
723759
if validate:
724760
self._validate(pa_array)
725761
self._chunked_array = pa_array
726-
self._dtype = NestedDtype(pa_array.chunk(0).type)
762+
self._dtype = NestedDtype(pa_array.type)
727763

728764
@property
729765
def list_offsets(self) -> pa.Array:
@@ -737,48 +773,32 @@ def list_offsets(self) -> pa.Array:
737773
pa.ChunkedArray
738774
The list offsets of the field arrays.
739775
"""
740-
# Quick and cheap path for a single chunk
776+
# Cheap path for a single chunk
741777
if self._chunked_array.num_chunks == 1:
742778
struct_array = cast(pa.StructArray, self._chunked_array.chunk(0))
743779
return cast(pa.ListArray, struct_array.field(0)).offsets
744780

745-
chunks = []
746-
# The offset of the current chunk in the flat array.
747-
# Offset arrays use int32 type, so we cast to it
748-
chunk_offset = pa.scalar(0, type=pa.int32())
749-
for chunk in self._chunked_array.iterchunks():
750-
list_array = cast(pa.ListArray, chunk.field(0))
751-
if chunk_offset.equals(pa.scalar(0, type=pa.int32())):
752-
offsets = list_array.offsets
753-
else:
754-
offsets = pa.compute.add(list_array.offsets[1:], chunk_offset)
755-
chunks.append(offsets)
756-
chunk_offset = offsets[-1]
757-
return pa.concat_arrays(chunks)
781+
zero_and_lengths = pa.chunked_array(
782+
[pa.array([0], type=pa.int32()), pa.array(self.list_lengths, type=pa.int32())]
783+
)
784+
offsets = pa.compute.cumulative_sum(zero_and_lengths)
785+
return offsets.chunk(0) if offsets.num_chunks == 1 else offsets.combine_chunks()
758786

759787
@property
760788
def field_names(self) -> list[str]:
761789
"""Names of the nested columns"""
762790
return [field.name for field in self._chunked_array.chunk(0).type]
763791

764-
def _iter_list_lengths(self) -> Generator[int, None, None]:
765-
"""Iterate over the lengths of the list arrays"""
766-
for chunk in self._chunked_array.iterchunks():
767-
for length in chunk.field(0).value_lengths():
768-
if length.is_valid:
769-
yield length.as_py()
770-
else:
771-
yield 0
772-
773792
@property
774-
def list_lengths(self) -> list[int]:
793+
def list_lengths(self) -> np.ndarray:
775794
"""Lengths of the list arrays"""
776-
return list(self._iter_list_lengths())
795+
list_lengths = pa.compute.list_value_length(self._list_array)
796+
return np.asarray(list_lengths)
777797

778798
@property
779799
def flat_length(self) -> int:
780800
"""Length of the flat arrays"""
781-
return sum(self._iter_list_lengths())
801+
return pa.compute.sum(self.list_lengths).as_py()
782802

783803
@property
784804
def num_chunks(self) -> int:
@@ -790,8 +810,8 @@ def get_list_index(self) -> np.ndarray:
790810
if len(self) == 0:
791811
# Since we have no list offsets, return an empty array
792812
return np.array([], dtype=int)
793-
list_index = np.arange(len(self))
794-
return np.repeat(list_index, np.diff(self.list_offsets))
813+
list_index = np.arange(len(self), dtype=int)
814+
return np.repeat(list_index, self.list_lengths)
795815

796816
def iter_field_lists(self, field: str) -> Generator[np.ndarray, None, None]:
797817
"""Iterate over single field nested lists, as numpy arrays
@@ -813,7 +833,7 @@ def iter_field_lists(self, field: str) -> Generator[np.ndarray, None, None]:
813833
yield np.asarray(list_scalar.values)
814834

815835
def view_fields(self, fields: str | list[str]) -> Self: # type: ignore[name-defined] # noqa: F821
816-
"""Get a view of the extension array with only the specified fields
836+
"""Get a view of the extension array with the specified fields only
817837
818838
Parameters
819839
----------
@@ -842,7 +862,7 @@ def view_fields(self, fields: str | list[str]) -> Self: # type: ignore[name-def
842862
chunks.append(struct_array)
843863
pa_array = pa.chunked_array(chunks)
844864

845-
return self.__class__(pa_array, validate=False)
865+
return type(self)(pa_array, validate=False)
846866

847867
def set_flat_field(self, field: str, value: ArrayLike, *, keep_dtype: bool = False) -> None:
848868
"""Set the field from flat-array of values

src/nested_pandas/series/utils.py

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
from __future__ import annotations # Python 3.9 requires it for X | Y type hints
2+
13
from collections.abc import Generator
4+
from typing import cast
25

36
import pyarrow as pa
47

58

6-
def is_pa_type_a_list(pa_type: type[pa.Array]) -> bool:
9+
def is_pa_type_a_list(pa_type: pa.DataType) -> bool:
710
"""Check if the given pyarrow type is a list type.
811
912
I.e. one of the following types: ListArray, LargeListArray,
@@ -39,3 +42,148 @@ def enumerate_chunks(array: pa.ChunkedArray) -> Generator[tuple[slice, pa.Array]
3942
index_stop = index_start + len(chunk)
4043
yield slice(index_start, index_stop), chunk
4144
index_start = index_stop
45+
46+
47+
def validate_struct_list_array_for_equal_lengths(array: pa.StructArray) -> None:
48+
"""Check if the given struct array has lists of equal length.
49+
50+
Parameters
51+
----------
52+
array : pa.StructArray
53+
Input struct array.
54+
55+
Raises
56+
------
57+
ValueError
58+
If the struct array has lists of unequal length or type of the input
59+
array is not a StructArray or fields are not ListArrays.
60+
"""
61+
if not pa.types.is_struct(array.type):
62+
raise ValueError(f"Expected a StructArray, got {array.type}")
63+
64+
first_list_array: pa.ListArray | None = None
65+
for field in array.type:
66+
inner_array = array.field(field.name)
67+
if not is_pa_type_a_list(inner_array.type):
68+
raise ValueError(f"Expected a ListArray, got {inner_array.type}")
69+
list_array = cast(pa.ListArray, inner_array)
70+
71+
if first_list_array is None:
72+
first_list_array = list_array
73+
continue
74+
# compare offsets from the first list array with the current one
75+
if not first_list_array.offsets.equals(list_array.offsets):
76+
raise ValueError("Offsets of all ListArrays must be the same")
77+
78+
79+
def transpose_struct_list_type(t: pa.StructType) -> pa.ListType:
80+
"""Converts a type of struct-list array into a type of list-struct array.
81+
82+
Parameters
83+
----------
84+
t : pa.DataType
85+
Input type of struct-list array.
86+
87+
Returns
88+
-------
89+
pa.DataType
90+
Type of list-struct array.
91+
92+
Raises
93+
------
94+
ValueError
95+
If the input type is not a struct-list type.
96+
"""
97+
if not pa.types.is_struct(t):
98+
raise ValueError(f"Expected a StructType, got {t}")
99+
100+
fields = []
101+
for field in t:
102+
if not is_pa_type_a_list(field.type):
103+
raise ValueError(f"Expected a ListType, got {field.type}")
104+
list_type = cast(pa.ListType, field.type)
105+
fields.append(pa.field(field.name, list_type.value_type))
106+
107+
list_type = cast(pa.ListType, pa.list_(pa.struct(fields)))
108+
return list_type
109+
110+
111+
def transpose_struct_list_array(array: pa.StructArray, validate: bool = True) -> pa.ListArray:
112+
"""Converts a struct-array of lists into a list-array of structs.
113+
114+
Parameters
115+
----------
116+
array : pa.StructArray
117+
Input struct array, each scalar must have lists of equal length.
118+
validate : bool, default True
119+
Whether to validate the input array for list lengths. Raises ValueError
120+
if something is wrong.
121+
122+
Returns
123+
-------
124+
pa.ListArray
125+
List array of structs.
126+
"""
127+
if validate:
128+
validate_struct_list_array_for_equal_lengths(array)
129+
130+
# Since we know that all lists have the same length, we can use the first list to get offsets
131+
offsets = array.field(0).offsets
132+
struct_flat_array = pa.StructArray.from_arrays(
133+
[field.values for field in array.flatten()],
134+
names=array.type.names,
135+
)
136+
return pa.ListArray.from_arrays(offsets, struct_flat_array)
137+
138+
139+
def transpose_list_struct_type(t: pa.ListType) -> pa.StructType:
140+
"""Converts a type of list-struct array into a type of struct-list array.
141+
142+
Parameters
143+
----------
144+
t : pa.DataType
145+
Input type of list-struct array.
146+
147+
Returns
148+
-------
149+
pa.DataType
150+
Type of struct-list array.
151+
152+
Raises
153+
------
154+
ValueError
155+
If the input type is not a list-struct type.
156+
"""
157+
if not is_pa_type_a_list(t):
158+
raise ValueError(f"Expected a ListType, got {t}")
159+
160+
struct_type = cast(pa.StructType, t.value_type)
161+
fields = []
162+
for field in struct_type:
163+
fields.append(pa.field(field.name, pa.list_(field.type)))
164+
165+
struct_type = cast(pa.StructType, pa.struct(fields))
166+
return struct_type
167+
168+
169+
def transpose_list_struct_array(array: pa.ListArray) -> pa.StructArray:
170+
"""Converts a list-array of structs into a struct-array of lists.
171+
172+
Parameters
173+
----------
174+
array : pa.ListArray
175+
Input list array of structs.
176+
177+
Returns
178+
-------
179+
pa.StructArray
180+
Struct array of lists.
181+
"""
182+
offsets, values = array.offsets, array.values
183+
184+
fields = []
185+
for field_values in values.flatten():
186+
list_array = pa.ListArray.from_arrays(offsets, field_values)
187+
fields.append(list_array)
188+
189+
return pa.StructArray.from_arrays(fields, names=array.type.value_type.names)

0 commit comments

Comments
 (0)