Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 8 additions & 15 deletions django_mongodb_backend/aggregates.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,7 @@
MONGO_AGGREGATIONS = {Count: "sum"}


def aggregate(
self,
compiler,
connection,
operator=None,
resolve_inner_expression=False,
**extra_context, # noqa: ARG001
):
def aggregate(self, compiler, connection, operator=None, resolve_inner_expression=False):
if self.filter:
node = self.copy()
node.filter = None
Expand All @@ -31,7 +24,7 @@ def aggregate(
return {f"${operator}": lhs_mql}


def count(self, compiler, connection, resolve_inner_expression=False, **extra_context): # noqa: ARG001
def count(self, compiler, connection, resolve_inner_expression=False):
"""
When resolve_inner_expression=True, return the MQL that resolves as a
value. This is used to count different elements, so the inner values are
Expand Down Expand Up @@ -64,16 +57,16 @@ def count(self, compiler, connection, resolve_inner_expression=False, **extra_co
return {"$add": [{"$size": lhs_mql}, exits_null]}


def stddev_variance(self, compiler, connection, **extra_context):
def stddev_variance(self, compiler, connection):
if self.function.endswith("_SAMP"):
operator = "stdDevSamp"
elif self.function.endswith("_POP"):
operator = "stdDevPop"
return aggregate(self, compiler, connection, operator=operator, **extra_context)
return aggregate(self, compiler, connection, operator=operator)


def register_aggregates():
Aggregate.as_mql = aggregate
Count.as_mql = count
StdDev.as_mql = stddev_variance
Variance.as_mql = stddev_variance
Aggregate.as_mql_expr = aggregate
Count.as_mql_expr = count
StdDev.as_mql_expr = stddev_variance
Variance.as_mql_expr = stddev_variance
58 changes: 50 additions & 8 deletions django_mongodb_backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .features import DatabaseFeatures
from .introspection import DatabaseIntrospection
from .operations import DatabaseOperations
from .query_utils import regex_match
from .query_utils import regex_expr, regex_match
from .schema import DatabaseSchemaEditor
from .utils import OperationDebugWrapper
from .validation import DatabaseValidation
Expand Down Expand Up @@ -108,7 +108,12 @@ def _isnull_operator(a, b):
}
return is_null if b else {"$not": is_null}

mongo_operators = {
def _isnull_operator_match(a, b):
if b:
return {"$or": [{a: {"$exists": False}}, {a: None}]}
return {"$and": [{a: {"$exists": True}}, {a: {"$ne": None}}]}

mongo_operators_expr = {
"exact": lambda a, b: {"$eq": [a, b]},
"gt": lambda a, b: {"$gt": [a, b]},
"gte": lambda a, b: {"$gte": [a, b]},
Expand All @@ -118,19 +123,56 @@ def _isnull_operator(a, b):
"lte": lambda a, b: {
"$and": [{"$lte": [a, b]}, DatabaseWrapper._isnull_operator(a, False)]
},
"in": lambda a, b: {"$in": [a, b]},
"in": lambda a, b: {"$in": (a, b)},
"isnull": _isnull_operator,
"range": lambda a, b: {
"$and": [
{"$or": [DatabaseWrapper._isnull_operator(b[0], True), {"$gte": [a, b[0]]}]},
{"$or": [DatabaseWrapper._isnull_operator(b[1], True), {"$lte": [a, b[1]]}]},
]
},
"iexact": lambda a, b: regex_match(a, ("^", b, {"$literal": "$"}), insensitive=True),
"startswith": lambda a, b: regex_match(a, ("^", b)),
"istartswith": lambda a, b: regex_match(a, ("^", b), insensitive=True),
"endswith": lambda a, b: regex_match(a, (b, {"$literal": "$"})),
"iendswith": lambda a, b: regex_match(a, (b, {"$literal": "$"}), insensitive=True),
"iexact": lambda a, b: regex_expr(a, ("^", b, {"$literal": "$"}), insensitive=True),
"startswith": lambda a, b: regex_expr(a, ("^", b)),
"istartswith": lambda a, b: regex_expr(a, ("^", b), insensitive=True),
"endswith": lambda a, b: regex_expr(a, (b, {"$literal": "$"})),
"iendswith": lambda a, b: regex_expr(a, (b, {"$literal": "$"}), insensitive=True),
"contains": lambda a, b: regex_expr(a, b),
"icontains": lambda a, b: regex_expr(a, b, insensitive=True),
"regex": lambda a, b: regex_expr(a, b),
"iregex": lambda a, b: regex_expr(a, b, insensitive=True),
}

def range_match(a, b):
## TODO: MAKE A TEST TO TEST WHEN BOTH ENDS ARE NONE. WHAT SHALL I RETURN?
conditions = []
if b[0] is not None:
conditions.append({a: {"$gte": b[0]}})
if b[1] is not None:
conditions.append({a: {"$lte": b[1]}})
if not conditions:
return {"$literal": True}
return {"$and": conditions}

mongo_operators_match = {
"exact": lambda a, b: {a: b},
"gt": lambda a, b: {a: {"$gt": b}},
"gte": lambda a, b: {a: {"$gte": b}},
# MongoDB considers null less than zero. Exclude null values to match
# SQL behavior.
"lt": lambda a, b: {
"$and": [{a: {"$lt": b}}, DatabaseWrapper._isnull_operator_match(a, False)]
},
"lte": lambda a, b: {
"$and": [{a: {"$lte": b}}, DatabaseWrapper._isnull_operator_match(a, False)]
},
"in": lambda a, b: {a: {"$in": tuple(b)}},
"isnull": _isnull_operator_match,
"range": range_match,
"iexact": lambda a, b: regex_match(a, f"^{b}$", insensitive=True),
"startswith": lambda a, b: regex_match(a, f"^{b}"),
"istartswith": lambda a, b: regex_match(a, f"^{b}", insensitive=True),
"endswith": lambda a, b: regex_match(a, f"{b}$"),
"iendswith": lambda a, b: regex_match(a, f"{b}$", insensitive=True),
"contains": lambda a, b: regex_match(a, b),
"icontains": lambda a, b: regex_match(a, b, insensitive=True),
"regex": lambda a, b: regex_match(a, b),
Expand Down
16 changes: 9 additions & 7 deletions django_mongodb_backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,14 +327,14 @@ def pre_sql_setup(self, with_col_aliases=False):
pipeline = self._build_aggregation_pipeline(ids, group)
if self.having:
having = self.having.replace_expressions(all_replacements).as_mql(
self, self.connection
self, self.connection, as_path=True
)
# Add HAVING subqueries.
for query in self.subqueries or ():
pipeline.extend(query.get_pipeline())
# Remove the added subqueries.
self.subqueries = []
pipeline.append({"$match": {"$expr": having}})
pipeline.append({"$match": having})
self.aggregation_pipeline = pipeline
self.annotations = {
target: expr.replace_expressions(all_replacements)
Expand Down Expand Up @@ -481,11 +481,11 @@ def build_query(self, columns=None):
query.lookup_pipeline = self.get_lookup_pipeline()
where = self.get_where()
try:
expr = where.as_mql(self, self.connection) if where else {}
expr = where.as_mql(self, self.connection, as_path=True) if where else {}
except FullResultSet:
query.match_mql = {}
else:
query.match_mql = {"$expr": expr}
query.match_mql = expr
if extra_fields:
query.extra_fields = self.get_project_fields(extra_fields, force_expression=True)
query.subqueries = self.subqueries
Expand Down Expand Up @@ -643,7 +643,9 @@ def get_combinator_queries(self):
for alias, expr in self.columns:
# Unfold foreign fields.
if isinstance(expr, Col) and expr.alias != self.collection_name:
ids[expr.alias][expr.target.column] = expr.as_mql(self, self.connection)
ids[expr.alias][expr.target.column] = expr.as_mql(
self, self.connection, as_path=False
)
else:
ids[alias] = f"${alias}"
# Convert defaultdict to dict so it doesn't appear as
Expand Down Expand Up @@ -714,9 +716,9 @@ def get_project_fields(self, columns=None, ordering=None, force_expression=False
value = (
False if empty_result_set_value is NotImplemented else empty_result_set_value
)
fields[collection][name] = Value(value).as_mql(self, self.connection)
fields[collection][name] = Value(value).as_mql(self, self.connection, as_path=False)
except FullResultSet:
fields[collection][name] = Value(True).as_mql(self, self.connection)
fields[collection][name] = Value(True).as_mql(self, self.connection, as_path=False)
# Annotations (stored in None) and the main collection's fields
# should appear in the top-level of the fields dict.
fields.update(fields.pop(None, {}))
Expand Down
74 changes: 47 additions & 27 deletions django_mongodb_backend/expressions/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from django.core.exceptions import EmptyResultSet, FullResultSet
from django.db import NotSupportedError
from django.db.models.expressions import (
BaseExpression,
Case,
Col,
ColPairs,
Expand Down Expand Up @@ -33,7 +34,7 @@ def case(self, compiler, connection):
for case in self.cases:
case_mql = {}
try:
case_mql["case"] = case.as_mql(compiler, connection)
case_mql["case"] = case.as_mql(compiler, connection, as_path=False)
except EmptyResultSet:
continue
except FullResultSet:
Expand Down Expand Up @@ -76,34 +77,34 @@ def col(self, compiler, connection, as_path=False): # noqa: ARG001
return f"{prefix}{self.target.column}"


def col_pairs(self, compiler, connection):
def col_pairs(self, compiler, connection, as_path=False):
cols = self.get_cols()
if len(cols) > 1:
raise NotSupportedError("ColPairs is not supported.")
return cols[0].as_mql(compiler, connection)
return cols[0].as_mql(compiler, connection, as_path=as_path)


def combined_expression(self, compiler, connection):
def combined_expression(self, compiler, connection, as_path=False):
expressions = [
self.lhs.as_mql(compiler, connection),
self.rhs.as_mql(compiler, connection),
self.lhs.as_mql(compiler, connection, as_path=as_path),
self.rhs.as_mql(compiler, connection, as_path=as_path),
]
return connection.ops.combine_expression(self.connector, expressions)


def expression_wrapper(self, compiler, connection):
return self.expression.as_mql(compiler, connection)
def expression_wrapper_expr(self, compiler, connection):
return self.expression.as_mql(compiler, connection, as_path=False)


def negated_expression(self, compiler, connection):
return {"$not": expression_wrapper(self, compiler, connection)}
def negated_expression_expr(self, compiler, connection):
return {"$not": expression_wrapper_expr(self, compiler, connection)}


def order_by(self, compiler, connection):
return self.expression.as_mql(compiler, connection)


def query(self, compiler, connection, get_wrapping_pipeline=None):
def query(self, compiler, connection, get_wrapping_pipeline=None, as_path=False):
subquery_compiler = self.get_compiler(connection=connection)
subquery_compiler.pre_sql_setup(with_col_aliases=False)
field_name, expr = subquery_compiler.columns[0]
Expand Down Expand Up @@ -145,14 +146,16 @@ def query(self, compiler, connection, get_wrapping_pipeline=None):
# Erase project_fields since the required value is projected above.
subquery.project_fields = None
compiler.subqueries.append(subquery)
if as_path:
return f"{table_output}.{field_name}"
return f"${table_output}.{field_name}"


def raw_sql(self, compiler, connection): # noqa: ARG001
raise NotSupportedError("RawSQL is not supported on MongoDB.")


def ref(self, compiler, connection): # noqa: ARG001
def ref(self, compiler, connection, as_path=False): # noqa: ARG001
prefix = (
f"{self.source.alias}."
if isinstance(self.source, Col) and self.source.alias != compiler.collection_name
Expand All @@ -162,32 +165,36 @@ def ref(self, compiler, connection): # noqa: ARG001
refs, _ = compiler.columns[self.ordinal - 1]
else:
refs = self.refs
return f"${prefix}{refs}"
if not as_path:
prefix = f"${prefix}"
return f"{prefix}{refs}"


def star(self, compiler, connection): # noqa: ARG001
def star(self, compiler, connection, as_path=False): # noqa: ARG001
return {"$literal": True}


def subquery(self, compiler, connection, get_wrapping_pipeline=None):
return self.query.as_mql(compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline)
return self.query.as_mql(
compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline, as_path=False
)


def exists(self, compiler, connection, get_wrapping_pipeline=None):
try:
lhs_mql = subquery(self, compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline)
except EmptyResultSet:
return Value(False).as_mql(compiler, connection)
return connection.mongo_operators["isnull"](lhs_mql, False)
return connection.mongo_operators_expr["isnull"](lhs_mql, False)


def when(self, compiler, connection):
return self.condition.as_mql(compiler, connection)
def when(self, compiler, connection, as_path=False):
return self.condition.as_mql(compiler, connection, as_path=as_path)


def value(self, compiler, connection): # noqa: ARG001
def value(self, compiler, connection, as_path=False): # noqa: ARG001
value = self.value
if isinstance(value, (list, int)):
if isinstance(value, (list, int)) and not as_path:
# Wrap lists & numbers in $literal to prevent ambiguity when Value
# appears in $project.
return {"$literal": value}
Expand All @@ -209,21 +216,34 @@ def value(self, compiler, connection): # noqa: ARG001
return value


def base_expression(self, compiler, connection, as_path=False, **extra):
if (
as_path
and hasattr(self, "as_mql_path")
and getattr(self, "is_simple_expression", lambda: False)()
):
return self.as_mql_path(compiler, connection, **extra)

expr = self.as_mql_expr(compiler, connection, **extra)
return {"$expr": expr} if as_path else expr


def register_expressions():
Case.as_mql = case
Case.as_mql_expr = case
Col.as_mql = col
ColPairs.as_mql = col_pairs
CombinedExpression.as_mql = combined_expression
Exists.as_mql = exists
CombinedExpression.as_mql_expr = combined_expression
Exists.as_mql_expr = exists
ExpressionList.as_mql = process_lhs
ExpressionWrapper.as_mql = expression_wrapper
NegatedExpression.as_mql = negated_expression
OrderBy.as_mql = order_by
ExpressionWrapper.as_mql_expr = expression_wrapper_expr
NegatedExpression.as_mql_expr = negated_expression_expr
OrderBy.as_mql_expr = order_by
Query.as_mql = query
RawSQL.as_mql = raw_sql
Ref.as_mql = ref
ResolvedOuterRef.as_mql = ResolvedOuterRef.as_sql
Star.as_mql = star
Subquery.as_mql = subquery
Subquery.as_mql_expr = subquery
When.as_mql = when
Value.as_mql = value
BaseExpression.as_mql = base_expression
10 changes: 6 additions & 4 deletions django_mongodb_backend/expressions/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,10 +933,12 @@ def __str__(self):
def __repr__(self):
return f"SearchText({self.lhs}, {self.rhs})"

def as_mql(self, compiler, connection):
lhs_mql = process_lhs(self, compiler, connection)
value = process_rhs(self, compiler, connection)
return {"$gte": [lhs_mql, value]}
def as_mql(self, compiler, connection, as_path=False):
lhs_mql = process_lhs(self, compiler, connection, as_path=as_path)
value = process_rhs(self, compiler, connection, as_path=as_path)
if as_path:
return {lhs_mql: {"$gte": value}}
return {"$expr": {"$gte": [lhs_mql, value]}}


CharField.register_lookup(SearchTextLookup)
Expand Down
3 changes: 0 additions & 3 deletions django_mongodb_backend/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,6 @@ class DatabaseFeatures(GISFeatures, BaseDatabaseFeatures):
"auth_tests.test_views.LoginTest.test_login_session_without_hash_session_key",
# GenericRelation.value_to_string() assumes integer pk.
"contenttypes_tests.test_fields.GenericRelationTests.test_value_to_string",
# icontains doesn't work on ArrayField:
# Unsupported conversion from array to string in $convert
"model_fields_.test_arrayfield.QueryingTests.test_icontains",
# ArrayField's contained_by lookup crashes with Exists: "both operands "
# of $setIsSubset must be arrays. Second argument is of type: null"
# https://jira.mongodb.org/browse/SERVER-99186
Expand Down
Loading