From 2b1ba99dbb51556f01a18d1457396cd44a402737 Mon Sep 17 00:00:00 2001 From: Martijn Pieters Date: Fri, 20 Dec 2024 16:56:11 +0000 Subject: [PATCH] [typing] misc prefect modules A few more modules to pass strict typing checks. --- src/prefect/artifacts.py | 182 ++++++++++++-------------- src/prefect/automations.py | 34 ++++- src/prefect/context.py | 43 +++--- src/prefect/plugins.py | 22 ++-- tests/typesafety/test_automations.yml | 19 +++ tests/typesafety/test_flows.yml | 2 +- 6 files changed, 161 insertions(+), 141 deletions(-) create mode 100644 tests/typesafety/test_automations.yml diff --git a/src/prefect/artifacts.py b/src/prefect/artifacts.py index 785effbe92216..98249679e6c66 100644 --- a/src/prefect/artifacts.py +++ b/src/prefect/artifacts.py @@ -4,17 +4,18 @@ from __future__ import annotations +import asyncio import json # noqa: I001 import math import warnings -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any from uuid import UUID from prefect.client.schemas.actions import ArtifactCreate as ArtifactRequest from prefect.client.schemas.actions import ArtifactUpdate from prefect.client.schemas.filters import ArtifactFilter, ArtifactFilterKey from prefect.client.schemas.sorting import ArtifactSort -from prefect.client.utilities import get_or_create_client, inject_client +from prefect.client.utilities import get_or_create_client from prefect.logging.loggers import get_logger from prefect.utilities.asyncutils import sync_compatible from prefect.utilities.context import get_task_and_flow_run_ids @@ -43,9 +44,9 @@ class Artifact(ArtifactRequest): @sync_compatible async def create( - self: "Self", - client: Optional["PrefectClient"] = None, - ) -> "ArtifactResponse": + self: Self, + client: PrefectClient | None = None, + ) -> ArtifactResponse: """ A method to create an artifact. @@ -82,8 +83,8 @@ async def create( @classmethod @sync_compatible async def get( - cls, key: Optional[str] = None, client: Optional["PrefectClient"] = None - ) -> Optional["ArtifactResponse"]: + cls, key: str | None = None, client: PrefectClient | None = None + ) -> ArtifactResponse | None: """ A method to get an artifact. @@ -95,27 +96,26 @@ async def get( (ArtifactResponse, optional): The artifact (if found). """ client, _ = get_or_create_client(client) - return next( - iter( - await client.read_artifacts( - limit=1, - sort=ArtifactSort.UPDATED_DESC, - artifact_filter=ArtifactFilter(key=ArtifactFilterKey(any_=[key])), - ) + filter_key_value = None if key is None else [key] + artifacts = await client.read_artifacts( + limit=1, + sort=ArtifactSort.UPDATED_DESC, + artifact_filter=ArtifactFilter( + key=ArtifactFilterKey(any_=filter_key_value) ), - None, ) + return None if not artifacts else artifacts[0] @classmethod @sync_compatible async def get_or_create( cls, - key: Optional[str] = None, - description: Optional[str] = None, - data: Optional[Union[Dict[str, Any], Any]] = None, - client: Optional["PrefectClient"] = None, + key: str | None = None, + description: str | None = None, + data: dict[str, Any] | Any | None = None, + client: PrefectClient | None = None, **kwargs: Any, - ) -> Tuple["ArtifactResponse", bool]: + ) -> tuple[ArtifactResponse, bool]: """ A method to get or create an artifact. @@ -128,25 +128,27 @@ async def get_or_create( Returns: (ArtifactResponse): The artifact, either retrieved or created. """ - artifact = await cls.get(key, client) + artifact_coro = cls.get(key, client) + if TYPE_CHECKING: + assert asyncio.iscoroutine(artifact_coro) + artifact = await artifact_coro if artifact: return artifact, False - else: - return ( - await cls(key=key, description=description, data=data, **kwargs).create( - client - ), - True, - ) - async def format(self) -> Optional[Union[Dict[str, Any], Any]]: + new_artifact = cls(key=key, description=description, data=data, **kwargs) + create_coro = new_artifact.create(client) + if TYPE_CHECKING: + assert asyncio.iscoroutine(create_coro) + return await create_coro, True + + async def format(self) -> dict[str, Any] | Any | None: return json.dumps(self.data) class LinkArtifact(Artifact): link: str - link_text: Optional[str] = None - type: Optional[str] = "markdown" + link_text: str | None = None + type: str | None = "markdown" async def format(self) -> str: return ( @@ -158,20 +160,20 @@ async def format(self) -> str: class MarkdownArtifact(Artifact): markdown: str - type: Optional[str] = "markdown" + type: str | None = "markdown" async def format(self) -> str: return self.markdown class TableArtifact(Artifact): - table: Union[Dict[str, List[Any]], List[Dict[str, Any]], List[List[Any]]] - type: Optional[str] = "table" + table: dict[str, list[Any]] | list[dict[str, Any]] | list[list[Any]] + type: str | None = "table" @classmethod def _sanitize( - cls, item: Union[Dict[str, Any], List[Any], float] - ) -> Union[Dict[str, Any], List[Any], int, float, None]: + cls, item: dict[str, Any] | list[Any] | float + ) -> dict[str, Any] | list[Any] | int | float | None: """ Sanitize NaN values in a given item. The item can be a dict, list or float. @@ -191,7 +193,7 @@ async def format(self) -> str: class ProgressArtifact(Artifact): progress: float - type: Optional[str] = "progress" + type: str | None = "progress" async def format(self) -> float: # Ensure progress is between 0 and 100 @@ -216,7 +218,7 @@ class ImageArtifact(Artifact): """ image_url: str - type: Optional[str] = "image" + type: str | None = "image" async def format(self) -> str: """ @@ -230,46 +232,13 @@ async def format(self) -> str: return self.image_url -@inject_client -async def _create_artifact( - type: str, - key: Optional[str] = None, - description: Optional[str] = None, - data: Optional[Union[Dict[str, Any], Any]] = None, - client: Optional["PrefectClient"] = None, -) -> UUID: - """ - Helper function to create an artifact. - - Arguments: - type: A string identifying the type of artifact. - key: A user-provided string identifier. - The key must only contain lowercase letters, numbers, and dashes. - description: A user-specified description of the artifact. - data: A JSON payload that allows for a result to be retrieved. - client: The PrefectClient - - Returns: - - The table artifact ID. - """ - - artifact = await Artifact( - type=type, - key=key, - description=description, - data=data, - ).create(client) - - return artifact.id - - @sync_compatible async def create_link_artifact( link: str, - link_text: Optional[str] = None, - key: Optional[str] = None, - description: Optional[str] = None, - client: Optional["PrefectClient"] = None, + link_text: str | None = None, + key: str | None = None, + description: str | None = None, + client: PrefectClient | None = None, ) -> UUID: """ Create a link artifact. @@ -286,12 +255,16 @@ async def create_link_artifact( Returns: The table artifact ID. """ - artifact = await LinkArtifact( + new_artifact = LinkArtifact( key=key, description=description, link=link, link_text=link_text, - ).create(client) + ) + create_coro = new_artifact.create(client) + if TYPE_CHECKING: + assert asyncio.iscoroutine(create_coro) + artifact = await create_coro return artifact.id @@ -299,8 +272,9 @@ async def create_link_artifact( @sync_compatible async def create_markdown_artifact( markdown: str, - key: Optional[str] = None, - description: Optional[str] = None, + key: str | None = None, + description: str | None = None, + client: PrefectClient | None = None, ) -> UUID: """ Create a markdown artifact. @@ -315,20 +289,24 @@ async def create_markdown_artifact( Returns: The table artifact ID. """ - artifact = await MarkdownArtifact( + new_artifact = MarkdownArtifact( key=key, description=description, markdown=markdown, - ).create() + ) + create_coro = new_artifact.create() + if TYPE_CHECKING: + assert asyncio.iscoroutine(create_coro) + artifact = await create_coro return artifact.id @sync_compatible async def create_table_artifact( - table: Union[Dict[str, List[Any]], List[Dict[str, Any]], List[List[Any]]], - key: Optional[str] = None, - description: Optional[str] = None, + table: dict[str, list[Any]] | list[dict[str, Any]] | list[list[Any]], + key: str | None = None, + description: str | None = None, ) -> UUID: """ Create a table artifact. @@ -344,11 +322,15 @@ async def create_table_artifact( The table artifact ID. """ - artifact = await TableArtifact( + new_artifact = TableArtifact( key=key, description=description, table=table, - ).create() + ) + create_coro = new_artifact.create() + if TYPE_CHECKING: + assert asyncio.iscoroutine(create_coro) + artifact = await create_coro return artifact.id @@ -356,8 +338,8 @@ async def create_table_artifact( @sync_compatible async def create_progress_artifact( progress: float, - key: Optional[str] = None, - description: Optional[str] = None, + key: str | None = None, + description: str | None = None, ) -> UUID: """ Create a progress artifact. @@ -373,11 +355,15 @@ async def create_progress_artifact( The progress artifact ID. """ - artifact = await ProgressArtifact( + new_artifact = ProgressArtifact( key=key, description=description, progress=progress, - ).create() + ) + create_coro = new_artifact.create() + if TYPE_CHECKING: + assert asyncio.iscoroutine(create_coro) + artifact = await create_coro return artifact.id @@ -386,8 +372,8 @@ async def create_progress_artifact( async def update_progress_artifact( artifact_id: UUID, progress: float, - description: Optional[str] = None, - client: Optional[PrefectClient] = None, + description: str | None = None, + client: PrefectClient | None = None, ) -> UUID: """ Update a progress artifact. @@ -427,8 +413,8 @@ async def update_progress_artifact( @sync_compatible async def create_image_artifact( image_url: str, - key: Optional[str] = None, - description: Optional[str] = None, + key: str | None = None, + description: str | None = None, ) -> UUID: """ Create an image artifact. @@ -444,10 +430,14 @@ async def create_image_artifact( The image artifact ID. """ - artifact = await ImageArtifact( + new_artifact = ImageArtifact( key=key, description=description, image_url=image_url, - ).create() + ) + create_coro = new_artifact.create() + if TYPE_CHECKING: + assert asyncio.iscoroutine(create_coro) + artifact = await create_coro return artifact.id diff --git a/src/prefect/automations.py b/src/prefect/automations.py index a37c5a3a45dd0..623e4a066fe12 100644 --- a/src/prefect/automations.py +++ b/src/prefect/automations.py @@ -1,4 +1,4 @@ -from typing import Optional, Type +from typing import TYPE_CHECKING, Optional, overload from uuid import UUID from pydantic import Field @@ -112,17 +112,28 @@ async def update(self: Self): auto.name = "new name" auto.update() """ + assert self.id is not None async with get_client() as client: automation = AutomationCore( **self.model_dump(exclude={"id", "owner_resource"}) ) await client.update_automation(automation_id=self.id, automation=automation) + @overload + @classmethod + async def read(cls, id: UUID, name: Optional[str] = ...) -> Self: + ... + + @overload + @classmethod + async def read(cls, id: None = None, name: str = ...) -> Optional[Self]: + ... + @classmethod @sync_compatible async def read( - cls: Type[Self], id: Optional[UUID] = None, name: Optional[str] = None - ) -> Self: + cls, *, id: Optional[UUID] = None, name: Optional[str] = None + ) -> Optional[Self]: """ Read an automation by ID or name. automation = Automation.read(name="woodchonk") @@ -145,13 +156,13 @@ async def read( raise if automation is None: raise ValueError(f"Automation with ID {id!r} not found") - return Automation(**automation.model_dump()) + return cls(**automation.model_dump()) else: + if TYPE_CHECKING: + assert name is not None automation = await client.read_automations_by_name(name=name) if len(automation) > 0: - return ( - Automation(**automation[0].model_dump()) if automation else None - ) + return cls(**automation[0].model_dump()) if automation else None else: raise ValueError(f"Automation with name {name!r} not found") @@ -161,6 +172,9 @@ async def delete(self: Self) -> bool: auto = Automation.read(id = 123) auto.delete() """ + if self.id is None: + raise ValueError("Can't delete an automation without an id") + async with get_client() as client: try: await client.delete_automation(self.id) @@ -177,6 +191,9 @@ async def disable(self: Self) -> bool: auto = Automation.read(id = 123) auto.disable() """ + if self.id is None: + raise ValueError("Can't disable an automation without an id") + async with get_client() as client: try: await client.pause_automation(self.id) @@ -193,6 +210,9 @@ async def enable(self: Self) -> bool: auto = Automation.read(id = 123) auto.enable() """ + if self.id is None: + raise ValueError("Can't enable an automation without an id") + async with get_client() as client: try: await client.resume_automation(self.id) diff --git a/src/prefect/context.py b/src/prefect/context.py index 287b9b58e1381..7f92ad7528ca5 100644 --- a/src/prefect/context.py +++ b/src/prefect/context.py @@ -9,21 +9,10 @@ import os import sys import warnings +from collections.abc import AsyncGenerator, Generator, Mapping from contextlib import ExitStack, asynccontextmanager, contextmanager from contextvars import ContextVar, Token -from typing import ( - TYPE_CHECKING, - Any, - AsyncGenerator, - Dict, - Generator, - Mapping, - Optional, - Set, - Type, - TypeVar, - Union, -) +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from typing_extensions import Self @@ -64,7 +53,7 @@ GLOBAL_SETTINGS_CONTEXT = None # type: ignore -def serialize_context() -> Dict[str, Any]: +def serialize_context() -> dict[str, Any]: """ Serialize the current context for use in a remote execution environment. """ @@ -84,7 +73,7 @@ def serialize_context() -> Dict[str, Any]: @contextmanager def hydrated_context( - serialized_context: Optional[Dict[str, Any]] = None, + serialized_context: Optional[dict[str, Any]] = None, client: Union[PrefectClient, SyncPrefectClient, None] = None, ): with ExitStack() as stack: @@ -148,7 +137,7 @@ def __exit__(self, *_): self._token = None @classmethod - def get(cls: Type[Self]) -> Optional[Self]: + def get(cls: type[Self]) -> Optional[Self]: """Get the current context instance""" return cls.__var__.get(None) @@ -173,7 +162,7 @@ def model_copy( new._token = None return new - def serialize(self, include_secrets: bool = True) -> Dict[str, Any]: + def serialize(self, include_secrets: bool = True) -> dict[str, Any]: """ Serialize the context model to a dictionary that can be pickled with cloudpickle. """ @@ -314,10 +303,10 @@ def __init__(self, *args: Any, **kwargs: Any): start_client_metrics_server() start_time: DateTime = Field(default_factory=lambda: DateTime.now("UTC")) - input_keyset: Optional[Dict[str, Dict[str, str]]] = None + input_keyset: Optional[dict[str, dict[str, str]]] = None client: Union[PrefectClient, SyncPrefectClient] - def serialize(self: Self, include_secrets: bool = True) -> Dict[str, Any]: + def serialize(self: Self, include_secrets: bool = True) -> dict[str, Any]: return self.model_dump( include={"start_time", "input_keyset"}, exclude_unset=True, @@ -344,7 +333,7 @@ class EngineContext(RunContext): flow_run: Optional[FlowRun] = None task_runner: TaskRunner[Any] log_prints: bool = False - parameters: Optional[Dict[str, Any]] = None + parameters: Optional[dict[str, Any]] = None # Flag signaling if the flow run context has been serialized and sent # to remote infrastructure. @@ -355,10 +344,10 @@ class EngineContext(RunContext): persist_result: bool = Field(default_factory=get_default_persist_setting) # Counter for task calls allowing unique - task_run_dynamic_keys: Dict[str, Union[str, int]] = Field(default_factory=dict) + task_run_dynamic_keys: dict[str, Union[str, int]] = Field(default_factory=dict) # Counter for flow pauses - observed_flow_pauses: Dict[str, int] = Field(default_factory=dict) + observed_flow_pauses: dict[str, int] = Field(default_factory=dict) # Tracking for result from task runs in this flow run for dependency tracking # Holds the ID of the object returned by the task run and task run state @@ -369,7 +358,7 @@ class EngineContext(RunContext): __var__: ContextVar[Self] = ContextVar("flow_run") - def serialize(self: Self, include_secrets: bool = True) -> Dict[str, Any]: + def serialize(self: Self, include_secrets: bool = True) -> dict[str, Any]: return self.model_dump( include={ "flow_run", @@ -403,7 +392,7 @@ class TaskRunContext(RunContext): task: "Task[Any, Any]" task_run: TaskRun log_prints: bool = False - parameters: Dict[str, Any] + parameters: dict[str, Any] # Result handling result_store: ResultStore @@ -411,7 +400,7 @@ class TaskRunContext(RunContext): __var__ = ContextVar("task_run") - def serialize(self: Self, include_secrets: bool = True) -> Dict[str, Any]: + def serialize(self: Self, include_secrets: bool = True) -> dict[str, Any]: return self.model_dump( include={ "task_run", @@ -437,7 +426,7 @@ class TagsContext(ContextModel): current_tags: A set of current tags in the context """ - current_tags: Set[str] = Field(default_factory=set) + current_tags: set[str] = Field(default_factory=set) @classmethod def get(cls) -> "TagsContext": @@ -512,7 +501,7 @@ def get_settings_context() -> SettingsContext: @contextmanager -def tags(*new_tags: str) -> Generator[Set[str], None, None]: +def tags(*new_tags: str) -> Generator[set[str], None, None]: """ Context manager to add tags to flow and task run calls. diff --git a/src/prefect/plugins.py b/src/prefect/plugins.py index 07f4144c4acd9..d55dfece3c34a 100644 --- a/src/prefect/plugins.py +++ b/src/prefect/plugins.py @@ -9,15 +9,15 @@ """ from types import ModuleType -from typing import Any, Dict, Union +from typing import Any, Union import prefect.settings from prefect.utilities.compat import EntryPoints, entry_points -COLLECTIONS: Union[None, Dict[str, Union[ModuleType, Exception]]] = None +_collections: Union[None, dict[str, Union[ModuleType, Exception]]] = None -def safe_load_entrypoints(entrypoints: EntryPoints) -> Dict[str, Union[Exception, Any]]: +def safe_load_entrypoints(entrypoints: EntryPoints) -> dict[str, Union[Exception, Any]]: """ Load entry points for a group capturing any exceptions that occur. """ @@ -26,7 +26,7 @@ def safe_load_entrypoints(entrypoints: EntryPoints) -> Dict[str, Union[Exception # also want to validate the type for the group for entrypoints that have # a specific type we expect. - results = {} + results: dict[str, Union[Exception, Any]] = {} for entrypoint in entrypoints: result = None @@ -40,18 +40,20 @@ def safe_load_entrypoints(entrypoints: EntryPoints) -> Dict[str, Union[Exception return results -def load_prefect_collections() -> Dict[str, Union[ModuleType, Exception]]: +def load_prefect_collections() -> dict[str, Union[ModuleType, Exception]]: """ Load all Prefect collections that define an entrypoint in the group `prefect.collections`. """ - global COLLECTIONS + global _collections - if COLLECTIONS is not None: - return COLLECTIONS + if _collections is not None: + return _collections collection_entrypoints: EntryPoints = entry_points(group="prefect.collections") - collections = safe_load_entrypoints(collection_entrypoints) + collections: dict[str, Union[Exception, Any]] = safe_load_entrypoints( + collection_entrypoints + ) # TODO: Consider the utility of this once we've established this pattern. # We cannot use a logger here because logging is not yet initialized. @@ -68,5 +70,5 @@ def load_prefect_collections() -> Dict[str, Union[ModuleType, Exception]]: if prefect.settings.PREFECT_DEBUG_MODE: print(f"Loaded collection {name!r}.") - COLLECTIONS = collections + _collections = collections return collections diff --git a/tests/typesafety/test_automations.yml b/tests/typesafety/test_automations.yml new file mode 100644 index 0000000000000..6bd1706f64294 --- /dev/null +++ b/tests/typesafety/test_automations.yml @@ -0,0 +1,19 @@ +# yaml-language-server: $schema=https://raw.githubusercontent.com/typeddjango/pytest-mypy-plugins/master/pytest_mypy_plugins/schema.json +- case: automations_read_by_id + main: | + from uuid import uuid4 + from prefect.automations import Automation + async def test_func() -> None: + automation = await Automation.read(id=uuid4()) + reveal_type(automation) # N: Revealed type is "prefect.automations.Automation" + automation = await Automation.read(id=uuid4(), name=None) + reveal_type(automation) # N: Revealed type is "prefect.automations.Automation" + +- case: automations_read_by_name + main: | + from prefect.automations import Automation + async def test_func() -> None: + automation = await Automation.read(name="foobar") + reveal_type(automation) # N: Revealed type is "Union[prefect.automations.Automation, None]" + automation = await Automation.read(id=None, name="foobar") + reveal_type(automation) # N: Revealed type is "Union[prefect.automations.Automation, None]" diff --git a/tests/typesafety/test_flows.yml b/tests/typesafety/test_flows.yml index 2db73d00e6c5a..4ef96d9e88d95 100644 --- a/tests/typesafety/test_flows.yml +++ b/tests/typesafety/test_flows.yml @@ -3,7 +3,7 @@ main: | from prefect import flow reveal_type(flow.from_source) - regex: yes + regex: true # this has to be a regex, because mypy randomly (!) switches between ... and [*Any, **Any] syntax here out: "main:2: note: Revealed type is \"\ def \\(\