diff --git a/AUTHORS.rst b/AUTHORS.rst index e88061426..42ac8f4f0 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -167,3 +167,4 @@ Contributors (chronological) - Ryan Morehart `@traherom `_ - Ben Windsor `@bwindsor `_ - Kevin Kirsche `@kkirsche `_ +- Dusko Simidzija `@dsimidzija `_ diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index d1cd88cc0..7af5b4a1a 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -19,6 +19,8 @@ missing as missing_, resolve_field_instance, is_aware, + _get_value_for_key, + _get_value_for_keys, ) from marshmallow.exceptions import ( ValidationError, @@ -310,6 +312,50 @@ def _validate_missing(self, value): if value is None and not self.allow_none: raise self.make_error("null") + def get_serializer( + self, + attr: str, + accessor: typing.Optional[ + typing.Callable[[typing.Any, str, typing.Any], typing.Any] + ] = None, + **kwargs, + ) -> typing.Callable[[typing.Any], typing.Any]: + """Return an optimized serializer for this Field object. + + :param str attr: The attribute or key on the object to be serialized. + :param dict kwargs: Field-specific keyword arguments. + :return: Serializer function. + """ + if not self._CHECK_ATTRIBUTE: + return lambda obj: self._serialize(None, attr, obj, **kwargs) + + attribute = getattr(self, "attribute", None) + check_key = attr if attribute is None else attribute + dump_default = None + callable_default = False + has_default = hasattr(self, "dump_default") + if has_default: + dump_default = self.dump_default + callable_default = callable(dump_default) + if accessor: + accessor_func = accessor + else: + if not isinstance(check_key, int) and "." in check_key: + accessor_func = _get_value_for_keys + check_key = check_key.split(".") + else: + accessor_func = _get_value_for_key + + def _serializer(obj): + value = accessor_func(obj, check_key, missing_) + if value is missing_ and has_default: + value = dump_default() if callable_default else dump_default + if value is missing_: + return value + return self._serialize(value, attr, obj, **kwargs) + + return _serializer + def serialize( self, attr: str, diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py index 12bb2f4c7..b755d53f1 100644 --- a/src/marshmallow/schema.py +++ b/src/marshmallow/schema.py @@ -394,6 +394,9 @@ def __init__( self.fields = {} # type: typing.Dict[str, ma_fields.Field] self.load_fields = {} # type: typing.Dict[str, ma_fields.Field] self.dump_fields = {} # type: typing.Dict[str, ma_fields.Field] + self.dump_serializers = ( + self.dict_class() + ) # type: typing.Dict[str, typing.Callable] self._init_fields() messages = {} messages.update(self._default_error_messages) @@ -466,7 +469,7 @@ def handle_error( """ pass - def get_attribute(self, obj: typing.Any, attr: str, default: typing.Any): + def default_get_attribute(self, obj: typing.Any, attr: str, default: typing.Any): """Defines how to pull values from an object to serialize. .. versionadded:: 2.0.0 @@ -476,6 +479,8 @@ def get_attribute(self, obj: typing.Any, attr: str, default: typing.Any): """ return get_value(obj, attr, default) + get_attribute = default_get_attribute + ##### Serialization/Deserialization API ##### @staticmethod @@ -510,19 +515,41 @@ def _serialize( .. versionchanged:: 1.0.0 Renamed from ``marshal``. """ - if many and obj is not None: - return [ - self._serialize(d, many=False) - for d in typing.cast(typing.Iterable[_T], obj) - ] - ret = self.dict_class() - for attr_name, field_obj in self.dump_fields.items(): - value = field_obj.serialize(attr_name, obj, accessor=self.get_attribute) - if value is missing: - continue - key = field_obj.data_key if field_obj.data_key is not None else attr_name - ret[key] = value - return ret + if not self.dump_serializers: + accessor = ( + None + ) # type: typing.Optional[typing.Callable[[typing.Any, str, typing.Any], typing.Any]] + if self.get_attribute != self.default_get_attribute: + accessor = self.get_attribute + + for field_name, field_obj in self.dump_fields.items(): + key = ( + field_obj.data_key if field_obj.data_key is not None else field_name + ) + self.dump_serializers[key] = field_obj.get_serializer( + field_name, accessor + ) + + source_obj = [None] # typing: typing.List[typing.Any] + + if not many: + source_obj = [typing.cast(typing.Any, obj)] + elif many and obj is not None: + source_obj = typing.cast(typing.List[typing.Any], obj) + + output = [] + for current_obj in source_obj: + ret = self.dict_class() + for key, serializer in self.dump_serializers.items(): + value = serializer(current_obj) + if value is missing: + continue + ret[key] = value + output.append(ret) + + if not many: + return output[0] + return output def dump(self, obj: typing.Any, *, many: typing.Optional[bool] = None): """Serialize an object to native Python data types according to this @@ -1015,6 +1042,7 @@ def _init_fields(self) -> None: self.fields = fields_dict self.dump_fields = dump_fields self.load_fields = load_fields + self.dump_serializers = self.dict_class() def on_bind_field(self, field_name: str, field_obj: ma_fields.Field) -> None: """Hook to modify a field when it is bound to the `Schema`.