Skip to content

Commit

Permalink
Make flow/task engine release slots after a timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
bunchesofdonald committed Jul 29, 2024
1 parent 85bea69 commit 70e76c4
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 7 deletions.
37 changes: 37 additions & 0 deletions src/prefect/client/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3025,6 +3025,19 @@ async def increment_concurrency_slots(
async def release_concurrency_slots(
self, names: List[str], slots: int, occupancy_seconds: float
) -> httpx.Response:
"""
Release concurrency slots for the specified limits.
Args:
names (List[str]): A list of limit names for which to release slots.
slots (int): The number of concurrency slots to release.
occupancy_seconds (float): The duration in seconds that the slots
were occupied.
Returns:
httpx.Response: The HTTP response from the server.
"""

return await self._client.post(
"/v2/concurrency_limits/decrement",
json={
Expand Down Expand Up @@ -4068,3 +4081,27 @@ def create_artifact(
)

return Artifact.model_validate(response.json())

def release_concurrency_slots(
self, names: List[str], slots: int, occupancy_seconds: float
) -> httpx.Response:
"""
Release concurrency slots for the specified limits.
Args:
names (List[str]): A list of limit names for which to release slots.
slots (int): The number of concurrency slots to release.
occupancy_seconds (float): The duration in seconds that the slots
were occupied.
Returns:
httpx.Response: The HTTP response from the server.
"""
return self._client.post(
"/v2/concurrency_limits/decrement",
json={
"names": names,
"slots": slots,
"occupancy_seconds": occupancy_seconds,
},
)
21 changes: 16 additions & 5 deletions src/prefect/concurrency/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from contextlib import asynccontextmanager
from typing import AsyncGenerator, List, Literal, Optional, Union, cast

import anyio
import httpx
import pendulum

Expand All @@ -14,6 +15,7 @@
from prefect.client.orchestration import get_client
from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse

from .context import ConcurrencyContext
from .events import (
_emit_concurrency_acquisition_events,
_emit_concurrency_release_events,
Expand Down Expand Up @@ -137,11 +139,20 @@ async def _acquire_concurrency_slots(
async def _release_concurrency_slots(
names: List[str], slots: int, occupancy_seconds: float
) -> List[MinimalConcurrencyLimitResponse]:
async with get_client() as client:
response = await client.release_concurrency_slots(
names=names, slots=slots, occupancy_seconds=occupancy_seconds
)
return _response_to_minimal_concurrency_limit_response(response)
try:
async with get_client() as client:
response = await client.release_concurrency_slots(
names=names, slots=slots, occupancy_seconds=occupancy_seconds
)
return _response_to_minimal_concurrency_limit_response(response)
except anyio.get_cancelled_exc_class() as exc:
# The task was cancelled before it could release the slots. Add the
# slots to the cleanup list so they can be released when the
# concurrency context is exited.
if ctx := ConcurrencyContext.get():
ctx.cleanup_slots.append((names, slots, occupancy_seconds))

raise exc


def _response_to_minimal_concurrency_limit_response(
Expand Down
24 changes: 24 additions & 0 deletions src/prefect/concurrency/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from contextvars import ContextVar
from typing import List, Tuple

from prefect.client.orchestration import get_client
from prefect.context import ContextModel, Field


class ConcurrencyContext(ContextModel):
__var__: ContextVar = ContextVar("concurrency")

# Track the slots that have been acquired but were not able to be released
# due to cancellation or some other error. These slots are released when
# the context manager exits.
cleanup_slots: List[Tuple[List[str], int, float]] = Field(default_factory=list)

def __exit__(self, *exc_info):
if self.cleanup_slots:
with get_client(sync_client=True) as client:
for names, occupy, occupancy_seconds in self.cleanup_slots:
client.release_concurrency_slots(
names=names, slots=occupy, occupancy_seconds=occupancy_seconds
)

return super().__exit__(*exc_info)
2 changes: 1 addition & 1 deletion src/prefect/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ class EngineContext(RunContext):
default_factory=weakref.WeakValueDictionary
)

# Events worker to emit events to Prefect Cloud
# Events worker to emit events
events: Optional[EventsWorker] = None

__var__: ContextVar = ContextVar("flow_run")
Expand Down
3 changes: 3 additions & 0 deletions src/prefect/flow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from prefect.client.schemas import FlowRun, TaskRun
from prefect.client.schemas.filters import FlowRunFilter
from prefect.client.schemas.sorting import FlowRunSort
from prefect.concurrency.context import ConcurrencyContext
from prefect.context import FlowRunContext, SyncClientContext, TagsContext
from prefect.exceptions import (
Abort,
Expand Down Expand Up @@ -505,6 +506,8 @@ def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
task_runner=task_runner,
)
)
stack.enter_context(ConcurrencyContext())

# set the logger to the flow run logger
self.logger = flow_run_logger(flow_run=self.flow_run, flow=self.flow)

Expand Down
3 changes: 3 additions & 0 deletions src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from prefect.client.schemas import TaskRun
from prefect.client.schemas.objects import State, TaskRunInput
from prefect.concurrency.asyncio import concurrency as aconcurrency
from prefect.concurrency.context import ConcurrencyContext
from prefect.concurrency.sync import concurrency
from prefect.context import (
AsyncClientContext,
Expand Down Expand Up @@ -592,6 +593,7 @@ def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
client=client,
)
)
stack.enter_context(ConcurrencyContext())

self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore

Expand Down Expand Up @@ -1137,6 +1139,7 @@ async def setup_run_context(self, client: Optional[PrefectClient] = None):
client=client,
)
)
stack.enter_context(ConcurrencyContext())

self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore

Expand Down
62 changes: 62 additions & 0 deletions tests/concurrency/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import asyncio
import time

import pytest

from prefect.client.orchestration import PrefectClient
from prefect.concurrency.asyncio import concurrency as aconcurrency
from prefect.concurrency.context import ConcurrencyContext
from prefect.concurrency.sync import concurrency
from prefect.server.schemas.core import ConcurrencyLimitV2
from prefect.utilities.asyncutils import run_coro_as_sync
from prefect.utilities.timeout import timeout, timeout_async


async def test_concurrency_context_releases_slots_async(
concurrency_limit: ConcurrencyLimitV2, prefect_client: PrefectClient
):
async def expensive_task():
async with aconcurrency(concurrency_limit.name):
response = await prefect_client.read_global_concurrency_limit_by_name(
concurrency_limit.name
)
assert response.active_slots == 1

# Occupy the slot for longer than the timeout
await asyncio.sleep(1)

with pytest.raises(TimeoutError):
with timeout_async(seconds=0.5):
with ConcurrencyContext():
await expensive_task()

response = await prefect_client.read_global_concurrency_limit_by_name(
concurrency_limit.name
)
assert response.active_slots == 0


async def test_concurrency_context_releases_slots_sync(
concurrency_limit: ConcurrencyLimitV2, prefect_client: PrefectClient
):
def expensive_task():
with concurrency(concurrency_limit.name):
response = run_coro_as_sync(
prefect_client.read_global_concurrency_limit_by_name(
concurrency_limit.name
)
)
assert response and response.active_slots == 1

# Occupy the slot for longer than the timeout
time.sleep(1)

with pytest.raises(TimeoutError):
with timeout(seconds=0.5):
with ConcurrencyContext():
expensive_task()

response = await prefect_client.read_global_concurrency_limit_by_name(
concurrency_limit.name
)
assert response.active_slots == 0
43 changes: 42 additions & 1 deletion tests/test_flow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
from prefect.client.schemas.filters import FlowFilter, FlowRunFilter
from prefect.client.schemas.objects import StateType
from prefect.client.schemas.sorting import FlowRunSort
from prefect.context import FlowRunContext, TaskRunContext, get_run_context
from prefect.concurrency.asyncio import concurrency as aconcurrency
from prefect.concurrency.sync import concurrency
from prefect.context import (
FlowRunContext,
TaskRunContext,
get_run_context,
)
from prefect.exceptions import (
CrashedRun,
FlowPauseTimeout,
Expand All @@ -36,6 +42,7 @@
from prefect.input.actions import read_flow_run_input
from prefect.input.run_input import RunInput
from prefect.logging import get_run_logger
from prefect.server.schemas.core import ConcurrencyLimitV2
from prefect.server.schemas.core import FlowRun as ServerFlowRun
from prefect.testing.utilities import AsyncMock
from prefect.utilities.callables import get_call_parameters
Expand Down Expand Up @@ -1764,3 +1771,37 @@ async def test_load_flow_from_script_with_module_level_sync_compatible_call(
assert flow_run.id == api_flow_run.id

assert await flow() == "bar"


class TestConcurrencyRelease:
async def test_timeout_concurrency_slot_released_sync(
self, concurrency_limit_v2: ConcurrencyLimitV2, prefect_client: PrefectClient
):
@flow(timeout_seconds=0.5)
def expensive_flow():
with concurrency(concurrency_limit_v2.name):
time.sleep(1)

with pytest.raises(TimeoutError):
expensive_flow()

response = await prefect_client.read_global_concurrency_limit_by_name(
concurrency_limit_v2.name
)
assert response.active_slots == 0

async def test_timeout_concurrency_slot_released_async(
self, concurrency_limit_v2: ConcurrencyLimitV2, prefect_client: PrefectClient
):
@flow(timeout_seconds=0.5)
async def expensive_flow():
async with aconcurrency(concurrency_limit_v2.name):
await asyncio.sleep(1)

with pytest.raises(TimeoutError):
await expensive_flow()

response = await prefect_client.read_global_concurrency_limit_by_name(
concurrency_limit_v2.name
)
assert response.active_slots == 0
35 changes: 35 additions & 0 deletions tests/test_task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
_acquire_concurrency_slots,
_release_concurrency_slots,
)
from prefect.concurrency.asyncio import concurrency as aconcurrency
from prefect.concurrency.sync import concurrency
from prefect.context import (
EngineContext,
FlowRunContext,
Expand All @@ -32,6 +34,7 @@
from prefect.filesystems import LocalFileSystem
from prefect.logging import get_run_logger
from prefect.results import PersistedResult, ResultFactory, UnpersistedResult
from prefect.server.schemas.core import ConcurrencyLimitV2
from prefect.settings import (
PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION,
PREFECT_TASK_DEFAULT_RETRIES,
Expand Down Expand Up @@ -1619,6 +1622,38 @@ def sync_task():
with pytest.raises(TimeoutError, match=".*timed out after 0.1 second(s)*"):
run_task_sync(sync_task)

async def test_timeout_concurrency_slot_released_sync(
self, concurrency_limit_v2: ConcurrencyLimitV2, prefect_client: PrefectClient
):
@task(timeout_seconds=0.5)
def expensive_task():
with concurrency(concurrency_limit_v2.name):
time.sleep(1)

with pytest.raises(TimeoutError):
expensive_task()

response = await prefect_client.read_global_concurrency_limit_by_name(
concurrency_limit_v2.name
)
assert response.active_slots == 0

async def test_timeout_concurrency_slot_released_async(
self, concurrency_limit_v2: ConcurrencyLimitV2, prefect_client: PrefectClient
):
@task(timeout_seconds=0.5)
async def expensive_task():
async with aconcurrency(concurrency_limit_v2.name):
await asyncio.sleep(1)

with pytest.raises(TimeoutError):
await expensive_task()

response = await prefect_client.read_global_concurrency_limit_by_name(
concurrency_limit_v2.name
)
assert response.active_slots == 0


class TestPersistence:
async def test_task_can_return_persisted_result(self, prefect_client):
Expand Down

0 comments on commit 70e76c4

Please sign in to comment.