1010import org .apache .logging .log4j .LogManager ;
1111import org .apache .logging .log4j .Logger ;
1212import org .apache .lucene .util .SetOnce ;
13+ import org .elasticsearch .TransportVersion ;
1314import org .elasticsearch .action .ActionListener ;
1415import org .elasticsearch .common .util .set .Sets ;
1516import org .elasticsearch .compute .data .LongBlock ;
5253import org .elasticsearch .xpack .esql .expression .predicate .operator .arithmetic .Div ;
5354import org .elasticsearch .xpack .esql .expression .predicate .operator .comparison .Equals ;
5455import org .elasticsearch .xpack .esql .expression .predicate .operator .comparison .GreaterThanOrEqual ;
56+ import org .elasticsearch .xpack .esql .optimizer .LogicalOptimizerContext ;
5557import org .elasticsearch .xpack .esql .optimizer .LogicalPlanOptimizer ;
5658import org .elasticsearch .xpack .esql .plan .logical .Aggregate ;
5759import org .elasticsearch .xpack .esql .plan .logical .ChangePoint ;
@@ -324,10 +326,10 @@ public record QueryProperties(boolean hasGrouping, boolean canDecreaseRowCount,
324326 private final EsqlExecutionInfo executionInfo ;
325327 private final QueryProperties queryProperties ;
326328 private final EsqlSession .PlanRunner runner ;
327- private final LogicalPlanOptimizer logicalPlanOptimizer ;
328329 private final Function <LogicalPlan , PhysicalPlan > toPhysicalPlan ;
329330 private final Configuration configuration ;
330331 private final FoldContext foldContext ;
332+ private final TransportVersion minimumVersion ;
331333 private final PlanTimeProfile planTimeProfile ;
332334
333335 private final SetOnce <Long > sourceRowCount ;
@@ -337,22 +339,22 @@ public Approximation(
337339 LogicalPlan logicalPlan ,
338340 ApproximationSettings settings ,
339341 EsqlExecutionInfo executionInfo ,
340- LogicalPlanOptimizer logicalPlanOptimizer ,
341342 Function <LogicalPlan , PhysicalPlan > toPhysicalPlan ,
342343 EsqlSession .PlanRunner runner ,
343344 Configuration configuration ,
344345 FoldContext foldContext ,
346+ TransportVersion minimumVersion ,
345347 PlanTimeProfile planTimeProfile
346348 ) {
347349 this .logicalPlan = logicalPlan ;
348350 this .settings = settings ;
349351 this .executionInfo = executionInfo ;
350352 this .queryProperties = verifyPlan (logicalPlan );
351- this .logicalPlanOptimizer = logicalPlanOptimizer ;
352353 this .toPhysicalPlan = toPhysicalPlan ;
353354 this .runner = runner ;
354355 this .configuration = configuration ;
355356 this .foldContext = foldContext ;
357+ this .minimumVersion = minimumVersion ;
356358 this .planTimeProfile = planTimeProfile ;
357359
358360 sourceRowCount = new SetOnce <>();
@@ -503,8 +505,8 @@ private LogicalPlan sourceCountPlan() {
503505 List .of (),
504506 List .of (new Alias (Source .EMPTY , "$count" , COUNT_ALL_ROWS ))
505507 );
506- sourceCountPlan .setPreOptimized ();
507- return logicalPlanOptimizer . optimize ( sourceCountPlan ) ;
508+ sourceCountPlan .setOptimized ();
509+ return sourceCountPlan ;
508510 }
509511
510512 /**
@@ -587,8 +589,8 @@ private LogicalPlan countPlan(double sampleProbability) {
587589 return plan ;
588590 });
589591
590- countPlan .setPreOptimized ();
591- return logicalPlanOptimizer . optimize ( countPlan ) ;
592+ countPlan .setOptimized ();
593+ return countPlan ;
592594 }
593595
594596 /**
@@ -748,7 +750,7 @@ private LogicalPlan approximationPlan(double sampleProbability) {
748750 return exactPlanWithConfidenceIntervals ();
749751 }
750752
751- logger .debug ("generating approximate plan (p=[{}])" , sampleProbability );
753+ logger .debug ("generating approximation plan (p=[{}])" , sampleProbability );
752754
753755 // Whether of not the first STATS command has been encountered yet.
754756 Holder <Boolean > encounteredStats = new Holder <>(false );
@@ -766,7 +768,7 @@ private LogicalPlan approximationPlan(double sampleProbability) {
766768 // is rewritten to AVG::double = SUM::double / COUNT::long.
767769 Map <NameId , NamedExpression > uncorrectedExpressions = new HashMap <>();
768770
769- LogicalPlan approximatePlan = logicalPlan .transformUp (plan -> {
771+ LogicalPlan approximationPlan = logicalPlan .transformUp (plan -> {
770772 if (plan instanceof LeafPlan ) {
771773 // The leaf plan should be appended by a SAMPLE.
772774 return new Sample (Source .EMPTY , Literal .fromDouble (Source .EMPTY , sampleProbability ), plan );
@@ -789,22 +791,49 @@ private LogicalPlan approximationPlan(double sampleProbability) {
789791 });
790792
791793 // Add the confidence intervals for all fields with buckets.
792- approximatePlan = new Eval (Source .EMPTY , approximatePlan , getConfidenceIntervals (fieldBuckets ));
794+ approximationPlan = new Eval (Source .EMPTY , approximationPlan , getConfidenceIntervals (fieldBuckets ));
793795
794796 // Drop all bucket fields and uncorrected fields from the output.
795797 Set <Attribute > dropAttributes = Stream .concat (
796798 fieldBuckets .values ().stream ().flatMap (List ::stream ),
797799 uncorrectedExpressions .values ().stream ()
798800 ).map (NamedExpression ::toAttribute ).collect (Collectors .toSet ());
799801
800- List <Attribute > keepAttributes = new ArrayList <>(approximatePlan .output ());
802+ List <Attribute > keepAttributes = new ArrayList <>(approximationPlan .output ());
801803 keepAttributes .removeAll (dropAttributes );
802- approximatePlan = new Project (Source .EMPTY , approximatePlan , keepAttributes );
804+ approximationPlan = new Project (Source .EMPTY , approximationPlan , keepAttributes );
805+ approximationPlan = optimize (approximationPlan );
806+ logger .debug ("approximation plan (after:\n {}" , approximationPlan );
807+ return approximationPlan ;
808+ }
803809
804- approximatePlan .setPreOptimized ();
805- approximatePlan = logicalPlanOptimizer .optimize (approximatePlan );
806- logger .debug ("approximate plan:\n {}" , approximatePlan );
807- return approximatePlan ;
810+ /**
811+ * Optimizes the plan by running just the operator batch. This is primarily
812+ * to prune unnecessary columns generated in the approximation plan.
813+ * <p>
814+ * These unnecessary columns are generated in various ways, for example:
815+ * - STATS AVG(x): the AVG is rewritten via a surrogate to SUM and COUNT.
816+ * Both the corrected and uncorrected sum and count column are generated,
817+ * but the corrected ones aren't needed for AVG(x) = SUM(x)/COUNT().
818+ * - STATS x=COUNT() | EVAL x=TO_STRING(x): bucket columns are generated
819+ * for the numeric count, but are not needed after the string conversion.
820+ * <p>
821+ * Note this is running the operator batch on top of a plan that it normally
822+ * doesn't run on (namely a cleaned-up plan, which can contain a TopN). It
823+ * seems like that works (at least for everything we've tested), but this
824+ * process has poor test coverage.
825+ * TODO: refactor so that approximation ideally happens halfway-through
826+ * optimization, before cleanup step. Possibly make approximation an
827+ * optimization rule itself in the substitutions batch.
828+ */
829+ private LogicalPlan optimize (LogicalPlan plan ) {
830+ LogicalPlanOptimizer optimizer = new LogicalPlanOptimizer (new LogicalOptimizerContext (configuration , foldContext , minimumVersion )) {
831+ @ Override
832+ protected List <Batch <LogicalPlan >> batches () {
833+ return List .of (operators ());
834+ }
835+ };
836+ return optimizer .optimize (plan );
808837 }
809838
810839 /**
@@ -819,12 +848,13 @@ private LogicalPlan approximationPlan(double sampleProbability) {
819848 * <pre>
820849 * {@code
821850 * STATS sampleSize = COUNT(*),
822- * s = SUM(x) / prob ,
823- * `s$0` = SUM(x) / (prob/B)) WHERE MV_SLICE(bucketId, 0, 0) == 0
851+ * s = SUM(x),
852+ * `s$0` = SUM(x) WHERE MV_SLICE(bucketId, 0, 0) == 0
824853 * ...,
825854 * `s$T*B-1` = SUM(x) / (prob/B) WHERE MV_SLICE(bucketId, T-1, T-1) == B-1
826855 * BY group
827856 * | WHERE sampleSize >= MIN_ROW_COUNT_FOR_RESULT_INCLUSION
857+ * | EVAL s = s / prob, `s$0` = `s$0` / (prob/B), `s$T*B-1` = `s$T*B-1` / (prob/B)
828858 * | DROP sampleSize
829859 * }
830860 * </pre>
@@ -843,9 +873,15 @@ private LogicalPlan sampleCorrectedAggregateAndBuckets(
843873 // TODO: use theoretically non-conflicting names.
844874 Alias bucketIdField = new Alias (Source .EMPTY , "$bucket_id" , bucketIds );
845875
876+ // The aggregate functions in the approximation plan.
846877 List <NamedExpression > aggregates = new ArrayList <>();
847- Alias sampleSize = new Alias (Source .EMPTY , "$sample_size" , COUNT_ALL_ROWS );
848- aggregates .add (sampleSize );
878+
879+ // List of expressions that must be evaluated after the sampled aggregation.
880+ // These consist of:
881+ // - sample corrections (to correct counts/sums for sampling)
882+ // - replace zero counts by NULLs (for confidence interval computation)
883+ // - exact total row count if COUNT(*) is used (to avoid sampling errors there)
884+ List <Alias > evals = new ArrayList <>();
849885
850886 for (NamedExpression aggOrKey : aggregate .aggregates ()) {
851887 if ((aggOrKey instanceof Alias alias && alias .child () instanceof AggregateFunction ) == false ) {
@@ -859,27 +895,24 @@ private LogicalPlan sampleCorrectedAggregateAndBuckets(
859895
860896 // If the query is preserving all rows, and the aggregation function is
861897 // counting all rows, we know the exact result without sampling.
862- // TODO: COUNT("foobar"), which counts all rows, should also be detected.
863- // Note that this inserts a constant as an aggregation function. This
864- // works fine (empirically) even though it isn't an aggregation function.
865- // TODO: refactor into EVAL+PROJECT, instead of a constant aggregation.
866898 if (aggFn .equals (COUNT_ALL_ROWS )
867899 && aggregate .groupings ().isEmpty ()
868900 && queryProperties .canDecreaseRowCount == false
869901 && queryProperties .canIncreaseRowCount == false ) {
870- aggregates .add (aggAlias .replaceChild (Literal .fromLong (Source .EMPTY , sourceRowCount .get ())));
902+ evals .add (aggAlias .replaceChild (Literal .fromLong (Source .EMPTY , sourceRowCount .get ())));
871903 continue ;
872904 }
873905
874906 // Replace the original aggregation by a sample-corrected one if needed.
875907 if (SAMPLE_CORRECTED_AGGS .contains (aggFn .getClass ()) == false ) {
876908 aggregates .add (aggAlias );
877909 } else {
878- Expression correctedAgg = correctForSampling (aggFn , sampleProbability );
879- aggregates .add (aggAlias .replaceChild (correctedAgg ));
880910 Alias uncorrectedAggAlias = new Alias (aggAlias .source (), aggAlias .name () + "$uncorrected" , aggFn );
881911 aggregates .add (uncorrectedAggAlias );
882912 uncorrectedExpressions .put (aggAlias .id (), uncorrectedAggAlias );
913+
914+ Expression correctedAgg = correctForSampling (uncorrectedAggAlias .toAttribute (), sampleProbability );
915+ evals .add (aggAlias .replaceChild (correctedAgg ));
883916 }
884917
885918 if (SUPPORTED_SINGLE_VALUED_AGGS .contains (aggFn .getClass ())) {
@@ -901,34 +934,41 @@ private LogicalPlan sampleCorrectedAggregateAndBuckets(
901934 Literal .integer (Source .EMPTY , bucketId )
902935 )
903936 );
904- if (aggFn .equals (COUNT_ALL_ROWS )) {
905- // For COUNT, no data should result in NULL, like in other aggregations.
906- // Otherwise, the confidence interval computation breaks.
907- bucket = new Case (
908- Source .EMPTY ,
909- new Equals (Source .EMPTY , bucket , Literal .fromLong (Source .EMPTY , 0L )),
910- List .of (Literal .NULL , bucket )
911- );
912- }
913937 Alias bucketAlias = new Alias (Source .EMPTY , aggOrKey .name () + "$" + (trialId * BUCKET_COUNT + bucketId ), bucket );
914938 if (SAMPLE_CORRECTED_AGGS .contains (aggFn .getClass ()) == false ) {
915939 buckets .add (bucketAlias );
916940 aggregates .add (bucketAlias );
917941 } else {
918- Expression correctedBucket = correctForSampling (bucket , sampleProbability / BUCKET_COUNT );
919- bucketAlias = bucketAlias .replaceChild (correctedBucket );
920- buckets .add (bucketAlias );
921- aggregates .add (bucketAlias );
922942 Alias uncorrectedBucketAlias = new Alias (Source .EMPTY , bucketAlias .name () + "$uncorrected" , bucket );
923- uncorrectedExpressions .put (bucketAlias .id (), uncorrectedBucketAlias );
924943 aggregates .add (uncorrectedBucketAlias );
944+ uncorrectedExpressions .put (bucketAlias .id (), uncorrectedBucketAlias );
945+
946+ Expression uncorrectedBucket = uncorrectedBucketAlias .toAttribute ();
947+ if (aggFn .equals (COUNT_ALL_ROWS )) {
948+ // For COUNT, no data should result in NULL, like in other aggregations.
949+ // Otherwise, the confidence interval computation breaks.
950+ uncorrectedBucket = new Case (
951+ Source .EMPTY ,
952+ new Equals (Source .EMPTY , uncorrectedBucket , Literal .fromLong (Source .EMPTY , 0L )),
953+ List .of (Literal .NULL , uncorrectedBucket )
954+ );
955+ }
956+
957+ Expression correctedBucket = correctForSampling (uncorrectedBucket , sampleProbability / BUCKET_COUNT );
958+ Alias correctedBucketAlias = bucketAlias .replaceChild (correctedBucket );
959+ evals .add (correctedBucketAlias );
960+ buckets .add (correctedBucketAlias );
925961 }
926962 }
927963 }
928964 fieldBuckets .put (aggOrKey .id (), buckets );
929965 }
930966 }
931967
968+ // Add the sample size per grouping to filter out groups with too few sampled rows.
969+ Alias sampleSize = new Alias (Source .EMPTY , "$sample_size" , COUNT_ALL_ROWS );
970+ aggregates .add (sampleSize );
971+
932972 // Add the bucket ID, do the aggregations (sampled corrected, including the buckets),
933973 // and filter out rows with too few sampled values.
934974 LogicalPlan plan = new Eval (Source .EMPTY , aggregate .child (), List .of (bucketIdField ));
@@ -942,6 +982,8 @@ private LogicalPlan sampleCorrectedAggregateAndBuckets(
942982 Literal .integer (Source .EMPTY , MIN_ROW_COUNT_FOR_RESULT_INCLUSION )
943983 )
944984 );
985+ plan = new Eval (Source .EMPTY , plan , evals );
986+
945987 List <Attribute > keepAttributes = new ArrayList <>(plan .output ());
946988 keepAttributes .remove (sampleSize .toAttribute ());
947989 return new Project (Source .EMPTY , plan , keepAttributes );
0 commit comments