Skip to content

Commit 3fd1743

Browse files
committed
reduce output nesting inference
1 parent 6b7a45f commit 3fd1743

File tree

2 files changed

+77
-2
lines changed

2 files changed

+77
-2
lines changed

src/nested_pandas/nestedframe/core.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,7 @@ def sort_values(
845845
return None
846846
return new_df
847847

848-
def reduce(self, func, *args, **kwargs) -> NestedFrame: # type: ignore[override]
848+
def reduce(self, func, *args, infer_nesting=True, **kwargs) -> NestedFrame: # type: ignore[override]
849849
"""
850850
Takes a function and applies it to each top-level row of the NestedFrame.
851851
@@ -862,6 +862,12 @@ def reduce(self, func, *args, **kwargs) -> NestedFrame: # type: ignore[override
862862
args : positional arguments
863863
Positional arguments to pass to the function, the first *args should be the names of the
864864
columns to apply the function to.
865+
infer_nesting : bool, default True
866+
If True, the function will pack output columns into nested
867+
structures based on column names adherring to a nested naming
868+
scheme. E.g. "nested.b" and "nested.c" will be packed into a column
869+
called "nested" with columns "b" and "c". If False, all outputs
870+
will be returned as base columns.
865871
kwargs : keyword arguments, optional
866872
Keyword arguments to pass to the function.
867873
@@ -915,7 +921,25 @@ def reduce(self, func, *args, **kwargs) -> NestedFrame: # type: ignore[override
915921
iterators.append(self[layer].array.iter_field_lists(col))
916922

917923
results = [func(*cols, *extra_args, **kwargs) for cols in zip(*iterators)]
918-
return NestedFrame(results, index=self.index)
924+
results_nf = NestedFrame(results, index=self.index)
925+
926+
if infer_nesting:
927+
# find potential nested structures
928+
nested_cols = []
929+
for column in results_nf.columns:
930+
if isinstance(column, str) and "." in column:
931+
layer, col = column.split(".", 1)
932+
nested_cols.append(layer)
933+
nested_cols = np.unique(nested_cols)
934+
935+
# pack results into nested structures
936+
for layer in nested_cols:
937+
layer_cols = [col for col in results_nf.columns if col.startswith(f"{layer}.")]
938+
rename_df = results_nf[layer_cols].rename(columns=lambda x: x.split(".", 1)[1])
939+
nested_col = pack_lists(rename_df, name=layer)
940+
results_nf = results_nf[[col for col in results_nf.columns if not col.startswith(f"{layer}.")]].join(nested_col)
941+
942+
return results_nf
919943

920944
def to_parquet(self, path, by_layer=False, **kwargs) -> None:
921945
"""Creates parquet file(s) with the data of a NestedFrame, either

tests/nested_pandas/nestedframe/test_nestedframe.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,57 @@ def cols_allclose(col1, col2):
10211021
result, pd.DataFrame({"allclose": [True, True, True]}, index=pd.Index([0, 1, 2], name="idx"))
10221022
)
10231023

1024+
def test_reduce_infer_nesting():
1025+
"""Test that nesting inference works in reduce"""
1026+
1027+
ndf = generate_data(3,20, seed=1)
1028+
1029+
# Test simple case
1030+
def complex_output(flux):
1031+
return {"max_flux":np.max(flux), "lc.flux_quantiles":np.quantile(flux, [0.1,0.2,0.3,0.4,0.5])}
1032+
1033+
result = ndf.reduce(complex_output, "nested.flux")
1034+
assert list(result.columns) == ["max_flux", "lc"]
1035+
assert list(result.lc.nest.fields) == ['flux_quantiles']
1036+
1037+
# Test multi-column nested output
1038+
def complex_output(flux):
1039+
return {"max_flux":np.max(flux),
1040+
"lc.flux_quantiles":np.quantile(flux, [0.1,0.2,0.3,0.4,0.5]),
1041+
"lc.labels":[0.1,0.2,0.3,0.4,0.5]}
1042+
1043+
result = ndf.reduce(complex_output, "nested.flux")
1044+
assert list(result.columns) == ["max_flux", "lc"]
1045+
assert list(result.lc.nest.fields) == ['flux_quantiles', "labels"]
1046+
1047+
# Test integer names
1048+
def complex_output(flux):
1049+
return np.max(flux), np.quantile(flux, [0.1,0.2,0.3,0.4,0.5]),[0.1,0.2,0.3,0.4,0.5]
1050+
1051+
result = ndf.reduce(complex_output, "nested.flux")
1052+
assert list(result.columns) == [0, 1, 2]
1053+
1054+
# Test multiple nested structure output
1055+
def complex_output(flux):
1056+
return {"max_flux":np.max(flux),
1057+
"lc.flux_quantiles":np.quantile(flux, [0.1,0.2,0.3,0.4,0.5]),
1058+
"lc.labels":[0.1,0.2,0.3,0.4,0.5],
1059+
"meta.colors":["green", "red", "blue"]}
1060+
1061+
result = ndf.reduce(complex_output, "nested.flux")
1062+
assert list(result.columns) == ["max_flux", "lc", "meta"]
1063+
assert list(result.lc.nest.fields) == ["flux_quantiles", "labels"]
1064+
assert list(result.lc.meta.fields) == ["colors"]
1065+
1066+
# Test only nested structure output
1067+
def complex_output(flux):
1068+
return {"lc.flux_quantiles":np.quantile(flux, [0.1,0.2,0.3,0.4,0.5]),
1069+
"lc.labels":[0.1,0.2,0.3,0.4,0.5]}
1070+
1071+
result = ndf.reduce(complex_output, "nested.flux")
1072+
assert list(result.columns) == ["lc"]
1073+
assert list(result.lc.nest.fields) == ["flux_quantiles", "labels"]
1074+
10241075

10251076
def test_scientific_notation():
10261077
"""

0 commit comments

Comments
 (0)