Skip to content

Commit

Permalink
[typing] prefect.tasks and prefect.task_worker (#16332)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz authored Dec 12, 2024
1 parent e791f82 commit a0f3219
Show file tree
Hide file tree
Showing 14 changed files with 160 additions and 117 deletions.
4 changes: 3 additions & 1 deletion src/prefect/_internal/compatibility/async_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def is_in_async_context() -> bool:
return False


def _is_acceptable_callable(obj: Union[Callable, "Task", classmethod]) -> bool:
def _is_acceptable_callable(
obj: Union[Callable[P, R], "Task[P, R]", classmethod],
) -> bool:
if inspect.iscoroutinefunction(obj):
return True

Expand Down
25 changes: 15 additions & 10 deletions src/prefect/_internal/pydantic/v2_validated_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
arguments.
"""

from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union

# importing directly from v2 to be able to create a v2 model
from pydantic import BaseModel, ConfigDict, create_model, field_validator
from pydantic.v1.decorator import ValidatedFunction
from pydantic.v1.errors import ConfigError
from pydantic.v1.utils import to_camel
from typing_extensions import TypeAlias

if TYPE_CHECKING:
ConfigType = Union[None, Type[Any], Dict[str, Any]]
ConfigType: TypeAlias = Union[None, type[Any], dict[str, Any]]

V_POSITIONAL_ONLY_NAME = "v__positional_only"
V_DUPLICATE_KWARGS = "v__duplicate_kwargs"
Expand All @@ -24,13 +25,17 @@
class V2ValidatedFunction(ValidatedFunction):
def create_model(
self,
fields: Dict[str, Any],
fields: dict[str, Any],
takes_args: bool,
takes_kwargs: bool,
config: ConfigDict,
config: "ConfigType",
) -> None:
pos_args = len(self.arg_mapping)

config = {} if config is None else config
if not isinstance(config, dict):
raise TypeError(f"config must be None or a dict, got {type(config)}")

if config.get("fields") or config.get("alias_generator"):
raise ConfigError(
'Setting the "fields" and "alias_generator" property on custom Config'
Expand All @@ -42,11 +47,11 @@ def create_model(

# This is the key change -- inheriting the BaseModel class from v2
class DecoratorBaseModel(BaseModel):
model_config = config
model_config: ClassVar[ConfigDict] = ConfigDict(**config)

@field_validator(self.v_args_name, check_fields=False)
@classmethod
def check_args(cls, v: Optional[List[Any]]) -> Optional[List[Any]]:
def check_args(cls, v: Optional[list[Any]]) -> Optional[list[Any]]:
if takes_args or v is None:
return v

Expand All @@ -58,8 +63,8 @@ def check_args(cls, v: Optional[List[Any]]) -> Optional[List[Any]]:
@field_validator(self.v_kwargs_name, check_fields=False)
@classmethod
def check_kwargs(
cls, v: Optional[Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
cls, v: Optional[dict[str, Any]]
) -> Optional[dict[str, Any]]:
if takes_kwargs or v is None:
return v

Expand All @@ -69,7 +74,7 @@ def check_kwargs(

@field_validator(V_POSITIONAL_ONLY_NAME, check_fields=False)
@classmethod
def check_positional_only(cls, v: Optional[List[str]]) -> None:
def check_positional_only(cls, v: Optional[list[str]]) -> None:
if v is None:
return

Expand All @@ -82,7 +87,7 @@ def check_positional_only(cls, v: Optional[List[str]]) -> None:

@field_validator(V_DUPLICATE_KWARGS, check_fields=False)
@classmethod
def check_duplicate_kwargs(cls, v: Optional[List[str]]) -> None:
def check_duplicate_kwargs(cls, v: Optional[list[str]]) -> None:
if v is None:
return

Expand Down
5 changes: 3 additions & 2 deletions src/prefect/blocks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ class Block(BaseModel, ABC):
json_schema_extra=schema_extra,
)

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.block_initialization()

Expand Down Expand Up @@ -629,7 +629,8 @@ def _generate_code_example(cls) -> str:
"""Generates a default code example for the current class"""
qualified_name = to_qualified_name(cls)
module_str = ".".join(qualified_name.split(".")[:-1])
class_name = cls.__name__
origin = cls.__pydantic_generic_metadata__.get("origin") or cls
class_name = origin.__name__
block_variable_name = f'{cls.get_block_type_slug().replace("-", "_")}_block'

return dedent(
Expand Down
1 change: 1 addition & 0 deletions src/prefect/blocks/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class Secret(Block, Generic[T]):

_logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/c6f20e556dd16effda9df16551feecfb5822092b-48x48.png"
_documentation_url = "https://docs.prefect.io/latest/develop/blocks"
_description = "A block that represents a secret value. The value stored in this block will be obfuscated when this block is viewed or edited in the UI."

value: Union[SecretStr, PydanticSecret[T]] = Field(
default=...,
Expand Down
28 changes: 14 additions & 14 deletions src/prefect/deployments/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def fast_flow():
import tempfile
from datetime import datetime, timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union
from uuid import UUID

from pydantic import (
Expand Down Expand Up @@ -160,7 +160,7 @@ class RunnerDeployment(BaseModel):
paused: Optional[bool] = Field(
default=None, description="Whether or not the deployment is paused."
)
parameters: Dict[str, Any] = Field(default_factory=dict)
parameters: dict[str, Any] = Field(default_factory=dict)
entrypoint: Optional[str] = Field(
default=None,
description=(
Expand Down Expand Up @@ -198,7 +198,7 @@ class RunnerDeployment(BaseModel):
" the deployment is registered with a built runner."
),
)
job_variables: Dict[str, Any] = Field(
job_variables: dict[str, Any] = Field(
default_factory=dict,
description=(
"Job variables used to override the default values of a work pool"
Expand Down Expand Up @@ -280,7 +280,7 @@ async def apply(
async with get_client() as client:
flow_id = await client.create_flow_from_name(self.flow_name)

create_payload = dict(
create_payload: dict[str, Any] = dict(
flow_id=flow_id,
name=self.name,
work_queue_name=self.work_queue_name,
Expand Down Expand Up @@ -428,7 +428,7 @@ def _construct_deployment_schedules(
else:
return [create_deployment_schedule_create(schedule)]

def _set_defaults_from_flow(self, flow: "Flow"):
def _set_defaults_from_flow(self, flow: "Flow[..., Any]"):
self._parameter_openapi_schema = parameter_schema(flow)

if not self.version:
Expand All @@ -439,7 +439,7 @@ def _set_defaults_from_flow(self, flow: "Flow"):
@classmethod
def from_flow(
cls,
flow: "Flow",
flow: "Flow[..., Any]",
name: str,
interval: Optional[
Union[Iterable[Union[int, float, timedelta]], int, float, timedelta]
Expand All @@ -449,15 +449,15 @@ def from_flow(
paused: Optional[bool] = None,
schedules: Optional["FlexibleScheduleList"] = None,
concurrency_limit: Optional[Union[int, ConcurrencyLimitConfig, None]] = None,
parameters: Optional[dict] = None,
parameters: Optional[dict[str, Any]] = None,
triggers: Optional[List[Union[DeploymentTriggerTypes, TriggerTypes]]] = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
version: Optional[str] = None,
enforce_parameter_schema: bool = True,
work_pool_name: Optional[str] = None,
work_queue_name: Optional[str] = None,
job_variables: Optional[Dict[str, Any]] = None,
job_variables: Optional[dict[str, Any]] = None,
entrypoint_type: EntrypointType = EntrypointType.FILE_PATH,
) -> "RunnerDeployment":
"""
Expand Down Expand Up @@ -588,15 +588,15 @@ def from_entrypoint(
paused: Optional[bool] = None,
schedules: Optional["FlexibleScheduleList"] = None,
concurrency_limit: Optional[Union[int, ConcurrencyLimitConfig, None]] = None,
parameters: Optional[dict] = None,
parameters: Optional[dict[str, Any]] = None,
triggers: Optional[List[Union[DeploymentTriggerTypes, TriggerTypes]]] = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
version: Optional[str] = None,
enforce_parameter_schema: bool = True,
work_pool_name: Optional[str] = None,
work_queue_name: Optional[str] = None,
job_variables: Optional[Dict[str, Any]] = None,
job_variables: Optional[dict[str, Any]] = None,
) -> "RunnerDeployment":
"""
Configure a deployment for a given flow located at a given entrypoint.
Expand Down Expand Up @@ -689,15 +689,15 @@ async def from_storage(
paused: Optional[bool] = None,
schedules: Optional["FlexibleScheduleList"] = None,
concurrency_limit: Optional[Union[int, ConcurrencyLimitConfig, None]] = None,
parameters: Optional[dict] = None,
parameters: Optional[dict[str, Any]] = None,
triggers: Optional[List[Union[DeploymentTriggerTypes, TriggerTypes]]] = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
version: Optional[str] = None,
enforce_parameter_schema: bool = True,
work_pool_name: Optional[str] = None,
work_queue_name: Optional[str] = None,
job_variables: Optional[Dict[str, Any]] = None,
job_variables: Optional[dict[str, Any]] = None,
):
"""
Create a RunnerDeployment from a flow located at a given entrypoint and stored in a
Expand Down Expand Up @@ -945,8 +945,8 @@ def local_flow():

console.print(f"Successfully pushed image {image.reference!r}", style="green")

deployment_exceptions = []
deployment_ids = []
deployment_exceptions: list[dict[str, Any]] = []
deployment_ids: list[UUID] = []
image_ref = image.reference if image else None
for deployment in track(
deployments,
Expand Down
6 changes: 4 additions & 2 deletions src/prefect/flow_runs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import (
TYPE_CHECKING,
Dict,
Any,
Optional,
Type,
TypeVar,
Expand Down Expand Up @@ -430,7 +430,9 @@ async def suspend_flow_run(


@sync_compatible
async def resume_flow_run(flow_run_id, run_input: Optional[Dict] = None):
async def resume_flow_run(
flow_run_id: UUID, run_input: Optional[dict[str, Any]] = None
) -> None:
"""
Resumes a paused flow.
Expand Down
10 changes: 4 additions & 6 deletions src/prefect/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,14 +564,12 @@ def resolve_block_reference(data: Any) -> Any:
"Cannot mix Pydantic v1 and v2 types as arguments to a flow."
)

validated_fn_kwargs = dict(arbitrary_types_allowed=True)

if has_v1_models:
validated_fn = V1ValidatedFunction(
self.fn, config={"arbitrary_types_allowed": True}
)
validated_fn = V1ValidatedFunction(self.fn, config=validated_fn_kwargs)
else:
validated_fn = V2ValidatedFunction(
self.fn, config=pydantic.ConfigDict(arbitrary_types_allowed=True)
)
validated_fn = V2ValidatedFunction(self.fn, config=validated_fn_kwargs)

try:
with warnings.catch_warnings():
Expand Down
8 changes: 6 additions & 2 deletions src/prefect/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ async def update_for_flow(self, flow: "Flow") -> Self:
return self.model_copy(update=update)

@sync_compatible
async def update_for_task(self: Self, task: "Task") -> Self:
async def update_for_task(self: Self, task: "Task[P, R]") -> Self:
"""
Create a new result store for a task.
Expand Down Expand Up @@ -915,7 +915,11 @@ async def store_parameters(self, identifier: UUID, parameters: Dict[str, Any]):
)

@sync_compatible
async def read_parameters(self, identifier: UUID) -> Dict[str, Any]:
async def read_parameters(self, identifier: UUID) -> dict[str, Any]:
if self.result_storage is None:
raise ValueError(
"Result store is not configured - must have a result storage block to read parameters"
)
record = ResultRecord.deserialize(
await self.result_storage.read_path(f"parameters/{identifier}")
)
Expand Down
6 changes: 3 additions & 3 deletions src/prefect/runner/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@ async def pull_code(self):
"""
...

def to_pull_step(self) -> dict:
def to_pull_step(self) -> dict[str, Any]:
"""
Returns a dictionary representation of the storage object that can be
used as a deployment pull step.
"""
...

def __eq__(self, __value) -> bool:
def __eq__(self, __value: Any) -> bool:
"""
Equality check for runner storage objects.
"""
Expand All @@ -69,7 +69,7 @@ def __eq__(self, __value) -> bool:

class GitCredentials(TypedDict, total=False):
username: str
access_token: Union[str, Secret]
access_token: Union[str, Secret[str]]


class GitRepository:
Expand Down
10 changes: 5 additions & 5 deletions src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,7 +1386,7 @@ def run_task_sync(
task_run_id: Optional[UUID] = None,
task_run: Optional[TaskRun] = None,
parameters: Optional[Dict[str, Any]] = None,
wait_for: Optional[Iterable[PrefectFuture]] = None,
wait_for: Optional[Iterable[PrefectFuture[R]]] = None,
return_type: Literal["state", "result"] = "result",
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
context: Optional[Dict[str, Any]] = None,
Expand All @@ -1413,7 +1413,7 @@ async def run_task_async(
task_run_id: Optional[UUID] = None,
task_run: Optional[TaskRun] = None,
parameters: Optional[Dict[str, Any]] = None,
wait_for: Optional[Iterable[PrefectFuture]] = None,
wait_for: Optional[Iterable[PrefectFuture[R]]] = None,
return_type: Literal["state", "result"] = "result",
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
context: Optional[Dict[str, Any]] = None,
Expand All @@ -1440,7 +1440,7 @@ def run_generator_task_sync(
task_run_id: Optional[UUID] = None,
task_run: Optional[TaskRun] = None,
parameters: Optional[Dict[str, Any]] = None,
wait_for: Optional[Iterable[PrefectFuture]] = None,
wait_for: Optional[Iterable[PrefectFuture[R]]] = None,
return_type: Literal["state", "result"] = "result",
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
context: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -1495,7 +1495,7 @@ async def run_generator_task_async(
task_run_id: Optional[UUID] = None,
task_run: Optional[TaskRun] = None,
parameters: Optional[Dict[str, Any]] = None,
wait_for: Optional[Iterable[PrefectFuture]] = None,
wait_for: Optional[Iterable[PrefectFuture[R]]] = None,
return_type: Literal["state", "result"] = "result",
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
context: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -1551,7 +1551,7 @@ def run_task(
task_run_id: Optional[UUID] = None,
task_run: Optional[TaskRun] = None,
parameters: Optional[Dict[str, Any]] = None,
wait_for: Optional[Iterable[PrefectFuture]] = None,
wait_for: Optional[Iterable[PrefectFuture[R]]] = None,
return_type: Literal["state", "result"] = "result",
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
context: Optional[Dict[str, Any]] = None,
Expand Down
Loading

0 comments on commit a0f3219

Please sign in to comment.