diff --git a/curl_cffi/requests/session.py b/curl_cffi/requests/session.py index 998fc5d..01f3258 100644 --- a/curl_cffi/requests/session.py +++ b/curl_cffi/requests/session.py @@ -21,6 +21,7 @@ TypedDict, Union, cast, + Type, ) from urllib.parse import ParseResult, parse_qsl, quote, unquote, urlencode, urljoin, urlparse @@ -86,6 +87,7 @@ class BaseSessionParams(TypedDict, total=False): debug: bool interface: Optional[str] cert: Optional[Union[str, Tuple[str, str]]] + response_class: Optional[Type[Response]] else: ProxySpec = Dict[str, str] @@ -225,6 +227,7 @@ def __init__( debug: bool = False, interface: Optional[str] = None, cert: Optional[Union[str, Tuple[str, str]]] = None, + response_class: Optional[Type[Response]] = None, ): self.headers = Headers(headers) self.cookies = Cookies(cookies) @@ -249,6 +252,13 @@ def __init__( self.interface = interface self.cert = cert + if response_class is None: + response_class = Response + elif not issubclass(response_class, Response): + raise TypeError( "`response_class` must be a subclass of `curl_cffi.requests.models.Response`" + f" not of type `{response_class}`" ) + self.response_class = response_class + if proxy and proxies: raise TypeError("Cannot specify both 'proxy' and 'proxies'") if proxy: @@ -703,7 +713,7 @@ def qput(chunk): def _parse_response(self, curl, buffer, header_buffer, default_encoding): c = curl - rsp = Response(c) + rsp = self.response_class(c) rsp.url = cast(bytes, c.getinfo(CurlInfo.EFFECTIVE_URL)).decode() if buffer: rsp.content = buffer.getvalue() diff --git a/examples/custom_response_class.py b/examples/custom_response_class.py new file mode 100644 index 0000000..f367943 --- /dev/null +++ b/examples/custom_response_class.py @@ -0,0 +1,23 @@ +from curl_cffi import requests +from curl_cffi.curl import Curl, CurlInfo +from typing import cast + +class CustomResponse(requests.Response): + def __init__(self, curl: Curl | None = None, request: requests.Request | None = None): + super().__init__(curl, request) + self.local_port = cast(int, curl.getinfo(CurlInfo.LOCAL_PORT)) + self.connect_time = cast(float, curl.getinfo(CurlInfo.CONNECT_TIME)) + + @property + def status(self): + return self.status_code + + def custom_method(self): + return "this is a custom method" + +session = requests.Session(response_class=CustomResponse) +response: CustomResponse = session.get("http://example.com") +print(f"{response.status=}") +print(response.custom_method()) +print(f"{response.local_port=}") +print(f"{response.connect_time=}") \ No newline at end of file diff --git a/tests/integration/test_response_class.py b/tests/integration/test_response_class.py new file mode 100644 index 0000000..985e951 --- /dev/null +++ b/tests/integration/test_response_class.py @@ -0,0 +1,25 @@ +import pytest +from curl_cffi import requests + +def test_default_response(): + response = requests.get("http://example.com") + assert type(response) == requests.Response + print(response.status_code) + +class CustomResponse(requests.Response): + @property + def status(self): + return self.status_code + +def test_custom_response(): + session = requests.Session(response_class=CustomResponse) + response = session.get("http://example.com") + assert isinstance(response, CustomResponse) + assert hasattr(response, "status") + print(response.status) + +class WrongTypeResponse: pass + +def test_wrong_type_custom_response(): + with pytest.raises(TypeError): + requests.Session(response_class=WrongTypeResponse) \ No newline at end of file