Skip to content

Commit

Permalink
Support HTTP response without Content-Length.
Browse files Browse the repository at this point in the history
Fix #1531.
  • Loading branch information
aaugustin committed Oct 26, 2024
1 parent 0d2e246 commit 6cea05e
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 6 deletions.
13 changes: 11 additions & 2 deletions src/websockets/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,13 +1060,22 @@ def eof_received(self) -> None:
# Feed the end of the data stream to the connection.
self.protocol.receive_eof()

# This isn't expected to generate events.
assert not self.protocol.events_received()
# This isn't expected to raise an exception.
events = self.protocol.events_received()

# There is no error handling because send_data() can only write
# the end of the data stream here and it shouldn't raise errors.
self.send_data()

# This code path is triggered when receiving an HTTP response
# without a Content-Length header. This is the only case where
# reading until EOF generates an event; all other events have
# a known length. Ignore for coverage measurement because tests
# are in test_client.py rather than test_connection.py.
for event in events: # pragma: no cover
# This isn't expected to raise an exception.
self.process_event(event)

# The WebSocket protocol has its own closing handshake: endpoints close
# the TCP or TLS connection after sending and receiving a close frame.
# As a consequence, they never need to write after receiving EOF, so
Expand Down
2 changes: 1 addition & 1 deletion src/websockets/legacy/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
headers: datastructures.HeadersLike,
body: bytes = b"",
) -> None:
# If a user passes an int instead of a HTTPStatus, fix it automatically.
# If a user passes an int instead of an HTTPStatus, fix it automatically.
self.status = http.HTTPStatus(status)
self.headers = datastructures.Headers(headers)
self.body = body
Expand Down
13 changes: 11 additions & 2 deletions src/websockets/sync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,13 +696,22 @@ def recv_events(self) -> None:
# Feed the end of the data stream to the protocol.
self.protocol.receive_eof()

# This isn't expected to generate events.
assert not self.protocol.events_received()
# This isn't expected to raise an exception.
events = self.protocol.events_received()

# There is no error handling because send_data() can only write
# the end of the data stream here and it handles errors itself.
self.send_data()

# This code path is triggered when receiving an HTTP response
# without a Content-Length header. This is the only case where
# reading until EOF generates an event; all other events have
# a known length. Ignore for coverage measurement because tests
# are in test_client.py rather than test_connection.py.
for event in events: # pragma: no cover
# This isn't expected to raise an exception.
self.process_event(event)

except Exception as exc:
# This branch should never run. It's a safety net in case of bugs.
self.logger.error("unexpected internal error", exc_info=True)
Expand Down
30 changes: 30 additions & 0 deletions tests/asyncio/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,36 @@ def close_connection(self, request):
"connection closed while reading HTTP status line",
)

async def test_http_response(self):
"""Client reads HTTP response."""

def http_response(connection, request):
return connection.respond(http.HTTPStatus.OK, "👌")

async with serve(*args, process_request=http_response) as server:
with self.assertRaises(InvalidStatus) as raised:
async with connect(get_uri(server)):
self.fail("did not raise")

self.assertEqual(raised.exception.response.status_code, 200)
self.assertEqual(raised.exception.response.body.decode(), "👌")

async def test_http_response_without_content_length(self):
"""Client reads HTTP response without a Content-Length header."""

def http_response(connection, request):
response = connection.respond(http.HTTPStatus.OK, "👌")
del response.headers["Content-Length"]
return response

async with serve(*args, process_request=http_response) as server:
with self.assertRaises(InvalidStatus) as raised:
async with connect(get_uri(server)):
self.fail("did not raise")

self.assertEqual(raised.exception.response.status_code, 200)
self.assertEqual(raised.exception.response.body.decode(), "👌")

async def test_junk_handshake(self):
"""Client closes the connection when receiving non-HTTP response from server."""

Expand Down
33 changes: 32 additions & 1 deletion tests/sync/test_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import http
import logging
import socket
import socketserver
Expand All @@ -6,7 +7,7 @@
import time
import unittest

from websockets.exceptions import InvalidHandshake, InvalidURI
from websockets.exceptions import InvalidHandshake, InvalidStatus, InvalidURI
from websockets.extensions.permessage_deflate import PerMessageDeflate
from websockets.sync.client import *

Expand Down Expand Up @@ -156,6 +157,36 @@ def close_connection(self, request):
"connection closed while reading HTTP status line",
)

def test_http_response(self):
"""Client reads HTTP response."""

def http_response(connection, request):
return connection.respond(http.HTTPStatus.OK, "👌")

with run_server(process_request=http_response) as server:
with self.assertRaises(InvalidStatus) as raised:
with connect(get_uri(server)):
self.fail("did not raise")

self.assertEqual(raised.exception.response.status_code, 200)
self.assertEqual(raised.exception.response.body.decode(), "👌")

def test_http_response_without_content_length(self):
"""Client reads HTTP response without a Content-Length header."""

def http_response(connection, request):
response = connection.respond(http.HTTPStatus.OK, "👌")
del response.headers["Content-Length"]
return response

with run_server(process_request=http_response) as server:
with self.assertRaises(InvalidStatus) as raised:
with connect(get_uri(server)):
self.fail("did not raise")

self.assertEqual(raised.exception.response.status_code, 200)
self.assertEqual(raised.exception.response.body.decode(), "👌")

def test_junk_handshake(self):
"""Client closes the connection when receiving non-HTTP response from server."""

Expand Down

0 comments on commit 6cea05e

Please sign in to comment.