Skip to content

Commit

Permalink
Add a graceful shutdown period to allow tasks to complete. (#7188)
Browse files Browse the repository at this point in the history
When the server is shutting down gracefully, it should wait on pending
tasks before running the application shutdown/cleanup steps and
cancelling all remaining tasks.

This helps ensure that tasks have a chance to finish writing to a DB,
handlers can finish responding to clients etc.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sviatoslav Sydorenko <[email protected]>
  • Loading branch information
3 people authored Feb 11, 2023
1 parent 82c944c commit edd49b5
Show file tree
Hide file tree
Showing 7 changed files with 251 additions and 20 deletions.
1 change: 1 addition & 0 deletions CHANGES/7188.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added a graceful shutdown period which allows pending tasks to complete before the application's cleanup is called. The period can be adjusted with the ``shutdown_timeout`` parameter -- by :user:`Dreamsorcerer`.
2 changes: 1 addition & 1 deletion aiohttp/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ async def start_server(self, **kwargs: Any) -> None:
if self.runner:
return
self._ssl = kwargs.pop("ssl", None)
self.runner = await self._make_runner(**kwargs)
self.runner = await self._make_runner(handler_cancellation=True, **kwargs)
await self.runner.setup()
if not self.port:
self.port = 0
Expand Down
2 changes: 1 addition & 1 deletion aiohttp/web_fileresponse.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,4 +283,4 @@ async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter
try:
return await self._sendfile(request, fobj, offset, count)
finally:
await loop.run_in_executor(None, fobj.close)
await asyncio.shield(loop.run_in_executor(None, fobj.close))
23 changes: 22 additions & 1 deletion aiohttp/web_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import signal
import socket
from abc import ABC, abstractmethod
from contextlib import suppress
from typing import Any, List, Optional, Set, Type

from yarl import URL
Expand Down Expand Up @@ -80,11 +81,26 @@ async def stop(self) -> None:
# named pipes do not have wait_closed property
if hasattr(self._server, "wait_closed"):
await self._server.wait_closed()

# Wait for pending tasks for a given time limit.
with suppress(asyncio.TimeoutError):
await asyncio.wait_for(
self._wait(asyncio.current_task()), timeout=self._shutdown_timeout
)

await self._runner.shutdown()
assert self._runner.server
await self._runner.server.shutdown(self._shutdown_timeout)
self._runner._unreg_site(self)

async def _wait(self, parent_task: Optional["asyncio.Task[object]"]) -> None:
exclude = self._runner.starting_tasks | {asyncio.current_task(), parent_task}
# TODO(PY38): while tasks := asyncio.all_tasks() - exclude:
tasks = asyncio.all_tasks() - exclude
while tasks:
await asyncio.wait(tasks)
tasks = asyncio.all_tasks() - exclude


class TCPSite(BaseSite):
__slots__ = ("_host", "_port", "_reuse_address", "_reuse_port")
Expand Down Expand Up @@ -247,7 +263,7 @@ async def start(self) -> None:


class BaseRunner(ABC):
__slots__ = ("_handle_signals", "_kwargs", "_server", "_sites")
__slots__ = ("starting_tasks", "_handle_signals", "_kwargs", "_server", "_sites")

def __init__(self, *, handle_signals: bool = False, **kwargs: Any) -> None:
self._handle_signals = handle_signals
Expand Down Expand Up @@ -287,6 +303,11 @@ async def setup(self) -> None:
pass

self._server = await self._make_server()
# On shutdown we want to avoid waiting on tasks which run forever.
# It's very likely that all tasks which run forever will have been created by
# the time we have completed the application startup (in self._make_server()),
# so we just record all running tasks here and exclude them later.
self.starting_tasks = asyncio.all_tasks()

@abstractmethod
async def shutdown(self) -> None:
Expand Down
10 changes: 8 additions & 2 deletions docs/web_advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -927,8 +927,14 @@ Graceful shutdown
Stopping *aiohttp web server* by just closing all connections is not
always satisfactory.

The problem is: if application supports :term:`websocket`\s or *data
streaming* it most likely has open connections at server
The first thing aiohttp will do is to stop listening on the sockets,
so new connections will be rejected. It will then wait a few
seconds to allow any pending tasks to complete before continuing
with application shutdown. The timeout can be adjusted with
``shutdown_timeout`` in :func:`run_app`.

Another problem is if the application supports :term:`websockets <websocket>` or
*data streaming* it most likely has open connections at server
shutdown time.

The *library* has no knowledge how to close them gracefully but
Expand Down
34 changes: 21 additions & 13 deletions docs/web_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2688,9 +2688,10 @@ application on specific TCP or Unix socket, e.g.::

:param int port: PORT to listed on, ``8080`` if ``None`` (default).

:param float shutdown_timeout: a timeout for closing opened
connections on :meth:`BaseSite.stop`
call.
:param float shutdown_timeout: a timeout used for both waiting on pending
tasks before application shutdown and for
closing opened connections on
:meth:`BaseSite.stop` call.

:param ssl_context: a :class:`ssl.SSLContext` instance for serving
SSL/TLS secure server, ``None`` for plain HTTP
Expand Down Expand Up @@ -2723,9 +2724,10 @@ application on specific TCP or Unix socket, e.g.::

:param str path: PATH to UNIX socket to listen.

:param float shutdown_timeout: a timeout for closing opened
connections on :meth:`BaseSite.stop`
call.
:param float shutdown_timeout: a timeout used for both waiting on pending
tasks before application shutdown and for
closing opened connections on
:meth:`BaseSite.stop` call.

:param ssl_context: a :class:`ssl.SSLContext` instance for serving
SSL/TLS secure server, ``None`` for plain HTTP
Expand All @@ -2745,9 +2747,10 @@ application on specific TCP or Unix socket, e.g.::

:param str path: PATH of named pipe to listen.

:param float shutdown_timeout: a timeout for closing opened
connections on :meth:`BaseSite.stop`
call.
:param float shutdown_timeout: a timeout used for both waiting on pending
tasks before application shutdown and for
closing opened connections on
:meth:`BaseSite.stop` call.

.. class:: SockSite(runner, sock, *, \
shutdown_timeout=60.0, ssl_context=None, \
Expand All @@ -2759,9 +2762,10 @@ application on specific TCP or Unix socket, e.g.::

:param sock: A :ref:`socket instance <socket-objects>` to listen to.

:param float shutdown_timeout: a timeout for closing opened
connections on :meth:`BaseSite.stop`
call.
:param float shutdown_timeout: a timeout used for both waiting on pending
tasks before application shutdown and for
closing opened connections on
:meth:`BaseSite.stop` call.

:param ssl_context: a :class:`ssl.SSLContext` instance for serving
SSL/TLS secure server, ``None`` for plain HTTP
Expand Down Expand Up @@ -2857,9 +2861,13 @@ Utilities
shutdown before disconnecting all
open client sockets hard way.

This is used as a delay to wait for
pending tasks to complete and then
again to close any pending connections.

A system with properly
:ref:`aiohttp-web-graceful-shutdown`
implemented never waits for this
implemented never waits for the second
timeout but closes a server in a few
milliseconds.

Expand Down
199 changes: 197 additions & 2 deletions tests/test_run_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
import ssl
import subprocess
import sys
from typing import Any
import time
from typing import Any, Callable, NoReturn
from unittest import mock
from uuid import uuid4

import pytest
from conftest import needs_unix

from aiohttp import web
from aiohttp import ClientConnectorError, ClientSession, web
from aiohttp.test_utils import make_mocked_coro
from aiohttp.web_runner import BaseRunner

Expand Down Expand Up @@ -926,3 +927,197 @@ async def init():

web.run_app(init(), print=stopper(patched_loop), loop=patched_loop)
assert count == 3


class TestShutdown:
def raiser(self) -> NoReturn:
raise KeyboardInterrupt

async def stop(self, request: web.Request) -> web.Response:
asyncio.get_running_loop().call_soon(self.raiser)
return web.Response()

def run_app(self, port: int, timeout: int, task, extra_test=None) -> asyncio.Task:
async def test() -> None:
await asyncio.sleep(1)
async with ClientSession() as sess:
async with sess.get(f"http://localhost:{port}/"):
pass
async with sess.get(f"http://localhost:{port}/stop"):
pass

if extra_test:
await extra_test(sess)

async def run_test(app: web.Application) -> None:
nonlocal test_task
test_task = asyncio.create_task(test())
yield
await test_task

async def handler(request: web.Request) -> web.Response:
nonlocal t
t = asyncio.create_task(task())
return web.Response(text="FOO")

t = test_task = None
app = web.Application()
app.cleanup_ctx.append(run_test)
app.router.add_get("/", handler)
app.router.add_get("/stop", self.stop)

web.run_app(app, port=port, shutdown_timeout=timeout)
assert test_task.exception() is None
return t

def test_shutdown_wait_for_task(
self, aiohttp_unused_port: Callable[[], int]
) -> None:
port = aiohttp_unused_port()
finished = False

async def task():
nonlocal finished
await asyncio.sleep(2)
finished = True

t = self.run_app(port, 3, task)

assert finished is True
assert t.done()
assert not t.cancelled()

def test_shutdown_timeout_task(
self, aiohttp_unused_port: Callable[[], int]
) -> None:
port = aiohttp_unused_port()
finished = False

async def task():
nonlocal finished
await asyncio.sleep(2)
finished = True

t = self.run_app(port, 1, task)

assert finished is False
assert t.done()
assert t.cancelled()

def test_shutdown_wait_for_spawned_task(
self, aiohttp_unused_port: Callable[[], int]
) -> None:
port = aiohttp_unused_port()
finished = False
finished_sub = False
sub_t = None

async def sub_task():
nonlocal finished_sub
await asyncio.sleep(1.5)
finished_sub = True

async def task():
nonlocal finished, sub_t
await asyncio.sleep(0.5)
sub_t = asyncio.create_task(sub_task())
finished = True

t = self.run_app(port, 3, task)

assert finished is True
assert t.done()
assert not t.cancelled()
assert finished_sub is True
assert sub_t.done()
assert not sub_t.cancelled()

def test_shutdown_timeout_not_reached(
self, aiohttp_unused_port: Callable[[], int]
) -> None:
port = aiohttp_unused_port()
finished = False

async def task():
nonlocal finished
await asyncio.sleep(1)
finished = True

start_time = time.time()
t = self.run_app(port, 15, task)

assert finished is True
assert t.done()
# Verify run_app has not waited for timeout.
assert time.time() - start_time < 10

def test_shutdown_new_conn_rejected(
self, aiohttp_unused_port: Callable[[], int]
) -> None:
port = aiohttp_unused_port()
finished = False

async def task() -> None:
nonlocal finished
await asyncio.sleep(9)
finished = True

async def test(sess: ClientSession) -> None:
# Ensure we are in the middle of shutdown (waiting for task()).
await asyncio.sleep(1)
with pytest.raises(ClientConnectorError):
# Use a new session to try and open a new connection.
async with ClientSession() as sess:
async with sess.get(f"http://localhost:{port}/"):
pass
assert finished is False

t = self.run_app(port, 10, task, test)

assert finished is True
assert t.done()

def test_shutdown_pending_handler_responds(
self, aiohttp_unused_port: Callable[[], int]
) -> None:
port = aiohttp_unused_port()
finished = False

async def test() -> None:
async def test_resp(sess):
async with sess.get(f"http://localhost:{port}/") as resp:
assert await resp.text() == "FOO"

await asyncio.sleep(1)
async with ClientSession() as sess:
t = asyncio.create_task(test_resp(sess))
await asyncio.sleep(1)
# Handler is in-progress while we trigger server shutdown.
async with sess.get(f"http://localhost:{port}/stop"):
pass

assert finished is False
# Handler should still complete and produce a response.
await t

async def run_test(app: web.Application) -> None:
nonlocal t
t = asyncio.create_task(test())
yield
await t

async def handler(request: web.Request) -> web.Response:
nonlocal finished
await asyncio.sleep(3)
finished = True
return web.Response(text="FOO")

t = None
app = web.Application()
app.cleanup_ctx.append(run_test)
app.router.add_get("/", handler)
app.router.add_get("/stop", self.stop)

web.run_app(app, port=port, shutdown_timeout=5)
assert t.exception() is None
assert finished is True

0 comments on commit edd49b5

Please sign in to comment.