Skip to content

Commit

Permalink
Add quote parameter to specify which character should be percent-en…
Browse files Browse the repository at this point in the history
…coded in path and query (#405)
  • Loading branch information
lexiforest authored Oct 8, 2024
1 parent 967fdc2 commit 93e551d
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 58 deletions.
10 changes: 8 additions & 2 deletions curl_cffi/requests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from functools import partial
from io import BytesIO
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union

from ..const import CurlHttpVersion, CurlWsFlag
from ..curl import CurlMime
Expand Down Expand Up @@ -68,6 +68,7 @@ def request(
thread: Optional[ThreadType] = None,
default_headers: Optional[bool] = None,
default_encoding: Union[str, Callable[[bytes], str]] = "utf-8",
quote: Union[str, Literal[False]] = "",
curl_options: Optional[dict] = None,
http_version: Optional[CurlHttpVersion] = None,
debug: bool = False,
Expand Down Expand Up @@ -111,7 +112,11 @@ def request(
choices: eventlet, gevent.
default_headers: whether to set default browser headers when impersonating.
default_encoding: encoding for decoding response content if charset is not found in headers.
Defaults to "utf-8". Can be set to a callable for automatic detection.
Defaults to "utf-8". Can be set to a callable for automatic detection.
quote: Set characters to be quoted, i.e. percent-encoded. Default safe string
is ``!#$%&'()*+,/:;=?@[]~``. If set to a sting, the character will be removed
from the safe string, thus quoted. If set to False, the url will be kept as is,
without any automatic percent-encoding, you must encode the URL yourself.
curl_options: extra curl options to use.
http_version: limiting http version, defaults to http2.
debug: print extra curl debug info.
Expand Down Expand Up @@ -151,6 +156,7 @@ def request(
extra_fp=extra_fp,
default_headers=default_headers,
default_encoding=default_encoding,
quote=quote,
http_version=http_version,
interface=interface,
cert=cert,
Expand Down
73 changes: 47 additions & 26 deletions curl_cffi/requests/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,30 @@ def _is_absolute_url(url: str) -> bool:
return bool(parsed_url.scheme and parsed_url.hostname)


def _update_url_params(url: str, *params_list: Union[Dict, List, Tuple, None]) -> str:
SAFE_CHARS = set("!#$%&'()*+,/:;=?@[]~")


def _quote_path_and_params(url: str, quote_str: str = ""):
safe = "".join(SAFE_CHARS - set(quote_str))
parsed_url = urlparse(url)
parsed_get_args = parse_qsl(parsed_url.query)
encoded_get_args = urlencode(parsed_get_args, doseq=True, safe=safe)
return ParseResult(
parsed_url.scheme,
parsed_url.netloc,
quote(parsed_url.path, safe=safe),
parsed_url.params,
encoded_get_args,
parsed_url.fragment,
).geturl()


def _update_url_params(url: str, params: Union[Dict, List, Tuple]) -> str:
"""Add URL query params to provided URL being aware of existing.
Parameters:
url: string of target URL
params: list of dict or list containing requested params to be added
params: dict containing requested params to be added
Returns:
string with updated URL
Expand All @@ -126,27 +144,20 @@ def _update_url_params(url: str, *params_list: Union[Dict, List, Tuple, None]) -
parsed_get_args = parse_qsl(parsed_url.query)

# Merging URL arguments dict with new params
for params in params_list:
if not params:
continue

# Check the args appearance count of keys
old_args_counter = Counter(x[0] for x in parsed_get_args)
if isinstance(params, dict):
params = list(params.items())
new_args_counter = Counter(x[0] for x in params)

for key, value in params:
# Bool and dict values should be converted to json-friendly values
if isinstance(value, (bool, dict)):
value = dumps(value)

# k:v is 1-to-1 mapping, we have to search and update it, e.g. k=v
if old_args_counter.get(key) == 1 and new_args_counter.get(key) == 1:
parsed_get_args = [(x if x[0] != key else (key, value)) for x in parsed_get_args]
# k:v is 1-to-list mapping, simply append them, e.g. k=v1&k=v2
else:
parsed_get_args.append((key, value))
old_args_counter = Counter(x[0] for x in parsed_get_args)
if isinstance(params, dict):
params = list(params.items())
new_args_counter = Counter(x[0] for x in params)
for key, value in params:
# Bool and Dict values should be converted to json-friendly values
# you may throw this part away if you don't like it :)
if isinstance(value, (bool, dict)):
value = dumps(value)
# 1 to 1 mapping, we have to search and update it.
if old_args_counter.get(key) == 1 and new_args_counter.get(key) == 1:
parsed_get_args = [(x if x[0] != key else (key, value)) for x in parsed_get_args]
else:
parsed_get_args.append((key, value))

# Converting URL argument to proper query string
encoded_get_args = urlencode(parsed_get_args, doseq=True)
Expand All @@ -156,7 +167,7 @@ def _update_url_params(url: str, *params_list: Union[Dict, List, Tuple, None]) -
new_url = ParseResult(
parsed_url.scheme,
parsed_url.netloc,
quote(parsed_url.path),
parsed_url.path,
parsed_url.params,
encoded_get_args,
parsed_url.fragment,
Expand Down Expand Up @@ -390,6 +401,7 @@ def _set_curl_options(
akamai: Optional[str] = None,
extra_fp: Optional[Union[ExtraFingerprints, ExtraFpDict]] = None,
default_headers: Optional[bool] = None,
quote: Union[str, Literal[False]] = "",
http_version: Optional[CurlHttpVersion] = None,
interface: Optional[str] = None,
cert: Optional[Union[str, Tuple[str, str]]] = None,
Expand All @@ -411,10 +423,15 @@ def _set_curl_options(
if method == "HEAD":
c.setopt(CurlOpt.NOBODY, 1)

# url, always unquote and re-quote
url = _update_url_params(url, self.params, params)
# url
if self.params:
url = _update_url_params(url, self.params)
if params:
url = _update_url_params(url, params)
if self.base_url:
url = urljoin(self.base_url, url)
if quote is not False:
url = _quote_path_and_params(url, quote_str=quote)
c.setopt(CurlOpt.URL, url.encode())

# data/body/json
Expand Down Expand Up @@ -939,6 +956,7 @@ def request(
extra_fp: Optional[Union[ExtraFingerprints, ExtraFpDict]] = None,
default_headers: Optional[bool] = None,
default_encoding: Union[str, Callable[[bytes], str]] = "utf-8",
quote: Union[str, Literal[False]] = "",
http_version: Optional[CurlHttpVersion] = None,
interface: Optional[str] = None,
cert: Optional[Union[str, Tuple[str, str]]] = None,
Expand Down Expand Up @@ -983,6 +1001,7 @@ def request(
akamai=akamai,
extra_fp=extra_fp,
default_headers=default_headers,
quote=quote,
http_version=http_version,
interface=interface,
stream=stream,
Expand Down Expand Up @@ -1233,6 +1252,7 @@ async def request(
extra_fp: Optional[Union[ExtraFingerprints, ExtraFpDict]] = None,
default_headers: Optional[bool] = None,
default_encoding: Union[str, Callable[[bytes], str]] = "utf-8",
quote: Union[str, Literal[False]] = "",
http_version: Optional[CurlHttpVersion] = None,
interface: Optional[str] = None,
cert: Optional[Union[str, Tuple[str, str]]] = None,
Expand Down Expand Up @@ -1270,6 +1290,7 @@ async def request(
akamai=akamai,
extra_fp=extra_fp,
default_headers=default_headers,
quote=quote,
http_version=http_version,
interface=interface,
stream=stream,
Expand Down
90 changes: 60 additions & 30 deletions tests/unittest/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,36 +118,6 @@ def test_options(server):
assert r.status_code == 200


def test_update_url_params():
# should be quoted
url = "https://example.com/post.json?limit=1&tags=id:<1000&page=0"
quoted = "https://example.com/post.json?limit=1&tags=id%3A%3C1000&page=0"
assert _update_url_params(url) == quoted

# should not change
url = "https://example.com/post.json?limit=1&tags=foo&page=0"
assert _update_url_params(url) == url

# update url params
url = "https://example.com/post.json?limit=1&tags=foo&page=0"
params = {"tags": "bar"}
updated_url = "https://example.com/post.json?limit=1&tags=bar&page=0"
assert _update_url_params(url, params) == updated_url

# append url params
url = "https://example.com/post.json?limit=1&tags=foo&tags=a"
params = {"tags": "bar"}
updated_url = "https://example.com/post.json?limit=1&tags=foo&tags=a&tags=bar"
assert _update_url_params(url, params) == updated_url

# update url params in a row
url = "https://example.com/post.json?limit=1&tags=foo&page=0"
session_params = {"tags": "a"}
request_params = {"tags": "bar"}
updated_url = "https://example.com/post.json?limit=1&tags=bar&page=0"
assert _update_url_params(url, session_params, request_params) == updated_url


def test_params(server):
r = requests.get(str(server.url.copy_with(path="/echo_params")), params={"foo": "bar"})
assert r.content == b'{"params": {"foo": ["bar"]}}'
Expand Down Expand Up @@ -180,6 +150,66 @@ def test_update_params(server):
assert r.content == b'{"params": {"a": ["1", "2"], "foo": ["z", "1", "2"]}}'


def test_url_encode(server):
# https://github.com/lexiforest/curl_cffi/issues/394

# FIXME: should use server.url, but it always encode

# should not change
url = "http://127.0.0.1:8000/%2f%2f%2f"
r = requests.get(str(url))
assert r.url == str(url)

url = "http://127.0.0.1:8000/imaginary-pagination:7"
r = requests.get(str(url))
assert r.url == str(url)

url = "http://127.0.0.1:8000/post.json?limit=1&tags=foo&page=0"
r = requests.get(str(url))
assert r.url == url

# Non-ASCII URL should be percent encoded as UTF-8 sequence
non_ascii_url = "http://127.0.0.1:8000/search?q=测试"
encoded_non_ascii_url = "http://127.0.0.1:8000/search?q=%E6%B5%8B%E8%AF%95"

r = requests.get(non_ascii_url)
assert r.url == encoded_non_ascii_url

r = requests.get(encoded_non_ascii_url)
assert r.url == encoded_non_ascii_url

# should be quoted
url = "http://127.0.0.1:8000/e x a m p l e"
quoted = "http://127.0.0.1:8000/e%20x%20a%20m%20p%20l%20e"
r = requests.get(str(url))
assert r.url == quoted

# I have seen discussions that ask how to prevent requests from quoting unwanted
# parts, like `:`. So, let's make it explicit that you want to quote some chars.
#
# See:
# 1. https://stackoverflow.com/q/57365497/1061155
# 2. https://stackoverflow.com/q/23496750/1061155

url = "http://127.0.0.1:8000/imaginary-pagination:7"
quoted = "http://127.0.0.1:8000/imaginary-pagination%3A7"
r = requests.get(url, quote=":")
assert r.url == quoted

url = "http://127.0.0.1:8000/post.json?limit=1&tags=id:<1000&page=0"
quoted = "http://127.0.0.1:8000/post.json?limit=1&tags=id%3A%3C1000&page=0"
r = requests.get(url, quote=":")
assert r.url == quoted

# Do not quote at all
url = "http://127.0.0.1:8000/query={}"
quoted = "http://127.0.0.1:8000/query=%7B%7D"
r = requests.get(url)
assert r.url == quoted
r = requests.get(url, quote=False)
assert r.url == url


def test_headers(server):
r = requests.get(str(server.url.copy_with(path="/echo_headers")), headers={"foo": "bar"})
headers = r.json()
Expand Down

0 comments on commit 93e551d

Please sign in to comment.