Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run AsyncConsumer tasks concurrently #1933

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions channels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,42 @@ async def await_many_dispatch(consumer_callables, dispatch):
"""
Given a set of consumer callables, awaits on them all and passes results
from them to the dispatch awaitable as they come in.
If a dispatch awaitable raises an exception,
this coroutine will fail with that exception.
"""
# Call all callables, and ensure all return types are Futures
tasks = [
asyncio.ensure_future(consumer_callable())
for consumer_callable in consumer_callables
]

dispatch_tasks = []
fut = asyncio.Future() # For child task to report an exception
tasks.append(fut)

def on_dispatch_task_complete(task):
dispatch_tasks.remove(task)
exc = task.exception()
if exc and not isinstance(exc, asyncio.CancelledError) and not fut.done():
fut.set_exception(exc)

try:
while True:
# Wait for any of them to complete
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
# Find the completed one(s), yield results, and replace them
for i, task in enumerate(tasks):
if task.done():
result = task.result()
await dispatch(result)
tasks[i] = asyncio.ensure_future(consumer_callables[i]())
if task == fut:
exc = fut.exception() # Child task has reported an exception
if exc:
raise exc
else:
result = task.result()
task = asyncio.create_task(dispatch(result))
dispatch_tasks.append(task)
task.add_done_callback(on_dispatch_task_complete)
tasks[i] = asyncio.ensure_future(consumer_callables[i]())
finally:
# Make sure we clean up tasks on exit
for task in tasks:
Expand All @@ -57,3 +77,15 @@ async def await_many_dispatch(consumer_callables, dispatch):
await task
except asyncio.CancelledError:
pass
if dispatch_tasks:
"""
This may be needed if the consumer task running this coroutine
is cancelled and one of the subtasks raises an exception after cancellation.
"""
done, pending = await asyncio.wait(dispatch_tasks)
for task in done:
exc = task.exception()
if exc and not isinstance(exc, asyncio.CancelledError):
raise exc
if not fut.done():
fut.set_result(None)