Skip to content

Commit

Permalink
Merge pull request #99 from dabapps/override-serializer-args
Browse files Browse the repository at this point in the history
Override keyword args for automatically generated serializer fields
  • Loading branch information
j4mie authored Apr 17, 2024
2 parents 41a1b3e + 14149f5 commit 5fdd2f1
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 49 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added
- Added support for overriding field kwargs for autogenerated serializer fields via `out` mechanism ([#99](https://github.com/dabapps/django-readers/pull/99)).

## [2.2.0] - 2024-01-12

### Added
Expand Down
120 changes: 71 additions & 49 deletions django_readers/rest_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@
from rest_framework.utils import model_meta


def add_annotation(obj, key, value):
obj._readers_annotation = getattr(obj, "_readers_annotation", None) or {}
obj._readers_annotation[key] = value


def get_annotation(obj, key):
# Either the item itself or (if this is a pair) just the
# producer/projector function may have been decorated
if value := getattr(obj, "_readers_annotation", {}).get(key):
return value
if isinstance(obj, tuple):
return get_annotation(obj[1], key)


class ProjectionSerializer:
def __init__(self, data=None, many=False, context=None):
self.many = many
Expand Down Expand Up @@ -77,46 +91,43 @@ def __init__(self, model, name):
def _lowercase_with_underscores_to_capitalized_words(self, string):
return "".join(part.title() for part in string.split("_"))

def _prepare_field(self, field):
def _prepare_field(self, field, kwargs=None):
# We copy the field so its _creation_counter is correct and
# it appears in the right order in the resulting serializer.
# We also force it to be read_only
field = deepcopy(field)
if kwargs:
field._kwargs.update(kwargs)
field._kwargs["read_only"] = True
return field

def _get_out_value(self, item):
# Either the item itself or (if this is a pair) just the
# producer/projector function may have been decorated
if hasattr(item, "out"):
return item.out
if isinstance(item, tuple) and hasattr(item[1], "out"):
return item[1].out
return None

def visit_str(self, item):
return self.visit_dict_item_str(item, item)

def visit_dict_item_str(self, key, value):
# This is a model field name. First, check if the
# field has been explicitly overridden
if hasattr(value, "out"):
field = self._prepare_field(value.out)
self.fields[str(key)] = field
return key, field

# No explicit override, so we can use ModelSerializer
# machinery to figure out which field type to use
field_class, field_kwargs = self.field_builder.build_field(
value,
self.info,
self.model,
0,
)
if key != value:
field_kwargs["source"] = value
field_kwargs.setdefault("read_only", True)
self.fields[key] = field_class(**field_kwargs)
if out := get_annotation(value, "field"):
field = self._prepare_field(out, kwargs=get_annotation(value, "kwargs"))

else:
# No explicit override, so we can use ModelSerializer
# machinery to figure out which field type to use
field_class, field_kwargs = self.field_builder.build_field(
value,
self.info,
self.model,
0,
)
if key != value:
field_kwargs["source"] = value
field_kwargs.setdefault("read_only", True)

if kwargs := get_annotation(value, "kwargs"):
field_kwargs.update(kwargs)
field = field_class(**field_kwargs)

self.fields[str(key)] = field
return key, value

def _get_child_serializer_kwargs(self, rel_info):
Expand Down Expand Up @@ -173,24 +184,26 @@ def visit_dict_item_dict(self, key, value):

def visit_dict_item_tuple(self, key, value):
# This is a producer pair.
out = self._get_out_value(value)
out = get_annotation(value, "field")
kwargs = get_annotation(value, "kwargs") or {}
if out:
field = self._prepare_field(out)
field = self._prepare_field(out, kwargs)
self.fields[key] = field
else:
# Fallback case: we don't know what field type to use
self.fields[key] = serializers.ReadOnlyField()
self.fields[key] = serializers.ReadOnlyField(**kwargs)
return key, value

visit_dict_item_callable = visit_dict_item_tuple

def visit_tuple(self, item):
# This is a projector pair.
out = self._get_out_value(item)
out = get_annotation(item, "field")
kwargs = get_annotation(item, "kwargs") or {}
if out:
# `out` is a dictionary mapping field names to Fields
for name, field in out.items():
field = self._prepare_field(field)
field = self._prepare_field(field, kwargs)
self.fields[name] = field
# There is no fallback case because we have no way of knowing the shape
# of the returned dictionary, so the schema will be unavoidably incorrect.
Expand Down Expand Up @@ -231,22 +244,28 @@ def serializer_class_for_view(view):
return serializer_class_for_spec(name_prefix, model, view.spec)


class PairWithOutAttribute(tuple):
out = None
class PairWithAnnotation(tuple):
_readers_annotation = None


class StringWithOutAttribute(str):
out = None
class StringWithAnnotation(str):
_readers_annotation = None


def out(field_or_dict):
if isinstance(field_or_dict, dict):
if not all(
isinstance(item, serializers.Field) for item in field_or_dict.values()
):
raise TypeError("Each value must be an instance of Field")
elif not isinstance(field_or_dict, serializers.Field):
raise TypeError("Must be an instance of Field")
def out(*args, **kwargs):
if args:
if len(args) != 1:
raise TypeError("Provide a single field or dictionary of fields")
field_or_dict = args[0]
if isinstance(field_or_dict, dict):
if not all(
isinstance(item, serializers.Field) for item in field_or_dict.values()
):
raise TypeError("Each value must be an instance of Field")
elif not isinstance(field_or_dict, serializers.Field):
raise TypeError("Must be an instance of Field")
else:
field_or_dict = None

class ShiftableDecorator:
def __call__(self, item):
Expand All @@ -257,15 +276,18 @@ def wrapper(*args, **kwargs):
result = item(*args, **kwargs)
return self(result)

wrapper.out = field_or_dict
add_annotation(wrapper, "field", field_or_dict)
add_annotation(wrapper, "kwargs", kwargs)
return wrapper
else:
if isinstance(item, str):
item = StringWithOutAttribute(item)
item.out = field_or_dict
item = StringWithAnnotation(item)
add_annotation(item, "field", field_or_dict)
add_annotation(item, "kwargs", kwargs)
if isinstance(item, tuple):
item = PairWithOutAttribute(item)
item.out = field_or_dict
item = PairWithAnnotation(item)
add_annotation(item, "field", field_or_dict)
add_annotation(item, "kwargs", kwargs)
return item

def __rrshift__(self, other):
Expand Down
14 changes: 14 additions & 0 deletions docs/reference/rest-framework.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,17 @@ class SomeView(SpecMixin, RetrieveAPIView):
```

This mechanism can also be used to override the output field type for an autogenerated field (a string).

### Overriding default behaviour

Rather than providing a serializer field instance to `out`, you can optionally provide keyword arguments that will be used when constructing the _default_ serializer field that would otherwise be generated by model field introspection. This is particularly useful when using the generated serializers to create a schema, because schema generation libraries often use `label` and `help_text` to add metadata to fields in the schema. For example:

```python hl_lines="5"
class SomeView(SpecMixin, RetrieveAPIView):
queryset = SomeModel.objects.all()
spec = [
...,
"title" >> out(label="The title of the object")
...,
]
```
28 changes: 28 additions & 0 deletions tests/test_rest_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,34 @@ def test_field_name_override(self):
)
self.assertEqual(repr(cls()), expected)

def test_out_kwargs(self):
@out(serializers.CharField(), label="Hello label")
def produce_hello(_):
return "Hello"

hello = qs.noop, produce_hello

spec = [
"name" >> out(help_text="Help for name"),
{"aliased_name": "name" >> out(label="Label for aliased name")},
{"upper_name": out(help_text="Help for upper name")(upper_name)},
# This is a bit redundant (kwargs could just be passed to the field
# directly) but should still work.
{"hello": hello},
]

cls = serializer_class_for_spec("Category", Category, spec)

expected = dedent(
"""\
CategorySerializer():
name = CharField(help_text='Help for name', max_length=100, read_only=True)
aliased_name = CharField(label='Label for aliased name', max_length=100, read_only=True, source='name')
upper_name = ReadOnlyField(help_text='Help for upper name')
hello = CharField(label='Hello label', read_only=True)"""
)
self.assertEqual(repr(cls()), expected)

def test_out_raises_with_field_class(self):
with self.assertRaises(TypeError):
out(serializers.CharField)
Expand Down

0 comments on commit 5fdd2f1

Please sign in to comment.