diff --git a/src/hera/shared/_type_util.py b/src/hera/shared/_type_util.py index 45f4f9582..de2e0ab9b 100644 --- a/src/hera/shared/_type_util.py +++ b/src/hera/shared/_type_util.py @@ -92,15 +92,24 @@ def get_unsubscripted_type(t: Any) -> Any: return t -def origin_type_issubclass(cls: Any, type_: type) -> bool: - """Return True if cls can be considered as a subclass of type_.""" - unwrapped_type = unwrap_annotation(cls) +def origin_type_issubtype(annotation: Any, type_: Union[type, Tuple[type, ...]]) -> bool: + """Return True if annotation is a subtype of type_.""" + unwrapped_type = unwrap_annotation(annotation) origin_type = get_unsubscripted_type(unwrapped_type) if origin_type is Union or origin_type is UnionType: - return any(origin_type_issubclass(arg, type_) for arg in get_args(cls)) + return all(origin_type_issubtype(arg, type_) for arg in get_args(annotation)) return issubclass(origin_type, type_) +def origin_type_issupertype(annotation: Any, type_: type) -> bool: + """Return True if annotation is a supertype of type_.""" + unwrapped_type = unwrap_annotation(annotation) + origin_type = get_unsubscripted_type(unwrapped_type) + if origin_type is Union or origin_type is UnionType: + return any(origin_type_issupertype(arg, type_) for arg in get_args(annotation)) + return issubclass(type_, origin_type) + + def is_subscripted(t: Any) -> bool: """Check if given type is subscripted, i.e. a typing object of the form X[Y, Z, ...]. diff --git a/src/hera/workflows/_runner/script_annotations_util.py b/src/hera/workflows/_runner/script_annotations_util.py index fbdf162ae..d560fe4dd 100644 --- a/src/hera/workflows/_runner/script_annotations_util.py +++ b/src/hera/workflows/_runner/script_annotations_util.py @@ -3,15 +3,21 @@ import inspect import json import os +import sys from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast +if sys.version_info >= (3, 10): + from types import NoneType +else: + NoneType = type(None) + from hera.shared._pydantic import BaseModel, get_field_annotations, get_fields from hera.shared._type_util import ( get_unsubscripted_type, get_workflow_annotation, is_subscripted, - origin_type_issubclass, + origin_type_issubtype, unwrap_annotation, ) from hera.shared.serialization import serialize @@ -138,7 +144,7 @@ def map_runner_input( input_model_obj = {} def load_parameter_value(value: str, value_type: type) -> Any: - if origin_type_issubclass(value_type, str): + if origin_type_issubtype(value_type, (str, NoneType)): return value try: diff --git a/src/hera/workflows/_runner/util.py b/src/hera/workflows/_runner/util.py index 5acc32ab3..b949d4558 100644 --- a/src/hera/workflows/_runner/util.py +++ b/src/hera/workflows/_runner/util.py @@ -6,13 +6,19 @@ import inspect import json import os +import sys from pathlib import Path from typing import Any, Callable, Dict, List, Optional, cast +if sys.version_info >= (3, 10): + from types import NoneType +else: + NoneType = type(None) + from hera.shared._pydantic import _PYDANTIC_VERSION from hera.shared._type_util import ( get_workflow_annotation, - origin_type_issubclass, + origin_type_issubtype, unwrap_annotation, ) from hera.shared.serialization import serialize @@ -125,7 +131,7 @@ def _get_unannotated_type(key: str, f: Callable) -> Optional[type]: def _is_str_kwarg_of(key: str, f: Callable) -> bool: """Check if param `key` of function `f` has a type annotation that can be interpreted as a subclass of str.""" if func_param_annotation := _get_function_param_annotation(key, f): - return origin_type_issubclass(func_param_annotation, str) + return origin_type_issubtype(func_param_annotation, (str, NoneType)) return False diff --git a/src/hera/workflows/script.py b/src/hera/workflows/script.py index dce234510..cd9d419a4 100644 --- a/src/hera/workflows/script.py +++ b/src/hera/workflows/script.py @@ -47,7 +47,7 @@ _flag_enabled, ) from hera.shared._pydantic import _PYDANTIC_VERSION, root_validator, validator -from hera.shared._type_util import get_workflow_annotation, is_subscripted, origin_type_issubclass +from hera.shared._type_util import get_workflow_annotation, is_subscripted, origin_type_issupertype from hera.shared.serialization import serialize from hera.workflows._context import _context from hera.workflows._meta_mixins import CallableTemplateMixin @@ -540,7 +540,9 @@ class will be used as inputs, rather than the class itself. else: default = MISSING - if origin_type_issubclass(func_param.annotation, NoneType) and (default is MISSING or default is not None): + if origin_type_issupertype(func_param.annotation, NoneType) and ( + default is MISSING or default is not None + ): raise ValueError(f"Optional parameter '{func_param.name}' must have a default value of None.") parameters.append(Parameter(name=func_param.name, default=default)) diff --git a/tests/script_runner/parameter_inputs.py b/tests/script_runner/parameter_inputs.py index 980d756a9..4710dedd8 100644 --- a/tests/script_runner/parameter_inputs.py +++ b/tests/script_runner/parameter_inputs.py @@ -1,5 +1,5 @@ import json -from typing import Any, List +from typing import Any, List, Union try: from typing import Annotated @@ -76,6 +76,11 @@ def no_type_parameter(my_anything) -> Any: return my_anything +@script() +def str_or_int_parameter(my_str_or_int: Union[str, int]) -> str: + return f"type given: {type(my_str_or_int).__name__}" + + @script() def str_parameter_expects_jsonstr_dict(my_json_str: str) -> dict: return json.loads(my_json_str) diff --git a/tests/test_runner.py b/tests/test_runner.py index f9fac159f..f61546695 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -65,6 +65,18 @@ {}, id="no-type-dict", ), + pytest.param( + "tests.script_runner.parameter_inputs:str_or_int_parameter", + [{"name": "my_str_or_int", "value": "hi there"}], + "type given: str", + id="str-or-int-given-str", + ), + pytest.param( + "tests.script_runner.parameter_inputs:str_or_int_parameter", + [{"name": "my_str_or_int", "value": "3"}], + "type given: int", + id="str-or-int-given-int", + ), pytest.param( "tests.script_runner.parameter_inputs:str_parameter_expects_jsonstr_dict", [{"name": "my_json_str", "value": json.dumps({"my": "dict"})}], diff --git a/tests/test_unit/test_shared_type_utils.py b/tests/test_unit/test_shared_type_utils.py index d543997d8..799c1cb7a 100644 --- a/tests/test_unit/test_shared_type_utils.py +++ b/tests/test_unit/test_shared_type_utils.py @@ -1,5 +1,15 @@ +import sys from typing import List, Optional, Union +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated +if sys.version_info >= (3, 10): + from types import NoneType +else: + NoneType = type(None) + import pytest from annotated_types import Gt @@ -8,16 +18,12 @@ get_unsubscripted_type, get_workflow_annotation, is_annotated, - origin_type_issubclass, + origin_type_issubtype, + origin_type_issupertype, unwrap_annotation, ) from hera.workflows import Artifact, Parameter -try: - from typing import Annotated -except ImportError: - from typing_extensions import Annotated - @pytest.mark.parametrize("annotation, expected", [[Annotated[str, "some metadata"], True], [str, False]]) def test_is_annotated(annotation, expected): @@ -104,11 +110,29 @@ def test_get_unsubscripted_type(annotation, expected): @pytest.mark.parametrize( "annotation, target, expected", [ - [List[str], str, False], - [Optional[str], str, True], - [str, str, True], - [Union[int, str], int, True], + pytest.param(List[str], str, False, id="list-str-not-subtype-of-str"), + pytest.param(Optional[str], str, False, id="optional-str-not-subtype-of-str"), + pytest.param(str, str, True, id="str-is-subtype-of-str"), + pytest.param(Union[int, str], int, False, id="union-int-str-not-subtype-of-str"), + pytest.param(Optional[str], (str, NoneType), True, id="optional-str-is-subtype-of-optional-str"), + pytest.param(str, (str, NoneType), True, id="str-is-subtype-of-optional-str"), + pytest.param(Union[int, str], (str, NoneType), False, id="union-int-str-not-subtype-of-optional-str"), + ], +) +def test_origin_type_issubtype(annotation, target, expected): + assert origin_type_issubtype(annotation, target) is expected + + +@pytest.mark.parametrize( + "annotation, target, expected", + [ + pytest.param(List[str], str, False, id="list-str-not-supertype-of-str"), + pytest.param(Optional[str], str, True, id="optional-str-is-supertype-of-str"), + pytest.param(str, str, True, id="str-is-supertype-of-str"), + pytest.param(Union[int, str], int, True, id="union-int-str-is-supertype-of-int"), + pytest.param(Optional[str], NoneType, True, id="optional-str-is-supertype-of-nonetype"), + pytest.param(str, NoneType, False, id="str-not-supertype-of-nonetype"), ], ) -def test_origin_type_issubclass(annotation, target, expected): - assert origin_type_issubclass(annotation, target) is expected +def test_origin_type_issupertype(annotation, target, expected): + assert origin_type_issupertype(annotation, target) is expected