From c0e6fa0df34ea3d0278597279bc2c0744cd70187 Mon Sep 17 00:00:00 2001 From: Yifei Kong Date: Fri, 6 Oct 2023 22:48:53 +0800 Subject: [PATCH] Add tests for stream --- tests/unittest/conftest.py | 36 ++++++++++++++++++ tests/unittest/test_async_session.py | 56 ++++++++++++++++++++++++++++ tests/unittest/test_requests.py | 44 ++++++++++++++++++++-- 3 files changed, 133 insertions(+), 3 deletions(-) diff --git a/tests/unittest/conftest.py b/tests/unittest/conftest.py index 51c762a0..64de0b9c 100644 --- a/tests/unittest/conftest.py +++ b/tests/unittest/conftest.py @@ -81,6 +81,10 @@ async def app(scope, receive, send): await echo_path(scope, receive, send) elif scope["path"].startswith("/echo_params"): await echo_params(scope, receive, send) + elif scope["path"].startswith("/stream"): + await stream(scope, receive, send) + elif scope["path"].startswith("/empty_body"): + await empty_body(scope, receive, send) elif scope["path"].startswith("/echo_body"): await echo_body(scope, receive, send) elif scope["path"].startswith("/echo_binary"): @@ -229,6 +233,38 @@ async def echo_params(scope, receive, send): await send({"type": "http.response.body", "body": json.dumps(body).encode()}) +async def stream(scope, receive, send): + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"application/json"]], + } + ) + body = {"path": scope["path"], "params": parse_qs(scope["query_string"].decode())} + n = int(parse_qs(scope["query_string"].decode()).get("n", [10])[0]) + for _ in range(n - 1): + await send( + { + "type": "http.response.body", + "body": json.dumps(body).encode() + b"\n", + "more_body": True, + } + ) + await send( + { + "type": "http.response.body", + "body": json.dumps(body).encode(), + "more_body": False, + } + ) + + +async def empty_body(scope, receive, send): + await send({"type": "http.response.start", "status": 200}) + await send({"type": "http.response.body", "body": b""}) + + async def echo_binary(scope, receive, send): body = b"" more_body = True diff --git a/tests/unittest/test_async_session.py b/tests/unittest/test_async_session.py index 586f8c41..0f29d699 100644 --- a/tests/unittest/test_async_session.py +++ b/tests/unittest/test_async_session.py @@ -1,5 +1,6 @@ import asyncio import base64 +import json import pytest @@ -303,3 +304,58 @@ async def test_parallel(server): for idx, r in enumerate(rs): assert r.status_code == 200 assert r.json()["Foo"][0] == str(idx) + + +async def test_stream_iter_content(server): + async with AsyncSession() as s: + url = str(server.url.copy_with(path="/stream")) + async with s.stream("GET", url, params={"n": "20"}) as r: + async for chunk in r.aiter_content(): + assert b"path" in chunk + + +async def test_stream_iter_content_break(server): + async with AsyncSession() as s: + url = str(server.url.copy_with(path="/stream")) + async with s.stream("GET", url, params={"n": "20"}) as r: + idx = 0 + async for chunk in r.aiter_content(): + idx += 1 + assert b"path" in chunk + if idx == 3: + break + assert r.status_code == 200 + + +async def test_stream_iter_lines(server): + async with AsyncSession() as s: + url = str(server.url.copy_with(path="/stream")) + async with s.stream("GET", url, params={"n": "20"}) as r: + async for chunk in r.aiter_lines(): + data = json.loads(chunk) + assert data["path"] == "/stream" + + +async def test_stream_status_code(server): + async with AsyncSession() as s: + url = str(server.url.copy_with(path="/stream")) + async with s.stream("GET", url, params={"n": "20"}) as r: + assert r.status_code == 200 + + +async def test_stream_empty_body(server): + async with AsyncSession() as s: + url = str(server.url.copy_with(path="/empty_body")) + async with s.stream("GET", url) as r: + assert r.status_code == 200 + + +async def test_stream_atext(server): + async with AsyncSession() as s: + url = str(server.url.copy_with(path="/stream")) + async with s.stream("GET", url, params={"n": "20"}) as r: + text = await r.atext() + chunks = text.split("\n") + assert len(chunks) == 20 + + diff --git a/tests/unittest/test_requests.py b/tests/unittest/test_requests.py index 663b7bdd..e330601a 100644 --- a/tests/unittest/test_requests.py +++ b/tests/unittest/test_requests.py @@ -1,5 +1,6 @@ import base64 from io import BytesIO +import json import pytest @@ -82,7 +83,7 @@ def test_options(server): assert r.status_code == 200 -def test_parms(server): +def test_params(server): r = requests.get( str(server.url.copy_with(path="/echo_params")), params={"foo": "bar"} ) @@ -407,6 +408,43 @@ def test_session_with_headers(server): assert r.status_code == 200 -def test_stream(server): - s = requests.Session() +def test_stream_iter_content(server): + with requests.Session() as s: + url = str(server.url.copy_with(path="/stream")) + with s.stream("GET", url, params={"n": "20"}) as r: + for chunk in r.iter_content(): + assert b"path" in chunk + + +def test_stream_iter_content_break(server): + with requests.Session() as s: + url = str(server.url.copy_with(path="/stream")) + with s.stream("GET", url, params={"n": "20"}) as r: + for idx, chunk in enumerate(r.iter_content()): + assert b"path" in chunk + if idx == 3: + break + assert r.status_code == 200 + + +def test_stream_iter_lines(server): + with requests.Session() as s: + url = str(server.url.copy_with(path="/stream")) + with s.stream("GET", url, params={"n": "20"}) as r: + for chunk in r.iter_lines(): + data = json.loads(chunk) + assert data["path"] == "/stream" + + +def test_stream_status_code(server): + with requests.Session() as s: + url = str(server.url.copy_with(path="/stream")) + with s.stream("GET", url, params={"n": "20"}) as r: + assert r.status_code == 200 + +def test_stream_empty_body(server): + with requests.Session() as s: + url = str(server.url.copy_with(path="/empty_body")) + with s.stream("GET", url) as r: + assert r.status_code == 200