Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to pass unknown in parse calls #514

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,51 @@
Changelog
---------

6.2.0 (Unreleased)
******************

Features:

* Add a new ``unknown`` parameter to ``Parser.parse``, ``Parser.use_args``, and
``Parser.use_kwargs``. When set, it will be passed to the ``Schema.load``
call. If set to ``None`` (the default), no value is passed, so the schema's
``unknown`` behavior is used.

This allows usages like

.. code-block:: python

import marshmallow as ma

# marshmallow 3 only, for use of ``unknown`` and ``EXCLUDE``
@parser.use_kwargs(
{"q1": ma.fields.Int(), "q2": ma.fields.Int()}, location="query", unknown=ma.EXCLUDE
)
def foo(q1, q2):
...

* Add the ability to set defaults for ``unknown`` on either a Parser instance
or Parser class. Set ``Parser.DEFAULT_UNKNOWN`` on a parser class to apply a value
to any new parser instances created from that class, or set ``unknown`` during
``Parser`` initialization.

Usages are varied, but include

.. code-block:: python

import marshmallow as ma
from webargs.flaskparser import FlaskParser

parser = FlaskParser(unknown=ma.INCLUDE)

# as well as...
class MyParser(FlaskParser):
DEFAULT_UNKNOWN = ma.INCLUDE


parser = MyParser()


6.1.0 (2020-04-05)
******************

Expand Down
11 changes: 10 additions & 1 deletion src/webargs/asyncparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from marshmallow.fields import Field
import marshmallow as ma

from webargs.compat import MARSHMALLOW_VERSION_INFO
from webargs import core

Request = typing.TypeVar("Request")
Expand All @@ -28,6 +29,7 @@ async def parse(
req: Request = None,
*,
location: str = None,
unknown: str = None,
validate: Validate = None,
error_status_code: typing.Union[int, None] = None,
error_headers: typing.Union[typing.Mapping[str, str], None] = None
Expand All @@ -38,6 +40,10 @@ async def parse(
"""
req = req if req is not None else self.get_default_request()
location = location or self.location
unknown = unknown or self.unknown
load_kwargs = (
{"unknown": unknown} if MARSHMALLOW_VERSION_INFO[0] >= 3 and unknown else {}
)
if req is None:
raise ValueError("Must pass req object")
data = None
Expand All @@ -47,7 +53,7 @@ async def parse(
location_data = await self._load_location_data(
schema=schema, req=req, location=location
)
result = schema.load(location_data)
result = schema.load(location_data, **load_kwargs)
data = result.data if core.MARSHMALLOW_VERSION_INFO[0] < 3 else result
self._validate_arguments(data, validators)
except ma.exceptions.ValidationError as error:
Expand Down Expand Up @@ -111,6 +117,7 @@ def use_args(
req: typing.Optional[Request] = None,
*,
location: str = None,
unknown=None,
as_kwargs: bool = False,
validate: Validate = None,
error_status_code: typing.Optional[int] = None,
Expand Down Expand Up @@ -143,6 +150,7 @@ async def wrapper(*args, **kwargs):
argmap,
req=req_obj,
location=location,
unknown=unknown,
validate=validate,
error_status_code=error_status_code,
error_headers=error_headers,
Expand All @@ -165,6 +173,7 @@ def wrapper(*args, **kwargs):
argmap,
req=req_obj,
location=location,
unknown=unknown,
validate=validate,
error_status_code=error_status_code,
error_headers=error_headers,
Expand Down
22 changes: 20 additions & 2 deletions src/webargs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,15 @@ class Parser:
etc.

:param str location: Default location to use for data
:param str unknown: Default value for ``unknown`` in ``parse``,
``use_args``, and ``use_kwargs``
:param callable error_handler: Custom error handler function.
"""

#: Default location to check for data
DEFAULT_LOCATION = "json"
#: Default value to use for 'unknown' on schema load
DEFAULT_UNKNOWN = None
#: The marshmallow Schema class to use when creating new schemas
DEFAULT_SCHEMA_CLASS = ma.Schema
#: Default status code to return for validation errors
Expand All @@ -125,10 +129,13 @@ class Parser:
"json_or_form": "load_json_or_form",
}

def __init__(self, location=None, *, error_handler=None, schema_class=None):
def __init__(
self, location=None, *, unknown=None, error_handler=None, schema_class=None
):
self.location = location or self.DEFAULT_LOCATION
self.error_callback = _callable_or_raise(error_handler)
self.schema_class = schema_class or self.DEFAULT_SCHEMA_CLASS
self.unknown = unknown or self.DEFAULT_UNKNOWN

def _get_loader(self, location):
"""Get the loader function for the given location.
Expand Down Expand Up @@ -222,6 +229,7 @@ def parse(
req=None,
*,
location=None,
unknown=None,
validate=None,
error_status_code=None,
error_headers=None
Expand All @@ -236,6 +244,8 @@ def parse(
Can be any of the values in :py:attr:`~__location_map__`. By
default, that means one of ``('json', 'query', 'querystring',
'form', 'headers', 'cookies', 'files', 'json_or_form')``.
:param str unknown: A value to pass for ``unknown`` when calling the
schema's ``load`` method (marshmallow 3 only).
:param callable validate: Validation function or list of validation functions
that receives the dictionary of parsed arguments. Validator either returns a
boolean or raises a :exc:`ValidationError`.
Expand All @@ -248,6 +258,10 @@ def parse(
"""
req = req if req is not None else self.get_default_request()
location = location or self.location
unknown = unknown or self.unknown
load_kwargs = (
{"unknown": unknown} if MARSHMALLOW_VERSION_INFO[0] >= 3 and unknown else {}
)
if req is None:
raise ValueError("Must pass req object")
data = None
Expand All @@ -257,7 +271,7 @@ def parse(
location_data = self._load_location_data(
schema=schema, req=req, location=location
)
result = schema.load(location_data)
result = schema.load(location_data, **load_kwargs)
data = result.data if MARSHMALLOW_VERSION_INFO[0] < 3 else result
self._validate_arguments(data, validators)
except ma.exceptions.ValidationError as error:
Expand Down Expand Up @@ -307,6 +321,7 @@ def use_args(
req=None,
*,
location=None,
unknown=None,
as_kwargs=False,
validate=None,
error_status_code=None,
Expand All @@ -325,6 +340,8 @@ def greet(args):
of argname -> `marshmallow.fields.Field` pairs, or a callable
which accepts a request and returns a `marshmallow.Schema`.
:param str location: Where on the request to load values.
:param str unknown: A value to pass for ``unknown`` when calling the
schema's ``load`` method (marshmallow 3 only).
:param bool as_kwargs: Whether to insert arguments as keyword arguments.
:param callable validate: Validation function that receives the dictionary
of parsed arguments. If the function returns ``False``, the parser
Expand Down Expand Up @@ -356,6 +373,7 @@ def wrapper(*args, **kwargs):
argmap,
req=req_obj,
location=location,
unknown=unknown,
validate=validate,
error_status_code=error_status_code,
error_headers=error_headers,
Expand Down
4 changes: 4 additions & 0 deletions src/webargs/pyramidparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def use_args(
req=None,
*,
location=core.Parser.DEFAULT_LOCATION,
unknown=None,
as_kwargs=False,
validate=None,
error_status_code=None,
Expand All @@ -127,6 +128,8 @@ def use_args(
which accepts a request and returns a `marshmallow.Schema`.
:param req: The request object to parse. Pulled off of the view by default.
:param str location: Where on the request to load values.
:param str unknown: A value to pass for ``unknown`` when calling the
schema's ``load`` method (marshmallow 3 only).
:param bool as_kwargs: Whether to insert arguments as keyword arguments.
:param callable validate: Validation function that receives the dictionary
of parsed arguments. If the function returns ``False``, the parser
Expand Down Expand Up @@ -155,6 +158,7 @@ def wrapper(obj, *args, **kwargs):
argmap,
req=request,
location=location,
unknown=unknown,
validate=validate,
error_status_code=error_status_code,
error_headers=error_headers,
Expand Down
78 changes: 63 additions & 15 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,11 @@ def test_parse(parser, web_request):
@pytest.mark.skipif(
MARSHMALLOW_VERSION_INFO[0] < 3, reason="unknown=... added in marshmallow3"
)
def test_parse_with_unknown_behavior_specified(parser, web_request):
@pytest.mark.parametrize(
"set_location",
["schema_instance", "parse_call", "parser_default", "parser_class_default"],
)
def test_parse_with_unknown_behavior_specified(parser, web_request, set_location):
# This is new in webargs 6.x ; it's the way you can "get back" the behavior
# of webargs 5.x in which extra args are ignored
from marshmallow import EXCLUDE, INCLUDE, RAISE
Expand All @@ -119,17 +123,65 @@ class CustomSchema(Schema):
username = fields.Field()
password = fields.Field()

def parse_with_desired_behavior(value):
if set_location == "schema_instance":
if value is not None:
return parser.parse(CustomSchema(unknown=value), web_request)
else:
return parser.parse(CustomSchema(), web_request)
elif set_location == "parse_call":
return parser.parse(CustomSchema(), web_request, unknown=value)
elif set_location == "parser_default":
parser.unknown = value
return parser.parse(CustomSchema(), web_request)
elif set_location == "parser_class_default":

class CustomParser(MockRequestParser):
DEFAULT_UNKNOWN = value

return CustomParser().parse(CustomSchema(), web_request)
else:
raise NotImplementedError

# with no unknown setting or unknown=RAISE, it blows up
with pytest.raises(ValidationError, match="Unknown field."):
parser.parse(CustomSchema(), web_request)
parse_with_desired_behavior(None)
with pytest.raises(ValidationError, match="Unknown field."):
parser.parse(CustomSchema(unknown=RAISE), web_request)
parse_with_desired_behavior(RAISE)

# with unknown=EXCLUDE the data is omitted
ret = parser.parse(CustomSchema(unknown=EXCLUDE), web_request)
ret = parse_with_desired_behavior(EXCLUDE)
assert {"username": 42, "password": 42} == ret
# with unknown=INCLUDE it is added even though it isn't part of the schema
ret = parser.parse(CustomSchema(unknown=INCLUDE), web_request)
ret = parse_with_desired_behavior(INCLUDE)
assert {"username": 42, "password": 42, "fjords": 42} == ret


@pytest.mark.skipif(
MARSHMALLOW_VERSION_INFO[0] < 3, reason="unknown=... added in marshmallow3"
)
def test_parse_with_explicit_unknown_overrides_schema(parser, web_request):
# this test ensures that if you specify unknown=... in your parse call (or
# use_args) it takes precedence over a setting in the schema object
from marshmallow import EXCLUDE, INCLUDE, RAISE

web_request.json = {"username": 42, "password": 42, "fjords": 42}

class CustomSchema(Schema):
username = fields.Field()
password = fields.Field()

# setting RAISE in the parse call overrides schema setting
with pytest.raises(ValidationError, match="Unknown field."):
parser.parse(CustomSchema(unknown=EXCLUDE), web_request, unknown=RAISE)
with pytest.raises(ValidationError, match="Unknown field."):
parser.parse(CustomSchema(unknown=INCLUDE), web_request, unknown=RAISE)

# and the reverse -- setting EXCLUDE or INCLUDE in the parse call overrides
# a schema with RAISE already set
ret = parser.parse(CustomSchema(unknown=RAISE), web_request, unknown=EXCLUDE)
assert {"username": 42, "password": 42} == ret
ret = parser.parse(CustomSchema(unknown=RAISE), web_request, unknown=INCLUDE)
assert {"username": 42, "password": 42, "fjords": 42} == ret


Expand Down Expand Up @@ -756,22 +808,18 @@ def test_warning_raised_if_schema_is_not_in_strict_mode(self, web_request, parse
assert "strict=True" in str(warning.message)

def test_use_kwargs_stacked(self, web_request, parser):
parse_kwargs = {}
if MARSHMALLOW_VERSION_INFO[0] >= 3:
from marshmallow import EXCLUDE

class PageSchema(Schema):
page = fields.Int()

pageschema = PageSchema(unknown=EXCLUDE)
userschema = self.UserSchema(unknown=EXCLUDE)
else:
pageschema = {"page": fields.Int()}
userschema = self.UserSchema(**strict_kwargs)
parse_kwargs = {"unknown": EXCLUDE}

web_request.json = {"email": "[email protected]", "password": "bar", "page": 42}

@parser.use_kwargs(pageschema, web_request)
@parser.use_kwargs(userschema, web_request)
@parser.use_kwargs({"page": fields.Int()}, web_request, **parse_kwargs)
@parser.use_kwargs(
self.UserSchema(**strict_kwargs), web_request, **parse_kwargs
)
def viewfunc(email, password, page):
return {"email": email, "password": password, "page": page}

Expand Down