From 7162fb7c0380f39a2a13193bfd093f7c7c42b34b Mon Sep 17 00:00:00 2001 From: Martijn Pieters Date: Tue, 24 Dec 2024 23:39:22 +0000 Subject: [PATCH] [typing] p._internal, p.server.utilities --- src/prefect/_internal/_logging.py | 18 +- .../_internal/compatibility/async_dispatch.py | 38 +- .../_internal/compatibility/deprecated.py | 58 +- .../_internal/compatibility/migration.py | 4 +- .../_internal/concurrency/inspection.py | 26 +- .../_internal/concurrency/primitives.py | 4 +- src/prefect/_internal/concurrency/waiters.py | 22 +- .../pydantic/annotations/pendulum.py | 14 +- src/prefect/_internal/pytz.py | 7 +- src/prefect/_internal/retries.py | 15 +- src/prefect/_internal/schemas/bases.py | 17 +- src/prefect/_internal/schemas/validators.py | 615 +++++++----------- src/prefect/serializers.py | 57 +- src/prefect/server/events/messaging.py | 2 +- src/prefect/server/utilities/encryption.py | 13 +- .../server/utilities/messaging/__init__.py | 87 ++- .../server/utilities/messaging/memory.py | 52 +- src/prefect/server/utilities/server.py | 15 +- src/prefect/server/utilities/subscriptions.py | 9 +- .../server/utilities/user_templates.py | 12 +- src/prefect/telemetry/run_telemetry.py | 2 +- .../schemas}/test_validation.py | 8 +- 22 files changed, 489 insertions(+), 606 deletions(-) rename tests/{utilities => _internal/schemas}/test_validation.py (91%) diff --git a/src/prefect/_internal/_logging.py b/src/prefect/_internal/_logging.py index 189c87ff5d22..f1fc922ea063 100644 --- a/src/prefect/_internal/_logging.py +++ b/src/prefect/_internal/_logging.py @@ -1,4 +1,14 @@ import logging +import sys + +from typing_extensions import Self + +if sys.version_info < (3, 11): + + def getLevelNamesMapping() -> dict[str, int]: + return getattr(logging, "_nameToLevel").copy() +else: + getLevelNamesMapping = logging.getLevelNamesMapping # novermin class SafeLogger(logging.Logger): @@ -11,11 +21,13 @@ def isEnabledFor(self, level: int): # deadlocks during complex concurrency handling from prefect.settings import PREFECT_LOGGING_INTERNAL_LEVEL - return level >= logging._nameToLevel[PREFECT_LOGGING_INTERNAL_LEVEL.value()] + internal_level = getLevelNamesMapping()[PREFECT_LOGGING_INTERNAL_LEVEL.value()] + + return level >= internal_level - def getChild(self, suffix: str): + def getChild(self, suffix: str) -> Self: logger = super().getChild(suffix) - logger.__class__ = SafeLogger + logger.__class__ = self.__class__ return logger diff --git a/src/prefect/_internal/compatibility/async_dispatch.py b/src/prefect/_internal/compatibility/async_dispatch.py index a0a3ffc67666..0e013a77d335 100644 --- a/src/prefect/_internal/compatibility/async_dispatch.py +++ b/src/prefect/_internal/compatibility/async_dispatch.py @@ -1,11 +1,13 @@ import asyncio import inspect +from collections.abc import Coroutine from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Coroutine, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, cast from typing_extensions import ParamSpec if TYPE_CHECKING: + from prefect.flows import Flow from prefect.tasks import Task R = TypeVar("R") @@ -41,7 +43,9 @@ def is_in_async_context() -> bool: def _is_acceptable_callable( - obj: Union[Callable[P, R], "Task[P, R]", classmethod], + obj: Union[ + Callable[P, R], "Flow[P, R]", "Task[P, R]", "classmethod[type[Any], P, R]" + ], ) -> bool: if inspect.iscoroutinefunction(obj): return True @@ -58,7 +62,10 @@ def _is_acceptable_callable( def async_dispatch( - async_impl: Callable[P, Coroutine[Any, Any, R]], + async_impl: Union[ + Callable[P, Coroutine[Any, Any, R]], + "classmethod[type[Any], P, Coroutine[Any, Any, R]]", + ], ) -> Callable[[Callable[P, R]], Callable[P, Union[R, Coroutine[Any, Any, R]]]]: """ Decorator that dispatches to either sync or async implementation based on context. @@ -66,27 +73,26 @@ def async_dispatch( Args: async_impl: The async implementation to dispatch to when in async context """ + if not _is_acceptable_callable(async_impl): + raise TypeError("async_impl must be an async function") + if isinstance(async_impl, classmethod): + async_impl = cast(Callable[P, Coroutine[Any, Any, R]], async_impl.__func__) def decorator( sync_fn: Callable[P, R], ) -> Callable[P, Union[R, Coroutine[Any, Any, R]]]: - if not _is_acceptable_callable(async_impl): - raise TypeError("async_impl must be an async function") - @wraps(sync_fn) def wrapper( *args: P.args, - _sync: Optional[bool] = None, # type: ignore **kwargs: P.kwargs, ) -> Union[R, Coroutine[Any, Any, R]]: - should_run_sync = _sync if _sync is not None else not is_in_async_context() - - if should_run_sync: - return sync_fn(*args, **kwargs) - if isinstance(async_impl, classmethod): - return async_impl.__func__(*args, **kwargs) - return async_impl(*args, **kwargs) - - return wrapper # type: ignore + _sync = kwargs.pop("_sync", None) + should_run_sync = ( + bool(_sync) if _sync is not None else not is_in_async_context() + ) + fn = sync_fn if should_run_sync else async_impl + return fn(*args, **kwargs) + + return wrapper return decorator diff --git a/src/prefect/_internal/compatibility/deprecated.py b/src/prefect/_internal/compatibility/deprecated.py index 62344d28351a..cf11ff94687b 100644 --- a/src/prefect/_internal/compatibility/deprecated.py +++ b/src/prefect/_internal/compatibility/deprecated.py @@ -13,10 +13,11 @@ import functools import sys import warnings -from typing import Any, Callable, List, Optional, Type, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import pendulum from pydantic import BaseModel +from typing_extensions import ParamSpec, TypeAlias, TypeVar from prefect.utilities.callables import get_call_parameters from prefect.utilities.importtools import ( @@ -25,8 +26,10 @@ to_qualified_name, ) -T = TypeVar("T", bound=Callable) +P = ParamSpec("P") +R = TypeVar("R", infer_variance=True) M = TypeVar("M", bound=BaseModel) +T = TypeVar("T") DEPRECATED_WARNING = ( @@ -38,7 +41,7 @@ "path after {end_date}. {help}" ) DEPRECATED_DATEFMT = "MMM YYYY" # e.g. Feb 2023 -DEPRECATED_MODULE_ALIASES: List[AliasedModuleDefinition] = [] +DEPRECATED_MODULE_ALIASES: list[AliasedModuleDefinition] = [] class PrefectDeprecationWarning(DeprecationWarning): @@ -61,6 +64,8 @@ def generate_deprecation_message( ) if not end_date: + if TYPE_CHECKING: + assert start_date is not None parsed_start_date = pendulum.from_format(start_date, DEPRECATED_DATEFMT) parsed_end_date = parsed_start_date.add(months=6) end_date = parsed_end_date.format(DEPRECATED_DATEFMT) @@ -83,8 +88,8 @@ def deprecated_callable( end_date: Optional[str] = None, stacklevel: int = 2, help: str = "", -) -> Callable[[T], T]: - def decorator(fn: T): +) -> Callable[[Callable[P, R]], Callable[P, R]]: + def decorator(fn: Callable[P, R]) -> Callable[P, R]: message = generate_deprecation_message( name=to_qualified_name(fn), start_date=start_date, @@ -93,7 +98,7 @@ def decorator(fn: T): ) @functools.wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: warnings.warn(message, PrefectDeprecationWarning, stacklevel=stacklevel) return fn(*args, **kwargs) @@ -108,8 +113,8 @@ def deprecated_class( end_date: Optional[str] = None, stacklevel: int = 2, help: str = "", -) -> Callable[[T], T]: - def decorator(cls: T): +) -> Callable[[type[T]], type[T]]: + def decorator(cls: type[T]) -> type[T]: message = generate_deprecation_message( name=to_qualified_name(cls), start_date=start_date, @@ -120,7 +125,7 @@ def decorator(cls: T): original_init = cls.__init__ @functools.wraps(original_init) - def new_init(self, *args, **kwargs): + def new_init(self: T, *args: Any, **kwargs: Any) -> None: warnings.warn(message, PrefectDeprecationWarning, stacklevel=stacklevel) original_init(self, *args, **kwargs) @@ -139,7 +144,7 @@ def deprecated_parameter( help: str = "", when: Optional[Callable[[Any], bool]] = None, when_message: str = "", -) -> Callable[[T], T]: +) -> Callable[[Callable[P, R]], Callable[P, R]]: """ Mark a parameter in a callable as deprecated. @@ -155,7 +160,7 @@ def foo(x, y = None): when = when or (lambda _: True) - def decorator(fn: T): + def decorator(fn: Callable[P, R]) -> Callable[P, R]: message = generate_deprecation_message( name=f"The parameter {name!r} for {fn.__name__!r}", start_date=start_date, @@ -165,7 +170,7 @@ def decorator(fn: T): ) @functools.wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: try: parameters = get_call_parameters(fn, args, kwargs, apply_defaults=False) except Exception: @@ -182,6 +187,10 @@ def wrapper(*args, **kwargs): return decorator +JsonValue: TypeAlias = Union[int, float, str, bool, None, list["JsonValue"], "JsonDict"] +JsonDict: TypeAlias = dict[str, JsonValue] + + def deprecated_field( name: str, *, @@ -191,7 +200,7 @@ def deprecated_field( help: str = "", when: Optional[Callable[[Any], bool]] = None, stacklevel: int = 2, -): +) -> Callable[[type[M]], type[M]]: """ Mark a field in a Pydantic model as deprecated. @@ -212,7 +221,7 @@ class Model(BaseModel) # Replaces the model's __init__ method with one that performs an additional warning # check - def decorator(model_cls: Type[M]) -> Type[M]: + def decorator(model_cls: type[M]) -> type[M]: message = generate_deprecation_message( name=f"The field {name!r} in {model_cls.__name__!r}", start_date=start_date, @@ -224,7 +233,7 @@ def decorator(model_cls: Type[M]) -> Type[M]: cls_init = model_cls.__init__ @functools.wraps(model_cls.__init__) - def __init__(__pydantic_self__, **data: Any) -> None: + def __init__(__pydantic_self__: M, **data: Any) -> None: if name in data.keys() and when(data[name]): warnings.warn(message, PrefectDeprecationWarning, stacklevel=stacklevel) @@ -232,8 +241,23 @@ def __init__(__pydantic_self__, **data: Any) -> None: field = __pydantic_self__.model_fields.get(name) if field is not None: - field.json_schema_extra = field.json_schema_extra or {} - field.json_schema_extra["deprecated"] = True + json_schema_extra = field.json_schema_extra or {} + + if not isinstance(json_schema_extra, dict): + # json_schema_extra is a hook function; wrap it to add the deprecated flag. + extra_func = json_schema_extra + + @functools.wraps(extra_func) + def wrapped(__json_schema: JsonDict) -> None: + extra_func(__json_schema) + __json_schema["deprecated"] = True + + json_schema_extra = wrapped + + else: + json_schema_extra["deprecated"] = True + + field.json_schema_extra = json_schema_extra # Patch the model's init method model_cls.__init__ = __init__ diff --git a/src/prefect/_internal/compatibility/migration.py b/src/prefect/_internal/compatibility/migration.py index f39739f2c9df..0478be7f4ca5 100644 --- a/src/prefect/_internal/compatibility/migration.py +++ b/src/prefect/_internal/compatibility/migration.py @@ -43,7 +43,7 @@ """ import sys -from typing import Any, Callable, Dict +from typing import Any, Callable from pydantic_core import PydanticCustomError @@ -157,7 +157,7 @@ def wrapper(name: str) -> object: f"`{import_path}` has been removed. {error_message}" ) - globals: Dict[str, Any] = sys.modules[module_name].__dict__ + globals: dict[str, Any] = sys.modules[module_name].__dict__ if name in globals: return globals[name] diff --git a/src/prefect/_internal/concurrency/inspection.py b/src/prefect/_internal/concurrency/inspection.py index 014c77d3ad89..88a6662242a2 100644 --- a/src/prefect/_internal/concurrency/inspection.py +++ b/src/prefect/_internal/concurrency/inspection.py @@ -7,7 +7,6 @@ import sys import threading from types import FrameType -from typing import List, Optional """ The following functions are derived from dask/distributed which is licensed under the @@ -72,26 +71,25 @@ def repr_frame(frame: FrameType) -> str: return text + "\n\t" + line -def call_stack(frame: FrameType) -> List[str]: +def call_stack(frame: FrameType) -> list[str]: """Create a call text stack from a frame""" - L = [] - cur_frame: Optional[FrameType] = frame + frames: list[str] = [] + cur_frame = frame while cur_frame: - L.append(repr_frame(cur_frame)) + frames.append(repr_frame(cur_frame)) cur_frame = cur_frame.f_back - return L[::-1] + return frames[::-1] -def stack_for_threads(*threads: threading.Thread) -> List[str]: - frames = sys._current_frames() +def stack_for_threads(*threads: threading.Thread) -> list[str]: + frames = sys._current_frames() # pyright: ignore[reportPrivateUsage] try: - lines = [] + lines: list[str] = [] for thread in threads: - lines.append( - f"------ Call stack of {thread.name} ({hex(thread.ident)}) -----" - ) - thread_frames = frames.get(thread.ident) - if thread_frames: + ident = thread.ident + hex_ident = hex(ident) if ident is not None else "" + lines.append(f"------ Call stack of {thread.name} ({hex_ident}) -----") + if ident is not None and (thread_frames := frames.get(ident)): lines.append("".join(call_stack(thread_frames))) else: lines.append("No stack frames found") diff --git a/src/prefect/_internal/concurrency/primitives.py b/src/prefect/_internal/concurrency/primitives.py index 4eeb13470514..c762955ffad7 100644 --- a/src/prefect/_internal/concurrency/primitives.py +++ b/src/prefect/_internal/concurrency/primitives.py @@ -27,7 +27,7 @@ class Event: """ def __init__(self) -> None: - self._waiters = collections.deque() + self._waiters: collections.deque[asyncio.Future[bool]] = collections.deque() self._value = False self._lock = threading.Lock() @@ -69,7 +69,7 @@ async def wait(self) -> Literal[True]: if self._value: return True - fut = asyncio.get_running_loop().create_future() + fut: asyncio.Future[bool] = asyncio.get_running_loop().create_future() self._waiters.append(fut) try: diff --git a/src/prefect/_internal/concurrency/waiters.py b/src/prefect/_internal/concurrency/waiters.py index 07522992100d..fc8a4c26b186 100644 --- a/src/prefect/_internal/concurrency/waiters.py +++ b/src/prefect/_internal/concurrency/waiters.py @@ -10,7 +10,7 @@ import queue import threading from collections import deque -from collections.abc import Awaitable +from collections.abc import AsyncGenerator, Awaitable, Generator from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar from weakref import WeakKeyDictionary @@ -24,7 +24,7 @@ T = TypeVar("T") -# Waiters are stored in a stack for each thread +# Waiters are stored in a queue for each thread _WAITERS_BY_THREAD: "WeakKeyDictionary[threading.Thread, deque[Waiter[Any]]]" = ( WeakKeyDictionary() ) @@ -49,8 +49,9 @@ class Waiter(Portal, abc.ABC, Generic[T]): """ def __init__(self, call: Call[T]) -> None: - if not isinstance(call, Call): # Guard against common mistake - raise TypeError(f"Expected call of type `Call`; got {call!r}.") + if not TYPE_CHECKING: + if not isinstance(call, Call): # Guard against common mistake + raise TypeError(f"Expected call of type `Call`; got {call!r}.") self._call = call self._owner_thread = threading.current_thread() @@ -107,7 +108,7 @@ def submit(self, call: Call[T]) -> Call[T]: call.set_runner(self) return call - def _handle_waiting_callbacks(self): + def _handle_waiting_callbacks(self) -> None: logger.debug("Waiter %r watching for callbacks", self) while True: callback = self._queue.get() @@ -121,7 +122,7 @@ def _handle_waiting_callbacks(self): del callback @contextlib.contextmanager - def _handle_done_callbacks(self): + def _handle_done_callbacks(self) -> Generator[None, Any, None]: try: yield finally: @@ -195,10 +196,13 @@ def _resubmit_early_submissions(self) -> None: call_soon_in_loop(self._loop, self._queue.put_nowait, call) self._early_submissions = [] - async def _handle_waiting_callbacks(self): + async def _handle_waiting_callbacks(self) -> None: logger.debug("Waiter %r watching for callbacks", self) tasks: list[Awaitable[None]] = [] + if TYPE_CHECKING: + assert self._queue is not None + try: while True: callback = await self._queue.get() @@ -221,7 +225,7 @@ async def _handle_waiting_callbacks(self): self._done_waiting = True @contextlib.asynccontextmanager - async def _handle_done_callbacks(self): + async def _handle_done_callbacks(self) -> AsyncGenerator[None, Any]: try: yield finally: @@ -244,7 +248,7 @@ def add_done_callback(self, callback: Call[Any]) -> None: else: self._done_callbacks.append(callback) - def _signal_stop_waiting(self): + def _signal_stop_waiting(self) -> None: # Only send a `None` to the queue if the waiter is still blocked reading from # the queue. Otherwise, it's possible that the event loop is stopped. if not self._done_waiting: diff --git a/src/prefect/_internal/pydantic/annotations/pendulum.py b/src/prefect/_internal/pydantic/annotations/pendulum.py index b18ee1a13310..2c3c6906fe72 100644 --- a/src/prefect/_internal/pydantic/annotations/pendulum.py +++ b/src/prefect/_internal/pydantic/annotations/pendulum.py @@ -3,39 +3,39 @@ generation and validation. """ -import typing as t +from typing import Annotated, Any, Union import pendulum from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler from pydantic.json_schema import JsonSchemaValue from pydantic_core import core_schema -from typing_extensions import Annotated class _PendulumDateTimeAnnotation: - _pendulum_type: t.Type[ - t.Union[pendulum.DateTime, pendulum.Date, pendulum.Time] + _pendulum_type: type[ + Union[pendulum.DateTime, pendulum.Date, pendulum.Time, pendulum.Duration] ] = pendulum.DateTime _pendulum_types_to_schemas = { pendulum.DateTime: core_schema.datetime_schema(), pendulum.Date: core_schema.date_schema(), + pendulum.Time: core_schema.time_schema(), pendulum.Duration: core_schema.timedelta_schema(), } @classmethod def __get_pydantic_core_schema__( cls, - _source_type: t.Any, + _source_type: Any, _handler: GetCoreSchemaHandler, ) -> core_schema.CoreSchema: def validate_from_str( value: str, - ) -> t.Union[pendulum.DateTime, pendulum.Date, pendulum.Time]: + ) -> Union[pendulum.DateTime, pendulum.Date, pendulum.Time, pendulum.Duration]: return pendulum.parse(value) def to_str( - value: t.Union[pendulum.DateTime, pendulum.Date, pendulum.Time], + value: Union[pendulum.DateTime, pendulum.Date, pendulum.Time], ) -> str: return value.isoformat() diff --git a/src/prefect/_internal/pytz.py b/src/prefect/_internal/pytz.py index 2a71d73eb0ec..3f887d1ac712 100644 --- a/src/prefect/_internal/pytz.py +++ b/src/prefect/_internal/pytz.py @@ -14,11 +14,12 @@ """ try: - import pytz # noqa: F401 + import pytz as pytz except ImportError: - HAS_PYTZ = False + _has_pytz = False else: - HAS_PYTZ = True + _has_pytz = True +HAS_PYTZ = _has_pytz all_timezones_set = { diff --git a/src/prefect/_internal/retries.py b/src/prefect/_internal/retries.py index 08cc21e9a252..e7c02f4b2ef1 100644 --- a/src/prefect/_internal/retries.py +++ b/src/prefect/_internal/retries.py @@ -1,6 +1,7 @@ import asyncio +from collections.abc import Coroutine from functools import wraps -from typing import Callable, Optional, Tuple, Type, TypeVar +from typing import Any, Callable, Optional, TypeVar from typing_extensions import ParamSpec @@ -25,9 +26,11 @@ def retry_async_fn( ] = exponential_backoff_with_jitter, base_delay: float = 1, max_delay: float = 10, - retry_on_exceptions: Tuple[Type[Exception], ...] = (Exception,), + retry_on_exceptions: tuple[type[Exception], ...] = (Exception,), operation_name: Optional[str] = None, -) -> Callable[[Callable[P, R]], Callable[P, R]]: +) -> Callable[ + [Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, Optional[R]]] +]: """A decorator for retrying an async function. Args: @@ -43,9 +46,11 @@ def retry_async_fn( the function name. If None, uses the function name. """ - def decorator(func: Callable[P, R]) -> Callable[P, R]: + def decorator( + func: Callable[P, Coroutine[Any, Any, R]], + ) -> Callable[P, Coroutine[Any, Any, Optional[R]]]: @wraps(func) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> Optional[R]: name = operation_name or func.__name__ for attempt in range(max_attempts): try: diff --git a/src/prefect/_internal/schemas/bases.py b/src/prefect/_internal/schemas/bases.py index 62804f8b478a..f80792f37fe6 100644 --- a/src/prefect/_internal/schemas/bases.py +++ b/src/prefect/_internal/schemas/bases.py @@ -4,15 +4,12 @@ import datetime import os -from typing import Any, ClassVar, Generator, Optional, Set, TypeVar, cast +from typing import Any, ClassVar, Optional, TypeVar, cast from uuid import UUID, uuid4 import pendulum -from pydantic import ( - BaseModel, - ConfigDict, - Field, -) +from pydantic import BaseModel, ConfigDict, Field +from rich.repr import RichReprResult from typing_extensions import Self from prefect.types import DateTime @@ -32,7 +29,7 @@ class PrefectBaseModel(BaseModel): subtle unintentional testing errors. """ - _reset_fields: ClassVar[Set[str]] = set() + _reset_fields: ClassVar[set[str]] = set() model_config: ClassVar[ConfigDict] = ConfigDict( ser_json_timedelta="float", @@ -59,7 +56,7 @@ def __eq__(self, other: Any) -> bool: else: return copy_dict == other - def __rich_repr__(self) -> Generator[tuple[str, Any, Any], None, None]: + def __rich_repr__(self) -> RichReprResult: # Display all of the fields in the model if they differ from the default value for name, field in self.model_fields.items(): value = getattr(self, name) @@ -102,7 +99,7 @@ class IDBaseModel(PrefectBaseModel): The ID is reset on copy() and not included in equality comparisons. """ - _reset_fields: ClassVar[Set[str]] = {"id"} + _reset_fields: ClassVar[set[str]] = {"id"} id: UUID = Field(default_factory=uuid4) @@ -115,7 +112,7 @@ class ObjectBaseModel(IDBaseModel): equality comparisons. """ - _reset_fields: ClassVar[Set[str]] = {"id", "created", "updated"} + _reset_fields: ClassVar[set[str]] = {"id", "created", "updated"} model_config: ClassVar[ConfigDict] = ConfigDict(from_attributes=True) created: Optional[DateTime] = Field(default=None, repr=False) diff --git a/src/prefect/_internal/schemas/validators.py b/src/prefect/_internal/schemas/validators.py index 52380cce951c..4bbc9c7de15a 100644 --- a/src/prefect/_internal/schemas/validators.py +++ b/src/prefect/_internal/schemas/validators.py @@ -6,41 +6,51 @@ This will be subject to consolidation and refactoring over the next few months. """ -import json +import os import re import urllib.parse import warnings +from collections.abc import Iterable, Mapping, MutableMapping from copy import copy from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, overload from uuid import UUID import jsonschema import pendulum -import yaml +import pendulum.tz -from prefect.exceptions import InvalidRepositoryURLError -from prefect.types import DateTime from prefect.utilities.collections import isiterable -from prefect.utilities.dockerutils import get_prefect_image_name from prefect.utilities.filesystem import relative_path_to_current_platform from prefect.utilities.importtools import from_qualified_name from prefect.utilities.names import generate_slug -from prefect.utilities.pydantic import JsonPatch + +if TYPE_CHECKING: + from prefect.serializers import Serializer + +T = TypeVar("T") +M = TypeVar("M", bound=Mapping[str, Any]) +MM = TypeVar("MM", bound=MutableMapping[str, Any]) + LOWERCASE_LETTERS_NUMBERS_AND_DASHES_ONLY_REGEX = "^[a-z0-9-]*$" LOWERCASE_LETTERS_NUMBERS_AND_UNDERSCORES_REGEX = "^[a-z0-9_]*$" -if TYPE_CHECKING: - from prefect.blocks.core import Block - from prefect.serializers import Serializer - from prefect.utilities.callables import ParameterSchema + +@overload +def raise_on_name_alphanumeric_dashes_only(value: str, field_name: str = ...) -> str: + ... + + +@overload +def raise_on_name_alphanumeric_dashes_only(value: None, field_name: str = ...) -> None: + ... def raise_on_name_alphanumeric_dashes_only( value: Optional[str], field_name: str = "value" -): - if value and not bool( +) -> Optional[str]: + if value is not None and not bool( re.match(LOWERCASE_LETTERS_NUMBERS_AND_DASHES_ONLY_REGEX, value) ): raise ValueError( @@ -49,40 +59,38 @@ def raise_on_name_alphanumeric_dashes_only( return value -def raise_on_name_alphanumeric_underscores_only(value, field_name: str = "value"): - if not bool(re.match(LOWERCASE_LETTERS_NUMBERS_AND_UNDERSCORES_REGEX, value)): - raise ValueError( - f"{field_name} must only contain lowercase letters, numbers, and" - " underscores." - ) - return value +@overload +def raise_on_name_alphanumeric_underscores_only( + value: str, field_name: str = ... +) -> str: + ... -def validate_schema(schema: dict): - """ - Validate that the provided schema is a valid json schema. +@overload +def raise_on_name_alphanumeric_underscores_only( + value: None, field_name: str = ... +) -> None: + ... - Args: - schema: The schema to validate. - Raises: - ValueError: If the provided schema is not a valid json schema. - - """ - try: - if schema is not None: - # Most closely matches the schemas generated by pydantic - jsonschema.Draft202012Validator.check_schema(schema) - except jsonschema.SchemaError as exc: +def raise_on_name_alphanumeric_underscores_only( + value: Optional[str], field_name: str = "value" +) -> Optional[str]: + if value is not None and not re.match( + LOWERCASE_LETTERS_NUMBERS_AND_UNDERSCORES_REGEX, value + ): raise ValueError( - "The provided schema is not a valid json schema. Schema error:" - f" {exc.message}" - ) from exc + f"{field_name} must only contain lowercase letters, numbers, and" + " underscores." + ) + return value def validate_values_conform_to_schema( - values: dict, schema: dict, ignore_required: bool = False -): + values: Optional[Mapping[str, Any]], + schema: Optional[Mapping[str, Any]], + ignore_required: bool = False, +) -> None: """ Validate that the provided values conform to the provided json schema. @@ -127,90 +135,46 @@ def validate_values_conform_to_schema( ### DEPLOYMENT SCHEMA VALIDATORS ### -def infrastructure_must_have_capabilities( - value: Union[Dict[str, Any], "Block", None], -) -> Optional["Block"]: - """ - Ensure that the provided value is an infrastructure block with the required capabilities. - """ - - from prefect.blocks.core import Block - - if isinstance(value, dict): - if "_block_type_slug" in value: - # Replace private attribute with public for dispatch - value["block_type_slug"] = value.pop("_block_type_slug") - block = Block(**value) - elif value is None: - return value - else: - block = value - - if "run-infrastructure" not in block.get_block_capabilities(): - raise ValueError( - "Infrastructure block must have 'run-infrastructure' capabilities." +def validate_parameters_conform_to_schema( + parameters: M, values: Mapping[str, Any] +) -> M: + """Validate that the parameters conform to the parameter schema.""" + if values.get("enforce_parameter_schema"): + validate_values_conform_to_schema( + parameters, values.get("parameter_openapi_schema"), ignore_required=True ) - return block + return parameters -def storage_must_have_capabilities( - value: Union[Dict[str, Any], "Block", None], -) -> Optional["Block"]: - """ - Ensure that the provided value is a storage block with the required capabilities. - """ - from prefect.blocks.core import Block - - if isinstance(value, dict): - block_type = Block.get_block_class_from_key(value.pop("_block_type_slug")) - block = block_type(**value) - elif value is None: - return value - else: - block = value +@overload +def validate_parameter_openapi_schema(schema: M, values: Mapping[str, Any]) -> M: + ... - capabilities = block.get_block_capabilities() - if "get-directory" not in capabilities: - raise ValueError("Remote Storage block must have 'get-directory' capabilities.") - return block - -def handle_openapi_schema(value: Optional["ParameterSchema"]) -> "ParameterSchema": - """ - This method ensures setting a value of `None` is handled gracefully. - """ - from prefect.utilities.callables import ParameterSchema - - if value is None: - return ParameterSchema() - return value - - -def validate_parameters_conform_to_schema(value: dict, values: dict) -> dict: - """Validate that the parameters conform to the parameter schema.""" - if values.get("enforce_parameter_schema"): - validate_values_conform_to_schema( - value, values.get("parameter_openapi_schema"), ignore_required=True - ) - return value +@overload +def validate_parameter_openapi_schema(schema: None, values: Mapping[str, Any]) -> None: + ... -def validate_parameter_openapi_schema(value: dict, values: dict) -> dict: +def validate_parameter_openapi_schema( + schema: Optional[M], values: Mapping[str, Any] +) -> Optional[M]: """Validate that the parameter_openapi_schema is a valid json schema.""" if values.get("enforce_parameter_schema"): - validate_schema(value) - return value - - -def return_none_schedule(v: Optional[Union[str, dict]]) -> Optional[Union[str, dict]]: - from prefect.client.schemas.schedules import NoSchedule + try: + if schema is not None: + # Most closely matches the schemas generated by pydantic + jsonschema.Draft202012Validator.check_schema(schema) + except jsonschema.SchemaError as exc: + raise ValueError( + "The provided schema is not a valid json schema. Schema error:" + f" {exc.message}" + ) from exc - if isinstance(v, NoSchedule): - return None - return v + return schema -def convert_to_strings(value: Union[Any, List[Any]]) -> Union[str, List[str]]: +def convert_to_strings(value: Union[Any, Iterable[Any]]) -> Union[str, list[str]]: if isiterable(value): return [str(item) for item in value] return str(value) @@ -219,7 +183,7 @@ def convert_to_strings(value: Union[Any, List[Any]]) -> Union[str, List[str]]: ### SCHEDULE SCHEMA VALIDATORS ### -def reconcile_schedules_runner(values: dict) -> dict: +def reconcile_schedules_runner(values: MM) -> MM: from prefect.deployments.schedules import ( normalize_to_deployment_schedule_create, ) @@ -231,13 +195,23 @@ def reconcile_schedules_runner(values: dict) -> dict: return values +@overload +def validate_schedule_max_scheduled_runs(v: int, limit: int) -> int: + ... + + +@overload +def validate_schedule_max_scheduled_runs(v: None, limit: int) -> None: + ... + + def validate_schedule_max_scheduled_runs(v: Optional[int], limit: int) -> Optional[int]: if v is not None and v > limit: raise ValueError(f"`max_scheduled_runs` must be less than or equal to {limit}.") return v -def remove_old_deployment_fields(values: dict) -> dict: +def remove_old_deployment_fields(values: MM) -> MM: # 2.7.7 removed worker_pool_queue_id in lieu of worker_pool_name and # worker_pool_queue_name. Those fields were later renamed to work_pool_name # and work_queue_name. This validator removes old fields provided @@ -270,7 +244,7 @@ def remove_old_deployment_fields(values: dict) -> dict: return values_copy -def reconcile_paused_deployment(values): +def reconcile_paused_deployment(values: MM) -> MM: paused = values.get("paused") if paused is None: @@ -279,45 +253,44 @@ def reconcile_paused_deployment(values): return values -def default_anchor_date(v: DateTime) -> DateTime: +def default_anchor_date(v: pendulum.DateTime) -> pendulum.DateTime: return pendulum.instance(v) -def get_valid_timezones(v: Optional[str]) -> Tuple[str, ...]: - # pendulum.tz.timezones is a callable in 3.0 and above - # https://github.com/PrefectHQ/prefect/issues/11619 - if callable(pendulum.tz.timezones): - return pendulum.tz.timezones() - else: - return pendulum.tz.timezones +@overload +def default_timezone(v: str, values: Optional[Mapping[str, Any]] = ...) -> str: + ... -def validate_timezone(v: str, timezones: Tuple[str, ...]) -> str: - if v and v not in timezones: - raise ValueError( - f'Invalid timezone: "{v}" (specify in IANA tzdata format, for example,' - " America/New_York)" - ) - return v +@overload +def default_timezone( + v: None, values: Optional[Mapping[str, Any]] = ... +) -> Optional[str]: + ... -def default_timezone(v: Optional[str], values: Optional[dict] = None) -> str: +def default_timezone( + v: Optional[str], values: Optional[Mapping[str, Any]] = None +) -> Optional[str]: values = values or {} - timezones = get_valid_timezones(v) + timezones = pendulum.tz.timezones() if v is not None: - return validate_timezone(v, timezones) + if v and v not in timezones: + raise ValueError( + f'Invalid timezone: "{v}" (specify in IANA tzdata format, for example,' + " America/New_York)" + ) + return v # anchor schedules - elif v is None and values and values.get("anchor_date"): - tz = getattr(values["anchor_date"].tz, "name", None) or "UTC" - if tz in timezones: - return tz + elif "anchor_date" in values: + anchor_date: pendulum.DateTime = values["anchor_date"] + tz = "UTC" if anchor_date.tz is None else anchor_date.tz.name # sometimes anchor dates have "timezones" that are UTC offsets # like "-04:00". This happens when parsing ISO8601 strings. # In this case we, the correct inferred localization is "UTC". - else: - return "UTC" + return tz if tz in timezones else "UTC" # cron schedules return v @@ -360,119 +333,18 @@ def validate_rrule_string(v: str) -> str: return v -### INFRASTRUCTURE SCHEMA VALIDATORS ### - - -def validate_k8s_job_required_components(cls, value: Dict[str, Any]): - """ - Validate that a Kubernetes job manifest has all required components. - """ - from prefect.utilities.pydantic import JsonPatch - - patch = JsonPatch.from_diff(value, cls.base_job_manifest()) - missing_paths = sorted([op["path"] for op in patch if op["op"] == "add"]) - if missing_paths: - raise ValueError( - "Job is missing required attributes at the following paths: " - f"{', '.join(missing_paths)}" - ) - return value - - -def validate_k8s_job_compatible_values(cls, value: Dict[str, Any]): - """ - Validate that the provided job values are compatible with the job type. - """ - from prefect.utilities.pydantic import JsonPatch - - patch = JsonPatch.from_diff(value, cls.base_job_manifest()) - incompatible = sorted( - [ - f"{op['path']} must have value {op['value']!r}" - for op in patch - if op["op"] == "replace" - ] - ) - if incompatible: - raise ValueError( - "Job has incompatible values for the following attributes: " - f"{', '.join(incompatible)}" - ) - return value - - -def cast_k8s_job_customizations( - cls, value: Union[JsonPatch, str, List[Dict[str, Any]]] -): - if isinstance(value, list): - return JsonPatch(value) - elif isinstance(value, str): - try: - return JsonPatch(json.loads(value)) - except json.JSONDecodeError as exc: - raise ValueError( - f"Unable to parse customizations as JSON: {value}. Please make sure" - " that the provided value is a valid JSON string." - ) from exc - return value - - -def set_default_namespace(values: dict) -> dict: - """ - Set the default namespace for a Kubernetes job if not provided. - """ - job = values.get("job") - - namespace = values.get("namespace") - job_namespace = job["metadata"].get("namespace") if job else None - - if not namespace and not job_namespace: - values["namespace"] = "default" - - return values - - -def set_default_image(values: dict) -> dict: - """ - Set the default image for a Kubernetes job if not provided. - """ - - job = values.get("job") - image = values.get("image") - job_image = ( - job["spec"]["template"]["spec"]["containers"][0].get("image") if job else None - ) - - if not image and not job_image: - values["image"] = get_prefect_image_name() - - return values - - ### STATE SCHEMA VALIDATORS ### -def get_or_create_state_name(v: str, values: dict) -> str: - """If a name is not provided, use the type""" - - # if `type` is not in `values` it means the `type` didn't pass its own - # validation check and an error will be raised after this function is called - if v is None and values.get("type"): - v = " ".join([v.capitalize() for v in values.get("type").value.split("_")]) - return v - - -def get_or_create_run_name(name): +def get_or_create_run_name(name: Optional[str]) -> str: return name or generate_slug(2) ### FILESYSTEM SCHEMA VALIDATORS ### -def stringify_path(value: Union[str, Path]) -> str: - if isinstance(value, Path): - return str(value) - return value +def stringify_path(value: Union[str, os.PathLike[str]]) -> str: + return os.fspath(value) def validate_basepath(value: str) -> str: @@ -495,25 +367,6 @@ def validate_basepath(value: str) -> str: return value -def validate_github_access_token(v: str, values: dict) -> str: - """Ensure that credentials are not provided with 'SSH' formatted GitHub URLs. - - Note: validates `access_token` specifically so that it only fires when - private repositories are used. - """ - if v is not None: - if urllib.parse.urlparse(values["repository"]).scheme != "https": - raise InvalidRepositoryURLError( - "Crendentials can only be used with GitHub repositories " - "using the 'HTTPS' format. You must either remove the " - "credential if you wish to use the 'SSH' format and are not " - "using a private repository, or you must change the repository " - "URL to the 'HTTPS' format. " - ) - - return v - - ### SERIALIZER SCHEMA VALIDATORS ### @@ -537,49 +390,7 @@ def validate_picklelib(value: str) -> str: return value -def validate_picklelib_version(values: dict) -> dict: - """ - Infers a default value for `picklelib_version` if null or ensures it matches - the version retrieved from the `pickelib`. - """ - picklelib = values.get("picklelib") - picklelib_version = values.get("picklelib_version") - - if not picklelib: - raise ValueError("Unable to check version of unrecognized picklelib module") - - pickler = from_qualified_name(picklelib) - pickler_version = getattr(pickler, "__version__", None) - - if not picklelib_version: - values["picklelib_version"] = pickler_version - elif picklelib_version != pickler_version: - warnings.warn( - ( - f"Mismatched {picklelib!r} versions. Found {pickler_version} in the" - f" environment but {picklelib_version} was requested. This may" - " cause the serializer to fail." - ), - RuntimeWarning, - stacklevel=3, - ) - - return values - - -def validate_picklelib_and_modules(values: dict) -> dict: - """ - Prevents modules from being specified if picklelib is not cloudpickle - """ - if values.get("picklelib") != "cloudpickle" and values.get("pickle_modules"): - raise ValueError( - "`pickle_modules` cannot be used without 'cloudpickle'. Got" - f" {values.get('picklelib')!r}." - ) - return values - - -def validate_dump_kwargs(value: dict[str, Any]) -> dict[str, Any]: +def validate_dump_kwargs(value: M) -> M: # `default` is set by `object_encoder`. A user provided callable would make this # class unserializable anyway. if "default" in value: @@ -587,7 +398,7 @@ def validate_dump_kwargs(value: dict[str, Any]) -> dict[str, Any]: return value -def validate_load_kwargs(value: dict[str, Any]) -> dict[str, Any]: +def validate_load_kwargs(value: M) -> M: # `object_hook` is set by `object_decoder`. A user provided callable would make # this class unserializable anyway. if "object_hook" in value: @@ -597,7 +408,19 @@ def validate_load_kwargs(value: dict[str, Any]) -> dict[str, Any]: return value -def cast_type_names_to_serializers(value: Union[str, "Serializer"]) -> "Serializer": +@overload +def cast_type_names_to_serializers(value: str) -> "Serializer[Any]": + ... + + +@overload +def cast_type_names_to_serializers(value: "Serializer[T]") -> "Serializer[T]": + ... + + +def cast_type_names_to_serializers( + value: Union[str, "Serializer[Any]"], +) -> "Serializer[Any]": from prefect.serializers import Serializer if isinstance(value, str): @@ -631,19 +454,49 @@ def validate_compressionlib(value: str) -> str: # TODO: if we use this elsewhere we can change the error message to be more generic -def list_length_50_or_less(v: Optional[List[float]]) -> Optional[List[float]]: +@overload +def list_length_50_or_less(v: list[float]) -> list[float]: + ... + + +@overload +def list_length_50_or_less(v: None) -> None: + ... + + +def list_length_50_or_less(v: Optional[list[float]]) -> Optional[list[float]]: if isinstance(v, list) and (len(v) > 50): raise ValueError("Can not configure more than 50 retry delays per task.") return v # TODO: if we use this elsewhere we can change the error message to be more generic +@overload +def validate_not_negative(v: float) -> float: + ... + + +@overload +def validate_not_negative(v: None) -> None: + ... + + def validate_not_negative(v: Optional[float]) -> Optional[float]: if v is not None and v < 0: raise ValueError("`retry_jitter_factor` must be >= 0.") return v +@overload +def validate_message_template_variables(v: str) -> str: + ... + + +@overload +def validate_message_template_variables(v: None) -> None: + ... + + def validate_message_template_variables(v: Optional[str]) -> Optional[str]: from prefect.client.schemas.objects import FLOW_RUN_NOTIFICATION_TEMPLATE_KWARGS @@ -665,11 +518,19 @@ def validate_default_queue_id_not_none(v: Optional[UUID]) -> UUID: return v -def validate_max_metadata_length( - v: Optional[Dict[str, Any]], -) -> Optional[Dict[str, Any]]: +@overload +def validate_max_metadata_length(v: MM) -> MM: + ... + + +@overload +def validate_max_metadata_length(v: None) -> None: + ... + + +def validate_max_metadata_length(v: Optional[MM]) -> Optional[MM]: max_metadata_length = 500 - if not isinstance(v, dict): + if v is None: return v for key in v.keys(): if len(str(v[key])) > max_metadata_length: @@ -677,79 +538,17 @@ def validate_max_metadata_length( return v -### DOCKER SCHEMA VALIDATORS ### - - -def validate_registry_url(value: Optional[str]) -> Optional[str]: - if isinstance(value, str): - if "://" not in value: - return "https://" + value - return value - - -def convert_labels_to_docker_format(labels: Dict[str, str]) -> Dict[str, str]: - labels = labels or {} - new_labels = {} - for name, value in labels.items(): - if "/" in name: - namespace, key = name.split("/", maxsplit=1) - new_namespace = ".".join(reversed(namespace.split("."))) - new_labels[f"{new_namespace}.{key}"] = value - else: - new_labels[name] = value - return new_labels - - -def check_volume_format(volumes: List[str]) -> List[str]: - for volume in volumes: - if ":" not in volume: - raise ValueError( - "Invalid volume specification. " - f"Expected format 'path:container_path', but got {volume!r}" - ) - - return volumes - - -def base_image_xor_dockerfile(values: Mapping[str, Any]): - if values.get("base_image") and values.get("dockerfile"): - raise ValueError( - "Either `base_image` or `dockerfile` should be provided, but not both" - ) - return values - - -### SETTINGS SCHEMA VALIDATORS ### - +### TASK RUN SCHEMA VALIDATORS ### -def validate_settings(value: dict) -> dict: - from prefect.settings import Setting, Settings - from prefect.settings.legacy import _get_settings_fields - if value is None: - return value +@overload +def validate_cache_key_length(cache_key: str) -> str: + ... - # Cast string setting names to variables - validated = {} - for setting, val in value.items(): - settings_fields = _get_settings_fields(Settings) - if isinstance(setting, str) and setting in settings_fields: - validated[settings_fields[setting]] = val - elif isinstance(setting, Setting): - validated[setting] = val - else: - warnings.warn(f"Setting {setting!r} is not recognized and will be ignored.") - return validated - - -def validate_yaml(value: Union[str, dict]) -> dict: - if isinstance(value, str): - return yaml.safe_load(value) - return value - - -### TASK RUN SCHEMA VALIDATORS ### +@overload +def validate_cache_key_length(cache_key: None) -> None: + ... def validate_cache_key_length(cache_key: Optional[str]) -> Optional[str]: @@ -765,7 +564,7 @@ def validate_cache_key_length(cache_key: Optional[str]) -> Optional[str]: return cache_key -def set_run_policy_deprecated_fields(values: dict) -> dict: +def set_run_policy_deprecated_fields(values: MM) -> MM: """ If deprecated fields are provided, populate the corresponding new fields to preserve orchestration behavior. @@ -785,6 +584,16 @@ def set_run_policy_deprecated_fields(values: dict) -> dict: ### PYTHON ENVIRONMENT SCHEMA VALIDATORS ### +@overload +def return_v_or_none(v: str) -> str: + ... + + +@overload +def return_v_or_none(v: None) -> None: + ... + + def return_v_or_none(v: Optional[str]) -> Optional[str]: """Make sure that empty strings are treated as None""" if not v: @@ -795,7 +604,7 @@ def return_v_or_none(v: Optional[str]) -> Optional[str]: ### BLOCK SCHEMA VALIDATORS ### -def validate_parent_and_ref_diff(values: dict) -> dict: +def validate_parent_and_ref_diff(values: M) -> M: parent_id = values.get("parent_block_document_id") ref_id = values.get("reference_block_document_id") if parent_id and ref_id and parent_id == ref_id: @@ -806,7 +615,7 @@ def validate_parent_and_ref_diff(values: dict) -> dict: return values -def validate_name_present_on_nonanonymous_blocks(values: dict) -> dict: +def validate_name_present_on_nonanonymous_blocks(values: M) -> M: # anonymous blocks may have no name prior to actually being # stored in the database if not values.get("is_anonymous") and not values.get("name"): @@ -817,9 +626,19 @@ def validate_name_present_on_nonanonymous_blocks(values: dict) -> dict: ### PROCESS JOB CONFIGURATION VALIDATORS ### +@overload def validate_command(v: str) -> Path: + ... + + +@overload +def validate_command(v: None) -> None: + ... + + +def validate_command(v: Optional[str]) -> Optional[Path]: """Make sure that the working directory is formatted for the current platform.""" - if v: + if v is not None: return relative_path_to_current_platform(v) return v @@ -830,23 +649,43 @@ def validate_command(v: str) -> Path: # catch-all for validators until we organize these into files -def validate_block_document_name(value): +@overload +def validate_block_document_name(value: str) -> str: + ... + + +@overload +def validate_block_document_name(value: None) -> None: + ... + + +def validate_block_document_name(value: Optional[str]) -> Optional[str]: if value is not None: raise_on_name_alphanumeric_dashes_only(value, field_name="Block document name") return value -def validate_artifact_key(value): +def validate_artifact_key(value: str) -> str: raise_on_name_alphanumeric_dashes_only(value, field_name="Artifact key") return value -def validate_variable_name(value): +@overload +def validate_variable_name(value: str) -> str: + ... + + +@overload +def validate_variable_name(value: None) -> None: + ... + + +def validate_variable_name(value: Optional[str]) -> Optional[str]: if value is not None: raise_on_name_alphanumeric_underscores_only(value, field_name="Variable name") return value -def validate_block_type_slug(value): +def validate_block_type_slug(value: str): raise_on_name_alphanumeric_dashes_only(value, field_name="Block type slug") return value diff --git a/src/prefect/serializers.py b/src/prefect/serializers.py index bbd188cd5907..02e4b6f1801d 100644 --- a/src/prefect/serializers.py +++ b/src/prefect/serializers.py @@ -11,9 +11,8 @@ bytes to an object respectively. """ -import abc import base64 -from typing import Any, Generic, Optional, Type, Union +from typing import Any, ClassVar, Generic, Optional, Union, overload from pydantic import ( BaseModel, @@ -54,7 +53,7 @@ def prefect_json_object_encoder(obj: Any) -> Any: } -def prefect_json_object_decoder(result: dict[str, Any]): +def prefect_json_object_decoder(result: dict[str, Any]) -> Any: """ `JSONDecoder.object_hook` for decoding objects from JSON when previously encoded with `prefect_json_object_encoder` @@ -70,7 +69,7 @@ def prefect_json_object_decoder(result: dict[str, Any]): @register_base_type -class Serializer(BaseModel, Generic[D], abc.ABC): +class Serializer(BaseModel, Generic[D]): """ A serializer that can encode objects of type 'D' into bytes. """ @@ -80,10 +79,18 @@ def __init__(self, **data: Any) -> None: data.setdefault("type", type_string) super().__init__(**data) - def __new__(cls: Type[Self], **kwargs: Any) -> Self: - if "type" in kwargs: + @overload + def __new__(cls, *, type: str, **kwargs: Any) -> "Serializer[Any]": + ... + + @overload + def __new__(cls, *, type: None = ..., **kwargs: Any) -> Self: + ... + + def __new__(cls, **kwargs: Any) -> Union[Self, "Serializer[Any]"]: + if type_ := kwargs.get("type"): try: - subcls = lookup_type(cls, dispatch_key=kwargs["type"]) + subcls = lookup_type(cls, dispatch_key=type_) except KeyError as exc: raise ValidationError.from_exception_data( title=cls.__name__, @@ -97,15 +104,15 @@ def __new__(cls: Type[Self], **kwargs: Any) -> Self: type: str - @abc.abstractmethod def dumps(self, obj: D) -> bytes: """Encode the object into a blob of bytes.""" + raise NotImplementedError - @abc.abstractmethod def loads(self, blob: bytes) -> D: """Decode the blob of bytes into an object.""" + raise NotImplementedError - model_config = ConfigDict(extra="forbid") + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") @classmethod def __dispatch_key__(cls) -> Optional[str]: @@ -113,7 +120,7 @@ def __dispatch_key__(cls) -> Optional[str]: return type_str if isinstance(type_str, str) else None -class PickleSerializer(Serializer): +class PickleSerializer(Serializer[D]): """ Serializes objects using the pickle protocol. @@ -132,17 +139,17 @@ class PickleSerializer(Serializer): def check_picklelib(cls, value: str) -> str: return validate_picklelib(value) - def dumps(self, obj: Any) -> bytes: + def dumps(self, obj: D) -> bytes: pickler = from_qualified_name(self.picklelib) blob = pickler.dumps(obj) return base64.encodebytes(blob) - def loads(self, blob: bytes) -> Any: + def loads(self, blob: bytes) -> D: pickler = from_qualified_name(self.picklelib) return pickler.loads(base64.decodebytes(blob)) -class JSONSerializer(Serializer): +class JSONSerializer(Serializer[D]): """ Serializes data to JSON. @@ -186,7 +193,7 @@ def loads_kwargs_cannot_contain_object_hook( ) -> dict[str, Any]: return validate_load_kwargs(value) - def dumps(self, obj: Any) -> bytes: + def dumps(self, obj: D) -> bytes: json = from_qualified_name(self.jsonlib) kwargs = self.dumps_kwargs.copy() if self.object_encoder: @@ -197,7 +204,7 @@ def dumps(self, obj: Any) -> bytes: result = result.encode() return result - def loads(self, blob: bytes) -> Any: + def loads(self, blob: bytes) -> D: json = from_qualified_name(self.jsonlib) kwargs = self.loads_kwargs.copy() if self.object_decoder: @@ -205,7 +212,7 @@ def loads(self, blob: bytes) -> Any: return json.loads(blob.decode(), **kwargs) -class CompressedSerializer(Serializer): +class CompressedSerializer(Serializer[D]): """ Wraps another serializer, compressing its output. Uses `lzma` by default. See `compressionlib` for using alternative libraries. @@ -219,43 +226,43 @@ class CompressedSerializer(Serializer): type: str = Field(default="compressed", frozen=True) - serializer: Serializer + serializer: Serializer[D] compressionlib: str = "lzma" @field_validator("serializer", mode="before") - def validate_serializer(cls, value: Union[str, Serializer]) -> Serializer: + def validate_serializer(cls, value: Union[str, Serializer[D]]) -> Serializer[D]: return cast_type_names_to_serializers(value) @field_validator("compressionlib") def check_compressionlib(cls, value: str) -> str: return validate_compressionlib(value) - def dumps(self, obj: Any) -> bytes: + def dumps(self, obj: D) -> bytes: blob = self.serializer.dumps(obj) compressor = from_qualified_name(self.compressionlib) return base64.encodebytes(compressor.compress(blob)) - def loads(self, blob: bytes) -> Any: + def loads(self, blob: bytes) -> D: compressor = from_qualified_name(self.compressionlib) uncompressed = compressor.decompress(base64.decodebytes(blob)) return self.serializer.loads(uncompressed) -class CompressedPickleSerializer(CompressedSerializer): +class CompressedPickleSerializer(CompressedSerializer[D]): """ A compressed serializer preconfigured to use the pickle serializer. """ type: str = Field(default="compressed/pickle", frozen=True) - serializer: Serializer = Field(default_factory=PickleSerializer) + serializer: Serializer[D] = Field(default_factory=PickleSerializer) -class CompressedJSONSerializer(CompressedSerializer): +class CompressedJSONSerializer(CompressedSerializer[D]): """ A compressed serializer preconfigured to use the json serializer. """ type: str = Field(default="compressed/json", frozen=True) - serializer: Serializer = Field(default_factory=JSONSerializer) + serializer: Serializer[D] = Field(default_factory=JSONSerializer) diff --git a/src/prefect/server/events/messaging.py b/src/prefect/server/events/messaging.py index 1e09fb275043..a71841b2383a 100644 --- a/src/prefect/server/events/messaging.py +++ b/src/prefect/server/events/messaging.py @@ -1,4 +1,4 @@ -from typing import Iterable +from collections.abc import Iterable from prefect.logging import get_logger from prefect.server.events.schemas.events import ReceivedEvent diff --git a/src/prefect/server/utilities/encryption.py b/src/prefect/server/utilities/encryption.py index 31a23af013d4..4296971b7e47 100644 --- a/src/prefect/server/utilities/encryption.py +++ b/src/prefect/server/utilities/encryption.py @@ -4,13 +4,16 @@ import json import os +from collections.abc import Mapping +from typing import Any from cryptography.fernet import Fernet +from sqlalchemy.ext.asyncio import AsyncSession from prefect.server import schemas -async def get_fernet_encryption(session): +async def _get_fernet_encryption(session: AsyncSession) -> Fernet: from prefect.server.models import configuration environment_key = os.getenv( @@ -34,13 +37,13 @@ async def get_fernet_encryption(session): return Fernet(encryption_key) -async def encrypt_fernet(session, data: dict): - fernet = await get_fernet_encryption(session) +async def encrypt_fernet(session: AsyncSession, data: Mapping[str, Any]) -> str: + fernet = await _get_fernet_encryption(session) byte_blob = json.dumps(data).encode() return fernet.encrypt(byte_blob).decode() -async def decrypt_fernet(session, data: dict): - fernet = await get_fernet_encryption(session) +async def decrypt_fernet(session: AsyncSession, data: str) -> dict[str, Any]: + fernet = await _get_fernet_encryption(session) byte_blob = data.encode() return json.loads(fernet.decrypt(byte_blob).decode()) diff --git a/src/prefect/server/utilities/messaging/__init__.py b/src/prefect/server/utilities/messaging/__init__.py index 476fe7a6ab08..7696a2218dbc 100644 --- a/src/prefect/server/utilities/messaging/__init__.py +++ b/src/prefect/server/utilities/messaging/__init__.py @@ -1,23 +1,9 @@ import abc -from contextlib import asynccontextmanager +from contextlib import asynccontextmanager, AbstractAsyncContextManager from dataclasses import dataclass import importlib -from typing import ( - Any, - AsyncContextManager, - AsyncGenerator, - Awaitable, - Callable, - Dict, - List, - Optional, - Protocol, - Type, - TypeVar, - Union, - runtime_checkable, -) -from typing_extensions import Self +from typing import Any, Callable, Optional, Protocol, TypeVar, Union, runtime_checkable +from collections.abc import AsyncGenerator, Awaitable, Iterable, Mapping from prefect.settings import PREFECT_MESSAGING_CACHE, PREFECT_MESSAGING_BROKER from prefect.logging import get_logger @@ -25,16 +11,21 @@ logger = get_logger(__name__) +M = TypeVar("M", bound="Message", covariant=True) + + class Message(Protocol): """ A protocol representing a message sent to a message broker. """ - data: Union[bytes, str] - attributes: Dict[str, Any] - + @property + def data(self) -> Union[str, bytes]: + ... -M = TypeVar("M", bound=Message) + @property + def attributes(self) -> Mapping[str, Any]: + ... class Cache(abc.ABC): @@ -43,36 +34,40 @@ async def clear_recently_seen_messages(self) -> None: ... @abc.abstractmethod - async def without_duplicates(self, attribute: str, messages: List[M]) -> List[M]: + async def without_duplicates( + self, attribute: str, messages: Iterable[M] + ) -> list[M]: ... @abc.abstractmethod - async def forget_duplicates(self, attribute: str, messages: List[M]) -> None: + async def forget_duplicates( + self, attribute: str, messages: Iterable[Message] + ) -> None: ... -class Publisher(abc.ABC): - @abc.abstractmethod - async def __aenter__(self) -> Self: - ... - - @abc.abstractmethod - async def __aexit__(self, exc_type, exc_value, traceback): +class Publisher(AbstractAsyncContextManager["Publisher"], abc.ABC): + def __init__( + self, + topic: str, + cache: Optional[Cache] = None, + deduplicate_by: Optional[str] = None, + ) -> None: ... @abc.abstractmethod - async def publish_data(self, data: bytes, attributes: Dict[str, str]): + async def publish_data(self, data: bytes, attributes: Mapping[str, str]) -> None: ... @dataclass class CapturedMessage: data: bytes - attributes: Dict[str, str] + attributes: Mapping[str, str] class CapturingPublisher(Publisher): - messages: List[CapturedMessage] = [] + messages: list[CapturedMessage] = [] deduplicate_by: Optional[str] def __init__( @@ -85,13 +80,10 @@ def __init__( self.cache = cache or create_cache() self.deduplicate_by = deduplicate_by - async def __aenter__(self) -> Self: - return self - - async def __aexit__(self, exc_type, exc_value, traceback): + async def __aexit__(self, *args: Any) -> None: pass - async def publish_data(self, data: bytes, attributes: Dict[str, str]): + async def publish_data(self, data: bytes, attributes: Mapping[str, str]): to_publish = [CapturedMessage(data, attributes)] if self.deduplicate_by: @@ -120,6 +112,9 @@ class Consumer(abc.ABC): call a handler function for each message received. """ + def __init__(self, topic: str, **kwargs: Any) -> None: + self.topic = topic + @abc.abstractmethod async def run(self, handler: MessageHandler) -> None: """Runs the consumer (indefinitely)""" @@ -128,7 +123,7 @@ async def run(self, handler: MessageHandler) -> None: @runtime_checkable class CacheModule(Protocol): - Cache: Type[Cache] + Cache: type[Cache] def create_cache() -> Cache: @@ -144,13 +139,15 @@ def create_cache() -> Cache: @runtime_checkable class BrokerModule(Protocol): - Publisher: Type[Publisher] - Consumer: Type[Consumer] - ephemeral_subscription: Callable[[str], AsyncGenerator[Dict[str, Any], None]] + Publisher: type[Publisher] + Consumer: type[Consumer] + ephemeral_subscription: Callable[ + [str], AbstractAsyncContextManager[Mapping[str, Any]] + ] # Used for testing: a context manager that breaks the topic in a way that raises # a ValueError("oops") when attempting to publish a message. - break_topic: Callable[[], AsyncContextManager[None]] + break_topic: Callable[[], AbstractAsyncContextManager[None]] def create_publisher( @@ -171,7 +168,7 @@ def create_publisher( @asynccontextmanager -async def ephemeral_subscription(topic: str) -> AsyncGenerator[Dict[str, Any], None]: +async def ephemeral_subscription(topic: str) -> AsyncGenerator[Mapping[str, Any], Any]: """ Creates an ephemeral subscription to the given source, removing it when the context exits. @@ -182,7 +179,7 @@ async def ephemeral_subscription(topic: str) -> AsyncGenerator[Dict[str, Any], N yield consumer_create_kwargs -def create_consumer(topic: str, **kwargs) -> Consumer: +def create_consumer(topic: str, **kwargs: Any) -> Consumer: """ Creates a new consumer with the applications default settings. Args: diff --git a/src/prefect/server/utilities/messaging/memory.py b/src/prefect/server/utilities/messaging/memory.py index e971dca00b46..95934375a7d5 100644 --- a/src/prefect/server/utilities/messaging/memory.py +++ b/src/prefect/server/utilities/messaging/memory.py @@ -1,25 +1,16 @@ import asyncio import copy +from collections.abc import AsyncGenerator, Iterable, Mapping, MutableMapping from contextlib import asynccontextmanager from dataclasses import asdict, dataclass from datetime import timedelta from pathlib import Path -from typing import ( - Any, - AsyncGenerator, - Dict, - List, - MutableMapping, - Optional, - TypeVar, - Union, -) +from typing import Any, Optional, TypeVar, Union from uuid import uuid4 import anyio from cachetools import TTLCache from pydantic_core import to_json -from typing_extensions import Self from prefect.logging import get_logger from prefect.server.utilities.messaging import Cache as _Cache @@ -34,7 +25,7 @@ @dataclass class MemoryMessage: data: Union[bytes, str] - attributes: Dict[str, Any] + attributes: Mapping[str, Any] retry_count: int = 0 @@ -71,8 +62,8 @@ def __init__( if dead_letter_queue_path else get_current_settings().home / "dlq" ) - self._queue = asyncio.Queue() - self._retry = asyncio.Queue() + self._queue: asyncio.Queue[MemoryMessage] = asyncio.Queue() + self._retry: asyncio.Queue[MemoryMessage] = asyncio.Queue() async def deliver(self, message: MemoryMessage) -> None: """ @@ -108,7 +99,7 @@ async def get(self) -> MemoryMessage: """ Get a message from the subscription's queue. """ - if self._retry.qsize() > 0: + if not self._retry.empty(): return await self._retry.get() return await self._queue.get() @@ -132,17 +123,17 @@ async def send_to_dead_letter_queue(self, message: MemoryMessage) -> None: class Topic: - _topics: Dict[str, "Topic"] = {} + _topics: dict[str, "Topic"] = {} name: str - _subscriptions: List[Subscription] + _subscriptions: list[Subscription] def __init__(self, name: str) -> None: self.name = name self._subscriptions = [] @classmethod - def by_name(cls, name: str) -> Self: + def by_name(cls, name: str) -> "Topic": try: return cls._topics[name] except KeyError: @@ -151,7 +142,7 @@ def by_name(cls, name: str) -> Self: return topic @classmethod - def clear_all(cls): + def clear_all(cls) -> None: for topic in cls._topics.values(): topic.clear() cls._topics = {} @@ -200,12 +191,14 @@ class Cache(_Cache): async def clear_recently_seen_messages(self) -> None: self._recently_seen_messages.clear() - async def without_duplicates(self, attribute: str, messages: List[M]) -> List[M]: - messages_with_attribute = [] - messages_without_attribute = [] + async def without_duplicates( + self, attribute: str, messages: Iterable[M] + ) -> list[M]: + messages_with_attribute: list[M] = [] + messages_without_attribute: list[M] = [] for m in messages: - if m.attributes is None or attribute not in m.attributes: + if not m.attributes or attribute not in m.attributes: logger.warning( "Message is missing deduplication attribute %r", attribute, @@ -222,9 +215,9 @@ async def without_duplicates(self, attribute: str, messages: List[M]) -> List[M] return messages_with_attribute + messages_without_attribute - async def forget_duplicates(self, attribute: str, messages: List[M]) -> None: + async def forget_duplicates(self, attribute: str, messages: Iterable[M]) -> None: for m in messages: - if m.attributes is None or attribute not in m.attributes: + if not m.attributes or attribute not in m.attributes: logger.warning( "Message is missing deduplication attribute %r", attribute, @@ -240,13 +233,10 @@ def __init__(self, topic: str, cache: Cache, deduplicate_by: Optional[str] = Non self.deduplicate_by = deduplicate_by self._cache = cache - async def __aenter__(self) -> Self: - return self - - async def __aexit__(self, exc_type, exc_value, traceback): + async def __aexit__(self, *args: Any) -> None: return None - async def publish_data(self, data: bytes, attributes: Dict[str, str]): + async def publish_data(self, data: bytes, attributes: Mapping[str, str]) -> None: to_publish = [MemoryMessage(data, attributes)] if self.deduplicate_by: to_publish = await self._cache.without_duplicates( @@ -284,7 +274,7 @@ async def run(self, handler: MessageHandler) -> None: @asynccontextmanager -async def ephemeral_subscription(topic: str) -> AsyncGenerator[Dict[str, Any], None]: +async def ephemeral_subscription(topic: str) -> AsyncGenerator[Mapping[str, Any], None]: subscription = Topic.by_name(topic).subscribe() try: yield {"topic": topic, "subscription": subscription} diff --git a/src/prefect/server/utilities/server.py b/src/prefect/server/utilities/server.py index 1b9fe107b53a..03bf81c82f8f 100644 --- a/src/prefect/server/utilities/server.py +++ b/src/prefect/server/utilities/server.py @@ -2,24 +2,26 @@ Utilities for the Prefect REST API server. """ +from collections.abc import Coroutine, Sequence from contextlib import AsyncExitStack -from typing import Any, Callable, Coroutine, Sequence, Set, get_type_hints +from typing import TYPE_CHECKING, Any, Callable, get_type_hints from fastapi import APIRouter, Request, Response, status -from fastapi.routing import APIRoute, BaseRoute +from fastapi.routing import APIRoute +from starlette.routing import BaseRoute from starlette.routing import Route as StarletteRoute -def method_paths_from_routes(routes: Sequence[BaseRoute]) -> Set[str]: +def method_paths_from_routes(routes: Sequence[BaseRoute]) -> set[str]: """ Generate a set of strings describing the given routes in the format: For example, "GET /logs/" """ - method_paths = set() + method_paths: set[str] = set() for route in routes: if isinstance(route, (APIRoute, StarletteRoute)): - for method in route.methods: + for method in route.methods or (): method_paths.add(f"{method} {route.path}") return method_paths @@ -42,10 +44,13 @@ def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response] async def handle_response_scoped_depends(request: Request) -> Response: # Create a new stack scoped to exit before the response is returned + response = None async with AsyncExitStack() as stack: request.state.response_scoped_stack = stack response = await default_handler(request) + if TYPE_CHECKING: + assert response is not None return response return handle_response_scoped_depends diff --git a/src/prefect/server/utilities/subscriptions.py b/src/prefect/server/utilities/subscriptions.py index 578b8d95a5dd..8d4616ec9730 100644 --- a/src/prefect/server/utilities/subscriptions.py +++ b/src/prefect/server/utilities/subscriptions.py @@ -1,13 +1,8 @@ import asyncio from typing import Optional -from fastapi import ( - WebSocket, -) -from starlette.status import ( - WS_1002_PROTOCOL_ERROR, - WS_1008_POLICY_VIOLATION, -) +from fastapi import WebSocket +from starlette.status import WS_1002_PROTOCOL_ERROR, WS_1008_POLICY_VIOLATION from starlette.websockets import WebSocketDisconnect from websockets.exceptions import ConnectionClosed diff --git a/src/prefect/server/utilities/user_templates.py b/src/prefect/server/utilities/user_templates.py index aec6c63ee511..8c721789f0cb 100644 --- a/src/prefect/server/utilities/user_templates.py +++ b/src/prefect/server/utilities/user_templates.py @@ -1,6 +1,6 @@ """Utilities to support safely rendering user-supplied templates""" -from typing import Any, Dict, List, Optional, Set +from typing import Any, Optional import jinja2.sandbox from jinja2 import ChainableUndefined, nodes @@ -46,7 +46,7 @@ def __init__(self, message: Optional[str] = None, line_number: int = 0) -> None: super().__init__(message) -def register_user_template_filters(filters: Dict[str, Any]): +def register_user_template_filters(filters: dict[str, Any]): """Register additional filters that will be available to user templates""" _template_environment.filters.update(filters) @@ -89,8 +89,8 @@ def _nested_loop_depth(node: nodes.Node, depth: int = 0) -> int: return max(_nested_loop_depth(child, depth) for child in children) -def matching_types_in_templates(templates: List[str], types: Set[str]) -> List[str]: - found = set() +def matching_types_in_templates(templates: list[str], types: set[str]) -> list[str]: + found: set[str] = set() for template in templates: root_node = _template_environment.parse(template) @@ -105,7 +105,7 @@ def maybe_template(possible: str) -> bool: return "{{" in possible or "{%" in possible -async def render_user_template(template: str, context: Dict[str, Any]) -> str: +async def render_user_template(template: str, context: dict[str, Any]) -> str: if not maybe_template(template): return template @@ -120,7 +120,7 @@ async def render_user_template(template: str, context: Dict[str, Any]) -> str: ) + template -def render_user_template_sync(template: str, context: Dict[str, Any]) -> str: +def render_user_template_sync(template: str, context: dict[str, Any]) -> str: if not maybe_template(template): return template diff --git a/src/prefect/telemetry/run_telemetry.py b/src/prefect/telemetry/run_telemetry.py index d76f7c57e576..fe2a7bfeff94 100644 --- a/src/prefect/telemetry/run_telemetry.py +++ b/src/prefect/telemetry/run_telemetry.py @@ -152,7 +152,7 @@ def _trace_context_from_labels( return propagate.extract(carrier) def _traceparent_from_span(self, span: Span) -> Optional[str]: - carrier = {} + carrier: dict[str, Any] = {} propagate.inject(carrier, context=trace.set_span_in_context(span)) return carrier.get(TRACEPARENT_KEY) diff --git a/tests/utilities/test_validation.py b/tests/_internal/schemas/test_validation.py similarity index 91% rename from tests/utilities/test_validation.py rename to tests/_internal/schemas/test_validation.py index 1a47a8fd5c64..8a1f5f369932 100644 --- a/tests/utilities/test_validation.py +++ b/tests/_internal/schemas/test_validation.py @@ -1,7 +1,7 @@ import pytest from prefect._internal.schemas.validators import ( - validate_schema, + validate_parameter_openapi_schema, validate_values_conform_to_schema, ) @@ -15,13 +15,13 @@ def test_validate_schema_with_valid_schema(): "required": ["name"], } # Should not raise any exception - validate_schema(schema) + validate_parameter_openapi_schema(schema, {"enforce_parameter_schema": True}) def test_validate_schema_with_invalid_schema(): schema = {"type": "object", "properties": {"name": {"type": "nonexistenttype"}}} with pytest.raises(ValueError) as excinfo: - validate_schema(schema) + validate_parameter_openapi_schema(schema, {"enforce_parameter_schema": True}) assert "The provided schema is not a valid json schema." in str(excinfo.value) assert ( "Schema error: 'nonexistenttype' is not valid under any of the given schemas" @@ -31,7 +31,7 @@ def test_validate_schema_with_invalid_schema(): def test_validate_schema_with_none_schema(): # Should not raise any exception - validate_schema(None) + validate_parameter_openapi_schema(None, {"enforce_parameter_schema": True}) # Tests for validate_values_conform_to_schema function