diff --git a/CHANGELOG.md b/CHANGELOG.md index fe63988..51e0221 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- Added support for overriding field kwargs for autogenerated serializer fields via `out` mechanism ([#99](https://github.com/dabapps/django-readers/pull/99)). + ## [2.2.0] - 2024-01-12 ### Added diff --git a/django_readers/rest_framework.py b/django_readers/rest_framework.py index de1eb22..98f523b 100644 --- a/django_readers/rest_framework.py +++ b/django_readers/rest_framework.py @@ -9,6 +9,20 @@ from rest_framework.utils import model_meta +def add_annotation(obj, key, value): + obj._readers_annotation = getattr(obj, "_readers_annotation", None) or {} + obj._readers_annotation[key] = value + + +def get_annotation(obj, key): + # Either the item itself or (if this is a pair) just the + # producer/projector function may have been decorated + if value := getattr(obj, "_readers_annotation", {}).get(key): + return value + if isinstance(obj, tuple): + return get_annotation(obj[1], key) + + class ProjectionSerializer: def __init__(self, data=None, many=False, context=None): self.many = many @@ -77,46 +91,43 @@ def __init__(self, model, name): def _lowercase_with_underscores_to_capitalized_words(self, string): return "".join(part.title() for part in string.split("_")) - def _prepare_field(self, field): + def _prepare_field(self, field, kwargs=None): # We copy the field so its _creation_counter is correct and # it appears in the right order in the resulting serializer. # We also force it to be read_only field = deepcopy(field) + if kwargs: + field._kwargs.update(kwargs) field._kwargs["read_only"] = True return field - def _get_out_value(self, item): - # Either the item itself or (if this is a pair) just the - # producer/projector function may have been decorated - if hasattr(item, "out"): - return item.out - if isinstance(item, tuple) and hasattr(item[1], "out"): - return item[1].out - return None - def visit_str(self, item): return self.visit_dict_item_str(item, item) def visit_dict_item_str(self, key, value): # This is a model field name. First, check if the # field has been explicitly overridden - if hasattr(value, "out"): - field = self._prepare_field(value.out) - self.fields[str(key)] = field - return key, field - - # No explicit override, so we can use ModelSerializer - # machinery to figure out which field type to use - field_class, field_kwargs = self.field_builder.build_field( - value, - self.info, - self.model, - 0, - ) - if key != value: - field_kwargs["source"] = value - field_kwargs.setdefault("read_only", True) - self.fields[key] = field_class(**field_kwargs) + if out := get_annotation(value, "field"): + field = self._prepare_field(out, kwargs=get_annotation(value, "kwargs")) + + else: + # No explicit override, so we can use ModelSerializer + # machinery to figure out which field type to use + field_class, field_kwargs = self.field_builder.build_field( + value, + self.info, + self.model, + 0, + ) + if key != value: + field_kwargs["source"] = value + field_kwargs.setdefault("read_only", True) + + if kwargs := get_annotation(value, "kwargs"): + field_kwargs.update(kwargs) + field = field_class(**field_kwargs) + + self.fields[str(key)] = field return key, value def _get_child_serializer_kwargs(self, rel_info): @@ -173,24 +184,26 @@ def visit_dict_item_dict(self, key, value): def visit_dict_item_tuple(self, key, value): # This is a producer pair. - out = self._get_out_value(value) + out = get_annotation(value, "field") + kwargs = get_annotation(value, "kwargs") or {} if out: - field = self._prepare_field(out) + field = self._prepare_field(out, kwargs) self.fields[key] = field else: # Fallback case: we don't know what field type to use - self.fields[key] = serializers.ReadOnlyField() + self.fields[key] = serializers.ReadOnlyField(**kwargs) return key, value visit_dict_item_callable = visit_dict_item_tuple def visit_tuple(self, item): # This is a projector pair. - out = self._get_out_value(item) + out = get_annotation(item, "field") + kwargs = get_annotation(item, "kwargs") or {} if out: # `out` is a dictionary mapping field names to Fields for name, field in out.items(): - field = self._prepare_field(field) + field = self._prepare_field(field, kwargs) self.fields[name] = field # There is no fallback case because we have no way of knowing the shape # of the returned dictionary, so the schema will be unavoidably incorrect. @@ -231,22 +244,28 @@ def serializer_class_for_view(view): return serializer_class_for_spec(name_prefix, model, view.spec) -class PairWithOutAttribute(tuple): - out = None +class PairWithAnnotation(tuple): + _readers_annotation = None -class StringWithOutAttribute(str): - out = None +class StringWithAnnotation(str): + _readers_annotation = None -def out(field_or_dict): - if isinstance(field_or_dict, dict): - if not all( - isinstance(item, serializers.Field) for item in field_or_dict.values() - ): - raise TypeError("Each value must be an instance of Field") - elif not isinstance(field_or_dict, serializers.Field): - raise TypeError("Must be an instance of Field") +def out(*args, **kwargs): + if args: + if len(args) != 1: + raise TypeError("Provide a single field or dictionary of fields") + field_or_dict = args[0] + if isinstance(field_or_dict, dict): + if not all( + isinstance(item, serializers.Field) for item in field_or_dict.values() + ): + raise TypeError("Each value must be an instance of Field") + elif not isinstance(field_or_dict, serializers.Field): + raise TypeError("Must be an instance of Field") + else: + field_or_dict = None class ShiftableDecorator: def __call__(self, item): @@ -257,15 +276,18 @@ def wrapper(*args, **kwargs): result = item(*args, **kwargs) return self(result) - wrapper.out = field_or_dict + add_annotation(wrapper, "field", field_or_dict) + add_annotation(wrapper, "kwargs", kwargs) return wrapper else: if isinstance(item, str): - item = StringWithOutAttribute(item) - item.out = field_or_dict + item = StringWithAnnotation(item) + add_annotation(item, "field", field_or_dict) + add_annotation(item, "kwargs", kwargs) if isinstance(item, tuple): - item = PairWithOutAttribute(item) - item.out = field_or_dict + item = PairWithAnnotation(item) + add_annotation(item, "field", field_or_dict) + add_annotation(item, "kwargs", kwargs) return item def __rrshift__(self, other): diff --git a/docs/reference/rest-framework.md b/docs/reference/rest-framework.md index 097a0af..e10289d 100644 --- a/docs/reference/rest-framework.md +++ b/docs/reference/rest-framework.md @@ -185,3 +185,17 @@ class SomeView(SpecMixin, RetrieveAPIView): ``` This mechanism can also be used to override the output field type for an autogenerated field (a string). + +### Overriding default behaviour + +Rather than providing a serializer field instance to `out`, you can optionally provide keyword arguments that will be used when constructing the _default_ serializer field that would otherwise be generated by model field introspection. This is particularly useful when using the generated serializers to create a schema, because schema generation libraries often use `label` and `help_text` to add metadata to fields in the schema. For example: + +```python hl_lines="5" +class SomeView(SpecMixin, RetrieveAPIView): + queryset = SomeModel.objects.all() + spec = [ + ..., + "title" >> out(label="The title of the object") + ..., + ] +``` diff --git a/tests/test_rest_framework.py b/tests/test_rest_framework.py index 546f43d..41c492d 100644 --- a/tests/test_rest_framework.py +++ b/tests/test_rest_framework.py @@ -416,6 +416,34 @@ def test_field_name_override(self): ) self.assertEqual(repr(cls()), expected) + def test_out_kwargs(self): + @out(serializers.CharField(), label="Hello label") + def produce_hello(_): + return "Hello" + + hello = qs.noop, produce_hello + + spec = [ + "name" >> out(help_text="Help for name"), + {"aliased_name": "name" >> out(label="Label for aliased name")}, + {"upper_name": out(help_text="Help for upper name")(upper_name)}, + # This is a bit redundant (kwargs could just be passed to the field + # directly) but should still work. + {"hello": hello}, + ] + + cls = serializer_class_for_spec("Category", Category, spec) + + expected = dedent( + """\ + CategorySerializer(): + name = CharField(help_text='Help for name', max_length=100, read_only=True) + aliased_name = CharField(label='Label for aliased name', max_length=100, read_only=True, source='name') + upper_name = ReadOnlyField(help_text='Help for upper name') + hello = CharField(label='Hello label', read_only=True)""" + ) + self.assertEqual(repr(cls()), expected) + def test_out_raises_with_field_class(self): with self.assertRaises(TypeError): out(serializers.CharField)