Skip to content

Commit

Permalink
Fix deserialization for str unions
Browse files Browse the repository at this point in the history
PR #1168 changed the logic of map_runner_input and _parse to no longer
pass incoming values to json.loads if their annotated type was a union
that included str, rather than only when given a subtype of
`Optional[str]`. Split `origin_type_issubclass` into two functions,
`origin_type_issupertype` (which matches the previous behaviour) and
`origin_type_issubtype`, and use the latter instead to restore the
original behaviour.

Add a runner check which verifies this behaviour.

Signed-off-by: Alice Purcell <[email protected]>
  • Loading branch information
alicederyn committed Oct 15, 2024
1 parent a08f744 commit 9192b57
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 23 deletions.
17 changes: 13 additions & 4 deletions src/hera/shared/_type_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,24 @@ def get_unsubscripted_type(t: Any) -> Any:
return t


def origin_type_issubclass(cls: Any, type_: type) -> bool:
"""Return True if cls can be considered as a subclass of type_."""
unwrapped_type = unwrap_annotation(cls)
def origin_type_issubtype(annotation: Any, type_: Union[type, Tuple[type, ...]]) -> bool:
"""Return True if annotation is a subtype of type_."""
unwrapped_type = unwrap_annotation(annotation)
origin_type = get_unsubscripted_type(unwrapped_type)
if origin_type is Union or origin_type is UnionType:
return any(origin_type_issubclass(arg, type_) for arg in get_args(cls))
return all(origin_type_issubtype(arg, type_) for arg in get_args(annotation))
return issubclass(origin_type, type_)


def origin_type_issupertype(annotation: Any, type_: type) -> bool:
"""Return True if annotation is a supertype of type_."""
unwrapped_type = unwrap_annotation(annotation)
origin_type = get_unsubscripted_type(unwrapped_type)
if origin_type is Union or origin_type is UnionType:
return any(origin_type_issupertype(arg, type_) for arg in get_args(annotation))
return issubclass(type_, origin_type)


def is_subscripted(t: Any) -> bool:
"""Check if given type is subscripted, i.e. a typing object of the form X[Y, Z, ...].
Expand Down
10 changes: 8 additions & 2 deletions src/hera/workflows/_runner/script_annotations_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,21 @@
import inspect
import json
import os
import sys
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast

if sys.version_info >= (3, 10):
from types import NoneType
else:
NoneType = type(None)

from hera.shared._pydantic import BaseModel, get_field_annotations, get_fields
from hera.shared._type_util import (
get_unsubscripted_type,
get_workflow_annotation,
is_subscripted,
origin_type_issubclass,
origin_type_issubtype,
unwrap_annotation,
)
from hera.shared.serialization import serialize
Expand Down Expand Up @@ -138,7 +144,7 @@ def map_runner_input(
input_model_obj = {}

def load_parameter_value(value: str, value_type: type) -> Any:
if origin_type_issubclass(value_type, str):
if origin_type_issubtype(value_type, (str, NoneType)):
return value

try:
Expand Down
10 changes: 8 additions & 2 deletions src/hera/workflows/_runner/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@
import inspect
import json
import os
import sys
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, cast

if sys.version_info >= (3, 10):
from types import NoneType
else:
NoneType = type(None)

from hera.shared._pydantic import _PYDANTIC_VERSION
from hera.shared._type_util import (
get_workflow_annotation,
origin_type_issubclass,
origin_type_issubtype,
unwrap_annotation,
)
from hera.shared.serialization import serialize
Expand Down Expand Up @@ -125,7 +131,7 @@ def _get_unannotated_type(key: str, f: Callable) -> Optional[type]:
def _is_str_kwarg_of(key: str, f: Callable) -> bool:
"""Check if param `key` of function `f` has a type annotation that can be interpreted as a subclass of str."""
if func_param_annotation := _get_function_param_annotation(key, f):
return origin_type_issubclass(func_param_annotation, str)
return origin_type_issubtype(func_param_annotation, (str, NoneType))
return False


Expand Down
6 changes: 4 additions & 2 deletions src/hera/workflows/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
_flag_enabled,
)
from hera.shared._pydantic import _PYDANTIC_VERSION, root_validator, validator
from hera.shared._type_util import get_workflow_annotation, is_subscripted, origin_type_issubclass
from hera.shared._type_util import get_workflow_annotation, is_subscripted, origin_type_issupertype
from hera.shared.serialization import serialize
from hera.workflows._context import _context
from hera.workflows._meta_mixins import CallableTemplateMixin
Expand Down Expand Up @@ -540,7 +540,9 @@ class will be used as inputs, rather than the class itself.
else:
default = MISSING

if origin_type_issubclass(func_param.annotation, NoneType) and (default is MISSING or default is not None):
if origin_type_issupertype(func_param.annotation, NoneType) and (
default is MISSING or default is not None
):
raise ValueError(f"Optional parameter '{func_param.name}' must have a default value of None.")

parameters.append(Parameter(name=func_param.name, default=default))
Expand Down
7 changes: 6 additions & 1 deletion tests/script_runner/parameter_inputs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, List
from typing import Any, List, Union

try:
from typing import Annotated
Expand Down Expand Up @@ -76,6 +76,11 @@ def no_type_parameter(my_anything) -> Any:
return my_anything


@script()
def str_or_int_parameter(my_str_or_int: Union[str, int]) -> str:
return f"type given: {type(my_str_or_int).__name__}"


@script()
def str_parameter_expects_jsonstr_dict(my_json_str: str) -> dict:
return json.loads(my_json_str)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,18 @@
{},
id="no-type-dict",
),
pytest.param(
"tests.script_runner.parameter_inputs:str_or_int_parameter",
[{"name": "my_str_or_int", "value": "hi there"}],
"type given: str",
id="str-or-int-given-str",
),
pytest.param(
"tests.script_runner.parameter_inputs:str_or_int_parameter",
[{"name": "my_str_or_int", "value": "3"}],
"type given: int",
id="str-or-int-given-int",
),
pytest.param(
"tests.script_runner.parameter_inputs:str_parameter_expects_jsonstr_dict",
[{"name": "my_json_str", "value": json.dumps({"my": "dict"})}],
Expand Down
48 changes: 36 additions & 12 deletions tests/test_unit/test_shared_type_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
import sys
from typing import List, Optional, Union

if sys.version_info >= (3, 9):
from typing import Annotated
else:
from typing_extensions import Annotated
if sys.version_info >= (3, 10):
from types import NoneType
else:
NoneType = type(None)

import pytest
from annotated_types import Gt

Expand All @@ -8,16 +18,12 @@
get_unsubscripted_type,
get_workflow_annotation,
is_annotated,
origin_type_issubclass,
origin_type_issubtype,
origin_type_issupertype,
unwrap_annotation,
)
from hera.workflows import Artifact, Parameter

try:
from typing import Annotated
except ImportError:
from typing_extensions import Annotated


@pytest.mark.parametrize("annotation, expected", [[Annotated[str, "some metadata"], True], [str, False]])
def test_is_annotated(annotation, expected):
Expand Down Expand Up @@ -104,11 +110,29 @@ def test_get_unsubscripted_type(annotation, expected):
@pytest.mark.parametrize(
"annotation, target, expected",
[
[List[str], str, False],
[Optional[str], str, True],
[str, str, True],
[Union[int, str], int, True],
pytest.param(List[str], str, False, id="list-str-not-subtype-of-str"),
pytest.param(Optional[str], str, False, id="optional-str-not-subtype-of-str"),
pytest.param(str, str, True, id="str-is-subtype-of-str"),
pytest.param(Union[int, str], int, False, id="union-int-str-not-subtype-of-str"),
pytest.param(Optional[str], (str, NoneType), True, id="optional-str-is-subtype-of-optional-str"),
pytest.param(str, (str, NoneType), True, id="str-is-subtype-of-optional-str"),
pytest.param(Union[int, str], (str, NoneType), False, id="union-int-str-not-subtype-of-optional-str"),
],
)
def test_origin_type_issubtype(annotation, target, expected):
assert origin_type_issubtype(annotation, target) is expected


@pytest.mark.parametrize(
"annotation, target, expected",
[
pytest.param(List[str], str, False, id="list-str-not-supertype-of-str"),
pytest.param(Optional[str], str, True, id="optional-str-is-supertype-of-str"),
pytest.param(str, str, True, id="str-is-supertype-of-str"),
pytest.param(Union[int, str], int, True, id="union-int-str-is-supertype-of-int"),
pytest.param(Optional[str], NoneType, True, id="optional-str-is-supertype-of-nonetype"),
pytest.param(str, NoneType, False, id="str-not-supertype-of-nonetype"),
],
)
def test_origin_type_issubclass(annotation, target, expected):
assert origin_type_issubclass(annotation, target) is expected
def test_origin_type_issupertype(annotation, target, expected):
assert origin_type_issupertype(annotation, target) is expected

0 comments on commit 9192b57

Please sign in to comment.