diff --git a/CHANGELOG.rst b/CHANGELOG.rst index f72d2b00..43145b53 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,51 @@ Changelog --------- +6.2.0 (Unreleased) +****************** + +Features: + +* Add a new ``unknown`` parameter to ``Parser.parse``, ``Parser.use_args``, and + ``Parser.use_kwargs``. When set, it will be passed to the ``Schema.load`` + call. If set to ``None`` (the default), no value is passed, so the schema's + ``unknown`` behavior is used. + +This allows usages like + +.. code-block:: python + + import marshmallow as ma + + # marshmallow 3 only, for use of ``unknown`` and ``EXCLUDE`` + @parser.use_kwargs( + {"q1": ma.fields.Int(), "q2": ma.fields.Int()}, location="query", unknown=ma.EXCLUDE + ) + def foo(q1, q2): + ... + +* Add the ability to set defaults for ``unknown`` on either a Parser instance + or Parser class. Set ``Parser.DEFAULT_UNKNOWN`` on a parser class to apply a value + to any new parser instances created from that class, or set ``unknown`` during + ``Parser`` initialization. + +Usages are varied, but include + +.. code-block:: python + + import marshmallow as ma + from webargs.flaskparser import FlaskParser + + parser = FlaskParser(unknown=ma.INCLUDE) + + # as well as... + class MyParser(FlaskParser): + DEFAULT_UNKNOWN = ma.INCLUDE + + + parser = MyParser() + + 6.1.0 (2020-04-05) ****************** diff --git a/src/webargs/asyncparser.py b/src/webargs/asyncparser.py index 1ba77c70..914dd809 100644 --- a/src/webargs/asyncparser.py +++ b/src/webargs/asyncparser.py @@ -9,6 +9,7 @@ from marshmallow.fields import Field import marshmallow as ma +from webargs.compat import MARSHMALLOW_VERSION_INFO from webargs import core Request = typing.TypeVar("Request") @@ -28,6 +29,7 @@ async def parse( req: Request = None, *, location: str = None, + unknown: str = None, validate: Validate = None, error_status_code: typing.Union[int, None] = None, error_headers: typing.Union[typing.Mapping[str, str], None] = None @@ -38,6 +40,10 @@ async def parse( """ req = req if req is not None else self.get_default_request() location = location or self.location + unknown = unknown or self.unknown + load_kwargs = ( + {"unknown": unknown} if MARSHMALLOW_VERSION_INFO[0] >= 3 and unknown else {} + ) if req is None: raise ValueError("Must pass req object") data = None @@ -47,7 +53,7 @@ async def parse( location_data = await self._load_location_data( schema=schema, req=req, location=location ) - result = schema.load(location_data) + result = schema.load(location_data, **load_kwargs) data = result.data if core.MARSHMALLOW_VERSION_INFO[0] < 3 else result self._validate_arguments(data, validators) except ma.exceptions.ValidationError as error: @@ -111,6 +117,7 @@ def use_args( req: typing.Optional[Request] = None, *, location: str = None, + unknown=None, as_kwargs: bool = False, validate: Validate = None, error_status_code: typing.Optional[int] = None, @@ -143,6 +150,7 @@ async def wrapper(*args, **kwargs): argmap, req=req_obj, location=location, + unknown=unknown, validate=validate, error_status_code=error_status_code, error_headers=error_headers, @@ -165,6 +173,7 @@ def wrapper(*args, **kwargs): argmap, req=req_obj, location=location, + unknown=unknown, validate=validate, error_status_code=error_status_code, error_headers=error_headers, diff --git a/src/webargs/core.py b/src/webargs/core.py index d242a931..0dc96a06 100644 --- a/src/webargs/core.py +++ b/src/webargs/core.py @@ -101,11 +101,15 @@ class Parser: etc. :param str location: Default location to use for data + :param str unknown: Default value for ``unknown`` in ``parse``, + ``use_args``, and ``use_kwargs`` :param callable error_handler: Custom error handler function. """ #: Default location to check for data DEFAULT_LOCATION = "json" + #: Default value to use for 'unknown' on schema load + DEFAULT_UNKNOWN = None #: The marshmallow Schema class to use when creating new schemas DEFAULT_SCHEMA_CLASS = ma.Schema #: Default status code to return for validation errors @@ -125,10 +129,13 @@ class Parser: "json_or_form": "load_json_or_form", } - def __init__(self, location=None, *, error_handler=None, schema_class=None): + def __init__( + self, location=None, *, unknown=None, error_handler=None, schema_class=None + ): self.location = location or self.DEFAULT_LOCATION self.error_callback = _callable_or_raise(error_handler) self.schema_class = schema_class or self.DEFAULT_SCHEMA_CLASS + self.unknown = unknown or self.DEFAULT_UNKNOWN def _get_loader(self, location): """Get the loader function for the given location. @@ -222,6 +229,7 @@ def parse( req=None, *, location=None, + unknown=None, validate=None, error_status_code=None, error_headers=None @@ -236,6 +244,8 @@ def parse( Can be any of the values in :py:attr:`~__location_map__`. By default, that means one of ``('json', 'query', 'querystring', 'form', 'headers', 'cookies', 'files', 'json_or_form')``. + :param str unknown: A value to pass for ``unknown`` when calling the + schema's ``load`` method (marshmallow 3 only). :param callable validate: Validation function or list of validation functions that receives the dictionary of parsed arguments. Validator either returns a boolean or raises a :exc:`ValidationError`. @@ -248,6 +258,10 @@ def parse( """ req = req if req is not None else self.get_default_request() location = location or self.location + unknown = unknown or self.unknown + load_kwargs = ( + {"unknown": unknown} if MARSHMALLOW_VERSION_INFO[0] >= 3 and unknown else {} + ) if req is None: raise ValueError("Must pass req object") data = None @@ -257,7 +271,7 @@ def parse( location_data = self._load_location_data( schema=schema, req=req, location=location ) - result = schema.load(location_data) + result = schema.load(location_data, **load_kwargs) data = result.data if MARSHMALLOW_VERSION_INFO[0] < 3 else result self._validate_arguments(data, validators) except ma.exceptions.ValidationError as error: @@ -307,6 +321,7 @@ def use_args( req=None, *, location=None, + unknown=None, as_kwargs=False, validate=None, error_status_code=None, @@ -325,6 +340,8 @@ def greet(args): of argname -> `marshmallow.fields.Field` pairs, or a callable which accepts a request and returns a `marshmallow.Schema`. :param str location: Where on the request to load values. + :param str unknown: A value to pass for ``unknown`` when calling the + schema's ``load`` method (marshmallow 3 only). :param bool as_kwargs: Whether to insert arguments as keyword arguments. :param callable validate: Validation function that receives the dictionary of parsed arguments. If the function returns ``False``, the parser @@ -356,6 +373,7 @@ def wrapper(*args, **kwargs): argmap, req=req_obj, location=location, + unknown=unknown, validate=validate, error_status_code=error_status_code, error_headers=error_headers, diff --git a/src/webargs/pyramidparser.py b/src/webargs/pyramidparser.py index 7f9d5c01..91018f41 100644 --- a/src/webargs/pyramidparser.py +++ b/src/webargs/pyramidparser.py @@ -113,6 +113,7 @@ def use_args( req=None, *, location=core.Parser.DEFAULT_LOCATION, + unknown=None, as_kwargs=False, validate=None, error_status_code=None, @@ -127,6 +128,8 @@ def use_args( which accepts a request and returns a `marshmallow.Schema`. :param req: The request object to parse. Pulled off of the view by default. :param str location: Where on the request to load values. + :param str unknown: A value to pass for ``unknown`` when calling the + schema's ``load`` method (marshmallow 3 only). :param bool as_kwargs: Whether to insert arguments as keyword arguments. :param callable validate: Validation function that receives the dictionary of parsed arguments. If the function returns ``False``, the parser @@ -155,6 +158,7 @@ def wrapper(obj, *args, **kwargs): argmap, req=request, location=location, + unknown=unknown, validate=validate, error_status_code=error_status_code, error_headers=error_headers, diff --git a/tests/test_core.py b/tests/test_core.py index a9a32e98..1995636d 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -108,7 +108,11 @@ def test_parse(parser, web_request): @pytest.mark.skipif( MARSHMALLOW_VERSION_INFO[0] < 3, reason="unknown=... added in marshmallow3" ) -def test_parse_with_unknown_behavior_specified(parser, web_request): +@pytest.mark.parametrize( + "set_location", + ["schema_instance", "parse_call", "parser_default", "parser_class_default"], +) +def test_parse_with_unknown_behavior_specified(parser, web_request, set_location): # This is new in webargs 6.x ; it's the way you can "get back" the behavior # of webargs 5.x in which extra args are ignored from marshmallow import EXCLUDE, INCLUDE, RAISE @@ -119,17 +123,65 @@ class CustomSchema(Schema): username = fields.Field() password = fields.Field() + def parse_with_desired_behavior(value): + if set_location == "schema_instance": + if value is not None: + return parser.parse(CustomSchema(unknown=value), web_request) + else: + return parser.parse(CustomSchema(), web_request) + elif set_location == "parse_call": + return parser.parse(CustomSchema(), web_request, unknown=value) + elif set_location == "parser_default": + parser.unknown = value + return parser.parse(CustomSchema(), web_request) + elif set_location == "parser_class_default": + + class CustomParser(MockRequestParser): + DEFAULT_UNKNOWN = value + + return CustomParser().parse(CustomSchema(), web_request) + else: + raise NotImplementedError + # with no unknown setting or unknown=RAISE, it blows up with pytest.raises(ValidationError, match="Unknown field."): - parser.parse(CustomSchema(), web_request) + parse_with_desired_behavior(None) with pytest.raises(ValidationError, match="Unknown field."): - parser.parse(CustomSchema(unknown=RAISE), web_request) + parse_with_desired_behavior(RAISE) # with unknown=EXCLUDE the data is omitted - ret = parser.parse(CustomSchema(unknown=EXCLUDE), web_request) + ret = parse_with_desired_behavior(EXCLUDE) assert {"username": 42, "password": 42} == ret # with unknown=INCLUDE it is added even though it isn't part of the schema - ret = parser.parse(CustomSchema(unknown=INCLUDE), web_request) + ret = parse_with_desired_behavior(INCLUDE) + assert {"username": 42, "password": 42, "fjords": 42} == ret + + +@pytest.mark.skipif( + MARSHMALLOW_VERSION_INFO[0] < 3, reason="unknown=... added in marshmallow3" +) +def test_parse_with_explicit_unknown_overrides_schema(parser, web_request): + # this test ensures that if you specify unknown=... in your parse call (or + # use_args) it takes precedence over a setting in the schema object + from marshmallow import EXCLUDE, INCLUDE, RAISE + + web_request.json = {"username": 42, "password": 42, "fjords": 42} + + class CustomSchema(Schema): + username = fields.Field() + password = fields.Field() + + # setting RAISE in the parse call overrides schema setting + with pytest.raises(ValidationError, match="Unknown field."): + parser.parse(CustomSchema(unknown=EXCLUDE), web_request, unknown=RAISE) + with pytest.raises(ValidationError, match="Unknown field."): + parser.parse(CustomSchema(unknown=INCLUDE), web_request, unknown=RAISE) + + # and the reverse -- setting EXCLUDE or INCLUDE in the parse call overrides + # a schema with RAISE already set + ret = parser.parse(CustomSchema(unknown=RAISE), web_request, unknown=EXCLUDE) + assert {"username": 42, "password": 42} == ret + ret = parser.parse(CustomSchema(unknown=RAISE), web_request, unknown=INCLUDE) assert {"username": 42, "password": 42, "fjords": 42} == ret @@ -756,22 +808,18 @@ def test_warning_raised_if_schema_is_not_in_strict_mode(self, web_request, parse assert "strict=True" in str(warning.message) def test_use_kwargs_stacked(self, web_request, parser): + parse_kwargs = {} if MARSHMALLOW_VERSION_INFO[0] >= 3: from marshmallow import EXCLUDE - class PageSchema(Schema): - page = fields.Int() - - pageschema = PageSchema(unknown=EXCLUDE) - userschema = self.UserSchema(unknown=EXCLUDE) - else: - pageschema = {"page": fields.Int()} - userschema = self.UserSchema(**strict_kwargs) + parse_kwargs = {"unknown": EXCLUDE} web_request.json = {"email": "foo@bar.com", "password": "bar", "page": 42} - @parser.use_kwargs(pageschema, web_request) - @parser.use_kwargs(userschema, web_request) + @parser.use_kwargs({"page": fields.Int()}, web_request, **parse_kwargs) + @parser.use_kwargs( + self.UserSchema(**strict_kwargs), web_request, **parse_kwargs + ) def viewfunc(email, password, page): return {"email": email, "password": password, "page": page}