diff --git a/src/prefect/flows.py b/src/prefect/flows.py index 27dc6e7c4a8a..cc4f781b6499 100644 --- a/src/prefect/flows.py +++ b/src/prefect/flows.py @@ -10,6 +10,7 @@ import importlib.util import inspect import os +import sys import tempfile import warnings from functools import partial, update_wrapper @@ -137,6 +138,16 @@ if TYPE_CHECKING: from prefect.deployments.runner import FlexibleScheduleList, RunnerDeployment +# Handle Python 3.8 compatibility for GenericAlias +if sys.version_info >= (3, 9): + from types import GenericAlias # novermin + + GENERIC_ALIAS = (GenericAlias,) +else: + from typing import _GenericAlias + + GENERIC_ALIAS = (_GenericAlias,) + @PrefectObjectRegistry.register_instances class Flow(Generic[P, R]): @@ -530,18 +541,22 @@ def validate_parameters(self, parameters: Dict[str, Any]) -> Dict[str, Any]: is_v1_type(param.annotation) for param in sig.parameters.values() ) has_v1_models = any( - issubclass(param.annotation, V1BaseModel) - if isinstance(param.annotation, type) - else False + ( + isinstance(param.annotation, type) + and not isinstance(param.annotation, GENERIC_ALIAS) + and issubclass(param.annotation, V1BaseModel) + ) for param in sig.parameters.values() ) has_v2_types = any( is_v2_type(param.annotation) for param in sig.parameters.values() ) has_v2_models = any( - issubclass(param.annotation, V2BaseModel) - if isinstance(param.annotation, type) - else False + ( + isinstance(param.annotation, type) + and not isinstance(param.annotation, GENERIC_ALIAS) + and issubclass(param.annotation, V2BaseModel) + ) for param in sig.parameters.values() ) @@ -1601,7 +1616,9 @@ def flow( def select_flow( - flows: Iterable[Flow], flow_name: str = None, from_message: str = None + flows: Iterable[Flow], + flow_name: Optional[str] = None, + from_message: Optional[str] = None, ) -> Flow: """ Select the only flow in an iterable or a flow specified by name. diff --git a/tests/test_flows.py b/tests/test_flows.py index 47bfea76ec4a..097018f97563 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -1517,6 +1517,23 @@ def my_flow(secret: SecretStr): "secret": SecretStr("my secret") } + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="Python 3.9+ required for GenericAlias" + ) + def test_flow_signature_can_contain_generic_type_hints(self): + """Test that generic type hints like dict[str, str] work correctly + + this is a regression test for https://github.com/PrefectHQ/prefect/issues/16105 + """ + + @flow + def my_flow(param: dict[str, str]): # novermin + return param + + test_data = {"foo": "bar"} + assert my_flow(test_data) == test_data + assert my_flow.validate_parameters({"param": test_data}) == {"param": test_data} + class TestSubflowTaskInputs: async def test_subflow_with_one_upstream_task_future(self, prefect_client):