@@ -695,7 +695,28 @@ def _pack_inputs(agg_funcs: List[ReductionAggStep], in_data):
695
695
return out_dict
696
696
697
697
@staticmethod
698
- def _do_custom_agg (op , custom_reduction , * input_objs ):
698
+ def _do_custom_agg_single (op , custom_reduction , input_obj ):
699
+ if op .stage == OperandStage .map :
700
+ if custom_reduction .pre_with_agg :
701
+ apply_fun = custom_reduction .pre
702
+ else :
703
+
704
+ def apply_fun (obj ):
705
+ return custom_reduction .agg (custom_reduction .pre (obj ))
706
+
707
+ elif op .stage == OperandStage .agg :
708
+
709
+ def apply_fun (obj ):
710
+ return custom_reduction .post (custom_reduction .agg (obj ))
711
+
712
+ else :
713
+ apply_fun = custom_reduction .agg
714
+
715
+ res = input_obj .apply (apply_fun )
716
+ return (res ,)
717
+
718
+ @staticmethod
719
+ def _do_custom_agg_multiple (op , custom_reduction , * input_objs ):
699
720
xdf = cudf if op .gpu else pd
700
721
results = []
701
722
out = op .outputs [0 ]
@@ -763,6 +784,13 @@ def _do_custom_agg(op, custom_reduction, *input_objs):
763
784
concat_result = tuple (xdf .concat (parts ) for parts in zip (* results ))
764
785
return concat_result
765
786
787
+ @classmethod
788
+ def _do_custom_agg (cls , op , custom_reduction , * input_objs , output_limit : int = 1 ):
789
+ if output_limit == 1 :
790
+ return cls ._do_custom_agg_single (op , custom_reduction , input_objs [0 ])
791
+ else :
792
+ return cls ._do_custom_agg_multiple (op , custom_reduction , * input_objs )
793
+
766
794
@staticmethod
767
795
def _do_predefined_agg (input_obj , agg_func , single_func = False , ** kwds ):
768
796
ndim = getattr (input_obj , "ndim" , None ) or input_obj .obj .ndim
@@ -857,12 +885,16 @@ def _wrapped_func(col):
857
885
_agg_func_name ,
858
886
custom_reduction ,
859
887
_output_key ,
860
- _output_limit ,
888
+ output_limit ,
861
889
kwds ,
862
890
) in op .agg_funcs :
863
891
input_obj = ret_map_groupbys [input_key ]
864
892
if map_func_name == "custom_reduction" :
865
- agg_dfs .extend (cls ._do_custom_agg (op , custom_reduction , input_obj ))
893
+ agg_dfs .extend (
894
+ cls ._do_custom_agg (
895
+ op , custom_reduction , input_obj , output_limit = output_limit
896
+ )
897
+ )
866
898
else :
867
899
single_func = map_func_name == op .raw_func
868
900
agg_dfs .append (
@@ -903,12 +935,16 @@ def _execute_combine(cls, ctx, op: "DataFrameGroupByAgg"):
903
935
agg_func_name ,
904
936
custom_reduction ,
905
937
output_key ,
906
- _output_limit ,
938
+ output_limit ,
907
939
kwds ,
908
940
) in op .agg_funcs :
909
941
input_obj = in_data_dict [output_key ]
910
942
if agg_func_name == "custom_reduction" :
911
- combines .extend (cls ._do_custom_agg (op , custom_reduction , * input_obj ))
943
+ combines .extend (
944
+ cls ._do_custom_agg (
945
+ op , custom_reduction , * input_obj , output_limit = output_limit
946
+ )
947
+ )
912
948
else :
913
949
combines .append (
914
950
cls ._do_predefined_agg (input_obj , agg_func_name , ** kwds )
@@ -943,15 +979,15 @@ def _execute_agg(cls, ctx, op: "DataFrameGroupByAgg"):
943
979
agg_func_name ,
944
980
custom_reduction ,
945
981
output_key ,
946
- _output_limit ,
982
+ output_limit ,
947
983
kwds ,
948
984
) in op .agg_funcs :
949
985
if agg_func_name == "custom_reduction" :
950
986
input_obj = tuple (
951
987
cls ._get_grouped (op , o , ctx ) for o in in_data_dict [output_key ]
952
988
)
953
989
in_data_dict [output_key ] = cls ._do_custom_agg (
954
- op , custom_reduction , * input_obj
990
+ op , custom_reduction , * input_obj , output_limit = output_limit
955
991
)[0 ]
956
992
else :
957
993
input_obj = cls ._get_grouped (op , in_data_dict [output_key ], ctx )
0 commit comments