diff --git a/tests/test_fields.py b/tests/test_fields.py index 68500763..7b497f41 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -11,6 +11,7 @@ item_from_fields, item_from_fields_sync, ) +from web_poet.fields import fields_dict @attrs.define @@ -271,3 +272,115 @@ def to_item(self) -> Item: page = ExtendedPage2(response=EXAMPLE_RESPONSE) item = page.to_item() assert item == Item(name="Hello!", price="$123") + + +def test_field_meta(): + class MyPage(ItemPage): + @field(meta={"good": True}) + def field1(self): + return "foo" + + @field + def field2(self): + return "foo" + + def to_item(self): + return item_from_fields_sync(self) + + page = MyPage() + for fields in [fields_dict(MyPage), fields_dict(page)]: + assert list(fields.keys()) == ["field1", "field2"] + assert fields["field1"].name == "field1" + assert fields["field1"].meta == {"good": True} + + assert fields["field2"].name == "field2" + assert fields["field2"].meta is None + + +def test_field_extra(): + @attrs.define + class OnlyNameItem: + name: str + + @attrs.define + class OnlyPriceItem: + price: str + + class BasePage(ItemPage): + item_cls = OnlyNameItem + + @field + def name(self): # noqa: D102 + return "name" + + @field(extra=True) + def price(self): # noqa: D102 + return "price" + + def to_item(self): # noqa: D102 + return item_from_fields_sync(self, self.item_cls) + + # BasePage contains field which is not in item class, + # but the field is defined as extra, so an exception is not raised + page = BasePage() + assert page.to_item() == OnlyNameItem(name="name") + + class FullItemPage(BasePage): + item_cls = Item + + # extra field is available in an item, so it's used now + page = FullItemPage() + assert page.to_item() == Item(name="name", price="price") + + class OnlyPricePage(BasePage): + item_cls = OnlyPriceItem + + # regular fields are always passed + page = OnlyPricePage() + with pytest.raises(TypeError, match="unexpected keyword argument 'name'"): + page.to_item() + + +@pytest.mark.asyncio +async def test_field_extra_async(): + @attrs.define + class OnlyNameItem: + name: str + + @attrs.define + class OnlyPriceItem: + price: str + + class BasePage(ItemPage): + item_cls = OnlyNameItem + + @field + async def name(self): # noqa: D102 + return "name" + + @field(extra=True) + async def price(self): # noqa: D102 + return "price" + + async def to_item(self): # noqa: D102 + return await item_from_fields(self, self.item_cls) + + # BasePage contains field which is not in item class, + # but the field is defined as extra, so an exception is not raised + page = BasePage() + assert await page.to_item() == OnlyNameItem(name="name") + + class FullItemPage(BasePage): + item_cls = Item + + # extra field is available in an item, so it's used now + page = FullItemPage() + assert await page.to_item() == Item(name="name", price="price") + + class OnlyPricePage(BasePage): + item_cls = OnlyPriceItem + + # regular fields are always passed + page = OnlyPricePage() + with pytest.raises(TypeError, match="unexpected keyword argument 'name'"): + await page.to_item() diff --git a/web_poet/fields.py b/web_poet/fields.py index cf7660cf..5c726112 100644 --- a/web_poet/fields.py +++ b/web_poet/fields.py @@ -25,15 +25,24 @@ async def to_item(self): """ from functools import update_wrapper +from typing import Dict, Optional +import attrs from itemadapter import ItemAdapter from web_poet.utils import cached_method, ensure_awaitable -_FIELDS_ATTRIBUTE = "_marked_as_fields" +_FIELDS_INFO_ATTRIBUTE = "_web_poet_fields_info" -def field(method=None, *, cached=False): +@attrs.define +class FieldInfo: + name: str + meta: Optional[dict] = None + extra: bool = False + + +def field(method=None, *, cached: bool = False, meta: Optional[dict] = None, extra: bool = False): """ Page Object method decorated with ``@field`` decorator becomes a property, which is used by :func:`item_from_fields` or :func:`item_from_fields_sync` @@ -41,6 +50,14 @@ def field(method=None, *, cached=False): By default, the value is computed on each property access. Use ``@field(cached=True)`` to cache the property value. + + Fields decorated with ``@field(extra=True)`` are not passed to item + classes by :func:`item_from_fields` if items don't support them, regardless + of ``item_cls_fields`` argument. + + ``meta`` parameter allows to store arbitrary information for the + field - e.g. ``@field(meta={"expensive": True})``. This information + can be later retrieved for all fields using :func:`fields_dict` function. """ class _field: @@ -53,10 +70,11 @@ def __init__(self, method): self.unbound_method = method def __set_name__(self, owner, name): - if not hasattr(owner, _FIELDS_ATTRIBUTE): - # dict is used instead of set to preserve the insertion order - setattr(owner, _FIELDS_ATTRIBUTE, {}) - getattr(owner, _FIELDS_ATTRIBUTE)[name] = True + if not hasattr(owner, _FIELDS_INFO_ATTRIBUTE): + setattr(owner, _FIELDS_INFO_ATTRIBUTE, {}) + + field_info = FieldInfo(name=name, meta=meta, extra=extra) + getattr(owner, _FIELDS_INFO_ATTRIBUTE)[name] = field_info def __get__(self, instance, owner=None): return self.unbound_method(instance) @@ -71,30 +89,47 @@ def __get__(self, instance, owner=None): return _field +def fields_dict(cls_or_instance) -> Dict[str, FieldInfo]: + """Return a dictionary with information about the fields defined + for the class""" + return getattr(cls_or_instance, _FIELDS_INFO_ATTRIBUTE, {}) + + async def item_from_fields(obj, item_cls=dict, *, item_cls_fields=False): """Return an item of ``item_cls`` type, with its attributes populated from the ``obj`` methods decorated with :class:`field` decorator. If ``item_cls_fields`` is True, ``@fields`` whose names don't match any of the ``item_cls`` attributes are not passed to ``item_cls.__init__``. + When ``item_cls_fields`` is False (default), all ``@fields`` are passed - to ``item_cls.__init__``. + to ``item_cls.__init__``, unless they're created with ``extra=True`` + argument. """ - item_dict = item_from_fields_sync(obj, item_cls=dict, item_cls_fields=False) - field_names = item_dict.keys() - if item_cls_fields: - field_names = _without_unsupported_field_names(item_cls, field_names) + field_names = _final_field_names(obj, item_cls, item_cls_fields) + item_dict = {name: getattr(obj, name) for name in field_names} return item_cls(**{name: await ensure_awaitable(item_dict[name]) for name in field_names}) def item_from_fields_sync(obj, item_cls=dict, *, item_cls_fields=False): """Synchronous version of :func:`item_from_fields`.""" - field_names = list(getattr(obj, _FIELDS_ATTRIBUTE, {})) - if item_cls_fields: - field_names = _without_unsupported_field_names(item_cls, field_names) + field_names = _final_field_names(obj, item_cls, item_cls_fields) return item_cls(**{name: getattr(obj, name) for name in field_names}) +def _final_field_names(obj, item_cls, item_cls_fields): + fields = fields_dict(obj) + extra_field_names = _without_unsupported_field_names( + item_cls, [info.name for info in fields.values() if info.extra] + ) + + regular_field_names = [info.name for info in fields.values() if not info.extra] + if item_cls_fields: + regular_field_names = _without_unsupported_field_names(item_cls, regular_field_names) + + return regular_field_names + extra_field_names + + def _without_unsupported_field_names(item_cls, field_names): item_field_names = ItemAdapter.get_field_names_from_class(item_cls) if item_field_names is None: # item_cls doesn't define field names upfront