From a6fc0d4a860953163f2c915beb40c9d26cd44524 Mon Sep 17 00:00:00 2001 From: Martijn Pieters Date: Fri, 20 Dec 2024 16:56:11 +0000 Subject: [PATCH] [typing] misc prefect modules A few more modules to pass strict typing checks. --- src/prefect/artifacts.py | 135 ++++++++++++-------------- src/prefect/automations.py | 34 +++++-- src/prefect/context.py | 43 +++----- src/prefect/locking/filesystem.py | 15 +-- src/prefect/locking/memory.py | 8 +- src/prefect/locking/protocol.py | 2 +- src/prefect/plugins.py | 22 +++-- src/prefect/telemetry/processors.py | 12 +-- src/prefect/telemetry/services.py | 63 ++++++------ src/prefect/workers/__init__.py | 2 + tests/test_plugins.py | 4 +- tests/typesafety/test_automations.yml | 19 ++++ tests/typesafety/test_flows.yml | 2 +- 13 files changed, 192 insertions(+), 169 deletions(-) create mode 100644 tests/typesafety/test_automations.yml diff --git a/src/prefect/artifacts.py b/src/prefect/artifacts.py index 785effbe92216..b8903f960c8c8 100644 --- a/src/prefect/artifacts.py +++ b/src/prefect/artifacts.py @@ -2,19 +2,20 @@ Interface for creating and reading artifacts. """ -from __future__ import annotations - +import asyncio import json # noqa: I001 import math import warnings -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Optional, Union from uuid import UUID +from typing_extensions import Self + from prefect.client.schemas.actions import ArtifactCreate as ArtifactRequest from prefect.client.schemas.actions import ArtifactUpdate from prefect.client.schemas.filters import ArtifactFilter, ArtifactFilterKey from prefect.client.schemas.sorting import ArtifactSort -from prefect.client.utilities import get_or_create_client, inject_client +from prefect.client.utilities import get_or_create_client from prefect.logging.loggers import get_logger from prefect.utilities.asyncutils import sync_compatible from prefect.utilities.context import get_task_and_flow_run_ids @@ -22,8 +23,6 @@ logger = get_logger("artifacts") if TYPE_CHECKING: - from typing_extensions import Self - from prefect.client.orchestration import PrefectClient from prefect.client.schemas.objects import Artifact as ArtifactResponse @@ -43,7 +42,7 @@ class Artifact(ArtifactRequest): @sync_compatible async def create( - self: "Self", + self: Self, client: Optional["PrefectClient"] = None, ) -> "ArtifactResponse": """ @@ -95,16 +94,15 @@ async def get( (ArtifactResponse, optional): The artifact (if found). """ client, _ = get_or_create_client(client) - return next( - iter( - await client.read_artifacts( - limit=1, - sort=ArtifactSort.UPDATED_DESC, - artifact_filter=ArtifactFilter(key=ArtifactFilterKey(any_=[key])), - ) + filter_key_value = None if key is None else [key] + artifacts = await client.read_artifacts( + limit=1, + sort=ArtifactSort.UPDATED_DESC, + artifact_filter=ArtifactFilter( + key=ArtifactFilterKey(any_=filter_key_value) ), - None, ) + return None if not artifacts else artifacts[0] @classmethod @sync_compatible @@ -112,10 +110,10 @@ async def get_or_create( cls, key: Optional[str] = None, description: Optional[str] = None, - data: Optional[Union[Dict[str, Any], Any]] = None, + data: Optional[Union[dict[str, Any], Any]] = None, client: Optional["PrefectClient"] = None, **kwargs: Any, - ) -> Tuple["ArtifactResponse", bool]: + ) -> tuple["ArtifactResponse", bool]: """ A method to get or create an artifact. @@ -128,18 +126,20 @@ async def get_or_create( Returns: (ArtifactResponse): The artifact, either retrieved or created. """ - artifact = await cls.get(key, client) + artifact_coro = cls.get(key, client) + if TYPE_CHECKING: + assert asyncio.iscoroutine(artifact_coro) + artifact = await artifact_coro if artifact: return artifact, False - else: - return ( - await cls(key=key, description=description, data=data, **kwargs).create( - client - ), - True, - ) - async def format(self) -> Optional[Union[Dict[str, Any], Any]]: + new_artifact = cls(key=key, description=description, data=data, **kwargs) + create_coro = new_artifact.create(client) + if TYPE_CHECKING: + assert asyncio.iscoroutine(create_coro) + return await create_coro, True + + async def format(self) -> Optional[Union[dict[str, Any], Any]]: return json.dumps(self.data) @@ -165,13 +165,13 @@ async def format(self) -> str: class TableArtifact(Artifact): - table: Union[Dict[str, List[Any]], List[Dict[str, Any]], List[List[Any]]] + table: Union[dict[str, list[Any]], list[dict[str, Any]], list[list[Any]]] type: Optional[str] = "table" @classmethod def _sanitize( - cls, item: Union[Dict[str, Any], List[Any], float] - ) -> Union[Dict[str, Any], List[Any], int, float, None]: + cls, item: Union[dict[str, Any], list[Any], float] + ) -> Union[dict[str, Any], list[Any], int, float, None]: """ Sanitize NaN values in a given item. The item can be a dict, list or float. @@ -230,39 +230,6 @@ async def format(self) -> str: return self.image_url -@inject_client -async def _create_artifact( - type: str, - key: Optional[str] = None, - description: Optional[str] = None, - data: Optional[Union[Dict[str, Any], Any]] = None, - client: Optional["PrefectClient"] = None, -) -> UUID: - """ - Helper function to create an artifact. - - Arguments: - type: A string identifying the type of artifact. - key: A user-provided string identifier. - The key must only contain lowercase letters, numbers, and dashes. - description: A user-specified description of the artifact. - data: A JSON payload that allows for a result to be retrieved. - client: The PrefectClient - - Returns: - - The table artifact ID. - """ - - artifact = await Artifact( - type=type, - key=key, - description=description, - data=data, - ).create(client) - - return artifact.id - - @sync_compatible async def create_link_artifact( link: str, @@ -286,12 +253,16 @@ async def create_link_artifact( Returns: The table artifact ID. """ - artifact = await LinkArtifact( + new_artifact = LinkArtifact( key=key, description=description, link=link, link_text=link_text, - ).create(client) + ) + create_coro = new_artifact.create(client) + if TYPE_CHECKING: + assert asyncio.iscoroutine(create_coro) + artifact = await create_coro return artifact.id @@ -315,18 +286,22 @@ async def create_markdown_artifact( Returns: The table artifact ID. """ - artifact = await MarkdownArtifact( + new_artifact = MarkdownArtifact( key=key, description=description, markdown=markdown, - ).create() + ) + create_coro = new_artifact.create() + if TYPE_CHECKING: + assert asyncio.iscoroutine(create_coro) + artifact = await create_coro return artifact.id @sync_compatible async def create_table_artifact( - table: Union[Dict[str, List[Any]], List[Dict[str, Any]], List[List[Any]]], + table: Union[dict[str, list[Any]], list[dict[str, Any]], list[list[Any]]], key: Optional[str] = None, description: Optional[str] = None, ) -> UUID: @@ -344,11 +319,15 @@ async def create_table_artifact( The table artifact ID. """ - artifact = await TableArtifact( + new_artifact = TableArtifact( key=key, description=description, table=table, - ).create() + ) + create_coro = new_artifact.create() + if TYPE_CHECKING: + assert asyncio.iscoroutine(create_coro) + artifact = await create_coro return artifact.id @@ -373,11 +352,15 @@ async def create_progress_artifact( The progress artifact ID. """ - artifact = await ProgressArtifact( + new_artifact = ProgressArtifact( key=key, description=description, progress=progress, - ).create() + ) + create_coro = new_artifact.create() + if TYPE_CHECKING: + assert asyncio.iscoroutine(create_coro) + artifact = await create_coro return artifact.id @@ -387,7 +370,7 @@ async def update_progress_artifact( artifact_id: UUID, progress: float, description: Optional[str] = None, - client: Optional[PrefectClient] = None, + client: Optional["PrefectClient"] = None, ) -> UUID: """ Update a progress artifact. @@ -444,10 +427,14 @@ async def create_image_artifact( The image artifact ID. """ - artifact = await ImageArtifact( + new_artifact = ImageArtifact( key=key, description=description, image_url=image_url, - ).create() + ) + create_coro = new_artifact.create() + if TYPE_CHECKING: + assert asyncio.iscoroutine(create_coro) + artifact = await create_coro return artifact.id diff --git a/src/prefect/automations.py b/src/prefect/automations.py index a37c5a3a45dd0..76c0af19455ac 100644 --- a/src/prefect/automations.py +++ b/src/prefect/automations.py @@ -1,4 +1,4 @@ -from typing import Optional, Type +from typing import TYPE_CHECKING, Optional, overload from uuid import UUID from pydantic import Field @@ -112,17 +112,28 @@ async def update(self: Self): auto.name = "new name" auto.update() """ + assert self.id is not None async with get_client() as client: automation = AutomationCore( **self.model_dump(exclude={"id", "owner_resource"}) ) await client.update_automation(automation_id=self.id, automation=automation) + @overload + @classmethod + async def read(cls, id: UUID, name: Optional[str] = ...) -> Self: + ... + + @overload + @classmethod + async def read(cls, id: None = None, name: str = ...) -> Optional[Self]: + ... + @classmethod @sync_compatible async def read( - cls: Type[Self], id: Optional[UUID] = None, name: Optional[str] = None - ) -> Self: + cls, id: Optional[UUID] = None, name: Optional[str] = None + ) -> Optional[Self]: """ Read an automation by ID or name. automation = Automation.read(name="woodchonk") @@ -145,13 +156,13 @@ async def read( raise if automation is None: raise ValueError(f"Automation with ID {id!r} not found") - return Automation(**automation.model_dump()) + return cls(**automation.model_dump()) else: + if TYPE_CHECKING: + assert name is not None automation = await client.read_automations_by_name(name=name) if len(automation) > 0: - return ( - Automation(**automation[0].model_dump()) if automation else None - ) + return cls(**automation[0].model_dump()) if automation else None else: raise ValueError(f"Automation with name {name!r} not found") @@ -161,6 +172,9 @@ async def delete(self: Self) -> bool: auto = Automation.read(id = 123) auto.delete() """ + if self.id is None: + raise ValueError("Can't delete an automation without an id") + async with get_client() as client: try: await client.delete_automation(self.id) @@ -177,6 +191,9 @@ async def disable(self: Self) -> bool: auto = Automation.read(id = 123) auto.disable() """ + if self.id is None: + raise ValueError("Can't disable an automation without an id") + async with get_client() as client: try: await client.pause_automation(self.id) @@ -193,6 +210,9 @@ async def enable(self: Self) -> bool: auto = Automation.read(id = 123) auto.enable() """ + if self.id is None: + raise ValueError("Can't enable an automation without an id") + async with get_client() as client: try: await client.resume_automation(self.id) diff --git a/src/prefect/context.py b/src/prefect/context.py index 287b9b58e1381..7f92ad7528ca5 100644 --- a/src/prefect/context.py +++ b/src/prefect/context.py @@ -9,21 +9,10 @@ import os import sys import warnings +from collections.abc import AsyncGenerator, Generator, Mapping from contextlib import ExitStack, asynccontextmanager, contextmanager from contextvars import ContextVar, Token -from typing import ( - TYPE_CHECKING, - Any, - AsyncGenerator, - Dict, - Generator, - Mapping, - Optional, - Set, - Type, - TypeVar, - Union, -) +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from typing_extensions import Self @@ -64,7 +53,7 @@ GLOBAL_SETTINGS_CONTEXT = None # type: ignore -def serialize_context() -> Dict[str, Any]: +def serialize_context() -> dict[str, Any]: """ Serialize the current context for use in a remote execution environment. """ @@ -84,7 +73,7 @@ def serialize_context() -> Dict[str, Any]: @contextmanager def hydrated_context( - serialized_context: Optional[Dict[str, Any]] = None, + serialized_context: Optional[dict[str, Any]] = None, client: Union[PrefectClient, SyncPrefectClient, None] = None, ): with ExitStack() as stack: @@ -148,7 +137,7 @@ def __exit__(self, *_): self._token = None @classmethod - def get(cls: Type[Self]) -> Optional[Self]: + def get(cls: type[Self]) -> Optional[Self]: """Get the current context instance""" return cls.__var__.get(None) @@ -173,7 +162,7 @@ def model_copy( new._token = None return new - def serialize(self, include_secrets: bool = True) -> Dict[str, Any]: + def serialize(self, include_secrets: bool = True) -> dict[str, Any]: """ Serialize the context model to a dictionary that can be pickled with cloudpickle. """ @@ -314,10 +303,10 @@ def __init__(self, *args: Any, **kwargs: Any): start_client_metrics_server() start_time: DateTime = Field(default_factory=lambda: DateTime.now("UTC")) - input_keyset: Optional[Dict[str, Dict[str, str]]] = None + input_keyset: Optional[dict[str, dict[str, str]]] = None client: Union[PrefectClient, SyncPrefectClient] - def serialize(self: Self, include_secrets: bool = True) -> Dict[str, Any]: + def serialize(self: Self, include_secrets: bool = True) -> dict[str, Any]: return self.model_dump( include={"start_time", "input_keyset"}, exclude_unset=True, @@ -344,7 +333,7 @@ class EngineContext(RunContext): flow_run: Optional[FlowRun] = None task_runner: TaskRunner[Any] log_prints: bool = False - parameters: Optional[Dict[str, Any]] = None + parameters: Optional[dict[str, Any]] = None # Flag signaling if the flow run context has been serialized and sent # to remote infrastructure. @@ -355,10 +344,10 @@ class EngineContext(RunContext): persist_result: bool = Field(default_factory=get_default_persist_setting) # Counter for task calls allowing unique - task_run_dynamic_keys: Dict[str, Union[str, int]] = Field(default_factory=dict) + task_run_dynamic_keys: dict[str, Union[str, int]] = Field(default_factory=dict) # Counter for flow pauses - observed_flow_pauses: Dict[str, int] = Field(default_factory=dict) + observed_flow_pauses: dict[str, int] = Field(default_factory=dict) # Tracking for result from task runs in this flow run for dependency tracking # Holds the ID of the object returned by the task run and task run state @@ -369,7 +358,7 @@ class EngineContext(RunContext): __var__: ContextVar[Self] = ContextVar("flow_run") - def serialize(self: Self, include_secrets: bool = True) -> Dict[str, Any]: + def serialize(self: Self, include_secrets: bool = True) -> dict[str, Any]: return self.model_dump( include={ "flow_run", @@ -403,7 +392,7 @@ class TaskRunContext(RunContext): task: "Task[Any, Any]" task_run: TaskRun log_prints: bool = False - parameters: Dict[str, Any] + parameters: dict[str, Any] # Result handling result_store: ResultStore @@ -411,7 +400,7 @@ class TaskRunContext(RunContext): __var__ = ContextVar("task_run") - def serialize(self: Self, include_secrets: bool = True) -> Dict[str, Any]: + def serialize(self: Self, include_secrets: bool = True) -> dict[str, Any]: return self.model_dump( include={ "task_run", @@ -437,7 +426,7 @@ class TagsContext(ContextModel): current_tags: A set of current tags in the context """ - current_tags: Set[str] = Field(default_factory=set) + current_tags: set[str] = Field(default_factory=set) @classmethod def get(cls) -> "TagsContext": @@ -512,7 +501,7 @@ def get_settings_context() -> SettingsContext: @contextmanager -def tags(*new_tags: str) -> Generator[Set[str], None, None]: +def tags(*new_tags: str) -> Generator[set[str], None, None]: """ Context manager to add tags to flow and task run calls. diff --git a/src/prefect/locking/filesystem.py b/src/prefect/locking/filesystem.py index d97ff4580fde4..c324f7d17aab0 100644 --- a/src/prefect/locking/filesystem.py +++ b/src/prefect/locking/filesystem.py @@ -1,6 +1,7 @@ import time +from logging import Logger from pathlib import Path -from typing import Dict, Optional +from typing import Optional import anyio import pendulum @@ -11,7 +12,7 @@ from .protocol import LockManager -logger = get_logger(__name__) +logger: Logger = get_logger(__name__) class _LockInfo(TypedDict): @@ -37,11 +38,11 @@ class FileSystemLockManager(LockManager): lock_files_directory: the directory where lock files are stored """ - def __init__(self, lock_files_directory: Path): - self.lock_files_directory = lock_files_directory.expanduser().resolve() - self._locks: Dict[str, _LockInfo] = {} + def __init__(self, lock_files_directory: Path) -> None: + self.lock_files_directory: Path = lock_files_directory.expanduser().resolve() + self._locks: dict[str, _LockInfo] = {} - def _ensure_lock_files_directory_exists(self): + def _ensure_lock_files_directory_exists(self) -> None: self.lock_files_directory.mkdir(parents=True, exist_ok=True) def _lock_path_for_key(self, key: str) -> Path: @@ -49,7 +50,7 @@ def _lock_path_for_key(self, key: str) -> Path: return lock_info["path"] return self.lock_files_directory.joinpath(key).with_suffix(".lock") - def _get_lock_info(self, key: str, use_cache=True) -> Optional[_LockInfo]: + def _get_lock_info(self, key: str, use_cache: bool = True) -> Optional[_LockInfo]: if use_cache: if (lock_info := self._locks.get(key)) is not None: return lock_info diff --git a/src/prefect/locking/memory.py b/src/prefect/locking/memory.py index b6891f844709f..554f0dbeaa525 100644 --- a/src/prefect/locking/memory.py +++ b/src/prefect/locking/memory.py @@ -1,6 +1,8 @@ import asyncio import threading -from typing import Dict, Optional, TypedDict +from typing import Any, Optional, TypedDict + +from typing_extensions import Self from .protocol import LockManager @@ -30,14 +32,14 @@ class MemoryLockManager(LockManager): _instance = None - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: Any, **kwargs: Any) -> Self: if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __init__(self): self._locks_dict_lock = threading.Lock() - self._locks: Dict[str, _LockInfo] = {} + self._locks: dict[str, _LockInfo] = {} def _expire_lock(self, key: str): """ diff --git a/src/prefect/locking/protocol.py b/src/prefect/locking/protocol.py index 369f49d4b01b4..3505eeec00ddf 100644 --- a/src/prefect/locking/protocol.py +++ b/src/prefect/locking/protocol.py @@ -57,7 +57,7 @@ async def aacquire_lock( """ ... - def release_lock(self, key: str, holder: str): + def release_lock(self, key: str, holder: str) -> None: """ Releases the lock on the corresponding transaction record. diff --git a/src/prefect/plugins.py b/src/prefect/plugins.py index 07f4144c4acd9..d55dfece3c34a 100644 --- a/src/prefect/plugins.py +++ b/src/prefect/plugins.py @@ -9,15 +9,15 @@ """ from types import ModuleType -from typing import Any, Dict, Union +from typing import Any, Union import prefect.settings from prefect.utilities.compat import EntryPoints, entry_points -COLLECTIONS: Union[None, Dict[str, Union[ModuleType, Exception]]] = None +_collections: Union[None, dict[str, Union[ModuleType, Exception]]] = None -def safe_load_entrypoints(entrypoints: EntryPoints) -> Dict[str, Union[Exception, Any]]: +def safe_load_entrypoints(entrypoints: EntryPoints) -> dict[str, Union[Exception, Any]]: """ Load entry points for a group capturing any exceptions that occur. """ @@ -26,7 +26,7 @@ def safe_load_entrypoints(entrypoints: EntryPoints) -> Dict[str, Union[Exception # also want to validate the type for the group for entrypoints that have # a specific type we expect. - results = {} + results: dict[str, Union[Exception, Any]] = {} for entrypoint in entrypoints: result = None @@ -40,18 +40,20 @@ def safe_load_entrypoints(entrypoints: EntryPoints) -> Dict[str, Union[Exception return results -def load_prefect_collections() -> Dict[str, Union[ModuleType, Exception]]: +def load_prefect_collections() -> dict[str, Union[ModuleType, Exception]]: """ Load all Prefect collections that define an entrypoint in the group `prefect.collections`. """ - global COLLECTIONS + global _collections - if COLLECTIONS is not None: - return COLLECTIONS + if _collections is not None: + return _collections collection_entrypoints: EntryPoints = entry_points(group="prefect.collections") - collections = safe_load_entrypoints(collection_entrypoints) + collections: dict[str, Union[Exception, Any]] = safe_load_entrypoints( + collection_entrypoints + ) # TODO: Consider the utility of this once we've established this pattern. # We cannot use a logger here because logging is not yet initialized. @@ -68,5 +70,5 @@ def load_prefect_collections() -> Dict[str, Union[ModuleType, Exception]]: if prefect.settings.PREFECT_DEBUG_MODE: print(f"Loaded collection {name!r}.") - COLLECTIONS = collections + _collections = collections return collections diff --git a/src/prefect/telemetry/processors.py b/src/prefect/telemetry/processors.py index 03a33ab0f2b69..64064f5f3beb3 100644 --- a/src/prefect/telemetry/processors.py +++ b/src/prefect/telemetry/processors.py @@ -1,6 +1,6 @@ import time from threading import Event, Lock, Thread -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Optional from opentelemetry.context import Context from opentelemetry.sdk.trace import Span, SpanProcessor @@ -13,7 +13,7 @@ class InFlightSpanProcessor(SpanProcessor): def __init__(self, span_exporter: "SpanExporter"): self.span_exporter = span_exporter - self._in_flight: Dict[int, Span] = {} + self._in_flight: dict[int, Span] = {} self._lock = Lock() self._stop_event = Event() self._export_thread = Thread(target=self._export_periodically, daemon=True) @@ -30,10 +30,10 @@ def _export_periodically(self) -> None: self.span_exporter.export(to_export) def _readable_span(self, span: "Span") -> "ReadableSpan": - readable = span._readable_span() - readable._end_time = time.time_ns() - readable._attributes = { - **(readable._attributes or {}), + readable = span._readable_span() # pyright: ignore[reportPrivateUsage] + readable._end_time = time.time_ns() # pyright: ignore[reportPrivateUsage] + readable._attributes = { # pyright: ignore[reportPrivateUsage] + **(readable._attributes or {}), # pyright: ignore[reportPrivateUsage] "prefect.in-flight": True, } return readable diff --git a/src/prefect/telemetry/services.py b/src/prefect/telemetry/services.py index 6d9d7e47bca6a..a9825094f0e02 100644 --- a/src/prefect/telemetry/services.py +++ b/src/prefect/telemetry/services.py @@ -1,32 +1,38 @@ -from abc import abstractmethod -from typing import Union +from collections.abc import Sequence +from typing import Any, Protocol, TypeVar from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk._logs import LogData from opentelemetry.sdk._logs.export import LogExporter from opentelemetry.sdk.trace import ReadableSpan -from opentelemetry.sdk.trace.export import SpanExporter +from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult from prefect._internal.concurrency.services import BatchedQueueService +BatchItem = TypeVar("BatchItem", ReadableSpan, LogData) +T_contra = TypeVar("T_contra", contravariant=True) -class BaseQueueingExporter(BatchedQueueService): + +class OTLPExporter(Protocol[T_contra]): + def export(self, __items: Sequence[T_contra]) -> Any: + ... + + def shutdown(self) -> Any: + ... + + +class BaseQueueingExporter(BatchedQueueService[BatchItem]): _max_batch_size = 512 _min_interval = 2.0 - _otlp_exporter: Union[SpanExporter, LogExporter] - - def export(self, batch: list[Union[ReadableSpan, LogData]]) -> None: - for item in batch: - self.send(item) - @abstractmethod - def _export_batch(self, items: list[Union[ReadableSpan, LogData]]) -> None: - pass + def __init__(self, otlp_exporter: OTLPExporter[BatchItem]) -> None: + super().__init__() + self._otlp_exporter = otlp_exporter - async def _handle_batch(self, items: list[Union[ReadableSpan, LogData]]) -> None: + async def _handle_batch(self, items: list[BatchItem]) -> None: try: - self._export_batch(items) + self._otlp_exporter.export(items) except Exception as e: self._logger.exception(f"Failed to export batch: {e}") raise @@ -39,29 +45,24 @@ def shutdown(self) -> None: self._otlp_exporter.shutdown() -class QueueingSpanExporter(BaseQueueingExporter, SpanExporter): +class QueueingSpanExporter(BaseQueueingExporter[ReadableSpan], SpanExporter): _otlp_exporter: OTLPSpanExporter def __init__(self, endpoint: str, headers: tuple[tuple[str, str]]): - super().__init__() - self._otlp_exporter = OTLPSpanExporter( - endpoint=endpoint, - headers=dict(headers), - ) + super().__init__(OTLPSpanExporter(endpoint=endpoint, headers=dict(headers))) - def _export_batch(self, items: list[ReadableSpan]) -> None: - self._otlp_exporter.export(items) + def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: + for item in spans: + self.send(item) + return SpanExportResult.SUCCESS -class QueueingLogExporter(BaseQueueingExporter, LogExporter): +class QueueingLogExporter(BaseQueueingExporter[LogData], LogExporter): _otlp_exporter: OTLPLogExporter - def __init__(self, endpoint: str, headers: tuple[tuple[str, str]]): - super().__init__() - self._otlp_exporter = OTLPLogExporter( - endpoint=endpoint, - headers=dict(headers), - ) + def __init__(self, endpoint: str, headers: tuple[tuple[str, str]]) -> None: + super().__init__(OTLPLogExporter(endpoint=endpoint, headers=dict(headers))) - def _export_batch(self, items: list[LogData]) -> None: - self._otlp_exporter.export(items) + def export(self, batch: Sequence[LogData]) -> None: + for item in batch: + self.send(item) diff --git a/src/prefect/workers/__init__.py b/src/prefect/workers/__init__.py index 3ac8dc9c76137..c08307719f71f 100644 --- a/src/prefect/workers/__init__.py +++ b/src/prefect/workers/__init__.py @@ -1 +1,3 @@ from .process import ProcessWorker + +__all__ = ["ProcessWorker"] diff --git a/tests/test_plugins.py b/tests/test_plugins.py index f3aacd1b2cf30..3ef2c9e35a36c 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -10,9 +10,9 @@ @pytest.fixture(autouse=True) def reset_collections(): - prefect.plugins.COLLECTIONS = None + prefect.plugins._collections = None yield - prefect.plugins.COLLECTIONS = None + prefect.plugins._collections = None def test_safe_load_entrypoints_returns_modules_and_exceptions(): diff --git a/tests/typesafety/test_automations.yml b/tests/typesafety/test_automations.yml new file mode 100644 index 0000000000000..6bd1706f64294 --- /dev/null +++ b/tests/typesafety/test_automations.yml @@ -0,0 +1,19 @@ +# yaml-language-server: $schema=https://raw.githubusercontent.com/typeddjango/pytest-mypy-plugins/master/pytest_mypy_plugins/schema.json +- case: automations_read_by_id + main: | + from uuid import uuid4 + from prefect.automations import Automation + async def test_func() -> None: + automation = await Automation.read(id=uuid4()) + reveal_type(automation) # N: Revealed type is "prefect.automations.Automation" + automation = await Automation.read(id=uuid4(), name=None) + reveal_type(automation) # N: Revealed type is "prefect.automations.Automation" + +- case: automations_read_by_name + main: | + from prefect.automations import Automation + async def test_func() -> None: + automation = await Automation.read(name="foobar") + reveal_type(automation) # N: Revealed type is "Union[prefect.automations.Automation, None]" + automation = await Automation.read(id=None, name="foobar") + reveal_type(automation) # N: Revealed type is "Union[prefect.automations.Automation, None]" diff --git a/tests/typesafety/test_flows.yml b/tests/typesafety/test_flows.yml index 2db73d00e6c5a..4ef96d9e88d95 100644 --- a/tests/typesafety/test_flows.yml +++ b/tests/typesafety/test_flows.yml @@ -3,7 +3,7 @@ main: | from prefect import flow reveal_type(flow.from_source) - regex: yes + regex: true # this has to be a regex, because mypy randomly (!) switches between ... and [*Any, **Any] syntax here out: "main:2: note: Revealed type is \"\ def \\(\