From 59dd379f4331ee36c14a1fc3679cb5b5367e4205 Mon Sep 17 00:00:00 2001
From: Konstantin Alekseev <konstantin@kotify.com>
Date: Sun, 23 Jun 2024 16:54:08 +0300
Subject: [PATCH] Fix unique together validator doesn't respect condition's
 fields

---
 rest_framework/compat.py      | 36 +++++++++++++++++
 rest_framework/serializers.py | 40 +++++++++++--------
 rest_framework/validators.py  | 30 +++++++++++---
 tests/test_validators.py      | 74 +++++++++++++++++++++++++++--------
 4 files changed, 140 insertions(+), 40 deletions(-)

diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 27c5632be5..ff21bacff4 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -3,6 +3,9 @@
 versions of Django/Python, and compatibility wrappers around optional packages.
 """
 import django
+from django.db import models
+from django.db.models.constants import LOOKUP_SEP
+from django.db.models.sql.query import Node
 from django.views.generic import View
 
 
@@ -157,6 +160,10 @@ def md_filter_add_syntax_highlight(md):
     #       1) the list of validators and 2) the error message. Starting from
     #       Django 5.1 ip_address_validators only returns the list of validators
     from django.core.validators import ip_address_validators
+
+    def get_referenced_base_fields_from_q(q):
+        return q.referenced_base_fields
+
 else:
     # Django <= 5.1: create a compatibility shim for ip_address_validators
     from django.core.validators import \
@@ -165,6 +172,35 @@ def md_filter_add_syntax_highlight(md):
     def ip_address_validators(protocol, unpack_ipv4):
         return _ip_address_validators(protocol, unpack_ipv4)[0]
 
+    # Django < 5.1: create a compatibility shim for Q.referenced_base_fields
+    # https://github.com/django/django/blob/5.1a1/django/db/models/query_utils.py#L179
+    def _get_paths_from_expression(expr):
+        if isinstance(expr, models.F):
+            yield expr.name
+        elif hasattr(expr, 'flatten'):
+            for child in expr.flatten():
+                if isinstance(child, models.F):
+                    yield child.name
+                elif isinstance(child, models.Q):
+                    yield from _get_children_from_q(child)
+
+    def _get_children_from_q(q):
+        for child in q.children:
+            if isinstance(child, Node):
+                yield from _get_children_from_q(child)
+            elif isinstance(child, tuple):
+                lhs, rhs = child
+                yield lhs
+                if hasattr(rhs, 'resolve_expression'):
+                    yield from _get_paths_from_expression(rhs)
+            elif hasattr(child, 'resolve_expression'):
+                yield from _get_paths_from_expression(child)
+
+    def get_referenced_base_fields_from_q(q):
+        return {
+            child.split(LOOKUP_SEP, 1)[0] for child in _get_children_from_q(q)
+        }
+
 
 # `separators` argument to `json.dumps()` differs between 2.x and 3.x
 # See: https://bugs.python.org/issue22767
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index f37bd3a3d6..0b87aa8fc1 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -26,7 +26,9 @@
 from django.utils.functional import cached_property
 from django.utils.translation import gettext_lazy as _
 
-from rest_framework.compat import postgres_fields
+from rest_framework.compat import (
+    get_referenced_base_fields_from_q, postgres_fields
+)
 from rest_framework.exceptions import ErrorDetail, ValidationError
 from rest_framework.fields import get_error_detail
 from rest_framework.settings import api_settings
@@ -1425,20 +1427,20 @@ def get_extra_kwargs(self):
 
     def get_unique_together_constraints(self, model):
         """
-        Returns iterator of (fields, queryset), each entry describes an unique together
-        constraint on `fields` in `queryset`.
+        Returns iterator of (fields, queryset, condition_fields, condition),
+        each entry describes an unique together constraint on `fields` in `queryset`
+        with respect of constraint's `condition`.
         """
         for parent_class in [model] + list(model._meta.parents):
             for unique_together in parent_class._meta.unique_together:
-                yield unique_together, model._default_manager
+                yield unique_together, model._default_manager, [], None
             for constraint in parent_class._meta.constraints:
                 if isinstance(constraint, models.UniqueConstraint) and len(constraint.fields) > 1:
-                    yield (
-                        constraint.fields,
-                        model._default_manager
-                        if constraint.condition is None
-                        else model._default_manager.filter(constraint.condition)
-                    )
+                    if constraint.condition is None:
+                        condition_fields = []
+                    else:
+                        condition_fields = list(get_referenced_base_fields_from_q(constraint.condition))
+                    yield (constraint.fields, model._default_manager, condition_fields, constraint.condition)
 
     def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs):
         """
@@ -1470,9 +1472,10 @@ def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs
 
         # Include each of the `unique_together` and `UniqueConstraint` field names,
         # so long as all the field names are included on the serializer.
-        for unique_together_list, queryset in self.get_unique_together_constraints(model):
-            if set(field_names).issuperset(unique_together_list):
-                unique_constraint_names |= set(unique_together_list)
+        for unique_together_list, queryset, condition_fields, condition in self.get_unique_together_constraints(model):
+            unique_together_list_and_condition_fields = set(unique_together_list) | set(condition_fields)
+            if set(field_names).issuperset(unique_together_list_and_condition_fields):
+                unique_constraint_names |= unique_together_list_and_condition_fields
 
         # Now we have all the field names that have uniqueness constraints
         # applied, we can add the extra 'required=...' or 'default=...'
@@ -1594,12 +1597,13 @@ def get_unique_together_validators(self):
         # Note that we make sure to check `unique_together` both on the
         # base model class, but also on any parent classes.
         validators = []
-        for unique_together, queryset in self.get_unique_together_constraints(self.Meta.model):
+        for unique_together, queryset, condition_fields, condition in self.get_unique_together_constraints(self.Meta.model):
             # Skip if serializer does not map to all unique together sources
-            if not set(source_map).issuperset(unique_together):
+            unique_together_and_condition_fields = set(unique_together) | set(condition_fields)
+            if not set(source_map).issuperset(unique_together_and_condition_fields):
                 continue
 
-            for source in unique_together:
+            for source in unique_together_and_condition_fields:
                 assert len(source_map[source]) == 1, (
                     "Unable to create `UniqueTogetherValidator` for "
                     "`{model}.{field}` as `{serializer}` has multiple "
@@ -1618,7 +1622,9 @@ def get_unique_together_validators(self):
             field_names = tuple(source_map[f][0] for f in unique_together)
             validator = UniqueTogetherValidator(
                 queryset=queryset,
-                fields=field_names
+                fields=field_names,
+                condition_fields=tuple(source_map[f][0] for f in condition_fields),
+                condition=condition,
             )
             validators.append(validator)
         return validators
diff --git a/rest_framework/validators.py b/rest_framework/validators.py
index 71ebc2ca9f..a152c6362f 100644
--- a/rest_framework/validators.py
+++ b/rest_framework/validators.py
@@ -6,7 +6,9 @@
 object creation, and makes it possible to switch between using the implicit
 `ModelSerializer` class and an equivalent explicit `Serializer` class.
 """
+from django.core.exceptions import FieldError
 from django.db import DataError
+from django.db.models import Exists
 from django.utils.translation import gettext_lazy as _
 
 from rest_framework.exceptions import ValidationError
@@ -23,6 +25,17 @@ def qs_exists(queryset):
         return False
 
 
+def qs_exists_with_condition(queryset, condition, against):
+    if condition is None:
+        return qs_exists(queryset)
+    try:
+        # use the same query as UniqueConstraint.validate
+        # https://github.com/django/django/blob/7ba2a0db20c37a5b1500434ca4ed48022311c171/django/db/models/constraints.py#L672
+        return (condition & Exists(queryset.filter(condition))).check(against)
+    except (TypeError, ValueError, DataError, FieldError):
+        return False
+
+
 def qs_filter(queryset, **kwargs):
     try:
         return queryset.filter(**kwargs)
@@ -99,10 +112,12 @@ class UniqueTogetherValidator:
     missing_message = _('This field is required.')
     requires_context = True
 
-    def __init__(self, queryset, fields, message=None):
+    def __init__(self, queryset, fields, message=None, condition_fields=None, condition=None):
         self.queryset = queryset
         self.fields = fields
         self.message = message or self.message
+        self.condition_fields = [] if condition_fields is None else condition_fields
+        self.condition = condition
 
     def enforce_required_fields(self, attrs, serializer):
         """
@@ -114,7 +129,7 @@ def enforce_required_fields(self, attrs, serializer):
 
         missing_items = {
             field_name: self.missing_message
-            for field_name in self.fields
+            for field_name in (*self.fields, *self.condition_fields)
             if serializer.fields[field_name].source not in attrs
         }
         if missing_items:
@@ -173,16 +188,19 @@ def __call__(self, attrs, serializer):
                 if attrs[field_name] != getattr(serializer.instance, field_name)
             ]
 
-        if checked_values and None not in checked_values and qs_exists(queryset):
+        condition_kwargs = {source: attrs[source] for source in self.condition_fields}
+        if checked_values and None not in checked_values and qs_exists_with_condition(queryset, self.condition, condition_kwargs):
             field_names = ', '.join(self.fields)
             message = self.message.format(field_names=field_names)
             raise ValidationError(message, code='unique')
 
     def __repr__(self):
-        return '<%s(queryset=%s, fields=%s)>' % (
+        return '<{}({})>'.format(
             self.__class__.__name__,
-            smart_repr(self.queryset),
-            smart_repr(self.fields)
+            ', '.join(
+                f'{attr}={smart_repr(getattr(self, attr))}'
+                for attr in ('queryset', 'fields', 'condition')
+                if getattr(self, attr) is not None)
         )
 
     def __eq__(self, other):
diff --git a/tests/test_validators.py b/tests/test_validators.py
index 9c1a0eac31..5b6cd973ca 100644
--- a/tests/test_validators.py
+++ b/tests/test_validators.py
@@ -521,7 +521,7 @@ class UniqueConstraintModel(models.Model):
     race_name = models.CharField(max_length=100)
     position = models.IntegerField()
     global_id = models.IntegerField()
-    fancy_conditions = models.IntegerField(null=True)
+    fancy_conditions = models.IntegerField()
 
     class Meta:
         constraints = [
@@ -543,7 +543,12 @@ class Meta:
                 name="unique_constraint_model_together_uniq",
                 fields=('race_name', 'position'),
                 condition=models.Q(race_name='example'),
-            )
+            ),
+            models.UniqueConstraint(
+                name='unique_constraint_model_together_uniq2',
+                fields=('race_name', 'position'),
+                condition=models.Q(fancy_conditions__gte=10),
+            ),
         ]
 
 
@@ -576,17 +581,20 @@ def setUp(self):
         self.instance = UniqueConstraintModel.objects.create(
             race_name='example',
             position=1,
-            global_id=1
+            global_id=1,
+            fancy_conditions=1
         )
         UniqueConstraintModel.objects.create(
             race_name='example',
             position=2,
-            global_id=2
+            global_id=2,
+            fancy_conditions=1
         )
         UniqueConstraintModel.objects.create(
             race_name='other',
             position=1,
-            global_id=3
+            global_id=3,
+            fancy_conditions=1
         )
 
     def test_repr(self):
@@ -601,22 +609,55 @@ def test_repr(self):
                 position = IntegerField\(.*required=True\)
                 global_id = IntegerField\(.*validators=\[<UniqueValidator\(queryset=UniqueConstraintModel.objects.all\(\)\)>\]\)
                 class Meta:
-                    validators = \[<UniqueTogetherValidator\(queryset=<QuerySet \[<UniqueConstraintModel: UniqueConstraintModel object \(1\)>, <UniqueConstraintModel: UniqueConstraintModel object \(2\)>\]>, fields=\('race_name', 'position'\)\)>\]
+                    validators = \[<UniqueTogetherValidator\(queryset=UniqueConstraintModel.objects.all\(\), fields=\('race_name', 'position'\), condition=<Q: \(AND: \('race_name', 'example'\)\)>\)>\]
         """)
         assert re.search(expected, repr(serializer)) is not None
 
-    def test_unique_together_field(self):
+    def test_unique_together_condition(self):
         """
-        UniqueConstraint fields and condition attributes must be passed
-        to UniqueTogetherValidator as fields and queryset
+        Fields used in UniqueConstraint's condition must be included
+        into queryset existence check
         """
-        serializer = UniqueConstraintSerializer()
-        assert len(serializer.validators) == 1
-        validator = serializer.validators[0]
-        assert validator.fields == ('race_name', 'position')
-        assert set(validator.queryset.values_list(flat=True)) == set(
-            UniqueConstraintModel.objects.filter(race_name='example').values_list(flat=True)
+        UniqueConstraintModel.objects.create(
+            race_name='condition',
+            position=1,
+            global_id=10,
+            fancy_conditions=10,
         )
+        serializer = UniqueConstraintSerializer(data={
+            'race_name': 'condition',
+            'position': 1,
+            'global_id': 11,
+            'fancy_conditions': 9,
+        })
+        assert serializer.is_valid()
+        serializer = UniqueConstraintSerializer(data={
+            'race_name': 'condition',
+            'position': 1,
+            'global_id': 11,
+            'fancy_conditions': 11,
+        })
+        assert not serializer.is_valid()
+
+    def test_unique_together_condition_fields_required(self):
+        """
+        Fields used in UniqueConstraint's condition must be present in serializer
+        """
+        serializer = UniqueConstraintSerializer(data={
+            'race_name': 'condition',
+            'position': 1,
+            'global_id': 11,
+        })
+        assert not serializer.is_valid()
+        assert serializer.errors == {'fancy_conditions': ['This field is required.']}
+
+        class NoFieldsSerializer(serializers.ModelSerializer):
+            class Meta:
+                model = UniqueConstraintModel
+                fields = ('race_name', 'position', 'global_id')
+
+        serializer = NoFieldsSerializer()
+        assert len(serializer.validators) == 1
 
     def test_single_field_uniq_validators(self):
         """
@@ -625,9 +666,8 @@ def test_single_field_uniq_validators(self):
         """
         # Django 5 includes Max and Min values validators for IntergerField
         extra_validators_qty = 2 if django_version[0] >= 5 else 0
-        #
         serializer = UniqueConstraintSerializer()
-        assert len(serializer.validators) == 1
+        assert len(serializer.validators) == 2
         validators = serializer.fields['global_id'].validators
         assert len(validators) == 1 + extra_validators_qty
         assert validators[0].queryset == UniqueConstraintModel.objects