Skip to content

Commit 54c9b5c

Browse files
committed
Push conditions into lookup when is possible.
1 parent 215d3be commit 54c9b5c

File tree

4 files changed

+594
-20
lines changed

4 files changed

+594
-20
lines changed

django_mongodb_backend/compiler.py

Lines changed: 59 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@
99
from django.db.models.expressions import Case, Col, OrderBy, Ref, Value, When
1010
from django.db.models.functions.comparison import Coalesce
1111
from django.db.models.functions.math import Power
12-
from django.db.models.lookups import IsNull, Lookup
12+
from django.db.models.lookups import IsNull
1313
from django.db.models.sql import compiler
1414
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, SINGLE
1515
from django.db.models.sql.datastructures import BaseTable
16-
from django.db.models.sql.where import AND, WhereNode
16+
from django.db.models.sql.where import AND, OR, XOR, NothingNode, WhereNode
1717
from django.utils.functional import cached_property
1818
from pymongo import ASCENDING, DESCENDING
1919

2020
from .expressions.search import SearchExpression, SearchVector
2121
from .query import MongoQuery, wrap_database_errors
22-
from .query_utils import is_direct_value
22+
from .query_utils import is_constant_value, is_direct_value
2323

2424

2525
class SQLCompiler(compiler.SQLCompiler):
@@ -661,27 +661,70 @@ def get_combinator_queries(self):
661661
combinator_pipeline.append({"$unset": "_id"})
662662
return combinator_pipeline
663663

664+
def _get_pushable_conditions(self):
665+
"""
666+
Return a dict mapping each collection alias to the set of
667+
conditions that can be safely pushed down into its pipeline.
668+
"""
669+
670+
def collect_pushable(expr, negated=False):
671+
if expr is None or isinstance(expr, NothingNode):
672+
return {}
673+
if isinstance(expr, WhereNode):
674+
# Apply De Morgan: track negation so connectors are flipped when needed.
675+
negated ^= expr.negated
676+
pushable_expressions = [
677+
collect_pushable(sub_expr, negated=negated)
678+
for sub_expr in expr.children
679+
if sub_expr is not None
680+
]
681+
operator = expr.connector
682+
if operator == XOR:
683+
return {}
684+
if negated:
685+
operator = OR if operator == AND else AND
686+
alias_children = defaultdict(list)
687+
for pe in pushable_expressions:
688+
for alias, expressions in pe.items():
689+
alias_children[alias].append(expressions)
690+
# Build per-alias pushable condition nodes.
691+
if operator == AND:
692+
return {
693+
alias: WhereNode(children=children, negated=False, connector=operator)
694+
for alias, children in alias_children.items()
695+
}
696+
# Only aliases shared across all branches are pushable under OR.
697+
shared_alias = (
698+
set.intersection(*(set(pe) for pe in pushable_expressions))
699+
if pushable_expressions
700+
else set()
701+
)
702+
return {
703+
alias: WhereNode(children=children, negated=False, connector=operator)
704+
for alias, children in alias_children.items()
705+
if alias in shared_alias
706+
}
707+
# A leaf is pushable only when comparing a field to a constant/simple value.
708+
if isinstance(expr.lhs, Col) and (
709+
is_constant_value(expr.rhs) or getattr(expr.rhs, "is_simple_column", False)
710+
):
711+
alias = expr.lhs.alias
712+
expr = WhereNode(children=[expr], negated=negated)
713+
return {alias: expr}
714+
return {}
715+
716+
return collect_pushable(self.get_where())
717+
664718
def get_lookup_pipeline(self):
665719
result = []
666720
# To improve join performance, push conditions (filters) from the
667721
# WHERE ($match) clause to the JOIN ($lookup) clause.
668-
where = self.get_where()
669-
pushed_filters = defaultdict(list)
670-
for expr in where.children if where and where.connector == AND else ():
671-
# Push only basic lookups; no subqueries or complex conditions.
672-
# To avoid duplication across subqueries, only use the LHS target
673-
# table.
674-
if (
675-
isinstance(expr, Lookup)
676-
and isinstance(expr.lhs, Col)
677-
and (is_direct_value(expr.rhs) or isinstance(expr.rhs, (Value, Col)))
678-
):
679-
pushed_filters[expr.lhs.alias].append(expr)
722+
pushed_filters = self._get_pushable_conditions()
680723
for alias in tuple(self.query.alias_map):
681724
if not self.query.alias_refcount[alias] or self.collection_name == alias:
682725
continue
683726
result += self.query.alias_map[alias].as_mql(
684-
self, self.connection, WhereNode(pushed_filters[alias], connector=AND)
727+
self, self.connection, pushed_filters.get(alias)
685728
)
686729
return result
687730

django_mongodb_backend/query_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from django.core.exceptions import FullResultSet
2+
from django.db.models import F
23
from django.db.models.aggregates import Aggregate
34
from django.db.models.expressions import CombinedExpression, Func, Value
45
from django.db.models.sql.query import Query
@@ -67,7 +68,7 @@ def is_constant_value(value):
6768
else:
6869
constants_sub_expressions = True
6970
constants_sub_expressions = constants_sub_expressions and not (
70-
isinstance(value, Query)
71+
isinstance(value, Query | F)
7172
or value.contains_aggregate
7273
or value.contains_over_clause
7374
or value.contains_column_references

django_mongodb_backend/test.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,28 @@ class MongoTestCaseMixin:
77
maxDiff = None
88
query_types = {"SON": SON, "ObjectId": ObjectId, "Decimal128": Decimal128}
99

10+
COMMUTATIVE_OPERATORS = {"$and", "$or", "$all"}
11+
12+
@staticmethod
13+
def _normalize_query(obj):
14+
if isinstance(obj, dict):
15+
normalized = {}
16+
for k, v in obj.items():
17+
if k in MongoTestCaseMixin.COMMUTATIVE_OPERATORS and isinstance(v, list):
18+
# Only sort for commutative operators
19+
normalized[k] = sorted(
20+
(MongoTestCaseMixin._normalize_query(i) for i in v), key=lambda x: str(x)
21+
)
22+
else:
23+
normalized[k] = MongoTestCaseMixin._normalize_query(v)
24+
return normalized
25+
26+
if isinstance(obj, list):
27+
# Lists not under commutative ops keep their order
28+
return [MongoTestCaseMixin._normalize_query(i) for i in obj]
29+
30+
return obj
31+
1032
def assertAggregateQuery(self, query, expected_collection, expected_pipeline):
1133
"""
1234
Assert that the logged query is equal to:
@@ -16,7 +38,14 @@ def assertAggregateQuery(self, query, expected_collection, expected_pipeline):
1638
_, collection, operator = prefix.split(".")
1739
self.assertEqual(operator, "aggregate")
1840
self.assertEqual(collection, expected_collection)
19-
self.assertEqual(eval(pipeline[:-1], self.query_types, {}), expected_pipeline) # noqa: S307
41+
self.assertEqual(
42+
self._normalize_query(
43+
eval( # noqa: S307
44+
pipeline[:-1], {"SON": SON, "ObjectId": ObjectId, "Decimal128": Decimal128}, {}
45+
)
46+
),
47+
self._normalize_query(expected_pipeline),
48+
)
2049

2150
def assertInsertQuery(self, query, expected_collection, expected_documents):
2251
"""

0 commit comments

Comments
 (0)