Skip to content

Commit

Permalink
Create an async copy of TaskRunEngine and use it for async tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
bunchesofdonald committed Jul 19, 2024
1 parent 96140de commit c5f4a91
Show file tree
Hide file tree
Showing 5 changed files with 748 additions and 168 deletions.
29 changes: 28 additions & 1 deletion src/prefect/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
import sys
import warnings
import weakref
from contextlib import ExitStack, contextmanager
from contextlib import ExitStack, asynccontextmanager, contextmanager
from contextvars import ContextVar, Token
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Dict,
Generator,
Mapping,
Expand Down Expand Up @@ -211,6 +212,22 @@ def __init__(self, httpx_settings: Optional[dict[str, Any]] = None):
self._httpx_settings = httpx_settings
self._context_stack = 0

async def __aenter__(self):
self._context_stack += 1
if self._context_stack == 1:
self.sync_client.__enter__()
await self.async_client.__aenter__()
return super().__enter__()
else:
return self

async def __aexit__(self, *exc_info):
self._context_stack -= 1
if self._context_stack == 0:
self.sync_client.__exit__(*exc_info)
await self.async_client.__aexit__(*exc_info)
return super().__exit__(*exc_info)

def __enter__(self):
self._context_stack += 1
if self._context_stack == 1:
Expand All @@ -227,6 +244,16 @@ def __exit__(self, *exc_info):
run_coro_as_sync(self.async_client.__aexit__(*exc_info))
return super().__exit__(*exc_info)

@classmethod
@asynccontextmanager
async def async_get_or_create(cls) -> AsyncGenerator["ClientContext", None]:
ctx = ClientContext.get()
if ctx:
yield ctx
else:
async with ClientContext() as ctx:
yield ctx

@classmethod
@contextmanager
def get_or_create(cls) -> Generator["ClientContext", None, None]:
Expand Down
Loading

0 comments on commit c5f4a91

Please sign in to comment.