Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[typing] misc prefect modules #16468

Merged
merged 1 commit into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pytest-benchmark
pytest-cov
pytest-env
pytest-flakefinder
pytest-mypy-plugins >= 3.1.0
pytest-mypy-plugins >= 3.2.0
pytest-timeout
pytest-xdist >= 3.6.1
pyyaml
Expand Down
135 changes: 61 additions & 74 deletions src/prefect/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,27 @@
Interface for creating and reading artifacts.
"""

from __future__ import annotations

import asyncio
import json # noqa: I001
import math
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from uuid import UUID

from typing_extensions import Self

from prefect.client.schemas.actions import ArtifactCreate as ArtifactRequest
from prefect.client.schemas.actions import ArtifactUpdate
from prefect.client.schemas.filters import ArtifactFilter, ArtifactFilterKey
from prefect.client.schemas.sorting import ArtifactSort
from prefect.client.utilities import get_or_create_client, inject_client
from prefect.client.utilities import get_or_create_client
from prefect.logging.loggers import get_logger
from prefect.utilities.asyncutils import sync_compatible
from prefect.utilities.context import get_task_and_flow_run_ids

logger = get_logger("artifacts")

if TYPE_CHECKING:
from typing_extensions import Self

from prefect.client.orchestration import PrefectClient
from prefect.client.schemas.objects import Artifact as ArtifactResponse

Expand All @@ -43,7 +42,7 @@ class Artifact(ArtifactRequest):

@sync_compatible
async def create(
self: "Self",
self: Self,
client: Optional["PrefectClient"] = None,
) -> "ArtifactResponse":
"""
Expand Down Expand Up @@ -95,27 +94,26 @@ async def get(
(ArtifactResponse, optional): The artifact (if found).
"""
client, _ = get_or_create_client(client)
return next(
iter(
await client.read_artifacts(
limit=1,
sort=ArtifactSort.UPDATED_DESC,
artifact_filter=ArtifactFilter(key=ArtifactFilterKey(any_=[key])),
)
filter_key_value = None if key is None else [key]
artifacts = await client.read_artifacts(
limit=1,
sort=ArtifactSort.UPDATED_DESC,
artifact_filter=ArtifactFilter(
key=ArtifactFilterKey(any_=filter_key_value)
),
None,
)
return None if not artifacts else artifacts[0]

@classmethod
@sync_compatible
async def get_or_create(
cls,
key: Optional[str] = None,
description: Optional[str] = None,
data: Optional[Union[Dict[str, Any], Any]] = None,
data: Optional[Union[dict[str, Any], Any]] = None,
client: Optional["PrefectClient"] = None,
**kwargs: Any,
) -> Tuple["ArtifactResponse", bool]:
) -> tuple["ArtifactResponse", bool]:
"""
A method to get or create an artifact.
Expand All @@ -128,18 +126,20 @@ async def get_or_create(
Returns:
(ArtifactResponse): The artifact, either retrieved or created.
"""
artifact = await cls.get(key, client)
artifact_coro = cls.get(key, client)
if TYPE_CHECKING:
assert asyncio.iscoroutine(artifact_coro)
artifact = await artifact_coro
if artifact:
return artifact, False
else:
return (
await cls(key=key, description=description, data=data, **kwargs).create(
client
),
True,
)

async def format(self) -> Optional[Union[Dict[str, Any], Any]]:
new_artifact = cls(key=key, description=description, data=data, **kwargs)
create_coro = new_artifact.create(client)
if TYPE_CHECKING:
assert asyncio.iscoroutine(create_coro)
return await create_coro, True

async def format(self) -> Optional[Union[dict[str, Any], Any]]:
return json.dumps(self.data)


Expand All @@ -165,13 +165,13 @@ async def format(self) -> str:


class TableArtifact(Artifact):
table: Union[Dict[str, List[Any]], List[Dict[str, Any]], List[List[Any]]]
table: Union[dict[str, list[Any]], list[dict[str, Any]], list[list[Any]]]
type: Optional[str] = "table"

@classmethod
def _sanitize(
cls, item: Union[Dict[str, Any], List[Any], float]
) -> Union[Dict[str, Any], List[Any], int, float, None]:
cls, item: Union[dict[str, Any], list[Any], float]
) -> Union[dict[str, Any], list[Any], int, float, None]:
"""
Sanitize NaN values in a given item.
The item can be a dict, list or float.
Expand Down Expand Up @@ -230,39 +230,6 @@ async def format(self) -> str:
return self.image_url


@inject_client
async def _create_artifact(
type: str,
key: Optional[str] = None,
description: Optional[str] = None,
data: Optional[Union[Dict[str, Any], Any]] = None,
client: Optional["PrefectClient"] = None,
) -> UUID:
"""
Helper function to create an artifact.
Arguments:
type: A string identifying the type of artifact.
key: A user-provided string identifier.
The key must only contain lowercase letters, numbers, and dashes.
description: A user-specified description of the artifact.
data: A JSON payload that allows for a result to be retrieved.
client: The PrefectClient
Returns:
- The table artifact ID.
"""

artifact = await Artifact(
type=type,
key=key,
description=description,
data=data,
).create(client)

return artifact.id


Comment on lines -233 to -265
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method has been unused since the last reference to it was removed in 8577a34.

@sync_compatible
async def create_link_artifact(
link: str,
Expand All @@ -286,12 +253,16 @@ async def create_link_artifact(
Returns:
The table artifact ID.
"""
artifact = await LinkArtifact(
new_artifact = LinkArtifact(
key=key,
description=description,
link=link,
link_text=link_text,
).create(client)
)
create_coro = new_artifact.create(client)
if TYPE_CHECKING:
assert asyncio.iscoroutine(create_coro)
artifact = await create_coro

return artifact.id

Expand All @@ -315,18 +286,22 @@ async def create_markdown_artifact(
Returns:
The table artifact ID.
"""
artifact = await MarkdownArtifact(
new_artifact = MarkdownArtifact(
key=key,
description=description,
markdown=markdown,
).create()
)
create_coro = new_artifact.create()
if TYPE_CHECKING:
assert asyncio.iscoroutine(create_coro)
artifact = await create_coro

return artifact.id


@sync_compatible
async def create_table_artifact(
table: Union[Dict[str, List[Any]], List[Dict[str, Any]], List[List[Any]]],
table: Union[dict[str, list[Any]], list[dict[str, Any]], list[list[Any]]],
key: Optional[str] = None,
description: Optional[str] = None,
) -> UUID:
Expand All @@ -344,11 +319,15 @@ async def create_table_artifact(
The table artifact ID.
"""

artifact = await TableArtifact(
new_artifact = TableArtifact(
key=key,
description=description,
table=table,
).create()
)
create_coro = new_artifact.create()
if TYPE_CHECKING:
assert asyncio.iscoroutine(create_coro)
artifact = await create_coro

return artifact.id

Expand All @@ -373,11 +352,15 @@ async def create_progress_artifact(
The progress artifact ID.
"""

artifact = await ProgressArtifact(
new_artifact = ProgressArtifact(
key=key,
description=description,
progress=progress,
).create()
)
create_coro = new_artifact.create()
if TYPE_CHECKING:
assert asyncio.iscoroutine(create_coro)
artifact = await create_coro

return artifact.id

Expand All @@ -387,7 +370,7 @@ async def update_progress_artifact(
artifact_id: UUID,
progress: float,
description: Optional[str] = None,
client: Optional[PrefectClient] = None,
client: Optional["PrefectClient"] = None,
) -> UUID:
"""
Update a progress artifact.
Expand Down Expand Up @@ -444,10 +427,14 @@ async def create_image_artifact(
The image artifact ID.
"""

artifact = await ImageArtifact(
new_artifact = ImageArtifact(
key=key,
description=description,
image_url=image_url,
).create()
)
create_coro = new_artifact.create()
if TYPE_CHECKING:
assert asyncio.iscoroutine(create_coro)
artifact = await create_coro

return artifact.id
34 changes: 27 additions & 7 deletions src/prefect/automations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Type
from typing import TYPE_CHECKING, Optional, overload
from uuid import UUID

from pydantic import Field
Expand Down Expand Up @@ -112,17 +112,28 @@ async def update(self: Self):
auto.name = "new name"
auto.update()
"""
assert self.id is not None
async with get_client() as client:
automation = AutomationCore(
**self.model_dump(exclude={"id", "owner_resource"})
)
await client.update_automation(automation_id=self.id, automation=automation)

@overload
@classmethod
async def read(cls, id: UUID, name: Optional[str] = ...) -> Self:
...

@overload
@classmethod
async def read(cls, id: None = None, name: str = ...) -> Optional[Self]:
...

@classmethod
@sync_compatible
async def read(
cls: Type[Self], id: Optional[UUID] = None, name: Optional[str] = None
) -> Self:
cls, id: Optional[UUID] = None, name: Optional[str] = None
) -> Optional[Self]:
"""
Read an automation by ID or name.
automation = Automation.read(name="woodchonk")
Expand All @@ -145,13 +156,13 @@ async def read(
raise
if automation is None:
raise ValueError(f"Automation with ID {id!r} not found")
return Automation(**automation.model_dump())
return cls(**automation.model_dump())
else:
if TYPE_CHECKING:
assert name is not None
automation = await client.read_automations_by_name(name=name)
if len(automation) > 0:
return (
Automation(**automation[0].model_dump()) if automation else None
)
return cls(**automation[0].model_dump()) if automation else None
else:
raise ValueError(f"Automation with name {name!r} not found")

Expand All @@ -161,6 +172,9 @@ async def delete(self: Self) -> bool:
auto = Automation.read(id = 123)
auto.delete()
"""
if self.id is None:
raise ValueError("Can't delete an automation without an id")

async with get_client() as client:
try:
await client.delete_automation(self.id)
Expand All @@ -177,6 +191,9 @@ async def disable(self: Self) -> bool:
auto = Automation.read(id = 123)
auto.disable()
"""
if self.id is None:
raise ValueError("Can't disable an automation without an id")

async with get_client() as client:
try:
await client.pause_automation(self.id)
Expand All @@ -193,6 +210,9 @@ async def enable(self: Self) -> bool:
auto = Automation.read(id = 123)
auto.enable()
"""
if self.id is None:
raise ValueError("Can't enable an automation without an id")

async with get_client() as client:
try:
await client.resume_automation(self.id)
Expand Down
Loading
Loading