Skip to content

Commit 9e0ce49

Browse files
committed
CSHARP-5529: Optimize grouping.First().X to not retrieve the entire $$ROOT
1 parent 6216289 commit 9e0ce49

File tree

5 files changed

+151
-26
lines changed

5 files changed

+151
-26
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupingPipelineOptimizer.cs

+35
Original file line numberDiff line numberDiff line change
@@ -352,12 +352,47 @@ public override AstNode VisitFilterField(AstFilterField node)
352352

353353
public override AstNode VisitGetFieldExpression(AstGetFieldExpression node)
354354
{
355+
// { $getField : { field : <elementField>, input : { $firstOrLast : "$_elements" } } } => { __agg0 : { $firstOrLast : <rootField> } } + "$__agg0"
356+
if (IsGetFieldChainOnFirstOrLastElement(node, out var firstOrLastOperator, out var rootFieldExpression))
357+
{
358+
var unaryAccumulatorOperator = firstOrLastOperator == AstUnaryOperator.First ? AstUnaryAccumulatorOperator.First : AstUnaryAccumulatorOperator.Last;
359+
var accumulatorExpression = AstExpression.UnaryAccumulator(unaryAccumulatorOperator, rootFieldExpression);
360+
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
361+
return AstExpression.GetField(AstExpression.RootVar, accumulatorFieldName);
362+
}
363+
355364
if (node.FieldName.IsStringConstant("_elements"))
356365
{
357366
throw new UnableToRemoveReferenceToElementsException();
358367
}
359368

360369
return base.VisitGetFieldExpression(node);
370+
371+
bool IsGetFieldChainOnFirstOrLastElement(AstGetFieldExpression getFieldExpression, out AstUnaryOperator firstOrLastOperator, out AstExpression rootFieldExpression)
372+
{
373+
if (getFieldExpression.Input is AstGetFieldExpression innerGetFieldExpression &&
374+
IsGetFieldChainOnFirstOrLastElement(innerGetFieldExpression, out firstOrLastOperator, out rootFieldExpression))
375+
{
376+
rootFieldExpression = AstExpression.GetField(rootFieldExpression, getFieldExpression.FieldName);
377+
return true;
378+
}
379+
380+
if (getFieldExpression.Input is AstUnaryExpression unaryExpression &&
381+
unaryExpression.Operator is var unaryOperator &&
382+
(unaryOperator is AstUnaryOperator.First or AstUnaryOperator.Last) &&
383+
unaryExpression.Arg is AstGetFieldExpression innerMostGetFieldExpression &&
384+
innerMostGetFieldExpression.Input.IsRootVar() &&
385+
innerMostGetFieldExpression.FieldName.IsStringConstant("_elements"))
386+
{
387+
firstOrLastOperator = unaryOperator;
388+
rootFieldExpression = AstExpression.GetField(AstExpression.RootVar, getFieldExpression.FieldName);
389+
return true;
390+
}
391+
392+
firstOrLastOperator = default;
393+
rootFieldExpression = null;
394+
return false;
395+
}
361396
}
362397

363398
public override AstNode VisitMapExpression(AstMapExpression node)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System;
17+
using System.Collections.Generic;
18+
using System.Linq;
19+
using MongoDB.Driver.TestHelpers;
20+
using FluentAssertions;
21+
using Xunit;
22+
23+
namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira;
24+
25+
public class CSharp5529Tests : LinqIntegrationTest<CSharp5529Tests.ClassFixture>
26+
{
27+
public CSharp5529Tests(ClassFixture fixture)
28+
: base(fixture)
29+
{
30+
}
31+
32+
[Theory]
33+
[InlineData(1, 1, """{ $group: { _id : 1, __agg0 : { $first : "$X" } } }""", 1)]
34+
[InlineData(1, 2, """{ $group: { _id : 1, __agg0 : { $last : "$X" } } }""", 2)]
35+
[InlineData(2, 1, """{ $group: { _id : 1, __agg0 : { $first : "$D.Y" } } }""", 11)]
36+
[InlineData(2, 2, """{ $group: { _id : 1, __agg0 : { $last : "$D.Y" } } }""", 22)]
37+
[InlineData(3, 1, """{ $group: { _id : 1, __agg0 : { $first : "$D.E.Z" } } }""", 111)]
38+
[InlineData(3, 2, """{ $group: { _id : 1, __agg0 : { $last : "$D.E.Z" } } }""", 222)]
39+
public void First_or_Last_optimization_should_work(int level, int firstOrLast, string expectedGroupStage, int expectedResult)
40+
{
41+
var collection = Fixture.Collection;
42+
43+
var queryable = (level, firstOrLast) switch
44+
{
45+
(1, 1) => collection.Aggregate().Group(x => 1, g => g.First().X),
46+
(1, 2) => collection.Aggregate().Group(x => 1, g => g.Last().X),
47+
(2, 1) => collection.Aggregate().Group(x => 1, g => g.First().D.Y),
48+
(2, 2) => collection.Aggregate().Group(x => 1, g => g.Last().D.Y),
49+
(3, 1) => collection.Aggregate().Group(x => 1, g => g.First().D.E.Z),
50+
(3, 2) => collection.Aggregate().Group(x => 1, g => g.Last().D.E.Z),
51+
_ => throw new ArgumentException()
52+
};
53+
54+
var stages = Translate(collection,queryable);
55+
AssertStages(
56+
stages,
57+
expectedGroupStage,
58+
"""{ $project : { _v : "$__agg0", _id : 0 } }""");
59+
60+
var result = queryable.Single();
61+
result.Should().Be(expectedResult);
62+
}
63+
public class C
64+
{
65+
public int Id { get; set; }
66+
public int X { get; set; }
67+
68+
public D D { get; set; }
69+
}
70+
71+
public class D
72+
{
73+
public E E { get; set; }
74+
public int Y { get; set; }
75+
}
76+
77+
public class E
78+
{
79+
public int Z { get; set; }
80+
}
81+
82+
public sealed class ClassFixture : MongoCollectionFixture<C>
83+
{
84+
protected override IEnumerable<C> InitialData =>
85+
[
86+
new C { Id = 1, X = 1, D = new D { E = new E { Z = 111 }, Y = 11 } },
87+
new C { Id = 2, X = 2, D = new D { E = new E { Z = 222 }, Y = 22 } },
88+
];
89+
}
90+
}

tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/MongoQueryableTests.cs

+10-10
Original file line numberDiff line numberDiff line change
@@ -411,8 +411,8 @@ public void GroupBy_select_anonymous_type_method()
411411

412412
Assert(query,
413413
2,
414-
"{ $group: { _id: '$A', __agg0: { $first: '$$ROOT'} } }",
415-
"{ $project: { Key: '$_id', FirstB: '$__agg0.B', _id: 0 } }");
414+
"{ $group: { _id: '$A', __agg0: { $first: '$B'} } }",
415+
"{ $project: { Key: '$_id', FirstB: '$__agg0', _id: 0 } }");
416416

417417
query = CreateQuery()
418418
.GroupBy(x => x.A)
@@ -434,8 +434,8 @@ group p by p.A into g
434434

435435
Assert(query,
436436
2,
437-
"{ $group: { _id: '$A', __agg0: { $first: '$$ROOT'} } }",
438-
"{ $project: { Key: '$_id', FirstB: '$__agg0.B', _id: 0 } }");
437+
"{ $group: { _id: '$A', __agg0: { $first: '$B'} } }",
438+
"{ $project: { Key: '$_id', FirstB: '$__agg0', _id: 0 } }");
439439

440440
query = from p in CreateQuery()
441441
group p by p.A into g
@@ -484,9 +484,9 @@ public void GroupBy_where_select_anonymous_type_with_duplicate_accumulators_meth
484484

485485
Assert(query,
486486
1,
487-
"{ $group: { _id: '$A', __agg0: { $first: '$$ROOT'} } }",
487+
"{ $group: { _id: '$A', __agg0: { $first: '$$ROOT'}, __agg1 : { $first : '$B' } } }",
488488
"{ $match: { '__agg0.B' : 'Balloon' } }",
489-
"{ $project: { Key: '$_id', FirstB: '$__agg0.B', _id: 0 } }");
489+
"{ $project: { Key: '$_id', FirstB: '$__agg1', _id: 0 } }");
490490

491491
query = CreateQuery()
492492
.GroupBy(x => x.A)
@@ -511,9 +511,9 @@ where g.First().B == "Balloon"
511511

512512
Assert(query,
513513
1,
514-
"{ $group: { _id: '$A', __agg0: { $first: '$$ROOT'} } }",
514+
"{ $group: { _id: '$A', __agg0: { $first: '$$ROOT' }, __agg1 : { $first : '$B' } } }",
515515
"{ $match: { '__agg0.B' : 'Balloon' } }",
516-
"{ $project: { Key: '$_id', FirstB: '$__agg0.B', _id: 0 } }");
516+
"{ $project: { Key: '$_id', FirstB: '$__agg1', _id: 0 } }");
517517
}
518518
#endif
519519

@@ -525,8 +525,8 @@ public void GroupBy_with_resultSelector_anonymous_type_method()
525525

526526
Assert(query,
527527
2,
528-
"{ $group: { _id : '$A', __agg0 : { $first: '$$ROOT'} } }",
529-
"{ $project : { Key : '$_id', FirstB : '$__agg0.B', _id : 0 } }");
528+
"{ $group: { _id : '$A', __agg0 : { $first: '$B'} } }",
529+
"{ $project : { Key : '$_id', FirstB : '$__agg0', _id : 0 } }");
530530

531531
query = CreateQuery()
532532
.GroupBy(x => x.A, (k, s) => new { Key = k, FirstB = s.Select(x => x.B).First() });

tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs

+12-12
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ public void Should_translate_using_non_anonymous_type_with_default_constructor()
3939

4040
AssertStages(
4141
result.Stages,
42-
"{ $group : { _id : '$A', __agg0 : { $first : '$$ROOT' } } }",
43-
"{ $project : { Property : '$_id', Field : '$__agg0.B', _id : 0 } }");
42+
"{ $group : { _id : '$A', __agg0 : { $first : '$B' } } }",
43+
"{ $project : { Property : '$_id', Field : '$__agg0', _id : 0 } }");
4444

4545
result.Value.Property.Should().Be("Amazing");
4646
result.Value.Field.Should().Be("Baby");
@@ -53,8 +53,8 @@ public void Should_translate_using_non_anonymous_type_with_parameterized_constru
5353

5454
AssertStages(
5555
result.Stages,
56-
"{ $group : { _id : '$A', __agg0 : { $first : '$$ROOT' } } }",
57-
"{ $project : { Property : '$_id', Field : '$__agg0.B', _id : 0 } }");
56+
"{ $group : { _id : '$A', __agg0 : { $first : '$B' } } }",
57+
"{ $project : { Property : '$_id', Field : '$__agg0', _id : 0 } }");
5858

5959
result.Value.Property.Should().Be("Amazing");
6060
result.Value.Field.Should().Be("Baby");
@@ -236,8 +236,8 @@ public void Should_translate_first_with_normalization()
236236

237237
AssertStages(
238238
result.Stages,
239-
"{ $group : { _id : '$A', __agg0 : { $first : '$$ROOT' } } }",
240-
"{ $project : { B : '$__agg0.B', _id : 0 } }");
239+
"{ $group : { _id : '$A', __agg0 : { $first : '$B' } } }",
240+
"{ $project : { B : '$__agg0', _id : 0 } }");
241241

242242
result.Value.B.Should().Be("Baby");
243243
}
@@ -262,8 +262,8 @@ public void Should_translate_last_with_normalization()
262262

263263
AssertStages(
264264
result.Stages,
265-
"{ $group : { _id : '$A', __agg0 : { $last : '$$ROOT' } } }",
266-
"{ $project : { B : '$__agg0.B', _id : 0 } }");
265+
"{ $group : { _id : '$A', __agg0 : { $last : '$B' } } }",
266+
"{ $project : { B : '$__agg0', _id : 0 } }");
267267

268268
result.Value.B.Should().Be("Baby");
269269
}
@@ -492,8 +492,8 @@ public void Should_translate_complex_selector()
492492
_id : '$A',
493493
__agg0 : { $sum : 1 },
494494
__agg1 : { $sum : { $add : ['$C.E.F', '$C.E.H'] } },
495-
__agg2 : { $first : '$$ROOT' },
496-
__agg3 : { $last : '$$ROOT' },
495+
__agg2 : { $first : '$B' },
496+
__agg3 : { $last : '$K' },
497497
__agg4 : { $min : { $add : ['$C.E.F', '$C.E.H'] } },
498498
__agg5 : { $max : { $add : ['$C.E.F', '$C.E.H'] } }
499499
}
@@ -503,8 +503,8 @@ public void Should_translate_complex_selector()
503503
$project : {
504504
Count : '$__agg0',
505505
Sum : '$__agg1',
506-
First : '$__agg2.B',
507-
Last : '$__agg3.K',
506+
First : '$__agg2',
507+
Last : '$__agg3',
508508
Min : '$__agg4',
509509
Max : '$__agg5',
510510
_id : 0

tests/MongoDB.Driver.Tests/Samples/AggregationSample.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ where g.Sum(x => x.Population) > 20000
108108
select new { State = g.Key, TotalPopulation = g.Sum(x => x.Population) };
109109

110110
var stages = Linq3TestHelpers.Translate(collection, queryable);
111-
var expectedStages =
111+
var expectedStages =
112112
new[]
113113
{
114114
"{ $group : { _id : '$state', __agg0 : { $sum : '$pop' } } }",
@@ -173,13 +173,13 @@ public async Task Largest_and_smallest_cities_by_state()
173173
.SortBy(x => x.State);
174174

175175
var pipelineTranslation = pipeline.ToString();
176-
var expectedTranslation =
176+
var expectedTranslation =
177177
"aggregate([" +
178178
"{ \"$group\" : { \"_id\" : { \"State\" : \"$state\", \"City\" : \"$city\" }, \"__agg0\" : { \"$sum\" : \"$pop\" } } }, " +
179179
"{ \"$project\" : { \"StateAndCity\" : \"$_id\", \"Population\" : \"$__agg0\", \"_id\" : 0 } }, " +
180180
"{ \"$sort\" : { \"Population\" : 1 } }, " +
181-
"{ \"$group\" : { \"_id\" : \"$StateAndCity.State\", \"__agg0\" : { \"$last\" : \"$$ROOT\" }, \"__agg1\" : { \"$first\" : \"$$ROOT\" } } }, " +
182-
"{ \"$project\" : { \"State\" : \"$_id\", \"BiggestCity\" : \"$__agg0.StateAndCity.City\", \"BiggestPopulation\" : \"$__agg0.Population\", \"SmallestCity\" : \"$__agg1.StateAndCity.City\", \"SmallestPopulation\" : \"$__agg1.Population\", \"_id\" : 0 } }, " +
181+
"{ \"$group\" : { \"_id\" : \"$StateAndCity.State\", \"__agg0\" : { \"$last\" : \"$StateAndCity.City\" }, \"__agg1\" : { \"$last\" : \"$Population\" }, \"__agg2\" : { \"$first\" : \"$StateAndCity.City\" }, \"__agg3\" : { \"$first\" : \"$Population\" } } }, " +
182+
"{ \"$project\" : { \"State\" : \"$_id\", \"BiggestCity\" : \"$__agg0\", \"BiggestPopulation\" : \"$__agg1\", \"SmallestCity\" : \"$__agg2\", \"SmallestPopulation\" : \"$__agg3\", \"_id\" : 0 } }, " +
183183
"{ \"$project\" : { \"State\" : \"$State\", \"BiggestCity\" : { \"Name\" : \"$BiggestCity\", \"Population\" : \"$BiggestPopulation\" }, \"SmallestCity\" : { \"Name\" : \"$SmallestCity\", \"Population\" : \"$SmallestPopulation\" }, \"_id\" : 0 } }, " +
184184
"{ \"$sort\" : { \"State\" : 1 } }])";
185185
pipelineTranslation.Should().Be(expectedTranslation);

0 commit comments

Comments
 (0)