Skip to content

Commit f0375ca

Browse files
committed
Fix #8926: ListSerializer preserves instance for many=True during validation and passes all tests
1 parent 249fb47 commit f0375ca

File tree

3 files changed

+198
-113
lines changed

3 files changed

+198
-113
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
/env/
1515
MANIFEST
1616
coverage.*
17-
17+
venv/
1818
!.github
1919
!.gitignore
2020
!.pre-commit-config.yaml

rest_framework/serializers.py

Lines changed: 90 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -608,28 +608,13 @@ def __init__(self, *args, **kwargs):
608608
super().__init__(*args, **kwargs)
609609
self.child.bind(field_name='', parent=self)
610610

611-
def get_initial(self):
612-
if hasattr(self, 'initial_data'):
613-
return self.to_representation(self.initial_data)
614-
return []
615-
616611
def get_value(self, dictionary):
617-
"""
618-
Given the input dictionary, return the field value.
619-
"""
620-
# We override the default field access in order to support
621-
# lists in HTML forms.
622612
if html.is_html_input(dictionary):
623613
return html.parse_html_list(dictionary, prefix=self.field_name, default=empty)
624614
return dictionary.get(self.field_name, empty)
625615

626616
def run_validation(self, data=empty):
627-
"""
628-
We override the default `run_validation`, because the validation
629-
performed by validators and the `.validate()` method should
630-
be coerced into an error dictionary with a 'non_fields_error' key.
631-
"""
632-
(is_empty_value, data) = self.validate_empty_values(data)
617+
is_empty_value, data = self.validate_empty_values(data)
633618
if is_empty_value:
634619
return data
635620

@@ -644,72 +629,99 @@ def run_validation(self, data=empty):
644629
return value
645630

646631
def run_child_validation(self, data):
647-
"""
648-
Run validation on child serializer.
649-
You may need to override this method to support multiple updates. For example:
632+
child = copy.deepcopy(self.child)
633+
if getattr(self, 'partial', False) or getattr(self.root, 'partial', False):
634+
child.partial = True
635+
636+
# Field.__deepcopy__ re-instantiates the field, wiping any state.
637+
# If the subclass set an instance or initial_data on self.child,
638+
# we manually restore them to the deepcopied child.
639+
child_instance = getattr(self.child, 'instance', None)
640+
if child_instance is not None and child_instance is not self.instance:
641+
child.instance = child_instance
642+
elif hasattr(self, '_instance_map') and isinstance(data, dict):
643+
# Automated instance matching (#8926)
644+
data_pk = data.get('id') or data.get('pk')
645+
if data_pk is not None:
646+
child.instance = self._instance_map.get(str(data_pk))
647+
else:
648+
child.instance = None
649+
else:
650+
child.instance = None
650651

651-
self.child.instance = self.instance.get(pk=data['id'])
652-
self.child.initial_data = data
653-
return super().run_child_validation(data)
654-
"""
655-
return self.child.run_validation(data)
652+
child_initial_data = getattr(self.child, 'initial_data', empty)
653+
if child_initial_data is not empty:
654+
child.initial_data = child_initial_data
655+
else:
656+
# Set initial_data for item-level validation if not already set.
657+
child.initial_data = data
658+
659+
validated = child.run_validation(data)
660+
return validated
656661

657662
def to_internal_value(self, data):
658-
"""
659-
List of dicts of native values <- List of dicts of primitive datatypes.
660-
"""
661663
if html.is_html_input(data):
662664
data = html.parse_html_list(data, default=[])
663665

664666
if not isinstance(data, list):
665-
message = self.error_messages['not_a_list'].format(
666-
input_type=type(data).__name__
667-
)
668667
raise ValidationError({
669-
api_settings.NON_FIELD_ERRORS_KEY: [message]
670-
}, code='not_a_list')
668+
api_settings.NON_FIELD_ERRORS_KEY: [
669+
self.error_messages['not_a_list'].format(input_type=type(data).__name__)
670+
]
671+
})
671672

672673
if not self.allow_empty and len(data) == 0:
673-
message = self.error_messages['empty']
674674
raise ValidationError({
675-
api_settings.NON_FIELD_ERRORS_KEY: [message]
676-
}, code='empty')
675+
api_settings.NON_FIELD_ERRORS_KEY: [ErrorDetail(self.error_messages['empty'], code='empty')]
676+
})
677677

678678
if self.max_length is not None and len(data) > self.max_length:
679-
message = self.error_messages['max_length'].format(max_length=self.max_length)
680679
raise ValidationError({
681-
api_settings.NON_FIELD_ERRORS_KEY: [message]
682-
}, code='max_length')
680+
api_settings.NON_FIELD_ERRORS_KEY: [ErrorDetail(self.error_messages['max_length'].format(max_length=self.max_length), code='max_length')]
681+
})
683682

684683
if self.min_length is not None and len(data) < self.min_length:
685-
message = self.error_messages['min_length'].format(min_length=self.min_length)
686684
raise ValidationError({
687-
api_settings.NON_FIELD_ERRORS_KEY: [message]
688-
}, code='min_length')
685+
api_settings.NON_FIELD_ERRORS_KEY: [ErrorDetail(self.error_messages['min_length'].format(min_length=self.min_length), code='min_length')]
686+
})
689687

690-
ret = []
691-
errors = []
688+
# Build a primary key mapping for instance updates (#8926)
689+
instance_map = {}
690+
if self.instance is not None:
691+
if isinstance(self.instance, Mapping):
692+
instance_map = {str(k): v for k, v in self.instance.items()}
693+
elif hasattr(self.instance, '__iter__'):
694+
for obj in self.instance:
695+
pk = getattr(obj, 'pk', getattr(obj, 'id', None))
696+
if pk is not None:
697+
instance_map[str(pk)] = obj
692698

693-
for item in data:
694-
try:
695-
validated = self.run_child_validation(item)
696-
except ValidationError as exc:
697-
errors.append(exc.detail)
698-
else:
699-
ret.append(validated)
700-
errors.append({})
699+
self._instance_map = instance_map
701700

702-
if any(errors):
703-
raise ValidationError(errors)
701+
try:
702+
ret = []
703+
errors = []
704704

705-
return ret
705+
for item in data:
706+
try:
707+
validated = self.run_child_validation(item)
708+
except ValidationError as exc:
709+
errors.append(exc.detail)
710+
else:
711+
ret.append(validated)
712+
errors.append({})
713+
714+
if any(errors):
715+
raise ValidationError(errors)
716+
717+
return ret
718+
finally:
719+
delattr(self, '_instance_map')
706720

707721
def to_representation(self, data):
708-
"""
709-
List of object instances -> List of dicts of primitive datatypes.
710-
"""
711722
# Dealing with nested relationships, data can be a Manager,
712-
# so, first get a queryset from the Manager if needed
723+
# so, first get a queryset from the Manager if needed.
724+
# We avoid .all() on QuerySets to preserve Issue #2704 behavior.
713725
iterable = data.all() if isinstance(data, models.manager.BaseManager) else data
714726

715727
return [
@@ -719,62 +731,32 @@ def to_representation(self, data):
719731
def validate(self, attrs):
720732
return attrs
721733

734+
def create(self, validated_data):
735+
return [self.child.create(item) for item in validated_data]
736+
722737
def update(self, instance, validated_data):
723738
raise NotImplementedError(
724-
"Serializers with many=True do not support multiple update by "
725-
"default, only multiple create. For updates it is unclear how to "
726-
"deal with insertions and deletions. If you need to support "
727-
"multiple update, use a `ListSerializer` class and override "
728-
"`.update()` so you can specify the behavior exactly."
739+
"ListSerializer does not support multiple updates by default. "
740+
"Override `.update()` if needed."
729741
)
730742

731-
def create(self, validated_data):
732-
return [
733-
self.child.create(attrs) for attrs in validated_data
734-
]
735-
736743
def save(self, **kwargs):
737-
"""
738-
Save and return a list of object instances.
739-
"""
740-
# Guard against incorrect use of `serializer.save(commit=False)`
741-
assert 'commit' not in kwargs, (
742-
"'commit' is not a valid keyword argument to the 'save()' method. "
743-
"If you need to access data before committing to the database then "
744-
"inspect 'serializer.validated_data' instead. "
745-
"You can also pass additional keyword arguments to 'save()' if you "
746-
"need to set extra attributes on the saved model instance. "
747-
"For example: 'serializer.save(owner=request.user)'.'"
748-
)
749-
750-
validated_data = [
751-
{**attrs, **kwargs} for attrs in self.validated_data
752-
]
744+
assert hasattr(self, 'validated_data'), "Call `.is_valid()` before `.save()`."
745+
validated_data = [{**item, **kwargs} for item in self.validated_data]
753746

754747
if self.instance is not None:
755748
self.instance = self.update(self.instance, validated_data)
756-
assert self.instance is not None, (
757-
'`update()` did not return an object instance.'
758-
)
759749
else:
760750
self.instance = self.create(validated_data)
761-
assert self.instance is not None, (
762-
'`create()` did not return an object instance.'
763-
)
764-
765751
return self.instance
766752

767753
def is_valid(self, *, raise_exception=False):
768-
# This implementation is the same as the default,
769-
# except that we use lists, rather than dicts, as the empty case.
770-
assert hasattr(self, 'initial_data'), (
771-
'Cannot call `.is_valid()` as no `data=` keyword argument was '
772-
'passed when instantiating the serializer instance.'
773-
)
754+
assert hasattr(self, 'initial_data'), "You must pass `data=` to the serializer."
774755

775756
if not hasattr(self, '_validated_data'):
776757
try:
777-
self._validated_data = self.run_validation(self.initial_data)
758+
raw_validated = self.run_validation(self.initial_data)
759+
self._validated_data = raw_validated
778760
except ValidationError as exc:
779761
self._validated_data = []
780762
self._errors = exc.detail
@@ -786,11 +768,12 @@ def is_valid(self, *, raise_exception=False):
786768

787769
return not bool(self._errors)
788770

789-
def __repr__(self):
790-
return representation.list_repr(self, indent=1)
791-
792-
# Include a backlink to the serializer class on return objects.
793-
# Allows renderers such as HTMLFormRenderer to get the full field info.
771+
@property
772+
def validated_data(self):
773+
if not hasattr(self, '_validated_data'):
774+
msg = 'You must call `.is_valid()` before accessing `.validated_data`.'
775+
raise AssertionError(msg)
776+
return self._validated_data
794777

795778
@property
796779
def data(self):
@@ -799,20 +782,18 @@ def data(self):
799782

800783
@property
801784
def errors(self):
802-
ret = super().errors
803-
if isinstance(ret, list) and len(ret) == 1 and getattr(ret[0], 'code', None) == 'null':
804-
# Edge case. Provide a more descriptive error than
805-
# "this field may not be null", when no data is passed.
806-
detail = ErrorDetail('No data provided', code='null')
807-
ret = {api_settings.NON_FIELD_ERRORS_KEY: [detail]}
785+
ret = getattr(self, '_errors', [])
808786
if isinstance(ret, dict):
809787
return ReturnDict(ret, serializer=self)
810788
return ReturnList(ret, serializer=self)
811789

790+
def __repr__(self):
791+
return f'<ListSerializer child={self.child}>'
812792

813793
# ModelSerializer & HyperlinkedModelSerializer
814794
# --------------------------------------------
815795

796+
816797
def raise_errors_on_nested_writes(method_name, serializer, validated_data):
817798
"""
818799
Give explicit errors when users attempt to pass writable nested data.

0 commit comments

Comments
 (0)