Skip to content

Commit

Permalink
Add Block.aload method and remove @sync_compatible from `Block.lo…
Browse files Browse the repository at this point in the history
…ad` (#16341)
  • Loading branch information
desertaxle authored Dec 11, 2024
1 parent 6bfce9e commit 3ea3428
Show file tree
Hide file tree
Showing 11 changed files with 431 additions and 40 deletions.
45 changes: 36 additions & 9 deletions src/prefect/_internal/compatibility/async_dispatch.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,57 @@
import asyncio
import inspect
from functools import wraps
from typing import Any, Callable, Coroutine, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Optional, TypeVar, Union

from typing_extensions import ParamSpec

from prefect.tasks import Task
if TYPE_CHECKING:
from prefect.tasks import Task

R = TypeVar("R")
P = ParamSpec("P")


def is_in_async_context() -> bool:
"""
Returns True if called from within an async context (coroutine or running event loop)
Returns True if called from within an async context.
An async context is one of:
- a coroutine
- a running event loop
- a task or flow that is async
"""
from prefect.context import get_run_context
from prefect.exceptions import MissingContextError

try:
asyncio.get_running_loop()
return True
except RuntimeError:
return False
run_ctx = get_run_context()
parent_obj = getattr(run_ctx, "task", None)
if not parent_obj:
parent_obj = getattr(run_ctx, "flow", None)
return getattr(parent_obj, "isasync", True)
except MissingContextError:
# not in an execution context, make best effort to
# decide whether to syncify
try:
asyncio.get_running_loop()
return True
except RuntimeError:
return False


def _is_acceptable_callable(obj: Union[Callable, Task]) -> bool:
def _is_acceptable_callable(obj: Union[Callable, "Task", classmethod]) -> bool:
if inspect.iscoroutinefunction(obj):
return True
if isinstance(obj, Task) and inspect.iscoroutinefunction(obj.fn):

# Check if a task or flow. Need to avoid importing `Task` or `Flow` here
# due to circular imports.
if (fn := getattr(obj, "fn", None)) and inspect.iscoroutinefunction(fn):
return True

if isinstance(obj, classmethod) and inspect.iscoroutinefunction(obj.__func__):
return True

return False


Expand Down Expand Up @@ -56,6 +81,8 @@ def wrapper(

if should_run_sync:
return sync_fn(*args, **kwargs)
if isinstance(async_impl, classmethod):
return async_impl.__func__(*args, **kwargs)
return async_impl(*args, **kwargs)

return wrapper # type: ignore
Expand Down
156 changes: 141 additions & 15 deletions src/prefect/blocks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from typing_extensions import Literal, ParamSpec, Self, get_args

import prefect.exceptions
from prefect._internal.compatibility.async_dispatch import async_dispatch
from prefect.client.schemas import (
DEFAULT_BLOCK_SCHEMA_VERSION,
BlockDocument,
Expand All @@ -53,7 +54,7 @@
from prefect.logging.loggers import disable_logger
from prefect.plugins import load_prefect_collections
from prefect.types import SecretDict
from prefect.utilities.asyncutils import sync_compatible
from prefect.utilities.asyncutils import run_coro_as_sync, sync_compatible
from prefect.utilities.collections import listrepr, remove_nested_keys, visit_collection
from prefect.utilities.dispatch import lookup_type, register_base_type
from prefect.utilities.hashing import hash_objects
Expand All @@ -64,7 +65,7 @@
if TYPE_CHECKING:
from pydantic.main import IncEx

from prefect.client.orchestration import PrefectClient
from prefect.client.orchestration import PrefectClient, SyncPrefectClient

R = TypeVar("R")
P = ParamSpec("P")
Expand Down Expand Up @@ -777,12 +778,11 @@ def _define_metadata_on_nested_blocks(
)

@classmethod
@inject_client
async def _get_block_document(
async def _aget_block_document(
cls,
name: str,
client: Optional["PrefectClient"] = None,
):
client: "PrefectClient",
) -> tuple[BlockDocument, str]:
if cls.__name__ == "Block":
block_type_slug, block_document_name = name.split("/", 1)
else:
Expand All @@ -801,6 +801,30 @@ async def _get_block_document(

return block_document, block_document_name

@classmethod
def _get_block_document(
cls,
name: str,
client: "SyncPrefectClient",
) -> tuple[BlockDocument, str]:
if cls.__name__ == "Block":
block_type_slug, block_document_name = name.split("/", 1)
else:
block_type_slug = cls.get_block_type_slug()
block_document_name = name

try:
block_document = client.read_block_document_by_name(
name=block_document_name, block_type_slug=block_type_slug
)
except prefect.exceptions.ObjectNotFound as e:
raise ValueError(
f"Unable to find block document named {block_document_name} for block"
f" type {block_type_slug}"
) from e

return block_document, block_document_name

@classmethod
@sync_compatible
@inject_client
Expand Down Expand Up @@ -829,9 +853,97 @@ async def _get_block_document_by_id(
return block_document, block_document.name

@classmethod
@sync_compatible
@inject_client
async def load(
async def aload(
cls,
name: str,
validate: bool = True,
client: Optional["PrefectClient"] = None,
) -> "Self":
"""
Retrieves data from the block document with the given name for the block type
that corresponds with the current class and returns an instantiated version of
the current class with the data stored in the block document.
If a block document for a given block type is saved with a different schema
than the current class calling `aload`, a warning will be raised.
If the current class schema is a subset of the block document schema, the block
can be loaded as normal using the default `validate = True`.
If the current class schema is a superset of the block document schema, `aload`
must be called with `validate` set to False to prevent a validation error. In
this case, the block attributes will default to `None` and must be set manually
and saved to a new block document before the block can be used as expected.
Args:
name: The name or slug of the block document. A block document slug is a
string with the format <block_type_slug>/<block_document_name>
validate: If False, the block document will be loaded without Pydantic
validating the block schema. This is useful if the block schema has
changed client-side since the block document referred to by `name` was saved.
client: The client to use to load the block document. If not provided, the
default client will be injected.
Raises:
ValueError: If the requested block document is not found.
Returns:
An instance of the current class hydrated with the data stored in the
block document with the specified name.
Examples:
Load from a Block subclass with a block document name:
```python
class Custom(Block):
message: str
Custom(message="Hello!").save("my-custom-message")
loaded_block = await Custom.aload("my-custom-message")
```
Load from Block with a block document slug:
```python
class Custom(Block):
message: str
Custom(message="Hello!").save("my-custom-message")
loaded_block = await Block.aload("custom/my-custom-message")
```
Migrate a block document to a new schema:
```python
# original class
class Custom(Block):
message: str
Custom(message="Hello!").save("my-custom-message")
# Updated class with new required field
class Custom(Block):
message: str
number_of_ducks: int
loaded_block = await Custom.aload("my-custom-message", validate=False)
# Prints UserWarning about schema mismatch
loaded_block.number_of_ducks = 42
loaded_block.save("my-custom-message", overwrite=True)
```
"""
if TYPE_CHECKING:
assert isinstance(client, PrefectClient)
block_document, _ = await cls._aget_block_document(name, client=client)

return cls._load_from_block_document(block_document, validate=validate)

@classmethod
@async_dispatch(aload)
def load(
cls,
name: str,
validate: bool = True,
Expand Down Expand Up @@ -912,9 +1024,19 @@ class Custom(Block):
loaded_block.save("my-custom-message", overwrite=True)
```
"""
block_document, block_document_name = await cls._get_block_document(
name, client=client
)
# Need to use a `PrefectClient` here to ensure `Block.load` and `Block.aload` signatures match
# TODO: replace with only sync client once all internal calls are updated to use `Block.aload` and `@async_dispatch` is removed
if client is None:
# If a client wasn't provided, we get to use a sync client
from prefect.client.orchestration import get_client

with get_client(sync_client=True) as sync_client:
block_document, _ = cls._get_block_document(name, client=sync_client)
else:
# If a client was provided, reuse it, even though it's async, to avoid excessive client creation
block_document, _ = run_coro_as_sync(
cls._aget_block_document(name, client=client)
)

return cls._load_from_block_document(block_document, validate=validate)

Expand Down Expand Up @@ -968,14 +1090,16 @@ async def load_from_ref(
"""
block_document = None
if isinstance(ref, (str, UUID)):
block_document, _ = await cls._get_block_document_by_id(ref)
block_document, _ = await cls._get_block_document_by_id(ref, client=client)
elif isinstance(ref, dict):
if block_document_id := ref.get("block_document_id"):
block_document, _ = await cls._get_block_document_by_id(
block_document_id
block_document_id, client=client
)
elif block_document_slug := ref.get("block_document_slug"):
block_document, _ = await cls._get_block_document(block_document_slug)
block_document, _ = await cls._get_block_document(
block_document_slug, client=client
)

if not block_document:
raise ValueError(f"Invalid reference format {ref!r}.")
Expand Down Expand Up @@ -1220,7 +1344,9 @@ async def delete(
name: str,
client: Optional["PrefectClient"] = None,
):
block_document, block_document_name = await cls._get_block_document(name)
if TYPE_CHECKING:
assert isinstance(client, PrefectClient)
block_document, _ = await cls._aget_block_document(name, client=client)

await client.delete_block_document(block_document.id)

Expand Down
2 changes: 1 addition & 1 deletion src/prefect/cli/_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ async def prompt_push_custom_docker_image(
docker_registry_creds_name = f"deployment-{slugify(deployment_config['name'])}-{slugify(deployment_config['work_pool']['name'])}-registry-creds"
create_new_block = False
try:
await credentials_block.load(docker_registry_creds_name)
await credentials_block.aload(docker_registry_creds_name)
if not confirm(
(
"Would you like to use the existing Docker registry credentials"
Expand Down
2 changes: 1 addition & 1 deletion src/prefect/cli/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,7 @@ async def _generate_git_clone_pull_step(
)

try:
await Secret.load(token_secret_block_name)
await Secret.aload(token_secret_block_name)
if not confirm(
(
"We found an existing token saved for this deployment. Would"
Expand Down
38 changes: 38 additions & 0 deletions src/prefect/client/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4415,3 +4415,41 @@ def update_flow_run_labels(
json=labels,
)
response.raise_for_status()

def read_block_document_by_name(
self,
name: str,
block_type_slug: str,
include_secrets: bool = True,
) -> BlockDocument:
"""
Read the block document with the specified name that corresponds to a
specific block type name.
Args:
name: The block document name.
block_type_slug: The block type slug.
include_secrets (bool): whether to include secret values
on the Block, corresponding to Pydantic's `SecretStr` and
`SecretBytes` fields. These fields are automatically obfuscated
by Pydantic, but users can additionally choose not to receive
their values from the API. Note that any business logic on the
Block may not work if this is `False`.
Raises:
httpx.RequestError: if the block document was not found for any reason
Returns:
A block document or None.
"""
try:
response = self._client.get(
f"/block_types/slug/{block_type_slug}/block_documents/name/{name}",
params=dict(include_secrets=include_secrets),
)
except httpx.HTTPStatusError as e:
if e.response.status_code == status.HTTP_404_NOT_FOUND:
raise prefect.exceptions.ObjectNotFound(http_exc=e) from e
else:
raise
return BlockDocument.model_validate(response.json())
16 changes: 15 additions & 1 deletion src/prefect/client/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from collections.abc import Awaitable, Coroutine
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, overload

from typing_extensions import Concatenate, ParamSpec, TypeGuard, TypeVar

Expand Down Expand Up @@ -71,9 +71,23 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return wrapper


@overload
def inject_client(
fn: Callable[P, Coroutine[Any, Any, R]],
) -> Callable[P, Coroutine[Any, Any, R]]:
...


@overload
def inject_client(
fn: Callable[P, R],
) -> Callable[P, R]:
...


def inject_client(
fn: Callable[P, Union[Coroutine[Any, Any, R], R]],
) -> Callable[P, Union[Coroutine[Any, Any, R], R]]:
"""
Simple helper to provide a context managed client to an asynchronous function.
Expand Down
Loading

0 comments on commit 3ea3428

Please sign in to comment.