Skip to content

Commit

Permalink
Add base_url to BaseSession (#279)
Browse files Browse the repository at this point in the history
* Add `base_url` to `BaseSession`

* Add the tests for `base_url`

* Add more tests for `base_url`
  • Loading branch information
lebr0nli authored Mar 27, 2024
1 parent 5509112 commit be0fe03
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 1 deletion.
17 changes: 16 additions & 1 deletion curl_cffi/requests/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
Union,
cast,
)
from urllib.parse import ParseResult, parse_qsl, unquote, urlencode, urlparse
from urllib.parse import ParseResult, parse_qsl, unquote, urlencode, urljoin, urlparse

from .. import AsyncCurl, Curl, CurlError, CurlHttpVersion, CurlInfo, CurlOpt
from ..curl import CURL_WRITEFUNC_ERROR, CurlMime
Expand Down Expand Up @@ -103,6 +103,12 @@ class BrowserSpec:
# TODO


def _is_absolute_url(url: str) -> bool:
"""Check if the provided url is an absolute url"""
parsed_url = urlparse(url)
return bool(parsed_url.scheme and parsed_url.hostname)


def _update_url_params(url: str, params: Dict) -> str:
"""Add GET params to provided URL being aware of existing.
Expand Down Expand Up @@ -190,6 +196,7 @@ def __init__(
proxies: Optional[ProxySpec] = None,
proxy: Optional[str] = None,
proxy_auth: Optional[Tuple[str, str]] = None,
base_url: Optional[str] = None,
params: Optional[dict] = None,
verify: bool = True,
timeout: Union[float, Tuple[float, float]] = 30,
Expand All @@ -208,6 +215,7 @@ def __init__(
self.headers = Headers(headers)
self.cookies = Cookies(cookies)
self.auth = auth
self.base_url = base_url
self.params = params
self.verify = verify
self.timeout = timeout
Expand All @@ -230,6 +238,9 @@ def __init__(
self.proxies: ProxySpec = proxies or {}
self.proxy_auth = proxy_auth

if self.base_url and not _is_absolute_url(self.base_url):
raise ValueError("You need to provide an absolute url for 'base_url'")

self._closed = False

def _set_curl_options(
Expand Down Expand Up @@ -278,6 +289,8 @@ def _set_curl_options(
url = _update_url_params(url, self.params)
if params:
url = _update_url_params(url, params)
if self.base_url:
url = urljoin(self.base_url, url)
c.setopt(CurlOpt.URL, url.encode())

# data/body/json
Expand Down Expand Up @@ -617,6 +630,7 @@ def __init__(
proxies: dict of proxies to use, format: {"http": proxy_url, "https": proxy_url}.
proxy: proxy to use, format: "http://proxy_url". Cannot be used with the above parameter.
proxy_auth: HTTP basic auth for proxy, a tuple of (username, password).
base_url: absolute url to use for relative urls.
params: query string for the session.
verify: whether to verify https certs.
timeout: how many seconds to wait before giving up.
Expand Down Expand Up @@ -897,6 +911,7 @@ def __init__(
proxies: dict of proxies to use, format: {"http": proxy_url, "https": proxy_url}.
proxy: proxy to use, format: "http://proxy_url". Cannot be used with the above parameter.
proxy_auth: HTTP basic auth for proxy, a tuple of (username, password).
base_url: absolute url to use for relative urls.
params: query string for the session.
verify: whether to verify https certs.
timeout: how many seconds to wait before giving up.
Expand Down
29 changes: 29 additions & 0 deletions tests/unittest/test_async_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,35 @@ async def test_options(server):
assert r.status_code == 200


async def test_base_url(server):
async with AsyncSession(
base_url=str(server.url.copy_with(path="/a/b", params={"foo": "bar"}))
) as s:
# target path is empty
r = await s.get("")
assert r.url == s.base_url

# target path only has params
r = await s.get("", params={"hello": "world"})
assert r.url == str(server.url.copy_with(path="/a/b", params={"hello": "world"}))

# target path is a relative path without starting /
r = await s.get("x")
assert r.url == str(server.url.copy_with(path="/a/x"))
r = await s.get("x", params={"hello": "world"})
assert r.url == str(server.url.copy_with(path="/a/x", params={"hello": "world"}))

# target path is a relative path with starting /
r = await s.get("/x")
assert r.url == str(server.url.copy_with(path="/x"))
r = await s.get("/x", params={"hello": "world"})
assert r.url == str(server.url.copy_with(path="/x", params={"hello": "world"}))

# target path is an absolute url
r = await s.get(str(server.url.copy_with(path="/x/y")))
assert r.url == str(server.url.copy_with(path="/x/y"))


async def test_params(server):
async with AsyncSession() as s:
r = await s.get(str(server.url.copy_with(path="/echo_params")), params={"foo": "bar"})
Expand Down
28 changes: 28 additions & 0 deletions tests/unittest/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,34 @@ def test_session_options(server):
assert r.status_code == 200


def test_session_base_url(server):
s = requests.Session(base_url=str(server.url.copy_with(path="/a/b", params={"foo": "bar"})))

# target path is empty
r = s.get("")
assert r.url == s.base_url

# target path only has params
r = s.get("", params={"hello": "world"})
assert r.url == str(server.url.copy_with(path="/a/b", params={"hello": "world"}))

# target path is a relative path without starting /
r = s.get("x")
assert r.url == str(server.url.copy_with(path="/a/x"))
r = s.get("x", params={"hello": "world"})
assert r.url == str(server.url.copy_with(path="/a/x", params={"hello": "world"}))

# target path is a relative path with starting /
r = s.get("/x")
assert r.url == str(server.url.copy_with(path="/x"))
r = s.get("/x", params={"hello": "world"})
assert r.url == str(server.url.copy_with(path="/x", params={"hello": "world"}))

# target path is an absolute url
r = s.get(str(server.url.copy_with(path="/x/y")))
assert r.url == str(server.url.copy_with(path="/x/y"))


def test_session_update_parms(server):
s = requests.Session(params={"old": "day"})
r = s.get(str(server.url.copy_with(path="/echo_params")), params={"foo": "bar"})
Expand Down

0 comments on commit be0fe03

Please sign in to comment.