diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 8fe284bc84..de737fc3b9 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -1569,6 +1569,17 @@ def get_validators(self): self.get_unique_for_date_validators() ) + def _get_constraint_violation_error_message(self, constraint): + """ + Returns the violation error message for the UniqueConstraint, + or None if the message is the default. + """ + violation_error_message = constraint.get_violation_error_message() + default_error_message = constraint.default_violation_error_message % {"name": constraint.name} + if violation_error_message == default_error_message: + return None + return violation_error_message + def get_unique_together_validators(self): """ Determine a default set of validators for any unique_together constraints. @@ -1595,6 +1606,13 @@ def get_unique_together_validators(self): for name, source in field_sources.items(): source_map[source].append(name) + unique_constraint_by_fields = { + constraint.fields: constraint + for model_cls in (self.Meta.model, *self.Meta.model._meta.parents) + for constraint in model_cls._meta.constraints + if isinstance(constraint, models.UniqueConstraint) + } + # Note that we make sure to check `unique_together` both on the # base model class, but also on any parent classes. validators = [] @@ -1621,11 +1639,17 @@ def get_unique_together_validators(self): ) field_names = tuple(source_map[f][0] for f in unique_together) + + constraint = unique_constraint_by_fields.get(tuple(unique_together)) + violation_error_message = self._get_constraint_violation_error_message(constraint) if constraint else None + validator = UniqueTogetherValidator( queryset=queryset, fields=field_names, condition_fields=tuple(source_map[f][0] for f in condition_fields), condition=condition, + message=violation_error_message, + code=getattr(constraint, 'violation_error_code', None), ) validators.append(validator) return validators diff --git a/rest_framework/validators.py b/rest_framework/validators.py index 76d2a2159f..cc759b39cc 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -111,13 +111,15 @@ class UniqueTogetherValidator: message = _('The fields {field_names} must make a unique set.') missing_message = _('This field is required.') requires_context = True + code = 'unique' - def __init__(self, queryset, fields, message=None, condition_fields=None, condition=None): + def __init__(self, queryset, fields, message=None, condition_fields=None, condition=None, code=None): self.queryset = queryset self.fields = fields self.message = message or self.message self.condition_fields = [] if condition_fields is None else condition_fields self.condition = condition + self.code = code or self.code def enforce_required_fields(self, attrs, serializer): """ @@ -198,7 +200,7 @@ def __call__(self, attrs, serializer): if checked_values and None not in checked_values and qs_exists_with_condition(queryset, self.condition, condition_kwargs): field_names = ', '.join(self.fields) message = self.message.format(field_names=field_names) - raise ValidationError(message, code='unique') + raise ValidationError(message, code=self.code) def __repr__(self): return '<{}({})>'.format( @@ -217,6 +219,7 @@ def __eq__(self, other): and self.missing_message == other.missing_message and self.queryset == other.queryset and self.fields == other.fields + and self.code == other.code ) diff --git a/tests/test_validators.py b/tests/test_validators.py index ea5bf3a4dd..79d4c0cf81 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -616,6 +616,26 @@ class Meta: ] +class UniqueConstraintCustomMessageCodeModel(models.Model): + username = models.CharField(max_length=32) + company_id = models.IntegerField() + role = models.CharField(max_length=32) + + class Meta: + constraints = [ + models.UniqueConstraint( + fields=("username", "company_id"), + name="unique_username_company_custom_msg", + violation_error_message="Username must be unique within a company.", + **(dict(violation_error_code="duplicate_username") if django_version[0] >= 5 else {}), + ), + models.UniqueConstraint( + fields=("company_id", "role"), + name="unique_company_role_default_msg", + ), + ] + + class UniqueConstraintSerializer(serializers.ModelSerializer): class Meta: model = UniqueConstraintModel @@ -628,6 +648,12 @@ class Meta: fields = ('title', 'age', 'tag') +class UniqueConstraintCustomMessageCodeSerializer(serializers.ModelSerializer): + class Meta: + model = UniqueConstraintCustomMessageCodeModel + fields = ('username', 'company_id', 'role') + + class TestUniqueConstraintValidation(TestCase): def setUp(self): self.instance = UniqueConstraintModel.objects.create( @@ -778,6 +804,31 @@ class Meta: ) assert serializer.is_valid() + def test_unique_constraint_custom_message_code(self): + UniqueConstraintCustomMessageCodeModel.objects.create(username="Alice", company_id=1, role="member") + expected_code = "duplicate_username" if django_version[0] >= 5 else UniqueTogetherValidator.code + + serializer = UniqueConstraintCustomMessageCodeSerializer(data={ + "username": "Alice", + "company_id": 1, + "role": "admin", + }) + assert not serializer.is_valid() + assert serializer.errors == {"non_field_errors": ["Username must be unique within a company."]} + assert serializer.errors["non_field_errors"][0].code == expected_code + + def test_unique_constraint_default_message_code(self): + UniqueConstraintCustomMessageCodeModel.objects.create(username="Alice", company_id=1, role="member") + serializer = UniqueConstraintCustomMessageCodeSerializer(data={ + "username": "John", + "company_id": 1, + "role": "member", + }) + expected_message = UniqueTogetherValidator.message.format(field_names=', '.join(("company_id", "role"))) + assert not serializer.is_valid() + assert serializer.errors == {"non_field_errors": [expected_message]} + assert serializer.errors["non_field_errors"][0].code == UniqueTogetherValidator.code + # Tests for `UniqueForDateValidator` # ----------------------------------