Skip to content

Commit

Permalink
Fix worker capacity test
Browse files Browse the repository at this point in the history
  • Loading branch information
bunchesofdonald committed Jul 24, 2024
1 parent 103f8e1 commit 8c2c3e1
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 119 deletions.
26 changes: 12 additions & 14 deletions src/prefect/client/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,20 +203,18 @@ def get_client(
except RuntimeError:
loop = None

if client_ctx := prefect.context.ClientContext.get():
if (
sync_client
and client_ctx.sync_client
and client_ctx._httpx_settings == httpx_settings
):
return client_ctx.sync_client
elif (
not sync_client
and client_ctx.async_client
and client_ctx._httpx_settings == httpx_settings
and loop in (client_ctx.async_client._loop, None)
):
return client_ctx.async_client
if sync_client:
if client_ctx := prefect.context.ClientContext.get():
if client_ctx.client and client_ctx._httpx_settings == httpx_settings:
return client_ctx.client
else:
if client_ctx := prefect.context.AsyncClientContext.get():
if (
client_ctx.client
and client_ctx._httpx_settings == httpx_settings
and loop in (client_ctx.client._loop, None)
):
return client_ctx.client

api = PREFECT_API_URL.value()

Expand Down
12 changes: 8 additions & 4 deletions src/prefect/client/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,18 @@ def get_or_create_client(
if client is not None:
return client, True
from prefect._internal.concurrency.event_loop import get_running_loop
from prefect.context import ClientContext, FlowRunContext, TaskRunContext
from prefect.context import (
AsyncClientContext,
FlowRunContext,
TaskRunContext,
)

client_context = ClientContext.get()
async_client_context = AsyncClientContext.get()
flow_run_context = FlowRunContext.get()
task_run_context = TaskRunContext.get()

if client_context and client_context.async_client._loop == get_running_loop():
return client_context.async_client, True
if async_client_context and async_client_context.client._loop == get_running_loop():
return async_client_context.client, True
elif (
flow_run_context
and getattr(flow_run_context.client, "_loop", None) == get_running_loop()
Expand Down
93 changes: 56 additions & 37 deletions src/prefect/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,88 +180,107 @@ def serialize(self) -> Dict[str, Any]:

class ClientContext(ContextModel):
"""
A context for managing the Prefect client instances.
A context for managing the sync Prefect client instance.
Clients were formerly tracked on the TaskRunContext and FlowRunContext, but
having two separate places and the addition of both sync and async clients
made it difficult to manage. This context is intended to be the single
source for clients.
The client creates both sync and async clients, which can either be read
directly from the context object OR loaded with get_client, inject_client,
or other Prefect utilities.
source for sync clients.
with ClientContext.get_or_create() as ctx:
c1 = get_client(sync_client=True)
c2 = get_client(sync_client=True)
assert c1 is c2
assert c1 is ctx.sync_client
assert c1 is ctx.client
"""

__var__ = ContextVar("clients")
sync_client: SyncPrefectClient
async_client: PrefectClient
__var__ = ContextVar("sync-client-context")
client: SyncPrefectClient
_httpx_settings: Optional[dict[str, Any]] = PrivateAttr(None)
_context_stack: int = PrivateAttr(0)

def __init__(self, httpx_settings: Optional[dict[str, Any]] = None):
super().__init__(
sync_client=get_client(sync_client=True, httpx_settings=httpx_settings),
async_client=get_client(sync_client=False, httpx_settings=httpx_settings),
client=get_client(sync_client=True, httpx_settings=httpx_settings),
)
self._httpx_settings = httpx_settings
self._context_stack = 0

async def __aenter__(self):
def __enter__(self):
self._context_stack += 1
if self._context_stack == 1:
self.sync_client.__enter__()
await self.async_client.__aenter__()
self.client.__enter__()
return super().__enter__()
else:
return self

async def __aexit__(self, *exc_info):
def __exit__(self, *exc_info):
self._context_stack -= 1
if self._context_stack == 0:
self.sync_client.__exit__(*exc_info)
await self.async_client.__aexit__(*exc_info)
self.client.__exit__(*exc_info)
return super().__exit__(*exc_info)

def __enter__(self):
@classmethod
@contextmanager
def get_or_create(cls) -> Generator["ClientContext", None, None]:
ctx = ClientContext.get()
if ctx:
yield ctx
else:
with ClientContext() as ctx:
yield ctx


class AsyncClientContext(ContextModel):
"""
A context for managing the async Prefect client instance.
Clients were formerly tracked on the TaskRunContext and FlowRunContext, but
having two separate places and the addition of both sync and async clients
made it difficult to manage. This context is intended to be the single
source for sync clients.
async with AsyncClientContext.get_or_create() as ctx:
c1 = get_client()
c2 = get_client()
assert c1 is c2
assert c1 is ctx.client
"""

__var__ = ContextVar("async-client-context")
client: PrefectClient
_httpx_settings: Optional[dict[str, Any]] = PrivateAttr(None)
_context_stack: int = PrivateAttr(0)

def __init__(self, httpx_settings: Optional[dict[str, Any]] = None):
super().__init__(
client=get_client(sync_client=False, httpx_settings=httpx_settings),
)
self._httpx_settings = httpx_settings
self._context_stack = 0

async def __aenter__(self):
self._context_stack += 1
if self._context_stack == 1:
self.sync_client.__enter__()
run_coro_as_sync(self.async_client.__aenter__())
await self.client.__aenter__()
return super().__enter__()
else:
return self

def __exit__(self, *exc_info):
async def __aexit__(self, *exc_info):
self._context_stack -= 1
if self._context_stack == 0:
self.sync_client.__exit__(*exc_info)
run_coro_as_sync(self.async_client.__aexit__(*exc_info))
await self.client.__aexit__(*exc_info)
return super().__exit__(*exc_info)

@classmethod
@asynccontextmanager
async def async_get_or_create(cls) -> AsyncGenerator["ClientContext", None]:
ctx = ClientContext.get()
if ctx:
yield ctx
else:
async with ClientContext() as ctx:
yield ctx

@classmethod
@contextmanager
def get_or_create(cls) -> Generator["ClientContext", None, None]:
ctx = ClientContext.get()
async def get_or_create(cls) -> AsyncGenerator["AsyncClientContext", None]:
ctx = AsyncClientContext.get()
if ctx:
yield ctx
else:
with ClientContext() as ctx:
async with AsyncClientContext() as ctx:
yield ctx


Expand Down
2 changes: 1 addition & 1 deletion src/prefect/flow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def initialize_run(self):
Enters a client context and creates a flow run if needed.
"""
with ClientContext.get_or_create() as client_ctx:
self._client = client_ctx.sync_client
self._client = client_ctx.client
self._is_started = True

if not self.flow_run:
Expand Down
12 changes: 8 additions & 4 deletions src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from prefect.client.schemas import TaskRun
from prefect.client.schemas.objects import State, TaskRunInput
from prefect.context import (
AsyncClientContext,
ClientContext,
FlowRunContext,
TaskRunContext,
Expand Down Expand Up @@ -568,7 +569,7 @@ def initialize_run(

with hydrated_context(self.context):
with ClientContext.get_or_create() as client_ctx:
self._client = client_ctx.sync_client
self._client = client_ctx.client
self._is_started = True
try:
if not self.task_run:
Expand Down Expand Up @@ -821,6 +822,9 @@ async def call_hooks(self, state: Optional[State] = None):
else:
self.logger.info(f"Hook {hook_name!r} finished running successfully")

async def sleep(self, interval: float):
await anyio.sleep(interval)

async def begin_run(self):
try:
self._resolve_parameters()
Expand Down Expand Up @@ -859,7 +863,7 @@ async def begin_run(self):
interval = clamped_poisson_interval(
average_interval=backoff_count, clamping_factor=0.3
)
await anyio.sleep(interval)
await self.sleep(interval)
state = await self.set_state(new_state)

async def set_state(self, state: State, force: bool = False) -> State:
Expand Down Expand Up @@ -1091,8 +1095,8 @@ async def initialize_run(
"""

with hydrated_context(self.context):
async with ClientContext.async_get_or_create() as client_ctx:
self._client = client_ctx.async_client
async with AsyncClientContext.get_or_create() as client_ctx:
self._client = client_ctx.client
self._is_started = True
try:
if not self.task_run:
Expand Down
6 changes: 4 additions & 2 deletions tests/test_flow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,8 +1125,10 @@ async def flow_resumer():
assert schema is not None

async def test_paused_task_polling(self, monkeypatch, prefect_client):
sleeper = MagicMock(side_effect=[None, None, None, None, None])
monkeypatch.setattr("prefect.task_engine.time.sleep", sleeper)
from prefect.testing.utilities import AsyncMock

sleeper = AsyncMock(side_effect=[None, None, None, None, None])
monkeypatch.setattr("prefect.task_engine.AsyncTaskRunEngine.sleep", sleeper)

@task
async def doesnt_pause():
Expand Down
90 changes: 33 additions & 57 deletions tests/test_task_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,80 +769,56 @@ async def mock_iter():
async def test_tasks_execute_when_capacity_frees_up(
self, mock_subscription, prefect_client
):
event = asyncio.Event()
execution_order = []

@task
async def slow_task():
await asyncio.sleep(1)
if event.is_set():
raise ValueError("Something went wrong! This event should not be set.")
event.set()
async def slow_task(task_id: str):
execution_order.append(f"{task_id} start")
await asyncio.sleep(0.1) # Simulating some work
execution_order.append(f"{task_id} end")

task_worker = TaskWorker(slow_task, limit=1)

task_run_future_1 = slow_task.apply_async()
task_run_future_1 = slow_task.apply_async(("task1",))
task_run_1 = await prefect_client.read_task_run(task_run_future_1.task_run_id)
task_run_future_2 = slow_task.apply_async()
task_run_future_2 = slow_task.apply_async(("task2",))
task_run_2 = await prefect_client.read_task_run(task_run_future_2.task_run_id)

async def mock_iter():
yield task_run_1
yield task_run_2
# sleep for a second to ensure that task execution starts
await asyncio.sleep(1)
while len(execution_order) < 4:
await asyncio.sleep(0.1)

mock_subscription.return_value = mock_iter()

server_task = asyncio.create_task(task_worker.start())
await event.wait()
updated_task_run_1 = await prefect_client.read_task_run(task_run_1.id)
updated_task_run_2 = await prefect_client.read_task_run(task_run_2.id)

assert updated_task_run_1.state.is_completed()
assert not updated_task_run_2.state.is_completed()

# clear the event to allow the second task to complete
event.clear()

await event.wait()
updated_task_run_2 = await prefect_client.read_task_run(task_run_2.id)

assert updated_task_run_2.state.is_completed()

server_task.cancel()
await server_task

async def test_execute_task_run_respects_limit(self, prefect_client):
@task
def slow_task():
import time

time.sleep(1)

task_worker = TaskWorker(slow_task, limit=1)

task_run_future_1 = slow_task.apply_async()
task_run_1 = await prefect_client.read_task_run(task_run_future_1.task_run_id)
task_run_future_2 = slow_task.apply_async()
task_run_2 = await prefect_client.read_task_run(task_run_future_2.task_run_id)

try:
with anyio.move_on_after(1):
# start task worker first to avoid race condition between two execute_task_run calls
async with task_worker:
await asyncio.gather(
task_worker.execute_task_run(task_run_1),
task_worker.execute_task_run(task_run_2),
)
except asyncio.exceptions.CancelledError:
# We want to cancel the second task run, so this is expected
pass

updated_task_run_1 = await prefect_client.read_task_run(task_run_1.id)
updated_task_run_2 = await prefect_client.read_task_run(task_run_2.id)

assert updated_task_run_1.state.is_completed()
assert updated_task_run_2.state.is_scheduled()
# Wait for both tasks to complete
await asyncio.sleep(2)

# Verify the execution order
assert execution_order == [
"task1 start",
"task1 end",
"task2 start",
"task2 end",
], "Tasks should execute sequentially"

# Verify the states of both tasks
updated_task_run_1 = await prefect_client.read_task_run(task_run_1.id)
updated_task_run_2 = await prefect_client.read_task_run(task_run_2.id)

assert updated_task_run_1.state.is_completed()
assert updated_task_run_2.state.is_completed()

finally:
server_task.cancel()
try:
await server_task
except asyncio.CancelledError:
pass

async def test_serve_respects_limit(self, prefect_client, mock_subscription):
@task
Expand Down

0 comments on commit 8c2c3e1

Please sign in to comment.