Skip to content

Commit 0733fda

Browse files
committed
fix less than lookup on encrypted fields
1 parent 89322a2 commit 0733fda

File tree

3 files changed

+61
-14
lines changed

3 files changed

+61
-14
lines changed

django_mongodb_backend/lookups.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
BuiltinLookup,
55
FieldGetDbPrepValueIterableMixin,
66
IsNull,
7+
LessThan,
8+
LessThanOrEqual,
79
Lookup,
810
PatternLookup,
911
UUIDTextMixin,
@@ -101,6 +103,26 @@ def is_null_path(self, compiler, connection):
101103
return connection.mongo_operators["isnull"](lhs_mql, self.rhs)
102104

103105

106+
def less_than_path(self, compiler, connection):
107+
lhs_mql = process_lhs(self, compiler, connection)
108+
value = process_rhs(self, compiler, connection)
109+
# Encrypted fields don't support null and Automatic Encryption cannot
110+
# handle it ("csfle "analyze_query" failed: typenull type isn't supported
111+
# for the range encrypted index.), so omit the null check.
112+
if getattr(self.lhs.output_field, "encrypted", False):
113+
return {lhs_mql: {"$lt": value}}
114+
return connection.mongo_operators[self.lookup_name](lhs_mql, value)
115+
116+
117+
def less_than_or_equal_path(self, compiler, connection):
118+
lhs_mql = process_lhs(self, compiler, connection)
119+
value = process_rhs(self, compiler, connection)
120+
# Same comment as less_than_path.
121+
if getattr(self.lhs.output_field, "encrypted", False):
122+
return {lhs_mql: {"$lte": value}}
123+
return connection.mongo_operators[self.lookup_name](lhs_mql, value)
124+
125+
104126
# from https://www.pcre.org/current/doc/html/pcre2pattern.html#SEC4
105127
REGEX_MATCH_ESCAPE_CHARS = (
106128
("\\", r"\\"), # general escape character
@@ -157,6 +179,8 @@ def register_lookups():
157179
In.get_subquery_wrapping_pipeline = get_subquery_wrapping_pipeline
158180
IsNull.as_mql_expr = is_null_expr
159181
IsNull.as_mql_path = is_null_path
182+
LessThan.as_mql_path = less_than_path
183+
LessThanOrEqual.as_mql_path = less_than_or_equal_path
160184
Lookup.can_use_path = lookup_can_use_path
161185
PatternLookup.prep_lookup_value_mongo = pattern_lookup_prep_lookup_value
162186
UUIDTextMixin.as_mql = uuid_text_mixin

docs/ref/models/encrypted-fields.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,12 @@ The :ref:`available query types <manual:qe-fundamentals-encrypt-query>` depend
6565
on your version of MongoDB. For example, in MongoDB 8.0, the supported types
6666
are ``equality`` and ``range``.
6767

68-
.. admonition:: Query types vs. Django lookups
68+
The supported lookups for ``equality`` queries are: :lookup:`exact` and
69+
lookup:`in`. The supported operators are AND (``&``) and OR (``|``).
6970

70-
Range queries in Queryable Encryption are different from Django's
71-
:ref:`range lookups <django:field-lookups>`. Range queries allow you to
72-
perform comparisons on encrypted fields, while Django's range lookups are
73-
used for filtering based on a range of values.
71+
The supported lookups for ``range`` queries include those of ``equality``
72+
queries as well as :lookup:`lt`, :lookup:`lte`, :lookup:`gt`, and
73+
:lookup:`gte`.
7474

7575
\* These fields don't support the ``queries`` argument:
7676

tests/encryption_/test_fields.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from bson import ObjectId
77
from django.db import DatabaseError
8-
from django.db.models import Avg, F
8+
from django.db.models import Avg, F, Q
99

1010
from django_mongodb_backend.fields import (
1111
EncryptedArrayField,
@@ -94,18 +94,41 @@ def test_array(self):
9494

9595
class FieldTests(EncryptionTestCase):
9696
def assertEquality(self, model_cls, val):
97-
model_cls.objects.create(value=val)
98-
fetched = model_cls.objects.get(value=val)
99-
self.assertEqual(fetched.value, val)
97+
obj = model_cls.objects.create(value=val)
98+
self.assertEqual(model_cls.objects.get(value=val), obj)
99+
self.assertEqual(model_cls.objects.get(value__in=[val]), obj)
100+
self.assertQuerySetEqual(model_cls.objects.exclude(value=val), [])
100101

101102
def assertRange(self, model_cls, *, low, high, threshold):
102-
model_cls.objects.create(value=low)
103-
model_cls.objects.create(value=high)
103+
obj1 = model_cls.objects.create(value=low)
104+
obj2 = model_cls.objects.create(value=high)
104105
self.assertEqual(model_cls.objects.get(value=low).value, low)
105106
self.assertEqual(model_cls.objects.get(value=high).value, high)
106-
objs = list(model_cls.objects.filter(value__gt=threshold))
107-
self.assertEqual(len(objs), 1)
108-
self.assertEqual(objs[0].value, high)
107+
self.assertEqual(model_cls.objects.exclude(value=high).get().value, low)
108+
self.assertCountEqual(model_cls.objects.filter(Q(value=high) | Q(value=low)), [obj1, obj2])
109+
self.assertQuerySetEqual(
110+
model_cls.objects.filter(value__gt=threshold), [high], attrgetter("value")
111+
)
112+
self.assertQuerySetEqual(
113+
model_cls.objects.filter(value__gte=threshold), [high], attrgetter("value")
114+
)
115+
self.assertQuerySetEqual(
116+
model_cls.objects.filter(value__lt=threshold), [low], attrgetter("value")
117+
)
118+
self.assertQuerySetEqual(
119+
model_cls.objects.filter(value__lte=threshold), [low], attrgetter("value")
120+
)
121+
self.assertQuerySetEqual(
122+
model_cls.objects.filter(value__in=[low]), [low], attrgetter("value")
123+
)
124+
msg = (
125+
"Comparison disallowed between Queryable Encryption encrypted "
126+
"fields and non-constant expressions; field 'value' is encrypted."
127+
)
128+
with self.assertRaisesMessage(DatabaseError, msg):
129+
self.assertQuerySetEqual(
130+
model_cls.objects.filter(value__lte=F("value")), [low], attrgetter("value")
131+
)
109132

110133
# Equality-only fields
111134
def test_binary(self):

0 commit comments

Comments
 (0)