Skip to content

Commit

Permalink
[typing] misc prefect modules
Browse files Browse the repository at this point in the history
A few more modules to pass strict typing checks.
  • Loading branch information
mjpieters committed Dec 20, 2024
1 parent 7bf7697 commit 2b1ba99
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 141 deletions.
182 changes: 86 additions & 96 deletions src/prefect/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@

from __future__ import annotations

import asyncio
import json # noqa: I001
import math
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any
from uuid import UUID

from prefect.client.schemas.actions import ArtifactCreate as ArtifactRequest
from prefect.client.schemas.actions import ArtifactUpdate
from prefect.client.schemas.filters import ArtifactFilter, ArtifactFilterKey
from prefect.client.schemas.sorting import ArtifactSort
from prefect.client.utilities import get_or_create_client, inject_client
from prefect.client.utilities import get_or_create_client
from prefect.logging.loggers import get_logger
from prefect.utilities.asyncutils import sync_compatible
from prefect.utilities.context import get_task_and_flow_run_ids
Expand Down Expand Up @@ -43,9 +44,9 @@ class Artifact(ArtifactRequest):

@sync_compatible
async def create(
self: "Self",
client: Optional["PrefectClient"] = None,
) -> "ArtifactResponse":
self: Self,
client: PrefectClient | None = None,
) -> ArtifactResponse:
"""
A method to create an artifact.
Expand Down Expand Up @@ -82,8 +83,8 @@ async def create(
@classmethod
@sync_compatible
async def get(
cls, key: Optional[str] = None, client: Optional["PrefectClient"] = None
) -> Optional["ArtifactResponse"]:
cls, key: str | None = None, client: PrefectClient | None = None
) -> ArtifactResponse | None:
"""
A method to get an artifact.
Expand All @@ -95,27 +96,26 @@ async def get(
(ArtifactResponse, optional): The artifact (if found).
"""
client, _ = get_or_create_client(client)
return next(
iter(
await client.read_artifacts(
limit=1,
sort=ArtifactSort.UPDATED_DESC,
artifact_filter=ArtifactFilter(key=ArtifactFilterKey(any_=[key])),
)
filter_key_value = None if key is None else [key]
artifacts = await client.read_artifacts(
limit=1,
sort=ArtifactSort.UPDATED_DESC,
artifact_filter=ArtifactFilter(
key=ArtifactFilterKey(any_=filter_key_value)
),
None,
)
return None if not artifacts else artifacts[0]

@classmethod
@sync_compatible
async def get_or_create(
cls,
key: Optional[str] = None,
description: Optional[str] = None,
data: Optional[Union[Dict[str, Any], Any]] = None,
client: Optional["PrefectClient"] = None,
key: str | None = None,
description: str | None = None,
data: dict[str, Any] | Any | None = None,
client: PrefectClient | None = None,
**kwargs: Any,
) -> Tuple["ArtifactResponse", bool]:
) -> tuple[ArtifactResponse, bool]:
"""
A method to get or create an artifact.
Expand All @@ -128,25 +128,27 @@ async def get_or_create(
Returns:
(ArtifactResponse): The artifact, either retrieved or created.
"""
artifact = await cls.get(key, client)
artifact_coro = cls.get(key, client)
if TYPE_CHECKING:
assert asyncio.iscoroutine(artifact_coro)
artifact = await artifact_coro
if artifact:
return artifact, False
else:
return (
await cls(key=key, description=description, data=data, **kwargs).create(
client
),
True,
)

async def format(self) -> Optional[Union[Dict[str, Any], Any]]:
new_artifact = cls(key=key, description=description, data=data, **kwargs)
create_coro = new_artifact.create(client)
if TYPE_CHECKING:
assert asyncio.iscoroutine(create_coro)
return await create_coro, True

async def format(self) -> dict[str, Any] | Any | None:
return json.dumps(self.data)


class LinkArtifact(Artifact):
link: str
link_text: Optional[str] = None
type: Optional[str] = "markdown"
link_text: str | None = None
type: str | None = "markdown"

async def format(self) -> str:
return (
Expand All @@ -158,20 +160,20 @@ async def format(self) -> str:

class MarkdownArtifact(Artifact):
markdown: str
type: Optional[str] = "markdown"
type: str | None = "markdown"

async def format(self) -> str:
return self.markdown


class TableArtifact(Artifact):
table: Union[Dict[str, List[Any]], List[Dict[str, Any]], List[List[Any]]]
type: Optional[str] = "table"
table: dict[str, list[Any]] | list[dict[str, Any]] | list[list[Any]]
type: str | None = "table"

@classmethod
def _sanitize(
cls, item: Union[Dict[str, Any], List[Any], float]
) -> Union[Dict[str, Any], List[Any], int, float, None]:
cls, item: dict[str, Any] | list[Any] | float
) -> dict[str, Any] | list[Any] | int | float | None:
"""
Sanitize NaN values in a given item.
The item can be a dict, list or float.
Expand All @@ -191,7 +193,7 @@ async def format(self) -> str:

class ProgressArtifact(Artifact):
progress: float
type: Optional[str] = "progress"
type: str | None = "progress"

async def format(self) -> float:
# Ensure progress is between 0 and 100
Expand All @@ -216,7 +218,7 @@ class ImageArtifact(Artifact):
"""

image_url: str
type: Optional[str] = "image"
type: str | None = "image"

async def format(self) -> str:
"""
Expand All @@ -230,46 +232,13 @@ async def format(self) -> str:
return self.image_url


@inject_client
async def _create_artifact(
type: str,
key: Optional[str] = None,
description: Optional[str] = None,
data: Optional[Union[Dict[str, Any], Any]] = None,
client: Optional["PrefectClient"] = None,
) -> UUID:
"""
Helper function to create an artifact.
Arguments:
type: A string identifying the type of artifact.
key: A user-provided string identifier.
The key must only contain lowercase letters, numbers, and dashes.
description: A user-specified description of the artifact.
data: A JSON payload that allows for a result to be retrieved.
client: The PrefectClient
Returns:
- The table artifact ID.
"""

artifact = await Artifact(
type=type,
key=key,
description=description,
data=data,
).create(client)

return artifact.id


@sync_compatible
async def create_link_artifact(
link: str,
link_text: Optional[str] = None,
key: Optional[str] = None,
description: Optional[str] = None,
client: Optional["PrefectClient"] = None,
link_text: str | None = None,
key: str | None = None,
description: str | None = None,
client: PrefectClient | None = None,
) -> UUID:
"""
Create a link artifact.
Expand All @@ -286,21 +255,26 @@ async def create_link_artifact(
Returns:
The table artifact ID.
"""
artifact = await LinkArtifact(
new_artifact = LinkArtifact(
key=key,
description=description,
link=link,
link_text=link_text,
).create(client)
)
create_coro = new_artifact.create(client)
if TYPE_CHECKING:
assert asyncio.iscoroutine(create_coro)
artifact = await create_coro

return artifact.id


@sync_compatible
async def create_markdown_artifact(
markdown: str,
key: Optional[str] = None,
description: Optional[str] = None,
key: str | None = None,
description: str | None = None,
client: PrefectClient | None = None,
) -> UUID:
"""
Create a markdown artifact.
Expand All @@ -315,20 +289,24 @@ async def create_markdown_artifact(
Returns:
The table artifact ID.
"""
artifact = await MarkdownArtifact(
new_artifact = MarkdownArtifact(
key=key,
description=description,
markdown=markdown,
).create()
)
create_coro = new_artifact.create()
if TYPE_CHECKING:
assert asyncio.iscoroutine(create_coro)
artifact = await create_coro

return artifact.id


@sync_compatible
async def create_table_artifact(
table: Union[Dict[str, List[Any]], List[Dict[str, Any]], List[List[Any]]],
key: Optional[str] = None,
description: Optional[str] = None,
table: dict[str, list[Any]] | list[dict[str, Any]] | list[list[Any]],
key: str | None = None,
description: str | None = None,
) -> UUID:
"""
Create a table artifact.
Expand All @@ -344,20 +322,24 @@ async def create_table_artifact(
The table artifact ID.
"""

artifact = await TableArtifact(
new_artifact = TableArtifact(
key=key,
description=description,
table=table,
).create()
)
create_coro = new_artifact.create()
if TYPE_CHECKING:
assert asyncio.iscoroutine(create_coro)
artifact = await create_coro

return artifact.id


@sync_compatible
async def create_progress_artifact(
progress: float,
key: Optional[str] = None,
description: Optional[str] = None,
key: str | None = None,
description: str | None = None,
) -> UUID:
"""
Create a progress artifact.
Expand All @@ -373,11 +355,15 @@ async def create_progress_artifact(
The progress artifact ID.
"""

artifact = await ProgressArtifact(
new_artifact = ProgressArtifact(
key=key,
description=description,
progress=progress,
).create()
)
create_coro = new_artifact.create()
if TYPE_CHECKING:
assert asyncio.iscoroutine(create_coro)
artifact = await create_coro

return artifact.id

Expand All @@ -386,8 +372,8 @@ async def create_progress_artifact(
async def update_progress_artifact(
artifact_id: UUID,
progress: float,
description: Optional[str] = None,
client: Optional[PrefectClient] = None,
description: str | None = None,
client: PrefectClient | None = None,
) -> UUID:
"""
Update a progress artifact.
Expand Down Expand Up @@ -427,8 +413,8 @@ async def update_progress_artifact(
@sync_compatible
async def create_image_artifact(
image_url: str,
key: Optional[str] = None,
description: Optional[str] = None,
key: str | None = None,
description: str | None = None,
) -> UUID:
"""
Create an image artifact.
Expand All @@ -444,10 +430,14 @@ async def create_image_artifact(
The image artifact ID.
"""

artifact = await ImageArtifact(
new_artifact = ImageArtifact(
key=key,
description=description,
image_url=image_url,
).create()
)
create_coro = new_artifact.create()
if TYPE_CHECKING:
assert asyncio.iscoroutine(create_coro)
artifact = await create_coro

return artifact.id
Loading

0 comments on commit 2b1ba99

Please sign in to comment.