From 698e585a50f9f81e3416979a55abb624cd474ce7 Mon Sep 17 00:00:00 2001 From: Martijn Pieters Date: Wed, 18 Dec 2024 18:33:53 +0000 Subject: [PATCH] [typing] prefect.concurrency --- client/client_flow.py | 2 +- src/prefect/_internal/concurrency/services.py | 215 ++++++++++++------ src/prefect/concurrency/.ruff.toml | 3 + .../concurrency/{asyncio.py => _asyncio.py} | 87 ++++--- .../concurrency/{events.py => _events.py} | 20 +- src/prefect/concurrency/context.py | 10 +- src/prefect/concurrency/services.py | 100 +++----- src/prefect/concurrency/sync.py | 67 ++---- .../v1/{asyncio.py => _asyncio.py} | 81 ++++--- .../concurrency/v1/{events.py => _events.py} | 28 ++- src/prefect/concurrency/v1/context.py | 10 +- src/prefect/concurrency/v1/services.py | 112 +++------ src/prefect/concurrency/v1/sync.py | 55 ++--- src/prefect/context.py | 136 ++++++----- src/prefect/runner/runner.py | 2 +- src/prefect/task_engine.py | 2 +- .../test_acquire_concurrency_slots.py | 8 +- tests/concurrency/test_concurrency_asyncio.py | 34 +-- ...st_concurrency_slot_acquisition_service.py | 9 +- tests/concurrency/test_concurrency_sync.py | 2 +- tests/concurrency/test_context.py | 2 +- .../test_release_concurrency_slots.py | 8 +- .../v1/test_concurrency_asyncio.py | 18 +- ...t_concurrency_limit_acquisition_service.py | 8 +- tests/concurrency/v1/test_concurrency_sync.py | 18 +- tests/concurrency/v1/test_context.py | 2 +- .../v1/test_decrement_concurrency_slots.py | 6 +- .../v1/test_increment_concurrency_limits.py | 8 +- tests/test_flow_engine.py | 2 +- tests/test_task_engine.py | 65 +++--- 30 files changed, 528 insertions(+), 592 deletions(-) create mode 100644 src/prefect/concurrency/.ruff.toml rename src/prefect/concurrency/{asyncio.py => _asyncio.py} (76%) rename src/prefect/concurrency/{events.py => _events.py} (76%) rename src/prefect/concurrency/v1/{asyncio.py => _asyncio.py} (65%) rename src/prefect/concurrency/v1/{events.py => _events.py} (65%) diff --git a/client/client_flow.py b/client/client_flow.py index 429a4f9507662..39c73a80e805e 100644 --- a/client/client_flow.py +++ b/client/client_flow.py @@ -1,5 +1,5 @@ from prefect import flow, task -from prefect.concurrency import asyncio, events, services, sync # noqa: F401 +from prefect.concurrency import events, services, sync # noqa: F401 def skip_remote_run(): diff --git a/src/prefect/_internal/concurrency/services.py b/src/prefect/_internal/concurrency/services.py index 6e33aca96be47..15498a29c8d75 100644 --- a/src/prefect/_internal/concurrency/services.py +++ b/src/prefect/_internal/concurrency/services.py @@ -7,9 +7,10 @@ import queue import sys import threading -from typing import Awaitable, Dict, Generic, List, Optional, Type, TypeVar, Union +from collections.abc import AsyncGenerator, Awaitable, Coroutine, Generator, Hashable +from typing import TYPE_CHECKING, Any, Generic, NoReturn, Optional, Union, cast -from typing_extensions import Self +from typing_extensions import Self, TypeVar, TypeVarTuple, Unpack from prefect._internal.concurrency import logger from prefect._internal.concurrency.api import create_call, from_sync @@ -18,17 +19,19 @@ from prefect._internal.concurrency.threads import WorkerThread, get_global_loop T = TypeVar("T") +Ts = TypeVarTuple("Ts") +R = TypeVar("R", infer_variance=True) -class QueueService(abc.ABC, Generic[T]): - _instances: Dict[int, Self] = {} +class _QueueServiceBase(abc.ABC, Generic[T]): + _instances: dict[int, Self] = {} _instance_lock = threading.Lock() - def __init__(self, *args) -> None: - self._queue: queue.Queue = queue.Queue() + def __init__(self, *args: Hashable) -> None: + self._queue: queue.Queue[Optional[T]] = queue.Queue() self._loop: Optional[asyncio.AbstractEventLoop] = None self._done_event: Optional[asyncio.Event] = None - self._task: Optional[asyncio.Task] = None + self._task: Optional[asyncio.Task[None]] = None self._stopped: bool = False self._started: bool = False self._key = hash(args) @@ -41,14 +44,14 @@ def __init__(self, *args) -> None: ) self._logger = logging.getLogger(f"{type(self).__name__}") - def start(self): + def start(self) -> None: logger.debug("Starting service %r", self) loop_thread = get_global_loop() - if not asyncio.get_running_loop() == loop_thread._loop: + if not asyncio.get_running_loop() == getattr(loop_thread, "_loop"): raise RuntimeError("Services must run on the global loop thread.") - self._loop = loop_thread._loop + self._loop = asyncio.get_running_loop() self._done_event = asyncio.Event() self._task = self._loop.create_task(self._run()) self._queue_get_thread.start() @@ -67,14 +70,16 @@ def start(self): # _before_ the normal `atexit` hook is called resulting in failure to # process items. This is particularly relevant for services which use an # httpx client. - from threading import _register_atexit + from threading import ( + _register_atexit, # pyright: ignore[reportUnknownVariableType, reportAttributeAccessIssue] + ) _register_atexit(self._at_exit) - def _at_exit(self): + def _at_exit(self) -> None: self.drain(at_exit=True) - def _stop(self, at_exit: bool = False): + def _stop(self, at_exit: bool = False) -> None: """ Stop running this instance. @@ -100,27 +105,11 @@ def _stop(self, at_exit: bool = False): # Signal completion to the loop self._queue.put_nowait(None) - def send(self, item: T): - """ - Send an item to this instance of the service. - """ - with self._lock: - if self._stopped: - raise RuntimeError("Cannot put items in a stopped service instance.") - - logger.debug("Service %r enqueuing item %r", self, item) - self._queue.put_nowait(self._prepare_item(item)) - - def _prepare_item(self, item: T) -> T: - """ - Prepare an item for submission to the service. This is called before - the item is sent to the service. - - The default implementation returns the item unchanged. - """ - return item + @abc.abstractmethod + def send(self, item: Any) -> Any: + raise NotImplementedError - async def _run(self): + async def _run(self) -> None: try: async with self._lifespan(): await self._main_loop() @@ -142,14 +131,15 @@ async def _run(self): self._queue_get_thread.shutdown() self._stopped = True + assert self._done_event is not None self._done_event.set() - async def _main_loop(self): + async def _main_loop(self) -> None: last_log_time = 0 log_interval = 4 # log every 4 seconds while True: - item: T = await self._queue_get_thread.submit( + item: Optional[T] = await self._queue_get_thread.submit( create_call(self._queue.get) ).aresult() @@ -183,19 +173,17 @@ async def _main_loop(self): self._queue.task_done() @abc.abstractmethod - async def _handle(self, item: T): - """ - Process an item sent to the service. - """ + def _handle(self, item: Any) -> Any: + raise NotImplementedError @contextlib.asynccontextmanager - async def _lifespan(self): + async def _lifespan(self) -> AsyncGenerator[None, Any]: """ Perform any setup and teardown for the service. """ yield - def _drain(self, at_exit: bool = False) -> concurrent.futures.Future: + def _drain(self, at_exit: bool = False) -> concurrent.futures.Future[bool]: """ Internal implementation for `drain`. Returns a future for sync/async interfaces. """ @@ -204,15 +192,17 @@ def _drain(self, at_exit: bool = False) -> concurrent.futures.Future: self._stop(at_exit=at_exit) + assert self._done_event is not None if self._done_event.is_set(): - future = concurrent.futures.Future() - future.set_result(None) + future: concurrent.futures.Future[bool] = concurrent.futures.Future() + future.set_result(False) return future - future = asyncio.run_coroutine_threadsafe(self._done_event.wait(), self._loop) - return future + assert self._loop is not None + task = cast(Coroutine[Any, Any, bool], self._done_event.wait()) + return asyncio.run_coroutine_threadsafe(task, self._loop) - def drain(self, at_exit: bool = False) -> None: + def drain(self, at_exit: bool = False) -> Union[bool, Awaitable[bool]]: """ Stop this instance of the service and wait for remaining work to be completed. @@ -225,14 +215,25 @@ def drain(self, at_exit: bool = False) -> None: return future.result() @classmethod - def drain_all(cls, timeout: Optional[float] = None) -> Union[Awaitable, None]: + def drain_all( + cls, timeout: Optional[float] = None + ) -> Union[ + tuple[ + set[concurrent.futures.Future[bool]], set[concurrent.futures.Future[bool]] + ], + Coroutine[ + Any, + Any, + Optional[tuple[set[asyncio.Future[bool]], set[asyncio.Future[bool]]]], + ], + ]: """ Stop all instances of the service and wait for all remaining work to be completed. Returns an awaitable if called from an async context. """ - futures = [] + futures: list[concurrent.futures.Future[bool]] = [] with cls._instance_lock: instances = tuple(cls._instances.values()) @@ -240,26 +241,24 @@ def drain_all(cls, timeout: Optional[float] = None) -> Union[Awaitable, None]: futures.append(instance._drain()) if get_running_loop() is not None: - return ( - asyncio.wait( + if futures: + return asyncio.wait( [asyncio.wrap_future(fut) for fut in futures], timeout=timeout ) - if futures - # `wait` errors if it receives an empty list but we need to return a - # coroutine still - else asyncio.sleep(0) - ) + # `wait` errors if it receives an empty list but we need to return a + # coroutine still + return asyncio.sleep(0) else: return concurrent.futures.wait(futures, timeout=timeout) - def wait_until_empty(self): + def wait_until_empty(self) -> None: """ Wait until the queue is empty and all items have been processed. """ self._queue.join() @classmethod - def instance(cls: Type[Self], *args) -> Self: + def instance(cls, *args: Hashable) -> Self: """ Get an instance of the service. @@ -276,7 +275,7 @@ def _remove_instance(self): self._instances.pop(self._key, None) @classmethod - def _new_instance(cls, *args): + def _new_instance(cls, *args: Hashable) -> Self: """ Create and start a new instance of the service. """ @@ -293,6 +292,87 @@ def _new_instance(cls, *args): return instance +class QueueService(_QueueServiceBase[T]): + def send(self, item: T) -> None: + """ + Send an item to this instance of the service. + """ + with self._lock: + if self._stopped: + raise RuntimeError("Cannot put items in a stopped service instance.") + + logger.debug("Service %r enqueuing item %r", self, item) + self._queue.put_nowait(self._prepare_item(item)) + + def _prepare_item(self, item: T) -> T: + """ + Prepare an item for submission to the service. This is called before + the item is sent to the service. + + The default implementation returns the item unchanged. + """ + return item + + @abc.abstractmethod + async def _handle(self, item: T) -> None: + """ + Process an item sent to the service. + """ + + +class FutureQueueService( + _QueueServiceBase[tuple[Unpack[Ts], concurrent.futures.Future[R]]] +): + """Queued service that provides a future that is signalled with the acquired result for each item + + If there was a failure acquiring, the future result is set to the exception. + + Type Parameters: + Ts: the tuple of types that make up sent arguments + R: the type returned for each item once acquired + + """ + + async def _handle( + self, item: tuple[Unpack[Ts], concurrent.futures.Future[R]] + ) -> None: + send_item, future = item[:-1], item[-1] + try: + response = await self.acquire(*send_item) + except Exception as exc: + # If the request to the increment endpoint fails in a non-standard + # way, we need to set the future's result so it'll be re-raised in + # the context of the caller. + future.set_exception(exc) + raise exc + else: + future.set_result(response) + + @abc.abstractmethod + async def acquire(self, *args: Unpack[Ts]) -> R: + raise NotImplementedError + + def send(self, item: tuple[Unpack[Ts]]) -> concurrent.futures.Future[R]: + with self._lock: + if self._stopped: + raise RuntimeError("Cannot put items in a stopped service instance.") + + logger.debug("Service %r enqueuing item %r", self, item) + future: concurrent.futures.Future[R] = concurrent.futures.Future() + self._queue.put_nowait((*self._prepare_item(item), future)) + + return future + + def _prepare_item(self, item: tuple[Unpack[Ts]]) -> tuple[Unpack[Ts]]: + """ + Prepare an item for submission to the service. This is called before + the item is sent to the service. + + The default implementation returns the item unchanged. + """ + return item + + class BatchedQueueService(QueueService[T]): """ A queue service that handles a batch of items instead of a single item at a time. @@ -308,7 +388,7 @@ async def _main_loop(self): done = False while not done: - batch = [] + batch: list[T] = [] batch_size = 0 # Pull items from the queue until we reach the batch size @@ -357,13 +437,15 @@ async def _main_loop(self): ) @abc.abstractmethod - async def _handle_batch(self, items: List[T]): + async def _handle_batch(self, items: list[T]) -> None: """ Process a batch of items sent to the service. """ - async def _handle(self, item: T): - assert False, "`_handle` should never be called for batched queue services" + async def _handle(self, item: T) -> NoReturn: + raise AssertionError( + "`_handle` should never be called for batched queue services" + ) def _get_size(self, item: T) -> int: """ @@ -374,12 +456,15 @@ def _get_size(self, item: T) -> int: @contextlib.contextmanager -def drain_on_exit(service: QueueService): +def drain_on_exit(service: QueueService[Any]) -> Generator[None, Any, None]: yield service.drain_all() @contextlib.asynccontextmanager -async def drain_on_exit_async(service: QueueService): +async def drain_on_exit_async(service: QueueService[Any]) -> AsyncGenerator[None, Any]: yield - await service.drain_all() + drain_all = service.drain_all() + if TYPE_CHECKING: + assert not isinstance(drain_all, tuple) + await drain_all diff --git a/src/prefect/concurrency/.ruff.toml b/src/prefect/concurrency/.ruff.toml new file mode 100644 index 0000000000000..f8a73dcd5f2fd --- /dev/null +++ b/src/prefect/concurrency/.ruff.toml @@ -0,0 +1,3 @@ +extend-select = ["UP"] +target-version = "py39" + diff --git a/src/prefect/concurrency/asyncio.py b/src/prefect/concurrency/_asyncio.py similarity index 76% rename from src/prefect/concurrency/asyncio.py rename to src/prefect/concurrency/_asyncio.py index 5d419a6c079f5..b5e4d6aa5e771 100644 --- a/src/prefect/concurrency/asyncio.py +++ b/src/prefect/concurrency/_asyncio.py @@ -1,28 +1,22 @@ import asyncio +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from typing import AsyncGenerator, List, Literal, Optional, Union, cast +from typing import Literal, Optional, Union import anyio import httpx import pendulum from prefect._internal.compatibility.deprecated import deprecated_parameter - -try: - from pendulum import Interval -except ImportError: - # pendulum < 3 - from pendulum.period import Period as Interval # type: ignore - from prefect.client.orchestration import get_client from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse from prefect.logging.loggers import get_run_logger -from .context import ConcurrencyContext -from .events import ( - _emit_concurrency_acquisition_events, - _emit_concurrency_release_events, +from ._events import ( + emit_concurrency_acquisition_events, + emit_concurrency_release_events, ) +from .context import ConcurrencyContext from .services import ConcurrencySlotAcquisitionService @@ -36,7 +30,7 @@ class AcquireConcurrencySlotTimeoutError(TimeoutError): @asynccontextmanager async def concurrency( - names: Union[str, List[str]], + names: Union[str, list[str]], occupy: int = 1, timeout_seconds: Optional[float] = None, max_retries: Optional[int] = None, @@ -62,7 +56,7 @@ async def concurrency( Example: A simple example of using the async `concurrency` context manager: ```python - from prefect.concurrency.asyncio import concurrency + from prefect.concurrency._asyncio import concurrency async def resource_heavy(): async with concurrency("test", occupy=1): @@ -78,7 +72,7 @@ async def main(): names = names if isinstance(names, list) else [names] - limits = await _aacquire_concurrency_slots( + limits = await aacquire_concurrency_slots( names, occupy, timeout_seconds=timeout_seconds, @@ -87,14 +81,14 @@ async def main(): strict=strict, ) acquisition_time = pendulum.now("UTC") - emitted_events = _emit_concurrency_acquisition_events(limits, occupy) + emitted_events = emit_concurrency_acquisition_events(limits, occupy) try: yield finally: - occupancy_period = cast(Interval, (pendulum.now("UTC") - acquisition_time)) + occupancy_period = pendulum.now("UTC") - acquisition_time try: - await _arelease_concurrency_slots( + await arelease_concurrency_slots( names, occupy, occupancy_period.total_seconds() ) except anyio.get_cancelled_exc_class(): @@ -106,11 +100,11 @@ async def main(): (names, occupy, occupancy_period.total_seconds()) ) - _emit_concurrency_release_events(limits, occupy, emitted_events) + emit_concurrency_release_events(limits, occupy, emitted_events) async def rate_limit( - names: Union[str, List[str]], + names: Union[str, list[str]], occupy: int = 1, timeout_seconds: Optional[float] = None, create_if_missing: Optional[bool] = None, @@ -137,7 +131,7 @@ async def rate_limit( names = names if isinstance(names, list) else [names] - limits = await _aacquire_concurrency_slots( + limits = await aacquire_concurrency_slots( names, occupy, mode="rate_limit", @@ -145,7 +139,7 @@ async def rate_limit( create_if_missing=create_if_missing, strict=strict, ) - _emit_concurrency_acquisition_events(limits, occupy) + emit_concurrency_acquisition_events(limits, occupy) @deprecated_parameter( @@ -155,51 +149,52 @@ async def rate_limit( when=lambda x: x is not None, help="Limits must be explicitly created before acquiring concurrency slots; see `strict` if you want to enforce this behavior.", ) -async def _aacquire_concurrency_slots( - names: List[str], +async def aacquire_concurrency_slots( + names: list[str], slots: int, mode: Literal["concurrency", "rate_limit"] = "concurrency", timeout_seconds: Optional[float] = None, create_if_missing: Optional[bool] = None, max_retries: Optional[int] = None, strict: bool = False, -) -> List[MinimalConcurrencyLimitResponse]: +) -> list[MinimalConcurrencyLimitResponse]: service = ConcurrencySlotAcquisitionService.instance(frozenset(names)) future = service.send( (slots, mode, timeout_seconds, create_if_missing, max_retries) ) - response_or_exception = await asyncio.wrap_future(future) - - if isinstance(response_or_exception, Exception): - if isinstance(response_or_exception, TimeoutError): - raise AcquireConcurrencySlotTimeoutError( - f"Attempt to acquire concurrency slots timed out after {timeout_seconds} second(s)" - ) from response_or_exception - + try: + response = await asyncio.wrap_future(future) + except TimeoutError as timeout: + raise AcquireConcurrencySlotTimeoutError( + f"Attempt to acquire concurrency slots timed out after {timeout_seconds} second(s)" + ) from timeout + except Exception as exc: raise ConcurrencySlotAcquisitionError( f"Unable to acquire concurrency slots on {names!r}" - ) from response_or_exception + ) from exc - retval = _response_to_minimal_concurrency_limit_response(response_or_exception) + retval = _response_to_minimal_concurrency_limit_response(response) - if strict and not retval: - raise ConcurrencySlotAcquisitionError( - f"Concurrency limits {names!r} must be created before acquiring slots" - ) - elif not retval: + if not retval: + if strict: + raise ConcurrencySlotAcquisitionError( + f"Concurrency limits {names!r} must be created before acquiring slots" + ) try: logger = get_run_logger() + except Exception: + pass + else: logger.warning( f"Concurrency limits {names!r} do not exist - skipping acquisition." ) - except Exception: - pass + return retval -async def _arelease_concurrency_slots( - names: List[str], slots: int, occupancy_seconds: float -) -> List[MinimalConcurrencyLimitResponse]: +async def arelease_concurrency_slots( + names: list[str], slots: int, occupancy_seconds: float +) -> list[MinimalConcurrencyLimitResponse]: async with get_client() as client: response = await client.release_concurrency_slots( names=names, slots=slots, occupancy_seconds=occupancy_seconds @@ -209,7 +204,7 @@ async def _arelease_concurrency_slots( def _response_to_minimal_concurrency_limit_response( response: httpx.Response, -) -> List[MinimalConcurrencyLimitResponse]: +) -> list[MinimalConcurrencyLimitResponse]: return [ MinimalConcurrencyLimitResponse.model_validate(obj_) for obj_ in response.json() ] diff --git a/src/prefect/concurrency/events.py b/src/prefect/concurrency/_events.py similarity index 76% rename from src/prefect/concurrency/events.py rename to src/prefect/concurrency/_events.py index c5a7598c7f47d..acd49b156dd35 100644 --- a/src/prefect/concurrency/events.py +++ b/src/prefect/concurrency/_events.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Literal, Optional, Union +from typing import Literal, Optional, Union from uuid import UUID from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse @@ -8,11 +8,11 @@ def _emit_concurrency_event( phase: Union[Literal["acquired"], Literal["released"]], primary_limit: MinimalConcurrencyLimitResponse, - related_limits: List[MinimalConcurrencyLimitResponse], + related_limits: list[MinimalConcurrencyLimitResponse], slots: int, follows: Union[Event, None] = None, ) -> Union[Event, None]: - resource: Dict[str, str] = { + resource: dict[str, str] = { "prefect.resource.id": f"prefect.concurrency-limit.{primary_limit.id}", "prefect.resource.name": primary_limit.name, "slots-acquired": str(slots), @@ -38,11 +38,11 @@ def _emit_concurrency_event( ) -def _emit_concurrency_acquisition_events( - limits: List[MinimalConcurrencyLimitResponse], +def emit_concurrency_acquisition_events( + limits: list[MinimalConcurrencyLimitResponse], occupy: int, -) -> Dict[UUID, Optional[Event]]: - events = {} +) -> dict[UUID, Optional[Event]]: + events: dict[UUID, Optional[Event]] = {} for limit in limits: event = _emit_concurrency_event("acquired", limit, limits, occupy) events[limit.id] = event @@ -50,10 +50,10 @@ def _emit_concurrency_acquisition_events( return events -def _emit_concurrency_release_events( - limits: List[MinimalConcurrencyLimitResponse], +def emit_concurrency_release_events( + limits: list[MinimalConcurrencyLimitResponse], occupy: int, - events: Dict[UUID, Optional[Event]], + events: dict[UUID, Optional[Event]], ) -> None: for limit in limits: _emit_concurrency_event("released", limit, limits, occupy, events[limit.id]) diff --git a/src/prefect/concurrency/context.py b/src/prefect/concurrency/context.py index 9fc3b40ddb809..986f36281c3a1 100644 --- a/src/prefect/concurrency/context.py +++ b/src/prefect/concurrency/context.py @@ -1,19 +1,21 @@ from contextvars import ContextVar -from typing import List, Tuple +from typing import Any, ClassVar + +from typing_extensions import Self from prefect.client.orchestration import get_client from prefect.context import ContextModel, Field class ConcurrencyContext(ContextModel): - __var__: ContextVar = ContextVar("concurrency") + __var__: ClassVar[ContextVar[Self]] = ContextVar("concurrency") # Track the slots that have been acquired but were not able to be released # due to cancellation or some other error. These slots are released when # the context manager exits. - cleanup_slots: List[Tuple[List[str], int, float]] = Field(default_factory=list) + cleanup_slots: list[tuple[list[str], int, float]] = Field(default_factory=list) - def __exit__(self, *exc_info): + def __exit__(self, *exc_info: Any) -> None: if self.cleanup_slots: with get_client(sync_client=True) as client: for names, occupy, occupancy_seconds in self.cleanup_slots: diff --git a/src/prefect/concurrency/services.py b/src/prefect/concurrency/services.py index 64e847cad582d..530ea7ceb3030 100644 --- a/src/prefect/concurrency/services.py +++ b/src/prefect/concurrency/services.py @@ -1,31 +1,30 @@ import asyncio -import concurrent.futures +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from typing import ( - TYPE_CHECKING, - AsyncGenerator, - FrozenSet, - Optional, - Tuple, -) +from typing import TYPE_CHECKING, Optional import httpx from starlette import status +from typing_extensions import TypeAlias, Unpack from prefect._internal.concurrency import logger -from prefect._internal.concurrency.services import QueueService +from prefect._internal.concurrency.services import FutureQueueService from prefect.client.orchestration import get_client from prefect.utilities.timeout import timeout_async if TYPE_CHECKING: from prefect.client.orchestration import PrefectClient +_Item: TypeAlias = tuple[int, str, Optional[float], Optional[bool], Optional[int]] -class ConcurrencySlotAcquisitionService(QueueService): - def __init__(self, concurrency_limit_names: FrozenSet[str]): + +class ConcurrencySlotAcquisitionService( + FutureQueueService[Unpack[_Item], httpx.Response] +): + def __init__(self, concurrency_limit_names: frozenset[str]): super().__init__(concurrency_limit_names) - self._client: "PrefectClient" - self.concurrency_limit_names = sorted(list(concurrency_limit_names)) + self._client: PrefectClient + self.concurrency_limit_names: list[str] = sorted(list(concurrency_limit_names)) @asynccontextmanager async def _lifespan(self) -> AsyncGenerator[None, None]: @@ -33,32 +32,7 @@ async def _lifespan(self) -> AsyncGenerator[None, None]: self._client = client yield - async def _handle( - self, - item: Tuple[ - int, - str, - Optional[float], - concurrent.futures.Future, - Optional[bool], - Optional[int], - ], - ) -> None: - occupy, mode, timeout_seconds, future, create_if_missing, max_retries = item - try: - response = await self.acquire_slots( - occupy, mode, timeout_seconds, create_if_missing, max_retries - ) - except Exception as exc: - # If the request to the increment endpoint fails in a non-standard - # way, we need to set the future's result so that the caller can - # handle the exception and then re-raise. - future.set_result(exc) - raise exc - else: - future.set_result(response) - - async def acquire_slots( + async def acquire( self, slots: int, mode: str, @@ -69,44 +43,22 @@ async def acquire_slots( with timeout_async(seconds=timeout_seconds): while True: try: - response = await self._client.increment_concurrency_slots( + return await self._client.increment_concurrency_slots( names=self.concurrency_limit_names, slots=slots, mode=mode, create_if_missing=create_if_missing, ) - except Exception as exc: - if ( - isinstance(exc, httpx.HTTPStatusError) - and exc.response.status_code == status.HTTP_423_LOCKED - ): - if max_retries is not None and max_retries <= 0: - raise exc - retry_after = float(exc.response.headers["Retry-After"]) - logger.debug( - f"Unable to acquire concurrency slot. Retrying in {retry_after} second(s)." - ) - await asyncio.sleep(retry_after) - if max_retries is not None: - max_retries -= 1 - else: - raise exc - else: - return response - - def send( - self, item: Tuple[int, str, Optional[float], Optional[bool], Optional[int]] - ) -> concurrent.futures.Future: - with self._lock: - if self._stopped: - raise RuntimeError("Cannot put items in a stopped service instance.") + except httpx.HTTPStatusError as exc: + if not exc.response.status_code == status.HTTP_423_LOCKED: + raise - logger.debug("Service %r enqueuing item %r", self, item) - future: concurrent.futures.Future = concurrent.futures.Future() - - occupy, mode, timeout_seconds, create_if_missing, max_retries = item - self._queue.put_nowait( - (occupy, mode, timeout_seconds, future, create_if_missing, max_retries) - ) - - return future + if max_retries is not None and max_retries <= 0: + raise exc + retry_after = float(exc.response.headers["Retry-After"]) + logger.debug( + f"Unable to acquire concurrency slot. Retrying in {retry_after} second(s)." + ) + await asyncio.sleep(retry_after) + if max_retries is not None: + max_retries -= 1 diff --git a/src/prefect/concurrency/sync.py b/src/prefect/concurrency/sync.py index 2f6bf47a3df64..88aa69f47c451 100644 --- a/src/prefect/concurrency/sync.py +++ b/src/prefect/concurrency/sync.py @@ -1,71 +1,54 @@ +from collections.abc import Generator from contextlib import contextmanager -from typing import ( - Generator, - List, - Optional, - TypeVar, - Union, - cast, -) +from typing import Optional, TypeVar, Union import pendulum from typing_extensions import Literal -from prefect.utilities.asyncutils import run_coro_as_sync - -try: - from pendulum import Interval -except ImportError: - # pendulum < 3 - from pendulum.period import Period as Interval # type: ignore - from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse +from prefect.utilities.asyncutils import run_coro_as_sync -from .asyncio import ( - _aacquire_concurrency_slots, - _arelease_concurrency_slots, +from ._asyncio import ( + aacquire_concurrency_slots, + arelease_concurrency_slots, ) -from .events import ( - _emit_concurrency_acquisition_events, - _emit_concurrency_release_events, +from ._events import ( + emit_concurrency_acquisition_events, + emit_concurrency_release_events, ) T = TypeVar("T") def _release_concurrency_slots( - names: List[str], slots: int, occupancy_seconds: float -) -> List[MinimalConcurrencyLimitResponse]: + names: list[str], slots: int, occupancy_seconds: float +) -> list[MinimalConcurrencyLimitResponse]: result = run_coro_as_sync( - _arelease_concurrency_slots(names, slots, occupancy_seconds) + arelease_concurrency_slots(names, slots, occupancy_seconds) ) - if result is None: - raise RuntimeError("Failed to release concurrency slots") return result def _acquire_concurrency_slots( - names: List[str], + names: list[str], slots: int, mode: Literal["concurrency", "rate_limit"] = "concurrency", timeout_seconds: Optional[float] = None, create_if_missing: Optional[bool] = None, max_retries: Optional[int] = None, strict: bool = False, -) -> List[MinimalConcurrencyLimitResponse]: +) -> list[MinimalConcurrencyLimitResponse]: result = run_coro_as_sync( - _aacquire_concurrency_slots( + aacquire_concurrency_slots( names, slots, mode, timeout_seconds, create_if_missing, max_retries, strict ) ) - if result is None: - raise RuntimeError("Failed to acquire concurrency slots") return result @contextmanager def concurrency( - names: Union[str, List[str]], + names: Union[str, list[str]], occupy: int = 1, timeout_seconds: Optional[float] = None, max_retries: Optional[int] = None, @@ -107,7 +90,7 @@ def main(): names = names if isinstance(names, list) else [names] - limits: List[MinimalConcurrencyLimitResponse] = _acquire_concurrency_slots( + limits: list[MinimalConcurrencyLimitResponse] = _acquire_concurrency_slots( names, occupy, timeout_seconds=timeout_seconds, @@ -116,22 +99,18 @@ def main(): max_retries=max_retries, ) acquisition_time = pendulum.now("UTC") - emitted_events = _emit_concurrency_acquisition_events(limits, occupy) + emitted_events = emit_concurrency_acquisition_events(limits, occupy) try: yield finally: - occupancy_period = cast(Interval, pendulum.now("UTC") - acquisition_time) - _release_concurrency_slots( - names, - occupy, - occupancy_period.total_seconds(), - ) - _emit_concurrency_release_events(limits, occupy, emitted_events) + occupancy_period = pendulum.now("UTC") - acquisition_time + _release_concurrency_slots(names, occupy, occupancy_period.total_seconds()) + emit_concurrency_release_events(limits, occupy, emitted_events) def rate_limit( - names: Union[str, List[str]], + names: Union[str, list[str]], occupy: int = 1, timeout_seconds: Optional[float] = None, create_if_missing: Optional[bool] = None, @@ -166,4 +145,4 @@ def rate_limit( create_if_missing=create_if_missing, strict=strict, ) - _emit_concurrency_acquisition_events(limits, occupy) + emit_concurrency_acquisition_events(limits, occupy) diff --git a/src/prefect/concurrency/v1/asyncio.py b/src/prefect/concurrency/v1/_asyncio.py similarity index 65% rename from src/prefect/concurrency/v1/asyncio.py rename to src/prefect/concurrency/v1/_asyncio.py index 7f888adc71728..a1ef1731f54b5 100644 --- a/src/prefect/concurrency/v1/asyncio.py +++ b/src/prefect/concurrency/v1/_asyncio.py @@ -1,28 +1,22 @@ import asyncio +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from typing import AsyncGenerator, List, Optional, Union, cast +from typing import TYPE_CHECKING, Optional, Union from uuid import UUID import anyio import httpx import pendulum -from ...client.schemas.responses import MinimalConcurrencyLimitResponse - -try: - from pendulum import Interval -except ImportError: - # pendulum < 3 - from pendulum.period import Period as Interval # type: ignore - from prefect.client.orchestration import get_client +from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse from prefect.utilities.asyncutils import sync_compatible -from .context import ConcurrencyContext -from .events import ( - _emit_concurrency_acquisition_events, - _emit_concurrency_release_events, +from ._events import ( + emit_concurrency_acquisition_events, + emit_concurrency_release_events, ) +from .context import ConcurrencyContext from .services import ConcurrencySlotAcquisitionService @@ -36,7 +30,7 @@ class AcquireConcurrencySlotTimeoutError(TimeoutError): @asynccontextmanager async def concurrency( - names: Union[str, List[str]], + names: Union[str, list[str]], task_run_id: UUID, timeout_seconds: Optional[float] = None, ) -> AsyncGenerator[None, None]: @@ -55,7 +49,7 @@ async def concurrency( Example: A simple example of using the async `concurrency` context manager: ```python - from prefect.concurrency.v1.asyncio import concurrency + from prefect.concurrency.v1._asyncio import concurrency async def resource_heavy(): async with concurrency("test", task_run_id): @@ -69,24 +63,30 @@ async def main(): yield return - names_normalized: List[str] = names if isinstance(names, list) else [names] + names_normalized: list[str] = names if isinstance(names, list) else [names] - limits = await _acquire_concurrency_slots( + acquire_slots = acquire_concurrency_slots( names_normalized, task_run_id=task_run_id, timeout_seconds=timeout_seconds, ) + if TYPE_CHECKING: + assert not isinstance(acquire_slots, list) + limits = await acquire_slots acquisition_time = pendulum.now("UTC") - emitted_events = _emit_concurrency_acquisition_events(limits, task_run_id) + emitted_events = emit_concurrency_acquisition_events(limits, task_run_id) try: yield finally: - occupancy_period = cast(Interval, (pendulum.now("UTC") - acquisition_time)) + occupancy_period = pendulum.now("UTC") - acquisition_time try: - await _release_concurrency_slots( + release_slots = release_concurrency_slots( names_normalized, task_run_id, occupancy_period.total_seconds() ) + if TYPE_CHECKING: + assert not isinstance(release_slots, list) + await release_slots except anyio.get_cancelled_exc_class(): # The task was cancelled before it could release the slots. Add the # slots to the cleanup list so they can be released when the @@ -96,38 +96,35 @@ async def main(): (names_normalized, occupancy_period.total_seconds(), task_run_id) ) - _emit_concurrency_release_events(limits, emitted_events, task_run_id) + emit_concurrency_release_events(limits, emitted_events, task_run_id) @sync_compatible -async def _acquire_concurrency_slots( - names: List[str], +async def acquire_concurrency_slots( + names: list[str], task_run_id: UUID, timeout_seconds: Optional[float] = None, -) -> List[MinimalConcurrencyLimitResponse]: +) -> list[MinimalConcurrencyLimitResponse]: service = ConcurrencySlotAcquisitionService.instance(frozenset(names)) future = service.send((task_run_id, timeout_seconds)) - response_or_exception = await asyncio.wrap_future(future) - - if isinstance(response_or_exception, Exception): - if isinstance(response_or_exception, TimeoutError): - raise AcquireConcurrencySlotTimeoutError( - f"Attempt to acquire concurrency limits timed out after {timeout_seconds} second(s)" - ) from response_or_exception - + try: + response = await asyncio.wrap_future(future) + except TimeoutError as timeout: + raise AcquireConcurrencySlotTimeoutError( + f"Attempt to acquire concurrency limits timed out after {timeout_seconds} second(s)" + ) from timeout + except Exception as exc: raise ConcurrencySlotAcquisitionError( f"Unable to acquire concurrency limits {names!r}" - ) from response_or_exception - - return _response_to_concurrency_limit_response(response_or_exception) + ) from exc + else: + return _response_to_concurrency_limit_response(response) @sync_compatible -async def _release_concurrency_slots( - names: List[str], - task_run_id: UUID, - occupancy_seconds: float, -) -> List[MinimalConcurrencyLimitResponse]: +async def release_concurrency_slots( + names: list[str], task_run_id: UUID, occupancy_seconds: float +) -> list[MinimalConcurrencyLimitResponse]: async with get_client() as client: response = await client.decrement_v1_concurrency_slots( names=names, @@ -139,8 +136,8 @@ async def _release_concurrency_slots( def _response_to_concurrency_limit_response( response: httpx.Response, -) -> List[MinimalConcurrencyLimitResponse]: - data = response.json() or [] +) -> list[MinimalConcurrencyLimitResponse]: + data: list[MinimalConcurrencyLimitResponse] = response.json() or [] return [ MinimalConcurrencyLimitResponse.model_validate(limit) for limit in data if data ] diff --git a/src/prefect/concurrency/v1/events.py b/src/prefect/concurrency/v1/_events.py similarity index 65% rename from src/prefect/concurrency/v1/events.py rename to src/prefect/concurrency/v1/_events.py index 3fa5193e6fea8..f3924cb1a5d1e 100644 --- a/src/prefect/concurrency/v1/events.py +++ b/src/prefect/concurrency/v1/_events.py @@ -1,18 +1,18 @@ -from typing import Dict, List, Literal, Optional, Union +from typing import Literal, Optional, Union from uuid import UUID from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse from prefect.events import Event, RelatedResource, emit_event -def _emit_concurrency_event( +def emit_concurrency_event( phase: Union[Literal["acquired"], Literal["released"]], primary_limit: MinimalConcurrencyLimitResponse, - related_limits: List[MinimalConcurrencyLimitResponse], + related_limits: list[MinimalConcurrencyLimitResponse], task_run_id: UUID, follows: Union[Event, None] = None, ) -> Union[Event, None]: - resource: Dict[str, str] = { + resource: dict[str, str] = { "prefect.resource.id": f"prefect.concurrency-limit.v1.{primary_limit.id}", "prefect.resource.name": primary_limit.name, "limit": str(primary_limit.limit), @@ -38,24 +38,22 @@ def _emit_concurrency_event( ) -def _emit_concurrency_acquisition_events( - limits: List[MinimalConcurrencyLimitResponse], +def emit_concurrency_acquisition_events( + limits: list[MinimalConcurrencyLimitResponse], task_run_id: UUID, -) -> Dict[UUID, Optional[Event]]: - events = {} +) -> dict[UUID, Optional[Event]]: + events: dict[UUID, Optional[Event]] = {} for limit in limits: - event = _emit_concurrency_event("acquired", limit, limits, task_run_id) + event = emit_concurrency_event("acquired", limit, limits, task_run_id) events[limit.id] = event return events -def _emit_concurrency_release_events( - limits: List[MinimalConcurrencyLimitResponse], - events: Dict[UUID, Optional[Event]], +def emit_concurrency_release_events( + limits: list[MinimalConcurrencyLimitResponse], + events: dict[UUID, Optional[Event]], task_run_id: UUID, ) -> None: for limit in limits: - _emit_concurrency_event( - "released", limit, limits, task_run_id, events[limit.id] - ) + emit_concurrency_event("released", limit, limits, task_run_id, events[limit.id]) diff --git a/src/prefect/concurrency/v1/context.py b/src/prefect/concurrency/v1/context.py index f413c84ed1f46..faaac13a4523f 100644 --- a/src/prefect/concurrency/v1/context.py +++ b/src/prefect/concurrency/v1/context.py @@ -1,20 +1,22 @@ from contextvars import ContextVar -from typing import List, Tuple +from typing import Any, ClassVar from uuid import UUID +from typing_extensions import Self + from prefect.client.orchestration import get_client from prefect.context import ContextModel, Field class ConcurrencyContext(ContextModel): - __var__: ContextVar = ContextVar("concurrency_v1") + __var__: ClassVar[ContextVar[Self]] = ContextVar("concurrency_v1") # Track the limits that have been acquired but were not able to be released # due to cancellation or some other error. These limits are released when # the context manager exits. - cleanup_slots: List[Tuple[List[str], float, UUID]] = Field(default_factory=list) + cleanup_slots: list[tuple[list[str], float, UUID]] = Field(default_factory=list) - def __exit__(self, *exc_info): + def __exit__(self, *exc_info: Any) -> None: if self.cleanup_slots: with get_client(sync_client=True) as client: for names, occupancy_seconds, task_run_id in self.cleanup_slots: diff --git a/src/prefect/concurrency/v1/services.py b/src/prefect/concurrency/v1/services.py index 1199c7ef33733..ad8d4b742b45d 100644 --- a/src/prefect/concurrency/v1/services.py +++ b/src/prefect/concurrency/v1/services.py @@ -1,21 +1,16 @@ import asyncio -import concurrent.futures +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from json import JSONDecodeError -from typing import ( - TYPE_CHECKING, - AsyncGenerator, - FrozenSet, - Optional, - Tuple, -) +from typing import TYPE_CHECKING, Optional from uuid import UUID import httpx from starlette import status +from typing_extensions import Unpack from prefect._internal.concurrency import logger -from prefect._internal.concurrency.services import QueueService +from prefect._internal.concurrency.services import FutureQueueService from prefect.client.orchestration import get_client from prefect.utilities.timeout import timeout_async @@ -27,11 +22,13 @@ class ConcurrencySlotAcquisitionServiceError(Exception): """Raised when an error occurs while acquiring concurrency slots.""" -class ConcurrencySlotAcquisitionService(QueueService): - def __init__(self, concurrency_limit_names: FrozenSet[str]): +class ConcurrencySlotAcquisitionService( + FutureQueueService[Unpack[tuple[UUID, Optional[float]]], httpx.Response] +): + def __init__(self, concurrency_limit_names: frozenset[str]) -> None: super().__init__(concurrency_limit_names) - self._client: "PrefectClient" - self.concurrency_limit_names = sorted(list(concurrency_limit_names)) + self._client: PrefectClient + self.concurrency_limit_names: list[str] = sorted(list(concurrency_limit_names)) @asynccontextmanager async def _lifespan(self) -> AsyncGenerator[None, None]: @@ -39,78 +36,35 @@ async def _lifespan(self) -> AsyncGenerator[None, None]: self._client = client yield - async def _handle( - self, - item: Tuple[ - UUID, - concurrent.futures.Future, - Optional[float], - ], - ) -> None: - task_run_id, future, timeout_seconds = item - try: - response = await self.acquire_slots(task_run_id, timeout_seconds) - except Exception as exc: - # If the request to the increment endpoint fails in a non-standard - # way, we need to set the future's result so that the caller can - # handle the exception and then re-raise. - future.set_result(exc) - raise exc - else: - future.set_result(response) - - async def acquire_slots( - self, - task_run_id: UUID, - timeout_seconds: Optional[float] = None, + async def acquire( + self, task_run_id: UUID, timeout_seconds: Optional[float] = None ) -> httpx.Response: with timeout_async(seconds=timeout_seconds): while True: try: - response = await self._client.increment_v1_concurrency_slots( + return await self._client.increment_v1_concurrency_slots( task_run_id=task_run_id, names=self.concurrency_limit_names, ) - except Exception as exc: - if ( - isinstance(exc, httpx.HTTPStatusError) - and exc.response.status_code == status.HTTP_423_LOCKED - ): - retry_after = exc.response.headers.get("Retry-After") - if retry_after: - retry_after = float(retry_after) - await asyncio.sleep(retry_after) - else: - # We received a 423 but no Retry-After header. This - # should indicate that the server told us to abort - # because the concurrency limit is set to 0, i.e. - # effectively disabled. - try: - reason = exc.response.json()["detail"] - except (JSONDecodeError, KeyError): - logger.error( - "Failed to parse response from concurrency limit 423 Locked response: %s", - exc.response.content, - ) - reason = "Concurrency limit is locked (server did not specify the reason)" - raise ConcurrencySlotAcquisitionServiceError( - reason - ) from exc + except httpx.HTTPStatusError as exc: + if not exc.response.status_code == status.HTTP_423_LOCKED: + raise + retry_after = exc.response.headers.get("Retry-After") + if retry_after: + retry_after = float(retry_after) + await asyncio.sleep(retry_after) else: - raise exc # type: ignore - else: - return response - - def send(self, item: Tuple[UUID, Optional[float]]) -> concurrent.futures.Future: - with self._lock: - if self._stopped: - raise RuntimeError("Cannot put items in a stopped service instance.") - - logger.debug("Service %r enqueuing item %r", self, item) - future: concurrent.futures.Future = concurrent.futures.Future() - - task_run_id, timeout_seconds = item - self._queue.put_nowait((task_run_id, future, timeout_seconds)) - - return future + # We received a 423 but no Retry-After header. This + # should indicate that the server told us to abort + # because the concurrency limit is set to 0, i.e. + # effectively disabled. + try: + reason = exc.response.json()["detail"] + except (JSONDecodeError, KeyError): + logger.error( + "Failed to parse response from concurrency limit 423 Locked response: %s", + exc.response.content, + ) + reason = "Concurrency limit is locked (server did not specify the reason)" + raise ConcurrencySlotAcquisitionServiceError(reason) from exc diff --git a/src/prefect/concurrency/v1/sync.py b/src/prefect/concurrency/v1/sync.py index 6e557b344502b..287de878e4be8 100644 --- a/src/prefect/concurrency/v1/sync.py +++ b/src/prefect/concurrency/v1/sync.py @@ -1,31 +1,15 @@ +import asyncio +from collections.abc import Generator from contextlib import contextmanager -from typing import ( - Generator, - List, - Optional, - TypeVar, - Union, - cast, -) +from typing import Optional, TypeVar, Union from uuid import UUID import pendulum -from ...client.schemas.responses import MinimalConcurrencyLimitResponse - -try: - from pendulum import Interval -except ImportError: - # pendulum < 3 - from pendulum.period import Period as Interval # type: ignore - -from .asyncio import ( - _acquire_concurrency_slots, - _release_concurrency_slots, -) -from .events import ( - _emit_concurrency_acquisition_events, - _emit_concurrency_release_events, +from ._asyncio import acquire_concurrency_slots, release_concurrency_slots +from ._events import ( + emit_concurrency_acquisition_events, + emit_concurrency_release_events, ) T = TypeVar("T") @@ -33,7 +17,7 @@ @contextmanager def concurrency( - names: Union[str, List[str]], + names: Union[str, list[str]], task_run_id: UUID, timeout_seconds: Optional[float] = None, ) -> Generator[None, None, None]: @@ -69,23 +53,20 @@ def main(): names = names if isinstance(names, list) else [names] - limits: List[MinimalConcurrencyLimitResponse] = _acquire_concurrency_slots( - names, - timeout_seconds=timeout_seconds, - task_run_id=task_run_id, - _sync=True, + force = {"_sync": True} + result = acquire_concurrency_slots( + names, timeout_seconds=timeout_seconds, task_run_id=task_run_id, **force ) + assert not asyncio.iscoroutine(result) + limits = result acquisition_time = pendulum.now("UTC") - emitted_events = _emit_concurrency_acquisition_events(limits, task_run_id) + emitted_events = emit_concurrency_acquisition_events(limits, task_run_id) try: yield finally: - occupancy_period = cast(Interval, pendulum.now("UTC") - acquisition_time) - _release_concurrency_slots( - names, - task_run_id, - occupancy_period.total_seconds(), - _sync=True, + occupancy_period = pendulum.now("UTC") - acquisition_time + release_concurrency_slots( + names, task_run_id, occupancy_period.total_seconds(), **force ) - _emit_concurrency_release_events(limits, emitted_events, task_run_id) + emit_concurrency_release_events(limits, emitted_events, task_run_id) diff --git a/src/prefect/context.py b/src/prefect/context.py index 287b9b58e1381..31298b0b1b438 100644 --- a/src/prefect/context.py +++ b/src/prefect/context.py @@ -9,21 +9,10 @@ import os import sys import warnings +from collections.abc import AsyncGenerator, Generator, Mapping from contextlib import ExitStack, asynccontextmanager, contextmanager from contextvars import ContextVar, Token -from typing import ( - TYPE_CHECKING, - Any, - AsyncGenerator, - Dict, - Generator, - Mapping, - Optional, - Set, - Type, - TypeVar, - Union, -) +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, TypeVar, Union from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from typing_extensions import Self @@ -58,13 +47,8 @@ from prefect.flows import Flow from prefect.tasks import Task -# Define the global settings context variable -# This will be populated downstream but must be null here to facilitate loading the -# default settings. -GLOBAL_SETTINGS_CONTEXT = None # type: ignore - -def serialize_context() -> Dict[str, Any]: +def serialize_context() -> dict[str, Any]: """ Serialize the current context for use in a remote execution environment. """ @@ -84,9 +68,9 @@ def serialize_context() -> Dict[str, Any]: @contextmanager def hydrated_context( - serialized_context: Optional[Dict[str, Any]] = None, + serialized_context: Optional[dict[str, Any]] = None, client: Union[PrefectClient, SyncPrefectClient, None] = None, -): +) -> Generator[None, Any, None]: with ExitStack() as stack: if serialized_context: # Set up settings context @@ -123,10 +107,15 @@ class ContextModel(BaseModel): a context manager """ + if TYPE_CHECKING: + # subclasses can pass through keyword arguments to the pydantic base model + def __init__(self, **kwargs: Any) -> None: + ... + # The context variable for storing data must be defined by the child class - __var__: ContextVar[Self] + __var__: ClassVar[ContextVar[Self]] _token: Optional[Token[Self]] = PrivateAttr(None) - model_config = ConfigDict( + model_config: ClassVar[ConfigDict] = ConfigDict( arbitrary_types_allowed=True, extra="forbid", ) @@ -139,7 +128,7 @@ def __enter__(self) -> Self: self._token = self.__var__.set(self) return self - def __exit__(self, *_): + def __exit__(self, *_: Any) -> None: if not self._token: raise RuntimeError( "Asymmetric use of context. Context exit called without an enter." @@ -148,13 +137,13 @@ def __exit__(self, *_): self._token = None @classmethod - def get(cls: Type[Self]) -> Optional[Self]: + def get(cls: type[Self]) -> Optional[Self]: """Get the current context instance""" return cls.__var__.get(None) def model_copy( self: Self, *, update: Optional[Mapping[str, Any]] = None, deep: bool = False - ): + ) -> Self: """ Duplicate the context model, optionally choosing which fields to include, exclude, or change. @@ -173,7 +162,7 @@ def model_copy( new._token = None return new - def serialize(self, include_secrets: bool = True) -> Dict[str, Any]: + def serialize(self, include_secrets: bool = True) -> dict[str, Any]: """ Serialize the context model to a dictionary that can be pickled with cloudpickle. """ @@ -202,19 +191,19 @@ class SyncClientContext(ContextModel): assert c1 is ctx.client """ - __var__: ContextVar[Self] = ContextVar("sync-client-context") + __var__: ClassVar[ContextVar[Self]] = ContextVar("sync-client-context") client: SyncPrefectClient _httpx_settings: Optional[dict[str, Any]] = PrivateAttr(None) _context_stack: int = PrivateAttr(0) - def __init__(self, httpx_settings: Optional[dict[str, Any]] = None): + def __init__(self, httpx_settings: Optional[dict[str, Any]] = None) -> None: super().__init__( - client=get_client(sync_client=True, httpx_settings=httpx_settings), # type: ignore[reportCallIssue] + client=get_client(sync_client=True, httpx_settings=httpx_settings), ) self._httpx_settings = httpx_settings self._context_stack = 0 - def __enter__(self): + def __enter__(self) -> Self: self._context_stack += 1 if self._context_stack == 1: self.client.__enter__() @@ -223,20 +212,20 @@ def __enter__(self): else: return self - def __exit__(self, *exc_info: Any): + def __exit__(self, *exc_info: Any) -> None: self._context_stack -= 1 if self._context_stack == 0: - self.client.__exit__(*exc_info) # type: ignore[reportUnknownMemberType] - return super().__exit__(*exc_info) # type: ignore[reportUnknownMemberType] + self.client.__exit__(*exc_info) + return super().__exit__(*exc_info) @classmethod @contextmanager - def get_or_create(cls) -> Generator["SyncClientContext", None, None]: - ctx = SyncClientContext.get() + def get_or_create(cls) -> Generator[Self, None, None]: + ctx = cls.get() if ctx: yield ctx else: - with SyncClientContext() as ctx: + with cls() as ctx: yield ctx @@ -260,14 +249,14 @@ class AsyncClientContext(ContextModel): assert c1 is ctx.client """ - __var__ = ContextVar("async-client-context") + __var__: ClassVar[ContextVar[Self]] = ContextVar("async-client-context") client: PrefectClient _httpx_settings: Optional[dict[str, Any]] = PrivateAttr(None) _context_stack: int = PrivateAttr(0) def __init__(self, httpx_settings: Optional[dict[str, Any]] = None): super().__init__( - client=get_client(sync_client=False, httpx_settings=httpx_settings), # type: ignore[reportCallIssue] + client=get_client(sync_client=False, httpx_settings=httpx_settings) ) self._httpx_settings = httpx_settings self._context_stack = 0 @@ -284,8 +273,8 @@ async def __aenter__(self: Self) -> Self: async def __aexit__(self: Self, *exc_info: Any) -> None: self._context_stack -= 1 if self._context_stack == 0: - await self.client.__aexit__(*exc_info) # type: ignore[reportUnknownMemberType] - return super().__exit__(*exc_info) # type: ignore[reportUnknownMemberType] + await self.client.__aexit__(*exc_info) + return super().__exit__(*exc_info) @classmethod @asynccontextmanager @@ -308,16 +297,16 @@ class RunContext(ContextModel): client: The Prefect client instance being used for API communication """ - def __init__(self, *args: Any, **kwargs: Any): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) start_client_metrics_server() start_time: DateTime = Field(default_factory=lambda: DateTime.now("UTC")) - input_keyset: Optional[Dict[str, Dict[str, str]]] = None + input_keyset: Optional[dict[str, dict[str, str]]] = None client: Union[PrefectClient, SyncPrefectClient] - def serialize(self: Self, include_secrets: bool = True) -> Dict[str, Any]: + def serialize(self: Self, include_secrets: bool = True) -> dict[str, Any]: return self.model_dump( include={"start_time", "input_keyset"}, exclude_unset=True, @@ -344,7 +333,7 @@ class EngineContext(RunContext): flow_run: Optional[FlowRun] = None task_runner: TaskRunner[Any] log_prints: bool = False - parameters: Optional[Dict[str, Any]] = None + parameters: Optional[dict[str, Any]] = None # Flag signaling if the flow run context has been serialized and sent # to remote infrastructure. @@ -355,10 +344,10 @@ class EngineContext(RunContext): persist_result: bool = Field(default_factory=get_default_persist_setting) # Counter for task calls allowing unique - task_run_dynamic_keys: Dict[str, Union[str, int]] = Field(default_factory=dict) + task_run_dynamic_keys: dict[str, Union[str, int]] = Field(default_factory=dict) # Counter for flow pauses - observed_flow_pauses: Dict[str, int] = Field(default_factory=dict) + observed_flow_pauses: dict[str, int] = Field(default_factory=dict) # Tracking for result from task runs in this flow run for dependency tracking # Holds the ID of the object returned by the task run and task run state @@ -367,9 +356,9 @@ class EngineContext(RunContext): # Events worker to emit events events: Optional[EventsWorker] = None - __var__: ContextVar[Self] = ContextVar("flow_run") + __var__: ClassVar[ContextVar[Self]] = ContextVar("flow_run") - def serialize(self: Self, include_secrets: bool = True) -> Dict[str, Any]: + def serialize(self: Self, include_secrets: bool = True) -> dict[str, Any]: return self.model_dump( include={ "flow_run", @@ -403,15 +392,15 @@ class TaskRunContext(RunContext): task: "Task[Any, Any]" task_run: TaskRun log_prints: bool = False - parameters: Dict[str, Any] + parameters: dict[str, Any] # Result handling result_store: ResultStore persist_result: bool = Field(default_factory=get_default_persist_setting_for_tasks) - __var__ = ContextVar("task_run") + __var__: ClassVar[ContextVar[Self]] = ContextVar("task_run") - def serialize(self: Self, include_secrets: bool = True) -> Dict[str, Any]: + def serialize(self: Self, include_secrets: bool = True) -> dict[str, Any]: return self.model_dump( include={ "task_run", @@ -437,14 +426,14 @@ class TagsContext(ContextModel): current_tags: A set of current tags in the context """ - current_tags: Set[str] = Field(default_factory=set) + current_tags: set[str] = Field(default_factory=set) @classmethod - def get(cls) -> "TagsContext": + def get(cls) -> Self: # Return an empty `TagsContext` instead of `None` if no context exists - return cls.__var__.get(TagsContext()) + return cls.__var__.get(cls()) - __var__: ContextVar[Self] = ContextVar("tags") + __var__: ClassVar[ContextVar[Self]] = ContextVar("tags") class SettingsContext(ContextModel): @@ -461,15 +450,21 @@ class SettingsContext(ContextModel): profile: Profile settings: Settings - __var__: ContextVar[Self] = ContextVar("settings") + __var__: ClassVar[ContextVar[Self]] = ContextVar("settings") def __hash__(self: Self) -> int: return hash(self.settings) @classmethod - def get(cls) -> "SettingsContext": + def get(cls) -> Optional["SettingsContext"]: # Return the global context instead of `None` if no context exists - return super().get() or GLOBAL_SETTINGS_CONTEXT + try: + return super().get() or GLOBAL_SETTINGS_CONTEXT + except NameError: + # GLOBAL_SETTINGS_CONTEXT has not yet been set; in order to create + # it profiles need to be loaded, and that process calls + # SettingsContext.get(). + return None def get_run_context() -> Union[FlowRunContext, TaskRunContext]: @@ -512,7 +507,7 @@ def get_settings_context() -> SettingsContext: @contextmanager -def tags(*new_tags: str) -> Generator[Set[str], None, None]: +def tags(*new_tags: str) -> Generator[set[str], None, None]: """ Context manager to add tags to flow and task run calls. @@ -570,10 +565,10 @@ def tags(*new_tags: str) -> Generator[Set[str], None, None]: @contextmanager def use_profile( - profile: Union[Profile, str, Any], + profile: Union[Profile, str], override_environment_variables: bool = False, include_current_context: bool = True, -): +) -> Generator[SettingsContext, Any, None]: """ Switch to a profile for the duration of this context. @@ -595,11 +590,12 @@ def use_profile( profiles = prefect.settings.load_profiles() profile = profiles[profile] - if not isinstance(profile, Profile): - raise TypeError( - f"Unexpected type {type(profile).__name__!r} for `profile`. " - "Expected 'str' or 'Profile'." - ) + if not TYPE_CHECKING: + if not isinstance(profile, Profile): + raise TypeError( + f"Unexpected type {type(profile).__name__!r} for `profile`. " + "Expected 'str' or 'Profile'." + ) # Create a copy of the profiles settings as we will mutate it profile_settings = profile.settings.copy() @@ -620,7 +616,7 @@ def use_profile( yield ctx -def root_settings_context(): +def root_settings_context() -> SettingsContext: """ Return the settings context that will exist as the root context for the module. @@ -670,9 +666,9 @@ def root_settings_context(): # an override in the `SettingsContext.get` method. -GLOBAL_SETTINGS_CONTEXT: SettingsContext = root_settings_context() # type: ignore[reportConstantRedefinition] +GLOBAL_SETTINGS_CONTEXT: SettingsContext = root_settings_context() # 2024-07-02: This surfaces an actionable error message for removed objects # in Prefect 3.0 upgrade. -__getattr__ = getattr_migration(__name__) +__getattr__: Callable[[str], Any] = getattr_migration(__name__) diff --git a/src/prefect/runner/runner.py b/src/prefect/runner/runner.py index 28f1893260c85..8af7fcc314c09 100644 --- a/src/prefect/runner/runner.py +++ b/src/prefect/runner/runner.py @@ -82,7 +82,7 @@ def fast_flow(): StateType, ) from prefect.client.schemas.objects import Flow as APIFlow -from prefect.concurrency.asyncio import ( +from prefect.concurrency._asyncio import ( AcquireConcurrencySlotTimeoutError, ConcurrencySlotAcquisitionError, ) diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index e20e6a71733c1..d22d4ce2147db 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -37,7 +37,7 @@ from prefect.client.schemas import TaskRun from prefect.client.schemas.objects import State, TaskRunInput from prefect.concurrency.context import ConcurrencyContext -from prefect.concurrency.v1.asyncio import concurrency as aconcurrency +from prefect.concurrency.v1._asyncio import concurrency as aconcurrency from prefect.concurrency.v1.context import ConcurrencyContext as ConcurrencyContextV1 from prefect.concurrency.v1.sync import concurrency from prefect.context import ( diff --git a/tests/concurrency/test_acquire_concurrency_slots.py b/tests/concurrency/test_acquire_concurrency_slots.py index a6d1817051a71..1848d67384f22 100644 --- a/tests/concurrency/test_acquire_concurrency_slots.py +++ b/tests/concurrency/test_acquire_concurrency_slots.py @@ -4,9 +4,7 @@ from httpx import Response from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse -from prefect.concurrency.asyncio import ( - _aacquire_concurrency_slots, -) +from prefect.concurrency._asyncio import aacquire_concurrency_slots async def test_calls_increment_client_method(): @@ -23,7 +21,7 @@ async def test_calls_increment_client_method(): ) increment_concurrency_slots.return_value = response - await _aacquire_concurrency_slots( + await aacquire_concurrency_slots( names=["test-1", "test-2"], slots=1, mode="concurrency" ) increment_concurrency_slots.assert_called_once_with( @@ -48,5 +46,5 @@ async def test_returns_minimal_concurrency_limit(): ) increment_concurrency_slots.return_value = response - result = await _aacquire_concurrency_slots(["test-1", "test-2"], 1) + result = await aacquire_concurrency_slots(["test-1", "test-2"], 1) assert result == limits diff --git a/tests/concurrency/test_concurrency_asyncio.py b/tests/concurrency/test_concurrency_asyncio.py index ff1306e95e014..aa4ed2fc8cec6 100644 --- a/tests/concurrency/test_concurrency_asyncio.py +++ b/tests/concurrency/test_concurrency_asyncio.py @@ -5,10 +5,10 @@ from starlette import status from prefect import flow, task -from prefect.concurrency.asyncio import ( +from prefect.concurrency._asyncio import ( ConcurrencySlotAcquisitionError, - _aacquire_concurrency_slots, - _arelease_concurrency_slots, + aacquire_concurrency_slots, + arelease_concurrency_slots, concurrency, rate_limit, ) @@ -28,12 +28,12 @@ async def resource_heavy(): assert not executed with mock.patch( - "prefect.concurrency.asyncio._aacquire_concurrency_slots", - wraps=_aacquire_concurrency_slots, + "prefect.concurrency._asyncio.aacquire_concurrency_slots", + wraps=aacquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.asyncio._arelease_concurrency_slots", - wraps=_arelease_concurrency_slots, + "prefect.concurrency._asyncio.arelease_concurrency_slots", + wraps=arelease_concurrency_slots, ) as release_spy: await resource_heavy() @@ -221,12 +221,12 @@ async def resource_heavy(): assert not executed with mock.patch( - "prefect.concurrency.asyncio._aacquire_concurrency_slots", - wraps=_aacquire_concurrency_slots, + "prefect.concurrency._asyncio.aacquire_concurrency_slots", + wraps=aacquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.asyncio._arelease_concurrency_slots", - wraps=_arelease_concurrency_slots, + "prefect.concurrency._asyncio.arelease_concurrency_slots", + wraps=arelease_concurrency_slots, ) as release_spy: await resource_heavy() @@ -377,7 +377,7 @@ async def resource_heavy(): wraps=lambda *args, **kwargs: None, ) as acquire_spy: with mock.patch( - "prefect.concurrency.sync._arelease_concurrency_slots", + "prefect.concurrency.sync.arelease_concurrency_slots", wraps=lambda *args, **kwargs: None, ) as release_spy: await resource_heavy() @@ -401,12 +401,12 @@ async def resource_heavy(): assert not executed with mock.patch( - "prefect.concurrency.asyncio._aacquire_concurrency_slots", - wraps=_aacquire_concurrency_slots, + "prefect.concurrency._asyncio.aacquire_concurrency_slots", + wraps=aacquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.asyncio._arelease_concurrency_slots", - wraps=_arelease_concurrency_slots, + "prefect.concurrency._asyncio.arelease_concurrency_slots", + wraps=arelease_concurrency_slots, ) as release_spy: await resource_heavy() @@ -447,7 +447,7 @@ async def resource_heavy(): wraps=lambda *args, **kwargs: None, ) as acquire_spy: with mock.patch( - "prefect.concurrency.sync._arelease_concurrency_slots", + "prefect.concurrency.sync.arelease_concurrency_slots", wraps=lambda *args, **kwargs: None, ) as release_spy: await resource_heavy() diff --git a/tests/concurrency/test_concurrency_slot_acquisition_service.py b/tests/concurrency/test_concurrency_slot_acquisition_service.py index 668099cf64a42..1e17b5c389389 100644 --- a/tests/concurrency/test_concurrency_slot_acquisition_service.py +++ b/tests/concurrency/test_concurrency_slot_acquisition_service.py @@ -69,7 +69,7 @@ async def test_retries_failed_call_respects_retry_after_header(mocked_client): limit_names = sorted(["api", "database"]) service = ConcurrencySlotAcquisitionService.instance(frozenset(limit_names)) - with mock.patch("prefect.concurrency.asyncio.asyncio.sleep") as sleep: + with mock.patch("prefect.concurrency._asyncio.asyncio.sleep") as sleep: future = service.send((1, "concurrency", None, True, None)) await service.drain() returned_response = await asyncio.wrap_future(future) @@ -111,7 +111,8 @@ async def test_basic_exception_returns_exception(mocked_client): future = service.send((1, "concurrency", None, True, None)) await service.drain() - exception = await asyncio.wrap_future(future) - assert isinstance(exception, Exception) - assert exception == exc + with pytest.raises(Exception) as info: + await asyncio.wrap_future(future) + + assert info.value == exc diff --git a/tests/concurrency/test_concurrency_sync.py b/tests/concurrency/test_concurrency_sync.py index 13724a0972147..181d3cb68f3fe 100644 --- a/tests/concurrency/test_concurrency_sync.py +++ b/tests/concurrency/test_concurrency_sync.py @@ -5,7 +5,7 @@ from starlette import status from prefect import flow, task -from prefect.concurrency.asyncio import ConcurrencySlotAcquisitionError +from prefect.concurrency._asyncio import ConcurrencySlotAcquisitionError from prefect.concurrency.sync import ( _acquire_concurrency_slots, _release_concurrency_slots, diff --git a/tests/concurrency/test_context.py b/tests/concurrency/test_context.py index c1a0dc8262782..45ce531da4800 100644 --- a/tests/concurrency/test_context.py +++ b/tests/concurrency/test_context.py @@ -4,7 +4,7 @@ import pytest from prefect.client.orchestration import PrefectClient, get_client -from prefect.concurrency.asyncio import concurrency as aconcurrency +from prefect.concurrency._asyncio import concurrency as aconcurrency from prefect.concurrency.context import ConcurrencyContext from prefect.concurrency.sync import concurrency from prefect.server.schemas.core import ConcurrencyLimitV2 diff --git a/tests/concurrency/test_release_concurrency_slots.py b/tests/concurrency/test_release_concurrency_slots.py index 98d477f724f31..ee9225bb4700c 100644 --- a/tests/concurrency/test_release_concurrency_slots.py +++ b/tests/concurrency/test_release_concurrency_slots.py @@ -4,9 +4,7 @@ from httpx import Response from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse -from prefect.concurrency.asyncio import ( - _arelease_concurrency_slots, -) +from prefect.concurrency._asyncio import arelease_concurrency_slots async def test_calls_release_client_method(): @@ -23,7 +21,7 @@ async def test_calls_release_client_method(): ) client_release_concurrency_slots.return_value = response - await _arelease_concurrency_slots( + await arelease_concurrency_slots( names=["test-1", "test-2"], slots=1, occupancy_seconds=1.0 ) client_release_concurrency_slots.assert_called_once_with( @@ -47,5 +45,5 @@ async def test_returns_minimal_concurrency_limit(): ) client_release_concurrency_slots.return_value = response - result = await _arelease_concurrency_slots(["test-1", "test-2"], 1, 1.0) + result = await arelease_concurrency_slots(["test-1", "test-2"], 1, 1.0) assert result == limits diff --git a/tests/concurrency/v1/test_concurrency_asyncio.py b/tests/concurrency/v1/test_concurrency_asyncio.py index 2a39415226227..cf149a830cd81 100644 --- a/tests/concurrency/v1/test_concurrency_asyncio.py +++ b/tests/concurrency/v1/test_concurrency_asyncio.py @@ -6,11 +6,11 @@ from starlette import status from prefect import flow, task -from prefect.concurrency.v1.asyncio import ( +from prefect.concurrency.v1._asyncio import ( ConcurrencySlotAcquisitionError, - _acquire_concurrency_slots, - _release_concurrency_slots, + acquire_concurrency_slots, concurrency, + release_concurrency_slots, ) from prefect.events.clients import AssertingEventsClient from prefect.events.worker import EventsWorker @@ -29,12 +29,12 @@ async def resource_heavy(): assert not executed with mock.patch( - "prefect.concurrency.v1.asyncio._acquire_concurrency_slots", - wraps=_acquire_concurrency_slots, + "prefect.concurrency.v1._asyncio.acquire_concurrency_slots", + wraps=acquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.v1.asyncio._release_concurrency_slots", - wraps=_release_concurrency_slots, + "prefect.concurrency.v1._asyncio.release_concurrency_slots", + wraps=release_concurrency_slots, ) as release_spy: await resource_heavy() @@ -262,11 +262,11 @@ async def resource_heavy(): assert not executed with mock.patch( - "prefect.concurrency.v1.asyncio._acquire_concurrency_slots", + "prefect.concurrency.v1._asyncio.acquire_concurrency_slots", wraps=lambda *args, **kwargs: None, ) as acquire_spy: with mock.patch( - "prefect.concurrency.v1.asyncio._release_concurrency_slots", + "prefect.concurrency.v1._asyncio.release_concurrency_slots", wraps=lambda *args, **kwargs: None, ) as release_spy: await resource_heavy() diff --git a/tests/concurrency/v1/test_concurrency_limit_acquisition_service.py b/tests/concurrency/v1/test_concurrency_limit_acquisition_service.py index bcadee6bdd301..8b85475345df6 100644 --- a/tests/concurrency/v1/test_concurrency_limit_acquisition_service.py +++ b/tests/concurrency/v1/test_concurrency_limit_acquisition_service.py @@ -68,7 +68,7 @@ async def test_retries_failed_call_respects_retry_after_header(mocked_client): limit_names = sorted(["api", "database"]) service = ConcurrencySlotAcquisitionService.instance(frozenset(limit_names)) - with mock.patch("prefect.concurrency.v1.asyncio.asyncio.sleep") as sleep: + with mock.patch("prefect.concurrency.v1._asyncio.asyncio.sleep") as sleep: future = service.send((task_run_id, None)) service.drain() returned_response = await asyncio.wrap_future(future) @@ -112,7 +112,7 @@ async def test_basic_exception_returns_exception(mocked_client): future = service.send((task_run_id, None)) await service.drain() - exception = await asyncio.wrap_future(future) + with pytest.raises(Exception) as info: + await asyncio.wrap_future(future) - assert isinstance(exception, Exception) - assert exception == exc + assert info.value == exc diff --git a/tests/concurrency/v1/test_concurrency_sync.py b/tests/concurrency/v1/test_concurrency_sync.py index d4be641589563..bf16f1bb0d8ba 100644 --- a/tests/concurrency/v1/test_concurrency_sync.py +++ b/tests/concurrency/v1/test_concurrency_sync.py @@ -6,9 +6,9 @@ from starlette import status from prefect import flow, task -from prefect.concurrency.v1.asyncio import ( - _acquire_concurrency_slots, - _release_concurrency_slots, +from prefect.concurrency.v1._asyncio import ( + acquire_concurrency_slots, + release_concurrency_slots, ) from prefect.concurrency.v1.sync import concurrency from prefect.events.clients import AssertingEventsClient @@ -28,12 +28,12 @@ def resource_heavy(): assert not executed with mock.patch( - "prefect.concurrency.v1.sync._acquire_concurrency_slots", - wraps=_acquire_concurrency_slots, + "prefect.concurrency.v1.sync.acquire_concurrency_slots", + wraps=acquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.v1.sync._release_concurrency_slots", - wraps=_release_concurrency_slots, + "prefect.concurrency.v1.sync.release_concurrency_slots", + wraps=release_concurrency_slots, ) as release_spy: resource_heavy() @@ -201,11 +201,11 @@ def resource_heavy(): assert not executed with mock.patch( - "prefect.concurrency.v1.sync._acquire_concurrency_slots", + "prefect.concurrency.v1.sync.acquire_concurrency_slots", wraps=lambda *args, **kwargs: None, ) as acquire_spy: with mock.patch( - "prefect.concurrency.v1.sync._release_concurrency_slots", + "prefect.concurrency.v1.sync.release_concurrency_slots", wraps=lambda *args, **kwargs: None, ) as release_spy: resource_heavy() diff --git a/tests/concurrency/v1/test_context.py b/tests/concurrency/v1/test_context.py index 50a013a4680d9..27734e984ff63 100644 --- a/tests/concurrency/v1/test_context.py +++ b/tests/concurrency/v1/test_context.py @@ -5,7 +5,7 @@ import pytest from prefect.client.orchestration import PrefectClient, get_client -from prefect.concurrency.v1.asyncio import concurrency as aconcurrency +from prefect.concurrency.v1._asyncio import concurrency as aconcurrency from prefect.concurrency.v1.context import ConcurrencyContext from prefect.concurrency.v1.sync import concurrency from prefect.server.schemas.core import ConcurrencyLimit diff --git a/tests/concurrency/v1/test_decrement_concurrency_slots.py b/tests/concurrency/v1/test_decrement_concurrency_slots.py index 697c12b081a22..0a214f5e4d044 100644 --- a/tests/concurrency/v1/test_decrement_concurrency_slots.py +++ b/tests/concurrency/v1/test_decrement_concurrency_slots.py @@ -4,7 +4,7 @@ from httpx import Response from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse -from prefect.concurrency.v1.asyncio import _release_concurrency_slots +from prefect.concurrency.v1._asyncio import release_concurrency_slots async def test_calls_release_client_method(): @@ -23,7 +23,7 @@ async def test_calls_release_client_method(): ) client_decrement_v1_concurrency_slots.return_value = response - await _release_concurrency_slots( + await release_concurrency_slots( names=["test-1", "test-2"], task_run_id=task_run_id, occupancy_seconds=1.0 ) client_decrement_v1_concurrency_slots.assert_called_once_with( @@ -49,7 +49,7 @@ async def test_returns_minimal_concurrency_limit(): ) client_decrement_v1_concurrency_slots.return_value = response - result = await _release_concurrency_slots( + result = await release_concurrency_slots( ["test-1", "test-2"], task_run_id, 1.0, diff --git a/tests/concurrency/v1/test_increment_concurrency_limits.py b/tests/concurrency/v1/test_increment_concurrency_limits.py index 1ee3f08943493..b856cad8df7a4 100644 --- a/tests/concurrency/v1/test_increment_concurrency_limits.py +++ b/tests/concurrency/v1/test_increment_concurrency_limits.py @@ -4,9 +4,7 @@ from httpx import Response from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse -from prefect.concurrency.asyncio import ( - _aacquire_concurrency_slots, -) +from prefect.concurrency._asyncio import aacquire_concurrency_slots async def test_calls_increment_client_method(): @@ -27,7 +25,7 @@ async def test_calls_increment_client_method(): ) increment_concurrency_slots.return_value = response - await _aacquire_concurrency_slots( + await aacquire_concurrency_slots( names=["test-1", "test-2"], slots=1, mode="concurrency" ) increment_concurrency_slots.assert_called_once_with( @@ -56,5 +54,5 @@ async def test_returns_minimal_concurrency_limit(): ) increment_concurrency_slots.return_value = response - result = await _aacquire_concurrency_slots(["test-1", "test-2"], 1) + result = await aacquire_concurrency_slots(["test-1", "test-2"], 1) assert result == limits diff --git a/tests/test_flow_engine.py b/tests/test_flow_engine.py index 9ceeae8241bd9..267d92ef16656 100644 --- a/tests/test_flow_engine.py +++ b/tests/test_flow_engine.py @@ -16,7 +16,7 @@ from prefect.client.schemas.filters import FlowFilter, FlowRunFilter from prefect.client.schemas.objects import StateType from prefect.client.schemas.sorting import FlowRunSort -from prefect.concurrency.asyncio import concurrency as aconcurrency +from prefect.concurrency._asyncio import concurrency as aconcurrency from prefect.concurrency.sync import concurrency from prefect.context import ( FlowRunContext, diff --git a/tests/test_task_engine.py b/tests/test_task_engine.py index 711b19f0ae792..58cd9182e648e 100644 --- a/tests/test_task_engine.py +++ b/tests/test_task_engine.py @@ -17,11 +17,11 @@ from prefect.cache_policies import FLOW_PARAMETERS, INPUTS, TASK_SOURCE from prefect.client.orchestration import PrefectClient, SyncPrefectClient from prefect.client.schemas.objects import StateType -from prefect.concurrency.asyncio import concurrency as aconcurrency +from prefect.concurrency._asyncio import concurrency as aconcurrency from prefect.concurrency.sync import concurrency -from prefect.concurrency.v1.asyncio import ( - _acquire_concurrency_slots, - _release_concurrency_slots, +from prefect.concurrency.v1._asyncio import ( + acquire_concurrency_slots, + release_concurrency_slots, ) from prefect.context import ( EngineContext, @@ -34,10 +34,7 @@ from prefect.logging import get_run_logger from prefect.results import ResultRecord, ResultStore from prefect.server.schemas.core import ConcurrencyLimitV2 -from prefect.settings import ( - PREFECT_TASK_DEFAULT_RETRIES, - temporary_settings, -) +from prefect.settings import PREFECT_TASK_DEFAULT_RETRIES, temporary_settings from prefect.states import Completed, Running, State from prefect.task_engine import ( AsyncTaskRunEngine, @@ -2512,12 +2509,12 @@ async def bar(): return 42 with mock.patch( - "prefect.concurrency.v1.asyncio._acquire_concurrency_slots", - wraps=_acquire_concurrency_slots, + "prefect.concurrency.v1._asyncio.acquire_concurrency_slots", + wraps=acquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.v1.asyncio._release_concurrency_slots", - wraps=_release_concurrency_slots, + "prefect.concurrency.v1._asyncio.release_concurrency_slots", + wraps=release_concurrency_slots, ) as release_spy: await bar() @@ -2540,12 +2537,12 @@ def bar(): return 42 with mock.patch( - "prefect.concurrency.v1.sync._acquire_concurrency_slots", - wraps=_acquire_concurrency_slots, + "prefect.concurrency.v1.sync.acquire_concurrency_slots", + wraps=acquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.v1.sync._release_concurrency_slots", - wraps=_release_concurrency_slots, + "prefect.concurrency.v1.sync.release_concurrency_slots", + wraps=release_concurrency_slots, ) as release_spy: bar() @@ -2571,12 +2568,12 @@ def bar(): return 42 with mock.patch( - "prefect.concurrency.v1.sync._acquire_concurrency_slots", - wraps=_acquire_concurrency_slots, + "prefect.concurrency.v1.sync.acquire_concurrency_slots", + wraps=acquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.v1.sync._release_concurrency_slots", - wraps=_release_concurrency_slots, + "prefect.concurrency.v1.sync.release_concurrency_slots", + wraps=release_concurrency_slots, ) as release_spy: with tags("limit-tag"): bar() @@ -2603,12 +2600,12 @@ async def bar(): return 42 with mock.patch( - "prefect.concurrency.v1.asyncio._acquire_concurrency_slots", - wraps=_acquire_concurrency_slots, + "prefect.concurrency.v1._asyncio.acquire_concurrency_slots", + wraps=acquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.v1.asyncio._release_concurrency_slots", - wraps=_release_concurrency_slots, + "prefect.concurrency.v1._asyncio.release_concurrency_slots", + wraps=release_concurrency_slots, ) as release_spy: with tags("limit-tag"): await bar() @@ -2628,12 +2625,12 @@ async def bar(): return 42 with mock.patch( - "prefect.concurrency.v1.asyncio._acquire_concurrency_slots", - wraps=_acquire_concurrency_slots, + "prefect.concurrency.v1._asyncio.acquire_concurrency_slots", + wraps=acquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.v1.asyncio._release_concurrency_slots", - wraps=_release_concurrency_slots, + "prefect.concurrency.v1._asyncio.release_concurrency_slots", + wraps=release_concurrency_slots, ) as release_spy: await bar() @@ -2646,12 +2643,12 @@ def bar(): return 42 with mock.patch( - "prefect.concurrency.v1.sync._acquire_concurrency_slots", - wraps=_acquire_concurrency_slots, + "prefect.concurrency.v1.sync.acquire_concurrency_slots", + wraps=acquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.v1.sync._release_concurrency_slots", - wraps=_release_concurrency_slots, + "prefect.concurrency.v1.sync.release_concurrency_slots", + wraps=release_concurrency_slots, ) as release_spy: bar() @@ -2668,8 +2665,8 @@ async def bar(): return 42 with mock.patch( - "prefect.concurrency.v1.asyncio._acquire_concurrency_slots", - wraps=_acquire_concurrency_slots, + "prefect.concurrency.v1._asyncio.acquire_concurrency_slots", + wraps=acquire_concurrency_slots, ) as acquire_spy: await bar()