From d5590323ae5e560709de0b57dc15b52e7e350279 Mon Sep 17 00:00:00 2001 From: Chris Pickett Date: Wed, 24 Jul 2024 14:43:47 -0400 Subject: [PATCH] Create an AsyncTaskRunEngine and wire it up to run async tasks --- src/prefect/client/orchestration.py | 26 +- src/prefect/client/utilities.py | 8 +- src/prefect/context.py | 96 ++- src/prefect/flow_engine.py | 6 +- src/prefect/task_engine.py | 874 ++++++++++++++++++++++------ tests/test_flow_engine.py | 6 +- tests/test_task_engine.py | 67 ++- tests/test_task_worker.py | 62 +- 8 files changed, 878 insertions(+), 267 deletions(-) diff --git a/src/prefect/client/orchestration.py b/src/prefect/client/orchestration.py index a6615f5752f37..778eb408732ce 100644 --- a/src/prefect/client/orchestration.py +++ b/src/prefect/client/orchestration.py @@ -203,20 +203,18 @@ def get_client( except RuntimeError: loop = None - if client_ctx := prefect.context.ClientContext.get(): - if ( - sync_client - and client_ctx.sync_client - and client_ctx._httpx_settings == httpx_settings - ): - return client_ctx.sync_client - elif ( - not sync_client - and client_ctx.async_client - and client_ctx._httpx_settings == httpx_settings - and loop in (client_ctx.async_client._loop, None) - ): - return client_ctx.async_client + if sync_client: + if client_ctx := prefect.context.SyncClientContext.get(): + if client_ctx.client and client_ctx._httpx_settings == httpx_settings: + return client_ctx.client + else: + if client_ctx := prefect.context.AsyncClientContext.get(): + if ( + client_ctx.client + and client_ctx._httpx_settings == httpx_settings + and loop in (client_ctx.client._loop, None) + ): + return client_ctx.client api = PREFECT_API_URL.value() diff --git a/src/prefect/client/utilities.py b/src/prefect/client/utilities.py index 243067fdda8bd..ffe42e63195f7 100644 --- a/src/prefect/client/utilities.py +++ b/src/prefect/client/utilities.py @@ -42,14 +42,14 @@ def get_or_create_client( if client is not None: return client, True from prefect._internal.concurrency.event_loop import get_running_loop - from prefect.context import ClientContext, FlowRunContext, TaskRunContext + from prefect.context import AsyncClientContext, FlowRunContext, TaskRunContext - client_context = ClientContext.get() + async_client_context = AsyncClientContext.get() flow_run_context = FlowRunContext.get() task_run_context = TaskRunContext.get() - if client_context and client_context.async_client._loop == get_running_loop(): - return client_context.async_client, True + if async_client_context and async_client_context.client._loop == get_running_loop(): + return async_client_context.client, True elif ( flow_run_context and getattr(flow_run_context.client, "_loop", None) == get_running_loop() diff --git a/src/prefect/context.py b/src/prefect/context.py index 702a52428280d..ecadbdfde3d8f 100644 --- a/src/prefect/context.py +++ b/src/prefect/context.py @@ -10,12 +10,13 @@ import sys import warnings import weakref -from contextlib import ExitStack, contextmanager +from contextlib import ExitStack, asynccontextmanager, contextmanager from contextvars import ContextVar, Token from pathlib import Path from typing import ( TYPE_CHECKING, Any, + AsyncGenerator, Dict, Generator, Mapping, @@ -177,36 +178,34 @@ def serialize(self) -> Dict[str, Any]: return self.model_dump(exclude_unset=True) -class ClientContext(ContextModel): +class SyncClientContext(ContextModel): """ - A context for managing the Prefect client instances. + A context for managing the sync Prefect client instances. Clients were formerly tracked on the TaskRunContext and FlowRunContext, but having two separate places and the addition of both sync and async clients made it difficult to manage. This context is intended to be the single - source for clients. + source for sync clients. - The client creates both sync and async clients, which can either be read - directly from the context object OR loaded with get_client, inject_client, - or other Prefect utilities. + The client creates a sync client, which can either be read directly from + the context object OR loaded with get_client, inject_client, or other + Prefect utilities. - with ClientContext.get_or_create() as ctx: + with SyncClientContext.get_or_create() as ctx: c1 = get_client(sync_client=True) c2 = get_client(sync_client=True) assert c1 is c2 - assert c1 is ctx.sync_client + assert c1 is ctx.client """ - __var__ = ContextVar("clients") - sync_client: SyncPrefectClient - async_client: PrefectClient + __var__ = ContextVar("sync-client-context") + client: SyncPrefectClient _httpx_settings: Optional[dict[str, Any]] = PrivateAttr(None) _context_stack: int = PrivateAttr(0) def __init__(self, httpx_settings: Optional[dict[str, Any]] = None): super().__init__( - sync_client=get_client(sync_client=True, httpx_settings=httpx_settings), - async_client=get_client(sync_client=False, httpx_settings=httpx_settings), + client=get_client(sync_client=True, httpx_settings=httpx_settings), ) self._httpx_settings = httpx_settings self._context_stack = 0 @@ -214,8 +213,7 @@ def __init__(self, httpx_settings: Optional[dict[str, Any]] = None): def __enter__(self): self._context_stack += 1 if self._context_stack == 1: - self.sync_client.__enter__() - run_coro_as_sync(self.async_client.__aenter__()) + self.client.__enter__() return super().__enter__() else: return self @@ -223,18 +221,74 @@ def __enter__(self): def __exit__(self, *exc_info): self._context_stack -= 1 if self._context_stack == 0: - self.sync_client.__exit__(*exc_info) - run_coro_as_sync(self.async_client.__aexit__(*exc_info)) + self.client.__exit__(*exc_info) return super().__exit__(*exc_info) @classmethod @contextmanager - def get_or_create(cls) -> Generator["ClientContext", None, None]: - ctx = ClientContext.get() + def get_or_create(cls) -> Generator["SyncClientContext", None, None]: + ctx = SyncClientContext.get() if ctx: yield ctx else: - with ClientContext() as ctx: + with SyncClientContext() as ctx: + yield ctx + + +class AsyncClientContext(ContextModel): + """ + A context for managing the async Prefect client instances. + + Clients were formerly tracked on the TaskRunContext and FlowRunContext, but + having two separate places and the addition of both sync and async clients + made it difficult to manage. This context is intended to be the single + source for async clients. + + The client creates an async client, which can either be read directly from + the context object OR loaded with get_client, inject_client, or other + Prefect utilities. + + with AsyncClientContext.get_or_create() as ctx: + c1 = get_client(sync_client=False) + c2 = get_client(sync_client=False) + assert c1 is c2 + assert c1 is ctx.client + """ + + __var__ = ContextVar("async-client-context") + client: PrefectClient + _httpx_settings: Optional[dict[str, Any]] = PrivateAttr(None) + _context_stack: int = PrivateAttr(0) + + def __init__(self, httpx_settings: Optional[dict[str, Any]] = None): + super().__init__( + client=get_client(sync_client=False, httpx_settings=httpx_settings), + ) + self._httpx_settings = httpx_settings + self._context_stack = 0 + + async def __aenter__(self): + self._context_stack += 1 + if self._context_stack == 1: + await self.client.__aenter__() + return super().__enter__() + else: + return self + + async def __aexit__(self, *exc_info): + self._context_stack -= 1 + if self._context_stack == 0: + await self.client.__aexit__(*exc_info) + return super().__exit__(*exc_info) + + @classmethod + @asynccontextmanager + async def get_or_create(cls) -> AsyncGenerator[Self, None]: + ctx = cls.get() + if ctx: + yield ctx + else: + with cls() as ctx: yield ctx diff --git a/src/prefect/flow_engine.py b/src/prefect/flow_engine.py index 6f84b6941ad18..c24be4e163e0d 100644 --- a/src/prefect/flow_engine.py +++ b/src/prefect/flow_engine.py @@ -29,7 +29,7 @@ from prefect.client.schemas import FlowRun, TaskRun from prefect.client.schemas.filters import FlowRunFilter from prefect.client.schemas.sorting import FlowRunSort -from prefect.context import ClientContext, FlowRunContext, TagsContext +from prefect.context import FlowRunContext, SyncClientContext, TagsContext from prefect.exceptions import ( Abort, Pause, @@ -529,8 +529,8 @@ def initialize_run(self): """ Enters a client context and creates a flow run if needed. """ - with ClientContext.get_or_create() as client_ctx: - self._client = client_ctx.sync_client + with SyncClientContext.get_or_create() as client_ctx: + self._client = client_ctx.client self._is_started = True if not self.flow_run: diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index 70c4718e687f9..bdef28e8638ca 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -3,7 +3,7 @@ import threading import time from asyncio import CancelledError -from contextlib import ExitStack, contextmanager +from contextlib import ExitStack, asynccontextmanager, contextmanager from dataclasses import dataclass, field from functools import wraps from textwrap import dedent @@ -31,12 +31,13 @@ from typing_extensions import ParamSpec from prefect import Task -from prefect.client.orchestration import SyncPrefectClient +from prefect.client.orchestration import PrefectClient, SyncPrefectClient from prefect.client.schemas import TaskRun from prefect.client.schemas.objects import State, TaskRunInput from prefect.context import ( - ClientContext, + AsyncClientContext, FlowRunContext, + SyncClientContext, TaskRunContext, hydrated_context, ) @@ -77,6 +78,7 @@ _get_hook_name, emit_task_run_state_change_event, link_state_to_result, + propose_state, propose_state_sync, resolve_to_final_result, ) @@ -86,13 +88,15 @@ P = ParamSpec("P") R = TypeVar("R") +BACKOFF_MAX = 10 + class TaskRunTimeoutError(TimeoutError): """Raised when a task run exceeds its timeout.""" @dataclass -class TaskRunEngine(Generic[P, R]): +class BaseTaskRunEngine(Generic[P, R]): task: Union[Task[P, R], Task[P, Coroutine[Any, Any, R]]] logger: logging.Logger = field(default_factory=lambda: get_logger("engine")) parameters: Optional[Dict[str, Any]] = None @@ -106,7 +110,6 @@ class TaskRunEngine(Generic[P, R]): _raised: Union[Exception, Type[NotSet]] = NotSet _initial_run_context: Optional[TaskRunContext] = None _is_started: bool = False - _client: Optional[SyncPrefectClient] = None _task_name_set: bool = False _last_event: Optional[PrefectEvent] = None @@ -114,12 +117,6 @@ def __post_init__(self): if self.parameters is None: self.parameters = {} - @property - def client(self) -> SyncPrefectClient: - if not self._is_started or self._client is None: - raise RuntimeError("Engine has not started.") - return self._client - @property def state(self) -> State: if not self.task_run: @@ -160,41 +157,6 @@ def is_cancelled(self) -> bool: return self.context["cancel_event"].is_set() return False - def call_hooks(self, state: Optional[State] = None): - if state is None: - state = self.state - task = self.task - task_run = self.task_run - - if not task_run: - raise ValueError("Task run is not set") - - if state.is_failed() and task.on_failure_hooks: - hooks = task.on_failure_hooks - elif state.is_completed() and task.on_completion_hooks: - hooks = task.on_completion_hooks - else: - hooks = None - - for hook in hooks or []: - hook_name = _get_hook_name(hook) - - try: - self.logger.info( - f"Running hook {hook_name!r} in response to entering state" - f" {state.name!r}" - ) - result = hook(task, task_run, state) - if inspect.isawaitable(result): - run_coro_as_sync(result) - except Exception: - self.logger.error( - f"An error was encountered while running hook {hook_name!r}", - exc_info=True, - ) - else: - self.logger.info(f"Hook {hook_name!r} finished running successfully") - def compute_transaction_key(self) -> Optional[str]: key = None if self.task.cache_policy: @@ -254,6 +216,106 @@ def _wait_for_dependencies(self): context={"current_task_run": self.task_run, "current_task": self.task}, ) + def record_terminal_state_timing(self, state: State) -> None: + if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: + if self.task_run.start_time and not self.task_run.end_time: + self.task_run.end_time = state.timestamp + + if self.task_run.state.is_running(): + self.task_run.total_run_time += ( + state.timestamp - self.task_run.state.timestamp + ) + + def is_running(self) -> bool: + """Whether or not the engine is currently running a task.""" + if (task_run := getattr(self, "task_run", None)) is None: + return False + return task_run.state.is_running() or task_run.state.is_scheduled() + + def log_finished_message(self): + # If debugging, use the more complete `repr` than the usual `str` description + display_state = repr(self.state) if PREFECT_DEBUG_MODE else str(self.state) + level = logging.INFO if self.state.is_completed() else logging.ERROR + msg = f"Finished in state {display_state}" + if self.state.is_pending(): + msg += ( + "\nPlease wait for all submitted tasks to complete" + " before exiting your flow by calling `.wait()` on the " + "`PrefectFuture` returned from your `.submit()` calls." + ) + msg += dedent( + """ + + Example: + + from prefect import flow, task + + @task + def say_hello(name): + print f"Hello, {name}!" + + @flow + def example_flow(): + future = say_hello.submit(name="Marvin) + future.wait() + + example_flow() + """ + ) + self.logger.log( + level=level, + msg=msg, + ) + + +@dataclass +class SyncTaskRunEngine(BaseTaskRunEngine[P, R]): + _client: Optional[SyncPrefectClient] = None + + @property + def client(self) -> SyncPrefectClient: + if not self._is_started or self._client is None: + raise RuntimeError("Engine has not started.") + return self._client + + def sleep(self, interval: float): + time.sleep(interval) + + def call_hooks(self, state: Optional[State] = None): + if state is None: + state = self.state + task = self.task + task_run = self.task_run + + if not task_run: + raise ValueError("Task run is not set") + + if state.is_failed() and task.on_failure_hooks: + hooks = task.on_failure_hooks + elif state.is_completed() and task.on_completion_hooks: + hooks = task.on_completion_hooks + else: + hooks = None + + for hook in hooks or []: + hook_name = _get_hook_name(hook) + + try: + self.logger.info( + f"Running hook {hook_name!r} in response to entering state" + f" {state.name!r}" + ) + result = hook(task, task_run, state) + if inspect.isawaitable(result): + run_coro_as_sync(result) + except Exception: + self.logger.error( + f"An error was encountered while running hook {hook_name!r}", + exc_info=True, + ) + else: + self.logger.info(f"Hook {hook_name!r} finished running successfully") + def begin_run(self): try: self._resolve_parameters() @@ -294,7 +356,6 @@ def begin_run(self): except Exception: state = self.set_state(new_state, force=True) - BACKOFF_MAX = 10 backoff_count = 0 # TODO: Could this listen for state change events instead of polling? @@ -304,7 +365,7 @@ def begin_run(self): interval = clamped_poisson_interval( average_interval=backoff_count, clamping_factor=0.3 ) - time.sleep(interval) + self.sleep(interval) state = self.set_state(new_state) def set_state(self, state: State, force: bool = False) -> State: @@ -496,16 +557,6 @@ def handle_crash(self, exc: BaseException) -> None: self.set_state(state, force=True) self._raised = exc - def record_terminal_state_timing(self, state: State) -> None: - if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: - if self.task_run.start_time and not self.task_run.end_time: - self.task_run.end_time = state.timestamp - - if self.task_run.state.is_running(): - self.task_run.total_run_time += ( - state.timestamp - self.task_run.state.timestamp - ) - @contextmanager def setup_run_context(self, client: Optional[SyncPrefectClient] = None): from prefect.utilities.engine import ( @@ -558,14 +609,14 @@ def initialize_run( self, task_run_id: Optional[UUID] = None, dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, - ) -> Generator["TaskRunEngine", Any, Any]: + ) -> Generator["SyncTaskRunEngine", Any, Any]: """ Enters a client context and creates a task run if needed. """ with hydrated_context(self.context): - with ClientContext.get_or_create() as client_ctx: - self._client = client_ctx.sync_client + with SyncClientContext.get_or_create() as client_ctx: + self._client = client_ctx.client self._is_started = True try: if not self.task_run: @@ -638,51 +689,10 @@ def initialize_run( self.handle_crash(exc) raise finally: - # If debugging, use the more complete `repr` than the usual `str` description - display_state = ( - repr(self.state) if PREFECT_DEBUG_MODE else str(self.state) - ) - level = logging.INFO if self.state.is_completed() else logging.ERROR - msg = f"Finished in state {display_state}" - if self.state.is_pending(): - msg += ( - "\nPlease wait for all submitted tasks to complete" - " before exiting your flow by calling `.wait()` on the " - "`PrefectFuture` returned from your `.submit()` calls." - ) - msg += dedent( - """ - - Example: - - from prefect import flow, task - - @task - def say_hello(name): - print f"Hello, {name}!" - - @flow - def example_flow(): - future = say_hello.submit(name="Marvin) - future.wait() - - example_flow() - """ - ) - self.logger.log( - level=level, - msg=msg, - ) - + self.log_finished_message() self._is_started = False self._client = None - def is_running(self) -> bool: - """Whether or not the engine is currently running a task.""" - if (task_run := getattr(self, "task_run", None)) is None: - return False - return task_run.state.is_running() or task_run.state.is_scheduled() - async def wait_until_ready(self): """Waits until the scheduled time (if its the future), then enters Running.""" if scheduled_time := self.state.state_details.scheduled_time: @@ -732,11 +742,10 @@ def transaction_context(self) -> Generator[Transaction, None, None]: @contextmanager def run_context(self): - timeout_context = timeout_async if self.task.isasync else timeout # reenter the run context to ensure it is up to date for every run with self.setup_run_context(): try: - with timeout_context( + with timeout( seconds=self.task.timeout_seconds, timeout_exc_type=TaskRunTimeoutError, ): @@ -760,96 +769,587 @@ def call_task_fn( task is async. """ parameters = self.parameters or {} - if self.task.isasync: + if transaction.is_committed(): + result = transaction.read() + else: + result = call_with_parameters(self.task.fn, parameters) + self.handle_success(result, transaction=transaction) + return result - async def _call_task_fn(): - if transaction.is_committed(): - result = transaction.read() - else: - result = await call_with_parameters(self.task.fn, parameters) - self.handle_success(result, transaction=transaction) - return result - return _call_task_fn() - else: - if transaction.is_committed(): - result = transaction.read() - else: - result = call_with_parameters(self.task.fn, parameters) - self.handle_success(result, transaction=transaction) - return result +@dataclass +class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]): + _client: Optional[PrefectClient] = None + @property + def client(self) -> PrefectClient: + if not self._is_started or self._client is None: + raise RuntimeError("Engine has not started.") + return self._client -def run_task_sync( - task: Task[P, R], - task_run_id: Optional[UUID] = None, - task_run: Optional[TaskRun] = None, - parameters: Optional[Dict[str, Any]] = None, - wait_for: Optional[Iterable[PrefectFuture]] = None, - return_type: Literal["state", "result"] = "result", - dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, - context: Optional[Dict[str, Any]] = None, -) -> Union[R, State, None]: - engine = TaskRunEngine[P, R]( - task=task, - parameters=parameters, - task_run=task_run, - wait_for=wait_for, - context=context, - ) + async def sleep(self, interval: float): + await anyio.sleep(interval) - with engine.start(task_run_id=task_run_id, dependencies=dependencies): - while engine.is_running(): - run_coro_as_sync(engine.wait_until_ready()) - with engine.run_context(), engine.transaction_context() as txn: - engine.call_task_fn(txn) + async def call_hooks(self, state: Optional[State] = None): + if state is None: + state = self.state + task = self.task + task_run = self.task_run - return engine.state if return_type == "state" else engine.result() + if not task_run: + raise ValueError("Task run is not set") + if state.is_failed() and task.on_failure_hooks: + hooks = task.on_failure_hooks + elif state.is_completed() and task.on_completion_hooks: + hooks = task.on_completion_hooks + else: + hooks = None -async def run_task_async( - task: Task[P, R], - task_run_id: Optional[UUID] = None, - task_run: Optional[TaskRun] = None, - parameters: Optional[Dict[str, Any]] = None, - wait_for: Optional[Iterable[PrefectFuture]] = None, - return_type: Literal["state", "result"] = "result", - dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, - context: Optional[Dict[str, Any]] = None, -) -> Union[R, State, None]: - engine = TaskRunEngine[P, R]( - task=task, - parameters=parameters, - task_run=task_run, - wait_for=wait_for, - context=context, - ) + for hook in hooks or []: + hook_name = _get_hook_name(hook) - with engine.start(task_run_id=task_run_id, dependencies=dependencies): - while engine.is_running(): - await engine.wait_until_ready() - with engine.run_context(), engine.transaction_context() as txn: - await engine.call_task_fn(txn) + try: + self.logger.info( + f"Running hook {hook_name!r} in response to entering state" + f" {state.name!r}" + ) + result = hook(task, task_run, state) + if inspect.isawaitable(result): + await result + except Exception: + self.logger.error( + f"An error was encountered while running hook {hook_name!r}", + exc_info=True, + ) + else: + self.logger.info(f"Hook {hook_name!r} finished running successfully") - return engine.state if return_type == "state" else engine.result() + async def begin_run(self): + try: + self._resolve_parameters() + self._wait_for_dependencies() + except UpstreamTaskError as upstream_exc: + state = await self.set_state( + Pending( + name="NotReady", + message=str(upstream_exc), + ), + # if orchestrating a run already in a pending state, force orchestration to + # update the state name + force=self.state.is_pending(), + ) + return + new_state = Running() -def run_generator_task_sync( - task: Task[P, R], - task_run_id: Optional[UUID] = None, - task_run: Optional[TaskRun] = None, - parameters: Optional[Dict[str, Any]] = None, - wait_for: Optional[Iterable[PrefectFuture]] = None, - return_type: Literal["state", "result"] = "result", - dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, - context: Optional[Dict[str, Any]] = None, -) -> Generator[R, None, None]: - if return_type != "result": - raise ValueError("The return_type for a generator task must be 'result'") + if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: + self.task_run.start_time = new_state.timestamp + self.task_run.run_count += 1 - engine = TaskRunEngine[P, R]( - task=task, - parameters=parameters, + flow_run_context = FlowRunContext.get() + if flow_run_context: + # Carry forward any task run information from the flow run + flow_run = flow_run_context.flow_run + self.task_run.flow_run_run_count = flow_run.run_count + + state = await self.set_state(new_state) + + # TODO: this is temporary until the API stops rejecting state transitions + # and the client / transaction store becomes the source of truth + # this is a bandaid caused by the API storing a Completed state with a bad + # result reference that no longer exists + if state.is_completed(): + try: + await state.result(retry_result_failure=False) + except Exception: + state = await self.set_state(new_state, force=True) + + backoff_count = 0 + + # TODO: Could this listen for state change events instead of polling? + while state.is_pending() or state.is_paused(): + if backoff_count < BACKOFF_MAX: + backoff_count += 1 + interval = clamped_poisson_interval( + average_interval=backoff_count, clamping_factor=0.3 + ) + await self.sleep(interval) + state = await self.set_state(new_state) + + async def set_state(self, state: State, force: bool = False) -> State: + last_state = self.state + if not self.task_run: + raise ValueError("Task run is not set") + + if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: + self.task_run.state = new_state = state + + # Ensure that the state_details are populated with the current run IDs + new_state.state_details.task_run_id = self.task_run.id + new_state.state_details.flow_run_id = self.task_run.flow_run_id + + # Predictively update the de-normalized task_run.state_* attributes + self.task_run.state_id = new_state.id + self.task_run.state_type = new_state.type + self.task_run.state_name = new_state.name + else: + try: + new_state = await propose_state( + self.client, state, task_run_id=self.task_run.id, force=force + ) + except Pause as exc: + # We shouldn't get a pause signal without a state, but if this happens, + # just use a Paused state to assume an in-process pause. + new_state = exc.state if exc.state else Paused() + if new_state.state_details.pause_reschedule: + # If we're being asked to pause and reschedule, we should exit the + # task and expect to be resumed later. + raise + + # currently this is a hack to keep a reference to the state object + # that has an in-memory result attached to it; using the API state + # could result in losing that reference + self.task_run.state = new_state + + # emit a state change event + self._last_event = emit_task_run_state_change_event( + task_run=self.task_run, + initial_state=last_state, + validated_state=self.task_run.state, + follows=self._last_event, + ) + + return new_state + + async def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]": + if self._return_value is not NotSet: + # if the return value is a BaseResult, we need to fetch it + if isinstance(self._return_value, BaseResult): + _result = self._return_value.get() + if inspect.isawaitable(_result): + _result = await _result + return _result + + # otherwise, return the value as is + return self._return_value + + if self._raised is not NotSet: + # if the task raised an exception, raise it + if raise_on_failure: + raise self._raised + + # otherwise, return the exception + return self._raised + + async def handle_success(self, result: R, transaction: Transaction) -> R: + result_factory = getattr(TaskRunContext.get(), "result_factory", None) + if result_factory is None: + raise ValueError("Result factory is not set") + + if self.task.cache_expiration is not None: + expiration = pendulum.now("utc") + self.task.cache_expiration + else: + expiration = None + + terminal_state = await return_value_to_state( + result, + result_factory=result_factory, + key=transaction.key, + expiration=expiration, + # defer persistence to transaction commit + defer_persistence=True, + ) + transaction.stage( + terminal_state.data, + on_rollback_hooks=[ + _with_transaction_hook_logging(hook, "rollback", self.logger) + for hook in self.task.on_rollback_hooks + ], + on_commit_hooks=[ + _with_transaction_hook_logging(hook, "commit", self.logger) + for hook in self.task.on_commit_hooks + ], + ) + if transaction.is_committed(): + terminal_state.name = "Cached" + + self.record_terminal_state_timing(terminal_state) + await self.set_state(terminal_state) + self._return_value = result + return result + + async def handle_retry(self, exc: Exception) -> bool: + """Handle any task run retries. + + - If the task has retries left, and the retry condition is met, set the task to retrying and return True. + - If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time. + - If the task has no retries left, or the retry condition is not met, return False. + """ + if self.retries < self.task.retries and self.can_retry: + if self.task.retry_delay_seconds: + delay = ( + self.task.retry_delay_seconds[ + min(self.retries, len(self.task.retry_delay_seconds) - 1) + ] # repeat final delay value if attempts exceed specified delays + if isinstance(self.task.retry_delay_seconds, Sequence) + else self.task.retry_delay_seconds + ) + new_state = AwaitingRetry( + scheduled_time=pendulum.now("utc").add(seconds=delay) + ) + else: + delay = None + new_state = Retrying() + if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: + self.task_run.run_count += 1 + + self.logger.info( + "Task run failed with exception: %r - " "Retry %s/%s will start %s", + exc, + self.retries + 1, + self.task.retries, + str(delay) + " second(s) from now" if delay else "immediately", + ) + + await self.set_state(new_state, force=True) + self.retries = self.retries + 1 + return True + elif self.retries >= self.task.retries: + self.logger.error( + "Task run failed with exception: %r - Retries are exhausted", + exc, + exc_info=True, + ) + return False + + return False + + async def handle_exception(self, exc: Exception) -> None: + # If the task fails, and we have retries left, set the task to retrying. + if not await self.handle_retry(exc): + # If the task has no retries left, or the retry condition is not met, set the task to failed. + context = TaskRunContext.get() + state = await exception_to_failed_state( + exc, + message="Task run encountered an exception", + result_factory=getattr(context, "result_factory", None), + ) + self.record_terminal_state_timing(state) + await self.set_state(state) + self._raised = exc + + async def handle_timeout(self, exc: TimeoutError) -> None: + if not await self.handle_retry(exc): + if isinstance(exc, TaskRunTimeoutError): + message = f"Task run exceeded timeout of {self.task.timeout_seconds} second(s)" + else: + message = f"Task run failed due to timeout: {exc!r}" + self.logger.error(message) + state = Failed( + data=exc, + message=message, + name="TimedOut", + ) + await self.set_state(state) + self._raised = exc + + async def handle_crash(self, exc: BaseException) -> None: + state = await exception_to_crashed_state(exc) + self.logger.error(f"Crash detected! {state.message}") + self.logger.debug("Crash details:", exc_info=exc) + self.record_terminal_state_timing(state) + await self.set_state(state, force=True) + self._raised = exc + + @asynccontextmanager + async def setup_run_context(self, client: Optional[PrefectClient] = None): + from prefect.utilities.engine import ( + _resolve_custom_task_run_name, + should_log_prints, + ) + + if client is None: + client = self.client + if not self.task_run: + raise ValueError("Task run is not set") + + if not PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: + self.task_run = await client.read_task_run(self.task_run.id) + with ExitStack() as stack: + if log_prints := should_log_prints(self.task): + stack.enter_context(patch_print()) + stack.enter_context( + TaskRunContext( + task=self.task, + log_prints=log_prints, + task_run=self.task_run, + parameters=self.parameters, + result_factory=await ResultFactory.from_task(self.task), # type: ignore + client=client, + ) + ) + + self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore + + if not PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: + # update the task run name if necessary + if not self._task_name_set and self.task.task_run_name: + task_run_name = _resolve_custom_task_run_name( + task=self.task, parameters=self.parameters + ) + await self.client.set_task_run_name( + task_run_id=self.task_run.id, name=task_run_name + ) + self.logger.extra["task_run_name"] = task_run_name + self.logger.debug( + f"Renamed task run {self.task_run.name!r} to {task_run_name!r}" + ) + self.task_run.name = task_run_name + self._task_name_set = True + yield + + @asynccontextmanager + async def initialize_run( + self, + task_run_id: Optional[UUID] = None, + dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, + ) -> AsyncGenerator["AsyncTaskRunEngine", Any]: + """ + Enters a client context and creates a task run if needed. + """ + + with hydrated_context(self.context): + async with AsyncClientContext.get_or_create() as client_ctx: + self._client = client_ctx.client + self._is_started = True + try: + if not self.task_run: + if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: + # TODO - this maybe should be a method on Task? + from prefect.utilities.engine import ( + _resolve_custom_task_run_name, + ) + + task_run_name = None + if not self._task_name_set and self.task.task_run_name: + task_run_name = _resolve_custom_task_run_name( + task=self.task, parameters=self.parameters + ) + + self.task_run = await self.task.create_local_run( + id=task_run_id, + parameters=self.parameters, + flow_run_context=FlowRunContext.get(), + parent_task_run_context=TaskRunContext.get(), + wait_for=self.wait_for, + extra_task_inputs=dependencies, + task_run_name=task_run_name, + ) + else: + self.task_run = await self.task.create_run( + id=task_run_id, + parameters=self.parameters, + flow_run_context=FlowRunContext.get(), + parent_task_run_context=TaskRunContext.get(), + wait_for=self.wait_for, + extra_task_inputs=dependencies, + ) + # Emit an event to capture that the task run was in the `PENDING` state. + self._last_event = emit_task_run_state_change_event( + task_run=self.task_run, + initial_state=None, + validated_state=self.task_run.state, + ) + + async with self.setup_run_context(): + # setup_run_context might update the task run name, so log creation here + self.logger.info( + f"Created task run {self.task_run.name!r} for task {self.task.name!r}" + ) + yield self + + except TerminationSignal as exc: + # TerminationSignals are caught and handled as crashes + await self.handle_crash(exc) + raise exc + + except Exception: + # regular exceptions are caught and re-raised to the user + raise + except (Pause, Abort) as exc: + # Do not capture internal signals as crashes + if isinstance(exc, Abort): + self.logger.error("Task run was aborted: %s", exc) + raise + except GeneratorExit: + # Do not capture generator exits as crashes + raise + except BaseException as exc: + # BaseExceptions are caught and handled as crashes + await self.handle_crash(exc) + raise + finally: + self.log_finished_message() + self._is_started = False + self._client = None + + async def wait_until_ready(self): + """Waits until the scheduled time (if its the future), then enters Running.""" + if scheduled_time := self.state.state_details.scheduled_time: + sleep_time = (scheduled_time - pendulum.now("utc")).total_seconds() + await anyio.sleep(sleep_time if sleep_time > 0 else 0) + await self.set_state( + Retrying() if self.state.name == "AwaitingRetry" else Running(), + force=True, + ) + + # -------------------------- + # + # The following methods compose the main task run loop + # + # -------------------------- + + @asynccontextmanager + async def start( + self, + task_run_id: Optional[UUID] = None, + dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, + ) -> AsyncGenerator[None, None]: + async with self.initialize_run( + task_run_id=task_run_id, dependencies=dependencies + ): + await self.begin_run() + try: + yield + finally: + await self.call_hooks() + + @asynccontextmanager + async def transaction_context(self) -> AsyncGenerator[Transaction, None]: + result_factory = getattr(TaskRunContext.get(), "result_factory", None) + + # refresh cache setting is now repurposes as overwrite transaction record + overwrite = ( + self.task.refresh_cache + if self.task.refresh_cache is not None + else PREFECT_TASKS_REFRESH_CACHE.value() + ) + with transaction( + key=self.compute_transaction_key(), + store=ResultFactoryStore(result_factory=result_factory), + overwrite=overwrite, + logger=self.logger, + ) as txn: + yield txn + + @asynccontextmanager + async def run_context(self): + # reenter the run context to ensure it is up to date for every run + async with self.setup_run_context(): + try: + with timeout_async( + seconds=self.task.timeout_seconds, + timeout_exc_type=TaskRunTimeoutError, + ): + self.logger.debug( + f"Executing task {self.task.name!r} for task run {self.task_run.name!r}..." + ) + if self.is_cancelled(): + raise CancelledError("Task run cancelled by the task runner") + + yield self + except TimeoutError as exc: + await self.handle_timeout(exc) + except Exception as exc: + await self.handle_exception(exc) + + async def call_task_fn( + self, transaction: Transaction + ) -> Union[R, Coroutine[Any, Any, R]]: + """ + Convenience method to call the task function. Returns a coroutine if the + task is async. + """ + parameters = self.parameters or {} + if transaction.is_committed(): + result = transaction.read() + else: + result = await call_with_parameters(self.task.fn, parameters) + await self.handle_success(result, transaction=transaction) + return result + + +def run_task_sync( + task: Task[P, R], + task_run_id: Optional[UUID] = None, + task_run: Optional[TaskRun] = None, + parameters: Optional[Dict[str, Any]] = None, + wait_for: Optional[Iterable[PrefectFuture]] = None, + return_type: Literal["state", "result"] = "result", + dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, + context: Optional[Dict[str, Any]] = None, +) -> Union[R, State, None]: + engine = SyncTaskRunEngine[P, R]( + task=task, + parameters=parameters, + task_run=task_run, + wait_for=wait_for, + context=context, + ) + + with engine.start(task_run_id=task_run_id, dependencies=dependencies): + while engine.is_running(): + run_coro_as_sync(engine.wait_until_ready()) + with engine.run_context(), engine.transaction_context() as txn: + engine.call_task_fn(txn) + + return engine.state if return_type == "state" else engine.result() + + +async def run_task_async( + task: Task[P, R], + task_run_id: Optional[UUID] = None, + task_run: Optional[TaskRun] = None, + parameters: Optional[Dict[str, Any]] = None, + wait_for: Optional[Iterable[PrefectFuture]] = None, + return_type: Literal["state", "result"] = "result", + dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, + context: Optional[Dict[str, Any]] = None, +) -> Union[R, State, None]: + engine = AsyncTaskRunEngine[P, R]( + task=task, + parameters=parameters, + task_run=task_run, + wait_for=wait_for, + context=context, + ) + + async with engine.start(task_run_id=task_run_id, dependencies=dependencies): + while engine.is_running(): + await engine.wait_until_ready() + async with engine.run_context(), engine.transaction_context() as txn: + await engine.call_task_fn(txn) + + return engine.state if return_type == "state" else await engine.result() + + +def run_generator_task_sync( + task: Task[P, R], + task_run_id: Optional[UUID] = None, + task_run: Optional[TaskRun] = None, + parameters: Optional[Dict[str, Any]] = None, + wait_for: Optional[Iterable[PrefectFuture]] = None, + return_type: Literal["state", "result"] = "result", + dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, + context: Optional[Dict[str, Any]] = None, +) -> Generator[R, None, None]: + if return_type != "result": + raise ValueError("The return_type for a generator task must be 'result'") + + engine = SyncTaskRunEngine[P, R]( + task=task, + parameters=parameters, task_run=task_run, wait_for=wait_for, context=context, @@ -901,7 +1401,7 @@ async def run_generator_task_async( ) -> AsyncGenerator[R, None]: if return_type != "result": raise ValueError("The return_type for a generator task must be 'result'") - engine = TaskRunEngine[P, R]( + engine = AsyncTaskRunEngine[P, R]( task=task, parameters=parameters, task_run=task_run, @@ -909,10 +1409,10 @@ async def run_generator_task_async( context=context, ) - with engine.start(task_run_id=task_run_id, dependencies=dependencies): + async with engine.start(task_run_id=task_run_id, dependencies=dependencies): while engine.is_running(): await engine.wait_until_ready() - with engine.run_context(), engine.transaction_context() as txn: + async with engine.run_context(), engine.transaction_context() as txn: # TODO: generators should default to commit_mode=OFF # because they are dynamic by definition # for now we just prevent this branch explicitly @@ -936,13 +1436,13 @@ async def run_generator_task_async( link_state_to_result(engine.state, gen_result) yield gen_result except (StopAsyncIteration, GeneratorExit) as exc: - engine.handle_success(None, transaction=txn) + await engine.handle_success(None, transaction=txn) if isinstance(exc, GeneratorExit): gen.throw(exc) # async generators can't return, but we can raise failures here if engine.state.is_failed(): - engine.result() + await engine.result() def run_task( diff --git a/tests/test_flow_engine.py b/tests/test_flow_engine.py index 718e3bcde1cff..a723e71b9a1ba 100644 --- a/tests/test_flow_engine.py +++ b/tests/test_flow_engine.py @@ -36,6 +36,8 @@ from prefect.input.run_input import RunInput from prefect.logging import get_run_logger from prefect.server.schemas.core import FlowRun as ServerFlowRun +from prefect.task_engine import AsyncTaskRunEngine +from prefect.testing.utilities import AsyncMock from prefect.utilities.callables import get_call_parameters from prefect.utilities.filesystem import tmpchdir @@ -1125,8 +1127,8 @@ async def flow_resumer(): assert schema is not None async def test_paused_task_polling(self, monkeypatch, prefect_client): - sleeper = MagicMock(side_effect=[None, None, None, None, None]) - monkeypatch.setattr("prefect.task_engine.time.sleep", sleeper) + sleeper = AsyncMock(side_effect=[None, None, None, None, None]) + monkeypatch.setattr(AsyncTaskRunEngine, "sleep", sleeper) @task async def doesnt_pause(): diff --git a/tests/test_task_engine.py b/tests/test_task_engine.py index c69bf8b8a7e2a..7f2f40f34041f 100644 --- a/tests/test_task_engine.py +++ b/tests/test_task_engine.py @@ -34,7 +34,12 @@ temporary_settings, ) from prefect.states import Running, State -from prefect.task_engine import TaskRunEngine, run_task_async, run_task_sync +from prefect.task_engine import ( + AsyncTaskRunEngine, + SyncTaskRunEngine, + run_task_async, + run_task_sync, +) from prefect.task_runners import ThreadPoolTaskRunner from prefect.testing.utilities import exceptions_equal from prefect.utilities.callables import get_call_parameters @@ -176,20 +181,20 @@ async def foo(): return 42 -class TestTaskRunEngine: +class TestSyncTaskRunEngine: async def test_basic_init(self): - engine = TaskRunEngine(task=foo) + engine = SyncTaskRunEngine(task=foo) assert isinstance(engine.task, Task) assert engine.task.name == "foo" assert engine.parameters == {} async def test_client_attribute_raises_informative_error(self): - engine = TaskRunEngine(task=foo) + engine = SyncTaskRunEngine(task=foo) with pytest.raises(RuntimeError, match="not started"): engine.client async def test_client_attr_returns_client_after_starting(self): - engine = TaskRunEngine(task=foo) + engine = SyncTaskRunEngine(task=foo) with engine.initialize_run(): client = engine.client assert isinstance(client, SyncPrefectClient) @@ -198,6 +203,28 @@ async def test_client_attr_returns_client_after_starting(self): engine.client +class TestAsyncTaskRunEngine: + async def test_basic_init(self): + engine = AsyncTaskRunEngine(task=foo) + assert isinstance(engine.task, Task) + assert engine.task.name == "foo" + assert engine.parameters == {} + + async def test_client_attribute_raises_informative_error(self): + engine = AsyncTaskRunEngine(task=foo) + with pytest.raises(RuntimeError, match="not started"): + engine.client + + async def test_client_attr_returns_client_after_starting(self): + engine = AsyncTaskRunEngine(task=foo) + async with engine.initialize_run(): + client = engine.client + assert isinstance(client, PrefectClient) + + with pytest.raises(RuntimeError, match="not started"): + engine.client + + class TestRunTask: def test_run_task_with_client_provided_uuid(self): @task @@ -1118,11 +1145,33 @@ def my_task(): await task_run.state.result() @pytest.mark.parametrize("interrupt_type", [KeyboardInterrupt, SystemExit]) - async def test_interrupt_in_task_orchestration_crashes_task_and_flow( + async def test_interrupt_in_task_orchestration_crashes_task_and_flow_sync( + self, interrupt_type, monkeypatch + ): + monkeypatch.setattr( + SyncTaskRunEngine, "begin_run", MagicMock(side_effect=interrupt_type) + ) + + @task + def my_task(): + pass + + with pytest.raises(interrupt_type): + my_task() + + task_run = await get_task_run(task_run_id=None) + assert task_run.state.is_crashed() + assert task_run.state.type == StateType.CRASHED + assert "Execution was aborted" in task_run.state.message + with pytest.raises(CrashedRun, match="Execution was aborted"): + await task_run.state.result() + + @pytest.mark.parametrize("interrupt_type", [KeyboardInterrupt, SystemExit]) + async def test_interrupt_in_task_orchestration_crashes_task_and_flow_async( self, interrupt_type, monkeypatch ): monkeypatch.setattr( - TaskRunEngine, "begin_run", MagicMock(side_effect=interrupt_type) + AsyncTaskRunEngine, "begin_run", MagicMock(side_effect=interrupt_type) ) @task @@ -1283,7 +1332,7 @@ async def test_sync_task_does_not_set_end_time_on_crash_pre_runnning( self, monkeypatch ): monkeypatch.setattr( - TaskRunEngine, "begin_run", MagicMock(side_effect=SystemExit) + SyncTaskRunEngine, "begin_run", MagicMock(side_effect=SystemExit) ) @task @@ -1301,7 +1350,7 @@ async def test_async_task_does_not_set_end_time_on_crash_pre_running( self, monkeypatch ): monkeypatch.setattr( - TaskRunEngine, "begin_run", MagicMock(side_effect=SystemExit) + AsyncTaskRunEngine, "begin_run", MagicMock(side_effect=SystemExit) ) @task diff --git a/tests/test_task_worker.py b/tests/test_task_worker.py index 99dce2562fd88..2d44b8154c768 100644 --- a/tests/test_task_worker.py +++ b/tests/test_task_worker.py @@ -178,7 +178,7 @@ def empty_task(): def always_error(*args, **kwargs): raise ValueError("oops") - monkeypatch.setattr("prefect.task_engine.TaskRunEngine.start", always_error) + monkeypatch.setattr("prefect.task_engine.SyncTaskRunEngine.start", always_error) task_worker = TaskWorker(empty_task) @@ -769,48 +769,56 @@ async def mock_iter(): async def test_tasks_execute_when_capacity_frees_up( self, mock_subscription, prefect_client ): - event = asyncio.Event() + execution_order = [] @task - async def slow_task(): - await asyncio.sleep(1) - if event.is_set(): - raise ValueError("Something went wrong! This event should not be set.") - event.set() + async def slow_task(task_id: str): + execution_order.append(f"{task_id} start") + await asyncio.sleep(0.1) # Simulating some work + execution_order.append(f"{task_id} end") task_worker = TaskWorker(slow_task, limit=1) - task_run_future_1 = slow_task.apply_async() + task_run_future_1 = slow_task.apply_async(("task1",)) task_run_1 = await prefect_client.read_task_run(task_run_future_1.task_run_id) - task_run_future_2 = slow_task.apply_async() + task_run_future_2 = slow_task.apply_async(("task2",)) task_run_2 = await prefect_client.read_task_run(task_run_future_2.task_run_id) async def mock_iter(): yield task_run_1 yield task_run_2 - # sleep for a second to ensure that task execution starts - await asyncio.sleep(1) + while len(execution_order) < 4: + await asyncio.sleep(0.1) mock_subscription.return_value = mock_iter() server_task = asyncio.create_task(task_worker.start()) - await event.wait() - updated_task_run_1 = await prefect_client.read_task_run(task_run_1.id) - updated_task_run_2 = await prefect_client.read_task_run(task_run_2.id) - assert updated_task_run_1.state.is_completed() - assert not updated_task_run_2.state.is_completed() - - # clear the event to allow the second task to complete - event.clear() - - await event.wait() - updated_task_run_2 = await prefect_client.read_task_run(task_run_2.id) - - assert updated_task_run_2.state.is_completed() - - server_task.cancel() - await server_task + try: + # Wait for both tasks to complete + await asyncio.sleep(2) + + # Verify the execution order + assert execution_order == [ + "task1 start", + "task1 end", + "task2 start", + "task2 end", + ], "Tasks should execute sequentially" + + # Verify the states of both tasks + updated_task_run_1 = await prefect_client.read_task_run(task_run_1.id) + updated_task_run_2 = await prefect_client.read_task_run(task_run_2.id) + + assert updated_task_run_1.state.is_completed() + assert updated_task_run_2.state.is_completed() + + finally: + server_task.cancel() + try: + await server_task + except asyncio.CancelledError: + pass async def test_execute_task_run_respects_limit(self, prefect_client): @task