Skip to content

Commit

Permalink
Add max_recv_speed and close stream early
Browse files Browse the repository at this point in the history
  • Loading branch information
perklet committed Nov 23, 2023
1 parent 8b0c96a commit c392c19
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 4 deletions.
7 changes: 6 additions & 1 deletion curl_cffi/curl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def __init__(self, msg, code: int = 0, *args, **kwargs):
CURLINFO_SSL_DATA_IN = 5
CURLINFO_SSL_DATA_OUT = 6

CURL_WRITEFUNC_PAUSE = 0x10000001
CURL_WRITEFUNC_ERROR = 0xFFFFFFFF


@ffi.def_extern()
def debug_function(curl, type: int, data, size, clientp) -> int:
Expand All @@ -51,7 +54,9 @@ def buffer_callback(ptr, size, nmemb, userdata):
def write_callback(ptr, size, nmemb, userdata):
# although similar enough to the function above, kept here for performance reasons
callback = ffi.from_handle(userdata)
callback(ffi.buffer(ptr, nmemb)[:])
wrote = callback(ffi.buffer(ptr, nmemb)[:])
if wrote == CURL_WRITEFUNC_PAUSE or wrote == CURL_WRITEFUNC_ERROR:
return wrote
return nmemb * size


Expand Down
2 changes: 2 additions & 0 deletions curl_cffi/requests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(self, curl: Optional[Curl] = None, request: Optional[Request] = Non
self.infos = {}
self.queue: Optional[queue.Queue] = None
self.stream_task = None
self.quit_now = None

def _decode(self, content: bytes) -> str:
try:
Expand Down Expand Up @@ -128,6 +129,7 @@ def json(self, **kw):
return loads(self.content, **kw)

def close(self):
self.quit_now.set() # type: ignore
self.stream_task.result() # type: ignore

async def aiter_lines(self, chunk_size=None, decode_unicode=False, delimiter=None):
Expand Down
23 changes: 20 additions & 3 deletions curl_cffi/requests/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from urllib.parse import ParseResult, parse_qsl, unquote, urlencode, urlparse
from concurrent.futures import ThreadPoolExecutor


from .. import AsyncCurl, Curl, CurlError, CurlInfo, CurlOpt, CurlHttpVersion
from ..curl import CURL_WRITEFUNC_ERROR
from .cookies import Cookies, CookieTypes, CurlMorsel
from .errors import RequestsError
from .headers import Headers, HeaderTypes
Expand Down Expand Up @@ -208,6 +210,7 @@ def _set_curl_options(
http_version: Optional[CurlHttpVersion] = None,
interface: Optional[str] = None,
stream: bool = False,
max_recv_speed: int = 0,
queue_class: Any = None,
event_class: Any = None,
):
Expand Down Expand Up @@ -391,13 +394,17 @@ def _set_curl_options(
buffer = None
q = None
header_recved = None
quit_now = None
if stream:
q = queue_class() # type: ignore
header_recved = event_class()
quit_now = event_class()

def qput(chunk):
if not header_recved.is_set():
header_recved.set()
if quit_now.is_set():
return CURL_WRITEFUNC_ERROR
q.put_nowait(chunk)

c.setopt(CurlOpt.WRITEFUNCTION, qput) # type: ignore
Expand All @@ -417,7 +424,11 @@ def qput(chunk):
if interface:
c.setopt(CurlOpt.INTERFACE, interface.encode())

return req, buffer, header_buffer, q, header_recved
# max_recv_speed
# do not check, since 0 is a valid value to disable it
c.setopt(CurlOpt.MAX_RECV_SPEED_LARGE, max_recv_speed)

return req, buffer, header_buffer, q, header_recved, quit_now

def _parse_response(self, curl, buffer, header_buffer):
c = curl
Expand Down Expand Up @@ -593,6 +604,7 @@ def request(
http_version: Optional[CurlHttpVersion] = None,
interface: Optional[str] = None,
stream: bool = False,
max_recv_speed: int = 0,
) -> Response:
"""Send the request, see [curl_cffi.requests.request](/api/curl_cffi.requests/#curl_cffi.requests.request) for details on parameters."""

Expand All @@ -603,7 +615,7 @@ def request(
else:
c = self.curl

req, buffer, header_buffer, q, header_recved = self._set_curl_options(
req, buffer, header_buffer, q, header_recved, quit_now = self._set_curl_options(
c,
method=method,
url=url,
Expand All @@ -627,6 +639,7 @@ def request(
http_version=http_version,
interface=interface,
stream=stream,
max_recv_speed=max_recv_speed,
queue_class=queue.Queue,
event_class=threading.Event,
)
Expand Down Expand Up @@ -667,6 +680,7 @@ def cleanup(fut):

rsp.request = req
rsp.stream_task = stream_task # type: ignore
rsp.quit_now = quit_now # type: ignore
rsp.queue = q
return rsp
else:
Expand Down Expand Up @@ -837,10 +851,11 @@ async def request(
http_version: Optional[CurlHttpVersion] = None,
interface: Optional[str] = None,
stream: bool = False,
max_recv_speed: int = 0,
):
"""Send the request, see [curl_cffi.requests.request](/api/curl_cffi.requests/#curl_cffi.requests.request) for details on parameters."""
curl = await self.pop_curl()
req, buffer, header_buffer, q, header_recved = self._set_curl_options(
req, buffer, header_buffer, q, header_recved, quit_now = self._set_curl_options(
curl=curl,
method=method,
url=url,
Expand All @@ -864,6 +879,7 @@ async def request(
http_version=http_version,
interface=interface,
stream=stream,
max_recv_speed=max_recv_speed,
queue_class=asyncio.Queue,
event_class=asyncio.Event,
)
Expand Down Expand Up @@ -903,6 +919,7 @@ def cleanup(fut):

rsp.request = req
rsp.stream_task = stream_task # type: ignore
rsp.quit_now = quit_now
rsp.queue = q
return rsp
else:
Expand Down
19 changes: 19 additions & 0 deletions tests/unittest/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ async def app(scope, receive, send):
await echo_params(scope, receive, send)
elif scope["path"].startswith("/stream"):
await stream(scope, receive, send)
elif scope["path"].startswith("/large"):
await large(scope, receive, send)
elif scope["path"].startswith("/empty_body"):
await empty_body(scope, receive, send)
elif scope["path"].startswith("/echo_body"):
Expand Down Expand Up @@ -262,6 +264,23 @@ async def stream(scope, receive, send):
)


async def large(scope, receive, send):
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [],
}
)
await send(
{
"type": "http.response.body",
"body": os.urandom(20 * 1024 * 1024), # 20MiB
"more_body": False,
}
)


async def empty_body(scope, receive, send):
await send({"type": "http.response.start", "status": 200})
await send({"type": "http.response.body", "body": b""})
Expand Down
30 changes: 30 additions & 0 deletions tests/unittest/test_requests.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import time
from io import BytesIO
import json

Expand Down Expand Up @@ -569,6 +570,35 @@ def test_stream_options_persist(server):
assert data["User-agent"][0] == "foo/1.0"


def test_stream_close_early(server):
s = requests.Session()
# url = str(server.url.copy_with(path="/large"))
# from http://xcal1.vodafone.co.uk/
url = "http://212.183.159.230/10MB.zip"
r = s.get(url, max_recv_speed=1024 * 1024, stream=True)
counter = 0
start = time.time()
for _ in r.iter_content():
counter += 1
if counter > 10:
break
r.close()
end = time.time()
assert end - start < 10


def test_max_recv_speed(server):
s = requests.Session()
url = str(server.url.copy_with(path="/large"))
# from http://xcal1.vodafone.co.uk/
url = "http://212.183.159.230/10MB.zip"
start = time.time()
r = s.get(url, max_recv_speed=1024 * 1024)
end = time.time()
# assert len(r.content) == 20 * 1024 * 1024
assert end - start > 10


def test_curl_infos(server):
s = requests.Session(curl_infos=[CurlInfo.PRIMARY_IP])

Expand Down

0 comments on commit c392c19

Please sign in to comment.