|
5 | 5 |
|
6 | 6 | from bson import ObjectId |
7 | 7 | from django.db import DatabaseError |
8 | | -from django.db.models import Avg, F |
| 8 | +from django.db.models import Avg, F, Q |
9 | 9 |
|
10 | 10 | from django_mongodb_backend.fields import ( |
11 | 11 | EncryptedArrayField, |
@@ -94,18 +94,41 @@ def test_array(self): |
94 | 94 |
|
95 | 95 | class FieldTests(EncryptionTestCase): |
96 | 96 | def assertEquality(self, model_cls, val): |
97 | | - model_cls.objects.create(value=val) |
98 | | - fetched = model_cls.objects.get(value=val) |
99 | | - self.assertEqual(fetched.value, val) |
| 97 | + obj = model_cls.objects.create(value=val) |
| 98 | + self.assertEqual(model_cls.objects.get(value=val), obj) |
| 99 | + self.assertEqual(model_cls.objects.get(value__in=[val]), obj) |
| 100 | + self.assertQuerySetEqual(model_cls.objects.exclude(value=val), []) |
100 | 101 |
|
101 | 102 | def assertRange(self, model_cls, *, low, high, threshold): |
102 | | - model_cls.objects.create(value=low) |
103 | | - model_cls.objects.create(value=high) |
| 103 | + obj1 = model_cls.objects.create(value=low) |
| 104 | + obj2 = model_cls.objects.create(value=high) |
104 | 105 | self.assertEqual(model_cls.objects.get(value=low).value, low) |
105 | 106 | self.assertEqual(model_cls.objects.get(value=high).value, high) |
106 | | - objs = list(model_cls.objects.filter(value__gt=threshold)) |
107 | | - self.assertEqual(len(objs), 1) |
108 | | - self.assertEqual(objs[0].value, high) |
| 107 | + self.assertEqual(model_cls.objects.exclude(value=high).get().value, low) |
| 108 | + self.assertCountEqual(model_cls.objects.filter(Q(value=high) | Q(value=low)), [obj1, obj2]) |
| 109 | + self.assertQuerySetEqual( |
| 110 | + model_cls.objects.filter(value__gt=threshold), [high], attrgetter("value") |
| 111 | + ) |
| 112 | + self.assertQuerySetEqual( |
| 113 | + model_cls.objects.filter(value__gte=threshold), [high], attrgetter("value") |
| 114 | + ) |
| 115 | + self.assertQuerySetEqual( |
| 116 | + model_cls.objects.filter(value__lt=threshold), [low], attrgetter("value") |
| 117 | + ) |
| 118 | + self.assertQuerySetEqual( |
| 119 | + model_cls.objects.filter(value__lte=threshold), [low], attrgetter("value") |
| 120 | + ) |
| 121 | + self.assertQuerySetEqual( |
| 122 | + model_cls.objects.filter(value__in=[low]), [low], attrgetter("value") |
| 123 | + ) |
| 124 | + msg = ( |
| 125 | + "Comparison disallowed between Queryable Encryption encrypted " |
| 126 | + "fields and non-constant expressions; field 'value' is encrypted." |
| 127 | + ) |
| 128 | + with self.assertRaisesMessage(DatabaseError, msg): |
| 129 | + self.assertQuerySetEqual( |
| 130 | + model_cls.objects.filter(value__lte=F("value")), [low], attrgetter("value") |
| 131 | + ) |
109 | 132 |
|
110 | 133 | # Equality-only fields |
111 | 134 | def test_binary(self): |
|
0 commit comments