Skip to content

Commit

Permalink
[typing] misc prefect modules
Browse files Browse the repository at this point in the history
A few more modules to pass strict typing checks.
  • Loading branch information
mjpieters committed Dec 21, 2024
1 parent 6f5d463 commit 8981f2a
Show file tree
Hide file tree
Showing 13 changed files with 192 additions and 169 deletions.
135 changes: 61 additions & 74 deletions src/prefect/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,27 @@
Interface for creating and reading artifacts.
"""

from __future__ import annotations

import asyncio
import json # noqa: I001
import math
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from uuid import UUID

from typing_extensions import Self

from prefect.client.schemas.actions import ArtifactCreate as ArtifactRequest
from prefect.client.schemas.actions import ArtifactUpdate
from prefect.client.schemas.filters import ArtifactFilter, ArtifactFilterKey
from prefect.client.schemas.sorting import ArtifactSort
from prefect.client.utilities import get_or_create_client, inject_client
from prefect.client.utilities import get_or_create_client
from prefect.logging.loggers import get_logger
from prefect.utilities.asyncutils import sync_compatible
from prefect.utilities.context import get_task_and_flow_run_ids

logger = get_logger("artifacts")

if TYPE_CHECKING:
from typing_extensions import Self

from prefect.client.orchestration import PrefectClient
from prefect.client.schemas.objects import Artifact as ArtifactResponse

Expand All @@ -43,7 +42,7 @@ class Artifact(ArtifactRequest):

@sync_compatible
async def create(
self: "Self",
self: Self,
client: Optional["PrefectClient"] = None,
) -> "ArtifactResponse":
"""
Expand Down Expand Up @@ -95,27 +94,26 @@ async def get(
(ArtifactResponse, optional): The artifact (if found).
"""
client, _ = get_or_create_client(client)
return next(
iter(
await client.read_artifacts(
limit=1,
sort=ArtifactSort.UPDATED_DESC,
artifact_filter=ArtifactFilter(key=ArtifactFilterKey(any_=[key])),
)
filter_key_value = None if key is None else [key]
artifacts = await client.read_artifacts(
limit=1,
sort=ArtifactSort.UPDATED_DESC,
artifact_filter=ArtifactFilter(
key=ArtifactFilterKey(any_=filter_key_value)
),
None,
)
return None if not artifacts else artifacts[0]

@classmethod
@sync_compatible
async def get_or_create(
cls,
key: Optional[str] = None,
description: Optional[str] = None,
data: Optional[Union[Dict[str, Any], Any]] = None,
data: Optional[Union[dict[str, Any], Any]] = None,
client: Optional["PrefectClient"] = None,
**kwargs: Any,
) -> Tuple["ArtifactResponse", bool]:
) -> tuple["ArtifactResponse", bool]:
"""
A method to get or create an artifact.
Expand All @@ -128,18 +126,20 @@ async def get_or_create(
Returns:
(ArtifactResponse): The artifact, either retrieved or created.
"""
artifact = await cls.get(key, client)
artifact_coro = cls.get(key, client)
if TYPE_CHECKING:
assert asyncio.iscoroutine(artifact_coro)
artifact = await artifact_coro
if artifact:
return artifact, False
else:
return (
await cls(key=key, description=description, data=data, **kwargs).create(
client
),
True,
)

async def format(self) -> Optional[Union[Dict[str, Any], Any]]:
new_artifact = cls(key=key, description=description, data=data, **kwargs)
create_coro = new_artifact.create(client)
if TYPE_CHECKING:
assert asyncio.iscoroutine(create_coro)
return await create_coro, True

async def format(self) -> Optional[Union[dict[str, Any], Any]]:
return json.dumps(self.data)


Expand All @@ -165,13 +165,13 @@ async def format(self) -> str:


class TableArtifact(Artifact):
table: Union[Dict[str, List[Any]], List[Dict[str, Any]], List[List[Any]]]
table: Union[dict[str, list[Any]], list[dict[str, Any]], list[list[Any]]]
type: Optional[str] = "table"

@classmethod
def _sanitize(
cls, item: Union[Dict[str, Any], List[Any], float]
) -> Union[Dict[str, Any], List[Any], int, float, None]:
cls, item: Union[dict[str, Any], list[Any], float]
) -> Union[dict[str, Any], list[Any], int, float, None]:
"""
Sanitize NaN values in a given item.
The item can be a dict, list or float.
Expand Down Expand Up @@ -230,39 +230,6 @@ async def format(self) -> str:
return self.image_url


@inject_client
async def _create_artifact(
type: str,
key: Optional[str] = None,
description: Optional[str] = None,
data: Optional[Union[Dict[str, Any], Any]] = None,
client: Optional["PrefectClient"] = None,
) -> UUID:
"""
Helper function to create an artifact.
Arguments:
type: A string identifying the type of artifact.
key: A user-provided string identifier.
The key must only contain lowercase letters, numbers, and dashes.
description: A user-specified description of the artifact.
data: A JSON payload that allows for a result to be retrieved.
client: The PrefectClient
Returns:
- The table artifact ID.
"""

artifact = await Artifact(
type=type,
key=key,
description=description,
data=data,
).create(client)

return artifact.id


@sync_compatible
async def create_link_artifact(
link: str,
Expand All @@ -286,12 +253,16 @@ async def create_link_artifact(
Returns:
The table artifact ID.
"""
artifact = await LinkArtifact(
new_artifact = LinkArtifact(
key=key,
description=description,
link=link,
link_text=link_text,
).create(client)
)
create_coro = new_artifact.create(client)
if TYPE_CHECKING:
assert asyncio.iscoroutine(create_coro)
artifact = await create_coro

return artifact.id

Expand All @@ -315,18 +286,22 @@ async def create_markdown_artifact(
Returns:
The table artifact ID.
"""
artifact = await MarkdownArtifact(
new_artifact = MarkdownArtifact(
key=key,
description=description,
markdown=markdown,
).create()
)
create_coro = new_artifact.create()
if TYPE_CHECKING:
assert asyncio.iscoroutine(create_coro)
artifact = await create_coro

return artifact.id


@sync_compatible
async def create_table_artifact(
table: Union[Dict[str, List[Any]], List[Dict[str, Any]], List[List[Any]]],
table: Union[dict[str, list[Any]], list[dict[str, Any]], list[list[Any]]],
key: Optional[str] = None,
description: Optional[str] = None,
) -> UUID:
Expand All @@ -344,11 +319,15 @@ async def create_table_artifact(
The table artifact ID.
"""

artifact = await TableArtifact(
new_artifact = TableArtifact(
key=key,
description=description,
table=table,
).create()
)
create_coro = new_artifact.create()
if TYPE_CHECKING:
assert asyncio.iscoroutine(create_coro)
artifact = await create_coro

return artifact.id

Expand All @@ -373,11 +352,15 @@ async def create_progress_artifact(
The progress artifact ID.
"""

artifact = await ProgressArtifact(
new_artifact = ProgressArtifact(
key=key,
description=description,
progress=progress,
).create()
)
create_coro = new_artifact.create()
if TYPE_CHECKING:
assert asyncio.iscoroutine(create_coro)
artifact = await create_coro

return artifact.id

Expand All @@ -387,7 +370,7 @@ async def update_progress_artifact(
artifact_id: UUID,
progress: float,
description: Optional[str] = None,
client: Optional[PrefectClient] = None,
client: Optional["PrefectClient"] = None,
) -> UUID:
"""
Update a progress artifact.
Expand Down Expand Up @@ -444,10 +427,14 @@ async def create_image_artifact(
The image artifact ID.
"""

artifact = await ImageArtifact(
new_artifact = ImageArtifact(
key=key,
description=description,
image_url=image_url,
).create()
)
create_coro = new_artifact.create()
if TYPE_CHECKING:
assert asyncio.iscoroutine(create_coro)
artifact = await create_coro

return artifact.id
34 changes: 27 additions & 7 deletions src/prefect/automations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Type
from typing import TYPE_CHECKING, Optional, overload
from uuid import UUID

from pydantic import Field
Expand Down Expand Up @@ -112,17 +112,28 @@ async def update(self: Self):
auto.name = "new name"
auto.update()
"""
assert self.id is not None
async with get_client() as client:
automation = AutomationCore(
**self.model_dump(exclude={"id", "owner_resource"})
)
await client.update_automation(automation_id=self.id, automation=automation)

@overload
@classmethod
async def read(cls, id: UUID, name: Optional[str] = ...) -> Self:
...

@overload
@classmethod
async def read(cls, id: None = None, name: str = ...) -> Optional[Self]:
...

@classmethod
@sync_compatible
async def read(
cls: Type[Self], id: Optional[UUID] = None, name: Optional[str] = None
) -> Self:
cls, id: Optional[UUID] = None, name: Optional[str] = None
) -> Optional[Self]:
"""
Read an automation by ID or name.
automation = Automation.read(name="woodchonk")
Expand All @@ -145,13 +156,13 @@ async def read(
raise
if automation is None:
raise ValueError(f"Automation with ID {id!r} not found")
return Automation(**automation.model_dump())
return cls(**automation.model_dump())
else:
if TYPE_CHECKING:
assert name is not None
automation = await client.read_automations_by_name(name=name)
if len(automation) > 0:
return (
Automation(**automation[0].model_dump()) if automation else None
)
return cls(**automation[0].model_dump()) if automation else None
else:
raise ValueError(f"Automation with name {name!r} not found")

Expand All @@ -161,6 +172,9 @@ async def delete(self: Self) -> bool:
auto = Automation.read(id = 123)
auto.delete()
"""
if self.id is None:
raise ValueError("Can't delete an automation without an id")

async with get_client() as client:
try:
await client.delete_automation(self.id)
Expand All @@ -177,6 +191,9 @@ async def disable(self: Self) -> bool:
auto = Automation.read(id = 123)
auto.disable()
"""
if self.id is None:
raise ValueError("Can't disable an automation without an id")

async with get_client() as client:
try:
await client.pause_automation(self.id)
Expand All @@ -193,6 +210,9 @@ async def enable(self: Self) -> bool:
auto = Automation.read(id = 123)
auto.enable()
"""
if self.id is None:
raise ValueError("Can't enable an automation without an id")

async with get_client() as client:
try:
await client.resume_automation(self.id)
Expand Down
Loading

0 comments on commit 8981f2a

Please sign in to comment.