Skip to content

Commit

Permalink
Add ability to pass unknown in parse calls
Browse files Browse the repository at this point in the history
This adds support for passing the `unknown` parameter in two major
locations: Parser instantiation, and Parser.parse calls.
use_args and use_kwargs are just parse wrappers, and they need to pass
it through as well.

It also adds support for a class-level default for unknown,
`Parser.DEFAULT_UNKNOWN`, which sets `unknown` for any future parser
instances.

Explicit tweaks to handle this were necessary in asyncparser and
PyramidParser, due to odd method signatures.

Support is tested in the core tests, but not the various framework
tests.

Add a 6.2.0 (Unreleased) changelog entry with detail on this change.
The changelog states that we will change the DEFAULT_UNKNOWN default
in a future major release. Presumably we'll make it `EXCLUDE`, but I'd
like to make it location-dependent if feasible, so I didn't commit to
anything in the phrasing.
  • Loading branch information
sirosen committed May 13, 2020
1 parent 1bad8f9 commit 1acf429
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 18 deletions.
48 changes: 48 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,54 @@
Changelog
---------

6.2.0 (Unreleased)
******************

Features:

* Add a new ``unknown`` parameter to ``Parser.parse``, ``Parser.use_args``, and
``Parser.use_kwargs``. When set, it will be passed to the ``Schema.load``
call. If set to ``None`` (the default), no value is passed, so the schema's
``unknown`` behavior is used.

This allows usages like

.. code-block:: python
import marshmallow as ma
# marshmallow 3 only, for use of ``unknown`` and ``EXCLUDE``
@parser.use_kwargs(
{"q1": ma.fields.Int(), "q2": ma.fields.Int()}, location="query", unknown=ma.EXCLUDE
)
def foo(q1, q2):
...
* Add the ability to set defaults for ``unknown`` on either a Parser instance
or Parser class. Set ``Parser.DEFAULT_UNKNOWN`` on a parser class to apply a value
to any new parser instances created from that class, or set ``unknown`` during
``Parser`` initialization.

Usages are varied, but include

.. code-block:: python
import marshmallow as ma
from webargs.flaskparser import FlaskParser
parser = FlaskParser(unknown=ma.INCLUDE)
# as well as...
class MyParser(FlaskParser):
DEFAULT_UNKNOWN = ma.INCLUDE
parser = MyParser()
NOTE: No default value is set for ``DEFAULT_UNKNOWN`` for now, as doing so would be
a breaking change. A future major release will change this behavior.

6.1.0 (2020-04-05)
******************

Expand Down
11 changes: 10 additions & 1 deletion src/webargs/asyncparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from marshmallow.fields import Field
import marshmallow as ma

from webargs.compat import MARSHMALLOW_VERSION_INFO
from webargs import core

Request = typing.TypeVar("Request")
Expand All @@ -28,6 +29,7 @@ async def parse(
req: Request = None,
*,
location: str = None,
unknown: str = None,
validate: Validate = None,
error_status_code: typing.Union[int, None] = None,
error_headers: typing.Union[typing.Mapping[str, str], None] = None
Expand All @@ -38,6 +40,10 @@ async def parse(
"""
req = req if req is not None else self.get_default_request()
location = location or self.location
unknown = unknown or self.unknown
load_kwargs = (
{"unknown": unknown} if MARSHMALLOW_VERSION_INFO[0] >= 3 and unknown else {}
)
if req is None:
raise ValueError("Must pass req object")
data = None
Expand All @@ -47,7 +53,7 @@ async def parse(
location_data = await self._load_location_data(
schema=schema, req=req, location=location
)
result = schema.load(location_data)
result = schema.load(location_data, **load_kwargs)
data = result.data if core.MARSHMALLOW_VERSION_INFO[0] < 3 else result
self._validate_arguments(data, validators)
except ma.exceptions.ValidationError as error:
Expand Down Expand Up @@ -111,6 +117,7 @@ def use_args(
req: typing.Optional[Request] = None,
*,
location: str = None,
unknown=None,
as_kwargs: bool = False,
validate: Validate = None,
error_status_code: typing.Optional[int] = None,
Expand Down Expand Up @@ -143,6 +150,7 @@ async def wrapper(*args, **kwargs):
argmap,
req=req_obj,
location=location,
unknown=unknown,
validate=validate,
error_status_code=error_status_code,
error_headers=error_headers,
Expand All @@ -165,6 +173,7 @@ def wrapper(*args, **kwargs):
argmap,
req=req_obj,
location=location,
unknown=unknown,
validate=validate,
error_status_code=error_status_code,
error_headers=error_headers,
Expand Down
22 changes: 20 additions & 2 deletions src/webargs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,15 @@ class Parser:
etc.
:param str location: Default location to use for data
:param str unknown: Default value for ``unknown`` in ``parse``,
``use_args``, and ``use_kwargs``
:param callable error_handler: Custom error handler function.
"""

#: Default location to check for data
DEFAULT_LOCATION = "json"
#: Default value to use for 'unknown' on schema load
DEFAULT_UNKNOWN = None
#: The marshmallow Schema class to use when creating new schemas
DEFAULT_SCHEMA_CLASS = ma.Schema
#: Default status code to return for validation errors
Expand All @@ -125,10 +129,13 @@ class Parser:
"json_or_form": "load_json_or_form",
}

def __init__(self, location=None, *, error_handler=None, schema_class=None):
def __init__(
self, location=None, *, unknown=None, error_handler=None, schema_class=None
):
self.location = location or self.DEFAULT_LOCATION
self.error_callback = _callable_or_raise(error_handler)
self.schema_class = schema_class or self.DEFAULT_SCHEMA_CLASS
self.unknown = unknown or self.DEFAULT_UNKNOWN

def _get_loader(self, location):
"""Get the loader function for the given location.
Expand Down Expand Up @@ -222,6 +229,7 @@ def parse(
req=None,
*,
location=None,
unknown=None,
validate=None,
error_status_code=None,
error_headers=None
Expand All @@ -236,6 +244,8 @@ def parse(
Can be any of the values in :py:attr:`~__location_map__`. By
default, that means one of ``('json', 'query', 'querystring',
'form', 'headers', 'cookies', 'files', 'json_or_form')``.
:param str unknown: A value to pass for ``unknown`` when calling the
schema's ``load`` method (marshmallow 3 only).
:param callable validate: Validation function or list of validation functions
that receives the dictionary of parsed arguments. Validator either returns a
boolean or raises a :exc:`ValidationError`.
Expand All @@ -248,6 +258,10 @@ def parse(
"""
req = req if req is not None else self.get_default_request()
location = location or self.location
unknown = unknown or self.unknown
load_kwargs = (
{"unknown": unknown} if MARSHMALLOW_VERSION_INFO[0] >= 3 and unknown else {}
)
if req is None:
raise ValueError("Must pass req object")
data = None
Expand All @@ -257,7 +271,7 @@ def parse(
location_data = self._load_location_data(
schema=schema, req=req, location=location
)
result = schema.load(location_data)
result = schema.load(location_data, **load_kwargs)
data = result.data if MARSHMALLOW_VERSION_INFO[0] < 3 else result
self._validate_arguments(data, validators)
except ma.exceptions.ValidationError as error:
Expand Down Expand Up @@ -307,6 +321,7 @@ def use_args(
req=None,
*,
location=None,
unknown=None,
as_kwargs=False,
validate=None,
error_status_code=None,
Expand All @@ -325,6 +340,8 @@ def greet(args):
of argname -> `marshmallow.fields.Field` pairs, or a callable
which accepts a request and returns a `marshmallow.Schema`.
:param str location: Where on the request to load values.
:param str unknown: A value to pass for ``unknown`` when calling the
schema's ``load`` method (marshmallow 3 only).
:param bool as_kwargs: Whether to insert arguments as keyword arguments.
:param callable validate: Validation function that receives the dictionary
of parsed arguments. If the function returns ``False``, the parser
Expand Down Expand Up @@ -356,6 +373,7 @@ def wrapper(*args, **kwargs):
argmap,
req=req_obj,
location=location,
unknown=unknown,
validate=validate,
error_status_code=error_status_code,
error_headers=error_headers,
Expand Down
4 changes: 4 additions & 0 deletions src/webargs/pyramidparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def use_args(
req=None,
*,
location=core.Parser.DEFAULT_LOCATION,
unknown=None,
as_kwargs=False,
validate=None,
error_status_code=None,
Expand All @@ -127,6 +128,8 @@ def use_args(
which accepts a request and returns a `marshmallow.Schema`.
:param req: The request object to parse. Pulled off of the view by default.
:param str location: Where on the request to load values.
:param str unknown: A value to pass for ``unknown`` when calling the
schema's ``load`` method (marshmallow 3 only).
:param bool as_kwargs: Whether to insert arguments as keyword arguments.
:param callable validate: Validation function that receives the dictionary
of parsed arguments. If the function returns ``False``, the parser
Expand Down Expand Up @@ -155,6 +158,7 @@ def wrapper(obj, *args, **kwargs):
argmap,
req=request,
location=location,
unknown=unknown,
validate=validate,
error_status_code=error_status_code,
error_headers=error_headers,
Expand Down
78 changes: 63 additions & 15 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,11 @@ def test_parse(parser, web_request):
@pytest.mark.skipif(
MARSHMALLOW_VERSION_INFO[0] < 3, reason="unknown=... added in marshmallow3"
)
def test_parse_with_unknown_behavior_specified(parser, web_request):
@pytest.mark.parametrize(
"set_location",
["schema_instance", "parse_call", "parser_default", "parser_class_default"],
)
def test_parse_with_unknown_behavior_specified(parser, web_request, set_location):
# This is new in webargs 6.x ; it's the way you can "get back" the behavior
# of webargs 5.x in which extra args are ignored
from marshmallow import EXCLUDE, INCLUDE, RAISE
Expand All @@ -119,17 +123,65 @@ class CustomSchema(Schema):
username = fields.Field()
password = fields.Field()

def parse_with_desired_behavior(value):
if set_location == "schema_instance":
if value is not None:
return parser.parse(CustomSchema(unknown=value), web_request)
else:
return parser.parse(CustomSchema(), web_request)
elif set_location == "parse_call":
return parser.parse(CustomSchema(), web_request, unknown=value)
elif set_location == "parser_default":
parser.unknown = value
return parser.parse(CustomSchema(), web_request)
elif set_location == "parser_class_default":

class CustomParser(MockRequestParser):
DEFAULT_UNKNOWN = value

return CustomParser().parse(CustomSchema(), web_request)
else:
raise NotImplementedError

# with no unknown setting or unknown=RAISE, it blows up
with pytest.raises(ValidationError, match="Unknown field."):
parser.parse(CustomSchema(), web_request)
parse_with_desired_behavior(None)
with pytest.raises(ValidationError, match="Unknown field."):
parser.parse(CustomSchema(unknown=RAISE), web_request)
parse_with_desired_behavior(RAISE)

# with unknown=EXCLUDE the data is omitted
ret = parser.parse(CustomSchema(unknown=EXCLUDE), web_request)
ret = parse_with_desired_behavior(EXCLUDE)
assert {"username": 42, "password": 42} == ret
# with unknown=INCLUDE it is added even though it isn't part of the schema
ret = parser.parse(CustomSchema(unknown=INCLUDE), web_request)
ret = parse_with_desired_behavior(INCLUDE)
assert {"username": 42, "password": 42, "fjords": 42} == ret


@pytest.mark.skipif(
MARSHMALLOW_VERSION_INFO[0] < 3, reason="unknown=... added in marshmallow3"
)
def test_parse_with_explicit_unknown_overrides_schema(parser, web_request):
# this test ensures that if you specify unknown=... in your parse call (or
# use_args) it takes precedence over a setting in the schema object
from marshmallow import EXCLUDE, INCLUDE, RAISE

web_request.json = {"username": 42, "password": 42, "fjords": 42}

class CustomSchema(Schema):
username = fields.Field()
password = fields.Field()

# setting RAISE in the parse call overrides schema setting
with pytest.raises(ValidationError, match="Unknown field."):
parser.parse(CustomSchema(unknown=EXCLUDE), web_request, unknown=RAISE)
with pytest.raises(ValidationError, match="Unknown field."):
parser.parse(CustomSchema(unknown=INCLUDE), web_request, unknown=RAISE)

# and the reverse -- setting EXCLUDE or INCLUDE in the parse call overrides
# a schema with RAISE already set
ret = parser.parse(CustomSchema(unknown=RAISE), web_request, unknown=EXCLUDE)
assert {"username": 42, "password": 42} == ret
ret = parser.parse(CustomSchema(unknown=RAISE), web_request, unknown=INCLUDE)
assert {"username": 42, "password": 42, "fjords": 42} == ret


Expand Down Expand Up @@ -756,22 +808,18 @@ def test_warning_raised_if_schema_is_not_in_strict_mode(self, web_request, parse
assert "strict=True" in str(warning.message)

def test_use_kwargs_stacked(self, web_request, parser):
parse_kwargs = {}
if MARSHMALLOW_VERSION_INFO[0] >= 3:
from marshmallow import EXCLUDE

class PageSchema(Schema):
page = fields.Int()

pageschema = PageSchema(unknown=EXCLUDE)
userschema = self.UserSchema(unknown=EXCLUDE)
else:
pageschema = {"page": fields.Int()}
userschema = self.UserSchema(**strict_kwargs)
parse_kwargs = {"unknown": EXCLUDE}

web_request.json = {"email": "[email protected]", "password": "bar", "page": 42}

@parser.use_kwargs(pageschema, web_request)
@parser.use_kwargs(userschema, web_request)
@parser.use_kwargs({"page": fields.Int()}, web_request, **parse_kwargs)
@parser.use_kwargs(
self.UserSchema(**strict_kwargs), web_request, **parse_kwargs
)
def viewfunc(email, password, page):
return {"email": email, "password": password, "page": page}

Expand Down

0 comments on commit 1acf429

Please sign in to comment.