diff --git a/curl_cffi/requests/__init__.py b/curl_cffi/requests/__init__.py index 3f66a77..804dbe0 100644 --- a/curl_cffi/requests/__init__.py +++ b/curl_cffi/requests/__init__.py @@ -28,7 +28,7 @@ from functools import partial from io import BytesIO -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Literal, Optional, Tuple, Union from ..const import CurlHttpVersion, CurlWsFlag from ..curl import CurlMime @@ -68,6 +68,7 @@ def request( thread: Optional[ThreadType] = None, default_headers: Optional[bool] = None, default_encoding: Union[str, Callable[[bytes], str]] = "utf-8", + quote: Union[str, Literal[False]] = "", curl_options: Optional[dict] = None, http_version: Optional[CurlHttpVersion] = None, debug: bool = False, @@ -111,7 +112,11 @@ def request( choices: eventlet, gevent. default_headers: whether to set default browser headers when impersonating. default_encoding: encoding for decoding response content if charset is not found in headers. - Defaults to "utf-8". Can be set to a callable for automatic detection. + Defaults to "utf-8". Can be set to a callable for automatic detection. + quote: Set characters to be quoted, i.e. percent-encoded. Default safe string + is ``!#$%&'()*+,/:;=?@[]~``. If set to a sting, the character will be removed + from the safe string, thus quoted. If set to False, the url will be kept as is, + without any automatic percent-encoding, you must encode the URL yourself. curl_options: extra curl options to use. http_version: limiting http version, defaults to http2. debug: print extra curl debug info. @@ -151,6 +156,7 @@ def request( extra_fp=extra_fp, default_headers=default_headers, default_encoding=default_encoding, + quote=quote, http_version=http_version, interface=interface, cert=cert, diff --git a/curl_cffi/requests/session.py b/curl_cffi/requests/session.py index 2a0d578..3130220 100644 --- a/curl_cffi/requests/session.py +++ b/curl_cffi/requests/session.py @@ -103,12 +103,30 @@ def _is_absolute_url(url: str) -> bool: return bool(parsed_url.scheme and parsed_url.hostname) -def _update_url_params(url: str, *params_list: Union[Dict, List, Tuple, None]) -> str: +SAFE_CHARS = set("!#$%&'()*+,/:;=?@[]~") + + +def _quote_path_and_params(url: str, quote_str: str = ""): + safe = "".join(SAFE_CHARS - set(quote_str)) + parsed_url = urlparse(url) + parsed_get_args = parse_qsl(parsed_url.query) + encoded_get_args = urlencode(parsed_get_args, doseq=True, safe=safe) + return ParseResult( + parsed_url.scheme, + parsed_url.netloc, + quote(parsed_url.path, safe=safe), + parsed_url.params, + encoded_get_args, + parsed_url.fragment, + ).geturl() + + +def _update_url_params(url: str, params: Union[Dict, List, Tuple]) -> str: """Add URL query params to provided URL being aware of existing. Parameters: url: string of target URL - params: list of dict or list containing requested params to be added + params: dict containing requested params to be added Returns: string with updated URL @@ -126,27 +144,20 @@ def _update_url_params(url: str, *params_list: Union[Dict, List, Tuple, None]) - parsed_get_args = parse_qsl(parsed_url.query) # Merging URL arguments dict with new params - for params in params_list: - if not params: - continue - - # Check the args appearance count of keys - old_args_counter = Counter(x[0] for x in parsed_get_args) - if isinstance(params, dict): - params = list(params.items()) - new_args_counter = Counter(x[0] for x in params) - - for key, value in params: - # Bool and dict values should be converted to json-friendly values - if isinstance(value, (bool, dict)): - value = dumps(value) - - # k:v is 1-to-1 mapping, we have to search and update it, e.g. k=v - if old_args_counter.get(key) == 1 and new_args_counter.get(key) == 1: - parsed_get_args = [(x if x[0] != key else (key, value)) for x in parsed_get_args] - # k:v is 1-to-list mapping, simply append them, e.g. k=v1&k=v2 - else: - parsed_get_args.append((key, value)) + old_args_counter = Counter(x[0] for x in parsed_get_args) + if isinstance(params, dict): + params = list(params.items()) + new_args_counter = Counter(x[0] for x in params) + for key, value in params: + # Bool and Dict values should be converted to json-friendly values + # you may throw this part away if you don't like it :) + if isinstance(value, (bool, dict)): + value = dumps(value) + # 1 to 1 mapping, we have to search and update it. + if old_args_counter.get(key) == 1 and new_args_counter.get(key) == 1: + parsed_get_args = [(x if x[0] != key else (key, value)) for x in parsed_get_args] + else: + parsed_get_args.append((key, value)) # Converting URL argument to proper query string encoded_get_args = urlencode(parsed_get_args, doseq=True) @@ -156,7 +167,7 @@ def _update_url_params(url: str, *params_list: Union[Dict, List, Tuple, None]) - new_url = ParseResult( parsed_url.scheme, parsed_url.netloc, - quote(parsed_url.path), + parsed_url.path, parsed_url.params, encoded_get_args, parsed_url.fragment, @@ -390,6 +401,7 @@ def _set_curl_options( akamai: Optional[str] = None, extra_fp: Optional[Union[ExtraFingerprints, ExtraFpDict]] = None, default_headers: Optional[bool] = None, + quote: Union[str, Literal[False]] = "", http_version: Optional[CurlHttpVersion] = None, interface: Optional[str] = None, cert: Optional[Union[str, Tuple[str, str]]] = None, @@ -411,10 +423,15 @@ def _set_curl_options( if method == "HEAD": c.setopt(CurlOpt.NOBODY, 1) - # url, always unquote and re-quote - url = _update_url_params(url, self.params, params) + # url + if self.params: + url = _update_url_params(url, self.params) + if params: + url = _update_url_params(url, params) if self.base_url: url = urljoin(self.base_url, url) + if quote is not False: + url = _quote_path_and_params(url, quote_str=quote) c.setopt(CurlOpt.URL, url.encode()) # data/body/json @@ -939,6 +956,7 @@ def request( extra_fp: Optional[Union[ExtraFingerprints, ExtraFpDict]] = None, default_headers: Optional[bool] = None, default_encoding: Union[str, Callable[[bytes], str]] = "utf-8", + quote: Union[str, Literal[False]] = "", http_version: Optional[CurlHttpVersion] = None, interface: Optional[str] = None, cert: Optional[Union[str, Tuple[str, str]]] = None, @@ -983,6 +1001,7 @@ def request( akamai=akamai, extra_fp=extra_fp, default_headers=default_headers, + quote=quote, http_version=http_version, interface=interface, stream=stream, @@ -1233,6 +1252,7 @@ async def request( extra_fp: Optional[Union[ExtraFingerprints, ExtraFpDict]] = None, default_headers: Optional[bool] = None, default_encoding: Union[str, Callable[[bytes], str]] = "utf-8", + quote: Union[str, Literal[False]] = "", http_version: Optional[CurlHttpVersion] = None, interface: Optional[str] = None, cert: Optional[Union[str, Tuple[str, str]]] = None, @@ -1270,6 +1290,7 @@ async def request( akamai=akamai, extra_fp=extra_fp, default_headers=default_headers, + quote=quote, http_version=http_version, interface=interface, stream=stream, diff --git a/tests/unittest/test_requests.py b/tests/unittest/test_requests.py index 556c50a..c28fb0e 100644 --- a/tests/unittest/test_requests.py +++ b/tests/unittest/test_requests.py @@ -118,36 +118,6 @@ def test_options(server): assert r.status_code == 200 -def test_update_url_params(): - # should be quoted - url = "https://example.com/post.json?limit=1&tags=id:<1000&page=0" - quoted = "https://example.com/post.json?limit=1&tags=id%3A%3C1000&page=0" - assert _update_url_params(url) == quoted - - # should not change - url = "https://example.com/post.json?limit=1&tags=foo&page=0" - assert _update_url_params(url) == url - - # update url params - url = "https://example.com/post.json?limit=1&tags=foo&page=0" - params = {"tags": "bar"} - updated_url = "https://example.com/post.json?limit=1&tags=bar&page=0" - assert _update_url_params(url, params) == updated_url - - # append url params - url = "https://example.com/post.json?limit=1&tags=foo&tags=a" - params = {"tags": "bar"} - updated_url = "https://example.com/post.json?limit=1&tags=foo&tags=a&tags=bar" - assert _update_url_params(url, params) == updated_url - - # update url params in a row - url = "https://example.com/post.json?limit=1&tags=foo&page=0" - session_params = {"tags": "a"} - request_params = {"tags": "bar"} - updated_url = "https://example.com/post.json?limit=1&tags=bar&page=0" - assert _update_url_params(url, session_params, request_params) == updated_url - - def test_params(server): r = requests.get(str(server.url.copy_with(path="/echo_params")), params={"foo": "bar"}) assert r.content == b'{"params": {"foo": ["bar"]}}' @@ -180,6 +150,66 @@ def test_update_params(server): assert r.content == b'{"params": {"a": ["1", "2"], "foo": ["z", "1", "2"]}}' +def test_url_encode(server): + # https://github.com/lexiforest/curl_cffi/issues/394 + + # FIXME: should use server.url, but it always encode + + # should not change + url = "http://127.0.0.1:8000/%2f%2f%2f" + r = requests.get(str(url)) + assert r.url == str(url) + + url = "http://127.0.0.1:8000/imaginary-pagination:7" + r = requests.get(str(url)) + assert r.url == str(url) + + url = "http://127.0.0.1:8000/post.json?limit=1&tags=foo&page=0" + r = requests.get(str(url)) + assert r.url == url + + # Non-ASCII URL should be percent encoded as UTF-8 sequence + non_ascii_url = "http://127.0.0.1:8000/search?q=测试" + encoded_non_ascii_url = "http://127.0.0.1:8000/search?q=%E6%B5%8B%E8%AF%95" + + r = requests.get(non_ascii_url) + assert r.url == encoded_non_ascii_url + + r = requests.get(encoded_non_ascii_url) + assert r.url == encoded_non_ascii_url + + # should be quoted + url = "http://127.0.0.1:8000/e x a m p l e" + quoted = "http://127.0.0.1:8000/e%20x%20a%20m%20p%20l%20e" + r = requests.get(str(url)) + assert r.url == quoted + + # I have seen discussions that ask how to prevent requests from quoting unwanted + # parts, like `:`. So, let's make it explicit that you want to quote some chars. + # + # See: + # 1. https://stackoverflow.com/q/57365497/1061155 + # 2. https://stackoverflow.com/q/23496750/1061155 + + url = "http://127.0.0.1:8000/imaginary-pagination:7" + quoted = "http://127.0.0.1:8000/imaginary-pagination%3A7" + r = requests.get(url, quote=":") + assert r.url == quoted + + url = "http://127.0.0.1:8000/post.json?limit=1&tags=id:<1000&page=0" + quoted = "http://127.0.0.1:8000/post.json?limit=1&tags=id%3A%3C1000&page=0" + r = requests.get(url, quote=":") + assert r.url == quoted + + # Do not quote at all + url = "http://127.0.0.1:8000/query={}" + quoted = "http://127.0.0.1:8000/query=%7B%7D" + r = requests.get(url) + assert r.url == quoted + r = requests.get(url, quote=False) + assert r.url == url + + def test_headers(server): r = requests.get(str(server.url.copy_with(path="/echo_headers")), headers={"foo": "bar"}) headers = r.json()