Skip to content

Commit

Permalink
[typing] prefect.utilities (#16298)
Browse files Browse the repository at this point in the history
  • Loading branch information
mjpieters authored Dec 10, 2024
1 parent 1e1bd57 commit e569cc5
Show file tree
Hide file tree
Showing 44 changed files with 1,544 additions and 1,203 deletions.
93 changes: 46 additions & 47 deletions src/prefect/_internal/concurrency/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,50 +6,46 @@
import asyncio
import concurrent.futures
import contextlib
from typing import (
Any,
Awaitable,
Callable,
ContextManager,
Iterable,
Optional,
TypeVar,
Union,
)
from collections.abc import Awaitable, Iterable
from contextlib import AbstractContextManager
from typing import Any, Callable, Optional, Union, cast

from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, TypeAlias, TypeVar

from prefect._internal.concurrency.threads import (
WorkerThread,
get_global_loop,
in_global_loop,
)
from prefect._internal.concurrency.waiters import (
AsyncWaiter,
Call,
SyncWaiter,
)
from prefect._internal.concurrency.waiters import AsyncWaiter, Call, SyncWaiter

P = ParamSpec("P")
T = TypeVar("T")
T = TypeVar("T", infer_variance=True)
Future = Union[concurrent.futures.Future[T], asyncio.Future[T]]

_SyncOrAsyncCallable: TypeAlias = Callable[P, Union[T, Awaitable[T]]]

def create_call(__fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> Call[T]:

def create_call(
__fn: _SyncOrAsyncCallable[P, T], *args: P.args, **kwargs: P.kwargs
) -> Call[T]:
return Call[T].new(__fn, *args, **kwargs)


def _cast_to_call(call_like: Union[Callable[[], T], Call[T]]) -> Call[T]:
def cast_to_call(
call_like: Union["_SyncOrAsyncCallable[[], T]", Call[T]],
) -> Call[T]:
if isinstance(call_like, Call):
return call_like
return cast(Call[T], call_like)
else:
return create_call(call_like)


class _base(abc.ABC):
@abc.abstractstaticmethod
@staticmethod
@abc.abstractmethod
def wait_for_call_in_loop_thread(
__call: Union[Callable[[], T], Call[T]], # type: ignore[reportGeneralTypeIssues]
__call: Union["_SyncOrAsyncCallable[[], Any]", Call[T]],
timeout: Optional[float] = None,
done_callbacks: Optional[Iterable[Call[Any]]] = None,
) -> T:
Expand All @@ -60,9 +56,10 @@ def wait_for_call_in_loop_thread(
"""
raise NotImplementedError()

@abc.abstractstaticmethod
@staticmethod
@abc.abstractmethod
def wait_for_call_in_new_thread(
__call: Union[Callable[[], T], Call[T]], # type: ignore[reportGeneralTypeIssues]
__call: Union["_SyncOrAsyncCallable[[], T]", Call[T]],
timeout: Optional[float] = None,
done_callbacks: Optional[Iterable[Call[Any]]] = None,
) -> T:
Expand All @@ -75,30 +72,31 @@ def wait_for_call_in_new_thread(

@staticmethod
def call_soon_in_new_thread(
__call: Union[Callable[[], T], Call[T]], timeout: Optional[float] = None
__call: Union["_SyncOrAsyncCallable[[], T]", Call[T]],
timeout: Optional[float] = None,
) -> Call[T]:
"""
Schedule a call for execution in a new worker thread.
Returns the submitted call.
"""
call = _cast_to_call(__call)
call = cast_to_call(__call)
runner = WorkerThread(run_once=True)
call.set_timeout(timeout)
runner.submit(call)
return call

@staticmethod
def call_soon_in_loop_thread(
__call: Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]],
__call: Union["_SyncOrAsyncCallable[[], T]", Call[T]],
timeout: Optional[float] = None,
) -> Call[T]:
"""
Schedule a call for execution in the global event loop thread.
Returns the submitted call.
"""
call = _cast_to_call(__call)
call = cast_to_call(__call)
runner = get_global_loop()
call.set_timeout(timeout)
runner.submit(call)
Expand All @@ -117,7 +115,7 @@ def call_in_new_thread(

@staticmethod
def call_in_loop_thread(
__call: Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]],
__call: Union[Callable[[], Awaitable[T]], Call[T]],
timeout: Optional[float] = None,
) -> T:
"""
Expand All @@ -131,12 +129,12 @@ def call_in_loop_thread(
class from_async(_base):
@staticmethod
async def wait_for_call_in_loop_thread(
__call: Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]],
__call: Union[Callable[[], Awaitable[T]], Call[T]],
timeout: Optional[float] = None,
done_callbacks: Optional[Iterable[Call[Any]]] = None,
contexts: Optional[Iterable[ContextManager[Any]]] = None,
) -> Awaitable[T]:
call = _cast_to_call(__call)
contexts: Optional[Iterable[AbstractContextManager[Any]]] = None,
) -> T:
call = cast_to_call(__call)
waiter = AsyncWaiter(call)
for callback in done_callbacks or []:
waiter.add_done_callback(callback)
Expand All @@ -153,7 +151,7 @@ async def wait_for_call_in_new_thread(
timeout: Optional[float] = None,
done_callbacks: Optional[Iterable[Call[Any]]] = None,
) -> T:
call = _cast_to_call(__call)
call = cast_to_call(__call)
waiter = AsyncWaiter(call=call)
for callback in done_callbacks or []:
waiter.add_done_callback(callback)
Expand All @@ -170,7 +168,7 @@ def call_in_new_thread(

@staticmethod
def call_in_loop_thread(
__call: Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]],
__call: Union[Callable[[], Awaitable[T]], Call[T]],
timeout: Optional[float] = None,
) -> Awaitable[T]:
call = _base.call_soon_in_loop_thread(__call, timeout=timeout)
Expand All @@ -182,13 +180,13 @@ class from_sync(_base):
def wait_for_call_in_loop_thread(
__call: Union[
Callable[[], Awaitable[T]],
Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]],
Call[T],
],
timeout: Optional[float] = None,
done_callbacks: Optional[Iterable[Call]] = None,
contexts: Optional[Iterable[ContextManager]] = None,
) -> Awaitable[T]:
call = _cast_to_call(__call)
done_callbacks: Optional[Iterable[Call[T]]] = None,
contexts: Optional[Iterable[AbstractContextManager[Any]]] = None,
) -> T:
call = cast_to_call(__call)
waiter = SyncWaiter(call)
_base.call_soon_in_loop_thread(call, timeout=timeout)
for callback in done_callbacks or []:
Expand All @@ -203,9 +201,9 @@ def wait_for_call_in_loop_thread(
def wait_for_call_in_new_thread(
__call: Union[Callable[[], T], Call[T]],
timeout: Optional[float] = None,
done_callbacks: Optional[Iterable[Call]] = None,
) -> Call[T]:
call = _cast_to_call(__call)
done_callbacks: Optional[Iterable[Call[T]]] = None,
) -> T:
call = cast_to_call(__call)
waiter = SyncWaiter(call=call)
for callback in done_callbacks or []:
waiter.add_done_callback(callback)
Expand All @@ -215,20 +213,21 @@ def wait_for_call_in_new_thread(

@staticmethod
def call_in_new_thread(
__call: Union[Callable[[], T], Call[T]], timeout: Optional[float] = None
__call: Union["_SyncOrAsyncCallable[[], T]", Call[T]],
timeout: Optional[float] = None,
) -> T:
call = _base.call_soon_in_new_thread(__call, timeout=timeout)
return call.result()

@staticmethod
def call_in_loop_thread(
__call: Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]],
__call: Union["_SyncOrAsyncCallable[[], T]", Call[T]],
timeout: Optional[float] = None,
) -> T:
) -> Union[Awaitable[T], T]:
if in_global_loop():
# Avoid deadlock where the call is submitted to the loop then the loop is
# blocked waiting for the call
call = _cast_to_call(__call)
call = cast_to_call(__call)
return call()

call = _base.call_soon_in_loop_thread(__call, timeout=timeout)
Expand Down
Loading

0 comments on commit e569cc5

Please sign in to comment.