Skip to content

Commit

Permalink
Add customizable response objects (#398)
Browse files Browse the repository at this point in the history
  • Loading branch information
novitae authored Oct 6, 2024
1 parent a2b6f98 commit 81a5400
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 1 deletion.
12 changes: 11 additions & 1 deletion curl_cffi/requests/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
TypedDict,
Union,
cast,
Type,
)
from urllib.parse import ParseResult, parse_qsl, quote, unquote, urlencode, urljoin, urlparse

Expand Down Expand Up @@ -86,6 +87,7 @@ class BaseSessionParams(TypedDict, total=False):
debug: bool
interface: Optional[str]
cert: Optional[Union[str, Tuple[str, str]]]
response_class: Optional[Type[Response]]

else:
ProxySpec = Dict[str, str]
Expand Down Expand Up @@ -225,6 +227,7 @@ def __init__(
debug: bool = False,
interface: Optional[str] = None,
cert: Optional[Union[str, Tuple[str, str]]] = None,
response_class: Optional[Type[Response]] = None,
):
self.headers = Headers(headers)
self.cookies = Cookies(cookies)
Expand All @@ -249,6 +252,13 @@ def __init__(
self.interface = interface
self.cert = cert

if response_class is None:
response_class = Response
elif not issubclass(response_class, Response):
raise TypeError( "`response_class` must be a subclass of `curl_cffi.requests.models.Response`"
f" not of type `{response_class}`" )
self.response_class = response_class

if proxy and proxies:
raise TypeError("Cannot specify both 'proxy' and 'proxies'")
if proxy:
Expand Down Expand Up @@ -703,7 +713,7 @@ def qput(chunk):

def _parse_response(self, curl, buffer, header_buffer, default_encoding):
c = curl
rsp = Response(c)
rsp = self.response_class(c)
rsp.url = cast(bytes, c.getinfo(CurlInfo.EFFECTIVE_URL)).decode()
if buffer:
rsp.content = buffer.getvalue()
Expand Down
23 changes: 23 additions & 0 deletions examples/custom_response_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from curl_cffi import requests
from curl_cffi.curl import Curl, CurlInfo
from typing import cast

class CustomResponse(requests.Response):
def __init__(self, curl: Curl | None = None, request: requests.Request | None = None):
super().__init__(curl, request)
self.local_port = cast(int, curl.getinfo(CurlInfo.LOCAL_PORT))
self.connect_time = cast(float, curl.getinfo(CurlInfo.CONNECT_TIME))

@property
def status(self):
return self.status_code

def custom_method(self):
return "this is a custom method"

session = requests.Session(response_class=CustomResponse)
response: CustomResponse = session.get("http://example.com")
print(f"{response.status=}")
print(response.custom_method())
print(f"{response.local_port=}")
print(f"{response.connect_time=}")
25 changes: 25 additions & 0 deletions tests/integration/test_response_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest
from curl_cffi import requests

def test_default_response():
response = requests.get("http://example.com")
assert type(response) == requests.Response
print(response.status_code)

class CustomResponse(requests.Response):
@property
def status(self):
return self.status_code

def test_custom_response():
session = requests.Session(response_class=CustomResponse)
response = session.get("http://example.com")
assert isinstance(response, CustomResponse)
assert hasattr(response, "status")
print(response.status)

class WrongTypeResponse: pass

def test_wrong_type_custom_response():
with pytest.raises(TypeError):
requests.Session(response_class=WrongTypeResponse)

0 comments on commit 81a5400

Please sign in to comment.