Skip to content

Commit 8a38cbb

Browse files
authored
[ES|QL] Query approximation: don't run full LogicalPlanOptimizer twice (#141814)
* don't run optimizer twice * CSV test for COUNT(*) * add mini-optimizer * polish * more polish
1 parent 9d155cd commit 8a38cbb

File tree

5 files changed

+101
-61
lines changed

5 files changed

+101
-61
lines changed

x-pack/plugin/esql/qa/testFixtures/src/main/resources/approximation.csv-spec

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,18 @@
1414
// which would make the tests flaky.
1515

1616

17+
total count
18+
required_capability: approximation
19+
20+
SET approximation={"rows":10000}\;
21+
FROM many_numbers | STATS count=COUNT(*)
22+
;
23+
24+
count:long
25+
200100
26+
;
27+
28+
1729
approximate stats on large data
1830
required_capability: approximation
1931

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximation/Approximation.java

Lines changed: 83 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
1212
import org.apache.lucene.util.SetOnce;
13+
import org.elasticsearch.TransportVersion;
1314
import org.elasticsearch.action.ActionListener;
1415
import org.elasticsearch.common.util.set.Sets;
1516
import org.elasticsearch.compute.data.LongBlock;
@@ -52,6 +53,7 @@
5253
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
5354
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
5455
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual;
56+
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
5557
import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer;
5658
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
5759
import org.elasticsearch.xpack.esql.plan.logical.ChangePoint;
@@ -324,10 +326,10 @@ public record QueryProperties(boolean hasGrouping, boolean canDecreaseRowCount,
324326
private final EsqlExecutionInfo executionInfo;
325327
private final QueryProperties queryProperties;
326328
private final EsqlSession.PlanRunner runner;
327-
private final LogicalPlanOptimizer logicalPlanOptimizer;
328329
private final Function<LogicalPlan, PhysicalPlan> toPhysicalPlan;
329330
private final Configuration configuration;
330331
private final FoldContext foldContext;
332+
private final TransportVersion minimumVersion;
331333
private final PlanTimeProfile planTimeProfile;
332334

333335
private final SetOnce<Long> sourceRowCount;
@@ -337,22 +339,22 @@ public Approximation(
337339
LogicalPlan logicalPlan,
338340
ApproximationSettings settings,
339341
EsqlExecutionInfo executionInfo,
340-
LogicalPlanOptimizer logicalPlanOptimizer,
341342
Function<LogicalPlan, PhysicalPlan> toPhysicalPlan,
342343
EsqlSession.PlanRunner runner,
343344
Configuration configuration,
344345
FoldContext foldContext,
346+
TransportVersion minimumVersion,
345347
PlanTimeProfile planTimeProfile
346348
) {
347349
this.logicalPlan = logicalPlan;
348350
this.settings = settings;
349351
this.executionInfo = executionInfo;
350352
this.queryProperties = verifyPlan(logicalPlan);
351-
this.logicalPlanOptimizer = logicalPlanOptimizer;
352353
this.toPhysicalPlan = toPhysicalPlan;
353354
this.runner = runner;
354355
this.configuration = configuration;
355356
this.foldContext = foldContext;
357+
this.minimumVersion = minimumVersion;
356358
this.planTimeProfile = planTimeProfile;
357359

358360
sourceRowCount = new SetOnce<>();
@@ -503,8 +505,8 @@ private LogicalPlan sourceCountPlan() {
503505
List.of(),
504506
List.of(new Alias(Source.EMPTY, "$count", COUNT_ALL_ROWS))
505507
);
506-
sourceCountPlan.setPreOptimized();
507-
return logicalPlanOptimizer.optimize(sourceCountPlan);
508+
sourceCountPlan.setOptimized();
509+
return sourceCountPlan;
508510
}
509511

510512
/**
@@ -587,8 +589,8 @@ private LogicalPlan countPlan(double sampleProbability) {
587589
return plan;
588590
});
589591

590-
countPlan.setPreOptimized();
591-
return logicalPlanOptimizer.optimize(countPlan);
592+
countPlan.setOptimized();
593+
return countPlan;
592594
}
593595

594596
/**
@@ -748,7 +750,7 @@ private LogicalPlan approximationPlan(double sampleProbability) {
748750
return exactPlanWithConfidenceIntervals();
749751
}
750752

751-
logger.debug("generating approximate plan (p=[{}])", sampleProbability);
753+
logger.debug("generating approximation plan (p=[{}])", sampleProbability);
752754

753755
// Whether of not the first STATS command has been encountered yet.
754756
Holder<Boolean> encounteredStats = new Holder<>(false);
@@ -766,7 +768,7 @@ private LogicalPlan approximationPlan(double sampleProbability) {
766768
// is rewritten to AVG::double = SUM::double / COUNT::long.
767769
Map<NameId, NamedExpression> uncorrectedExpressions = new HashMap<>();
768770

769-
LogicalPlan approximatePlan = logicalPlan.transformUp(plan -> {
771+
LogicalPlan approximationPlan = logicalPlan.transformUp(plan -> {
770772
if (plan instanceof LeafPlan) {
771773
// The leaf plan should be appended by a SAMPLE.
772774
return new Sample(Source.EMPTY, Literal.fromDouble(Source.EMPTY, sampleProbability), plan);
@@ -789,22 +791,49 @@ private LogicalPlan approximationPlan(double sampleProbability) {
789791
});
790792

791793
// Add the confidence intervals for all fields with buckets.
792-
approximatePlan = new Eval(Source.EMPTY, approximatePlan, getConfidenceIntervals(fieldBuckets));
794+
approximationPlan = new Eval(Source.EMPTY, approximationPlan, getConfidenceIntervals(fieldBuckets));
793795

794796
// Drop all bucket fields and uncorrected fields from the output.
795797
Set<Attribute> dropAttributes = Stream.concat(
796798
fieldBuckets.values().stream().flatMap(List::stream),
797799
uncorrectedExpressions.values().stream()
798800
).map(NamedExpression::toAttribute).collect(Collectors.toSet());
799801

800-
List<Attribute> keepAttributes = new ArrayList<>(approximatePlan.output());
802+
List<Attribute> keepAttributes = new ArrayList<>(approximationPlan.output());
801803
keepAttributes.removeAll(dropAttributes);
802-
approximatePlan = new Project(Source.EMPTY, approximatePlan, keepAttributes);
804+
approximationPlan = new Project(Source.EMPTY, approximationPlan, keepAttributes);
805+
approximationPlan = optimize(approximationPlan);
806+
logger.debug("approximation plan (after:\n{}", approximationPlan);
807+
return approximationPlan;
808+
}
803809

804-
approximatePlan.setPreOptimized();
805-
approximatePlan = logicalPlanOptimizer.optimize(approximatePlan);
806-
logger.debug("approximate plan:\n{}", approximatePlan);
807-
return approximatePlan;
810+
/**
811+
* Optimizes the plan by running just the operator batch. This is primarily
812+
* to prune unnecessary columns generated in the approximation plan.
813+
* <p>
814+
* These unnecessary columns are generated in various ways, for example:
815+
* - STATS AVG(x): the AVG is rewritten via a surrogate to SUM and COUNT.
816+
* Both the corrected and uncorrected sum and count column are generated,
817+
* but the corrected ones aren't needed for AVG(x) = SUM(x)/COUNT().
818+
* - STATS x=COUNT() | EVAL x=TO_STRING(x): bucket columns are generated
819+
* for the numeric count, but are not needed after the string conversion.
820+
* <p>
821+
* Note this is running the operator batch on top of a plan that it normally
822+
* doesn't run on (namely a cleaned-up plan, which can contain a TopN). It
823+
* seems like that works (at least for everything we've tested), but this
824+
* process has poor test coverage.
825+
* TODO: refactor so that approximation ideally happens halfway-through
826+
* optimization, before cleanup step. Possibly make approximation an
827+
* optimization rule itself in the substitutions batch.
828+
*/
829+
private LogicalPlan optimize(LogicalPlan plan) {
830+
LogicalPlanOptimizer optimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(configuration, foldContext, minimumVersion)) {
831+
@Override
832+
protected List<Batch<LogicalPlan>> batches() {
833+
return List.of(operators());
834+
}
835+
};
836+
return optimizer.optimize(plan);
808837
}
809838

810839
/**
@@ -819,12 +848,13 @@ private LogicalPlan approximationPlan(double sampleProbability) {
819848
* <pre>
820849
* {@code
821850
* STATS sampleSize = COUNT(*),
822-
* s = SUM(x) / prob,
823-
* `s$0` = SUM(x) / (prob/B)) WHERE MV_SLICE(bucketId, 0, 0) == 0
851+
* s = SUM(x),
852+
* `s$0` = SUM(x) WHERE MV_SLICE(bucketId, 0, 0) == 0
824853
* ...,
825854
* `s$T*B-1` = SUM(x) / (prob/B) WHERE MV_SLICE(bucketId, T-1, T-1) == B-1
826855
* BY group
827856
* | WHERE sampleSize >= MIN_ROW_COUNT_FOR_RESULT_INCLUSION
857+
* | EVAL s = s / prob, `s$0` = `s$0` / (prob/B), `s$T*B-1` = `s$T*B-1` / (prob/B)
828858
* | DROP sampleSize
829859
* }
830860
* </pre>
@@ -843,9 +873,15 @@ private LogicalPlan sampleCorrectedAggregateAndBuckets(
843873
// TODO: use theoretically non-conflicting names.
844874
Alias bucketIdField = new Alias(Source.EMPTY, "$bucket_id", bucketIds);
845875

876+
// The aggregate functions in the approximation plan.
846877
List<NamedExpression> aggregates = new ArrayList<>();
847-
Alias sampleSize = new Alias(Source.EMPTY, "$sample_size", COUNT_ALL_ROWS);
848-
aggregates.add(sampleSize);
878+
879+
// List of expressions that must be evaluated after the sampled aggregation.
880+
// These consist of:
881+
// - sample corrections (to correct counts/sums for sampling)
882+
// - replace zero counts by NULLs (for confidence interval computation)
883+
// - exact total row count if COUNT(*) is used (to avoid sampling errors there)
884+
List<Alias> evals = new ArrayList<>();
849885

850886
for (NamedExpression aggOrKey : aggregate.aggregates()) {
851887
if ((aggOrKey instanceof Alias alias && alias.child() instanceof AggregateFunction) == false) {
@@ -859,27 +895,24 @@ private LogicalPlan sampleCorrectedAggregateAndBuckets(
859895

860896
// If the query is preserving all rows, and the aggregation function is
861897
// counting all rows, we know the exact result without sampling.
862-
// TODO: COUNT("foobar"), which counts all rows, should also be detected.
863-
// Note that this inserts a constant as an aggregation function. This
864-
// works fine (empirically) even though it isn't an aggregation function.
865-
// TODO: refactor into EVAL+PROJECT, instead of a constant aggregation.
866898
if (aggFn.equals(COUNT_ALL_ROWS)
867899
&& aggregate.groupings().isEmpty()
868900
&& queryProperties.canDecreaseRowCount == false
869901
&& queryProperties.canIncreaseRowCount == false) {
870-
aggregates.add(aggAlias.replaceChild(Literal.fromLong(Source.EMPTY, sourceRowCount.get())));
902+
evals.add(aggAlias.replaceChild(Literal.fromLong(Source.EMPTY, sourceRowCount.get())));
871903
continue;
872904
}
873905

874906
// Replace the original aggregation by a sample-corrected one if needed.
875907
if (SAMPLE_CORRECTED_AGGS.contains(aggFn.getClass()) == false) {
876908
aggregates.add(aggAlias);
877909
} else {
878-
Expression correctedAgg = correctForSampling(aggFn, sampleProbability);
879-
aggregates.add(aggAlias.replaceChild(correctedAgg));
880910
Alias uncorrectedAggAlias = new Alias(aggAlias.source(), aggAlias.name() + "$uncorrected", aggFn);
881911
aggregates.add(uncorrectedAggAlias);
882912
uncorrectedExpressions.put(aggAlias.id(), uncorrectedAggAlias);
913+
914+
Expression correctedAgg = correctForSampling(uncorrectedAggAlias.toAttribute(), sampleProbability);
915+
evals.add(aggAlias.replaceChild(correctedAgg));
883916
}
884917

885918
if (SUPPORTED_SINGLE_VALUED_AGGS.contains(aggFn.getClass())) {
@@ -901,34 +934,41 @@ private LogicalPlan sampleCorrectedAggregateAndBuckets(
901934
Literal.integer(Source.EMPTY, bucketId)
902935
)
903936
);
904-
if (aggFn.equals(COUNT_ALL_ROWS)) {
905-
// For COUNT, no data should result in NULL, like in other aggregations.
906-
// Otherwise, the confidence interval computation breaks.
907-
bucket = new Case(
908-
Source.EMPTY,
909-
new Equals(Source.EMPTY, bucket, Literal.fromLong(Source.EMPTY, 0L)),
910-
List.of(Literal.NULL, bucket)
911-
);
912-
}
913937
Alias bucketAlias = new Alias(Source.EMPTY, aggOrKey.name() + "$" + (trialId * BUCKET_COUNT + bucketId), bucket);
914938
if (SAMPLE_CORRECTED_AGGS.contains(aggFn.getClass()) == false) {
915939
buckets.add(bucketAlias);
916940
aggregates.add(bucketAlias);
917941
} else {
918-
Expression correctedBucket = correctForSampling(bucket, sampleProbability / BUCKET_COUNT);
919-
bucketAlias = bucketAlias.replaceChild(correctedBucket);
920-
buckets.add(bucketAlias);
921-
aggregates.add(bucketAlias);
922942
Alias uncorrectedBucketAlias = new Alias(Source.EMPTY, bucketAlias.name() + "$uncorrected", bucket);
923-
uncorrectedExpressions.put(bucketAlias.id(), uncorrectedBucketAlias);
924943
aggregates.add(uncorrectedBucketAlias);
944+
uncorrectedExpressions.put(bucketAlias.id(), uncorrectedBucketAlias);
945+
946+
Expression uncorrectedBucket = uncorrectedBucketAlias.toAttribute();
947+
if (aggFn.equals(COUNT_ALL_ROWS)) {
948+
// For COUNT, no data should result in NULL, like in other aggregations.
949+
// Otherwise, the confidence interval computation breaks.
950+
uncorrectedBucket = new Case(
951+
Source.EMPTY,
952+
new Equals(Source.EMPTY, uncorrectedBucket, Literal.fromLong(Source.EMPTY, 0L)),
953+
List.of(Literal.NULL, uncorrectedBucket)
954+
);
955+
}
956+
957+
Expression correctedBucket = correctForSampling(uncorrectedBucket, sampleProbability / BUCKET_COUNT);
958+
Alias correctedBucketAlias = bucketAlias.replaceChild(correctedBucket);
959+
evals.add(correctedBucketAlias);
960+
buckets.add(correctedBucketAlias);
925961
}
926962
}
927963
}
928964
fieldBuckets.put(aggOrKey.id(), buckets);
929965
}
930966
}
931967

968+
// Add the sample size per grouping to filter out groups with too few sampled rows.
969+
Alias sampleSize = new Alias(Source.EMPTY, "$sample_size", COUNT_ALL_ROWS);
970+
aggregates.add(sampleSize);
971+
932972
// Add the bucket ID, do the aggregations (sampled corrected, including the buckets),
933973
// and filter out rows with too few sampled values.
934974
LogicalPlan plan = new Eval(Source.EMPTY, aggregate.child(), List.of(bucketIdField));
@@ -942,6 +982,8 @@ private LogicalPlan sampleCorrectedAggregateAndBuckets(
942982
Literal.integer(Source.EMPTY, MIN_ROW_COUNT_FOR_RESULT_INCLUSION)
943983
)
944984
);
985+
plan = new Eval(Source.EMPTY, plan, evals);
986+
945987
List<Attribute> keepAttributes = new ArrayList<>(plan.output());
946988
keepAttributes.remove(sampleSize.toAttribute());
947989
return new Project(Source.EMPTY, plan, keepAttributes);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
import org.elasticsearch.xpack.esql.analysis.UnmappedResolution;
5353
import org.elasticsearch.xpack.esql.analysis.Verifier;
5454
import org.elasticsearch.xpack.esql.approximation.Approximation;
55-
import org.elasticsearch.xpack.esql.approximation.ApproximationSettings;
5655
import org.elasticsearch.xpack.esql.core.expression.Attribute;
5756
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
5857
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
@@ -245,7 +244,6 @@ public void execute(
245244
ZoneId timeZone = request.timeZone() == null
246245
? statement.setting(QuerySettings.TIME_ZONE)
247246
: statement.settingOrDefault(QuerySettings.TIME_ZONE, request.timeZone());
248-
ApproximationSettings approximationSettings = statement.setting(QuerySettings.APPROXIMATION);
249247

250248
Configuration configuration = new Configuration(
251249
timeZone,
@@ -326,7 +324,6 @@ public void onResponse(Versioned<LogicalPlan> analyzedPlan) {
326324
foldContext,
327325
minimumVersion,
328326
planTimeProfile,
329-
logicalPlanOptimizer,
330327
l
331328
)
332329
)
@@ -363,7 +360,6 @@ public void executeOptimizedPlan(
363360
FoldContext foldContext,
364361
TransportVersion minimumVersion,
365362
PlanTimeProfile planTimeProfile,
366-
LogicalPlanOptimizer logicalPlanOptimizer,
367363
ActionListener<Result> listener
368364
) {
369365
assert ThreadPool.assertCurrentThreadPool(
@@ -396,11 +392,11 @@ public void executeOptimizedPlan(
396392
optimizedPlan,
397393
configuration,
398394
foldContext,
395+
minimumVersion,
399396
planRunner,
400397
executionInfo,
401398
request,
402399
statement,
403-
logicalPlanOptimizer,
404400
physicalPlanOptimizer,
405401
planTimeProfile,
406402
listener
@@ -412,11 +408,11 @@ private void executeSubPlans(
412408
LogicalPlan optimizedPlan,
413409
Configuration configuration,
414410
FoldContext foldContext,
411+
TransportVersion minimumVersion,
415412
PlanRunner runner,
416413
EsqlExecutionInfo executionInfo,
417414
EsqlQueryRequest request,
418415
EsqlStatement statement,
419-
LogicalPlanOptimizer logicalPlanOptimizer,
420416
PhysicalPlanOptimizer physicalPlanOptimizer,
421417
PlanTimeProfile planTimeProfile,
422418
ActionListener<Result> listener
@@ -448,17 +444,11 @@ private void executeSubPlans(
448444
optimizedPlan,
449445
statement.setting(QuerySettings.APPROXIMATION),
450446
executionInfo,
451-
logicalPlanOptimizer,
452-
p -> logicalPlanToPhysicalPlan(
453-
// TODO: don't run the full optimizer twice, because it may break things.
454-
optimizedPlan(p, logicalPlanOptimizer, planTimeProfile),
455-
request,
456-
physicalPlanOptimizer,
457-
planTimeProfile
458-
),
447+
p -> logicalPlanToPhysicalPlan(p, request, physicalPlanOptimizer, planTimeProfile),
459448
runner,
460449
configuration,
461450
foldContext,
451+
minimumVersion,
462452
planTimeProfile
463453
).approximate(listener);
464454
} else {

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -782,7 +782,6 @@ private ActualResults executePlan(BigArrays bigArrays) throws Exception {
782782
foldCtx,
783783
minimumVersion,
784784
planTimeProfile,
785-
logicalPlanOptimizer,
786785
listener.delegateFailureAndWrap(
787786
// Wrap so we can capture the warnings in the calling thread
788787
(next, result) -> next.onResponse(

0 commit comments

Comments
 (0)