Skip to content

Commit

Permalink
Add tests for stream
Browse files Browse the repository at this point in the history
  • Loading branch information
perklet committed Oct 6, 2023
1 parent 2149ffe commit c0e6fa0
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 3 deletions.
36 changes: 36 additions & 0 deletions tests/unittest/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ async def app(scope, receive, send):
await echo_path(scope, receive, send)
elif scope["path"].startswith("/echo_params"):
await echo_params(scope, receive, send)
elif scope["path"].startswith("/stream"):
await stream(scope, receive, send)
elif scope["path"].startswith("/empty_body"):
await empty_body(scope, receive, send)
elif scope["path"].startswith("/echo_body"):
await echo_body(scope, receive, send)
elif scope["path"].startswith("/echo_binary"):
Expand Down Expand Up @@ -229,6 +233,38 @@ async def echo_params(scope, receive, send):
await send({"type": "http.response.body", "body": json.dumps(body).encode()})


async def stream(scope, receive, send):
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"application/json"]],
}
)
body = {"path": scope["path"], "params": parse_qs(scope["query_string"].decode())}
n = int(parse_qs(scope["query_string"].decode()).get("n", [10])[0])
for _ in range(n - 1):
await send(
{
"type": "http.response.body",
"body": json.dumps(body).encode() + b"\n",
"more_body": True,
}
)
await send(
{
"type": "http.response.body",
"body": json.dumps(body).encode(),
"more_body": False,
}
)


async def empty_body(scope, receive, send):
await send({"type": "http.response.start", "status": 200})
await send({"type": "http.response.body", "body": b""})


async def echo_binary(scope, receive, send):
body = b""
more_body = True
Expand Down
56 changes: 56 additions & 0 deletions tests/unittest/test_async_session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import base64
import json

import pytest

Expand Down Expand Up @@ -303,3 +304,58 @@ async def test_parallel(server):
for idx, r in enumerate(rs):
assert r.status_code == 200
assert r.json()["Foo"][0] == str(idx)


async def test_stream_iter_content(server):
async with AsyncSession() as s:
url = str(server.url.copy_with(path="/stream"))
async with s.stream("GET", url, params={"n": "20"}) as r:
async for chunk in r.aiter_content():
assert b"path" in chunk


async def test_stream_iter_content_break(server):
async with AsyncSession() as s:
url = str(server.url.copy_with(path="/stream"))
async with s.stream("GET", url, params={"n": "20"}) as r:
idx = 0
async for chunk in r.aiter_content():
idx += 1
assert b"path" in chunk
if idx == 3:
break
assert r.status_code == 200


async def test_stream_iter_lines(server):
async with AsyncSession() as s:
url = str(server.url.copy_with(path="/stream"))
async with s.stream("GET", url, params={"n": "20"}) as r:
async for chunk in r.aiter_lines():
data = json.loads(chunk)
assert data["path"] == "/stream"


async def test_stream_status_code(server):
async with AsyncSession() as s:
url = str(server.url.copy_with(path="/stream"))
async with s.stream("GET", url, params={"n": "20"}) as r:
assert r.status_code == 200


async def test_stream_empty_body(server):
async with AsyncSession() as s:
url = str(server.url.copy_with(path="/empty_body"))
async with s.stream("GET", url) as r:
assert r.status_code == 200


async def test_stream_atext(server):
async with AsyncSession() as s:
url = str(server.url.copy_with(path="/stream"))
async with s.stream("GET", url, params={"n": "20"}) as r:
text = await r.atext()
chunks = text.split("\n")
assert len(chunks) == 20


44 changes: 41 additions & 3 deletions tests/unittest/test_requests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
from io import BytesIO
import json

import pytest

Expand Down Expand Up @@ -82,7 +83,7 @@ def test_options(server):
assert r.status_code == 200


def test_parms(server):
def test_params(server):
r = requests.get(
str(server.url.copy_with(path="/echo_params")), params={"foo": "bar"}
)
Expand Down Expand Up @@ -407,6 +408,43 @@ def test_session_with_headers(server):
assert r.status_code == 200


def test_stream(server):
s = requests.Session()
def test_stream_iter_content(server):
with requests.Session() as s:
url = str(server.url.copy_with(path="/stream"))
with s.stream("GET", url, params={"n": "20"}) as r:
for chunk in r.iter_content():
assert b"path" in chunk


def test_stream_iter_content_break(server):
with requests.Session() as s:
url = str(server.url.copy_with(path="/stream"))
with s.stream("GET", url, params={"n": "20"}) as r:
for idx, chunk in enumerate(r.iter_content()):
assert b"path" in chunk
if idx == 3:
break
assert r.status_code == 200


def test_stream_iter_lines(server):
with requests.Session() as s:
url = str(server.url.copy_with(path="/stream"))
with s.stream("GET", url, params={"n": "20"}) as r:
for chunk in r.iter_lines():
data = json.loads(chunk)
assert data["path"] == "/stream"


def test_stream_status_code(server):
with requests.Session() as s:
url = str(server.url.copy_with(path="/stream"))
with s.stream("GET", url, params={"n": "20"}) as r:
assert r.status_code == 200


def test_stream_empty_body(server):
with requests.Session() as s:
url = str(server.url.copy_with(path="/empty_body"))
with s.stream("GET", url) as r:
assert r.status_code == 200

0 comments on commit c0e6fa0

Please sign in to comment.