diff --git a/src/prefect/client/orchestration.py b/src/prefect/client/orchestration.py index a6615f5752f3..778eb408732c 100644 --- a/src/prefect/client/orchestration.py +++ b/src/prefect/client/orchestration.py @@ -203,20 +203,18 @@ def get_client( except RuntimeError: loop = None - if client_ctx := prefect.context.ClientContext.get(): - if ( - sync_client - and client_ctx.sync_client - and client_ctx._httpx_settings == httpx_settings - ): - return client_ctx.sync_client - elif ( - not sync_client - and client_ctx.async_client - and client_ctx._httpx_settings == httpx_settings - and loop in (client_ctx.async_client._loop, None) - ): - return client_ctx.async_client + if sync_client: + if client_ctx := prefect.context.SyncClientContext.get(): + if client_ctx.client and client_ctx._httpx_settings == httpx_settings: + return client_ctx.client + else: + if client_ctx := prefect.context.AsyncClientContext.get(): + if ( + client_ctx.client + and client_ctx._httpx_settings == httpx_settings + and loop in (client_ctx.client._loop, None) + ): + return client_ctx.client api = PREFECT_API_URL.value() diff --git a/src/prefect/client/utilities.py b/src/prefect/client/utilities.py index 243067fdda8b..ffe42e63195f 100644 --- a/src/prefect/client/utilities.py +++ b/src/prefect/client/utilities.py @@ -42,14 +42,14 @@ def get_or_create_client( if client is not None: return client, True from prefect._internal.concurrency.event_loop import get_running_loop - from prefect.context import ClientContext, FlowRunContext, TaskRunContext + from prefect.context import AsyncClientContext, FlowRunContext, TaskRunContext - client_context = ClientContext.get() + async_client_context = AsyncClientContext.get() flow_run_context = FlowRunContext.get() task_run_context = TaskRunContext.get() - if client_context and client_context.async_client._loop == get_running_loop(): - return client_context.async_client, True + if async_client_context and async_client_context.client._loop == get_running_loop(): + return async_client_context.client, True elif ( flow_run_context and getattr(flow_run_context.client, "_loop", None) == get_running_loop() diff --git a/src/prefect/context.py b/src/prefect/context.py index 702a52428280..ecadbdfde3d8 100644 --- a/src/prefect/context.py +++ b/src/prefect/context.py @@ -10,12 +10,13 @@ import sys import warnings import weakref -from contextlib import ExitStack, contextmanager +from contextlib import ExitStack, asynccontextmanager, contextmanager from contextvars import ContextVar, Token from pathlib import Path from typing import ( TYPE_CHECKING, Any, + AsyncGenerator, Dict, Generator, Mapping, @@ -177,36 +178,34 @@ def serialize(self) -> Dict[str, Any]: return self.model_dump(exclude_unset=True) -class ClientContext(ContextModel): +class SyncClientContext(ContextModel): """ - A context for managing the Prefect client instances. + A context for managing the sync Prefect client instances. Clients were formerly tracked on the TaskRunContext and FlowRunContext, but having two separate places and the addition of both sync and async clients made it difficult to manage. This context is intended to be the single - source for clients. + source for sync clients. - The client creates both sync and async clients, which can either be read - directly from the context object OR loaded with get_client, inject_client, - or other Prefect utilities. + The client creates a sync client, which can either be read directly from + the context object OR loaded with get_client, inject_client, or other + Prefect utilities. - with ClientContext.get_or_create() as ctx: + with SyncClientContext.get_or_create() as ctx: c1 = get_client(sync_client=True) c2 = get_client(sync_client=True) assert c1 is c2 - assert c1 is ctx.sync_client + assert c1 is ctx.client """ - __var__ = ContextVar("clients") - sync_client: SyncPrefectClient - async_client: PrefectClient + __var__ = ContextVar("sync-client-context") + client: SyncPrefectClient _httpx_settings: Optional[dict[str, Any]] = PrivateAttr(None) _context_stack: int = PrivateAttr(0) def __init__(self, httpx_settings: Optional[dict[str, Any]] = None): super().__init__( - sync_client=get_client(sync_client=True, httpx_settings=httpx_settings), - async_client=get_client(sync_client=False, httpx_settings=httpx_settings), + client=get_client(sync_client=True, httpx_settings=httpx_settings), ) self._httpx_settings = httpx_settings self._context_stack = 0 @@ -214,8 +213,7 @@ def __init__(self, httpx_settings: Optional[dict[str, Any]] = None): def __enter__(self): self._context_stack += 1 if self._context_stack == 1: - self.sync_client.__enter__() - run_coro_as_sync(self.async_client.__aenter__()) + self.client.__enter__() return super().__enter__() else: return self @@ -223,18 +221,74 @@ def __enter__(self): def __exit__(self, *exc_info): self._context_stack -= 1 if self._context_stack == 0: - self.sync_client.__exit__(*exc_info) - run_coro_as_sync(self.async_client.__aexit__(*exc_info)) + self.client.__exit__(*exc_info) return super().__exit__(*exc_info) @classmethod @contextmanager - def get_or_create(cls) -> Generator["ClientContext", None, None]: - ctx = ClientContext.get() + def get_or_create(cls) -> Generator["SyncClientContext", None, None]: + ctx = SyncClientContext.get() if ctx: yield ctx else: - with ClientContext() as ctx: + with SyncClientContext() as ctx: + yield ctx + + +class AsyncClientContext(ContextModel): + """ + A context for managing the async Prefect client instances. + + Clients were formerly tracked on the TaskRunContext and FlowRunContext, but + having two separate places and the addition of both sync and async clients + made it difficult to manage. This context is intended to be the single + source for async clients. + + The client creates an async client, which can either be read directly from + the context object OR loaded with get_client, inject_client, or other + Prefect utilities. + + with AsyncClientContext.get_or_create() as ctx: + c1 = get_client(sync_client=False) + c2 = get_client(sync_client=False) + assert c1 is c2 + assert c1 is ctx.client + """ + + __var__ = ContextVar("async-client-context") + client: PrefectClient + _httpx_settings: Optional[dict[str, Any]] = PrivateAttr(None) + _context_stack: int = PrivateAttr(0) + + def __init__(self, httpx_settings: Optional[dict[str, Any]] = None): + super().__init__( + client=get_client(sync_client=False, httpx_settings=httpx_settings), + ) + self._httpx_settings = httpx_settings + self._context_stack = 0 + + async def __aenter__(self): + self._context_stack += 1 + if self._context_stack == 1: + await self.client.__aenter__() + return super().__enter__() + else: + return self + + async def __aexit__(self, *exc_info): + self._context_stack -= 1 + if self._context_stack == 0: + await self.client.__aexit__(*exc_info) + return super().__exit__(*exc_info) + + @classmethod + @asynccontextmanager + async def get_or_create(cls) -> AsyncGenerator[Self, None]: + ctx = cls.get() + if ctx: + yield ctx + else: + with cls() as ctx: yield ctx diff --git a/src/prefect/flow_engine.py b/src/prefect/flow_engine.py index 6f84b6941ad1..c24be4e163e0 100644 --- a/src/prefect/flow_engine.py +++ b/src/prefect/flow_engine.py @@ -29,7 +29,7 @@ from prefect.client.schemas import FlowRun, TaskRun from prefect.client.schemas.filters import FlowRunFilter from prefect.client.schemas.sorting import FlowRunSort -from prefect.context import ClientContext, FlowRunContext, TagsContext +from prefect.context import FlowRunContext, SyncClientContext, TagsContext from prefect.exceptions import ( Abort, Pause, @@ -529,8 +529,8 @@ def initialize_run(self): """ Enters a client context and creates a flow run if needed. """ - with ClientContext.get_or_create() as client_ctx: - self._client = client_ctx.sync_client + with SyncClientContext.get_or_create() as client_ctx: + self._client = client_ctx.client self._is_started = True if not self.flow_run: diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index 836c041db1e3..e12ecb2a91aa 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -3,7 +3,7 @@ import threading import time from asyncio import CancelledError -from contextlib import ExitStack, contextmanager +from contextlib import ExitStack, asynccontextmanager, contextmanager from dataclasses import dataclass, field from functools import wraps from textwrap import dedent @@ -31,12 +31,13 @@ from typing_extensions import ParamSpec from prefect import Task -from prefect.client.orchestration import SyncPrefectClient +from prefect.client.orchestration import PrefectClient, SyncPrefectClient from prefect.client.schemas import TaskRun from prefect.client.schemas.objects import State, TaskRunInput from prefect.context import ( - ClientContext, + AsyncClientContext, FlowRunContext, + SyncClientContext, TaskRunContext, hydrated_context, ) @@ -77,6 +78,7 @@ _get_hook_name, emit_task_run_state_change_event, link_state_to_result, + propose_state, propose_state_sync, resolve_to_final_result, ) @@ -86,47 +88,717 @@ P = ParamSpec("P") R = TypeVar("R") +BACKOFF_MAX = 10 + class TaskRunTimeoutError(TimeoutError): """Raised when a task run exceeds its timeout.""" -@dataclass -class TaskRunEngine(Generic[P, R]): - task: Union[Task[P, R], Task[P, Coroutine[Any, Any, R]]] - logger: logging.Logger = field(default_factory=lambda: get_logger("engine")) - parameters: Optional[Dict[str, Any]] = None - task_run: Optional[TaskRun] = None - retries: int = 0 - wait_for: Optional[Iterable[PrefectFuture]] = None - context: Optional[Dict[str, Any]] = None - # holds the return value from the user code - _return_value: Union[R, Type[NotSet]] = NotSet - # holds the exception raised by the user code, if any - _raised: Union[Exception, Type[NotSet]] = NotSet - _initial_run_context: Optional[TaskRunContext] = None - _is_started: bool = False - _client: Optional[SyncPrefectClient] = None - _task_name_set: bool = False - _last_event: Optional[PrefectEvent] = None +@dataclass +class BaseTaskRunEngine(Generic[P, R]): + task: Union[Task[P, R], Task[P, Coroutine[Any, Any, R]]] + logger: logging.Logger = field(default_factory=lambda: get_logger("engine")) + parameters: Optional[Dict[str, Any]] = None + task_run: Optional[TaskRun] = None + retries: int = 0 + wait_for: Optional[Iterable[PrefectFuture]] = None + context: Optional[Dict[str, Any]] = None + # holds the return value from the user code + _return_value: Union[R, Type[NotSet]] = NotSet + # holds the exception raised by the user code, if any + _raised: Union[Exception, Type[NotSet]] = NotSet + _initial_run_context: Optional[TaskRunContext] = None + _is_started: bool = False + _task_name_set: bool = False + _last_event: Optional[PrefectEvent] = None + + def __post_init__(self): + if self.parameters is None: + self.parameters = {} + + @property + def state(self) -> State: + if not self.task_run: + raise ValueError("Task run is not set") + return self.task_run.state + + def is_cancelled(self) -> bool: + if ( + self.context + and "cancel_event" in self.context + and isinstance(self.context["cancel_event"], threading.Event) + ): + return self.context["cancel_event"].is_set() + return False + + def compute_transaction_key(self) -> Optional[str]: + key = None + if self.task.cache_policy: + flow_run_context = FlowRunContext.get() + task_run_context = TaskRunContext.get() + + if flow_run_context: + parameters = flow_run_context.parameters + else: + parameters = None + + key = self.task.cache_policy.compute_key( + task_ctx=task_run_context, + inputs=self.parameters, + flow_parameters=parameters, + ) + elif self.task.result_storage_key is not None: + key = _format_user_supplied_storage_key(self.task.result_storage_key) + return key + + def _resolve_parameters(self): + if not self.parameters: + return {} + + resolved_parameters = {} + for parameter, value in self.parameters.items(): + try: + resolved_parameters[parameter] = visit_collection( + value, + visit_fn=resolve_to_final_result, + return_data=True, + max_depth=-1, + remove_annotations=True, + context={}, + ) + except UpstreamTaskError: + raise + except Exception as exc: + raise PrefectException( + f"Failed to resolve inputs in parameter {parameter!r}. If your" + " parameter type is not supported, consider using the `quote`" + " annotation to skip resolution of inputs." + ) from exc + + self.parameters = resolved_parameters + + def _wait_for_dependencies(self): + if not self.wait_for: + return + + visit_collection( + self.wait_for, + visit_fn=resolve_to_final_result, + return_data=False, + max_depth=-1, + remove_annotations=True, + context={"current_task_run": self.task_run, "current_task": self.task}, + ) + + def record_terminal_state_timing(self, state: State) -> None: + if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: + if self.task_run.start_time and not self.task_run.end_time: + self.task_run.end_time = state.timestamp + + if self.task_run.state.is_running(): + self.task_run.total_run_time += ( + state.timestamp - self.task_run.state.timestamp + ) + + def is_running(self) -> bool: + """Whether or not the engine is currently running a task.""" + if (task_run := getattr(self, "task_run", None)) is None: + return False + return task_run.state.is_running() or task_run.state.is_scheduled() + + def log_finished_message(self): + # If debugging, use the more complete `repr` than the usual `str` description + display_state = repr(self.state) if PREFECT_DEBUG_MODE else str(self.state) + level = logging.INFO if self.state.is_completed() else logging.ERROR + msg = f"Finished in state {display_state}" + if self.state.is_pending(): + msg += ( + "\nPlease wait for all submitted tasks to complete" + " before exiting your flow by calling `.wait()` on the " + "`PrefectFuture` returned from your `.submit()` calls." + ) + msg += dedent( + """ + + Example: + + from prefect import flow, task + + @task + def say_hello(name): + print f"Hello, {name}!" + + @flow + def example_flow(): + future = say_hello.submit(name="Marvin) + future.wait() + + example_flow() + """ + ) + self.logger.log( + level=level, + msg=msg, + ) + + +@dataclass +class SyncTaskRunEngine(BaseTaskRunEngine[P, R]): + _client: Optional[SyncPrefectClient] = None + + @property + def client(self) -> SyncPrefectClient: + if not self._is_started or self._client is None: + raise RuntimeError("Engine has not started.") + return self._client + + def can_retry(self, exc: Exception) -> bool: + retry_condition: Optional[ + Callable[[Task[P, Coroutine[Any, Any, R]], TaskRun, State], bool] + ] = self.task.retry_condition_fn + if not self.task_run: + raise ValueError("Task run is not set") + try: + self.logger.debug( + f"Running `retry_condition_fn` check {retry_condition!r} for task" + f" {self.task.name!r}" + ) + state = Failed( + data=exc, + message=f"Task run encountered unexpected exception: {repr(exc)}", + ) + if inspect.iscoroutinefunction(retry_condition): + should_retry = run_coro_as_sync( + retry_condition(self.task, self.task_run, state) + ) + elif inspect.isfunction(retry_condition): + should_retry = retry_condition(self.task, self.task_run, state) + else: + should_retry = not retry_condition + return should_retry + except Exception: + self.logger.error( + ( + "An error was encountered while running `retry_condition_fn` check" + f" '{retry_condition!r}' for task {self.task.name!r}" + ), + exc_info=True, + ) + return False + + def call_hooks(self, state: Optional[State] = None): + if state is None: + state = self.state + task = self.task + task_run = self.task_run + + if not task_run: + raise ValueError("Task run is not set") + + if state.is_failed() and task.on_failure_hooks: + hooks = task.on_failure_hooks + elif state.is_completed() and task.on_completion_hooks: + hooks = task.on_completion_hooks + else: + hooks = None + + for hook in hooks or []: + hook_name = _get_hook_name(hook) + + try: + self.logger.info( + f"Running hook {hook_name!r} in response to entering state" + f" {state.name!r}" + ) + result = hook(task, task_run, state) + if inspect.isawaitable(result): + run_coro_as_sync(result) + except Exception: + self.logger.error( + f"An error was encountered while running hook {hook_name!r}", + exc_info=True, + ) + else: + self.logger.info(f"Hook {hook_name!r} finished running successfully") + + def begin_run(self): + try: + self._resolve_parameters() + self._wait_for_dependencies() + except UpstreamTaskError as upstream_exc: + state = self.set_state( + Pending( + name="NotReady", + message=str(upstream_exc), + ), + # if orchestrating a run already in a pending state, force orchestration to + # update the state name + force=self.state.is_pending(), + ) + return + + new_state = Running() + + if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: + self.task_run.start_time = new_state.timestamp + self.task_run.run_count += 1 + + flow_run_context = FlowRunContext.get() + if flow_run_context: + # Carry forward any task run information from the flow run + flow_run = flow_run_context.flow_run + self.task_run.flow_run_run_count = flow_run.run_count + + state = self.set_state(new_state) + + # TODO: this is temporary until the API stops rejecting state transitions + # and the client / transaction store becomes the source of truth + # this is a bandaid caused by the API storing a Completed state with a bad + # result reference that no longer exists + if state.is_completed(): + try: + state.result(retry_result_failure=False, _sync=True) + except Exception: + state = self.set_state(new_state, force=True) + + backoff_count = 0 + + # TODO: Could this listen for state change events instead of polling? + while state.is_pending() or state.is_paused(): + if backoff_count < BACKOFF_MAX: + backoff_count += 1 + interval = clamped_poisson_interval( + average_interval=backoff_count, clamping_factor=0.3 + ) + time.sleep(interval) + state = self.set_state(new_state) + + def set_state(self, state: State, force: bool = False) -> State: + last_state = self.state + if not self.task_run: + raise ValueError("Task run is not set") + + if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: + self.task_run.state = new_state = state + + # Ensure that the state_details are populated with the current run IDs + new_state.state_details.task_run_id = self.task_run.id + new_state.state_details.flow_run_id = self.task_run.flow_run_id + + # Predictively update the de-normalized task_run.state_* attributes + self.task_run.state_id = new_state.id + self.task_run.state_type = new_state.type + self.task_run.state_name = new_state.name + else: + try: + new_state = propose_state_sync( + self.client, state, task_run_id=self.task_run.id, force=force + ) + except Pause as exc: + # We shouldn't get a pause signal without a state, but if this happens, + # just use a Paused state to assume an in-process pause. + new_state = exc.state if exc.state else Paused() + if new_state.state_details.pause_reschedule: + # If we're being asked to pause and reschedule, we should exit the + # task and expect to be resumed later. + raise + + # currently this is a hack to keep a reference to the state object + # that has an in-memory result attached to it; using the API state + # could result in losing that reference + self.task_run.state = new_state + + # emit a state change event + self._last_event = emit_task_run_state_change_event( + task_run=self.task_run, + initial_state=last_state, + validated_state=self.task_run.state, + follows=self._last_event, + ) + + return new_state + + def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]": + if self._return_value is not NotSet: + # if the return value is a BaseResult, we need to fetch it + if isinstance(self._return_value, BaseResult): + _result = self._return_value.get() + if inspect.isawaitable(_result): + _result = run_coro_as_sync(_result) + return _result + + # otherwise, return the value as is + return self._return_value + + if self._raised is not NotSet: + # if the task raised an exception, raise it + if raise_on_failure: + raise self._raised + + # otherwise, return the exception + return self._raised + + def handle_success(self, result: R, transaction: Transaction) -> R: + result_factory = getattr(TaskRunContext.get(), "result_factory", None) + if result_factory is None: + raise ValueError("Result factory is not set") + + if self.task.cache_expiration is not None: + expiration = pendulum.now("utc") + self.task.cache_expiration + else: + expiration = None + + terminal_state = run_coro_as_sync( + return_value_to_state( + result, + result_factory=result_factory, + key=transaction.key, + expiration=expiration, + # defer persistence to transaction commit + defer_persistence=True, + ) + ) + transaction.stage( + terminal_state.data, + on_rollback_hooks=[ + _with_transaction_hook_logging(hook, "rollback", self.logger) + for hook in self.task.on_rollback_hooks + ], + on_commit_hooks=[ + _with_transaction_hook_logging(hook, "commit", self.logger) + for hook in self.task.on_commit_hooks + ], + ) + if transaction.is_committed(): + terminal_state.name = "Cached" + + self.record_terminal_state_timing(terminal_state) + self.set_state(terminal_state) + self._return_value = result + return result + + def handle_retry(self, exc: Exception) -> bool: + """Handle any task run retries. + + - If the task has retries left, and the retry condition is met, set the task to retrying and return True. + - If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time. + - If the task has no retries left, or the retry condition is not met, return False. + """ + if self.retries < self.task.retries and self.can_retry(exc): + if self.task.retry_delay_seconds: + delay = ( + self.task.retry_delay_seconds[ + min(self.retries, len(self.task.retry_delay_seconds) - 1) + ] # repeat final delay value if attempts exceed specified delays + if isinstance(self.task.retry_delay_seconds, Sequence) + else self.task.retry_delay_seconds + ) + new_state = AwaitingRetry( + scheduled_time=pendulum.now("utc").add(seconds=delay) + ) + else: + delay = None + new_state = Retrying() + if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: + self.task_run.run_count += 1 + + self.logger.info( + "Task run failed with exception: %r - " "Retry %s/%s will start %s", + exc, + self.retries + 1, + self.task.retries, + str(delay) + " second(s) from now" if delay else "immediately", + ) + + self.set_state(new_state, force=True) + self.retries = self.retries + 1 + return True + elif self.retries >= self.task.retries: + self.logger.error( + "Task run failed with exception: %r - Retries are exhausted", + exc, + exc_info=True, + ) + return False + + return False + + def handle_exception(self, exc: Exception) -> None: + # If the task fails, and we have retries left, set the task to retrying. + if not self.handle_retry(exc): + # If the task has no retries left, or the retry condition is not met, set the task to failed. + context = TaskRunContext.get() + state = run_coro_as_sync( + exception_to_failed_state( + exc, + message="Task run encountered an exception", + result_factory=getattr(context, "result_factory", None), + ) + ) + self.record_terminal_state_timing(state) + self.set_state(state) + self._raised = exc + + def handle_timeout(self, exc: TimeoutError) -> None: + if not self.handle_retry(exc): + if isinstance(exc, TaskRunTimeoutError): + message = f"Task run exceeded timeout of {self.task.timeout_seconds} second(s)" + else: + message = f"Task run failed due to timeout: {exc!r}" + self.logger.error(message) + state = Failed( + data=exc, + message=message, + name="TimedOut", + ) + self.set_state(state) + self._raised = exc + + def handle_crash(self, exc: BaseException) -> None: + state = run_coro_as_sync(exception_to_crashed_state(exc)) + self.logger.error(f"Crash detected! {state.message}") + self.logger.debug("Crash details:", exc_info=exc) + self.record_terminal_state_timing(state) + self.set_state(state, force=True) + self._raised = exc + + @contextmanager + def setup_run_context(self, client: Optional[SyncPrefectClient] = None): + from prefect.utilities.engine import ( + _resolve_custom_task_run_name, + should_log_prints, + ) + + if client is None: + client = self.client + if not self.task_run: + raise ValueError("Task run is not set") + + if not PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: + self.task_run = client.read_task_run(self.task_run.id) + with ExitStack() as stack: + if log_prints := should_log_prints(self.task): + stack.enter_context(patch_print()) + stack.enter_context( + TaskRunContext( + task=self.task, + log_prints=log_prints, + task_run=self.task_run, + parameters=self.parameters, + result_factory=run_coro_as_sync(ResultFactory.from_task(self.task)), # type: ignore + client=client, + ) + ) + + self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore + + if not PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: + # update the task run name if necessary + if not self._task_name_set and self.task.task_run_name: + task_run_name = _resolve_custom_task_run_name( + task=self.task, parameters=self.parameters + ) + self.client.set_task_run_name( + task_run_id=self.task_run.id, name=task_run_name + ) + self.logger.extra["task_run_name"] = task_run_name + self.logger.debug( + f"Renamed task run {self.task_run.name!r} to {task_run_name!r}" + ) + self.task_run.name = task_run_name + self._task_name_set = True + yield + + @contextmanager + def initialize_run( + self, + task_run_id: Optional[UUID] = None, + dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, + ) -> Generator["SyncTaskRunEngine", Any, Any]: + """ + Enters a client context and creates a task run if needed. + """ + + with hydrated_context(self.context): + with SyncClientContext.get_or_create() as client_ctx: + self._client = client_ctx.client + self._is_started = True + try: + if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: + from prefect.utilities.engine import ( + _resolve_custom_task_run_name, + ) + + task_run_name = ( + _resolve_custom_task_run_name( + task=self.task, parameters=self.parameters + ) + if self.task.task_run_name + else None + ) + + if self.task_run and task_run_name: + self.task_run.name = task_run_name + + if not self.task_run: + self.task_run = run_coro_as_sync( + self.task.create_local_run( + id=task_run_id, + parameters=self.parameters, + flow_run_context=FlowRunContext.get(), + parent_task_run_context=TaskRunContext.get(), + wait_for=self.wait_for, + extra_task_inputs=dependencies, + task_run_name=task_run_name, + ) + ) + else: + if not self.task_run: + self.task_run = run_coro_as_sync( + self.task.create_run( + id=task_run_id, + parameters=self.parameters, + flow_run_context=FlowRunContext.get(), + parent_task_run_context=TaskRunContext.get(), + wait_for=self.wait_for, + extra_task_inputs=dependencies, + ) + ) + # Emit an event to capture that the task run was in the `PENDING` state. + self._last_event = emit_task_run_state_change_event( + task_run=self.task_run, + initial_state=None, + validated_state=self.task_run.state, + ) + + with self.setup_run_context(): + # setup_run_context might update the task run name, so log creation here + self.logger.info( + f"Created task run {self.task_run.name!r} for task {self.task.name!r}" + ) + yield self + + except TerminationSignal as exc: + # TerminationSignals are caught and handled as crashes + self.handle_crash(exc) + raise exc + + except Exception: + # regular exceptions are caught and re-raised to the user + raise + except (Pause, Abort) as exc: + # Do not capture internal signals as crashes + if isinstance(exc, Abort): + self.logger.error("Task run was aborted: %s", exc) + raise + except GeneratorExit: + # Do not capture generator exits as crashes + raise + except BaseException as exc: + # BaseExceptions are caught and handled as crashes + self.handle_crash(exc) + raise + finally: + self.log_finished_message() + self._is_started = False + self._client = None + + async def wait_until_ready(self): + """Waits until the scheduled time (if its the future), then enters Running.""" + if scheduled_time := self.state.state_details.scheduled_time: + sleep_time = (scheduled_time - pendulum.now("utc")).total_seconds() + await anyio.sleep(sleep_time if sleep_time > 0 else 0) + self.set_state( + Retrying() if self.state.name == "AwaitingRetry" else Running(), + force=True, + ) + + # -------------------------- + # + # The following methods compose the main task run loop + # + # -------------------------- + + @contextmanager + def start( + self, + task_run_id: Optional[UUID] = None, + dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, + ) -> Generator[None, None, None]: + with self.initialize_run(task_run_id=task_run_id, dependencies=dependencies): + self.begin_run() + try: + yield + finally: + self.call_hooks() + + @contextmanager + def transaction_context(self) -> Generator[Transaction, None, None]: + result_factory = getattr(TaskRunContext.get(), "result_factory", None) + + # refresh cache setting is now repurposes as overwrite transaction record + overwrite = ( + self.task.refresh_cache + if self.task.refresh_cache is not None + else PREFECT_TASKS_REFRESH_CACHE.value() + ) + with transaction( + key=self.compute_transaction_key(), + store=ResultFactoryStore(result_factory=result_factory), + overwrite=overwrite, + logger=self.logger, + ) as txn: + yield txn + + @contextmanager + def run_context(self): + # reenter the run context to ensure it is up to date for every run + with self.setup_run_context(): + try: + with timeout( + seconds=self.task.timeout_seconds, + timeout_exc_type=TaskRunTimeoutError, + ): + self.logger.debug( + f"Executing task {self.task.name!r} for task run {self.task_run.name!r}..." + ) + if self.is_cancelled(): + raise CancelledError("Task run cancelled by the task runner") + + yield self + except TimeoutError as exc: + self.handle_timeout(exc) + except Exception as exc: + self.handle_exception(exc) + + def call_task_fn( + self, transaction: Transaction + ) -> Union[R, Coroutine[Any, Any, R]]: + """ + Convenience method to call the task function. Returns a coroutine if the + task is async. + """ + parameters = self.parameters or {} + if transaction.is_committed(): + result = transaction.read() + else: + result = call_with_parameters(self.task.fn, parameters) + self.handle_success(result, transaction=transaction) + return result - def __post_init__(self): - if self.parameters is None: - self.parameters = {} + +@dataclass +class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]): + _client: Optional[PrefectClient] = None @property - def client(self) -> SyncPrefectClient: + def client(self) -> PrefectClient: if not self._is_started or self._client is None: raise RuntimeError("Engine has not started.") return self._client - @property - def state(self) -> State: - if not self.task_run: - raise ValueError("Task run is not set") - return self.task_run.state - - def can_retry(self, exc: Exception) -> bool: + async def can_retry(self, exc: Exception) -> bool: retry_condition: Optional[ Callable[[Task[P, Coroutine[Any, Any, R]], TaskRun, State], bool] ] = self.task.retry_condition_fn @@ -142,14 +814,13 @@ def can_retry(self, exc: Exception) -> bool: message=f"Task run encountered unexpected exception: {repr(exc)}", ) if inspect.iscoroutinefunction(retry_condition): - should_retry = run_coro_as_sync( - retry_condition(self.task, self.task_run, state) - ) + should_retry = await retry_condition(self.task, self.task_run, state) elif inspect.isfunction(retry_condition): should_retry = retry_condition(self.task, self.task_run, state) else: should_retry = not retry_condition return should_retry + except Exception: self.logger.error( ( @@ -160,16 +831,7 @@ def can_retry(self, exc: Exception) -> bool: ) return False - def is_cancelled(self) -> bool: - if ( - self.context - and "cancel_event" in self.context - and isinstance(self.context["cancel_event"], threading.Event) - ): - return self.context["cancel_event"].is_set() - return False - - def call_hooks(self, state: Optional[State] = None): + async def call_hooks(self, state: Optional[State] = None): if state is None: state = self.state task = self.task @@ -195,7 +857,7 @@ def call_hooks(self, state: Optional[State] = None): ) result = hook(task, task_run, state) if inspect.isawaitable(result): - run_coro_as_sync(result) + await result except Exception: self.logger.error( f"An error was encountered while running hook {hook_name!r}", @@ -204,71 +866,12 @@ def call_hooks(self, state: Optional[State] = None): else: self.logger.info(f"Hook {hook_name!r} finished running successfully") - def compute_transaction_key(self) -> Optional[str]: - key = None - if self.task.cache_policy: - flow_run_context = FlowRunContext.get() - task_run_context = TaskRunContext.get() - - if flow_run_context: - parameters = flow_run_context.parameters - else: - parameters = None - - key = self.task.cache_policy.compute_key( - task_ctx=task_run_context, - inputs=self.parameters, - flow_parameters=parameters, - ) - elif self.task.result_storage_key is not None: - key = _format_user_supplied_storage_key(self.task.result_storage_key) - return key - - def _resolve_parameters(self): - if not self.parameters: - return {} - - resolved_parameters = {} - for parameter, value in self.parameters.items(): - try: - resolved_parameters[parameter] = visit_collection( - value, - visit_fn=resolve_to_final_result, - return_data=True, - max_depth=-1, - remove_annotations=True, - context={}, - ) - except UpstreamTaskError: - raise - except Exception as exc: - raise PrefectException( - f"Failed to resolve inputs in parameter {parameter!r}. If your" - " parameter type is not supported, consider using the `quote`" - " annotation to skip resolution of inputs." - ) from exc - - self.parameters = resolved_parameters - - def _wait_for_dependencies(self): - if not self.wait_for: - return - - visit_collection( - self.wait_for, - visit_fn=resolve_to_final_result, - return_data=False, - max_depth=-1, - remove_annotations=True, - context={"current_task_run": self.task_run, "current_task": self.task}, - ) - - def begin_run(self): + async def begin_run(self): try: self._resolve_parameters() self._wait_for_dependencies() except UpstreamTaskError as upstream_exc: - state = self.set_state( + state = await self.set_state( Pending( name="NotReady", message=str(upstream_exc), @@ -291,7 +894,7 @@ def begin_run(self): flow_run = flow_run_context.flow_run self.task_run.flow_run_run_count = flow_run.run_count - state = self.set_state(new_state) + state = await self.set_state(new_state) # TODO: this is temporary until the API stops rejecting state transitions # and the client / transaction store becomes the source of truth @@ -299,11 +902,10 @@ def begin_run(self): # result reference that no longer exists if state.is_completed(): try: - state.result(retry_result_failure=False, _sync=True) + await state.result(retry_result_failure=False) except Exception: - state = self.set_state(new_state, force=True) + state = await self.set_state(new_state, force=True) - BACKOFF_MAX = 10 backoff_count = 0 # TODO: Could this listen for state change events instead of polling? @@ -313,10 +915,10 @@ def begin_run(self): interval = clamped_poisson_interval( average_interval=backoff_count, clamping_factor=0.3 ) - time.sleep(interval) - state = self.set_state(new_state) + await anyio.sleep(interval) + state = await self.set_state(new_state) - def set_state(self, state: State, force: bool = False) -> State: + async def set_state(self, state: State, force: bool = False) -> State: last_state = self.state if not self.task_run: raise ValueError("Task run is not set") @@ -334,7 +936,7 @@ def set_state(self, state: State, force: bool = False) -> State: self.task_run.state_name = new_state.name else: try: - new_state = propose_state_sync( + new_state = await propose_state( self.client, state, task_run_id=self.task_run.id, force=force ) except Pause as exc: @@ -361,14 +963,11 @@ def set_state(self, state: State, force: bool = False) -> State: return new_state - def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]": + async def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]": if self._return_value is not NotSet: # if the return value is a BaseResult, we need to fetch it if isinstance(self._return_value, BaseResult): - _result = self._return_value.get() - if inspect.isawaitable(_result): - _result = run_coro_as_sync(_result) - return _result + return await self._return_value.get() # otherwise, return the value as is return self._return_value @@ -381,7 +980,7 @@ def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]": # otherwise, return the exception return self._raised - def handle_success(self, result: R, transaction: Transaction) -> R: + async def handle_success(self, result: R, transaction: Transaction) -> R: result_factory = getattr(TaskRunContext.get(), "result_factory", None) if result_factory is None: raise ValueError("Result factory is not set") @@ -391,15 +990,13 @@ def handle_success(self, result: R, transaction: Transaction) -> R: else: expiration = None - terminal_state = run_coro_as_sync( - return_value_to_state( - result, - result_factory=result_factory, - key=transaction.key, - expiration=expiration, - # defer persistence to transaction commit - defer_persistence=True, - ) + terminal_state = await return_value_to_state( + result, + result_factory=result_factory, + key=transaction.key, + expiration=expiration, + # defer persistence to transaction commit + defer_persistence=True, ) transaction.stage( terminal_state.data, @@ -416,18 +1013,18 @@ def handle_success(self, result: R, transaction: Transaction) -> R: terminal_state.name = "Cached" self.record_terminal_state_timing(terminal_state) - self.set_state(terminal_state) + await self.set_state(terminal_state) self._return_value = result return result - def handle_retry(self, exc: Exception) -> bool: + async def handle_retry(self, exc: Exception) -> bool: """Handle any task run retries. - If the task has retries left, and the retry condition is met, set the task to retrying and return True. - - If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time. + - If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time. - If the task has no retries left, or the retry condition is not met, return False. """ - if self.retries < self.task.retries and self.can_retry(exc): + if self.retries < self.task.retries and await self.can_retry(exc): if self.task.retry_delay_seconds: delay = ( self.task.retry_delay_seconds[ @@ -453,7 +1050,7 @@ def handle_retry(self, exc: Exception) -> bool: str(delay) + " second(s) from now" if delay else "immediately", ) - self.set_state(new_state, force=True) + await self.set_state(new_state, force=True) self.retries = self.retries + 1 return True elif self.retries >= self.task.retries: @@ -466,24 +1063,22 @@ def handle_retry(self, exc: Exception) -> bool: return False - def handle_exception(self, exc: Exception) -> None: + async def handle_exception(self, exc: Exception) -> None: # If the task fails, and we have retries left, set the task to retrying. - if not self.handle_retry(exc): + if not await self.handle_retry(exc): # If the task has no retries left, or the retry condition is not met, set the task to failed. context = TaskRunContext.get() - state = run_coro_as_sync( - exception_to_failed_state( - exc, - message="Task run encountered an exception", - result_factory=getattr(context, "result_factory", None), - ) + state = await exception_to_failed_state( + exc, + message="Task run encountered an exception", + result_factory=getattr(context, "result_factory", None), ) self.record_terminal_state_timing(state) - self.set_state(state) + await self.set_state(state) self._raised = exc - def handle_timeout(self, exc: TimeoutError) -> None: - if not self.handle_retry(exc): + async def handle_timeout(self, exc: TimeoutError) -> None: + if not await self.handle_retry(exc): if isinstance(exc, TaskRunTimeoutError): message = f"Task run exceeded timeout of {self.task.timeout_seconds} second(s)" else: @@ -494,29 +1089,19 @@ def handle_timeout(self, exc: TimeoutError) -> None: message=message, name="TimedOut", ) - self.set_state(state) + await self.set_state(state) self._raised = exc - def handle_crash(self, exc: BaseException) -> None: - state = run_coro_as_sync(exception_to_crashed_state(exc)) + async def handle_crash(self, exc: BaseException) -> None: + state = await exception_to_crashed_state(exc) self.logger.error(f"Crash detected! {state.message}") self.logger.debug("Crash details:", exc_info=exc) self.record_terminal_state_timing(state) - self.set_state(state, force=True) + await self.set_state(state, force=True) self._raised = exc - def record_terminal_state_timing(self, state: State) -> None: - if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: - if self.task_run.start_time and not self.task_run.end_time: - self.task_run.end_time = state.timestamp - - if self.task_run.state.is_running(): - self.task_run.total_run_time += ( - state.timestamp - self.task_run.state.timestamp - ) - - @contextmanager - def setup_run_context(self, client: Optional[SyncPrefectClient] = None): + @asynccontextmanager + async def setup_run_context(self, client: Optional[PrefectClient] = None): from prefect.utilities.engine import ( _resolve_custom_task_run_name, should_log_prints, @@ -528,7 +1113,7 @@ def setup_run_context(self, client: Optional[SyncPrefectClient] = None): raise ValueError("Task run is not set") if not PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: - self.task_run = client.read_task_run(self.task_run.id) + self.task_run = await client.read_task_run(self.task_run.id) with ExitStack() as stack: if log_prints := should_log_prints(self.task): stack.enter_context(patch_print()) @@ -538,7 +1123,7 @@ def setup_run_context(self, client: Optional[SyncPrefectClient] = None): log_prints=log_prints, task_run=self.task_run, parameters=self.parameters, - result_factory=run_coro_as_sync(ResultFactory.from_task(self.task)), # type: ignore + result_factory=await ResultFactory.from_task(self.task), # type: ignore client=client, ) ) @@ -551,7 +1136,7 @@ def setup_run_context(self, client: Optional[SyncPrefectClient] = None): task_run_name = _resolve_custom_task_run_name( task=self.task, parameters=self.parameters ) - self.client.set_task_run_name( + await self.client.set_task_run_name( task_run_id=self.task_run.id, name=task_run_name ) self.logger.extra["task_run_name"] = task_run_name @@ -562,19 +1147,19 @@ def setup_run_context(self, client: Optional[SyncPrefectClient] = None): self._task_name_set = True yield - @contextmanager - def initialize_run( + @asynccontextmanager + async def initialize_run( self, task_run_id: Optional[UUID] = None, dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, - ) -> Generator["TaskRunEngine", Any, Any]: + ) -> AsyncGenerator["AsyncTaskRunEngine", Any]: """ Enters a client context and creates a task run if needed. """ with hydrated_context(self.context): - with ClientContext.get_or_create() as client_ctx: - self._client = client_ctx.sync_client + async with AsyncClientContext.get_or_create() as client_ctx: + self._client = client_ctx.client self._is_started = True try: if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: @@ -594,28 +1179,24 @@ def initialize_run( self.task_run.name = task_run_name if not self.task_run: - self.task_run = run_coro_as_sync( - self.task.create_local_run( - id=task_run_id, - parameters=self.parameters, - flow_run_context=FlowRunContext.get(), - parent_task_run_context=TaskRunContext.get(), - wait_for=self.wait_for, - extra_task_inputs=dependencies, - task_run_name=task_run_name, - ) + self.task_run = await self.task.create_local_run( + id=task_run_id, + parameters=self.parameters, + flow_run_context=FlowRunContext.get(), + parent_task_run_context=TaskRunContext.get(), + wait_for=self.wait_for, + extra_task_inputs=dependencies, + task_run_name=task_run_name, ) else: if not self.task_run: - self.task_run = run_coro_as_sync( - self.task.create_run( - id=task_run_id, - parameters=self.parameters, - flow_run_context=FlowRunContext.get(), - parent_task_run_context=TaskRunContext.get(), - wait_for=self.wait_for, - extra_task_inputs=dependencies, - ) + self.task_run = await self.task.create_run( + id=task_run_id, + parameters=self.parameters, + flow_run_context=FlowRunContext.get(), + parent_task_run_context=TaskRunContext.get(), + wait_for=self.wait_for, + extra_task_inputs=dependencies, ) # Emit an event to capture that the task run was in the `PENDING` state. self._last_event = emit_task_run_state_change_event( @@ -624,7 +1205,7 @@ def initialize_run( validated_state=self.task_run.state, ) - with self.setup_run_context(): + async with self.setup_run_context(): # setup_run_context might update the task run name, so log creation here self.logger.info( f"Created task run {self.task_run.name!r} for task {self.task.name!r}" @@ -633,7 +1214,7 @@ def initialize_run( except TerminationSignal as exc: # TerminationSignals are caught and handled as crashes - self.handle_crash(exc) + await self.handle_crash(exc) raise exc except Exception: @@ -649,60 +1230,19 @@ def initialize_run( raise except BaseException as exc: # BaseExceptions are caught and handled as crashes - self.handle_crash(exc) + await self.handle_crash(exc) raise finally: - # If debugging, use the more complete `repr` than the usual `str` description - display_state = ( - repr(self.state) if PREFECT_DEBUG_MODE else str(self.state) - ) - level = logging.INFO if self.state.is_completed() else logging.ERROR - msg = f"Finished in state {display_state}" - if self.state.is_pending(): - msg += ( - "\nPlease wait for all submitted tasks to complete" - " before exiting your flow by calling `.wait()` on the " - "`PrefectFuture` returned from your `.submit()` calls." - ) - msg += dedent( - """ - - Example: - - from prefect import flow, task - - @task - def say_hello(name): - print f"Hello, {name}!" - - @flow - def example_flow(): - future = say_hello.submit(name="Marvin) - future.wait() - - example_flow() - """ - ) - self.logger.log( - level=level, - msg=msg, - ) - + self.log_finished_message() self._is_started = False self._client = None - def is_running(self) -> bool: - """Whether or not the engine is currently running a task.""" - if (task_run := getattr(self, "task_run", None)) is None: - return False - return task_run.state.is_running() or task_run.state.is_scheduled() - async def wait_until_ready(self): """Waits until the scheduled time (if its the future), then enters Running.""" if scheduled_time := self.state.state_details.scheduled_time: sleep_time = (scheduled_time - pendulum.now("utc")).total_seconds() await anyio.sleep(sleep_time if sleep_time > 0 else 0) - self.set_state( + await self.set_state( Retrying() if self.state.name == "AwaitingRetry" else Running(), force=True, ) @@ -713,21 +1253,23 @@ async def wait_until_ready(self): # # -------------------------- - @contextmanager - def start( + @asynccontextmanager + async def start( self, task_run_id: Optional[UUID] = None, dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, - ) -> Generator[None, None, None]: - with self.initialize_run(task_run_id=task_run_id, dependencies=dependencies): - self.begin_run() + ) -> AsyncGenerator[None, None]: + async with self.initialize_run( + task_run_id=task_run_id, dependencies=dependencies + ): + await self.begin_run() try: yield finally: - self.call_hooks() + await self.call_hooks() - @contextmanager - def transaction_context(self) -> Generator[Transaction, None, None]: + @asynccontextmanager + async def transaction_context(self) -> AsyncGenerator[Transaction, None]: result_factory = getattr(TaskRunContext.get(), "result_factory", None) # refresh cache setting is now repurposes as overwrite transaction record @@ -744,13 +1286,12 @@ def transaction_context(self) -> Generator[Transaction, None, None]: ) as txn: yield txn - @contextmanager - def run_context(self): - timeout_context = timeout_async if self.task.isasync else timeout + @asynccontextmanager + async def run_context(self): # reenter the run context to ensure it is up to date for every run - with self.setup_run_context(): + async with self.setup_run_context(): try: - with timeout_context( + with timeout_async( seconds=self.task.timeout_seconds, timeout_exc_type=TaskRunTimeoutError, ): @@ -762,11 +1303,11 @@ def run_context(self): yield self except TimeoutError as exc: - self.handle_timeout(exc) + await self.handle_timeout(exc) except Exception as exc: - self.handle_exception(exc) + await self.handle_exception(exc) - def call_task_fn( + async def call_task_fn( self, transaction: Transaction ) -> Union[R, Coroutine[Any, Any, R]]: """ @@ -774,24 +1315,12 @@ def call_task_fn( task is async. """ parameters = self.parameters or {} - if self.task.isasync: - - async def _call_task_fn(): - if transaction.is_committed(): - result = transaction.read() - else: - result = await call_with_parameters(self.task.fn, parameters) - self.handle_success(result, transaction=transaction) - return result - - return _call_task_fn() + if transaction.is_committed(): + result = transaction.read() else: - if transaction.is_committed(): - result = transaction.read() - else: - result = call_with_parameters(self.task.fn, parameters) - self.handle_success(result, transaction=transaction) - return result + result = await call_with_parameters(self.task.fn, parameters) + await self.handle_success(result, transaction=transaction) + return result def run_task_sync( @@ -804,7 +1333,7 @@ def run_task_sync( dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, context: Optional[Dict[str, Any]] = None, ) -> Union[R, State, None]: - engine = TaskRunEngine[P, R]( + engine = SyncTaskRunEngine[P, R]( task=task, parameters=parameters, task_run=task_run, @@ -831,7 +1360,7 @@ async def run_task_async( dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, context: Optional[Dict[str, Any]] = None, ) -> Union[R, State, None]: - engine = TaskRunEngine[P, R]( + engine = AsyncTaskRunEngine[P, R]( task=task, parameters=parameters, task_run=task_run, @@ -839,13 +1368,13 @@ async def run_task_async( context=context, ) - with engine.start(task_run_id=task_run_id, dependencies=dependencies): + async with engine.start(task_run_id=task_run_id, dependencies=dependencies): while engine.is_running(): await engine.wait_until_ready() - with engine.run_context(), engine.transaction_context() as txn: + async with engine.run_context(), engine.transaction_context() as txn: await engine.call_task_fn(txn) - return engine.state if return_type == "state" else engine.result() + return engine.state if return_type == "state" else await engine.result() def run_generator_task_sync( @@ -861,7 +1390,7 @@ def run_generator_task_sync( if return_type != "result": raise ValueError("The return_type for a generator task must be 'result'") - engine = TaskRunEngine[P, R]( + engine = SyncTaskRunEngine[P, R]( task=task, parameters=parameters, task_run=task_run, @@ -915,7 +1444,7 @@ async def run_generator_task_async( ) -> AsyncGenerator[R, None]: if return_type != "result": raise ValueError("The return_type for a generator task must be 'result'") - engine = TaskRunEngine[P, R]( + engine = AsyncTaskRunEngine[P, R]( task=task, parameters=parameters, task_run=task_run, @@ -923,10 +1452,10 @@ async def run_generator_task_async( context=context, ) - with engine.start(task_run_id=task_run_id, dependencies=dependencies): + async with engine.start(task_run_id=task_run_id, dependencies=dependencies): while engine.is_running(): await engine.wait_until_ready() - with engine.run_context(), engine.transaction_context() as txn: + async with engine.run_context(), engine.transaction_context() as txn: # TODO: generators should default to commit_mode=OFF # because they are dynamic by definition # for now we just prevent this branch explicitly @@ -950,13 +1479,13 @@ async def run_generator_task_async( link_state_to_result(engine.state, gen_result) yield gen_result except (StopAsyncIteration, GeneratorExit) as exc: - engine.handle_success(None, transaction=txn) + await engine.handle_success(None, transaction=txn) if isinstance(exc, GeneratorExit): gen.throw(exc) # async generators can't return, but we can raise failures here if engine.state.is_failed(): - engine.result() + await engine.result() def run_task( diff --git a/tests/test_flow_engine.py b/tests/test_flow_engine.py index 718e3bcde1cf..808fa5aeb208 100644 --- a/tests/test_flow_engine.py +++ b/tests/test_flow_engine.py @@ -4,6 +4,7 @@ import warnings from textwrap import dedent from typing import Optional +from unittest import mock from unittest.mock import MagicMock from uuid import UUID @@ -36,6 +37,7 @@ from prefect.input.run_input import RunInput from prefect.logging import get_run_logger from prefect.server.schemas.core import FlowRun as ServerFlowRun +from prefect.testing.utilities import AsyncMock from prefect.utilities.callables import get_call_parameters from prefect.utilities.filesystem import tmpchdir @@ -1124,9 +1126,8 @@ async def flow_resumer(): ) assert schema is not None - async def test_paused_task_polling(self, monkeypatch, prefect_client): - sleeper = MagicMock(side_effect=[None, None, None, None, None]) - monkeypatch.setattr("prefect.task_engine.time.sleep", sleeper) + async def test_paused_task_polling(self, prefect_client): + sleeper = AsyncMock(side_effect=[None, None, None, None, None]) @task async def doesnt_pause(): @@ -1151,9 +1152,10 @@ async def pausing_flow(): # execution isn't blocked, so this task should enter the engine, but not begin # execution - with pytest.raises(RuntimeError): - # the sleeper mock will exhaust its side effects after 6 calls - await doesnt_run() + with mock.patch("prefect.task_engine.anyio.sleep", sleeper): + with pytest.raises(RuntimeError): + # the sleeper mock will exhaust its side effects after 6 calls + await doesnt_run() await pausing_flow() diff --git a/tests/test_task_engine.py b/tests/test_task_engine.py index d2a376fe0504..82061790bd97 100644 --- a/tests/test_task_engine.py +++ b/tests/test_task_engine.py @@ -33,7 +33,12 @@ temporary_settings, ) from prefect.states import Running, State -from prefect.task_engine import TaskRunEngine, run_task_async, run_task_sync +from prefect.task_engine import ( + AsyncTaskRunEngine, + SyncTaskRunEngine, + run_task_async, + run_task_sync, +) from prefect.task_runners import ThreadPoolTaskRunner from prefect.testing.utilities import exceptions_equal from prefect.utilities.callables import get_call_parameters @@ -56,20 +61,20 @@ async def foo(): return 42 -class TestTaskRunEngine: +class TestSyncTaskRunEngine: async def test_basic_init(self): - engine = TaskRunEngine(task=foo) + engine = SyncTaskRunEngine(task=foo) assert isinstance(engine.task, Task) assert engine.task.name == "foo" assert engine.parameters == {} async def test_client_attribute_raises_informative_error(self): - engine = TaskRunEngine(task=foo) + engine = SyncTaskRunEngine(task=foo) with pytest.raises(RuntimeError, match="not started"): engine.client async def test_client_attr_returns_client_after_starting(self): - engine = TaskRunEngine(task=foo) + engine = SyncTaskRunEngine(task=foo) with engine.initialize_run(): client = engine.client assert isinstance(client, SyncPrefectClient) @@ -78,6 +83,28 @@ async def test_client_attr_returns_client_after_starting(self): engine.client +class TestAsyncTaskRunEngine: + async def test_basic_init(self): + engine = AsyncTaskRunEngine(task=foo) + assert isinstance(engine.task, Task) + assert engine.task.name == "foo" + assert engine.parameters == {} + + async def test_client_attribute_raises_informative_error(self): + engine = AsyncTaskRunEngine(task=foo) + with pytest.raises(RuntimeError, match="not started"): + engine.client + + async def test_client_attr_returns_client_after_starting(self): + engine = AsyncTaskRunEngine(task=foo) + async with engine.initialize_run(): + client = engine.client + assert isinstance(client, PrefectClient) + + with pytest.raises(RuntimeError, match="not started"): + engine.client + + class TestRunTask: def test_run_task_with_client_provided_uuid( self, sync_prefect_client, events_pipeline @@ -1163,11 +1190,36 @@ def my_task(): await task_run.state.result() @pytest.mark.parametrize("interrupt_type", [KeyboardInterrupt, SystemExit]) - async def test_interrupt_in_task_orchestration_crashes_task_and_flow( + async def test_interrupt_in_task_orchestration_crashes_task_and_flow_sync( + self, prefect_client, events_pipeline, interrupt_type, monkeypatch + ): + monkeypatch.setattr( + SyncTaskRunEngine, "begin_run", MagicMock(side_effect=interrupt_type) + ) + + @task + def my_task(): + pass + + with pytest.raises(interrupt_type): + my_task() + + await events_pipeline.process_events() + task_runs = await prefect_client.read_task_runs() + assert len(task_runs) == 1 + task_run = task_runs[0] + assert task_run.state.is_crashed() + assert task_run.state.type == StateType.CRASHED + assert "Execution was aborted" in task_run.state.message + with pytest.raises(CrashedRun, match="Execution was aborted"): + await task_run.state.result() + + @pytest.mark.parametrize("interrupt_type", [KeyboardInterrupt, SystemExit]) + async def test_interrupt_in_task_orchestration_crashes_task_and_flow_async( self, prefect_client, events_pipeline, interrupt_type, monkeypatch ): monkeypatch.setattr( - TaskRunEngine, "begin_run", MagicMock(side_effect=interrupt_type) + AsyncTaskRunEngine, "begin_run", MagicMock(side_effect=interrupt_type) ) @task @@ -1361,7 +1413,7 @@ async def test_sync_task_does_not_set_end_time_on_crash_pre_runnning( self, monkeypatch, prefect_client, events_pipeline ): monkeypatch.setattr( - TaskRunEngine, "begin_run", MagicMock(side_effect=SystemExit) + SyncTaskRunEngine, "begin_run", MagicMock(side_effect=SystemExit) ) @task @@ -1382,7 +1434,7 @@ async def test_async_task_does_not_set_end_time_on_crash_pre_running( self, monkeypatch, prefect_client, events_pipeline ): monkeypatch.setattr( - TaskRunEngine, "begin_run", MagicMock(side_effect=SystemExit) + AsyncTaskRunEngine, "begin_run", MagicMock(side_effect=SystemExit) ) @task diff --git a/tests/test_task_worker.py b/tests/test_task_worker.py index cd16442ae126..0349832ba17f 100644 --- a/tests/test_task_worker.py +++ b/tests/test_task_worker.py @@ -205,7 +205,7 @@ def empty_task(): def always_error(*args, **kwargs): raise ValueError("oops") - monkeypatch.setattr("prefect.task_engine.TaskRunEngine.start", always_error) + monkeypatch.setattr("prefect.task_engine.SyncTaskRunEngine.start", always_error) task_worker = TaskWorker(empty_task) @@ -842,53 +842,58 @@ async def mock_iter(): async def test_tasks_execute_when_capacity_frees_up( self, mock_subscription, prefect_client, events_pipeline ): - event = asyncio.Event() + execution_order = [] @task - async def slow_task(): - await asyncio.sleep(1) - if event.is_set(): - raise ValueError("Something went wrong! This event should not be set.") - event.set() + async def slow_task(task_id: str): + execution_order.append(f"{task_id} start") + await asyncio.sleep(0.1) # Simulating some work + execution_order.append(f"{task_id} end") task_worker = TaskWorker(slow_task, limit=1) - task_run_future_1 = slow_task.apply_async() + task_run_future_1 = slow_task.apply_async(("task1",)) task_run_1 = await prefect_client.read_task_run(task_run_future_1.task_run_id) - task_run_future_2 = slow_task.apply_async() + task_run_future_2 = slow_task.apply_async(("task2",)) task_run_2 = await prefect_client.read_task_run(task_run_future_2.task_run_id) async def mock_iter(): yield task_run_1 yield task_run_2 - # sleep for a second to ensure that task execution starts - await asyncio.sleep(1) + while len(execution_order) < 4: + await asyncio.sleep(0.1) mock_subscription.return_value = mock_iter() server_task = asyncio.create_task(task_worker.start()) - await event.wait() - - await events_pipeline.process_events() - - updated_task_run_1 = await prefect_client.read_task_run(task_run_1.id) - updated_task_run_2 = await prefect_client.read_task_run(task_run_2.id) - assert updated_task_run_1.state.is_completed() - assert not updated_task_run_2.state.is_completed() - - # clear the event to allow the second task to complete - event.clear() - - await event.wait() - await events_pipeline.process_events() - - updated_task_run_2 = await prefect_client.read_task_run(task_run_2.id) - - assert updated_task_run_2.state.is_completed() - - server_task.cancel() - await server_task + try: + # Wait for both tasks to complete + await asyncio.sleep(2) + + await events_pipeline.process_events() + + # Verify the execution order + assert execution_order == [ + "task1 start", + "task1 end", + "task2 start", + "task2 end", + ], "Tasks should execute sequentially" + + # Verify the states of both tasks + updated_task_run_1 = await prefect_client.read_task_run(task_run_1.id) + updated_task_run_2 = await prefect_client.read_task_run(task_run_2.id) + + assert updated_task_run_1.state.is_completed() + assert updated_task_run_2.state.is_completed() + + finally: + server_task.cancel() + try: + await server_task + except asyncio.CancelledError: + pass async def test_execute_task_run_respects_limit( self, prefect_client, events_pipeline