Skip to content

Commit

Permalink
Use curl_easy_duphandle for duplicating handles in stream
Browse files Browse the repository at this point in the history
  • Loading branch information
perklet committed Nov 8, 2023
1 parent 0563151 commit 4d4172b
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 31 deletions.
6 changes: 5 additions & 1 deletion curl_cffi/curl.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(self, cacert: str = DEFAULT_CACERT, debug: bool = False, handle = N
self._is_cert_set = False
self._write_handle = None
self._header_handle = None
self._body_handle = None
# TODO: use CURL_ERROR_SIZE
self._error_buffer = ffi.new("char[]", 256)
self._debug = debug
Expand Down Expand Up @@ -264,8 +265,11 @@ def clean_after_perform(self, clear_headers: bool = True):
self._headers = ffi.NULL

def duphandle(self):
"""This is not a full copy of entire curl object in python. For example, headers
handle is not copied, you have to set them again."""
new_handle = lib.curl_easy_duphandle(self._curl)
return Curl(cacert=self._cacert, debug=self._debug, handle=new_handle)
c = Curl(cacert=self._cacert, debug=self._debug, handle=new_handle)
return c

def reset(self):
"""Reset all curl options, wrapper for curl_easy_reset."""
Expand Down
37 changes: 10 additions & 27 deletions curl_cffi/requests/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,27 +461,6 @@ def _parse_response(self, curl, buffer, header_buffer):
# ThreadType = Literal["eventlet", "gevent", None]


class CurlSetoptProxy:
def __init__(self, curl: Curl, debug: bool = False):
self._real_curl = curl
self._debug = debug
self._options = {}

def setopt(self, option: CurlOpt, value: Any):
self._options[option] = value

def unwrap(self, clone: bool=False) -> Curl:
if not clone:
c = self._real_curl
else:
c = Curl(debug=self._debug)

for opt, value in self._options.items():
c.setopt(opt, value)

return c


class Session(BaseSession):
"""A request session, cookies and connections will be reused. This object is thread-safe,
but it's recommended to use a seperate session for each thread."""
Expand Down Expand Up @@ -532,20 +511,20 @@ def __init__(
self._local = threading.local()
if curl:
self._is_customized_curl = True
self._local.curl = CurlSetoptProxy(curl)
self._local.curl = curl
else:
self._is_customized_curl = False
self._local.curl = CurlSetoptProxy(Curl(debug=self.debug))
self._local.curl = Curl(debug=self.debug)
else:
self._curl = CurlSetoptProxy(curl if curl else Curl(debug=self.debug))
self._curl = curl if curl else Curl(debug=self.debug)

@property
def curl(self):
if self._use_thread_local_curl:
if self._is_customized_curl:
warnings.warn("Creating fresh curl handle in different thread.")
if not getattr(self._local, "curl", None):
self._local.curl = CurlSetoptProxy(Curl(debug=self.debug))
self._local.curl = Curl(debug=self.debug)
return self._local.curl
else:
return self._curl
Expand All @@ -564,7 +543,7 @@ def __exit__(self, *args):

def close(self):
"""Close the session."""
self.curl.unwrap().close()
self.curl.close()

@contextmanager
def stream(self, *args, **kwargs):
Expand Down Expand Up @@ -602,7 +581,11 @@ def request(
"""Send the request, see [curl_cffi.requests.request](/api/curl_cffi.requests/#curl_cffi.requests.request) for details on parameters."""

# clone a new curl instance for streaming response
c = self.curl.unwrap(clone=stream)
if stream:
c = self.curl.duphandle()
self.curl.reset()
else:
c = self.curl

req, buffer, header_buffer, q, header_recved = self._set_curl_options(
c,
Expand Down
6 changes: 3 additions & 3 deletions tests/unittest/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def test_post_body_cleaned(server):
# POST with body
r = s.post(str(server.url), json={"foo": "bar"})
# GET request with echo_body
assert s.curl.unwrap()._is_cert_set is False
assert s.curl._is_cert_set is False
r = s.get(str(server.url.copy_with(path="/echo_body")))
# ensure body is empty
assert r.content == b""
Expand Down Expand Up @@ -558,12 +558,12 @@ def test_stream_options_persist(server):
s = requests.Session()

# set here instead of when requesting
s.curl.setopt(CurlOpt.HTTPHEADER, [b"Foo: bar"])
s.curl.setopt(CurlOpt.USERAGENT, b"foo/1.0")

url = str(server.url.copy_with(path="/echo_headers"))
r = s.get(url, stream=True)
buffer = []
for line in r.iter_lines():
buffer.append(line)
data = json.loads(b"".join(buffer))
assert data["Foo"][0] == "bar"
assert data["User-agent"][0] == "foo/1.0"

0 comments on commit 4d4172b

Please sign in to comment.