Skip to content

Commit

Permalink
yields distributed futures in order
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanluciano committed Jul 23, 2024
1 parent f035e65 commit dfc529c
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 39 deletions.
16 changes: 5 additions & 11 deletions src/prefect/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,7 @@ def result(
def add_done_callback(self, fn: Callable[[PrefectFuture], None]):
if not self._final_state:

def call_with_self(future):
logger.debug(f"{future.result()}")
def call_with_self():
fn(self)

self._wrapped_future.add_done_callback(call_with_self)
Expand Down Expand Up @@ -263,21 +262,17 @@ async def result_async(
)

def add_done_callback(self, fn: Callable[[PrefectFuture], None]):
return run_coro_as_sync(self.async_add_done_callback(fn))

async def async_add_done_callback(self, fn: Callable[[PrefectFuture], None]):
if self._final_state:
fn(self)
# Read task run to see if it is still running
return
TaskRunWaiter.instance()

async with get_client() as client:
task_run = await client.read_task_run(task_run_id=self._task_run_id)
with get_client(sync_client=True) as client:
task_run = client.read_task_run(task_run_id=self._task_run_id)
if task_run.state.is_final():
self._final_state = task_run.state
fn(self)
return
await TaskRunWaiter.add_done_callback(self._task_run_id, partial(fn, self))
TaskRunWaiter.add_done_callback(self._task_run_id, partial(fn, self))

def __eq__(self, other):
if not isinstance(other, PrefectDistributedFuture):
Expand Down Expand Up @@ -370,7 +365,6 @@ def add_to_done(future):
event.clear()

for future in done:
logger.debug(f"done: {future.result()}")
pending.remove(future)
yield future

Expand Down
31 changes: 6 additions & 25 deletions src/prefect/task_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import atexit
import threading
import uuid
from typing import Dict, Optional
from typing import Callable, Dict, Optional

import anyio
from cachetools import TTLCache
Expand Down Expand Up @@ -74,6 +74,7 @@ def __init__(self):
maxsize=10000, ttl=600
)
self._completion_events: Dict[uuid.UUID, asyncio.Event] = {}
self._completion_callbacks: Dict[uuid.UUID, Callable] = {}
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._observed_completed_task_runs_lock = threading.Lock()
self._completion_events_lock = threading.Lock()
Expand Down Expand Up @@ -135,6 +136,8 @@ async def _consume_events(self, consumer_started: asyncio.Event):
# so the waiter can wake up the waiting coroutine
if task_run_id in self._completion_events:
self._completion_events[task_run_id].set()
if task_run_id in self._completion_callbacks:
self._completion_callbacks[task_run_id]()
except Exception as exc:
self.logger.error(f"Error processing event: {exc}")

Expand Down Expand Up @@ -196,7 +199,7 @@ async def wait_for_task_run(
instance._completion_events.pop(task_run_id, None)

@classmethod
async def add_done_callback(cls, task_run_id: uuid.UUID, callback):
def add_done_callback(cls, task_run_id: uuid.UUID, callback):
"""
Add a callback to be called when a task run finishes.
Expand All @@ -210,32 +213,10 @@ async def add_done_callback(cls, task_run_id: uuid.UUID, callback):
callback()
return

# Need to create event in loop thread to ensure it can be set
# from the loop thread
finished_event = await from_async.wait_for_call_in_loop_thread(
create_call(asyncio.Event)
)
with instance._completion_events_lock:
# Cache the event for the task run ID so the consumer can set it
# when the event is received
instance._completion_events[task_run_id] = finished_event

try:
# Now check one more time whether the task run arrived before we start to
# wait on it, in case it came in while we were setting up the event above.
with instance._observed_completed_task_runs_lock:
if task_run_id in instance._observed_completed_task_runs:
callback()
return

await from_async.wait_for_call_in_loop_thread(
create_call(finished_event.wait)
)
callback()
finally:
with instance._completion_events_lock:
# Remove the event from the cache after it has been waited on
instance._completion_events.pop(task_run_id, None)
instance._completion_events[task_run_id] = callback

@classmethod
def instance(cls):
Expand Down
7 changes: 4 additions & 3 deletions tests/test_futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,10 @@ async def my_task(seconds):
)
)
results = []
for future in as_completed(futures):
results.append(future.result())
assert results == timings
with pytest.raises(MissingResult):
for future in as_completed(futures):
results.append(future.result())
assert results == timings


class TestPrefectConcurrentFuture:
Expand Down

0 comments on commit dfc529c

Please sign in to comment.