From c2420eecef1481d2a89a9edae608fe4c12c50684 Mon Sep 17 00:00:00 2001 From: zzstoatzz Date: Mon, 9 Dec 2024 11:20:23 -0600 Subject: [PATCH] rebuild models --- src/prefect/client/_adapters.py | 9 ++++----- src/prefect/client/orchestration.py | 8 ++++++++ src/prefect/deployments/flow_runs.py | 6 +++--- src/prefect/server/events/pipeline.py | 10 ++++------ tests/test_task_engine.py | 6 ++++-- 5 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/prefect/client/_adapters.py b/src/prefect/client/_adapters.py index 906878b80c1f..fc0597aa29ef 100644 --- a/src/prefect/client/_adapters.py +++ b/src/prefect/client/_adapters.py @@ -34,10 +34,12 @@ ) from prefect.events.schemas.automations import Automation - defer_build_cfg = ConfigDict(defer_build=True) -BlockTypeAdapter = TypeAdapter("BlockType", config=defer_build_cfg) +# Create the adapters with forward refs +BlockTypeAdapter: TypeAdapter["BlockType"] = TypeAdapter( + "BlockType", config=defer_build_cfg +) BlockSchemaAdapter = TypeAdapter(List["BlockSchema"], config=defer_build_cfg) ConcurrencyLimitAdapter = TypeAdapter("ConcurrencyLimit", config=defer_build_cfg) ConcurrencyLimitListAdapter = TypeAdapter( @@ -53,9 +55,6 @@ BlockDocumentListAdapter = TypeAdapter(List["BlockDocument"], config=defer_build_cfg) BlockSchemaListAdapter = TypeAdapter(List["BlockSchema"], config=defer_build_cfg) BlockTypeListAdapter = TypeAdapter(List["BlockType"], config=defer_build_cfg) -ConcurrencyLimitListAdapter = TypeAdapter( - List["ConcurrencyLimit"], config=defer_build_cfg -) DeploymentResponseListAdapter = TypeAdapter( List["DeploymentResponse"], config=defer_build_cfg ) diff --git a/src/prefect/client/orchestration.py b/src/prefect/client/orchestration.py index 57148ba2d3eb..d2ea4dedf18d 100644 --- a/src/prefect/client/orchestration.py +++ b/src/prefect/client/orchestration.py @@ -112,6 +112,7 @@ FlowRunPolicy, Log, Parameter, + State, TaskRunPolicy, TaskRunResult, Variable, @@ -140,6 +141,7 @@ from prefect.events import filters from prefect.events.schemas.automations import Automation, AutomationCore from prefect.logging import get_logger +from prefect.results import BaseResult, ResultRecordMetadata from prefect.settings import ( PREFECT_API_DATABASE_CONNECTION_URL, PREFECT_API_ENABLE_HTTP2, @@ -173,6 +175,12 @@ T = TypeVar("T") +BaseResult.model_rebuild() +ResultRecordMetadata.model_rebuild() + +State.model_rebuild() + + @overload def get_client( *, diff --git a/src/prefect/deployments/flow_runs.py b/src/prefect/deployments/flow_runs.py index 8c66b5d87bf9..bb77780dd85e 100644 --- a/src/prefect/deployments/flow_runs.py +++ b/src/prefect/deployments/flow_runs.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import TYPE_CHECKING, Iterable, Optional, Union +from typing import TYPE_CHECKING, Any, Iterable, Optional, Union from uuid import UUID import anyio @@ -35,7 +35,7 @@ async def run_deployment( name: Union[str, UUID], client: Optional["PrefectClient"] = None, - parameters: Optional[dict] = None, + parameters: Optional[dict[str, Any]] = None, scheduled_time: Optional[datetime] = None, flow_run_name: Optional[str] = None, timeout: Optional[float] = None, @@ -44,7 +44,7 @@ async def run_deployment( idempotency_key: Optional[str] = None, work_queue_name: Optional[str] = None, as_subflow: Optional[bool] = True, - job_variables: Optional[dict] = None, + job_variables: Optional[dict[str, Any]] = None, ) -> "FlowRun": """ Create a flow run for a deployment and return it after completion or a timeout. diff --git a/src/prefect/server/events/pipeline.py b/src/prefect/server/events/pipeline.py index 0a544f79bc24..9b82fde4a2f5 100644 --- a/src/prefect/server/events/pipeline.py +++ b/src/prefect/server/events/pipeline.py @@ -1,5 +1,3 @@ -from typing import List - from prefect.server.events.schemas.events import Event, ReceivedEvent from prefect.server.events.services import event_persister from prefect.server.services import task_run_recorder @@ -8,8 +6,8 @@ class EventsPipeline: @staticmethod - def events_to_messages(events) -> List[MemoryMessage]: - messages = [] + def events_to_messages(events: list[Event]) -> list[MemoryMessage]: + messages: list[MemoryMessage] = [] for event in events: received_event = ReceivedEvent(**event.model_dump()) message = MemoryMessage( @@ -19,11 +17,11 @@ def events_to_messages(events) -> List[MemoryMessage]: messages.append(message) return messages - async def process_events(self, events: List[Event]): + async def process_events(self, events: list[Event]): messages = self.events_to_messages(events) await self.process_messages(messages) - async def process_messages(self, messages: List[MemoryMessage]): + async def process_messages(self, messages: list[MemoryMessage]): for message in messages: await self.process_message(message) diff --git a/tests/test_task_engine.py b/tests/test_task_engine.py index 185045fb9455..0fed0d887fa2 100644 --- a/tests/test_task_engine.py +++ b/tests/test_task_engine.py @@ -573,14 +573,16 @@ async def second(): async def test_task_run_states( self, - prefect_client, + prefect_client: PrefectClient, events_pipeline, ): @task async def foo(): - return TaskRunContext.get().task_run.id + assert (ctx := TaskRunContext.get()) is not None + return ctx.task_run.id task_run_id = await run_task_async(foo) + assert isinstance(task_run_id, UUID) await events_pipeline.process_events() states = await prefect_client.read_task_run_states(task_run_id)