diff --git a/falcon/testing/client.py b/falcon/testing/client.py index 1521128e7..e4dcf5b13 100644 --- a/falcon/testing/client.py +++ b/falcon/testing/client.py @@ -20,16 +20,18 @@ import asyncio import datetime as dt +import hashlib import inspect import json as json_module import time +from urllib.parse import urlencode from typing import Dict, Optional, Sequence, Union import warnings import wsgiref.validate from falcon.asgi_spec import ScopeType from falcon.constants import COMBINED_METHODS -from falcon.constants import MEDIA_JSON +from falcon.constants import MEDIA_JSON, MEDIA_MULTIPART, MEDIA_URLENCODED from falcon.errors import CompatibilityError from falcon.testing import helpers from falcon.testing.srmock import StartResponseMock @@ -90,7 +92,7 @@ class Cookie: or ``None`` if not specified. max_age (int): The lifetime of the cookie in seconds, or ``None`` if not specified. - secure (bool): Whether or not the cookie may only only be + secure (bool): Whether or not the cookie may only be transmitted from the client via HTTPS. http_only (bool): Whether or not the cookie may only be included in unscripted requests from the client. @@ -442,6 +444,7 @@ def simulate_request( content_type=None, body=None, json=None, + form=None, file_wrapper=None, wsgierrors=None, params=None, @@ -579,6 +582,7 @@ def simulate_request( content_type=content_type, body=body, json=json, + form=form, params=params, params_csv=params_csv, protocol=protocol, @@ -602,6 +606,7 @@ def simulate_request( headers, body, json, + form, extras, ) @@ -626,7 +631,7 @@ def simulate_request( # NOTE(vytas): Even given the duct tape nature of overriding # arbitrary environ variables, changing the method can potentially # be very confusing, particularly when using specialized - # simulate_get/post/patch etc methods. + # simulate_get/post/patch etc. methods. raise ValueError( 'WSGI environ extras may not override the request method. ' 'Please use the method parameter.' @@ -655,6 +660,7 @@ async def _simulate_request_asgi( content_type=None, body=None, json=None, + form=None, params=None, params_csv=True, protocol='http', @@ -739,6 +745,9 @@ async def _simulate_request_asgi( overrides `body` and sets the Content-Type header to ``'application/json'``, overriding any value specified by either the `content_type` or `headers` arguments. + form (dict): A form to submit as the request's body + (default: ``None``). If present, overrides `body`, and sets the + Content-Type header. host(str): A string to use for the hostname part of the fully qualified request URL (default: 'falconframework.org') remote_addr (str): A string to use as the remote IP address for the @@ -777,6 +786,7 @@ async def _simulate_request_asgi( headers, body, json, + form, extras, ) @@ -2143,8 +2153,70 @@ async def __aexit__(self, exc_type, exc, tb): await self._task_req +def _encode_form(form: dict) -> tuple: + """Build the body for a URL-encoded or multipart form. + + This utility method accepts two types of forms: a simple dict mapping + string keys to values will get URL-encoded, whereas if any value is a list + of two or three items, these will be treated as (filename, content) or + (filename, content, content_type), and encoded as a multipart form. + + Returns: (encoded body bytes, Content-Type header) + """ + form_items = form.items() if isinstance(form, dict) else form + + if not any(isinstance(value, (list, tuple)) for _, value in form_items): + # URL-encoded form + return urlencode(form, doseq=True).encode(), MEDIA_URLENCODED + + # Encode multipart form + body = [b''] + + for name, value in form_items: + data = value + filename = None + content_type = 'text/plain' + + if isinstance(value, (list, tuple)): + try: + filename, data = value + content_type = 'application/octet-stream' + except ValueError: + filename, data, content_type = value + if isinstance(data, str): + data = data.encode() + elif not isinstance(data, bytes): + # Assume a file-like object + data = data.read() + + headers = f'Content-Disposition: form-data; name="{name}"' + if filename: + headers += f'; filename="{filename}"' + headers += f'\r\nContent-Type: {content_type}\r\n\r\n' + + body.append(headers.encode() + data + b'\r\n') + + checksum = hashlib.sha256() + for chunk in body: + checksum.update(chunk) + boundary = checksum.hexdigest() + + encoded = f'--{boundary}\r\n'.encode().join(body) + encoded += f'--{boundary}--\r\n'.encode() + return encoded, f'{MEDIA_MULTIPART}; boundary={boundary}' + + def _prepare_sim_args( - path, query_string, params, params_csv, content_type, headers, body, json, extras + path, + query_string, + params, + params_csv, + content_type, + headers, + body, + json, + form, + extras, ): if not path.startswith('/'): raise ValueError("path must start with '/'") @@ -2178,6 +2250,11 @@ def _prepare_sim_args( headers = headers or {} headers['Content-Type'] = MEDIA_JSON + elif form is not None: + body, content_type = _encode_form(form) + headers = headers or {} + headers['Content-Type'] = content_type + return path, query_string, headers, body, extras diff --git a/tests/test_media_multipart.py b/tests/test_media_multipart.py index 277c0a567..57a4ccf9c 100644 --- a/tests/test_media_multipart.py +++ b/tests/test_media_multipart.py @@ -845,3 +845,42 @@ async def deserialize_async(self, stream, content_type, content_length): assert resp.status_code == 200 assert resp.json == ['', '0x48'] + + +def test_simulate_form(client): + resp = client.simulate_post( + '/submit', + form={ + 'checked': 'true', + 'file': ('test.txt', b'Hello, World!\n', 'text/plain'), + 'another': ('test.dat', io.BytesIO(b'1\n2\n3\n')), + }, + ) + + assert resp.status_code == 200 + assert resp.json == [ + { + 'content_type': 'text/plain', + 'data': 'true', + 'filename': None, + 'name': 'checked', + 'secure_filename': None, + 'text': 'true', + }, + { + 'content_type': 'text/plain', + 'data': 'Hello, World!\n', + 'filename': 'test.txt', + 'name': 'file', + 'secure_filename': 'test.txt', + 'text': 'Hello, World!\n', + }, + { + 'content_type': 'application/octet-stream', + 'data': '1\n2\n3\n', + 'filename': 'test.dat', + 'name': 'another', + 'secure_filename': 'test.dat', + 'text': None, + }, + ] diff --git a/tests/test_media_urlencoded.py b/tests/test_media_urlencoded.py index 4456096fa..dc7e2d055 100644 --- a/tests/test_media_urlencoded.py +++ b/tests/test_media_urlencoded.py @@ -81,3 +81,12 @@ def test_urlencoded_form(client, body, expected): headers={'Content-Type': 'application/x-www-form-urlencoded'}, ) assert resp.json == expected + + +@pytest.mark.parametrize( + 'form', [{}, {'a': '1', 'b': '2'}, (('a', '1'), ('b', '2'), ('c', '3'))] +) +def test_simulate_form(client, form): + resp = client.simulate_post('/media', form=form) + assert resp.status_code == 200 + assert resp.json == dict(form)