Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create an AsyncTaskRunEngine and wire it up to run async tasks #14743

Merged
merged 3 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions src/prefect/client/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,20 +203,18 @@ def get_client(
except RuntimeError:
loop = None

if client_ctx := prefect.context.ClientContext.get():
if (
sync_client
and client_ctx.sync_client
and client_ctx._httpx_settings == httpx_settings
):
return client_ctx.sync_client
elif (
not sync_client
and client_ctx.async_client
and client_ctx._httpx_settings == httpx_settings
and loop in (client_ctx.async_client._loop, None)
):
return client_ctx.async_client
if sync_client:
if client_ctx := prefect.context.SyncClientContext.get():
if client_ctx.client and client_ctx._httpx_settings == httpx_settings:
return client_ctx.client
else:
if client_ctx := prefect.context.AsyncClientContext.get():
if (
client_ctx.client
and client_ctx._httpx_settings == httpx_settings
and loop in (client_ctx.client._loop, None)
):
return client_ctx.client

api = PREFECT_API_URL.value()

Expand Down
8 changes: 4 additions & 4 deletions src/prefect/client/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ def get_or_create_client(
if client is not None:
return client, True
from prefect._internal.concurrency.event_loop import get_running_loop
from prefect.context import ClientContext, FlowRunContext, TaskRunContext
from prefect.context import AsyncClientContext, FlowRunContext, TaskRunContext

client_context = ClientContext.get()
async_client_context = AsyncClientContext.get()
flow_run_context = FlowRunContext.get()
task_run_context = TaskRunContext.get()

if client_context and client_context.async_client._loop == get_running_loop():
return client_context.async_client, True
if async_client_context and async_client_context.client._loop == get_running_loop():
return async_client_context.client, True
elif (
flow_run_context
and getattr(flow_run_context.client, "_loop", None) == get_running_loop()
Expand Down
96 changes: 75 additions & 21 deletions src/prefect/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
import sys
import warnings
import weakref
from contextlib import ExitStack, contextmanager
from contextlib import ExitStack, asynccontextmanager, contextmanager
from contextvars import ContextVar, Token
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Dict,
Generator,
Mapping,
Expand Down Expand Up @@ -177,64 +178,117 @@ def serialize(self) -> Dict[str, Any]:
return self.model_dump(exclude_unset=True)


class ClientContext(ContextModel):
class SyncClientContext(ContextModel):
"""
A context for managing the Prefect client instances.
A context for managing the sync Prefect client instances.

Clients were formerly tracked on the TaskRunContext and FlowRunContext, but
having two separate places and the addition of both sync and async clients
made it difficult to manage. This context is intended to be the single
source for clients.
source for sync clients.

The client creates both sync and async clients, which can either be read
directly from the context object OR loaded with get_client, inject_client,
or other Prefect utilities.
The client creates a sync client, which can either be read directly from
the context object OR loaded with get_client, inject_client, or other
Prefect utilities.

with ClientContext.get_or_create() as ctx:
with SyncClientContext.get_or_create() as ctx:
c1 = get_client(sync_client=True)
c2 = get_client(sync_client=True)
assert c1 is c2
assert c1 is ctx.sync_client
assert c1 is ctx.client
"""

__var__ = ContextVar("clients")
sync_client: SyncPrefectClient
async_client: PrefectClient
__var__ = ContextVar("sync-client-context")
client: SyncPrefectClient
_httpx_settings: Optional[dict[str, Any]] = PrivateAttr(None)
_context_stack: int = PrivateAttr(0)

def __init__(self, httpx_settings: Optional[dict[str, Any]] = None):
super().__init__(
sync_client=get_client(sync_client=True, httpx_settings=httpx_settings),
async_client=get_client(sync_client=False, httpx_settings=httpx_settings),
client=get_client(sync_client=True, httpx_settings=httpx_settings),
)
self._httpx_settings = httpx_settings
self._context_stack = 0

def __enter__(self):
self._context_stack += 1
if self._context_stack == 1:
self.sync_client.__enter__()
run_coro_as_sync(self.async_client.__aenter__())
self.client.__enter__()
return super().__enter__()
else:
return self

def __exit__(self, *exc_info):
self._context_stack -= 1
if self._context_stack == 0:
self.sync_client.__exit__(*exc_info)
run_coro_as_sync(self.async_client.__aexit__(*exc_info))
self.client.__exit__(*exc_info)
return super().__exit__(*exc_info)

@classmethod
@contextmanager
def get_or_create(cls) -> Generator["ClientContext", None, None]:
ctx = ClientContext.get()
def get_or_create(cls) -> Generator["SyncClientContext", None, None]:
ctx = SyncClientContext.get()
if ctx:
yield ctx
else:
with ClientContext() as ctx:
with SyncClientContext() as ctx:
yield ctx


class AsyncClientContext(ContextModel):
bunchesofdonald marked this conversation as resolved.
Show resolved Hide resolved
"""
A context for managing the async Prefect client instances.

Clients were formerly tracked on the TaskRunContext and FlowRunContext, but
having two separate places and the addition of both sync and async clients
made it difficult to manage. This context is intended to be the single
source for async clients.

The client creates an async client, which can either be read directly from
the context object OR loaded with get_client, inject_client, or other
Prefect utilities.

with AsyncClientContext.get_or_create() as ctx:
c1 = get_client(sync_client=False)
c2 = get_client(sync_client=False)
assert c1 is c2
assert c1 is ctx.client
"""

__var__ = ContextVar("async-client-context")
client: PrefectClient
_httpx_settings: Optional[dict[str, Any]] = PrivateAttr(None)
_context_stack: int = PrivateAttr(0)

def __init__(self, httpx_settings: Optional[dict[str, Any]] = None):
super().__init__(
client=get_client(sync_client=False, httpx_settings=httpx_settings),
)
self._httpx_settings = httpx_settings
self._context_stack = 0

async def __aenter__(self):
self._context_stack += 1
if self._context_stack == 1:
await self.client.__aenter__()
return super().__enter__()
else:
return self

async def __aexit__(self, *exc_info):
self._context_stack -= 1
if self._context_stack == 0:
await self.client.__aexit__(*exc_info)
return super().__exit__(*exc_info)

@classmethod
@asynccontextmanager
async def get_or_create(cls) -> AsyncGenerator[Self, None]:
ctx = cls.get()
if ctx:
yield ctx
else:
with cls() as ctx:
yield ctx


Expand Down
6 changes: 3 additions & 3 deletions src/prefect/flow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from prefect.client.schemas import FlowRun, TaskRun
from prefect.client.schemas.filters import FlowRunFilter
from prefect.client.schemas.sorting import FlowRunSort
from prefect.context import ClientContext, FlowRunContext, TagsContext
from prefect.context import FlowRunContext, SyncClientContext, TagsContext
from prefect.exceptions import (
Abort,
Pause,
Expand Down Expand Up @@ -529,8 +529,8 @@ def initialize_run(self):
"""
Enters a client context and creates a flow run if needed.
"""
with ClientContext.get_or_create() as client_ctx:
self._client = client_ctx.sync_client
with SyncClientContext.get_or_create() as client_ctx:
self._client = client_ctx.client
self._is_started = True

if not self.flow_run:
Expand Down
Loading
Loading