Skip to content

Commit 6f52547

Browse files
committed
Simplify agg with single return
1 parent ee31659 commit 6f52547

File tree

1 file changed

+43
-7
lines changed

1 file changed

+43
-7
lines changed

mars/dataframe/groupby/aggregation.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,28 @@ def _pack_inputs(agg_funcs: List[ReductionAggStep], in_data):
695695
return out_dict
696696

697697
@staticmethod
698-
def _do_custom_agg(op, custom_reduction, *input_objs):
698+
def _do_custom_agg_single(op, custom_reduction, input_obj):
699+
if op.stage == OperandStage.map:
700+
if custom_reduction.pre_with_agg:
701+
apply_fun = custom_reduction.pre
702+
else:
703+
704+
def apply_fun(obj):
705+
return custom_reduction.agg(custom_reduction.pre(obj))
706+
707+
elif op.stage == OperandStage.agg:
708+
709+
def apply_fun(obj):
710+
return custom_reduction.post(custom_reduction.agg(obj))
711+
712+
else:
713+
apply_fun = custom_reduction.agg
714+
715+
res = input_obj.apply(apply_fun)
716+
return (res,)
717+
718+
@staticmethod
719+
def _do_custom_agg_multiple(op, custom_reduction, *input_objs):
699720
xdf = cudf if op.gpu else pd
700721
results = []
701722
out = op.outputs[0]
@@ -763,6 +784,13 @@ def _do_custom_agg(op, custom_reduction, *input_objs):
763784
concat_result = tuple(xdf.concat(parts) for parts in zip(*results))
764785
return concat_result
765786

787+
@classmethod
788+
def _do_custom_agg(cls, op, custom_reduction, *input_objs, output_limit: int = 1):
789+
if output_limit == 1:
790+
return cls._do_custom_agg_single(op, custom_reduction, input_objs[0])
791+
else:
792+
return cls._do_custom_agg_multiple(op, custom_reduction, *input_objs)
793+
766794
@staticmethod
767795
def _do_predefined_agg(input_obj, agg_func, single_func=False, **kwds):
768796
ndim = getattr(input_obj, "ndim", None) or input_obj.obj.ndim
@@ -857,12 +885,16 @@ def _wrapped_func(col):
857885
_agg_func_name,
858886
custom_reduction,
859887
_output_key,
860-
_output_limit,
888+
output_limit,
861889
kwds,
862890
) in op.agg_funcs:
863891
input_obj = ret_map_groupbys[input_key]
864892
if map_func_name == "custom_reduction":
865-
agg_dfs.extend(cls._do_custom_agg(op, custom_reduction, input_obj))
893+
agg_dfs.extend(
894+
cls._do_custom_agg(
895+
op, custom_reduction, input_obj, output_limit=output_limit
896+
)
897+
)
866898
else:
867899
single_func = map_func_name == op.raw_func
868900
agg_dfs.append(
@@ -903,12 +935,16 @@ def _execute_combine(cls, ctx, op: "DataFrameGroupByAgg"):
903935
agg_func_name,
904936
custom_reduction,
905937
output_key,
906-
_output_limit,
938+
output_limit,
907939
kwds,
908940
) in op.agg_funcs:
909941
input_obj = in_data_dict[output_key]
910942
if agg_func_name == "custom_reduction":
911-
combines.extend(cls._do_custom_agg(op, custom_reduction, *input_obj))
943+
combines.extend(
944+
cls._do_custom_agg(
945+
op, custom_reduction, *input_obj, output_limit=output_limit
946+
)
947+
)
912948
else:
913949
combines.append(
914950
cls._do_predefined_agg(input_obj, agg_func_name, **kwds)
@@ -943,15 +979,15 @@ def _execute_agg(cls, ctx, op: "DataFrameGroupByAgg"):
943979
agg_func_name,
944980
custom_reduction,
945981
output_key,
946-
_output_limit,
982+
output_limit,
947983
kwds,
948984
) in op.agg_funcs:
949985
if agg_func_name == "custom_reduction":
950986
input_obj = tuple(
951987
cls._get_grouped(op, o, ctx) for o in in_data_dict[output_key]
952988
)
953989
in_data_dict[output_key] = cls._do_custom_agg(
954-
op, custom_reduction, *input_obj
990+
op, custom_reduction, *input_obj, output_limit=output_limit
955991
)[0]
956992
else:
957993
input_obj = cls._get_grouped(op, in_data_dict[output_key], ctx)

0 commit comments

Comments
 (0)