diff --git a/cheroot/test/test_conn.py b/cheroot/test/test_conn.py index 61574dafd6..c62929831f 100644 --- a/cheroot/test/test_conn.py +++ b/cheroot/test/test_conn.py @@ -875,6 +875,89 @@ def _trigger_scary_exc(_req, _resp): ) +def test_remains_alive_post_unhandled_exception( + mocker, + monkeypatch, + test_client, + testing_server, + wsgi_server_thread, +): + """Ensure worker threads are resillient to unhandled exceptions.""" + + class ScaryCrash(BaseException): # noqa: WPS418, WPS431 + """A simulated crash during HTTP parsing.""" + + _orig_read_request_line = ( + test_client.server_instance. + ConnectionClass.RequestHandlerClass. + read_request_line + ) + + def _read_request_line(self): + _orig_read_request_line(self) + raise ScaryCrash(666) + + monkeypatch.setattr( + test_client.server_instance.ConnectionClass.RequestHandlerClass, + 'read_request_line', + _read_request_line, + ) + + server_connection_close_spy = mocker.spy( + test_client.server_instance.ConnectionClass, + 'close', + ) + + # NOTE: The initial worker thread count is 10. + assert len(testing_server.requests._threads) == 10 + + test_client.get_connection().send(b'GET / HTTP/1.1') + + # NOTE: This spy ensure the log entry gets recorded before we're testing + # NOTE: them and before server shutdown, preserving their order and making + # NOTE: the log entry presence non-flaky. + while not server_connection_close_spy.called: # noqa: WPS328 + pass + + # NOTE: This checks for whether there's any crashed threads + while testing_server.requests.idle < 10: # noqa: WPS328 + pass + assert len(testing_server.requests._threads) == 10 + assert all( + worker_thread.is_alive() + for worker_thread in testing_server.requests._threads + ) + testing_server.interrupt = SystemExit('test requesting shutdown') + assert not testing_server.requests._threads + wsgi_server_thread.join() # no extra logs upon server termination + + actual_log_entries = testing_server.error_log.calls[:] + testing_server.error_log.calls.clear() # prevent post-test assertions + + expected_log_entries = ( + ( + logging.ERROR, + '^Unhandled error while processing an incoming connection ' + r'ScaryCrash\(666\)$', + ), + ( + logging.INFO, + '^SystemExit raised: shutting down$', + ), + ) + + assert len(actual_log_entries) == len(expected_log_entries) + + for ( # noqa: WPS352 + (expected_log_level, expected_msg_regex), + (actual_msg, actual_log_level, _tb), + ) in zip(expected_log_entries, actual_log_entries): + assert expected_log_level == actual_log_level + assert _matches_pattern(expected_msg_regex, actual_msg) is not None, ( + f'{actual_msg !r} does not match {expected_msg_regex !r}' + ) + + @pytest.mark.parametrize( 'timeout_before_headers', (