diff --git a/flytekit/core/constants.py b/flytekit/core/constants.py index b38bf46d23..ffedfedfe5 100644 --- a/flytekit/core/constants.py +++ b/flytekit/core/constants.py @@ -17,5 +17,8 @@ # Binary IDL Serialization Format MESSAGEPACK = "msgpack" +# Use the old way to create protobuf struct for dict, dataclass, and pydantic basemodel. +FLYTE_USE_OLD_DC_FORMAT = "FLYTE_USE_OLD_DC_FORMAT" + # Set this environment variable to true to force the task to return non-zero exit code on failure. FLYTE_FAIL_ON_ERROR = "FLYTE_FAIL_ON_ERROR" diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 2394a9b1af..471cda44ef 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -162,6 +162,12 @@ async def resolve_attr_path_in_promise(p: Promise) -> Promise: if len(p.attr_path) > 0 and type(curr_val.value) is _literals_models.Scalar: # We keep it for reference task local execution in the future. if type(curr_val.value.value) is _struct.Struct: + """ + Local execution currently has issues with struct attribute resolution. + This works correctly in remote execution. + Issue Link: https://github.com/flyteorg/flyte/issues/5959 + """ + st = curr_val.value.value new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:]) literal_type = TypeEngine.to_literal_type(type(new_st)) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index bb498070b5..6e3e652307 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -9,6 +9,7 @@ import inspect import json import mimetypes +import os import sys import textwrap import threading @@ -29,17 +30,17 @@ from google.protobuf.json_format import ParseDict as _ParseDict from google.protobuf.message import Message from google.protobuf.struct_pb2 import Struct -from mashumaro.codecs.json import JSONDecoder +from mashumaro.codecs.json import JSONDecoder, JSONEncoder from mashumaro.codecs.msgpack import MessagePackDecoder, MessagePackEncoder from mashumaro.mixins.json import DataClassJSONMixin from typing_extensions import Annotated, get_args, get_origin from flytekit.core.annotation import FlyteAnnotation -from flytekit.core.constants import MESSAGEPACK +from flytekit.core.constants import FLYTE_USE_OLD_DC_FORMAT, MESSAGEPACK from flytekit.core.context_manager import FlyteContext from flytekit.core.hash import HashMethod from flytekit.core.type_helpers import load_type_from_tag -from flytekit.core.utils import load_proto_from_file, timeit +from flytekit.core.utils import load_proto_from_file, str2bool, timeit from flytekit.exceptions import user as user_exceptions from flytekit.interaction.string_literals import literal_map_string_repr from flytekit.lazy_import.lazy_module import is_imported @@ -498,7 +499,8 @@ class Test(DataClassJsonMixin): def __init__(self) -> None: super().__init__("Object-Dataclass-Transformer", object) - self._decoder: Dict[Type, JSONDecoder] = dict() + self._json_encoder: Dict[Type, JSONEncoder] = dict() + self._json_decoder: Dict[Type, JSONDecoder] = dict() def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T): # Skip iterating all attributes in the dataclass if the type of v already matches the expected_type @@ -655,14 +657,58 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: ) ) + # This is for attribute access in FlytePropeller. ts = TypeStructure(tag="", dataclass_type=literal_type) return _type_models.LiteralType(simple=_type_models.SimpleType.STRUCT, metadata=schema, structure=ts) + def to_generic_literal( + self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType + ) -> Literal: + """ + Serializes a dataclass or dictionary to a Flyte literal, handling both JSON and MessagePack formats. + Set `FLYTE_USE_OLD_DC_FORMAT=true` to use the old JSON-based format. + Note: This is deprecated and will be removed in the future. + """ + if isinstance(python_val, dict): + json_str = json.dumps(python_val) + return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) + + if not dataclasses.is_dataclass(python_val): + raise TypeTransformerFailedError( + f"{type(python_val)} is not of type @dataclass, only Dataclasses are supported for " + f"user defined datatypes in Flytekit" + ) + + self._make_dataclass_serializable(python_val, python_type) + + # JSON serialization using mashumaro's DataClassJSONMixin + if isinstance(python_val, DataClassJSONMixin): + json_str = python_val.to_json() + else: + try: + encoder = self._json_encoder[python_type] + except KeyError: + encoder = JSONEncoder(python_type) + self._json_encoder[python_type] = encoder + + try: + json_str = encoder.encode(python_val) + except NotImplementedError: + raise NotImplementedError( + f"{python_type} should inherit from mashumaro.types.SerializableType" + f" and implement _serialize and _deserialize methods." + ) + + return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) # type: ignore + def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: + if str2bool(os.getenv(FLYTE_USE_OLD_DC_FORMAT)): + return self.to_generic_literal(ctx, python_val, python_type, expected) + if isinstance(python_val, dict): msgpack_bytes = msgpack.dumps(python_val) - return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack"))) + return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag=MESSAGEPACK))) if not dataclasses.is_dataclass(python_val): raise TypeTransformerFailedError( @@ -697,7 +743,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp f" and implement _serialize and _deserialize methods." ) - return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack"))) + return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag=MESSAGEPACK))) def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]: # dataclass will try to hash python type when calling dataclass.schema(), but some types in the annotation is @@ -863,10 +909,10 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: # The function looks up or creates a JSONDecoder specifically designed for the object's type. # This decoder is then used to convert a JSON string into a data class. try: - decoder = self._decoder[expected_python_type] + decoder = self._json_decoder[expected_python_type] except KeyError: decoder = JSONDecoder(expected_python_type) - self._decoder[expected_python_type] = decoder + self._json_decoder[expected_python_type] = decoder dc = decoder.decode(json_str) @@ -1929,6 +1975,43 @@ def extract_types_or_metadata(t: Optional[Type[dict]]) -> typing.Tuple: return _args # type: ignore return None, None + @staticmethod + async def dict_to_generic_literal( + ctx: FlyteContext, v: dict, python_type: Type[dict], allow_pickle: bool + ) -> Literal: + """ + This is deprecated from flytekit 1.14.0. + Creates a flyte-specific ``Literal`` value from a native python dictionary. + Note: This is deprecated and will be removed in the future. + """ + from flytekit.types.pickle import FlytePickle + + try: + try: + # JSONEncoder is mashumaro's codec and this can triggered Flyte Types customized serialization and deserialization. + encoder = JSONEncoder(python_type) + json_str = encoder.encode(v) + except NotImplementedError: + raise NotImplementedError( + f"{python_type} should inherit from mashumaro.types.SerializableType" + f" and implement _serialize and _deserialize methods." + ) + + return Literal( + scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())), + metadata={"format": "json"}, + ) + except TypeError as e: + if allow_pickle: + remote_path = await FlytePickle.to_pickle(ctx, v) + return Literal( + scalar=Scalar( + generic=_json_format.Parse(json.dumps({"pickle_file": remote_path}), _struct.Struct()) + ), + metadata={"format": "pickle"}, + ) + raise TypeTransformerFailedError(f"Cannot convert `{v}` to Flyte Literal.\n" f"Error Message: {e}") + @staticmethod async def dict_to_binary_literal( ctx: FlyteContext, v: dict, python_type: Type[dict], allow_pickle: bool @@ -1943,7 +2026,7 @@ async def dict_to_binary_literal( # Handle dictionaries with non-string keys (e.g., Dict[int, Type]) encoder = MessagePackEncoder(python_type) msgpack_bytes = encoder.encode(v) - return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack"))) + return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag=MESSAGEPACK))) except TypeError as e: if allow_pickle: remote_path = await FlytePickle.to_pickle(ctx, v) @@ -2004,6 +2087,8 @@ async def async_to_literal( allow_pickle, base_type = DictTransformer.is_pickle(python_type) if expected and expected.simple and expected.simple == SimpleType.STRUCT: + if str2bool(os.getenv(FLYTE_USE_OLD_DC_FORMAT)): + return await self.dict_to_generic_literal(ctx, python_val, python_type, allow_pickle) return await self.dict_to_binary_literal(ctx, python_val, python_type, allow_pickle) lit_map = {} diff --git a/flytekit/extras/pydantic_transformer/__init__.py b/flytekit/extras/pydantic_transformer/__init__.py index 3f7744fe2f..ec1ed0ff16 100644 --- a/flytekit/extras/pydantic_transformer/__init__.py +++ b/flytekit/extras/pydantic_transformer/__init__.py @@ -7,5 +7,5 @@ from . import transformer except (ImportError, OSError) as e: - logger.warning(f"Meet error when importing pydantic: `{e}`") - logger.warning("Flytekit only support pydantic version > 2.") + logger.debug(f"Meet error when importing pydantic: `{e}`") + logger.debug("Flytekit only support pydantic version > 2.") diff --git a/flytekit/extras/pydantic_transformer/decorator.py b/flytekit/extras/pydantic_transformer/decorator.py index 9db567739a..39a878a2da 100644 --- a/flytekit/extras/pydantic_transformer/decorator.py +++ b/flytekit/extras/pydantic_transformer/decorator.py @@ -14,7 +14,7 @@ It looks nicer in the real Flyte File/Directory class, but we also want it to not fail. """ - logger.warning( + logger.debug( "Pydantic is not installed.\n" "Please install Pydantic version > 2 to use FlyteTypes in pydantic BaseModel." ) diff --git a/flytekit/extras/pydantic_transformer/transformer.py b/flytekit/extras/pydantic_transformer/transformer.py index 4abefcc298..dc6751218b 100644 --- a/flytekit/extras/pydantic_transformer/transformer.py +++ b/flytekit/extras/pydantic_transformer/transformer.py @@ -1,13 +1,16 @@ import json +import os from typing import Type import msgpack from google.protobuf import json_format as _json_format +from google.protobuf import struct_pb2 as _struct from pydantic import BaseModel from flytekit import FlyteContext -from flytekit.core.constants import MESSAGEPACK +from flytekit.core.constants import FLYTE_USE_OLD_DC_FORMAT, MESSAGEPACK from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.core.utils import str2bool from flytekit.loggers import logger from flytekit.models import types from flytekit.models.literals import Binary, Literal, Scalar @@ -31,10 +34,24 @@ def get_literal_type(self, t: Type[BaseModel]) -> LiteralType: "Field {} of type {} cannot be converted to a literal type. Error: {}".format(name, python_type, e) ) + # This is for attribute access in FlytePropeller. ts = TypeStructure(tag="", dataclass_type=literal_type) return types.LiteralType(simple=types.SimpleType.STRUCT, metadata=schema, structure=ts) + def to_generic_literal( + self, + ctx: FlyteContext, + python_val: BaseModel, + python_type: Type[BaseModel], + expected: types.LiteralType, + ) -> Literal: + """ + Note: This is deprecated and will be removed in the future. + """ + json_str = python_val.model_dump_json() + return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) + def to_literal( self, ctx: FlyteContext, @@ -47,6 +64,9 @@ def to_literal( This is for handling enum in basemodel. More details: https://github.com/flyteorg/flytekit/pull/2792 """ + if str2bool(os.getenv(FLYTE_USE_OLD_DC_FORMAT)): + return self.to_generic_literal(ctx, python_val, python_type, expected) + json_str = python_val.model_dump_json() dict_obj = json.loads(json_str) msgpack_bytes = msgpack.dumps(dict_obj) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 39843668cb..89d088c264 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -701,6 +701,7 @@ async def async_to_literal( # that we will need to invoke an encoder for. Figure out which encoder to call and invoke it. df_type = type(python_val.dataframe) protocol = self._protocol_from_type_or_prefix(ctx, df_type, python_val.uri) + return self.encode( ctx, python_val, diff --git a/tests/flytekit/unit/core/test_generic_idl_dataclass.py b/tests/flytekit/unit/core/test_generic_idl_dataclass.py new file mode 100644 index 0000000000..0e9974782a --- /dev/null +++ b/tests/flytekit/unit/core/test_generic_idl_dataclass.py @@ -0,0 +1,1095 @@ +import copy + +import pytest +from enum import Enum +from dataclasses_json import DataClassJsonMixin +from mashumaro.mixins.json import DataClassJSONMixin +import os +import sys +import tempfile +from dataclasses import dataclass, fields, field +from typing import List, Dict, Optional, Union, Any +from typing_extensions import Annotated +from flytekit.types.schema import FlyteSchema +from flytekit.core.type_engine import TypeEngine +from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.type_engine import DataclassTransformer +from flytekit import task, workflow +from flytekit.types.directory import FlyteDirectory +from flytekit.types.file import FlyteFile +from flytekit.types.structured import StructuredDataset + +@pytest.fixture(autouse=True) +def prepare_env_variable(): + try: + origin_env = copy.deepcopy(os.environ.copy()) + os.environ["FLYTE_USE_OLD_DC_FORMAT"] = "True" + yield + finally: + os.environ = origin_env + +@pytest.fixture +def local_dummy_txt_file(): + fd, path = tempfile.mkstemp(suffix=".txt") + try: + with os.fdopen(fd, "w") as tmp: + tmp.write("Hello World") + yield path + finally: + os.remove(path) + +@pytest.fixture +def local_dummy_directory(): + temp_dir = tempfile.TemporaryDirectory() + try: + with open(os.path.join(temp_dir.name, "file"), "w") as tmp: + tmp.write("Hello world") + yield temp_dir.name + finally: + temp_dir.cleanup() + +def test_dataclass(): + @dataclass + class AppParams(DataClassJsonMixin): + snapshotDate: str + region: str + preprocess: bool + listKeys: List[str] + + @task + def t1() -> AppParams: + ap = AppParams(snapshotDate="4/5/2063", region="us-west-3", preprocess=False, listKeys=["a", "b"]) + return ap + + @workflow + def wf() -> AppParams: + return t1() + + res = wf() + assert res.region == "us-west-3" + + +def test_dataclass_assert_works_for_annotated(): + @dataclass + class MyDC(DataClassJSONMixin): + my_str: str + + d = Annotated[MyDC, "tag"] + DataclassTransformer().assert_type(d, MyDC(my_str="hi")) + +def test_pure_dataclasses_with_python_types(): + @dataclass + class DC: + string: Optional[str] = None + + @dataclass + class DCWithOptional: + string: Optional[str] = None + dc: Optional[DC] = None + list_dc: Optional[List[DC]] = None + list_list_dc: Optional[List[List[DC]]] = None + dict_dc: Optional[Dict[str, DC]] = None + dict_dict_dc: Optional[Dict[str, Dict[str, DC]]] = None + dict_list_dc: Optional[Dict[str, List[DC]]] = None + list_dict_dc: Optional[List[Dict[str, DC]]] = None + + @task + def t1() -> DCWithOptional: + return DCWithOptional(string="a", dc=DC(string="b"), + list_dc=[DC(string="c"), DC(string="d")], + list_list_dc=[[DC(string="e"), DC(string="f")]], + list_dict_dc=[{"g": DC(string="h"), "i": DC(string="j")}, + {"k": DC(string="l"), "m": DC(string="n")}], + dict_dc={"o": DC(string="p"), "q": DC(string="r")}, + dict_dict_dc={"s": {"t": DC(string="u"), "v": DC(string="w")}}, + dict_list_dc={"x": [DC(string="y"), DC(string="z")], + "aa": [DC(string="bb"), DC(string="cc")]},) + + @task + def t2() -> DCWithOptional: + return DCWithOptional() + + output = DCWithOptional(string="a", dc=DC(string="b"), + list_dc=[DC(string="c"), DC(string="d")], + list_list_dc=[[DC(string="e"), DC(string="f")]], + list_dict_dc=[{"g": DC(string="h"), "i": DC(string="j")}, + {"k": DC(string="l"), "m": DC(string="n")}], + dict_dc={"o": DC(string="p"), "q": DC(string="r")}, + dict_dict_dc={"s": {"t": DC(string="u"), "v": DC(string="w")}}, + dict_list_dc={"x": [DC(string="y"), DC(string="z")], + "aa": [DC(string="bb"), DC(string="cc")]}, ) + + dc1 = t1() + dc2 = t2() + + assert dc1 == output + assert dc2.string is None + assert dc2.dc is None + + DataclassTransformer().assert_type(DCWithOptional, dc1) + DataclassTransformer().assert_type(DCWithOptional, dc2) + + +def test_pure_dataclasses_with_python_types_get_literal_type_and_to_python_value(): + @dataclass + class DC: + string: Optional[str] = None + + @dataclass + class DCWithOptional: + string: Optional[str] = None + dc: Optional[DC] = None + list_dc: Optional[List[DC]] = None + list_list_dc: Optional[List[List[DC]]] = None + dict_dc: Optional[Dict[str, DC]] = None + dict_dict_dc: Optional[Dict[str, Dict[str, DC]]] = None + dict_list_dc: Optional[Dict[str, List[DC]]] = None + list_dict_dc: Optional[List[Dict[str, DC]]] = None + + ctx = FlyteContextManager.current_context() + + + o = DCWithOptional() + lt = TypeEngine.to_literal_type(DCWithOptional) + lv = TypeEngine.to_literal(ctx, o, DCWithOptional, lt) + assert lv is not None + pv = TypeEngine.to_python_value(ctx, lv, DCWithOptional) + assert isinstance(pv, DCWithOptional) + DataclassTransformer().assert_type(DCWithOptional, pv) + + o = DCWithOptional(string="a", dc=DC(string="b"), + list_dc=[DC(string="c"), DC(string="d")], + list_list_dc=[[DC(string="e"), DC(string="f")]], + list_dict_dc=[{"g": DC(string="h"), "i": DC(string="j")}, + {"k": DC(string="l"), "m": DC(string="n")}], + dict_dc={"o": DC(string="p"), "q": DC(string="r")}, + dict_dict_dc={"s": {"t": DC(string="u"), "v": DC(string="w")}}, + dict_list_dc={"x": [DC(string="y"), DC(string="z")], + "aa": [DC(string="bb"), DC(string="cc")]}) + lt = TypeEngine.to_literal_type(DCWithOptional) + lv = TypeEngine.to_literal(ctx, o, DCWithOptional, lt) + assert lv is not None + pv = TypeEngine.to_python_value(ctx, lv, DCWithOptional) + assert isinstance(pv, DCWithOptional) + DataclassTransformer().assert_type(DCWithOptional, pv) + + +def test_pure_dataclasses_with_flyte_types(local_dummy_txt_file, local_dummy_directory): + @dataclass + class FlyteTypes: + flytefile: Optional[FlyteFile] = None + flytedir: Optional[FlyteDirectory] = None + structured_dataset: Optional[StructuredDataset] = None + + @dataclass + class NestedFlyteTypes: + flytefile: Optional[FlyteFile] = None + flytedir: Optional[FlyteDirectory] = None + structured_dataset: Optional[StructuredDataset] = None + flyte_types: Optional[FlyteTypes] = None + list_flyte_types: Optional[List[FlyteTypes]] = None + dict_flyte_types: Optional[Dict[str, FlyteTypes]] = None + optional_flyte_types: Optional[FlyteTypes] = None + + @task + def pass_and_return_flyte_types(nested_flyte_types: NestedFlyteTypes) -> NestedFlyteTypes: + return nested_flyte_types + + @task + def generate_sd() -> StructuredDataset: + return StructuredDataset( + uri="s3://my-s3-bucket/data/test_sd", + file_format="parquet") + + @task + def create_local_dir(path: str) -> FlyteDirectory: + return FlyteDirectory(path=path) + + @task + def create_local_dir_by_str(path: str) -> FlyteDirectory: + return path + + @task + def create_local_file(path: str) -> FlyteFile: + return FlyteFile(path=path) + + @task + def create_local_file_with_str(path: str) -> FlyteFile: + return path + + @task + def generate_nested_flyte_types(local_file: FlyteFile, local_dir: FlyteDirectory, sd: StructuredDataset, + local_file_by_str: FlyteFile, + local_dir_by_str: FlyteDirectory, ) -> NestedFlyteTypes: + ft = FlyteTypes( + flytefile=local_file, + flytedir=local_dir, + structured_dataset=sd, + ) + + return NestedFlyteTypes( + flytefile=local_file, + flytedir=local_dir, + structured_dataset=sd, + flyte_types=FlyteTypes( + flytefile=local_file_by_str, + flytedir=local_dir_by_str, + structured_dataset=sd, + ), + list_flyte_types=[ft, ft, ft], + dict_flyte_types={"a": ft, "b": ft, "c": ft}, + ) + + @workflow + def nested_dc_wf(txt_path: str, dir_path: str) -> NestedFlyteTypes: + local_file = create_local_file(path=txt_path) + local_dir = create_local_dir(path=dir_path) + local_file_by_str = create_local_file_with_str(path=txt_path) + local_dir_by_str = create_local_dir_by_str(path=dir_path) + sd = generate_sd() + nested_flyte_types = generate_nested_flyte_types( + local_file=local_file, + local_dir=local_dir, + local_file_by_str=local_file_by_str, + local_dir_by_str=local_dir_by_str, + sd=sd + ) + old_flyte_types = pass_and_return_flyte_types(nested_flyte_types=nested_flyte_types) + return pass_and_return_flyte_types(nested_flyte_types=old_flyte_types) + + @task + def get_empty_nested_type() -> NestedFlyteTypes: + return NestedFlyteTypes() + + @workflow + def empty_nested_dc_wf() -> NestedFlyteTypes: + return get_empty_nested_type() + + nested_flyte_types = nested_dc_wf(txt_path=local_dummy_txt_file, dir_path=local_dummy_directory) + DataclassTransformer().assert_type(NestedFlyteTypes, nested_flyte_types) + + empty_nested_flyte_types = empty_nested_dc_wf() + DataclassTransformer().assert_type(NestedFlyteTypes, empty_nested_flyte_types) + + +def test_pure_dataclasses_with_flyte_types_get_literal_type_and_to_python_value(local_dummy_txt_file, local_dummy_directory): + @dataclass + class FlyteTypes: + flytefile: Optional[FlyteFile] = None + flytedir: Optional[FlyteDirectory] = None + structured_dataset: Optional[StructuredDataset] = None + + @dataclass + class NestedFlyteTypes: + flytefile: Optional[FlyteFile] = None + flytedir: Optional[FlyteDirectory] = None + structured_dataset: Optional[StructuredDataset] = None + flyte_types: Optional[FlyteTypes] = None + list_flyte_types: Optional[List[FlyteTypes]] = None + dict_flyte_types: Optional[Dict[str, FlyteTypes]] = None + optional_flyte_types: Optional[FlyteTypes] = None + + ctx = FlyteContextManager.current_context() + + o = NestedFlyteTypes() + + lt = TypeEngine.to_literal_type(NestedFlyteTypes) + lv = TypeEngine.to_literal(ctx, o, NestedFlyteTypes, lt) + assert lv is not None + pv = TypeEngine.to_python_value(ctx, lv, NestedFlyteTypes) + assert isinstance(pv, NestedFlyteTypes) + DataclassTransformer().assert_type(NestedFlyteTypes, pv) + + ff = FlyteFile(path=local_dummy_txt_file) + fd = FlyteDirectory(path=local_dummy_directory) + sd = StructuredDataset(uri="s3://my-s3-bucket/data/test_sd", file_format="parquet") + o = NestedFlyteTypes( + flytefile=ff, + flytedir=fd, + structured_dataset=sd, + flyte_types=FlyteTypes( + flytefile=ff, + flytedir=fd, + structured_dataset=sd, + ), + list_flyte_types=[FlyteTypes( + flytefile=ff, + flytedir=fd, + structured_dataset=sd, + )], + dict_flyte_types={ + "a": FlyteTypes( + flytefile=ff, + flytedir=fd, + structured_dataset=sd, + ), + "b": FlyteTypes( + flytefile=ff, + flytedir=fd, + structured_dataset=sd)}, + optional_flyte_types=FlyteTypes( + flytefile=ff, + flytedir=fd, + structured_dataset=sd, + ), + ) + + lt = TypeEngine.to_literal_type(NestedFlyteTypes) + lv = TypeEngine.to_literal(ctx, o, NestedFlyteTypes, lt) + assert lv is not None + pv = TypeEngine.to_python_value(ctx, lv, NestedFlyteTypes) + assert isinstance(pv, NestedFlyteTypes) + DataclassTransformer().assert_type(NestedFlyteTypes, pv) + +## For dataclasses json mixin, it's equal to use @dataclasses_json +def test_dataclasses_json_mixin_with_python_types(): + @dataclass + class DC(DataClassJsonMixin): + string: Optional[str] = None + + @dataclass + class DCWithOptional(DataClassJsonMixin): + string: Optional[str] = None + dc: Optional[DC] = None + list_dc: Optional[List[DC]] = None + list_list_dc: Optional[List[List[DC]]] = None + dict_dc: Optional[Dict[str, DC]] = None + dict_dict_dc: Optional[Dict[str, Dict[str, DC]]] = None + dict_list_dc: Optional[Dict[str, List[DC]]] = None + list_dict_dc: Optional[List[Dict[str, DC]]] = None + + @task + def t1() -> DCWithOptional: + return DCWithOptional(string="a", dc=DC(string="b"), + list_dc=[DC(string="c"), DC(string="d")], + list_list_dc=[[DC(string="e"), DC(string="f")]], + list_dict_dc=[{"g": DC(string="h"), "i": DC(string="j")}, + {"k": DC(string="l"), "m": DC(string="n")}], + dict_dc={"o": DC(string="p"), "q": DC(string="r")}, + dict_dict_dc={"s": {"t": DC(string="u"), "v": DC(string="w")}}, + dict_list_dc={"x": [DC(string="y"), DC(string="z")], + "aa": [DC(string="bb"), DC(string="cc")]},) + + @task + def t2() -> DCWithOptional: + return DCWithOptional() + + output = DCWithOptional(string="a", dc=DC(string="b"), + list_dc=[DC(string="c"), DC(string="d")], + list_list_dc=[[DC(string="e"), DC(string="f")]], + list_dict_dc=[{"g": DC(string="h"), "i": DC(string="j")}, + {"k": DC(string="l"), "m": DC(string="n")}], + dict_dc={"o": DC(string="p"), "q": DC(string="r")}, + dict_dict_dc={"s": {"t": DC(string="u"), "v": DC(string="w")}}, + dict_list_dc={"x": [DC(string="y"), DC(string="z")], + "aa": [DC(string="bb"), DC(string="cc")]}, ) + + dc1 = t1() + dc2 = t2() + + assert dc1 == output + assert dc2.string is None + assert dc2.dc is None + + DataclassTransformer().assert_type(DCWithOptional, dc1) + DataclassTransformer().assert_type(DCWithOptional, dc2) + + +def test_dataclasses_json_mixin__with_python_types_get_literal_type_and_to_python_value(): + @dataclass + class DC(DataClassJsonMixin): + string: Optional[str] = None + + @dataclass + class DCWithOptional(DataClassJsonMixin): + string: Optional[str] = None + dc: Optional[DC] = None + list_dc: Optional[List[DC]] = None + list_list_dc: Optional[List[List[DC]]] = None + dict_dc: Optional[Dict[str, DC]] = None + dict_dict_dc: Optional[Dict[str, Dict[str, DC]]] = None + dict_list_dc: Optional[Dict[str, List[DC]]] = None + list_dict_dc: Optional[List[Dict[str, DC]]] = None + + ctx = FlyteContextManager.current_context() + + + o = DCWithOptional() + lt = TypeEngine.to_literal_type(DCWithOptional) + lv = TypeEngine.to_literal(ctx, o, DCWithOptional, lt) + assert lv is not None + pv = TypeEngine.to_python_value(ctx, lv, DCWithOptional) + assert isinstance(pv, DCWithOptional) + DataclassTransformer().assert_type(DCWithOptional, pv) + + o = DCWithOptional(string="a", dc=DC(string="b"), + list_dc=[DC(string="c"), DC(string="d")], + list_list_dc=[[DC(string="e"), DC(string="f")]], + list_dict_dc=[{"g": DC(string="h"), "i": DC(string="j")}, + {"k": DC(string="l"), "m": DC(string="n")}], + dict_dc={"o": DC(string="p"), "q": DC(string="r")}, + dict_dict_dc={"s": {"t": DC(string="u"), "v": DC(string="w")}}, + dict_list_dc={"x": [DC(string="y"), DC(string="z")], + "aa": [DC(string="bb"), DC(string="cc")]}) + lt = TypeEngine.to_literal_type(DCWithOptional) + lv = TypeEngine.to_literal(ctx, o, DCWithOptional, lt) + assert lv is not None + pv = TypeEngine.to_python_value(ctx, lv, DCWithOptional) + assert isinstance(pv, DCWithOptional) + DataclassTransformer().assert_type(DCWithOptional, pv) + + +def test_dataclasses_json_mixin_with_flyte_types(local_dummy_txt_file, local_dummy_directory): + @dataclass + class FlyteTypes(DataClassJsonMixin): + flytefile: Optional[FlyteFile] = None + flytedir: Optional[FlyteDirectory] = None + structured_dataset: Optional[StructuredDataset] = None + + @dataclass + class NestedFlyteTypes(DataClassJsonMixin): + flytefile: Optional[FlyteFile] = None + flytedir: Optional[FlyteDirectory] = None + structured_dataset: Optional[StructuredDataset] = None + flyte_types: Optional[FlyteTypes] = None + list_flyte_types: Optional[List[FlyteTypes]] = None + dict_flyte_types: Optional[Dict[str, FlyteTypes]] = None + optional_flyte_types: Optional[FlyteTypes] = None + + @task + def pass_and_return_flyte_types(nested_flyte_types: NestedFlyteTypes) -> NestedFlyteTypes: + return nested_flyte_types + + @task + def generate_sd() -> StructuredDataset: + return StructuredDataset( + uri="s3://my-s3-bucket/data/test_sd", + file_format="parquet") + + @task + def create_local_dir(path: str) -> FlyteDirectory: + return FlyteDirectory(path=path) + + @task + def create_local_dir_by_str(path: str) -> FlyteDirectory: + return path + + @task + def create_local_file(path: str) -> FlyteFile: + return FlyteFile(path=path) + + @task + def create_local_file_with_str(path: str) -> FlyteFile: + return path + + @task + def generate_nested_flyte_types(local_file: FlyteFile, local_dir: FlyteDirectory, sd: StructuredDataset, + local_file_by_str: FlyteFile, + local_dir_by_str: FlyteDirectory, ) -> NestedFlyteTypes: + ft = FlyteTypes( + flytefile=local_file, + flytedir=local_dir, + structured_dataset=sd, + ) + + return NestedFlyteTypes( + flytefile=local_file, + flytedir=local_dir, + structured_dataset=sd, + flyte_types=FlyteTypes( + flytefile=local_file_by_str, + flytedir=local_dir_by_str, + structured_dataset=sd, + ), + list_flyte_types=[ft, ft, ft], + dict_flyte_types={"a": ft, "b": ft, "c": ft}, + ) + + @workflow + def nested_dc_wf(txt_path: str, dir_path: str) -> NestedFlyteTypes: + local_file = create_local_file(path=txt_path) + local_dir = create_local_dir(path=dir_path) + local_file_by_str = create_local_file_with_str(path=txt_path) + local_dir_by_str = create_local_dir_by_str(path=dir_path) + sd = generate_sd() + # current branch -> current branch + nested_flyte_types = generate_nested_flyte_types( + local_file=local_file, + local_dir=local_dir, + local_file_by_str=local_file_by_str, + local_dir_by_str=local_dir_by_str, + sd=sd + ) + old_flyte_types = pass_and_return_flyte_types(nested_flyte_types=nested_flyte_types) + return pass_and_return_flyte_types(nested_flyte_types=old_flyte_types) + + @task + def get_empty_nested_type() -> NestedFlyteTypes: + return NestedFlyteTypes() + + @workflow + def empty_nested_dc_wf() -> NestedFlyteTypes: + return get_empty_nested_type() + + nested_flyte_types = nested_dc_wf(txt_path=local_dummy_txt_file, dir_path=local_dummy_directory) + DataclassTransformer().assert_type(NestedFlyteTypes, nested_flyte_types) + + empty_nested_flyte_types = empty_nested_dc_wf() + DataclassTransformer().assert_type(NestedFlyteTypes, empty_nested_flyte_types) + + +def test_dataclasses_json_mixin_with_flyte_types_get_literal_type_and_to_python_value(local_dummy_txt_file, local_dummy_directory): + @dataclass + class FlyteTypes(DataClassJsonMixin): + flytefile: Optional[FlyteFile] = None + flytedir: Optional[FlyteDirectory] = None + structured_dataset: Optional[StructuredDataset] = None + + @dataclass + class NestedFlyteTypes(DataClassJsonMixin): + flytefile: Optional[FlyteFile] = None + flytedir: Optional[FlyteDirectory] = None + structured_dataset: Optional[StructuredDataset] = None + flyte_types: Optional[FlyteTypes] = None + list_flyte_types: Optional[List[FlyteTypes]] = None + dict_flyte_types: Optional[Dict[str, FlyteTypes]] = None + optional_flyte_types: Optional[FlyteTypes] = None + + ctx = FlyteContextManager.current_context() + + o = NestedFlyteTypes() + + lt = TypeEngine.to_literal_type(NestedFlyteTypes) + lv = TypeEngine.to_literal(ctx, o, NestedFlyteTypes, lt) + assert lv is not None + pv = TypeEngine.to_python_value(ctx, lv, NestedFlyteTypes) + assert isinstance(pv, NestedFlyteTypes) + DataclassTransformer().assert_type(NestedFlyteTypes, pv) + + ff = FlyteFile(path=local_dummy_txt_file) + fd = FlyteDirectory(path=local_dummy_directory) + sd = StructuredDataset(uri="s3://my-s3-bucket/data/test_sd", file_format="parquet") + o = NestedFlyteTypes( + flytefile=ff, + flytedir=fd, + structured_dataset=sd, + flyte_types=FlyteTypes( + flytefile=ff, + flytedir=fd, + structured_dataset=sd, + ), + list_flyte_types=[FlyteTypes( + flytefile=ff, + flytedir=fd, + structured_dataset=sd, + )], + dict_flyte_types={ + "a": FlyteTypes( + flytefile=ff, + flytedir=fd, + structured_dataset=sd, + ), + "b": FlyteTypes( + flytefile=ff, + flytedir=fd, + structured_dataset=sd)}, + optional_flyte_types=FlyteTypes( + flytefile=ff, + flytedir=fd, + structured_dataset=sd, + ), + ) + + lt = TypeEngine.to_literal_type(NestedFlyteTypes) + lv = TypeEngine.to_literal(ctx, o, NestedFlyteTypes, lt) + assert lv is not None + pv = TypeEngine.to_python_value(ctx, lv, NestedFlyteTypes) + assert isinstance(pv, NestedFlyteTypes) + DataclassTransformer().assert_type(NestedFlyteTypes, pv) + +# For mashumaro dataclasses mixins, it's equal to use @dataclasses only +def test_mashumaro_dataclasses_json_mixin_with_python_types(): + @dataclass + class DC(DataClassJSONMixin): + string: Optional[str] = None + + @dataclass + class DCWithOptional(DataClassJSONMixin): + string: Optional[str] = None + dc: Optional[DC] = None + list_dc: Optional[List[DC]] = None + list_list_dc: Optional[List[List[DC]]] = None + dict_dc: Optional[Dict[str, DC]] = None + dict_dict_dc: Optional[Dict[str, Dict[str, DC]]] = None + dict_list_dc: Optional[Dict[str, List[DC]]] = None + list_dict_dc: Optional[List[Dict[str, DC]]] = None + + @task + def t1() -> DCWithOptional: + return DCWithOptional(string="a", dc=DC(string="b"), + list_dc=[DC(string="c"), DC(string="d")], + list_list_dc=[[DC(string="e"), DC(string="f")]], + list_dict_dc=[{"g": DC(string="h"), "i": DC(string="j")}, + {"k": DC(string="l"), "m": DC(string="n")}], + dict_dc={"o": DC(string="p"), "q": DC(string="r")}, + dict_dict_dc={"s": {"t": DC(string="u"), "v": DC(string="w")}}, + dict_list_dc={"x": [DC(string="y"), DC(string="z")], + "aa": [DC(string="bb"), DC(string="cc")]},) + + @task + def t2() -> DCWithOptional: + return DCWithOptional() + + output = DCWithOptional(string="a", dc=DC(string="b"), + list_dc=[DC(string="c"), DC(string="d")], + list_list_dc=[[DC(string="e"), DC(string="f")]], + list_dict_dc=[{"g": DC(string="h"), "i": DC(string="j")}, + {"k": DC(string="l"), "m": DC(string="n")}], + dict_dc={"o": DC(string="p"), "q": DC(string="r")}, + dict_dict_dc={"s": {"t": DC(string="u"), "v": DC(string="w")}}, + dict_list_dc={"x": [DC(string="y"), DC(string="z")], + "aa": [DC(string="bb"), DC(string="cc")]}, ) + + dc1 = t1() + dc2 = t2() + + assert dc1 == output + assert dc2.string is None + assert dc2.dc is None + + DataclassTransformer().assert_type(DCWithOptional, dc1) + DataclassTransformer().assert_type(DCWithOptional, dc2) + + +def test_mashumaro_dataclasses_json_mixin_with_python_types_get_literal_type_and_to_python_value(): + @dataclass + class DC(DataClassJSONMixin): + string: Optional[str] = None + + @dataclass + class DCWithOptional(DataClassJSONMixin): + string: Optional[str] = None + dc: Optional[DC] = None + list_dc: Optional[List[DC]] = None + list_list_dc: Optional[List[List[DC]]] = None + dict_dc: Optional[Dict[str, DC]] = None + dict_dict_dc: Optional[Dict[str, Dict[str, DC]]] = None + dict_list_dc: Optional[Dict[str, List[DC]]] = None + list_dict_dc: Optional[List[Dict[str, DC]]] = None + + ctx = FlyteContextManager.current_context() + + + o = DCWithOptional() + lt = TypeEngine.to_literal_type(DCWithOptional) + lv = TypeEngine.to_literal(ctx, o, DCWithOptional, lt) + assert lv is not None + pv = TypeEngine.to_python_value(ctx, lv, DCWithOptional) + assert isinstance(pv, DCWithOptional) + DataclassTransformer().assert_type(DCWithOptional, pv) + + o = DCWithOptional(string="a", dc=DC(string="b"), + list_dc=[DC(string="c"), DC(string="d")], + list_list_dc=[[DC(string="e"), DC(string="f")]], + list_dict_dc=[{"g": DC(string="h"), "i": DC(string="j")}, + {"k": DC(string="l"), "m": DC(string="n")}], + dict_dc={"o": DC(string="p"), "q": DC(string="r")}, + dict_dict_dc={"s": {"t": DC(string="u"), "v": DC(string="w")}}, + dict_list_dc={"x": [DC(string="y"), DC(string="z")], + "aa": [DC(string="bb"), DC(string="cc")]}) + lt = TypeEngine.to_literal_type(DCWithOptional) + lv = TypeEngine.to_literal(ctx, o, DCWithOptional, lt) + assert lv is not None + pv = TypeEngine.to_python_value(ctx, lv, DCWithOptional) + assert isinstance(pv, DCWithOptional) + DataclassTransformer().assert_type(DCWithOptional, pv) + + +def test_mashumaro_dataclasses_json_mixin_with_flyte_types(local_dummy_txt_file, local_dummy_directory): + @dataclass + class FlyteTypes(DataClassJSONMixin): + flytefile: Optional[FlyteFile] = None + flytedir: Optional[FlyteDirectory] = None + structured_dataset: Optional[StructuredDataset] = None + + @dataclass + class NestedFlyteTypes(DataClassJSONMixin): + flytefile: Optional[FlyteFile] = None + flytedir: Optional[FlyteDirectory] = None + structured_dataset: Optional[StructuredDataset] = None + flyte_types: Optional[FlyteTypes] = None + list_flyte_types: Optional[List[FlyteTypes]] = None + dict_flyte_types: Optional[Dict[str, FlyteTypes]] = None + optional_flyte_types: Optional[FlyteTypes] = None + + @task + def pass_and_return_flyte_types(nested_flyte_types: NestedFlyteTypes) -> NestedFlyteTypes: + return nested_flyte_types + + @task + def generate_sd() -> StructuredDataset: + return StructuredDataset( + uri="s3://my-s3-bucket/data/test_sd", + file_format="parquet") + + @task + def create_local_dir(path: str) -> FlyteDirectory: + return FlyteDirectory(path=path) + + @task + def create_local_dir_by_str(path: str) -> FlyteDirectory: + return path + + @task + def create_local_file(path: str) -> FlyteFile: + return FlyteFile(path=path) + + @task + def create_local_file_with_str(path: str) -> FlyteFile: + return path + + @task + def generate_nested_flyte_types(local_file: FlyteFile, local_dir: FlyteDirectory, sd: StructuredDataset, + local_file_by_str: FlyteFile, + local_dir_by_str: FlyteDirectory, ) -> NestedFlyteTypes: + ft = FlyteTypes( + flytefile=local_file, + flytedir=local_dir, + structured_dataset=sd, + ) + + return NestedFlyteTypes( + flytefile=local_file, + flytedir=local_dir, + structured_dataset=sd, + flyte_types=FlyteTypes( + flytefile=local_file_by_str, + flytedir=local_dir_by_str, + structured_dataset=sd, + ), + list_flyte_types=[ft, ft, ft], + dict_flyte_types={"a": ft, "b": ft, "c": ft}, + ) + + @workflow + def nested_dc_wf(txt_path: str, dir_path: str) -> NestedFlyteTypes: + local_file = create_local_file(path=txt_path) + local_dir = create_local_dir(path=dir_path) + local_file_by_str = create_local_file_with_str(path=txt_path) + local_dir_by_str = create_local_dir_by_str(path=dir_path) + sd = generate_sd() + nested_flyte_types = generate_nested_flyte_types( + local_file=local_file, + local_dir=local_dir, + local_file_by_str=local_file_by_str, + local_dir_by_str=local_dir_by_str, + sd=sd + ) + old_flyte_types = pass_and_return_flyte_types(nested_flyte_types=nested_flyte_types) + return pass_and_return_flyte_types(nested_flyte_types=old_flyte_types) + + @task + def get_empty_nested_type() -> NestedFlyteTypes: + return NestedFlyteTypes() + + @workflow + def empty_nested_dc_wf() -> NestedFlyteTypes: + return get_empty_nested_type() + + nested_flyte_types = nested_dc_wf(txt_path=local_dummy_txt_file, dir_path=local_dummy_directory) + DataclassTransformer().assert_type(NestedFlyteTypes, nested_flyte_types) + + empty_nested_flyte_types = empty_nested_dc_wf() + DataclassTransformer().assert_type(NestedFlyteTypes, empty_nested_flyte_types) + + +def test_mashumaro_dataclasses_json_mixin_with_flyte_types_get_literal_type_and_to_python_value(local_dummy_txt_file, local_dummy_directory): + @dataclass + class FlyteTypes(DataClassJSONMixin): + flytefile: Optional[FlyteFile] = None + flytedir: Optional[FlyteDirectory] = None + structured_dataset: Optional[StructuredDataset] = None + + @dataclass + class NestedFlyteTypes(DataClassJSONMixin): + flytefile: Optional[FlyteFile] = None + flytedir: Optional[FlyteDirectory] = None + structured_dataset: Optional[StructuredDataset] = None + flyte_types: Optional[FlyteTypes] = None + list_flyte_types: Optional[List[FlyteTypes]] = None + dict_flyte_types: Optional[Dict[str, FlyteTypes]] = None + optional_flyte_types: Optional[FlyteTypes] = None + + ctx = FlyteContextManager.current_context() + + o = NestedFlyteTypes() + + lt = TypeEngine.to_literal_type(NestedFlyteTypes) + lv = TypeEngine.to_literal(ctx, o, NestedFlyteTypes, lt) + assert lv is not None + pv = TypeEngine.to_python_value(ctx, lv, NestedFlyteTypes) + assert isinstance(pv, NestedFlyteTypes) + DataclassTransformer().assert_type(NestedFlyteTypes, pv) + + ff = FlyteFile(path=local_dummy_txt_file) + fd = FlyteDirectory(path=local_dummy_directory) + sd = StructuredDataset(uri="s3://my-s3-bucket/data/test_sd", file_format="parquet") + o = NestedFlyteTypes( + flytefile=ff, + flytedir=fd, + structured_dataset=sd, + flyte_types=FlyteTypes( + flytefile=ff, + flytedir=fd, + structured_dataset=sd, + ), + list_flyte_types=[FlyteTypes( + flytefile=ff, + flytedir=fd, + structured_dataset=sd, + )], + dict_flyte_types={ + "a": FlyteTypes( + flytefile=ff, + flytedir=fd, + structured_dataset=sd, + ), + "b": FlyteTypes( + flytefile=ff, + flytedir=fd, + structured_dataset=sd)}, + optional_flyte_types=FlyteTypes( + flytefile=ff, + flytedir=fd, + structured_dataset=sd, + ), + ) + + lt = TypeEngine.to_literal_type(NestedFlyteTypes) + lv = TypeEngine.to_literal(ctx, o, NestedFlyteTypes, lt) + assert lv is not None + pv = TypeEngine.to_python_value(ctx, lv, NestedFlyteTypes) + assert isinstance(pv, NestedFlyteTypes) + DataclassTransformer().assert_type(NestedFlyteTypes, pv) + +def test_get_literal_type_data_class_json_fail_but_mashumaro_works(): + @dataclass + class FlyteTypesWithDataClassJson(DataClassJsonMixin): + flytefile: FlyteFile + flytedir: FlyteDirectory + structured_dataset: StructuredDataset + fs: FlyteSchema + + @dataclass + class NestedFlyteTypesWithDataClassJson(DataClassJsonMixin): + flytefile: FlyteFile + flytedir: FlyteDirectory + structured_dataset: StructuredDataset + flyte_types: FlyteTypesWithDataClassJson + fs: FlyteSchema + flyte_types: FlyteTypesWithDataClassJson + list_flyte_types: List[FlyteTypesWithDataClassJson] + dict_flyte_types: Dict[str, FlyteTypesWithDataClassJson] + flyte_types: FlyteTypesWithDataClassJson + optional_flyte_types: Optional[FlyteTypesWithDataClassJson] = None + + transformer = DataclassTransformer() + lt = transformer.get_literal_type(NestedFlyteTypesWithDataClassJson) + assert lt.metadata is not None +@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10 or higher") +def test_numpy_import_issue_from_flyte_schema_in_dataclass(): + from dataclasses import dataclass + + from flytekit import task, workflow + from flytekit.types.directory import FlyteDirectory + from flytekit.types.file import FlyteFile + + @dataclass + class MyDataClass: + output_file: FlyteFile + output_directory: FlyteDirectory + + @task + def my_flyte_workflow(b: bool) -> list[MyDataClass | None]: + if b: + return [MyDataClass(__file__, ".")] + return [None] + + @task + def my_flyte_task(inputs: list[MyDataClass | None]) -> bool: + return inputs and (inputs[0] is not None) # type: ignore + + @workflow + def main_flyte_workflow(b: bool = False) -> bool: + inputs = my_flyte_workflow(b=b) + return my_flyte_task(inputs=inputs) + + assert main_flyte_workflow(b=True) == True + assert main_flyte_workflow(b=False) == False + +def test_frozen_dataclass(): + @dataclass(frozen=True) + class FrozenDataclass: + a: int = 1 + b: float = 2.0 + c: bool = True + d: str = "hello" + + @task + def t1(dc: FrozenDataclass) -> (int, float, bool, str): + return dc.a, dc.b, dc.c, dc.d + + a, b, c, d = t1(dc=FrozenDataclass()) + assert a == 1 + assert b == 2.0 + assert c == True + assert d == "hello" + +def test_pure_frozen_dataclasses_with_python_types(): + @dataclass(frozen=True) + class DC: + string: Optional[str] = None + + @dataclass(frozen=True) + class DCWithOptional: + string: Optional[str] = None + dc: Optional[DC] = None + list_dc: Optional[List[DC]] = None + list_list_dc: Optional[List[List[DC]]] = None + dict_dc: Optional[Dict[str, DC]] = None + dict_dict_dc: Optional[Dict[str, Dict[str, DC]]] = None + dict_list_dc: Optional[Dict[str, List[DC]]] = None + list_dict_dc: Optional[List[Dict[str, DC]]] = None + + @task + def t1() -> DCWithOptional: + return DCWithOptional(string="a", dc=DC(string="b"), + list_dc=[DC(string="c"), DC(string="d")], + list_list_dc=[[DC(string="e"), DC(string="f")]], + list_dict_dc=[{"g": DC(string="h"), "i": DC(string="j")}, + {"k": DC(string="l"), "m": DC(string="n")}], + dict_dc={"o": DC(string="p"), "q": DC(string="r")}, + dict_dict_dc={"s": {"t": DC(string="u"), "v": DC(string="w")}}, + dict_list_dc={"x": [DC(string="y"), DC(string="z")], + "aa": [DC(string="bb"), DC(string="cc")]},) + + @task + def t2() -> DCWithOptional: + return DCWithOptional() + + output = DCWithOptional(string="a", dc=DC(string="b"), + list_dc=[DC(string="c"), DC(string="d")], + list_list_dc=[[DC(string="e"), DC(string="f")]], + list_dict_dc=[{"g": DC(string="h"), "i": DC(string="j")}, + {"k": DC(string="l"), "m": DC(string="n")}], + dict_dc={"o": DC(string="p"), "q": DC(string="r")}, + dict_dict_dc={"s": {"t": DC(string="u"), "v": DC(string="w")}}, + dict_list_dc={"x": [DC(string="y"), DC(string="z")], + "aa": [DC(string="bb"), DC(string="cc")]}, ) + + dc1 = t1() + dc2 = t2() + + assert dc1 == output + assert dc2.string is None + assert dc2.dc is None + + DataclassTransformer().assert_type(DCWithOptional, dc1) + DataclassTransformer().assert_type(DCWithOptional, dc2) + +def test_pure_frozen_dataclasses_with_flyte_types(local_dummy_txt_file, local_dummy_directory): + @dataclass(frozen=True) + class FlyteTypes: + flytefile: Optional[FlyteFile] = None + flytedir: Optional[FlyteDirectory] = None + structured_dataset: Optional[StructuredDataset] = None + + @dataclass(frozen=True) + class NestedFlyteTypes: + flytefile: Optional[FlyteFile] = None + flytedir: Optional[FlyteDirectory] = None + structured_dataset: Optional[StructuredDataset] = None + flyte_types: Optional[FlyteTypes] = None + list_flyte_types: Optional[List[FlyteTypes]] = None + dict_flyte_types: Optional[Dict[str, FlyteTypes]] = None + optional_flyte_types: Optional[FlyteTypes] = None + + @task + def pass_and_return_flyte_types(nested_flyte_types: NestedFlyteTypes) -> NestedFlyteTypes: + return nested_flyte_types + + @task + def generate_sd() -> StructuredDataset: + return StructuredDataset( + uri="s3://my-s3-bucket/data/test_sd", + file_format="parquet") + + @task + def create_local_dir(path: str) -> FlyteDirectory: + return FlyteDirectory(path=path) + + @task + def create_local_dir_by_str(path: str) -> FlyteDirectory: + return path + + @task + def create_local_file(path: str) -> FlyteFile: + return FlyteFile(path=path) + + @task + def create_local_file_with_str(path: str) -> FlyteFile: + return path + + @task + def generate_nested_flyte_types(local_file: FlyteFile, local_dir: FlyteDirectory, sd: StructuredDataset, + local_file_by_str: FlyteFile, + local_dir_by_str: FlyteDirectory, ) -> NestedFlyteTypes: + ft = FlyteTypes( + flytefile=local_file, + flytedir=local_dir, + structured_dataset=sd, + ) + + return NestedFlyteTypes( + flytefile=local_file, + flytedir=local_dir, + structured_dataset=sd, + flyte_types=FlyteTypes( + flytefile=local_file_by_str, + flytedir=local_dir_by_str, + structured_dataset=sd, + ), + list_flyte_types=[ft, ft, ft], + dict_flyte_types={"a": ft, "b": ft, "c": ft}, + ) + + @workflow + def nested_dc_wf(txt_path: str, dir_path: str) -> NestedFlyteTypes: + local_file = create_local_file(path=txt_path) + local_dir = create_local_dir(path=dir_path) + local_file_by_str = create_local_file_with_str(path=txt_path) + local_dir_by_str = create_local_dir_by_str(path=dir_path) + sd = generate_sd() + nested_flyte_types = generate_nested_flyte_types( + local_file=local_file, + local_dir=local_dir, + local_file_by_str=local_file_by_str, + local_dir_by_str=local_dir_by_str, + sd=sd + ) + old_flyte_types = pass_and_return_flyte_types(nested_flyte_types=nested_flyte_types) + return pass_and_return_flyte_types(nested_flyte_types=old_flyte_types) + + @task + def get_empty_nested_type() -> NestedFlyteTypes: + return NestedFlyteTypes() + + @workflow + def empty_nested_dc_wf() -> NestedFlyteTypes: + return get_empty_nested_type() + + nested_flyte_types = nested_dc_wf(txt_path=local_dummy_txt_file, dir_path=local_dummy_directory) + DataclassTransformer().assert_type(NestedFlyteTypes, nested_flyte_types) + + empty_nested_flyte_types = empty_nested_dc_wf() + DataclassTransformer().assert_type(NestedFlyteTypes, empty_nested_flyte_types) diff --git a/tests/flytekit/unit/core/test_generice_idl_type_engine.py b/tests/flytekit/unit/core/test_generice_idl_type_engine.py new file mode 100644 index 0000000000..f23c9557f8 --- /dev/null +++ b/tests/flytekit/unit/core/test_generice_idl_type_engine.py @@ -0,0 +1,3621 @@ +import dataclasses +import datetime +import json +import os +import re +import sys +import tempfile +import typing +from dataclasses import asdict, dataclass, field +from datetime import timedelta +from enum import Enum, auto +from typing import List, Optional, Type + +import mock +import pytest +import typing_extensions +from concurrent.futures import ThreadPoolExecutor +from dataclasses_json import DataClassJsonMixin, dataclass_json +from flyteidl.core import errors_pb2 +from google.protobuf import json_format as _json_format +from google.protobuf import struct_pb2 as _struct +from marshmallow_enum import LoadDumpOptions +from marshmallow_jsonschema import JSONSchema +from mashumaro.config import BaseConfig +from mashumaro.mixins.json import DataClassJSONMixin +from mashumaro.mixins.orjson import DataClassORJSONMixin +from mashumaro.types import Discriminator +from typing_extensions import Annotated, get_args, get_origin + +from flytekit import dynamic, kwtypes, task, workflow +from flytekit.core.annotation import FlyteAnnotation +from flytekit.core.context_manager import FlyteContext, FlyteContextManager +from flytekit.core.data_persistence import flyte_tmp_dir +from flytekit.core.hash import HashMethod +from flytekit.core.type_engine import ( + DataclassTransformer, + DictTransformer, + EnumTransformer, + ListTransformer, + LiteralsResolver, + SimpleTransformer, + TypeEngine, + TypeTransformer, + TypeTransformerFailedError, + UnionTransformer, + convert_marshmallow_json_schema_to_python_class, + convert_mashumaro_json_schema_to_python_class, + dataclass_from_dict, + get_underlying_type, + is_annotated, IntTransformer, +) +from flytekit.core.type_engine import * +from flytekit.exceptions import user as user_exceptions +from flytekit.models import types as model_types +from flytekit.models.annotation import TypeAnnotation +from flytekit.models.core.types import BlobType +from flytekit.models.literals import ( + Blob, + BlobMetadata, + Literal, + LiteralCollection, + LiteralMap, + LiteralOffloadedMetadata, + Primitive, + Scalar, + Void, Binary, +) +from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType +from flytekit.types.directory import TensorboardLogs +from flytekit.types.directory.types import ( + FlyteDirectory, + FlyteDirToMultipartBlobTransformer, +) +from flytekit.types.file import FileExt, JPEGImageFile +from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer, noop +from flytekit.types.iterator.iterator import IteratorTransformer +from flytekit.types.iterator.json_iterator import JSONIterator, JSONIteratorTransformer, JSON +from flytekit.types.pickle import FlytePickle +from flytekit.types.pickle.pickle import FlytePickleTransformer +from flytekit.types.schema import FlyteSchema +from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine, PARQUET + +@pytest.fixture(autouse=True) +def prepare_env_variable(): + try: + origin_env = copy.deepcopy(os.environ.copy()) + os.environ["FLYTE_USE_OLD_DC_FORMAT"] = "True" + yield + finally: + os.environ = origin_env + + + +T = typing.TypeVar("T") + + +def test_type_engine(): + t = int + lt = TypeEngine.to_literal_type(t) + assert lt.simple == model_types.SimpleType.INTEGER + + t = typing.Dict[str, typing.List[typing.Dict[str, timedelta]]] + lt = TypeEngine.to_literal_type(t) + assert lt.map_value_type.collection_type.map_value_type.simple == model_types.SimpleType.DURATION + + +def test_named_tuple(): + t = typing.NamedTuple("Outputs", [("x_str", str), ("y_int", int)]) + var_map = TypeEngine.named_tuple_to_variable_map(t) + assert var_map.variables["x_str"].type.simple == model_types.SimpleType.STRING + assert var_map.variables["y_int"].type.simple == model_types.SimpleType.INTEGER + + +def test_type_resolution(): + assert type(TypeEngine.get_transformer(typing.List[int])) == ListTransformer + assert type(TypeEngine.get_transformer(typing.List)) == ListTransformer + assert type(TypeEngine.get_transformer(list)) == ListTransformer + + assert type(TypeEngine.get_transformer(typing.Dict[str, int])) == DictTransformer + assert type(TypeEngine.get_transformer(typing.Dict)) == DictTransformer + assert type(TypeEngine.get_transformer(dict)) == DictTransformer + assert type(TypeEngine.get_transformer(Annotated[dict, kwtypes(allow_pickle=True)])) == DictTransformer + + assert type(TypeEngine.get_transformer(int)) == SimpleTransformer + assert type(TypeEngine.get_transformer(datetime.date)) == SimpleTransformer + + assert type(TypeEngine.get_transformer(os.PathLike)) == FlyteFilePathTransformer + assert type(TypeEngine.get_transformer(FlytePickle)) == FlytePickleTransformer + assert type(TypeEngine.get_transformer(typing.Any)) == FlytePickleTransformer + + +def test_file_formats_getting_literal_type(): + transformer = TypeEngine.get_transformer(FlyteFile) + + lt = transformer.get_literal_type(FlyteFile) + assert lt.blob.format == "" + + # Works with formats that we define + lt = transformer.get_literal_type(FlyteFile["txt"]) + assert lt.blob.format == "txt" + + lt = transformer.get_literal_type(FlyteFile[typing.TypeVar("jpg")]) + assert lt.blob.format == "jpg" + + # Empty default to the default + lt = transformer.get_literal_type(FlyteFile) + assert lt.blob.format == "" + + lt = transformer.get_literal_type(FlyteFile[typing.TypeVar(".png")]) + assert lt.blob.format == "png" + + +def test_file_format_getting_python_value(): + transformer = TypeEngine.get_transformer(FlyteFile) + + ctx = FlyteContext.current_context() + + temp_dir = tempfile.mkdtemp(prefix="temp_example_") + file_path = os.path.join(temp_dir, "file.txt") + with open(file_path, "w") as file1: + file1.write("hello world") + lv = Literal( + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata(type=BlobType(format="txt", dimensionality=0)), + uri=file_path, + ) + ) + ) + + pv = transformer.to_python_value(ctx, lv, expected_python_type=FlyteFile["txt"]) + assert isinstance(pv, FlyteFile) + assert pv.extension() == "txt" + + +def test_list_of_dict_getting_python_value(): + transformer = TypeEngine.get_transformer(typing.List) + ctx = FlyteContext.current_context() + lv = Literal( + collection=LiteralCollection( + literals=[Literal(map=LiteralMap({"foo": Literal(scalar=Scalar(primitive=Primitive(integer=1)))}))] + ) + ) + + pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[typing.Dict[str, int]]) + assert isinstance(pv, list) + + +def test_list_of_single_dataclass(): + @dataclass + class Bar(DataClassJsonMixin): + v: typing.Optional[typing.List[int]] + w: typing.Optional[typing.List[float]] + + @dataclass + class Foo(DataClassJsonMixin): + a: typing.Optional[typing.List[str]] + b: Bar + + foo = Foo(a=["abc", "def"], b=Bar(v=[1, 2, 99], w=[3.1415, 2.7182])) + generic = _json_format.Parse(typing.cast(DataClassJsonMixin, foo).to_json(), _struct.Struct()) + lv = Literal(collection=LiteralCollection(literals=[Literal(scalar=Scalar(generic=generic))])) + + transformer = TypeEngine.get_transformer(typing.List) + ctx = FlyteContext.current_context() + + pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[Foo]) + assert pv[0].a == ["abc", "def"] + assert pv[0].b == Bar(v=[1, 2, 99], w=[3.1415, 2.7182]) + + +@dataclass +class Bar(DataClassJSONMixin): + v: typing.Optional[typing.List[int]] + w: typing.Optional[typing.List[float]] + + +@dataclass +class Foo(DataClassJSONMixin): + a: typing.Optional[typing.List[str]] + b: Bar + + +def test_list_of_single_dataclassjsonmixin(): + foo = Foo(a=["abc", "def"], b=Bar(v=[1, 2, 99], w=[3.1415, 2.7182])) + generic = _json_format.Parse(typing.cast(DataClassJSONMixin, foo).to_json(), _struct.Struct()) + lv = Literal(collection=LiteralCollection(literals=[Literal(scalar=Scalar(generic=generic))])) + + transformer = TypeEngine.get_transformer(typing.List) + ctx = FlyteContext.current_context() + + pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[Foo]) + assert pv[0].a == ["abc", "def"] + assert pv[0].b == Bar(v=[1, 2, 99], w=[3.1415, 2.7182]) + + +def test_annotated_type(): + class JsonTypeTransformer(TypeTransformer[T]): + LiteralType = LiteralType( + simple=SimpleType.STRING, + annotation=TypeAnnotation(annotations=dict(protocol="json")), + ) + + def get_literal_type(self, t: Type[T]) -> LiteralType: + return self.LiteralType + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> Optional[T]: + return json.loads(lv.scalar.primitive.string_value) + + def to_literal( + self, + ctx: FlyteContext, + python_val: T, + python_type: typing.Type[T], + expected: LiteralType, + ) -> Literal: + return Literal(scalar=Scalar(primitive=Primitive(string_value=json.dumps(python_val)))) + + class JSONSerialized: + def __class_getitem__(cls, item: Type[T]): + return Annotated[item, JsonTypeTransformer(name=f"json[{item}]", t=item)] + + MyJsonDict = JSONSerialized[typing.Dict[str, int]] + _, test_transformer = get_args(MyJsonDict) + + assert TypeEngine.get_transformer(MyJsonDict) is test_transformer + assert TypeEngine.to_literal_type(MyJsonDict) == JsonTypeTransformer.LiteralType + + test_dict = {"foo": 1} + test_literal = Literal(scalar=Scalar(primitive=Primitive(string_value=json.dumps(test_dict)))) + + assert ( + TypeEngine.to_python_value( + FlyteContext.current_context(), + test_literal, + MyJsonDict, + ) + == test_dict + ) + + assert ( + TypeEngine.to_literal( + FlyteContext.current_context(), + test_dict, + MyJsonDict, + JsonTypeTransformer.LiteralType, + ) + == test_literal + ) + + +def test_list_of_dataclass_getting_python_value(): + @dataclass + class Bar(DataClassJsonMixin): + v: typing.Union[int, None] + w: typing.Optional[str] + x: float + y: str + z: typing.Dict[str, bool] + + @dataclass + class Foo(DataClassJsonMixin): + u: typing.Optional[int] + v: typing.Optional[int] + w: int + x: typing.List[int] + y: typing.Dict[str, str] + z: Bar + + foo = Foo( + u=5, + v=None, + w=1, + x=[1], + y={"hello": "10"}, + z=Bar(v=3, w=None, x=1.0, y="hello", z={"world": False}), + ) + generic = _json_format.Parse(typing.cast(DataClassJsonMixin, foo).to_json(), _struct.Struct()) + lv = Literal(collection=LiteralCollection(literals=[Literal(scalar=Scalar(generic=generic))])) + + transformer = TypeEngine.get_transformer(typing.List) + ctx = FlyteContext.current_context() + + schema = JSONSchema().dump(typing.cast(DataClassJsonMixin, Foo).schema()) + foo_class = convert_marshmallow_json_schema_to_python_class(schema["definitions"], "FooSchema") + + guessed_pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[foo_class]) + pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[Foo]) + assert isinstance(guessed_pv, list) + assert guessed_pv[0].u == pv[0].u + assert guessed_pv[0].v == pv[0].v + assert guessed_pv[0].w == pv[0].w + assert guessed_pv[0].x == pv[0].x + assert guessed_pv[0].y == pv[0].y + assert guessed_pv[0].z.x == pv[0].z.x + assert type(guessed_pv[0].u) == int + assert guessed_pv[0].v is None + assert type(guessed_pv[0].w) == int + assert type(guessed_pv[0].z.v) == int + assert type(guessed_pv[0].z.x) == float + assert guessed_pv[0].z.v == pv[0].z.v + assert guessed_pv[0].z.y == pv[0].z.y + assert guessed_pv[0].z.z == pv[0].z.z + assert pv[0] == dataclass_from_dict(Foo, asdict(guessed_pv[0])) + assert dataclasses.is_dataclass(foo_class) + + +@dataclass +class Bar_getting_python_value(DataClassJSONMixin): + v: typing.Union[int, None] + w: typing.Optional[str] + x: float + y: str + z: typing.Dict[str, bool] + + +@dataclass +class Foo_getting_python_value(DataClassJSONMixin): + u: typing.Optional[int] + v: typing.Optional[int] + w: int + x: typing.List[int] + y: typing.Dict[str, str] + z: Bar_getting_python_value + + +def test_list_of_dataclassjsonmixin_getting_python_value(): + foo = Foo_getting_python_value( + u=5, + v=None, + w=1, + x=[1], + y={"hello": "10"}, + z=Bar_getting_python_value(v=3, w=None, x=1.0, y="hello", z={"world": False}), + ) + generic = _json_format.Parse(typing.cast(DataClassJSONMixin, foo).to_json(), _struct.Struct()) + lv = Literal(collection=LiteralCollection(literals=[Literal(scalar=Scalar(generic=generic))])) + + transformer = TypeEngine.get_transformer(typing.List) + ctx = FlyteContext.current_context() + + from mashumaro.jsonschema import build_json_schema + + schema = build_json_schema(typing.cast(DataClassJSONMixin, Foo_getting_python_value)).to_dict() + foo_class = convert_mashumaro_json_schema_to_python_class(schema, "FooSchema") + + guessed_pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[foo_class]) + pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[Foo_getting_python_value]) + assert isinstance(guessed_pv, list) + assert guessed_pv[0].u == pv[0].u + assert guessed_pv[0].v == pv[0].v + assert guessed_pv[0].w == pv[0].w + assert guessed_pv[0].x == pv[0].x + assert guessed_pv[0].y == pv[0].y + assert guessed_pv[0].z.x == pv[0].z.x + assert type(guessed_pv[0].u) == int + assert guessed_pv[0].v is None + assert type(guessed_pv[0].w) == int + assert type(guessed_pv[0].z.v) == int + assert type(guessed_pv[0].z.x) == float + assert guessed_pv[0].z.v == pv[0].z.v + assert guessed_pv[0].z.y == pv[0].z.y + assert guessed_pv[0].z.z == pv[0].z.z + assert pv[0] == dataclass_from_dict(Foo_getting_python_value, asdict(guessed_pv[0])) + assert dataclasses.is_dataclass(foo_class) + + +def test_file_no_downloader_default(): + # The idea of this test is to assert that if a FlyteFile is created with no download specified, + # then it should return the set path itself. This matches if we use open method + transformer = TypeEngine.get_transformer(FlyteFile) + + ctx = FlyteContext.current_context() + temp_dir = tempfile.mkdtemp(prefix="temp_example_") + local_file = os.path.join(temp_dir, "file.txt") + with open(local_file, "w") as file: + file.write("hello world") + + lv = Literal( + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata(type=BlobType(format="", dimensionality=0)), + uri=local_file, + ) + ) + ) + + pv = transformer.to_python_value(ctx, lv, expected_python_type=FlyteFile) + assert isinstance(pv, FlyteFile) + assert pv.download() == local_file + + +def test_dir_no_downloader_default(): + # The idea of this test is to assert that if a FlyteFile is created with no download specified, + # then it should return the set path itself. This matches if we use open method + transformer = TypeEngine.get_transformer(FlyteDirectory) + + ctx = FlyteContext.current_context() + + local_dir = tempfile.mkdtemp(prefix="temp_example_") + + lv = Literal( + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata(type=BlobType(format="", dimensionality=1)), + uri=local_dir, + ) + ) + ) + + pv = transformer.to_python_value(ctx, lv, expected_python_type=FlyteDirectory) + assert isinstance(pv, FlyteDirectory) + assert pv.download() == local_dir + + +def test_dict_transformer(): + d = DictTransformer() + + def assert_struct(lit: LiteralType): + assert lit is not None + assert lit.simple == SimpleType.STRUCT + + def recursive_assert( + lit: LiteralType, + expected: LiteralType, + expected_depth: int = 1, + curr_depth: int = 0, + ): + assert curr_depth <= expected_depth + assert lit is not None + if lit.map_value_type is None: + assert lit == expected + return + recursive_assert(lit.map_value_type, expected, expected_depth, curr_depth + 1) + + # Type inference + assert_struct(d.get_literal_type(dict)) + assert_struct(d.get_literal_type(Annotated[dict, kwtypes(allow_pickle=True)])) + assert_struct(d.get_literal_type(typing.Dict[int, int])) + recursive_assert(d.get_literal_type(typing.Dict[str, str]), LiteralType(simple=SimpleType.STRING)) + recursive_assert( + d.get_literal_type(typing.Dict[str, int]), + LiteralType(simple=SimpleType.INTEGER), + ) + recursive_assert( + d.get_literal_type(typing.Dict[str, datetime.datetime]), + LiteralType(simple=SimpleType.DATETIME), + ) + recursive_assert( + d.get_literal_type(typing.Dict[str, datetime.timedelta]), + LiteralType(simple=SimpleType.DURATION), + ) + recursive_assert( + d.get_literal_type(typing.Dict[str, datetime.date]), + LiteralType(simple=SimpleType.DATETIME), + ) + recursive_assert( + d.get_literal_type(typing.Dict[str, dict]), + LiteralType(simple=SimpleType.STRUCT), + ) + recursive_assert( + d.get_literal_type(typing.Dict[str, typing.Dict[str, str]]), + LiteralType(simple=SimpleType.STRING), + expected_depth=2, + ) + recursive_assert( + d.get_literal_type(typing.Dict[str, typing.Dict[int, str]]), + LiteralType(simple=SimpleType.STRUCT), + expected_depth=2, + ) + recursive_assert( + d.get_literal_type(typing.Dict[str, typing.Dict[str, typing.Dict[str, str]]]), + LiteralType(simple=SimpleType.STRING), + expected_depth=3, + ) + recursive_assert( + d.get_literal_type(typing.Dict[str, typing.Dict[str, typing.Dict[str, dict]]]), + LiteralType(simple=SimpleType.STRUCT), + expected_depth=3, + ) + recursive_assert( + d.get_literal_type(typing.Dict[str, typing.Dict[str, typing.Dict[int, dict]]]), + LiteralType(simple=SimpleType.STRUCT), + expected_depth=2, + ) + + ctx = FlyteContext.current_context() + + lit = d.to_literal(ctx, {}, typing.Dict, LiteralType(SimpleType.STRUCT)) + pv = d.to_python_value(ctx, lit, typing.Dict) + assert pv == {} + + lit_empty = Literal(map=LiteralMap(literals={})) + pv_empty = d.to_python_value(ctx, lit_empty, typing.Dict[str, str]) + assert pv_empty == {} + + # Literal to python + with pytest.raises(TypeError): + d.to_python_value(ctx, Literal(scalar=Scalar(primitive=Primitive(integer=10))), dict) + with pytest.raises(TypeError): + d.to_python_value(ctx, Literal(), dict) + with pytest.raises(TypeError): + d.to_python_value(ctx, Literal(map=LiteralMap(literals={"x": None})), dict) + with pytest.raises(TypeError): + d.to_python_value(ctx, Literal(map=LiteralMap(literals={"x": None})), typing.Dict[int, str]) + + with pytest.raises(TypeError): + d.to_literal( + ctx, + {"x": datetime.datetime(2024, 5, 5)}, + dict, + LiteralType(simple=SimpleType.STRUCT), + ) + + lv = d.to_literal( + ctx, + {"x": datetime.datetime(2024, 5, 5)}, + Annotated[dict, kwtypes(allow_pickle=True)], + LiteralType(simple=SimpleType.STRUCT), + ) + assert lv.metadata["format"] == "pickle" + assert d.to_python_value(ctx, lv, dict) == {"x": datetime.datetime(2024, 5, 5)} + + d.to_python_value( + ctx, + Literal(map=LiteralMap(literals={"x": Literal(scalar=Scalar(primitive=Primitive(integer=1)))})), + typing.Dict[str, int], + ) + + lv = d.to_literal( + ctx, + {"x": "hello"}, + dict, + LiteralType(simple=SimpleType.STRUCT), + ) + + lv._metadata = None + assert d.to_python_value(ctx, lv, dict) == {"x": "hello"} + + +def test_convert_marshmallow_json_schema_to_python_class(): + @dataclass + class Foo(DataClassJsonMixin): + x: int + y: str + + schema = JSONSchema().dump(typing.cast(DataClassJsonMixin, Foo).schema()) + foo_class = convert_marshmallow_json_schema_to_python_class(schema["definitions"], "FooSchema") + foo = foo_class(x=1, y="hello") + foo.x = 2 + assert foo.x == 2 + assert foo.y == "hello" + with pytest.raises(AttributeError): + _ = foo.c + assert dataclasses.is_dataclass(foo_class) + + +def test_convert_mashumaro_json_schema_to_python_class(): + @dataclass + class Foo(DataClassJSONMixin): + x: int + y: str + + # schema = JSONSchema().dump(typing.cast(DataClassJSONMixin, Foo).schema()) + from mashumaro.jsonschema import build_json_schema + + schema = build_json_schema(typing.cast(DataClassJSONMixin, Foo)).to_dict() + foo_class = convert_mashumaro_json_schema_to_python_class(schema, "FooSchema") + foo = foo_class(x=1, y="hello") + foo.x = 2 + assert foo.x == 2 + assert foo.y == "hello" + with pytest.raises(AttributeError): + _ = foo.c + assert dataclasses.is_dataclass(foo_class) + + +def test_list_transformer(): + l0 = Literal(scalar=Scalar(primitive=Primitive(integer=3))) + l1 = Literal(scalar=Scalar(primitive=Primitive(integer=4))) + lc = LiteralCollection(literals=[l0, l1]) + lit = Literal(collection=lc) + + ctx = FlyteContext.current_context() + xx = TypeEngine.to_python_value(ctx, lit, typing.List[int]) + assert xx == [3, 4] + + +def test_protos(): + ctx = FlyteContext.current_context() + + pb = errors_pb2.ContainerError(code="code", message="message") + lt = TypeEngine.to_literal_type(errors_pb2.ContainerError) + assert lt.simple == SimpleType.STRUCT + assert lt.metadata["pb_type"] == "flyteidl.core.errors_pb2.ContainerError" + + lit = TypeEngine.to_literal(ctx, pb, errors_pb2.ContainerError, lt) + new_python_val = TypeEngine.to_python_value(ctx, lit, errors_pb2.ContainerError) + assert new_python_val == pb + + # Test error + l0 = Literal(scalar=Scalar(primitive=Primitive(integer=4))) + with pytest.raises(AssertionError): + TypeEngine.to_python_value(ctx, l0, errors_pb2.ContainerError) + + default_proto = errors_pb2.ContainerError() + lit = TypeEngine.to_literal(ctx, default_proto, errors_pb2.ContainerError, lt) + assert lit.scalar + assert lit.scalar.generic is not None + new_python_val = TypeEngine.to_python_value(ctx, lit, errors_pb2.ContainerError) + assert new_python_val == default_proto + + +def test_guessing_basic(): + b = model_types.LiteralType(simple=model_types.SimpleType.BOOLEAN) + pt = TypeEngine.guess_python_type(b) + assert pt is bool + + lt = model_types.LiteralType(simple=model_types.SimpleType.INTEGER) + pt = TypeEngine.guess_python_type(lt) + assert pt is int + + lt = model_types.LiteralType(simple=model_types.SimpleType.STRING) + pt = TypeEngine.guess_python_type(lt) + assert pt is str + + lt = model_types.LiteralType(simple=model_types.SimpleType.DURATION) + pt = TypeEngine.guess_python_type(lt) + assert pt is timedelta + + lt = model_types.LiteralType(simple=model_types.SimpleType.DATETIME) + pt = TypeEngine.guess_python_type(lt) + assert pt is datetime.datetime + + lt = model_types.LiteralType(simple=model_types.SimpleType.FLOAT) + pt = TypeEngine.guess_python_type(lt) + assert pt is float + + lt = model_types.LiteralType(simple=model_types.SimpleType.NONE) + pt = TypeEngine.guess_python_type(lt) + assert pt is type(None) # noqa: E721 + + lt = model_types.LiteralType( + blob=BlobType( + format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT, + dimensionality=BlobType.BlobDimensionality.SINGLE, + ) + ) + pt = TypeEngine.guess_python_type(lt) + assert pt is FlytePickle + + +def test_guessing_containers(): + b = model_types.LiteralType(simple=model_types.SimpleType.BOOLEAN) + lt = model_types.LiteralType(collection_type=b) + pt = TypeEngine.guess_python_type(lt) + assert pt == typing.List[bool] + + dur = model_types.LiteralType(simple=model_types.SimpleType.DURATION) + lt = model_types.LiteralType(map_value_type=dur) + pt = TypeEngine.guess_python_type(lt) + assert pt == typing.Dict[str, timedelta] + + +def test_zero_floats(): + ctx = FlyteContext.current_context() + + l0 = Literal(scalar=Scalar(primitive=Primitive(integer=0))) + l1 = Literal(scalar=Scalar(primitive=Primitive(float_value=0.0))) + + assert TypeEngine.to_python_value(ctx, l0, float) == 0 + assert TypeEngine.to_python_value(ctx, l1, float) == 0 + + +def test_dataclass_transformer(): + @dataclass + class InnerStruct(DataClassJsonMixin): + a: int + b: typing.Optional[str] + c: typing.List[int] + + @dataclass + class TestStruct(DataClassJsonMixin): + s: InnerStruct + m: typing.Dict[str, str] + + @dataclass + class TestStructB(DataClassJsonMixin): + s: InnerStruct + m: typing.Dict[int, str] + n: typing.Optional[typing.List[typing.List[int]]] = None + o: typing.Optional[typing.Dict[int, typing.Dict[int, int]]] = None + + @dataclass + class TestStructC(DataClassJsonMixin): + s: InnerStruct + m: typing.Dict[str, int] + + @dataclass + class TestStructD(DataClassJsonMixin): + s: InnerStruct + m: typing.Dict[str, typing.List[int]] + + class UnsupportedSchemaType: + def __init__(self): + self._a = "Hello" + + @dataclass + class UnsupportedNestedStruct(DataClassJsonMixin): + a: int + s: UnsupportedSchemaType + + schema = { + "$ref": "#/definitions/TeststructSchema", + "$schema": "http://json-schema.org/draft-07/schema#", + "definitions": { + "InnerstructSchema": { + "additionalProperties": False, + "properties": { + "a": {"title": "a", "type": "integer"}, + "b": {"default": None, "title": "b", "type": ["string", "null"]}, + "c": { + "items": {"title": "c", "type": "integer"}, + "title": "c", + "type": "array", + }, + }, + "type": "object", + }, + "TeststructSchema": { + "additionalProperties": False, + "properties": { + "m": { + "additionalProperties": {"title": "m", "type": "string"}, + "title": "m", + "type": "object", + }, + "s": { + "$ref": "#/definitions/InnerstructSchema", + "field_many": False, + "type": "object", + }, + }, + "type": "object", + }, + }, + } + tf = DataclassTransformer() + t = tf.get_literal_type(TestStruct) + assert t is not None + assert t.simple is not None + assert t.simple == SimpleType.STRUCT + assert t.metadata is not None + assert t.metadata == schema + + t = TypeEngine.to_literal_type(TestStruct) + assert t is not None + assert t.simple is not None + assert t.simple == SimpleType.STRUCT + assert t.metadata is not None + assert t.metadata == schema + + t = tf.get_literal_type(UnsupportedNestedStruct) + assert t is not None + assert t.simple is not None + assert t.simple == SimpleType.STRUCT + assert t.metadata is None + + +def test_dataclass_transformer_with_dataclassjsonmixin(): + @dataclass + class InnerStruct_transformer(DataClassJSONMixin): + a: int + b: typing.Optional[str] + c: typing.List[int] + + @dataclass + class TestStruct_transformer(DataClassJSONMixin): + s: InnerStruct_transformer + m: typing.Dict[str, str] + + class UnsupportedSchemaType: + def __init__(self): + self._a = "Hello" + + @dataclass + class UnsupportedNestedStruct(DataClassJsonMixin): + a: int + s: UnsupportedSchemaType + + schema = { + "type": "object", + "title": "TestStruct_transformer", + "properties": { + "s": { + "type": "object", + "title": "InnerStruct_transformer", + "properties": { + "a": {"type": "integer"}, + "b": {"anyOf": [{"type": "string"}, {"type": "null"}]}, + "c": {"type": "array", "items": {"type": "integer"}}, + }, + "additionalProperties": False, + "required": ["a", "b", "c"], + }, + "m": { + "type": "object", + "additionalProperties": {"type": "string"}, + "propertyNames": {"type": "string"}, + }, + }, + "additionalProperties": False, + "required": ["s", "m"], + } + + tf = DataclassTransformer() + t = tf.get_literal_type(TestStruct_transformer) + assert t is not None + assert t.simple is not None + assert t.simple == SimpleType.STRUCT + assert t.metadata is not None + assert t.metadata == schema + + t = TypeEngine.to_literal_type(TestStruct_transformer) + assert t is not None + assert t.simple is not None + assert t.simple == SimpleType.STRUCT + assert t.metadata is not None + assert t.metadata == schema + + t = tf.get_literal_type(UnsupportedNestedStruct) + assert t is not None + assert t.simple is not None + assert t.simple == SimpleType.STRUCT + assert t.metadata is None + + +def test_dataclass_int_preserving(): + @dataclass + class InnerStruct(DataClassJsonMixin): + a: int + b: typing.Optional[str] + c: typing.List[int] + + @dataclass + class TestStructB(DataClassJsonMixin): + s: InnerStruct + m: typing.Dict[int, str] + n: typing.Optional[typing.List[typing.List[int]]] = None + o: typing.Optional[typing.Dict[int, typing.Dict[int, int]]] = None + + @dataclass + class TestStructC(DataClassJsonMixin): + s: InnerStruct + m: typing.Dict[str, int] + + @dataclass + class TestStructD(DataClassJsonMixin): + s: InnerStruct + m: typing.Dict[str, typing.List[int]] + + ctx = FlyteContext.current_context() + o = InnerStruct(a=5, b=None, c=[1, 2, 3]) + tf = DataclassTransformer() + lv = tf.to_literal(ctx, o, InnerStruct, tf.get_literal_type(InnerStruct)) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=InnerStruct) + assert ot == o + + o = TestStructB( + s=InnerStruct(a=5, b=None, c=[1, 2, 3]), + m={5: "b"}, + n=[[1, 2, 3], [4, 5, 6]], + o={1: {2: 3}, 4: {5: 6}}, + ) + lv = tf.to_literal(ctx, o, TestStructB, tf.get_literal_type(TestStructB)) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestStructB) + assert ot == o + + o = TestStructC(s=InnerStruct(a=5, b=None, c=[1, 2, 3]), m={"a": 5}) + lv = tf.to_literal(ctx, o, TestStructC, tf.get_literal_type(TestStructC)) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestStructC) + assert ot == o + + o = TestStructD(s=InnerStruct(a=5, b=None, c=[1, 2, 3]), m={"a": [5]}) + lv = tf.to_literal(ctx, o, TestStructD, tf.get_literal_type(TestStructD)) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestStructD) + assert ot == o + + +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.async_put_data") +def test_dataclass_with_postponed_annotation(mock_put_data): + remote_path = "s3://tmp/file" + mock_put_data.return_value = remote_path + + @dataclass + class Data: + a: int + f: "FlyteFile" + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + t = tf.get_literal_type(Data) + assert t.simple == SimpleType.STRUCT + with tempfile.TemporaryDirectory() as tmp: + test_file = os.path.join(tmp, "abc.txt") + with open(test_file, "w") as f: + f.write("123") + + pv = Data(a=1, f=FlyteFile(test_file, remote_path=remote_path)) + lt = tf.to_literal(ctx, pv, Data, t) + assert lt.scalar.generic.fields["f"].struct_value.fields["path"].string_value == remote_path + +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.async_put_data") +def test_optional_flytefile_in_dataclass(mock_upload_dir): + mock_upload_dir.return_value = True + + @dataclass + class A(DataClassJsonMixin): + a: int + + @dataclass + class TestFileStruct(DataClassJsonMixin): + a: FlyteFile + b: typing.Optional[FlyteFile] + b_prime: typing.Optional[FlyteFile] + c: typing.Union[FlyteFile, None] + d: typing.List[FlyteFile] + e: typing.List[typing.Optional[FlyteFile]] + e_prime: typing.List[typing.Optional[FlyteFile]] + f: typing.Dict[str, FlyteFile] + g: typing.Dict[str, typing.Optional[FlyteFile]] + g_prime: typing.Dict[str, typing.Optional[FlyteFile]] + h: typing.Optional[FlyteFile] = None + h_prime: typing.Optional[FlyteFile] = None + i: typing.Optional[A] = None + i_prime: typing.Optional[A] = field(default_factory=lambda: A(a=99)) + + remote_path = "s3://tmp/file" + # set the return value to the remote path since that's what put_data does + mock_upload_dir.return_value = remote_path + with tempfile.TemporaryFile() as f: + f.write(b"abc") + f1 = FlyteFile("f1", remote_path=remote_path) + o = TestFileStruct( + a=f1, + b=f1, + b_prime=None, + c=f1, + d=[f1], + e=[f1], + e_prime=[None], + f={"a": f1}, + g={"a": f1}, + g_prime={"a": None}, + h=f1, + i=A(a=42), + ) + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(TestFileStruct) + lv = tf.to_literal(ctx, o, TestFileStruct, lt) + + assert lv.scalar.generic["a"].fields["path"].string_value == remote_path + assert lv.scalar.generic["b"].fields["path"].string_value == remote_path + assert lv.scalar.generic["b_prime"] is None + assert lv.scalar.generic["c"].fields["path"].string_value == remote_path + assert lv.scalar.generic["d"].values[0].struct_value.fields["path"].string_value == remote_path + assert lv.scalar.generic["e"].values[0].struct_value.fields["path"].string_value == remote_path + assert lv.scalar.generic["e_prime"].values[0].WhichOneof("kind") == "null_value" + assert lv.scalar.generic["f"]["a"].fields["path"].string_value == remote_path + assert lv.scalar.generic["g"]["a"].fields["path"].string_value == remote_path + assert lv.scalar.generic["g_prime"]["a"] is None + assert lv.scalar.generic["h"].fields["path"].string_value == remote_path + assert lv.scalar.generic["h_prime"] is None + assert lv.scalar.generic["i"].fields["a"].number_value == 42 + assert lv.scalar.generic["i_prime"].fields["a"].number_value == 99 + + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestFileStruct) + + assert o.a.remote_path == ot.a.remote_source + assert o.b.remote_path == ot.b.remote_source + assert ot.b_prime is None + assert o.c.remote_path == ot.c.remote_source + assert o.d[0].remote_path == ot.d[0].remote_source + assert o.e[0].remote_path == ot.e[0].remote_source + assert o.e_prime == [None] + assert o.f["a"].remote_path == ot.f["a"].remote_source + assert o.g["a"].remote_path == ot.g["a"].remote_source + assert o.g_prime == {"a": None} + assert o.h.remote_path == ot.h.remote_source + assert ot.h_prime is None + assert o.i == ot.i + assert o.i_prime == A(a=99) + + +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.async_put_data") +def test_optional_flytefile_in_dataclassjsonmixin(mock_upload_dir): + @dataclass + class A_optional_flytefile(DataClassJSONMixin): + a: int + + @dataclass + class TestFileStruct_optional_flytefile(DataClassJSONMixin): + a: FlyteFile + b: typing.Optional[FlyteFile] + b_prime: typing.Optional[FlyteFile] + c: typing.Union[FlyteFile, None] + d: typing.List[FlyteFile] + e: typing.List[typing.Optional[FlyteFile]] + e_prime: typing.List[typing.Optional[FlyteFile]] + f: typing.Dict[str, FlyteFile] + g: typing.Dict[str, typing.Optional[FlyteFile]] + g_prime: typing.Dict[str, typing.Optional[FlyteFile]] + h: typing.Optional[FlyteFile] = None + h_prime: typing.Optional[FlyteFile] = None + i: typing.Optional[A_optional_flytefile] = None + i_prime: typing.Optional[A_optional_flytefile] = field(default_factory=lambda: A_optional_flytefile(a=99)) + + remote_path = "s3://tmp/file" + mock_upload_dir.return_value = remote_path + + with tempfile.TemporaryFile() as f: + f.write(b"abc") + f1 = FlyteFile("f1", remote_path=remote_path) + o = TestFileStruct_optional_flytefile( + a=f1, + b=f1, + b_prime=None, + c=f1, + d=[f1], + e=[f1], + e_prime=[None], + f={"a": f1}, + g={"a": f1}, + g_prime={"a": None}, + h=f1, + i=A_optional_flytefile(a=42), + ) + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(TestFileStruct_optional_flytefile) + lv = tf.to_literal(ctx, o, TestFileStruct_optional_flytefile, lt) + + assert lv.scalar.generic["a"].fields["path"].string_value == remote_path + assert lv.scalar.generic["b"].fields["path"].string_value == remote_path + assert lv.scalar.generic["b_prime"] is None + assert lv.scalar.generic["c"].fields["path"].string_value == remote_path + assert lv.scalar.generic["d"].values[0].struct_value.fields["path"].string_value == remote_path + assert lv.scalar.generic["e"].values[0].struct_value.fields["path"].string_value == remote_path + assert lv.scalar.generic["e_prime"].values[0].WhichOneof("kind") == "null_value" + assert lv.scalar.generic["f"]["a"].fields["path"].string_value == remote_path + assert lv.scalar.generic["g"]["a"].fields["path"].string_value == remote_path + assert lv.scalar.generic["g_prime"]["a"] is None + assert lv.scalar.generic["h"].fields["path"].string_value == remote_path + assert lv.scalar.generic["h_prime"] is None + assert lv.scalar.generic["i"].fields["a"].number_value == 42 + assert lv.scalar.generic["i_prime"].fields["a"].number_value == 99 + + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestFileStruct_optional_flytefile) + + assert o.a.remote_path == ot.a.remote_source + assert o.b.remote_path == ot.b.remote_source + assert ot.b_prime is None + assert o.c.remote_path == ot.c.remote_source + assert o.d[0].remote_path == ot.d[0].remote_source + assert o.e[0].remote_path == ot.e[0].remote_source + assert o.e_prime == [None] + assert o.f["a"].remote_path == ot.f["a"].remote_source + assert o.g["a"].remote_path == ot.g["a"].remote_source + assert o.g_prime == {"a": None} + assert o.h.remote_path == ot.h.remote_source + assert ot.h_prime is None + assert o.i == ot.i + assert o.i_prime == A_optional_flytefile(a=99) + + +def test_flyte_file_in_dataclass(): + @dataclass + class TestInnerFileStruct(DataClassJsonMixin): + a: JPEGImageFile + b: typing.List[FlyteFile] + c: typing.Dict[str, FlyteFile] + d: typing.List[FlyteFile] + e: typing.Dict[str, FlyteFile] + + @dataclass + class TestFileStruct(DataClassJsonMixin): + a: FlyteFile + b: TestInnerFileStruct + + remote_path = "s3://tmp/file" + f1 = FlyteFile(remote_path) + f2 = FlyteFile("/tmp/file") + f2._remote_source = remote_path + o = TestFileStruct( + a=f1, + b=TestInnerFileStruct( + a=JPEGImageFile("s3://tmp/file.jpeg"), + b=[f1], + c={"hello": f1}, + d=[f2], + e={"hello": f2}, + ), + ) + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(TestFileStruct) + lv = tf.to_literal(ctx, o, TestFileStruct, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestFileStruct) + assert ot.a._downloader is not noop + assert ot.b.a._downloader is not noop + assert ot.b.b[0]._downloader is not noop + assert ot.b.c["hello"]._downloader is not noop + + assert o.a.path == ot.a.remote_source + assert o.b.a.path == ot.b.a.remote_source + assert o.b.b[0].path == ot.b.b[0].remote_source + assert o.b.c["hello"].path == ot.b.c["hello"].remote_source + assert ot.b.d[0].remote_source == remote_path + assert not ctx.file_access.is_remote(ot.b.d[0].path) + assert ot.b.e["hello"].remote_source == remote_path + assert not ctx.file_access.is_remote(ot.b.e["hello"].path) + + +def test_flyte_file_in_dataclassjsonmixin(): + @dataclass + class TestInnerFileStruct_flyte_file(DataClassJSONMixin): + a: JPEGImageFile + b: typing.List[FlyteFile] + c: typing.Dict[str, FlyteFile] + d: typing.List[FlyteFile] + e: typing.Dict[str, FlyteFile] + + @dataclass + class TestFileStruct_flyte_file(DataClassJSONMixin): + a: FlyteFile + b: TestInnerFileStruct_flyte_file + + remote_path = "s3://tmp/file" + f1 = FlyteFile(remote_path) + f2 = FlyteFile("/tmp/file") + f2._remote_source = remote_path + o = TestFileStruct_flyte_file( + a=f1, + b=TestInnerFileStruct_flyte_file( + a=JPEGImageFile("s3://tmp/file.jpeg"), + b=[f1], + c={"hello": f1}, + d=[f2], + e={"hello": f2}, + ), + ) + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(TestFileStruct_flyte_file) + lv = tf.to_literal(ctx, o, TestFileStruct_flyte_file, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestFileStruct_flyte_file) + assert ot.a._downloader is not noop + assert ot.b.a._downloader is not noop + assert ot.b.b[0]._downloader is not noop + assert ot.b.c["hello"]._downloader is not noop + + assert o.a.path == ot.a.remote_source + assert o.b.a.path == ot.b.a.remote_source + assert o.b.b[0].path == ot.b.b[0].remote_source + assert o.b.c["hello"].path == ot.b.c["hello"].remote_source + assert ot.b.d[0].remote_source == remote_path + assert not ctx.file_access.is_remote(ot.b.d[0].path) + assert ot.b.e["hello"].remote_source == remote_path + assert not ctx.file_access.is_remote(ot.b.e["hello"].path) + + +def test_flyte_directory_in_dataclass(): + @dataclass + class TestInnerFileStruct(DataClassJsonMixin): + a: TensorboardLogs + b: typing.List[FlyteDirectory] + c: typing.Dict[str, FlyteDirectory] + d: typing.List[FlyteDirectory] + e: typing.Dict[str, FlyteDirectory] + + @dataclass + class TestFileStruct(DataClassJsonMixin): + a: FlyteDirectory + b: TestInnerFileStruct + + remote_path = "s3://tmp/file" + tempdir = tempfile.mkdtemp(prefix="flyte-") + f1 = FlyteDirectory(tempdir) + f1._remote_source = remote_path + f2 = FlyteDirectory(remote_path) + o = TestFileStruct( + a=f1, + b=TestInnerFileStruct( + a=TensorboardLogs("s3://tensorboard"), + b=[f1], + c={"hello": f1}, + d=[f2], + e={"hello": f2}, + ), + ) + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(TestFileStruct) + lv = tf.to_literal(ctx, o, TestFileStruct, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestFileStruct) + + assert ot.a._downloader is not noop + assert ot.b.a._downloader is not noop + assert ot.b.b[0]._downloader is not noop + assert ot.b.c["hello"]._downloader is not noop + + assert o.a.remote_directory == ot.a.remote_directory + assert not ctx.file_access.is_remote(ot.a.path) + assert o.b.a.path == ot.b.a.remote_source + assert o.b.b[0].remote_directory == ot.b.b[0].remote_directory + assert not ctx.file_access.is_remote(ot.b.b[0].path) + assert o.b.c["hello"].remote_directory == ot.b.c["hello"].remote_directory + assert not ctx.file_access.is_remote(ot.b.c["hello"].path) + assert o.b.d[0].path == ot.b.d[0].remote_source + assert o.b.e["hello"].path == ot.b.e["hello"].remote_source + + +def test_flyte_directory_in_dataclassjsonmixin(): + @dataclass + class TestInnerFileStruct_flyte_directory(DataClassJSONMixin): + a: TensorboardLogs + b: typing.List[FlyteDirectory] + c: typing.Dict[str, FlyteDirectory] + d: typing.List[FlyteDirectory] + e: typing.Dict[str, FlyteDirectory] + + @dataclass + class TestFileStruct_flyte_directory(DataClassJSONMixin): + a: FlyteDirectory + b: TestInnerFileStruct_flyte_directory + + remote_path = "s3://tmp/file" + tempdir = tempfile.mkdtemp(prefix="flyte-") + f1 = FlyteDirectory(tempdir) + f1._remote_source = remote_path + f2 = FlyteDirectory(remote_path) + o = TestFileStruct_flyte_directory( + a=f1, + b=TestInnerFileStruct_flyte_directory( + a=TensorboardLogs("s3://tensorboard"), + b=[f1], + c={"hello": f1}, + d=[f2], + e={"hello": f2}, + ), + ) + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(TestFileStruct_flyte_directory) + lv = tf.to_literal(ctx, o, TestFileStruct_flyte_directory, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestFileStruct_flyte_directory) + + assert ot.a._downloader is not noop + assert ot.b.a._downloader is not noop + assert ot.b.b[0]._downloader is not noop + assert ot.b.c["hello"]._downloader is not noop + + assert o.a.remote_directory == ot.a.remote_directory + assert not ctx.file_access.is_remote(ot.a.path) + assert o.b.a.path == ot.b.a.remote_source + assert o.b.b[0].remote_directory == ot.b.b[0].remote_directory + assert not ctx.file_access.is_remote(ot.b.b[0].path) + assert o.b.c["hello"].remote_directory == ot.b.c["hello"].remote_directory + assert not ctx.file_access.is_remote(ot.b.c["hello"].path) + assert o.b.d[0].path == ot.b.d[0].remote_source + assert o.b.e["hello"].path == ot.b.e["hello"].remote_source + + +@pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.") +def test_structured_dataset_in_dataclass(): + import pandas as pd + from pandas._testing import assert_frame_equal + + df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + People = Annotated[StructuredDataset, "parquet", kwtypes(Name=str, Age=int)] + + @dataclass + class InnerDatasetStruct(DataClassJsonMixin): + a: StructuredDataset + b: typing.List[Annotated[StructuredDataset, "parquet"]] + c: typing.Dict[str, Annotated[StructuredDataset, kwtypes(Name=str, Age=int)]] + + @dataclass + class DatasetStruct(DataClassJsonMixin): + a: People + b: InnerDatasetStruct + + sd = StructuredDataset(dataframe=df, file_format="parquet") + o = DatasetStruct(a=sd, b=InnerDatasetStruct(a=sd, b=[sd], c={"hello": sd})) + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(DatasetStruct) + lv = tf.to_literal(ctx, o, DatasetStruct, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=DatasetStruct) + + assert_frame_equal(df, ot.a.open(pd.DataFrame).all()) + assert_frame_equal(df, ot.b.a.open(pd.DataFrame).all()) + assert_frame_equal(df, ot.b.b[0].open(pd.DataFrame).all()) + assert_frame_equal(df, ot.b.c["hello"].open(pd.DataFrame).all()) + assert "parquet" == ot.a.file_format + assert "parquet" == ot.b.a.file_format + assert "parquet" == ot.b.b[0].file_format + assert "parquet" == ot.b.c["hello"].file_format + + +@pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.") +def test_structured_dataset_in_dataclassjsonmixin(): + @dataclass + class InnerDatasetStructDataclassJsonMixin(DataClassJSONMixin): + a: StructuredDataset + b: typing.List[Annotated[StructuredDataset, "parquet"]] + c: typing.Dict[str, Annotated[StructuredDataset, kwtypes(Name=str, Age=int)]] + + import pandas as pd + from pandas._testing import assert_frame_equal + + df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + People = Annotated[StructuredDataset, "parquet"] + + @dataclass + class DatasetStruct_dataclassjsonmixin(DataClassJSONMixin): + a: People + b: InnerDatasetStructDataclassJsonMixin + + sd = StructuredDataset(dataframe=df, file_format="parquet") + o = DatasetStruct_dataclassjsonmixin(a=sd, b=InnerDatasetStructDataclassJsonMixin(a=sd, b=[sd], c={"hello": sd})) + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(DatasetStruct_dataclassjsonmixin) + lv = tf.to_literal(ctx, o, DatasetStruct_dataclassjsonmixin, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=DatasetStruct_dataclassjsonmixin) + + assert_frame_equal(df, ot.a.open(pd.DataFrame).all()) + assert_frame_equal(df, ot.b.a.open(pd.DataFrame).all()) + assert_frame_equal(df, ot.b.b[0].open(pd.DataFrame).all()) + assert_frame_equal(df, ot.b.c["hello"].open(pd.DataFrame).all()) + assert "parquet" == ot.a.file_format + assert "parquet" == ot.b.a.file_format + assert "parquet" == ot.b.b[0].file_format + assert "parquet" == ot.b.c["hello"].file_format + + +# Enums should have string values +class Color(Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + + +class MultiInheritanceColor(str, Enum): + RED = auto() + GREEN = auto() + BLUE = auto() + + +# Enums with integer values are not supported +class UnsupportedEnumValues(Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + +@pytest.mark.skipif("polars" not in sys.modules, reason="pyarrow is not installed.") +@pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.") +def test_structured_dataset_type(): + import pandas as pd + import pyarrow as pa + from pandas._testing import assert_frame_equal + + name = "Name" + age = "Age" + data = {name: ["Tom", "Joseph"], age: [20, 22]} + superset_cols = kwtypes(Name=str, Age=int) + subset_cols = kwtypes(Name=str) + df = pd.DataFrame(data) + + tf = TypeEngine.get_transformer(StructuredDataset) + lt = tf.get_literal_type(Annotated[StructuredDataset, superset_cols, "parquet"]) + assert lt.structured_dataset_type is not None + + ctx = FlyteContextManager.current_context() + lv = tf.to_literal(ctx, df, pd.DataFrame, lt) + assert flyte_tmp_dir in lv.scalar.structured_dataset.uri + metadata = lv.scalar.structured_dataset.metadata + assert metadata.structured_dataset_type.format == "parquet" + v1 = tf.to_python_value(ctx, lv, pd.DataFrame) + v2 = tf.to_python_value(ctx, lv, pa.Table) + assert_frame_equal(df, v1) + assert_frame_equal(df, v2.to_pandas()) + + subset_lt = tf.get_literal_type(Annotated[StructuredDataset, subset_cols, "parquet"]) + assert subset_lt.structured_dataset_type is not None + + subset_lv = tf.to_literal(ctx, df, pd.DataFrame, subset_lt) + assert flyte_tmp_dir in subset_lv.scalar.structured_dataset.uri + v1 = tf.to_python_value(ctx, subset_lv, pd.DataFrame) + v2 = tf.to_python_value(ctx, subset_lv, pa.Table) + subset_data = pd.DataFrame({name: ["Tom", "Joseph"]}) + assert_frame_equal(subset_data, v1) + assert_frame_equal(subset_data, v2.to_pandas()) + + empty_lt = tf.get_literal_type(Annotated[StructuredDataset, "parquet"]) + assert empty_lt.structured_dataset_type is not None + empty_lv = tf.to_literal(ctx, df, pd.DataFrame, empty_lt) + v1 = tf.to_python_value(ctx, empty_lv, pd.DataFrame) + v2 = tf.to_python_value(ctx, empty_lv, pa.Table) + assert_frame_equal(df, v1) + assert_frame_equal(df, v2.to_pandas()) + + +def test_enum_type(): + t = TypeEngine.to_literal_type(Color) + assert t is not None + assert t.enum_type is not None + assert t.enum_type.values + assert t.enum_type.values == [c.value for c in Color] + + g = TypeEngine.guess_python_type(t) + assert [e.value for e in g] == [e.value for e in Color] + + ctx = FlyteContextManager.current_context() + lv = TypeEngine.to_literal(ctx, Color.RED, Color, TypeEngine.to_literal_type(Color)) + assert lv + assert lv.scalar + assert lv.scalar.primitive.string_value == "red" + + v = TypeEngine.to_python_value(ctx, lv, Color) + assert v + assert v == Color.RED + + v = TypeEngine.to_python_value(ctx, lv, str) + assert v + assert v == "red" + + with pytest.raises(ValueError): + TypeEngine.to_python_value( + ctx, + Literal(scalar=Scalar(primitive=Primitive(string_value=str(Color.RED)))), + Color, + ) + + with pytest.raises(ValueError): + TypeEngine.to_python_value(ctx, Literal(scalar=Scalar(primitive=Primitive(string_value="bad"))), Color) + + with pytest.raises(AssertionError): + TypeEngine.to_literal_type(UnsupportedEnumValues) + + +def test_multi_inheritance_enum_type(): + tfm = TypeEngine.get_transformer(MultiInheritanceColor) + assert isinstance(tfm, EnumTransformer) + + +def union_type_tags_unique(t: LiteralType): + seen = set() + for x in t.union_type.variants: + if x.structure.tag in seen: + return False + seen.add(x.structure.tag) + + return True + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.") +def test_union_type(): + pt = typing.Union[str, int] + lt = TypeEngine.to_literal_type(pt) + pt_604 = str | int + lt_604 = TypeEngine.to_literal_type(pt_604) + assert lt == lt_604 + assert lt.union_type.variants == [ + LiteralType(simple=SimpleType.STRING, structure=TypeStructure(tag="str")), + LiteralType(simple=SimpleType.INTEGER, structure=TypeStructure(tag="int")), + ] + assert union_type_tags_unique(lt) + + ctx = FlyteContextManager.current_context() + lv = TypeEngine.to_literal(ctx, 3, pt, lt) + v = TypeEngine.to_python_value(ctx, lv, pt) + assert lv.scalar.union.stored_type.structure.tag == "int" + assert lv.scalar.union.value.scalar.primitive.integer == 3 + assert v == 3 + + lv = TypeEngine.to_literal(ctx, "hello", pt, lt) + v = TypeEngine.to_python_value(ctx, lv, pt) + assert lv.scalar.union.stored_type.structure.tag == "str" + assert lv.scalar.union.value.scalar.primitive.string_value == "hello" + assert v == "hello" + + +def test_assert_dataclass_type(): + @dataclass + class Args(DataClassJsonMixin): + x: int + y: typing.Optional[str] + + @dataclass + class Schema(DataClassJsonMixin): + x: typing.Optional[Args] = None + + pt = Schema + lt = TypeEngine.to_literal_type(pt) + gt = TypeEngine.guess_python_type(lt) + pv = Schema(x=Args(x=3, y="hello")) + DataclassTransformer().assert_type(gt, pv) + DataclassTransformer().assert_type(Schema, pv) + + @dataclass + class Bar(DataClassJsonMixin): + x: int + + pv = Bar(x=3) + with pytest.raises( + TypeTransformerFailedError, + match="Type of Val '' is not an instance of ", + ): + DataclassTransformer().assert_type(gt, pv) + + +@pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.") +def test_assert_dict_type(): + import pandas as pd + + @dataclass + class AnotherDataClass(DataClassJsonMixin): + z: int + + @dataclass + class Args(DataClassJsonMixin): + x: int + y: typing.Optional[str] + file: FlyteFile + dataset: StructuredDataset + another_dataclass: AnotherDataClass + + pv = tempfile.mkdtemp(prefix="flyte-") + df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + sd = StructuredDataset(dataframe=df, file_format="parquet") + # Test when v is a dict + vd = { + "x": 3, + "y": "hello", + "file": FlyteFile(pv), + "dataset": sd, + "another_dataclass": {"z": 4}, + } + DataclassTransformer().assert_type(Args, vd) + + # Test when v is a dict but missing Optional keys and other keys from dataclass + md = {"x": 3, "file": FlyteFile(pv), "dataset": sd, "another_dataclass": {"z": 4}} + DataclassTransformer().assert_type(Args, md) + + # Test when v is a dict but missing non-Optional keys from dataclass + md = { + "y": "hello", + "file": FlyteFile(pv), + "dataset": sd, + "another_dataclass": {"z": 4}, + } + with pytest.raises( + TypeTransformerFailedError, + match=re.escape("The original fields are missing the following keys from the dataclass fields: ['x']"), + ): + DataclassTransformer().assert_type(Args, md) + + # Test when v is a dict but has extra keys that are not in dataclass + ed = { + "x": 3, + "y": "hello", + "file": FlyteFile(pv), + "dataset": sd, + "another_dataclass": {"z": 4}, + "z": "extra", + } + with pytest.raises( + TypeTransformerFailedError, + match=re.escape( + "The original fields have the following extra keys that are not in dataclass fields: ['z']"), + ): + DataclassTransformer().assert_type(Args, ed) + + # Test when the type of value in the dict does not match the expected_type in the dataclass + td = { + "x": "3", + "y": "hello", + "file": FlyteFile(pv), + "dataset": sd, + "another_dataclass": {"z": 4}, + } + with pytest.raises( + TypeTransformerFailedError, + match="Type of Val '' is not an instance of ", + ): + DataclassTransformer().assert_type(Args, td) + + +def test_to_literal_dict(): + @dataclass + class Args(DataClassJsonMixin): + x: int + y: typing.Optional[str] + + ctx = FlyteContext.current_context() + python_type = Args + expected = TypeEngine.to_literal_type(python_type) + + # Test when python_val is a dict + python_val = {"x": 3, "y": "hello"} + literal = DataclassTransformer().to_literal(ctx, python_val, python_type, expected) + literal_json = _json_format.MessageToJson(literal.scalar.generic) + assert json.loads(literal_json) == python_val + + # Test when python_val is not a dict and not a dataclass + python_val = "not a dict or dataclass" + with pytest.raises( + TypeTransformerFailedError, + match="not of type @dataclass, only Dataclasses are supported for user defined datatypes in Flytekit", + ): + DataclassTransformer().to_literal(ctx, python_val, python_type, expected) + + +@dataclass +class ArgsAssert(DataClassJSONMixin): + x: int + y: typing.Optional[str] + + +@dataclass +class SchemaArgsAssert(DataClassJSONMixin): + x: typing.Optional[ArgsAssert] + + +def test_assert_dataclassjsonmixin_type(): + pt = SchemaArgsAssert + lt = TypeEngine.to_literal_type(pt) + gt = TypeEngine.guess_python_type(lt) + pv = SchemaArgsAssert(x=ArgsAssert(x=3, y="hello")) + DataclassTransformer().assert_type(gt, pv) + DataclassTransformer().assert_type(SchemaArgsAssert, pv) + + @dataclass + class Bar(DataClassJSONMixin): + x: int + + pv = Bar(x=3) + with pytest.raises( + TypeTransformerFailedError, + match="Type of Val '' is not an instance of ", + ): + DataclassTransformer().assert_type(gt, pv) + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.") +def test_union_transformer(): + assert UnionTransformer.is_optional_type(typing.Optional[int]) + assert UnionTransformer.is_optional_type(int | None) + assert not UnionTransformer.is_optional_type(str) + assert UnionTransformer.get_sub_type_in_optional(typing.Optional[int]) == int + assert UnionTransformer.get_sub_type_in_optional(int | None) == int + assert not UnionTransformer.is_optional_type(typing.Union[int, str]) + assert UnionTransformer.is_optional_type(typing.Union[int, None]) + + +def test_union_guess_type(): + ut = UnionTransformer() + t = ut.guess_python_type( + LiteralType( + union_type=UnionType( + variants=[ + LiteralType(simple=SimpleType.STRING), + LiteralType(simple=SimpleType.INTEGER), + ] + ) + ) + ) + assert t == typing.Union[str, int] + + +def test_union_type_with_annotated(): + pt = typing.Union[ + Annotated[str, FlyteAnnotation({"hello": "world"})], + Annotated[int, FlyteAnnotation({"test": 123})], + ] + lt = TypeEngine.to_literal_type(pt) + assert lt.union_type.variants == [ + LiteralType( + simple=SimpleType.STRING, + structure=TypeStructure(tag="str"), + annotation=TypeAnnotation({"hello": "world"}), + ), + LiteralType( + simple=SimpleType.INTEGER, + structure=TypeStructure(tag="int"), + annotation=TypeAnnotation({"test": 123}), + ), + ] + assert union_type_tags_unique(lt) + + ctx = FlyteContextManager.current_context() + lv = TypeEngine.to_literal(ctx, 3, pt, lt) + v = TypeEngine.to_python_value(ctx, lv, pt) + assert lv.scalar.union.stored_type.structure.tag == "int" + assert lv.scalar.union.value.scalar.primitive.integer == 3 + assert v == 3 + + lv = TypeEngine.to_literal(ctx, "hello", pt, lt) + v = TypeEngine.to_python_value(ctx, lv, pt) + assert lv.scalar.union.stored_type.structure.tag == "str" + assert lv.scalar.union.value.scalar.primitive.string_value == "hello" + assert v == "hello" + + +def test_annotated_union_type(): + pt = Annotated[typing.Union[str, int], FlyteAnnotation({"hello": "world"})] + lt = TypeEngine.to_literal_type(pt) + assert lt.union_type.variants == [ + LiteralType(simple=SimpleType.STRING, structure=TypeStructure(tag="str")), + LiteralType(simple=SimpleType.INTEGER, structure=TypeStructure(tag="int")), + ] + assert lt.annotation == TypeAnnotation({"hello": "world"}) + assert union_type_tags_unique(lt) + + ctx = FlyteContextManager.current_context() + lv = TypeEngine.to_literal(ctx, 3, pt, lt) + v = TypeEngine.to_python_value(ctx, lv, pt) + assert lv.scalar.union.stored_type.structure.tag == "int" + assert lv.scalar.union.value.scalar.primitive.integer == 3 + assert v == 3 + + lv = TypeEngine.to_literal(ctx, "hello", pt, lt) + v = TypeEngine.to_python_value(ctx, lv, pt) + assert lv.scalar.union.stored_type.structure.tag == "str" + assert lv.scalar.union.value.scalar.primitive.string_value == "hello" + assert v == "hello" + + +def test_union_type_simple(): + pt = typing.Union[str, int] + lt = TypeEngine.to_literal_type(pt) + assert lt.union_type.variants == [ + LiteralType(simple=SimpleType.STRING, structure=TypeStructure(tag="str")), + LiteralType(simple=SimpleType.INTEGER, structure=TypeStructure(tag="int")), + ] + ctx = FlyteContextManager.current_context() + lv = TypeEngine.to_literal(ctx, 3, pt, lt) + assert lv.scalar.union is not None + assert lv.scalar.union.stored_type.structure.tag == "int" + assert lv.scalar.union.stored_type.structure.dataclass_type is None + + +def test_union_containers(): + pt = typing.Union[typing.List[typing.Dict[str, typing.List[int]]], typing.Dict[str, typing.List[int]], int] + lt = TypeEngine.to_literal_type(pt) + + list_of_maps_of_list_ints = [ + {"first_map_a": [42], "first_map_b": [42, 2]}, + { + "second_map_c": [33], + "second_map_d": [9, 99], + }, + ] + map_of_list_ints = { + "ll_1": [1, 23, 3], + "ll_2": [4, 5, 6], + } + ctx = FlyteContextManager.current_context() + lv = TypeEngine.to_literal(ctx, list_of_maps_of_list_ints, pt, lt) + assert lv.scalar.union.stored_type.structure.tag == "Typed List" + lv = TypeEngine.to_literal(ctx, map_of_list_ints, pt, lt) + assert lv.scalar.union.stored_type.structure.tag == "Typed Dict" + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.") +def test_optional_type(): + pt = typing.Optional[int] + lt = TypeEngine.to_literal_type(pt) + pt_604 = int | None + lt_604 = TypeEngine.to_literal_type(pt_604) + assert lt == lt_604 + assert lt.union_type.variants == [ + LiteralType(simple=SimpleType.INTEGER, structure=TypeStructure(tag="int")), + LiteralType(simple=SimpleType.NONE, structure=TypeStructure(tag="none")), + ] + assert union_type_tags_unique(lt) + + ctx = FlyteContextManager.current_context() + lv = TypeEngine.to_literal(ctx, 3, pt, lt) + v = TypeEngine.to_python_value(ctx, lv, pt) + assert lv.scalar.union.stored_type.structure.tag == "int" + assert lv.scalar.union.value.scalar.primitive.integer == 3 + assert v == 3 + + lv = TypeEngine.to_literal(ctx, None, pt, lt) + v = TypeEngine.to_python_value(ctx, lv, pt) + assert lv.scalar.union.stored_type.structure.tag == "none" + assert lv.scalar.union.value.scalar.none_type == Void() + assert v is None + + +def test_union_from_unambiguous_literal(): + pt = typing.Union[str, int] + lt = TypeEngine.to_literal_type(pt) + assert lt.union_type.variants == [ + LiteralType(simple=SimpleType.STRING, structure=TypeStructure(tag="str")), + LiteralType(simple=SimpleType.INTEGER, structure=TypeStructure(tag="int")), + ] + assert union_type_tags_unique(lt) + + ctx = FlyteContextManager.current_context() + lv = TypeEngine.to_literal(ctx, 3, int, lt) + assert lv.scalar.primitive.integer == 3 + + v = TypeEngine.to_python_value(ctx, lv, pt) + assert v == 3 + + pt = typing.Union[FlyteFile, FlyteDirectory] + temp_dir = tempfile.mkdtemp(prefix="temp_example_") + file_path = os.path.join(temp_dir, "file.txt") + with open(file_path, "w") as file1: + file1.write("hello world") + + lt = TypeEngine.to_literal_type(FlyteFile) + lv = TypeEngine.to_literal(ctx, file_path, FlyteFile, lt) + v = TypeEngine.to_python_value(ctx, lv, pt) + assert isinstance(v, FlyteFile) + lv = TypeEngine.to_literal(ctx, v, FlyteFile, lt) + assert os.path.isfile(lv.scalar.blob.uri) + + lt = TypeEngine.to_literal_type(FlyteDirectory) + lv = TypeEngine.to_literal(ctx, temp_dir, FlyteDirectory, lt) + v = TypeEngine.to_python_value(ctx, lv, pt) + assert isinstance(v, FlyteDirectory) + lv = TypeEngine.to_literal(ctx, v, FlyteDirectory, lt) + assert os.path.isdir(lv.scalar.blob.uri) + + +def test_union_custom_transformer(): + class MyInt: + def __init__(self, x: int): + self.val = x + + def __eq__(self, other): + if not isinstance(other, MyInt): + return False + return other.val == self.val + + TypeEngine.register( + SimpleTransformer( + "MyInt", + MyInt, + LiteralType(simple=SimpleType.INTEGER), + lambda x: Literal(scalar=Scalar(primitive=Primitive(integer=x.val))), + lambda x: MyInt(x.scalar.primitive.integer), + ) + ) + + pt = typing.Union[int, MyInt] + lt = TypeEngine.to_literal_type(pt) + assert lt.union_type.variants == [ + LiteralType(simple=SimpleType.INTEGER, structure=TypeStructure(tag="int")), + LiteralType(simple=SimpleType.INTEGER, structure=TypeStructure(tag="MyInt")), + ] + assert union_type_tags_unique(lt) + + ctx = FlyteContextManager.current_context() + lv = TypeEngine.to_literal(ctx, 3, pt, lt) + v = TypeEngine.to_python_value(ctx, lv, pt) + assert lv.scalar.union.stored_type.structure.tag == "int" + assert lv.scalar.union.value.scalar.primitive.integer == 3 + assert v == 3 + + lv = TypeEngine.to_literal(ctx, MyInt(10), pt, lt) + v = TypeEngine.to_python_value(ctx, lv, pt) + assert lv.scalar.union.stored_type.structure.tag == "MyInt" + assert lv.scalar.union.value.scalar.primitive.integer == 10 + assert v == MyInt(10) + + lv = TypeEngine.to_literal(ctx, 4, int, LiteralType(simple=SimpleType.INTEGER)) + assert lv.scalar.primitive.integer == 4 + try: + TypeEngine.to_python_value(ctx, lv, pt) + except TypeError as e: + assert "Ambiguous choice of variant" in str(e) + + del TypeEngine._REGISTRY[MyInt] + + +def test_union_custom_transformer_sanity_check(): + class UnsignedInt: + def __init__(self, x: int): + self.val = x + + def __eq__(self, other): + if not isinstance(other, UnsignedInt): + return False + return other.val == self.val + + # This transformer will not work in the implicit wrapping case + class UnsignedIntTransformer(TypeTransformer[UnsignedInt]): + def __init__(self): + super().__init__("UnsignedInt", UnsignedInt) + + def get_literal_type(self, t: typing.Type[T]) -> LiteralType: + return LiteralType(simple=SimpleType.INTEGER) + + def to_literal( + self, + ctx: FlyteContext, + python_val: T, + python_type: typing.Type[T], + expected: LiteralType, + ) -> Literal: + if type(python_val) != int: + raise TypeTransformerFailedError("Expected an integer") + + if python_val < 0: + raise TypeTransformerFailedError("Expected a non-negative integer") + + return Literal(scalar=Scalar(primitive=Primitive(integer=python_val))) + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Type[T]) -> Literal: + val = lv.scalar.primitive.integer + return UnsignedInt(0 if val < 0 else val) # type: ignore + + TypeEngine.register(UnsignedIntTransformer()) + + pt = typing.Union[int, UnsignedInt] + lt = TypeEngine.to_literal_type(pt) + assert lt.union_type.variants == [ + LiteralType(simple=SimpleType.INTEGER, structure=TypeStructure(tag="int")), + LiteralType(simple=SimpleType.INTEGER, structure=TypeStructure(tag="UnsignedInt")), + ] + assert union_type_tags_unique(lt) + + ctx = FlyteContextManager.current_context() + with pytest.raises(TypeError, match="Ambiguous choice of variant for union type"): + TypeEngine.to_literal(ctx, 3, pt, lt) + + del TypeEngine._REGISTRY[UnsignedInt] + + +def test_union_of_lists(): + pt = typing.Union[typing.List[int], typing.List[str]] + lt = TypeEngine.to_literal_type(pt) + assert lt.union_type.variants == [ + LiteralType( + collection_type=LiteralType(simple=SimpleType.INTEGER), + structure=TypeStructure(tag="Typed List"), + ), + LiteralType( + collection_type=LiteralType(simple=SimpleType.STRING), + structure=TypeStructure(tag="Typed List"), + ), + ] + # Tags are deliberately NOT unique because they are not required to encode the deep type structure, + # only the top-level type transformer choice + # + # The stored typed will be used to differentiate union variants and must produce a unique choice. + assert not union_type_tags_unique(lt) + + ctx = FlyteContextManager.current_context() + lv = TypeEngine.to_literal(ctx, ["hello", "world"], pt, lt) + v = TypeEngine.to_python_value(ctx, lv, pt) + assert lv.scalar.union.stored_type.structure.tag == "Typed List" + assert [x.scalar.primitive.string_value for x in lv.scalar.union.value.collection.literals] == ["hello", "world"] + assert v == ["hello", "world"] + + lv = TypeEngine.to_literal(ctx, [1, 3], pt, lt) + v = TypeEngine.to_python_value(ctx, lv, pt) + assert lv.scalar.union.stored_type.structure.tag == "Typed List" + assert [x.scalar.primitive.integer for x in lv.scalar.union.value.collection.literals] == [1, 3] + assert v == [1, 3] + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.") +def test_list_of_unions(): + pt = typing.List[typing.Union[str, int]] + lt = TypeEngine.to_literal_type(pt) + pt_604 = typing.List[str | int] + lt_604 = TypeEngine.to_literal_type(pt_604) + assert lt == lt_604 + # todo(maximsmol): seems like the order here is non-deterministic + assert lt.collection_type.union_type.variants == [ + LiteralType(simple=SimpleType.STRING, structure=TypeStructure(tag="str")), + LiteralType(simple=SimpleType.INTEGER, structure=TypeStructure(tag="int")), + ] + assert union_type_tags_unique(lt.collection_type) # tags are deliberately NOT unique + + ctx = FlyteContextManager.current_context() + lv = TypeEngine.to_literal(ctx, ["hello", 123, "world"], pt, lt) + v = TypeEngine.to_python_value(ctx, lv, pt) + lv_604 = TypeEngine.to_literal(ctx, ["hello", 123, "world"], pt_604, lt_604) + v_604 = TypeEngine.to_python_value(ctx, lv_604, pt_604) + assert [x.scalar.union.stored_type.structure.tag for x in lv.collection.literals] == ["str", "int", "str"] + assert v == v_604 == ["hello", 123, "world"] + + +def test_pickle_type(): + class Foo(object): + def __init__(self, number: int): + self.number = number + + lt = TypeEngine.to_literal_type(FlytePickle) + assert lt.blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT + assert lt.blob.dimensionality == BlobType.BlobDimensionality.SINGLE + + ctx = FlyteContextManager.current_context() + lv = TypeEngine.to_literal(ctx, Foo(1), FlytePickle, lt) + assert flyte_tmp_dir in lv.scalar.blob.uri + + transformer = FlytePickleTransformer() + gt = transformer.guess_python_type(lt) + pv = transformer.to_python_value(ctx, lv, expected_python_type=gt) + assert Foo(1).number == pv.number + + with pytest.raises(AssertionError, match="Cannot pickle None Value"): + lt = TypeEngine.to_literal_type(typing.Optional[typing.Any]) + TypeEngine.to_literal(ctx, None, FlytePickle, lt) + + with pytest.raises( + AssertionError, + match="Expected value of type but got '1' of type ", + ): + lt = TypeEngine.to_literal_type(typing.Optional[typing.Any]) + TypeEngine.to_literal(ctx, 1, type(None), lt) + + lt = TypeEngine.to_literal_type(typing.Optional[typing.Any]) + TypeEngine.to_literal(ctx, 1, typing.Optional[typing.Any], lt) + + +def test_enum_in_dataclass(): + @dataclass + class Datum(DataClassJsonMixin): + x: int + y: Color + + lt = TypeEngine.to_literal_type(Datum) + schema = Datum.schema() + schema.fields["y"].load_by = LoadDumpOptions.name + assert lt.metadata == JSONSchema().dump(schema) + + transformer = DataclassTransformer() + ctx = FlyteContext.current_context() + datum = Datum(5, Color.RED) + lv = transformer.to_literal(ctx, datum, Datum, lt) + gt = transformer.guess_python_type(lt) + pv = transformer.to_python_value(ctx, lv, expected_python_type=gt) + assert datum.x == pv.x + assert datum.y.value == pv.y + + +def test_enum_in_dataclassjsonmixin(): + @dataclass + class Datum(DataClassJSONMixin): + x: int + y: Color + + lt = TypeEngine.to_literal_type(Datum) + from mashumaro.jsonschema import build_json_schema + + schema = build_json_schema(typing.cast(DataClassJSONMixin, Datum)).to_dict() + assert lt.metadata == schema + + transformer = DataclassTransformer() + ctx = FlyteContext.current_context() + datum = Datum(5, Color.RED) + lv = transformer.to_literal(ctx, datum, Datum, lt) + gt = transformer.guess_python_type(lt) + pv = transformer.to_python_value(ctx, lv, expected_python_type=gt) + assert datum.x == pv.x + assert datum.y.value == pv.y + + +@pytest.mark.parametrize( + "python_value,python_types,expected_literal_map", + [ + ( + {"a": [1, 2, 3]}, + {"a": typing.List[int]}, + LiteralMap( + literals={ + "a": Literal( + collection=LiteralCollection( + literals=[ + Literal(scalar=Scalar(primitive=Primitive(integer=1))), + Literal(scalar=Scalar(primitive=Primitive(integer=2))), + Literal(scalar=Scalar(primitive=Primitive(integer=3))), + ] + ) + ) + } + ), + ), + ( + {"p1": {"k1": "v1", "k2": "2"}}, + {"p1": typing.Dict[str, str]}, + LiteralMap( + literals={ + "p1": Literal( + map=LiteralMap( + literals={ + "k1": Literal(scalar=Scalar(primitive=Primitive(string_value="v1"))), + "k2": Literal(scalar=Scalar(primitive=Primitive(string_value="2"))), + }, + ) + ) + } + ), + ), + ( + {"p1": "s3://tmp/file.jpeg"}, + {"p1": JPEGImageFile}, + LiteralMap( + literals={ + "p1": Literal( + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata( + type=BlobType( + format="jpeg", + dimensionality=BlobType.BlobDimensionality.SINGLE, + ) + ), + uri="s3://tmp/file.jpeg", + ) + ) + ) + } + ), + ), + ], +) +def test_dict_to_literal_map(python_value, python_types, expected_literal_map): + ctx = FlyteContext.current_context() + + assert TypeEngine.dict_to_literal_map(ctx, python_value, python_types) == expected_literal_map + + +def test_dict_to_literal_map_with_dataclass(): + @dataclass + class InnerStruct(DataClassJsonMixin): + a: int + b: typing.Optional[str] + c: typing.List[int] + + @dataclass + class TestStructD(DataClassJsonMixin): + s: InnerStruct + m: typing.Dict[str, typing.List[int]] + + ctx = FlyteContext.current_context() + python_value = {"p1": TestStructD(s=InnerStruct(a=5, b=None, c=[1, 2, 3]), m={"a": [5]})} + python_types = {"p1": TestStructD} + + literal = TypeEngine.to_literal(ctx, python_value["p1"], TestStructD, TypeEngine.to_literal_type(TestStructD)) + expected_literal_map = LiteralMap( + literals={ + "p1": literal + } + ) + assert TypeEngine.dict_to_literal_map(ctx, python_value, python_types) == expected_literal_map + + +def test_dict_to_literal_map_with_wrong_input_type(): + ctx = FlyteContext.current_context() + input = {"a": 1} + guessed_python_types = {"a": str} + with pytest.raises(user_exceptions.FlyteTypeException): + TypeEngine.dict_to_literal_map(ctx, input, guessed_python_types) + + +def test_nested_annotated(): + """ + Test to show that nested Annotated types are flattened. + """ + pt = Annotated[Annotated[int, "inner-annotation"], "outer-annotation"] + lt = TypeEngine.to_literal_type(pt) + assert lt.simple == model_types.SimpleType.INTEGER + + ctx = FlyteContextManager.current_context() + lv = TypeEngine.to_literal(ctx, 42, pt, lt) + v = TypeEngine.to_python_value(ctx, lv, pt) + assert v == 42 + + +@pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.") +def test_pass_annotated_to_downstream_tasks(): + """ + Test to confirm that the loaded dataframe is not affected and can be used in @dynamic. + """ + import pandas as pd + + # pandas dataframe hash function + def hash_pandas_dataframe(df: pd.DataFrame) -> str: + return str(pd.util.hash_pandas_object(df)) + + @task + def t0(a: int) -> Annotated[int, HashMethod(function=str)]: + return a + 1 + + @task + def annotated_return_task() -> Annotated[pd.DataFrame, HashMethod(hash_pandas_dataframe)]: + return pd.DataFrame({"column_1": [1, 2, 3]}) + + @task(cache=True, cache_version="42") + def downstream_t(a: int, df: pd.DataFrame) -> int: + return a + 2 + len(df) + + @dynamic + def t1(a: int) -> int: + v = t0(a=a) + df = annotated_return_task() + + # We should have a cache miss in the first call to downstream_t + v_1 = downstream_t(a=v, df=df) + downstream_t(a=v, df=df) + + return v_1 + + assert t1(a=3) == 9 + + +def test_literal_hash_int_can_be_set(): + """ + Test to confirm that annotating an integer with `HashMethod` is allowed. + """ + ctx = FlyteContext.current_context() + lv = TypeEngine.to_literal( + ctx, + 42, + Annotated[int, HashMethod(str)], + LiteralType(simple=model_types.SimpleType.INTEGER), + ) + assert lv.scalar.primitive.integer == 42 + assert lv.hash == "42" + + +@pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.") +def test_literal_hash_to_python_value(): + """ + Test to confirm that literals can be converted to python values, regardless of the hash value set in the literal. + """ + import pandas as pd + + from flytekit.types.schema.types_pandas import PandasDataFrameTransformer + + ctx = FlyteContext.current_context() + + def constant_hash(df: pd.DataFrame) -> str: + return "h4Sh" + + df = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]}) + pandas_df_transformer = PandasDataFrameTransformer() + literal_with_hash_set = TypeEngine.to_literal( + ctx, + df, + Annotated[pd.DataFrame, HashMethod(constant_hash)], + pandas_df_transformer.get_literal_type(pd.DataFrame), + ) + assert literal_with_hash_set.hash == "h4Sh" + # Confirm that the loaded dataframe is not affected + python_df = TypeEngine.to_python_value(ctx, literal_with_hash_set, pd.DataFrame) + expected_df = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]}) + assert expected_df.equals(python_df) + + +def test_annotated_simple_types(): + @dataclass + class InnerStruct(DataClassJsonMixin): + a: int + b: typing.Optional[str] + c: typing.List[int] + + def _check_annotation(t, annotation): + lt = TypeEngine.to_literal_type(t) + assert isinstance(lt.annotation, TypeAnnotation) + assert lt.annotation.annotations == annotation + + _check_annotation( + typing_extensions.Annotated[int, FlyteAnnotation({"foo": "bar"})], + {"foo": "bar"}, + ) + _check_annotation( + typing_extensions.Annotated[int, FlyteAnnotation(["foo", "bar"])], + ["foo", "bar"], + ) + _check_annotation( + typing_extensions.Annotated[int, FlyteAnnotation({"d": {"test": "data"}, "l": ["nested", ["list"]]})], + {"d": {"test": "data"}, "l": ["nested", ["list"]]}, + ) + _check_annotation( + typing_extensions.Annotated[int, FlyteAnnotation(InnerStruct(a=1, b="fizz", c=[1]))], + InnerStruct(a=1, b="fizz", c=[1]), + ) + + +def test_annotated_list(): + t = typing_extensions.Annotated[typing.List[int], FlyteAnnotation({"foo": "bar"})] + lt = TypeEngine.to_literal_type(t) + assert isinstance(lt.annotation, TypeAnnotation) + assert lt.annotation.annotations == {"foo": "bar"} + + t = typing.List[typing_extensions.Annotated[int, FlyteAnnotation({"foo": "bar"})]] + lt = TypeEngine.to_literal_type(t) + assert isinstance(lt.collection_type.annotation, TypeAnnotation) + assert lt.collection_type.annotation.annotations == {"foo": "bar"} + + +def test_type_alias(): + inner_t = typing_extensions.Annotated[int, FlyteAnnotation("foo")] + t = typing_extensions.Annotated[inner_t, FlyteAnnotation("bar")] + with pytest.raises(ValueError): + TypeEngine.to_literal_type(t) + + +def test_multiple_annotations(): + t = typing_extensions.Annotated[int, FlyteAnnotation({"foo": "bar"}), FlyteAnnotation({"anotha": "one"})] + with pytest.raises(Exception): + TypeEngine.to_literal_type(t) + + +TestSchema = FlyteSchema[kwtypes(some_str=str)] # type: ignore + + +@dataclass +class InnerResult(DataClassJsonMixin): + number: int + schema: TestSchema # type: ignore + + +@dataclass +class Result(DataClassJsonMixin): + result: InnerResult + schema: TestSchema # type: ignore + + +def get_unsupported_complex_literals_tests(): + if sys.version_info < (3, 9): + return [ + typing_extensions.Annotated[typing.Dict[int, str], FlyteAnnotation({"foo": "bar"})], + typing_extensions.Annotated[typing.Dict[str, str], FlyteAnnotation({"foo": "bar"})], + typing_extensions.Annotated[Color, FlyteAnnotation({"foo": "bar"})], + typing_extensions.Annotated[Result, FlyteAnnotation({"foo": "bar"})], + ] + return [ + typing_extensions.Annotated[dict, FlyteAnnotation({"foo": "bar"})], + typing_extensions.Annotated[dict[int, str], FlyteAnnotation({"foo": "bar"})], + typing_extensions.Annotated[typing.Dict[int, str], FlyteAnnotation({"foo": "bar"})], + typing_extensions.Annotated[dict[str, str], FlyteAnnotation({"foo": "bar"})], + typing_extensions.Annotated[typing.Dict[str, str], FlyteAnnotation({"foo": "bar"})], + typing_extensions.Annotated[Color, FlyteAnnotation({"foo": "bar"})], + typing_extensions.Annotated[Result, FlyteAnnotation({"foo": "bar"})], + ] + + +@pytest.mark.parametrize( + "t", + get_unsupported_complex_literals_tests(), +) +def test_unsupported_complex_literals(t): + with pytest.raises(ValueError): + TypeEngine.to_literal_type(t) + + +@dataclass +class DataclassTest(DataClassJsonMixin): + a: int + b: str + + +@dataclass +class AnnotatedDataclassTest(DataClassJsonMixin): + a: int + b: Annotated[str, "str tag"] + + +@pytest.mark.parametrize( + "t,expected_type", + [ + (dict, LiteralType(simple=SimpleType.STRUCT)), + # Annotations are not being copied over to the LiteralType + ( + typing_extensions.Annotated[dict, "a-tag"], + LiteralType(simple=SimpleType.STRUCT), + ), + (typing.Dict[int, str], LiteralType(simple=SimpleType.STRUCT)), + ( + typing.Dict[str, int], + LiteralType(map_value_type=LiteralType(simple=SimpleType.INTEGER)), + ), + ( + typing.Dict[str, str], + LiteralType(map_value_type=LiteralType(simple=SimpleType.STRING)), + ), + ( + typing.Dict[str, typing.List[int]], + LiteralType(map_value_type=LiteralType(collection_type=LiteralType(simple=SimpleType.INTEGER))), + ), + (typing.Dict[int, typing.List[int]], LiteralType(simple=SimpleType.STRUCT)), + ( + typing.Dict[int, typing.Dict[int, int]], + LiteralType(simple=SimpleType.STRUCT), + ), + ( + typing.Dict[str, typing.Dict[int, int]], + LiteralType(map_value_type=LiteralType(simple=SimpleType.STRUCT)), + ), + ( + typing.Dict[str, typing.Dict[str, int]], + LiteralType(map_value_type=LiteralType(map_value_type=LiteralType(simple=SimpleType.INTEGER))), + ), + ( + DataclassTest, + LiteralType( + simple=SimpleType.STRUCT, + metadata={ + "$schema": "http://json-schema.org/draft-07/schema#", + "definitions": { + "DataclasstestSchema": { + "properties": { + "a": {"title": "a", "type": "integer"}, + "b": {"title": "b", "type": "string"}, + }, + "type": "object", + "additionalProperties": False, + } + }, + "$ref": "#/definitions/DataclasstestSchema", + }, + structure=TypeStructure( + tag="", + dataclass_type={ + "a": LiteralType(simple=SimpleType.INTEGER), + "b": LiteralType(simple=SimpleType.STRING), + }, + ), + ), + ), + # Similar to the dict[int, str] case, the annotation is not being copied over to the LiteralType + ( + Annotated[DataclassTest, "another-tag"], + LiteralType( + simple=SimpleType.STRUCT, + metadata={ + "$schema": "http://json-schema.org/draft-07/schema#", + "definitions": { + "DataclasstestSchema": { + "properties": { + "a": {"title": "a", "type": "integer"}, + "b": {"title": "b", "type": "string"}, + }, + "type": "object", + "additionalProperties": False, + } + }, + "$ref": "#/definitions/DataclasstestSchema", + }, + structure=TypeStructure( + tag="", + dataclass_type={ + "a": LiteralType(simple=SimpleType.INTEGER), + "b": LiteralType(simple=SimpleType.STRING), + }, + ), + ), + ), + # Notice how the annotation in the field is not carried over either + ( + Annotated[AnnotatedDataclassTest, "another-tag"], + LiteralType( + simple=SimpleType.STRUCT, + metadata={ + "$schema": "http://json-schema.org/draft-07/schema#", + "definitions": { + "AnnotateddataclasstestSchema": { + "properties": { + "a": {"title": "a", "type": "integer"}, + "b": {"title": "b", "type": "string"}, + }, + "type": "object", + "additionalProperties": False, + } + }, + "$ref": "#/definitions/AnnotateddataclasstestSchema", + }, + structure=TypeStructure( + tag="", + dataclass_type={ + "a": LiteralType(simple=SimpleType.INTEGER), + "b": LiteralType(simple=SimpleType.STRING), + }, + ), + ), + ), + ], +) +def test_annotated_dicts(t, expected_type): + assert TypeEngine.to_literal_type(t) == expected_type + + +@pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.") +def test_schema_in_dataclass(): + import pandas as pd + + schema = TestSchema() + df = pd.DataFrame(data={"some_str": ["a", "b", "c"]}) + schema.open().write(df) + o = Result(result=InnerResult(number=1, schema=schema), schema=schema) + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(Result) + lv = tf.to_literal(ctx, o, Result, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=Result) + + assert o == ot + assert o.result.schema.remote_path == ot.result.schema.remote_path + assert o.result.number == ot.result.number + assert o.schema.remote_path == ot.schema.remote_path + + +@pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.") +def test_union_in_dataclass(): + import pandas as pd + + schema = TestSchema() + df = pd.DataFrame(data={"some_str": ["a", "b", "c"]}) + schema.open().write(df) + o = Result(result=InnerResult(number=1, schema=schema), schema=schema) + ctx = FlyteContext.current_context() + tf = UnionTransformer() + pt = typing.Union[Result, InnerResult] + lt = tf.get_literal_type(pt) + lv = tf.to_literal(ctx, o, pt, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=pt) + + return o == ot + assert o.result.schema.remote_path == ot.result.schema.remote_path + assert o.result.number == ot.result.number + assert o.schema.remote_path == ot.schema.remote_path + + +@dataclass +class InnerResult_dataclassjsonmixin(DataClassJSONMixin): + number: int + schema: TestSchema # type: ignore + + +@dataclass +class Result_dataclassjsonmixin(DataClassJSONMixin): + result: InnerResult_dataclassjsonmixin + schema: TestSchema # type: ignore + + +@pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.") +def test_schema_in_dataclassjsonmixin(): + import pandas as pd + + schema = TestSchema() + df = pd.DataFrame(data={"some_str": ["a", "b", "c"]}) + schema.open().write(df) + o = Result(result=InnerResult(number=1, schema=schema), schema=schema) + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(Result) + lv = tf.to_literal(ctx, o, Result, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=Result) + + assert o == ot + assert o.result.schema.remote_path == ot.result.schema.remote_path + assert o.result.number == ot.result.number + assert o.schema.remote_path == ot.schema.remote_path + + +def test_guess_of_dataclass(): + @dataclass + class Foo(DataClassJsonMixin): + x: int + y: str + z: typing.Dict[str, int] + + def hello(self): + ... + + lt = TypeEngine.to_literal_type(Foo) + foo = Foo(1, "hello", {"world": 3}) + lv = TypeEngine.to_literal(FlyteContext.current_context(), foo, Foo, lt) + lit_dict = {"a": lv} + lr = LiteralsResolver(lit_dict) + assert lr.get("a", Foo) == foo + assert hasattr(lr.get("a", Foo), "hello") is True + + +def test_guess_of_dataclassjsonmixin(): + @dataclass + class Foo(DataClassJSONMixin): + x: int + y: str + z: typing.Dict[str, int] + + def hello(self): + ... + + lt = TypeEngine.to_literal_type(Foo) + foo = Foo(1, "hello", {"world": 3}) + lv = TypeEngine.to_literal(FlyteContext.current_context(), foo, Foo, lt) + lit_dict = {"a": lv} + lr = LiteralsResolver(lit_dict) + assert lr.get("a", Foo) == foo + assert hasattr(lr.get("a", Foo), "hello") is True + + +def test_flyte_dir_in_union(): + pt = typing.Union[str, FlyteDirectory, FlyteFile] + lt = TypeEngine.to_literal_type(pt) + ctx = FlyteContext.current_context() + tf = UnionTransformer() + + pv = tempfile.mkdtemp(prefix="flyte-") + lv = tf.to_literal(ctx, FlyteDirectory(pv), pt, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=pt) + assert ot is not None + + pv = "s3://bucket/key" + lv = tf.to_literal(ctx, FlyteFile(pv), pt, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=pt) + assert ot is not None + + pv = "hello" + lv = tf.to_literal(ctx, pv, pt, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=pt) + assert ot == "hello" + + +def test_file_ext_with_flyte_file_existing_file(): + assert JPEGImageFile.extension() == "jpeg" + + +def test_file_ext_convert_static_method(): + TAR_GZ = Annotated[str, FileExt("tar.gz")] + item = FileExt.check_and_convert_to_str(TAR_GZ) + assert item == "tar.gz" + + str_item = FileExt.check_and_convert_to_str("csv") + assert str_item == "csv" + + +def test_file_ext_with_flyte_file_new_file(): + TAR_GZ = Annotated[str, FileExt("tar.gz")] + flyte_file = FlyteFile[TAR_GZ] + assert flyte_file.extension() == "tar.gz" + + +class WrongType: + def __init__(self, num: int): + self.num = num + + +def test_file_ext_with_flyte_file_wrong_type(): + WRONG_TYPE = Annotated[int, WrongType(2)] + with pytest.raises(ValueError) as e: + FlyteFile[WRONG_TYPE] + assert str(e.value) == "Underlying type of File Extension must be of type " + + +@pytest.mark.parametrize( + "t,expected", + [ + (list, False), + (Annotated[int, "tag"], True), + (Annotated[typing.List[str], "a", "b"], True), + (Annotated[typing.Dict[int, str], FlyteAnnotation({"foo": "bar"})], True), + ], +) +def test_is_annotated(t, expected): + assert is_annotated(t) == expected + + +@pytest.mark.parametrize( + "t,expected", + [ + (typing.List, typing.List), + (Annotated[int, "tag"], int), + (Annotated[typing.List[str], "a", "b"], typing.List[str]), + ], +) +def test_get_underlying_type(t, expected): + assert get_underlying_type(t) == expected + + +@pytest.mark.parametrize( + "t,expected", + [ + (None, (None, None)), + (typing.Dict, ()), + (typing.Dict[str, str], (str, str)), + ( + Annotated[typing.Dict[str, str], kwtypes(allow_pickle=True)], + (typing.Dict[str, str], kwtypes(allow_pickle=True)), + ), + (typing.Dict[Annotated[str, "a-tag"], int], (Annotated[str, "a-tag"], int)), + ], +) +def test_dict_get(t, expected): + assert DictTransformer.extract_types_or_metadata(t) == expected + + +def test_DataclassTransformer_get_literal_type(): + @dataclass + class MyDataClassMashumaro(DataClassJsonMixin): + x: int + + @dataclass + class MyDataClassMashumaroORJSON(DataClassJsonMixin): + x: int + + @dataclass_json + @dataclass + class MyDataClass: + x: int + + de = DataclassTransformer() + + literal_type = de.get_literal_type(MyDataClass) + assert literal_type is not None + + literal_type = de.get_literal_type(MyDataClassMashumaro) + assert literal_type is not None + + literal_type = de.get_literal_type(MyDataClassMashumaroORJSON) + assert literal_type is not None + + invalid_json_str = "{ unbalanced_braces" + + with pytest.raises(Exception): + Literal(scalar=Scalar(generic=_json_format.Parse(invalid_json_str, _struct.Struct()))) + + @dataclass + class Fruit(DataClassJSONMixin): + name: str + + @dataclass + class NestedFruit(DataClassJSONMixin): + sub_fruit: Fruit + name: str + + literal_type = de.get_literal_type(NestedFruit) + dataclass_type = literal_type.structure.dataclass_type + assert dataclass_type["sub_fruit"].simple == SimpleType.STRUCT + assert dataclass_type["sub_fruit"].structure.dataclass_type["name"].simple == SimpleType.STRING + assert dataclass_type["name"].simple == SimpleType.STRING + + +def test_DataclassTransformer_to_literal(): + @dataclass + class MyDataClassMashumaro(DataClassJsonMixin): + x: int + + @dataclass + class MyDataClassMashumaroORJSON(DataClassORJSONMixin): + x: int + + @dataclass_json + @dataclass + class MyDataClass: + x: int + + transformer = DataclassTransformer() + ctx = FlyteContext.current_context() + + my_dat_class_mashumaro = MyDataClassMashumaro(5) + my_dat_class_mashumaro_orjson = MyDataClassMashumaroORJSON(5) + my_data_class = MyDataClass(5) + + lv_mashumaro = transformer.to_literal(ctx, my_dat_class_mashumaro, MyDataClassMashumaro, MyDataClassMashumaro) + assert lv_mashumaro is not None + assert lv_mashumaro.scalar.generic["x"] == 5 + + lv_mashumaro_orjson = transformer.to_literal( + ctx, + my_dat_class_mashumaro_orjson, + MyDataClassMashumaroORJSON, + MyDataClassMashumaroORJSON, + ) + assert lv_mashumaro_orjson is not None + assert lv_mashumaro_orjson.scalar.generic["x"] == 5 + + lv = transformer.to_literal(ctx, my_data_class, MyDataClass, MyDataClass) + assert lv is not None + assert lv.scalar.generic["x"] == 5 + + + +def test_DataclassTransformer_to_python_value(): + @dataclass + class MyDataClassMashumaro(DataClassJsonMixin): + x: int + + @dataclass + class MyDataClassMashumaroORJSON(DataClassORJSONMixin): + x: int + + @dataclass_json + @dataclass + class MyDataClass: + x: int + + de = DataclassTransformer() + + json_str = '{ "x" : 5 }' + mock_literal = Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) + + result = de.to_python_value(FlyteContext.current_context(), mock_literal, MyDataClass) + assert isinstance(result, MyDataClass) + assert result.x == 5 + + result = de.to_python_value(FlyteContext.current_context(), mock_literal, MyDataClassMashumaro) + assert isinstance(result, MyDataClassMashumaro) + assert result.x == 5 + + result = de.to_python_value(FlyteContext.current_context(), mock_literal, MyDataClassMashumaroORJSON) + assert isinstance(result, MyDataClassMashumaroORJSON) + assert result.x == 5 + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="dataclass(kw_only=True) requires >=3.10.") +def test_DataclassTransformer_with_discriminated_subtypes(): + class SubclassTypes(str, Enum): + BASE = auto() + CLASS_A = auto() + CLASS_B = auto() + + @dataclass(kw_only=True) + class BaseClass(DataClassJSONMixin): + class Config(BaseConfig): + discriminator = Discriminator( + field="subclass_type", + include_subtypes=True, + ) + + subclass_type: SubclassTypes = SubclassTypes.BASE + base_attribute: int + + @dataclass(kw_only=True) + class ClassA(BaseClass): + subclass_type: SubclassTypes = SubclassTypes.CLASS_A + class_a_attribute: str + + @dataclass(kw_only=True) + class ClassB(BaseClass): + subclass_type: SubclassTypes = SubclassTypes.CLASS_B + class_b_attribute: float + + @task + def assert_class_and_return(instance: BaseClass) -> BaseClass: + assert hasattr(instance, "class_a_attribute") or hasattr(instance, "class_b_attribute") + return instance + + class_a = ClassA(base_attribute=4, class_a_attribute="hello") + assert "class_a_attribute" in class_a.to_json() + res_1 = assert_class_and_return(class_a) + assert res_1.base_attribute == 4 + assert isinstance(res_1, ClassA) + assert res_1.class_a_attribute == "hello" + + class_b = ClassB(base_attribute=4, class_b_attribute=-2.5) + assert "class_b_attribute" in class_b.to_json() + res_2 = assert_class_and_return(class_b) + assert res_2.base_attribute == 4 + assert isinstance(res_2, ClassB) + assert res_2.class_b_attribute == -2.5 + + +def test_DataclassTransformer_with_sub_dataclasses(): + @dataclass + class Base: + a: int + + @dataclass + class Child1(Base): + b: int + + @task + def get_data() -> Child1: + return Child1(a=10, b=12) + + @task + def read_data(base: Base) -> int: + return base.a + + @task + def read_child(child: Child1) -> int: + return child.b + + @workflow + def wf1() -> Child1: + data = get_data() + read_data(base=data) + read_child(child=data) + return data + + @workflow + def wf2() -> Base: + data = Base(a=10) + read_data(base=data) + read_child(child=data) + return data + + @workflow + def wf3() -> Base: + data = Base(a=10) + read_data(base=data) + return data + + child_data = wf1() + assert child_data.a == 10 + assert child_data.b == 12 + assert isinstance(child_data, Child1) + + with pytest.raises(AttributeError): + wf2() + + base_data = wf3() + assert base_data.a == 10 + + +def test_DataclassTransformer_guess_python_type(): + @dataclass + class DatumMashumaroORJSON(DataClassORJSONMixin): + x: int + y: Color + z: datetime.datetime + + @dataclass + class DatumMashumaro(DataClassJSONMixin): + x: int + y: Color + + @dataclass_json + @dataclass + class DatumDataclassJson(DataClassJSONMixin): + x: int + y: Color + + @dataclass + class DatumDataclass: + x: int + y: Color + + @dataclass + class DatumDataUnion: + data: typing.Union[str, float] + + transformer = TypeEngine.get_transformer(DatumDataUnion) + ctx = FlyteContext.current_context() + + lt = TypeEngine.to_literal_type(DatumDataUnion) + datum_dataunion = DatumDataUnion(data="s3://my-file") + lv = transformer.to_literal(ctx, datum_dataunion, DatumDataUnion, lt) + gt = transformer.guess_python_type(lt) + pv = transformer.to_python_value(ctx, lv, expected_python_type=DatumDataUnion) + assert datum_dataunion.data == pv.data + + datum_dataunion = DatumDataUnion(data="0.123") + lv = transformer.to_literal(ctx, datum_dataunion, DatumDataUnion, lt) + gt = transformer.guess_python_type(lt) + pv = transformer.to_python_value(ctx, lv, expected_python_type=gt) + assert datum_dataunion.data == pv.data + + lt = TypeEngine.to_literal_type(DatumDataclass) + datum_dataclass = DatumDataclass(5, Color.RED) + lv = transformer.to_literal(ctx, datum_dataclass, DatumDataclass, lt) + gt = transformer.guess_python_type(lt) + pv = transformer.to_python_value(ctx, lv, expected_python_type=gt) + assert datum_dataclass.x == pv.x + assert datum_dataclass.y.value == pv.y + + lt = TypeEngine.to_literal_type(DatumDataclassJson) + datum = DatumDataclassJson(5, Color.RED) + lv = transformer.to_literal(ctx, datum, DatumDataclassJson, lt) + gt = transformer.guess_python_type(lt) + pv = transformer.to_python_value(ctx, lv, expected_python_type=gt) + assert datum.x == pv.x + assert datum.y.value == pv.y + + lt = TypeEngine.to_literal_type(DatumMashumaro) + datum_mashumaro = DatumMashumaro(5, Color.RED) + lv = transformer.to_literal(ctx, datum_mashumaro, DatumMashumaro, lt) + gt = transformer.guess_python_type(lt) + pv = transformer.to_python_value(ctx, lv, expected_python_type=gt) + assert datum_mashumaro.x == pv.x + assert datum_mashumaro.y.value == pv.y + + lt = TypeEngine.to_literal_type(DatumMashumaroORJSON) + now = datetime.datetime.now() + datum_mashumaro_orjson = DatumMashumaroORJSON(5, Color.RED, now) + lv = transformer.to_literal(ctx, datum_mashumaro_orjson, DatumMashumaroORJSON, lt) + gt = transformer.guess_python_type(lt) + pv = transformer.to_python_value(ctx, lv, expected_python_type=gt) + assert datum_mashumaro_orjson.x == pv.x + assert datum_mashumaro_orjson.y.value == pv.y + assert datum_mashumaro_orjson.z.isoformat() == pv.z + + +def test_dataclass_encoder_and_decoder_registry(): + iterations = 10 + + @dataclass + class Datum: + x: int + y: str + z: typing.Dict[int, int] + w: List[int] + + @task + def create_dataclasses() -> List[Datum]: + return [Datum(x=1, y="1", z={1: 1}, w=[1, 1, 1, 1])] + + @task + def concat_dataclasses(x: List[Datum], y: List[Datum]) -> List[Datum]: + return x + y + + @dynamic + def dynamic_wf() -> List[Datum]: + all_dataclasses: List[Datum] = [] + for _ in range(iterations): + data = create_dataclasses() + all_dataclasses = concat_dataclasses(x=all_dataclasses, y=data) + return all_dataclasses + + @workflow + def wf() -> List[Datum]: + return dynamic_wf() + + datum_list = wf() + assert len(datum_list) == iterations + + transformer = TypeEngine.get_transformer(Datum) + assert transformer._json_encoder.get(Datum) + assert transformer._json_decoder.get(Datum) + + +def test_ListTransformer_get_sub_type(): + assert ListTransformer.get_sub_type_or_none(typing.List[str]) is str + + +def test_ListTransformer_get_sub_type_as_none(): + assert ListTransformer.get_sub_type_or_none(type([])) is None + + +def test_union_file_directory(): + lt = TypeEngine.to_literal_type(FlyteFile) + s3_file = "s3://my-file" + + transformer = FlyteFilePathTransformer() + ctx = FlyteContext.current_context() + lv = transformer.to_literal(ctx, s3_file, FlyteFile, lt) + + union_trans = UnionTransformer() + pv = union_trans.to_python_value(ctx, lv, typing.Union[FlyteFile, FlyteDirectory]) + assert pv._remote_source == s3_file + + s3_dir = "s3://my-dir" + transformer = FlyteDirToMultipartBlobTransformer() + ctx = FlyteContext.current_context() + lv = transformer.to_literal(ctx, s3_dir, FlyteFile, lt) + + pv = union_trans.to_python_value(ctx, lv, typing.Union[FlyteFile, FlyteDirectory]) + assert pv._remote_source == s3_dir + + +@pytest.mark.parametrize( + "pt,pv", + [ + (bool, True), + (bool, False), + (int, 42), + (str, "hello"), + (Annotated[int, "tag"], 42), + (typing.List[int], [1, 2, 3]), + (typing.List[str], ["a", "b", "c"]), + (typing.List[Color], [Color.RED, Color.GREEN, Color.BLUE]), + (typing.List[Annotated[int, "tag"]], [1, 2, 3]), + (typing.List[Annotated[str, "tag"]], ["a", "b", "c"]), + (typing.Dict[int, str], {1: "a", 2: "b", 3: "c"}), + (typing.Dict[str, int], {"a": 1, "b": 2, "c": 3}), + (typing.Dict[str, typing.List[int]], {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}), + (typing.Dict[str, typing.Dict[int, str]], {"a": {1: "a", 2: "b", 3: "c"}, "b": {4: "d", 5: "e", 6: "f"}}), + (typing.Union[int, str], 42), + (typing.Union[int, str], "hello"), + (typing.Union[typing.List[int], typing.List[str]], [1, 2, 3]), + (typing.Union[typing.List[int], typing.List[str]], ["a", "b", "c"]), + (typing.Union[typing.List[int], str], [1, 2, 3]), + (typing.Union[typing.List[int], str], "hello"), + ((typing.Union[dict, str]), {"a": 1, "b": 2, "c": 3}), + ((typing.Union[dict, str]), "hello"), + ], +) +def test_offloaded_literal(tmp_path, pt, pv): + ctx = FlyteContext.current_context() + + lt = TypeEngine.to_literal_type(pt) + to_be_offloaded_lv = TypeEngine.to_literal(ctx, pv, pt, lt) + + # Write offloaded_lv as bytes to a temp file + with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f: + f.write(to_be_offloaded_lv.to_flyte_idl().SerializeToString()) + + literal = Literal( + offloaded_metadata=LiteralOffloadedMetadata( + uri=f"{tmp_path}/offloaded_proto.pb", + inferred_type=lt, + ), + ) + + loaded_pv = TypeEngine.to_python_value(ctx, literal, pt) + assert loaded_pv == pv + + +def test_offloaded_literal_with_inferred_type(): + ctx = FlyteContext.current_context() + lt = TypeEngine.to_literal_type(str) + offloaded_literal_missing_uri = Literal( + offloaded_metadata=LiteralOffloadedMetadata( + inferred_type=lt, + ), + ) + with pytest.raises(AssertionError): + TypeEngine.to_python_value(ctx, offloaded_literal_missing_uri, str) + + +def test_offloaded_literal_dataclass(tmp_path): + @dataclass + class InnerDatum(DataClassJsonMixin): + x: int + y: str + + @dataclass + class Datum(DataClassJsonMixin): + inner: InnerDatum + x: int + y: str + z: typing.Dict[int, int] + w: List[int] + + datum = Datum( + inner=InnerDatum(x=1, y="1"), + x=1, + y="1", + z={1: 1}, + w=[1, 1, 1, 1], + ) + + ctx = FlyteContext.current_context() + lt = TypeEngine.to_literal_type(Datum) + to_be_offloaded_lv = TypeEngine.to_literal(ctx, datum, Datum, lt) + + # Write offloaded_lv as bytes to a temp file + with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f: + f.write(to_be_offloaded_lv.to_flyte_idl().SerializeToString()) + + literal = Literal( + offloaded_metadata=LiteralOffloadedMetadata( + uri=f"{tmp_path}/offloaded_proto.pb", + inferred_type=lt, + ), + ) + + loaded_datum = TypeEngine.to_python_value(ctx, literal, Datum) + assert loaded_datum == datum + + +def test_offloaded_literal_flytefile(tmp_path): + ctx = FlyteContext.current_context() + lt = TypeEngine.to_literal_type(FlyteFile) + to_be_offloaded_lv = TypeEngine.to_literal(ctx, "s3://my-file", FlyteFile, lt) + + # Write offloaded_lv as bytes to a temp file + with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f: + f.write(to_be_offloaded_lv.to_flyte_idl().SerializeToString()) + + literal = Literal( + offloaded_metadata=LiteralOffloadedMetadata( + uri=f"{tmp_path}/offloaded_proto.pb", + inferred_type=lt, + ), + ) + + loaded_pv = TypeEngine.to_python_value(ctx, literal, FlyteFile) + assert loaded_pv._remote_source == "s3://my-file" + + +def test_offloaded_literal_flytedirectory(tmp_path): + ctx = FlyteContext.current_context() + lt = TypeEngine.to_literal_type(FlyteDirectory) + to_be_offloaded_lv = TypeEngine.to_literal(ctx, "s3://my-dir", FlyteDirectory, lt) + + # Write offloaded_lv as bytes to a temp file + with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f: + f.write(to_be_offloaded_lv.to_flyte_idl().SerializeToString()) + + literal = Literal( + offloaded_metadata=LiteralOffloadedMetadata( + uri=f"{tmp_path}/offloaded_proto.pb", + inferred_type=lt, + ), + ) + + loaded_pv: FlyteDirectory = TypeEngine.to_python_value(ctx, literal, FlyteDirectory) + assert loaded_pv._remote_source == "s3://my-dir" + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.") +def test_dataclass_none_output_input_deserialization(): + @dataclass + class OuterWorkflowInput(DataClassJSONMixin): + input: float + + @dataclass + class OuterWorkflowOutput(DataClassJSONMixin): + nullable_output: float | None = None + + @dataclass + class InnerWorkflowInput(DataClassJSONMixin): + input: float + + @dataclass + class InnerWorkflowOutput(DataClassJSONMixin): + nullable_output: float | None = None + + @task + def inner_task(input: float) -> float | None: + if input == 0.0: + return None + return input + + @task + def wrap_inner_inputs(input: float) -> InnerWorkflowInput: + return InnerWorkflowInput(input=input) + + @task + def wrap_inner_outputs(output: float | None) -> InnerWorkflowOutput: + return InnerWorkflowOutput(nullable_output=output) + + @task + def wrap_outer_outputs(output: float | None) -> OuterWorkflowOutput: + return OuterWorkflowOutput(nullable_output=output) + + @workflow + def inner_workflow(input: InnerWorkflowInput) -> InnerWorkflowOutput: + return wrap_inner_outputs( + output=inner_task( + input=input.input + ) + ) + + @workflow + def outer_workflow(input: OuterWorkflowInput) -> OuterWorkflowOutput: + inner_outputs = inner_workflow( + input=wrap_inner_inputs(input=input.input) + ) + return wrap_outer_outputs( + output=inner_outputs.nullable_output + ) + + float_value_output = outer_workflow(OuterWorkflowInput(input=1.0)).nullable_output + assert float_value_output == 1.0, f"Float value was {float_value_output}, not 1.0 as expected" + none_value_output = outer_workflow(OuterWorkflowInput(input=0.0)).nullable_output + assert none_value_output is None, f"None value was {none_value_output}, not None as expected" + + +@pytest.mark.serial +def test_lazy_import_transformers_concurrently(): + # Ensure that next call to TypeEngine.lazy_import_transformers doesn't skip the import. Mark as serial to ensure + # this achieves what we expect. + TypeEngine.has_lazy_import = False + + # Configure the mocks similar to https://stackoverflow.com/questions/29749193/python-unit-testing-with-two-mock-objects-how-to-verify-call-order + after_import_mock, mock_register = mock.Mock(), mock.Mock() + mock_wrapper = mock.Mock() + mock_wrapper.mock_register = mock_register + mock_wrapper.after_import_mock = after_import_mock + + with mock.patch.object(StructuredDatasetTransformerEngine, "register", new=mock_register): + def run(): + TypeEngine.lazy_import_transformers() + after_import_mock() + + N = 5 + with ThreadPoolExecutor(max_workers=N) as executor: + futures = [executor.submit(run) for _ in range(N)] + [f.result() for f in futures] + + # Assert that all the register calls come before anything else. + assert mock_wrapper.mock_calls[-N:] == [mock.call.after_import_mock()] * N + expected_number_of_register_calls = len(mock_wrapper.mock_calls) - N + assert all([mock_call[0] == "mock_register" for mock_call in + mock_wrapper.mock_calls[:expected_number_of_register_calls]]) + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10, 585 requires >=3.9") +def test_option_list_with_pipe(): + pt = list[int] | None + lt = TypeEngine.to_literal_type(pt) + + ctx = FlyteContextManager.current_context() + lit = TypeEngine.to_literal(ctx, [1, 2, 3], pt, lt) + assert lit.scalar.union.value.collection.literals[2].scalar.primitive.integer == 3 + + TypeEngine.to_literal(ctx, None, pt, lt) + + with pytest.raises(TypeTransformerFailedError): + TypeEngine.to_literal(ctx, [1, 2, "3"], pt, lt) + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10, 585 requires >=3.9") +def test_option_list_with_pipe_2(): + pt = list[list[dict[str, str]] | None] | None + lt = TypeEngine.to_literal_type(pt) + + ctx = FlyteContextManager.current_context() + lit = TypeEngine.to_literal(ctx, [[{"a": "one"}], None, [{"b": "two"}]], pt, lt) + uv = lit.scalar.union.value + assert uv is not None + assert len(uv.collection.literals) == 3 + first = uv.collection.literals[0] + assert first.scalar.union.value.collection.literals[0].map.literals["a"].scalar.primitive.string_value == "one" + + assert len(lt.union_type.variants) == 2 + v1 = lt.union_type.variants[0] + assert len(v1.collection_type.union_type.variants) == 2 + assert v1.collection_type.union_type.variants[0].collection_type.map_value_type.simple == SimpleType.STRING + + with pytest.raises(TypeTransformerFailedError): + TypeEngine.to_literal(ctx, [[{"a": "one"}], None, [{"b": 3}]], pt, lt) + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10, 585 requires >=3.9") +def test_generic_errors_and_empty(): + # Test dictionaries + pt = dict[str, str] + lt = TypeEngine.to_literal_type(pt) + + ctx = FlyteContextManager.current_context() + lit = TypeEngine.to_literal(ctx, {}, pt, lt) + lit = TypeEngine.to_literal(ctx, {"a": "b"}, pt, lt) + + with pytest.raises(TypeTransformerFailedError): + TypeEngine.to_literal(ctx, {"a": 3}, pt, lt) + + with pytest.raises(ValueError): + TypeEngine.to_literal(ctx, {3: "a"}, pt, lt) + + # Test lists + pt = list[str] + lt = TypeEngine.to_literal_type(pt) + lit = TypeEngine.to_literal(ctx, [], pt, lt) + lit = TypeEngine.to_literal(ctx, ["a"], pt, lt) + + with pytest.raises(TypeTransformerFailedError): + TypeEngine.to_literal(ctx, {"a": 3}, pt, lt) + + with pytest.raises(TypeTransformerFailedError): + TypeEngine.to_literal(ctx, [3], pt, lt) + + +def generate_type_engine_transformer_comprehensive_tests(): + # Test dataclasses + @dataclass + class DataClass(DataClassJsonMixin): + a: int + b: str + + class Test: + a: str + b: int + + T = typing.TypeVar("T") + + class TestGeneric(typing.Generic[T]): + a: str + b: int + + # Test annotated types + AnnotatedInt = Annotated[int, "tag"] + AnnotatedFloat = Annotated[float, "tag"] + AnnotatedStr = Annotated[str, "tag"] + AnnotatedBool = Annotated[bool, "tag"] + AnnotatedList = Annotated[List[str], "tag"] + AnnotatedDict = Annotated[Dict[str, str], "tag"] + Annotatedx3Int = Annotated[Annotated[Annotated[int, "tag"], "tag2"], "tag3"] + + # Test generics + ListInt = List[int] + ListStr = List[str] + DictIntStr = Dict[str, str] + ListAnnotatedInt = List[AnnotatedInt] + DictAnnotatedIntStr = Dict[str, AnnotatedStr] + + # Test regular types + Int = int + Str = str + + CallableType = typing.Callable[[int, str], int] + CallableTypeAnnotated = Annotated[CallableType, "tag"] + CallableTypeList = List[CallableType] + + IteratorType = typing.Iterator[int] + IteratorTypeAnnotated = Annotated[IteratorType, "tag"] + IteratorTypeList = List[IteratorType] + + People = Annotated[StructuredDataset, "parquet", kwtypes(Name=str, Age=int)] + PeopleDeepAnnotated = Annotated[Annotated[StructuredDataset, "parquet", kwtypes(Name=str, Age=int)], "tag"] + + AnyType = typing.Any + AnyTypeAnnotated = Annotated[AnyType, "tag"] + AnyTypeAnnotatedList = List[AnyTypeAnnotated] + + UnionType = typing.Union[int, str] + UnionTypeAnnotated = Annotated[UnionType, "tag"] + + OptionalType = typing.Optional[int] + OptionalTypeAnnotated = Annotated[OptionalType, "tag"] + + WineType = Annotated[StructuredDataset, kwtypes(alcohol=float, malic_acid=float)] + WineTypeList = List[WineType] + WineTypeListList = List[WineTypeList] + WineTypeDict = Dict[str, WineType] + + IntPickle = Annotated[int, FlytePickleTransformer()] + AnnotatedIntPickle = Annotated[Annotated[int, "tag"], FlytePickleTransformer()] + + # Test combinations + return [ + (DataClass, DataclassTransformer), + (AnnotatedInt, IntTransformer), + (AnnotatedFloat, FloatTransformer), + (AnnotatedStr, StrTransformer), + (Annotatedx3Int, IntTransformer), + (ListInt, ListTransformer), + (ListStr, ListTransformer), + (DictIntStr, DictTransformer), + (Int, IntTransformer), + (Str, StrTransformer), + (AnnotatedBool, BoolTransformer), + (AnnotatedList, ListTransformer), + (AnnotatedDict, DictTransformer), + (ListAnnotatedInt, ListTransformer), + (DictAnnotatedIntStr, DictTransformer), + (CallableType, FlytePickleTransformer), + (CallableTypeAnnotated, FlytePickleTransformer), + (CallableTypeList, ListTransformer), + (IteratorType, IteratorTransformer), + (IteratorTypeAnnotated, IteratorTransformer), + (IteratorTypeList, ListTransformer), + (People, StructuredDatasetTransformerEngine), + (PeopleDeepAnnotated, StructuredDatasetTransformerEngine), + (WineType, StructuredDatasetTransformerEngine), + (WineTypeList, ListTransformer), + (AnyType, FlytePickleTransformer), + (AnyTypeAnnotated, FlytePickleTransformer), + (UnionType, UnionTransformer), + (UnionTypeAnnotated, UnionTransformer), + (OptionalType, UnionTransformer), + (OptionalTypeAnnotated, UnionTransformer), + (Test, FlytePickleTransformer), + (TestGeneric, FlytePickleTransformer), + (typing.Iterable[int], FlytePickleTransformer), + (typing.Sequence[int], FlytePickleTransformer), + (IntPickle, FlytePickleTransformer), + (AnnotatedIntPickle, FlytePickleTransformer), + (typing.Iterator[JSON], JSONIteratorTransformer), + (JSONIterator, JSONIteratorTransformer), + (AnyTypeAnnotatedList, ListTransformer), + (WineTypeListList, ListTransformer), + (WineTypeDict, DictTransformer), + ] + + +@pytest.mark.parametrize("t, expected_transformer", generate_type_engine_transformer_comprehensive_tests()) +def test_type_engine_get_transformer_comprehensive(t, expected_transformer): + """ + This test will test various combinations like dataclasses, annotated types, generics and regular types and + assert the right transformers are returned. + """ + if isinstance(expected_transformer, SimpleTransformer): + underlying_type = expected_transformer.base_type + assert isinstance(TypeEngine.get_transformer(t), SimpleTransformer) + assert TypeEngine.get_transformer(t).base_type == underlying_type + else: + assert isinstance(TypeEngine.get_transformer(t), expected_transformer) + + +if sys.version_info >= (3, 10): + @pytest.mark.parametrize("t, expected_variants", [ + (int | float, [int, float]), + (int | float | None, [int, float, type(None)]), + (int | float | str, [int, float, str]), + ]) + @pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.") + def test_union_type_comprehensive_604(t, expected_variants): + """ + This test will test various combinations like dataclasses, annotated types, generics and regular types and + assert the right transformers are returned. + """ + transformer = TypeEngine.get_transformer(t) + assert isinstance(transformer, UnionTransformer) + lt = transformer.get_literal_type(t) + assert [TypeEngine.guess_python_type(i) for i in lt.union_type.variants] == expected_variants + + +@pytest.mark.parametrize("t, expected_variants", [ + (typing.Union[int, str], [int, str]), + (typing.Union[str, None], [str, type(None)]), + (typing.Optional[int], [int, type(None)]), +]) +def test_union_comprehensive(t, expected_variants): + """ + This test will test various combinations like dataclasses, annotated types, generics and regular types and + assert the right transformers are returned. + """ + transformer = TypeEngine.get_transformer(t) + assert isinstance(transformer, UnionTransformer) + lt = transformer.get_literal_type(t) + assert [TypeEngine.guess_python_type(i) for i in lt.union_type.variants] == expected_variants + + +@pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.") +def test_structured_dataset_collection(): + WineType = Annotated[StructuredDataset, kwtypes(alcohol=float, malic_acid=float)] + WineTypeList = List[WineType] + WineTypeListList = List[WineTypeList] + + import pandas as pd + df = pd.DataFrame({"alcohol": [1.0, 2.0], "malic_acid": [2.0, 3.0]}) + + TypeEngine.to_literal(FlyteContext.current_context(), StructuredDataset(df), + WineType, TypeEngine.to_literal_type(WineType)) + + transformer = TypeEngine.get_transformer(WineTypeListList) + assert isinstance(transformer, ListTransformer) + lt = transformer.get_literal_type(WineTypeListList) + cols = lt.collection_type.collection_type.structured_dataset_type.columns + assert cols[0].name == "alcohol" + assert cols[0].literal_type.simple == SimpleType.FLOAT + assert cols[1].name == "malic_acid" + assert cols[1].literal_type.simple == SimpleType.FLOAT + + sd = StructuredDataset(df, format="parquet") + lv = TypeEngine.to_literal(FlyteContext.current_context(), [[sd]], WineTypeListList, lt) + assert lv is not None + + lv = TypeEngine.to_literal(FlyteContext.current_context(), [[StructuredDataset(df)]], + WineTypeListList, lt) + assert lv is not None + + +@pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.") +def test_structured_dataset_mismatch(): + import pandas as pd + + df = pd.DataFrame({"alcohol": [1.0, 2.0], "malic_acid": [2.0, 3.0]}) + transformer = TypeEngine.get_transformer(StructuredDataset) + with pytest.raises(TypeTransformerFailedError): + transformer.to_literal(FlyteContext.current_context(), df, StructuredDataset, TypeEngine.to_literal_type(StructuredDataset)) + + with pytest.raises(TypeTransformerFailedError): + TypeEngine.to_literal(FlyteContext.current_context(), df, StructuredDataset, TypeEngine.to_literal_type(StructuredDataset)) diff --git a/tests/flytekit/unit/core/test_type_engine_binary_idl.py b/tests/flytekit/unit/core/test_type_engine_binary_idl.py index 3426e8021d..348525b943 100644 --- a/tests/flytekit/unit/core/test_type_engine_binary_idl.py +++ b/tests/flytekit/unit/core/test_type_engine_binary_idl.py @@ -13,6 +13,7 @@ from mashumaro.codecs.msgpack import MessagePackEncoder from flytekit import task, workflow +from flytekit.core.constants import MESSAGEPACK from flytekit.core.context_manager import FlyteContextManager from flytekit.core.type_engine import DataclassTransformer, TypeEngine from flytekit.models.literals import Binary, Literal, Scalar @@ -34,10 +35,8 @@ def test_simple_type_transformer(): for int_input in int_inputs: int_msgpack_bytes = encoder.encode(int_input) lv = Literal( - scalar=Scalar( - binary=Binary( - value=int_msgpack_bytes, - tag="msgpack"))) + scalar=Scalar(binary=Binary(value=int_msgpack_bytes, tag=MESSAGEPACK)) + ) int_output = TypeEngine.to_python_value(ctx, lv, int) assert int_input == int_output @@ -46,10 +45,8 @@ def test_simple_type_transformer(): for float_input in float_inputs: float_msgpack_bytes = encoder.encode(float_input) lv = Literal( - scalar=Scalar( - binary=Binary( - value=float_msgpack_bytes, - tag="msgpack"))) + scalar=Scalar(binary=Binary(value=float_msgpack_bytes, tag=MESSAGEPACK)) + ) float_output = TypeEngine.to_python_value(ctx, lv, float) assert float_input == float_output @@ -58,10 +55,8 @@ def test_simple_type_transformer(): for bool_input in bool_inputs: bool_msgpack_bytes = encoder.encode(bool_input) lv = Literal( - scalar=Scalar( - binary=Binary( - value=bool_msgpack_bytes, - tag="msgpack"))) + scalar=Scalar(binary=Binary(value=bool_msgpack_bytes, tag=MESSAGEPACK)) + ) bool_output = TypeEngine.to_python_value(ctx, lv, bool) assert bool_input == bool_output @@ -70,81 +65,72 @@ def test_simple_type_transformer(): for str_input in str_inputs: str_msgpack_bytes = encoder.encode(str_input) lv = Literal( - scalar=Scalar( - binary=Binary( - value=str_msgpack_bytes, - tag="msgpack"))) + scalar=Scalar(binary=Binary(value=str_msgpack_bytes, tag=MESSAGEPACK)) + ) str_output = TypeEngine.to_python_value(ctx, lv, str) assert str_input == str_output - datetime_inputs = [datetime.now(), - datetime(2024, 9, 18), - datetime(2024, 9, 18, 1), - datetime(2024, 9, 18, 1, 1), - datetime(2024, 9, 18, 1, 1, 1), - datetime(2024, 9, 18, 1, 1, 1, 1)] + datetime_inputs = [ + datetime.now(), + datetime(2024, 9, 18), + datetime(2024, 9, 18, 1), + datetime(2024, 9, 18, 1, 1), + datetime(2024, 9, 18, 1, 1, 1), + datetime(2024, 9, 18, 1, 1, 1, 1), + ] encoder = MessagePackEncoder(datetime) for datetime_input in datetime_inputs: datetime_msgpack_bytes = encoder.encode(datetime_input) lv = Literal( - scalar=Scalar( - binary=Binary( - value=datetime_msgpack_bytes, - tag="msgpack"))) + scalar=Scalar(binary=Binary(value=datetime_msgpack_bytes, tag=MESSAGEPACK)) + ) datetime_output = TypeEngine.to_python_value(ctx, lv, datetime) assert datetime_input == datetime_output - date_inputs = [date.today(), - date(2024, 9, 18)] + date_inputs = [date.today(), date(2024, 9, 18)] encoder = MessagePackEncoder(date) for date_input in date_inputs: date_msgpack_bytes = encoder.encode(date_input) lv = Literal( - scalar=Scalar( - binary=Binary( - value=date_msgpack_bytes, - tag="msgpack"))) + scalar=Scalar(binary=Binary(value=date_msgpack_bytes, tag=MESSAGEPACK)) + ) date_output = TypeEngine.to_python_value(ctx, lv, date) assert date_input == date_output - timedelta_inputs = [timedelta(days=1), - timedelta(days=1, seconds=1), - timedelta(days=1, seconds=1, microseconds=1), - timedelta( - days=1, - seconds=1, - microseconds=1, - milliseconds=1), + timedelta_inputs = [ + timedelta(days=1), + timedelta(days=1, seconds=1), + timedelta(days=1, seconds=1, microseconds=1), + timedelta(days=1, seconds=1, microseconds=1, milliseconds=1), + timedelta(days=1, seconds=1, microseconds=1, milliseconds=1, minutes=1), timedelta( - days=1, - seconds=1, - microseconds=1, - milliseconds=1, - minutes=1), + days=1, seconds=1, microseconds=1, milliseconds=1, minutes=1, hours=1 + ), timedelta( - days=1, - seconds=1, - microseconds=1, - milliseconds=1, - minutes=1, - hours=1), + days=1, + seconds=1, + microseconds=1, + milliseconds=1, + minutes=1, + hours=1, + weeks=1, + ), timedelta( - days=1, - seconds=1, - microseconds=1, - milliseconds=1, - minutes=1, - hours=1, - weeks=1), - timedelta(days=-1, seconds=-1, microseconds=-1, milliseconds=-1, minutes=-1, hours=-1, weeks=-1)] + days=-1, + seconds=-1, + microseconds=-1, + milliseconds=-1, + minutes=-1, + hours=-1, + weeks=-1, + ), + ] encoder = MessagePackEncoder(timedelta) for timedelta_input in timedelta_inputs: timedelta_msgpack_bytes = encoder.encode(timedelta_input) lv = Literal( - scalar=Scalar( - binary=Binary( - value=timedelta_msgpack_bytes, - tag="msgpack"))) + scalar=Scalar(binary=Binary(value=timedelta_msgpack_bytes, tag=MESSAGEPACK)) + ) timedelta_output = TypeEngine.to_python_value(ctx, lv, timedelta) assert timedelta_input == timedelta_output @@ -167,8 +153,7 @@ def test_untyped_dict(): }, { "list_in_dict": [ - {"inner_dict_1": [1, -2.5, "a"], - "inner_dict_2": [True, False, 3.14]}, + {"inner_dict_1": [1, -2.5, "a"], "inner_dict_2": [True, False, 3.14]}, [1, -2, 3, {"nested_list_dict": [False, "test"]}], ] }, @@ -201,10 +186,7 @@ def test_untyped_dict(): # dict_msgpack_bytes = msgpack.dumps(dict_input) dict_msgpack_bytes = MessagePackEncoder(dict).encode(dict_input) lv = Literal( - scalar=Scalar( - binary=Binary( - value=dict_msgpack_bytes, - tag="msgpack")) + scalar=Scalar(binary=Binary(value=dict_msgpack_bytes, tag=MESSAGEPACK)) ) dict_output = TypeEngine.to_python_value(ctx, lv, dict) assert dict_input == dict_output @@ -217,10 +199,7 @@ def test_list_transformer(): encoder = MessagePackEncoder(List[int]) list_int_msgpack_bytes = encoder.encode(list_int_input) lv = Literal( - scalar=Scalar( - binary=Binary( - value=list_int_msgpack_bytes, - tag="msgpack")) + scalar=Scalar(binary=Binary(value=list_int_msgpack_bytes, tag=MESSAGEPACK)) ) list_int_output = TypeEngine.to_python_value(ctx, lv, List[int]) assert list_int_input == list_int_output @@ -229,10 +208,7 @@ def test_list_transformer(): encoder = MessagePackEncoder(List[float]) list_float_msgpack_bytes = encoder.encode(list_float_input) lv = Literal( - scalar=Scalar( - binary=Binary( - value=list_float_msgpack_bytes, - tag="msgpack")) + scalar=Scalar(binary=Binary(value=list_float_msgpack_bytes, tag=MESSAGEPACK)) ) list_float_output = TypeEngine.to_python_value(ctx, lv, List[float]) assert list_float_input == list_float_output @@ -241,10 +217,7 @@ def test_list_transformer(): encoder = MessagePackEncoder(List[str]) list_str_msgpack_bytes = encoder.encode(list_str_input) lv = Literal( - scalar=Scalar( - binary=Binary( - value=list_str_msgpack_bytes, - tag="msgpack")) + scalar=Scalar(binary=Binary(value=list_str_msgpack_bytes, tag=MESSAGEPACK)) ) list_str_output = TypeEngine.to_python_value(ctx, lv, List[str]) assert list_str_input == list_str_output @@ -253,10 +226,7 @@ def test_list_transformer(): encoder = MessagePackEncoder(List[bool]) list_bool_msgpack_bytes = encoder.encode(list_bool_input) lv = Literal( - scalar=Scalar( - binary=Binary( - value=list_bool_msgpack_bytes, - tag="msgpack")) + scalar=Scalar(binary=Binary(value=list_bool_msgpack_bytes, tag=MESSAGEPACK)) ) list_bool_output = TypeEngine.to_python_value(ctx, lv, List[bool]) assert list_bool_input == list_bool_output @@ -265,10 +235,7 @@ def test_list_transformer(): encoder = MessagePackEncoder(List[List[int]]) list_list_int_msgpack_bytes = encoder.encode(list_list_int_input) lv = Literal( - scalar=Scalar( - binary=Binary( - value=list_list_int_msgpack_bytes, - tag="msgpack")) + scalar=Scalar(binary=Binary(value=list_list_int_msgpack_bytes, tag=MESSAGEPACK)) ) list_list_int_output = TypeEngine.to_python_value(ctx, lv, List[List[int]]) assert list_list_int_input == list_list_int_output @@ -278,22 +245,17 @@ def test_list_transformer(): list_list_float_msgpack_bytes = encoder.encode(list_list_float_input) lv = Literal( scalar=Scalar( - binary=Binary( - value=list_list_float_msgpack_bytes, - tag="msgpack")) + binary=Binary(value=list_list_float_msgpack_bytes, tag=MESSAGEPACK) + ) ) - list_list_float_output = TypeEngine.to_python_value( - ctx, lv, List[List[float]]) + list_list_float_output = TypeEngine.to_python_value(ctx, lv, List[List[float]]) assert list_list_float_input == list_list_float_output list_list_str_input = [["a", "b"], ["c", "d"]] encoder = MessagePackEncoder(List[List[str]]) list_list_str_msgpack_bytes = encoder.encode(list_list_str_input) lv = Literal( - scalar=Scalar( - binary=Binary( - value=list_list_str_msgpack_bytes, - tag="msgpack")) + scalar=Scalar(binary=Binary(value=list_list_str_msgpack_bytes, tag=MESSAGEPACK)) ) list_list_str_output = TypeEngine.to_python_value(ctx, lv, List[List[str]]) assert list_list_str_input == list_list_str_output @@ -303,12 +265,10 @@ def test_list_transformer(): list_list_bool_msgpack_bytes = encoder.encode(list_list_bool_input) lv = Literal( scalar=Scalar( - binary=Binary( - value=list_list_bool_msgpack_bytes, - tag="msgpack")) + binary=Binary(value=list_list_bool_msgpack_bytes, tag=MESSAGEPACK) + ) ) - list_list_bool_output = TypeEngine.to_python_value( - ctx, lv, List[List[bool]]) + list_list_bool_output = TypeEngine.to_python_value(ctx, lv, List[List[bool]]) assert list_list_bool_input == list_list_bool_output list_dict_str_int_input = [{"key1": -1, "key2": 2}] @@ -316,22 +276,18 @@ def test_list_transformer(): list_dict_str_int_msgpack_bytes = encoder.encode(list_dict_str_int_input) lv = Literal( scalar=Scalar( - binary=Binary(value=list_dict_str_int_msgpack_bytes, tag="msgpack") + binary=Binary(value=list_dict_str_int_msgpack_bytes, tag=MESSAGEPACK) ) ) - list_dict_str_int_output = TypeEngine.to_python_value( - ctx, lv, List[Dict[str, int]]) + list_dict_str_int_output = TypeEngine.to_python_value(ctx, lv, List[Dict[str, int]]) assert list_dict_str_int_input == list_dict_str_int_output list_dict_str_float_input = [{"key1": 1.0, "key2": -2.0}] encoder = MessagePackEncoder(List[Dict[str, float]]) - list_dict_str_float_msgpack_bytes = encoder.encode( - list_dict_str_float_input) + list_dict_str_float_msgpack_bytes = encoder.encode(list_dict_str_float_input) lv = Literal( scalar=Scalar( - binary=Binary( - value=list_dict_str_float_msgpack_bytes, - tag="msgpack") + binary=Binary(value=list_dict_str_float_msgpack_bytes, tag=MESSAGEPACK) ) ) list_dict_str_float_output = TypeEngine.to_python_value( @@ -344,11 +300,10 @@ def test_list_transformer(): list_dict_str_str_msgpack_bytes = encoder.encode(list_dict_str_str_input) lv = Literal( scalar=Scalar( - binary=Binary(value=list_dict_str_str_msgpack_bytes, tag="msgpack") + binary=Binary(value=list_dict_str_str_msgpack_bytes, tag=MESSAGEPACK) ) ) - list_dict_str_str_output = TypeEngine.to_python_value( - ctx, lv, List[Dict[str, str]]) + list_dict_str_str_output = TypeEngine.to_python_value(ctx, lv, List[Dict[str, str]]) assert list_dict_str_str_input == list_dict_str_str_output list_dict_str_bool_input = [{"key1": True, "key2": False}] @@ -356,9 +311,7 @@ def test_list_transformer(): list_dict_str_bool_msgpack_bytes = encoder.encode(list_dict_str_bool_input) lv = Literal( scalar=Scalar( - binary=Binary( - value=list_dict_str_bool_msgpack_bytes, - tag="msgpack") + binary=Binary(value=list_dict_str_bool_msgpack_bytes, tag=MESSAGEPACK) ) ) list_dict_str_bool_output = TypeEngine.to_python_value( @@ -380,10 +333,8 @@ class InnerDC: h: Dict[int, bool] = field( default_factory=lambda: {0: False, 1: True, -1: False} ) - i: Dict[int, List[int]] = field( - default_factory=lambda: {0: [0, 1, -1]}) - j: Dict[int, Dict[int, int]] = field( - default_factory=lambda: {1: {-1: 0}}) + i: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]}) + j: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}}) k: dict = field(default_factory=lambda: {"key": "value"}) enum_status: Status = field(default=Status.PENDING) @@ -401,24 +352,18 @@ class DC: h: Dict[int, bool] = field( default_factory=lambda: {0: False, 1: True, -1: False} ) - i: Dict[int, List[int]] = field( - default_factory=lambda: {0: [0, 1, -1]}) - j: Dict[int, Dict[int, int]] = field( - default_factory=lambda: {1: {-1: 0}}) + i: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]}) + j: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}}) k: dict = field(default_factory=lambda: {"key": "value"}) inner_dc: InnerDC = field(default_factory=lambda: InnerDC()) enum_status: Status = field(default=Status.PENDING) - list_dict_int_inner_dc_input = [ - {1: InnerDC(), -1: InnerDC(), 0: InnerDC()}] + list_dict_int_inner_dc_input = [{1: InnerDC(), -1: InnerDC(), 0: InnerDC()}] encoder = MessagePackEncoder(List[Dict[int, InnerDC]]) - list_dict_int_inner_dc_msgpack_bytes = encoder.encode( - list_dict_int_inner_dc_input) + list_dict_int_inner_dc_msgpack_bytes = encoder.encode(list_dict_int_inner_dc_input) lv = Literal( scalar=Scalar( - binary=Binary( - value=list_dict_int_inner_dc_msgpack_bytes, - tag="msgpack") + binary=Binary(value=list_dict_int_inner_dc_msgpack_bytes, tag=MESSAGEPACK) ) ) list_dict_int_inner_dc_output = TypeEngine.to_python_value( @@ -431,11 +376,10 @@ class DC: list_dict_int_dc_msgpack_bytes = encoder.encode(list_dict_int_dc_input) lv = Literal( scalar=Scalar( - binary=Binary(value=list_dict_int_dc_msgpack_bytes, tag="msgpack") + binary=Binary(value=list_dict_int_dc_msgpack_bytes, tag=MESSAGEPACK) ) ) - list_dict_int_dc_output = TypeEngine.to_python_value( - ctx, lv, List[Dict[int, DC]]) + list_dict_int_dc_output = TypeEngine.to_python_value(ctx, lv, List[Dict[int, DC]]) assert list_dict_int_dc_input == list_dict_int_dc_output list_list_inner_dc_input = [[InnerDC(), InnerDC(), InnerDC()]] @@ -443,23 +387,17 @@ class DC: list_list_inner_dc_msgpack_bytes = encoder.encode(list_list_inner_dc_input) lv = Literal( scalar=Scalar( - binary=Binary( - value=list_list_inner_dc_msgpack_bytes, - tag="msgpack") + binary=Binary(value=list_list_inner_dc_msgpack_bytes, tag=MESSAGEPACK) ) ) - list_list_inner_dc_output = TypeEngine.to_python_value( - ctx, lv, List[List[InnerDC]]) + list_list_inner_dc_output = TypeEngine.to_python_value(ctx, lv, List[List[InnerDC]]) assert list_list_inner_dc_input == list_list_inner_dc_output list_list_dc_input = [[DC(), DC(), DC()]] encoder = MessagePackEncoder(List[List[DC]]) list_list_dc_msgpack_bytes = encoder.encode(list_list_dc_input) lv = Literal( - scalar=Scalar( - binary=Binary( - value=list_list_dc_msgpack_bytes, - tag="msgpack")) + scalar=Scalar(binary=Binary(value=list_list_dc_msgpack_bytes, tag=MESSAGEPACK)) ) list_list_dc_output = TypeEngine.to_python_value(ctx, lv, List[List[DC]]) assert list_list_dc_input == list_list_dc_output @@ -472,10 +410,8 @@ def test_dict_transformer(local_dummy_file, local_dummy_directory): encoder = MessagePackEncoder(Dict[str, int]) dict_str_int_msgpack_bytes = encoder.encode(dict_str_int_input) lv = Literal( - scalar=Scalar( - binary=Binary( - value=dict_str_int_msgpack_bytes, - tag="msgpack"))) + scalar=Scalar(binary=Binary(value=dict_str_int_msgpack_bytes, tag=MESSAGEPACK)) + ) dict_str_int_output = TypeEngine.to_python_value(ctx, lv, Dict[str, int]) assert dict_str_int_input == dict_str_int_output @@ -484,21 +420,18 @@ def test_dict_transformer(local_dummy_file, local_dummy_directory): dict_str_float_msgpack_bytes = encoder.encode(dict_str_float_input) lv = Literal( scalar=Scalar( - binary=Binary( - value=dict_str_float_msgpack_bytes, - tag="msgpack"))) - dict_str_float_output = TypeEngine.to_python_value( - ctx, lv, Dict[str, float]) + binary=Binary(value=dict_str_float_msgpack_bytes, tag=MESSAGEPACK) + ) + ) + dict_str_float_output = TypeEngine.to_python_value(ctx, lv, Dict[str, float]) assert dict_str_float_input == dict_str_float_output dict_str_str_input = {"key1": "a", "key2": "b"} encoder = MessagePackEncoder(Dict[str, str]) dict_str_str_msgpack_bytes = encoder.encode(dict_str_str_input) lv = Literal( - scalar=Scalar( - binary=Binary( - value=dict_str_str_msgpack_bytes, - tag="msgpack"))) + scalar=Scalar(binary=Binary(value=dict_str_str_msgpack_bytes, tag=MESSAGEPACK)) + ) dict_str_str_output = TypeEngine.to_python_value(ctx, lv, Dict[str, str]) assert dict_str_str_input == dict_str_str_output @@ -506,10 +439,8 @@ def test_dict_transformer(local_dummy_file, local_dummy_directory): encoder = MessagePackEncoder(Dict[str, bool]) dict_str_bool_msgpack_bytes = encoder.encode(dict_str_bool_input) lv = Literal( - scalar=Scalar( - binary=Binary( - value=dict_str_bool_msgpack_bytes, - tag="msgpack"))) + scalar=Scalar(binary=Binary(value=dict_str_bool_msgpack_bytes, tag=MESSAGEPACK)) + ) dict_str_bool_output = TypeEngine.to_python_value(ctx, lv, Dict[str, bool]) assert dict_str_bool_input == dict_str_bool_output @@ -518,52 +449,59 @@ def test_dict_transformer(local_dummy_file, local_dummy_directory): dict_str_list_int_msgpack_bytes = encoder.encode(dict_str_list_int_input) lv = Literal( scalar=Scalar( - binary=Binary( - value=dict_str_list_int_msgpack_bytes, - tag="msgpack"))) - dict_str_list_int_output = TypeEngine.to_python_value( - ctx, lv, Dict[str, List[int]]) + binary=Binary(value=dict_str_list_int_msgpack_bytes, tag=MESSAGEPACK) + ) + ) + dict_str_list_int_output = TypeEngine.to_python_value(ctx, lv, Dict[str, List[int]]) assert dict_str_list_int_input == dict_str_list_int_output dict_str_dict_str_int_input = {"key1": {"subkey1": 1, "subkey2": -2}} encoder = MessagePackEncoder(Dict[str, Dict[str, int]]) - dict_str_dict_str_int_msgpack_bytes = encoder.encode( - dict_str_dict_str_int_input) + dict_str_dict_str_int_msgpack_bytes = encoder.encode(dict_str_dict_str_int_input) lv = Literal( scalar=Scalar( - binary=Binary( - value=dict_str_dict_str_int_msgpack_bytes, - tag="msgpack"))) + binary=Binary(value=dict_str_dict_str_int_msgpack_bytes, tag=MESSAGEPACK) + ) + ) dict_str_dict_str_int_output = TypeEngine.to_python_value( - ctx, lv, Dict[str, Dict[str, int]]) + ctx, lv, Dict[str, Dict[str, int]] + ) assert dict_str_dict_str_int_input == dict_str_dict_str_int_output dict_str_dict_str_list_int_input = { - "key1": {"subkey1": [1, -2], "subkey2": [-3, 4]}} + "key1": {"subkey1": [1, -2], "subkey2": [-3, 4]} + } encoder = MessagePackEncoder(Dict[str, Dict[str, List[int]]]) dict_str_dict_str_list_int_msgpack_bytes = encoder.encode( - dict_str_dict_str_list_int_input) + dict_str_dict_str_list_int_input + ) lv = Literal( scalar=Scalar( binary=Binary( - value=dict_str_dict_str_list_int_msgpack_bytes, - tag="msgpack"))) + value=dict_str_dict_str_list_int_msgpack_bytes, tag=MESSAGEPACK + ) + ) + ) dict_str_dict_str_list_int_output = TypeEngine.to_python_value( - ctx, lv, Dict[str, Dict[str, List[int]]]) + ctx, lv, Dict[str, Dict[str, List[int]]] + ) assert dict_str_dict_str_list_int_input == dict_str_dict_str_list_int_output - dict_str_list_dict_str_int_input = { - "key1": [{"subkey1": -1}, {"subkey2": 2}]} + dict_str_list_dict_str_int_input = {"key1": [{"subkey1": -1}, {"subkey2": 2}]} encoder = MessagePackEncoder(Dict[str, List[Dict[str, int]]]) dict_str_list_dict_str_int_msgpack_bytes = encoder.encode( - dict_str_list_dict_str_int_input) + dict_str_list_dict_str_int_input + ) lv = Literal( scalar=Scalar( binary=Binary( - value=dict_str_list_dict_str_int_msgpack_bytes, - tag="msgpack"))) + value=dict_str_list_dict_str_int_msgpack_bytes, tag=MESSAGEPACK + ) + ) + ) dict_str_list_dict_str_int_output = TypeEngine.to_python_value( - ctx, lv, Dict[str, List[Dict[str, int]]]) + ctx, lv, Dict[str, List[Dict[str, int]]] + ) assert dict_str_list_dict_str_int_input == dict_str_list_dict_str_int_output # non-strict types @@ -571,24 +509,26 @@ def test_dict_transformer(local_dummy_file, local_dummy_directory): encoder = MessagePackEncoder(dict) dict_int_str_msgpack_bytes = encoder.encode(dict_int_str_input) lv = Literal( - scalar=Scalar( - binary=Binary( - value=dict_int_str_msgpack_bytes, - tag="msgpack"))) + scalar=Scalar(binary=Binary(value=dict_int_str_msgpack_bytes, tag=MESSAGEPACK)) + ) dict_int_str_output = TypeEngine.to_python_value(ctx, lv, dict) assert dict_int_str_input == dict_int_str_output dict_int_dict_int_list_int_input = {1: {-2: [1, -2]}, -3: {4: [-3, 4]}} encoder = MessagePackEncoder(Dict[int, Dict[int, List[int]]]) dict_int_dict_int_list_int_msgpack_bytes = encoder.encode( - dict_int_dict_int_list_int_input) + dict_int_dict_int_list_int_input + ) lv = Literal( scalar=Scalar( binary=Binary( - value=dict_int_dict_int_list_int_msgpack_bytes, - tag="msgpack"))) + value=dict_int_dict_int_list_int_msgpack_bytes, tag=MESSAGEPACK + ) + ) + ) dict_int_dict_int_list_int_output = TypeEngine.to_python_value( - ctx, lv, Dict[int, Dict[int, List[int]]]) + ctx, lv, Dict[int, Dict[int, List[int]]] + ) assert dict_int_dict_int_list_int_input == dict_int_dict_int_list_int_output @dataclass @@ -605,10 +545,8 @@ class InnerDC: h: Dict[int, bool] = field( default_factory=lambda: {0: False, 1: True, -1: False} ) - i: Dict[int, List[int]] = field( - default_factory=lambda: {0: [0, 1, -1]}) - j: Dict[int, Dict[int, int]] = field( - default_factory=lambda: {1: {-1: 0}}) + i: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]}) + j: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}}) k: dict = field(default_factory=lambda: {"key": "value"}) enum_status: Status = field(default=Status.PENDING) @@ -626,10 +564,8 @@ class DC: h: Dict[int, bool] = field( default_factory=lambda: {0: False, 1: True, -1: False} ) - i: Dict[int, List[int]] = field( - default_factory=lambda: {0: [0, 1, -1]}) - j: Dict[int, Dict[int, int]] = field( - default_factory=lambda: {1: {-1: 0}}) + i: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]}) + j: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}}) k: dict = field(default_factory=lambda: {"key": "value"}) inner_dc: InnerDC = field(default_factory=lambda: InnerDC()) enum_status: Status = field(default=Status.PENDING) @@ -639,21 +575,17 @@ class DC: dict_int_inner_dc_msgpack_bytes = encoder.encode(dict_int_inner_dc_input) lv = Literal( scalar=Scalar( - binary=Binary(value=dict_int_inner_dc_msgpack_bytes, tag="msgpack") + binary=Binary(value=dict_int_inner_dc_msgpack_bytes, tag=MESSAGEPACK) ) ) - dict_int_inner_dc_output = TypeEngine.to_python_value( - ctx, lv, Dict[int, InnerDC]) + dict_int_inner_dc_output = TypeEngine.to_python_value(ctx, lv, Dict[int, InnerDC]) assert dict_int_inner_dc_input == dict_int_inner_dc_output dict_int_dc = {1: DC(), -2: DC(), 0: DC()} encoder = MessagePackEncoder(Dict[int, DC]) dict_int_dc_msgpack_bytes = encoder.encode(dict_int_dc) lv = Literal( - scalar=Scalar( - binary=Binary( - value=dict_int_dc_msgpack_bytes, - tag="msgpack")) + scalar=Scalar(binary=Binary(value=dict_int_dc_msgpack_bytes, tag=MESSAGEPACK)) ) dict_int_dc_output = TypeEngine.to_python_value(ctx, lv, Dict[int, DC]) assert dict_int_dc == dict_int_dc_output @@ -685,16 +617,20 @@ def test_flytetypes_in_dataclass_wf(local_dummy_file, local_dummy_directory): @dataclass class InnerDC: flytefile: FlyteFile = field( - default_factory=lambda: FlyteFile(local_dummy_file)) + default_factory=lambda: FlyteFile(local_dummy_file) + ) flytedir: FlyteDirectory = field( - default_factory=lambda: FlyteDirectory(local_dummy_directory)) + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) @dataclass class DC: flytefile: FlyteFile = field( - default_factory=lambda: FlyteFile(local_dummy_file)) + default_factory=lambda: FlyteFile(local_dummy_file) + ) flytedir: FlyteDirectory = field( - default_factory=lambda: FlyteDirectory(local_dummy_directory)) + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) inner_dc: InnerDC = field(default_factory=lambda: InnerDC()) @task @@ -736,25 +672,29 @@ class InnerDC: d: bool = False e: List[int] = field(default_factory=lambda: [0, 1, 2, -1, -2]) f: List[FlyteFile] = field( - default_factory=lambda: [ - FlyteFile(local_dummy_file)]) + default_factory=lambda: [FlyteFile(local_dummy_file)] + ) g: List[List[int]] = field(default_factory=lambda: [[0], [1], [-1]]) - h: List[Dict[int, bool]] = field(default_factory=lambda: [ - {0: False}, {1: True}, {-1: True}]) - i: Dict[int, bool] = field(default_factory=lambda: { - 0: False, 1: True, -1: False}) - j: Dict[int, FlyteFile] = field(default_factory=lambda: {0: FlyteFile(local_dummy_file), - 1: FlyteFile(local_dummy_file), - -1: FlyteFile(local_dummy_file)}) - k: Dict[int, List[int]] = field( - default_factory=lambda: {0: [0, 1, -1]}) - l: Dict[int, Dict[int, int]] = field( - default_factory=lambda: {1: {-1: 0}}) + h: List[Dict[int, bool]] = field( + default_factory=lambda: [{0: False}, {1: True}, {-1: True}] + ) + i: Dict[int, bool] = field( + default_factory=lambda: {0: False, 1: True, -1: False} + ) + j: Dict[int, FlyteFile] = field( + default_factory=lambda: { + 0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file), + } + ) + k: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]}) + l: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}}) m: dict = field(default_factory=lambda: {"key": "value"}) - n: FlyteFile = field( - default_factory=lambda: FlyteFile(local_dummy_file)) + n: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) o: FlyteDirectory = field( - default_factory=lambda: FlyteDirectory(local_dummy_directory)) + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) enum_status: Status = field(default=Status.PENDING) @dataclass @@ -766,24 +706,30 @@ class DC: e: List[int] = field(default_factory=lambda: [0, 1, 2, -1, -2]) f: List[FlyteFile] = field( default_factory=lambda: [ - FlyteFile(local_dummy_file), ]) + FlyteFile(local_dummy_file), + ] + ) g: List[List[int]] = field(default_factory=lambda: [[0], [1], [-1]]) - h: List[Dict[int, bool]] = field(default_factory=lambda: [ - {0: False}, {1: True}, {-1: True}]) - i: Dict[int, bool] = field(default_factory=lambda: { - 0: False, 1: True, -1: False}) - j: Dict[int, FlyteFile] = field(default_factory=lambda: {0: FlyteFile(local_dummy_file), - 1: FlyteFile(local_dummy_file), - -1: FlyteFile(local_dummy_file)}) - k: Dict[int, List[int]] = field( - default_factory=lambda: {0: [0, 1, -1]}) - l: Dict[int, Dict[int, int]] = field( - default_factory=lambda: {1: {-1: 0}}) + h: List[Dict[int, bool]] = field( + default_factory=lambda: [{0: False}, {1: True}, {-1: True}] + ) + i: Dict[int, bool] = field( + default_factory=lambda: {0: False, 1: True, -1: False} + ) + j: Dict[int, FlyteFile] = field( + default_factory=lambda: { + 0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file), + } + ) + k: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]}) + l: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}}) m: dict = field(default_factory=lambda: {"key": "value"}) - n: FlyteFile = field( - default_factory=lambda: FlyteFile(local_dummy_file)) + n: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) o: FlyteDirectory = field( - default_factory=lambda: FlyteDirectory(local_dummy_directory)) + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) inner_dc: InnerDC = field(default_factory=lambda: InnerDC()) enum_status: Status = field(default=Status.PENDING) @@ -816,10 +762,24 @@ def t_inner(inner_dc: InnerDC): assert inner_dc.enum_status == Status.PENDING @task - def t_test_all_attributes(a: int, b: float, c: str, d: bool, e: List[int], f: List[FlyteFile], g: List[List[int]], - h: List[Dict[int, bool]], i: Dict[int, bool], j: Dict[int, FlyteFile], - k: Dict[int, List[int]], l: Dict[int, Dict[int, int]], m: dict, - n: FlyteFile, o: FlyteDirectory, enum_status: Status): + def t_test_all_attributes( + a: int, + b: float, + c: str, + d: bool, + e: List[int], + f: List[FlyteFile], + g: List[List[int]], + h: List[Dict[int, bool]], + i: Dict[int, bool], + j: Dict[int, FlyteFile], + k: Dict[int, List[int]], + l: Dict[int, Dict[int, int]], + m: dict, + n: FlyteFile, + o: FlyteDirectory, + enum_status: Status, + ): # Strict type checks for simple types assert isinstance(a, int), f"a is not int, it's {type(a)}" assert a == -1 @@ -828,40 +788,55 @@ def t_test_all_attributes(a: int, b: float, c: str, d: bool, e: List[int], f: Li assert isinstance(d, bool), f"d is not bool, it's {type(d)}" # Strict type checks for List[int] - assert isinstance(e, list) and all(isinstance(i, int) - for i in e), "e is not List[int]" + assert isinstance(e, list) and all( + isinstance(i, int) for i in e + ), "e is not List[int]" # Strict type checks for List[FlyteFile] - assert isinstance(f, list) and all(isinstance(i, FlyteFile) - for i in f), "f is not List[FlyteFile]" + assert isinstance(f, list) and all( + isinstance(i, FlyteFile) for i in f + ), "f is not List[FlyteFile]" # Strict type checks for List[List[int]] assert isinstance(g, list) and all( - isinstance(i, list) and all(isinstance(j, int) for j in i) for i in g), "g is not List[List[int]]" + isinstance(i, list) and all(isinstance(j, int) for j in i) for i in g + ), "g is not List[List[int]]" # Strict type checks for List[Dict[int, bool]] assert isinstance(h, list) and all( - isinstance(i, dict) and all(isinstance(k, int) and isinstance(v, bool) for k, v in i.items()) for i in h + isinstance(i, dict) + and all(isinstance(k, int) and isinstance(v, bool) for k, v in i.items()) + for i in h ), "h is not List[Dict[int, bool]]" # Strict type checks for Dict[int, bool] assert isinstance(i, dict) and all( - isinstance(k, int) and isinstance(v, bool) for k, v in i.items()), "i is not Dict[int, bool]" + isinstance(k, int) and isinstance(v, bool) for k, v in i.items() + ), "i is not Dict[int, bool]" # Strict type checks for Dict[int, FlyteFile] assert isinstance(j, dict) and all( - isinstance(k, int) and isinstance(v, FlyteFile) for k, v in j.items()), "j is not Dict[int, FlyteFile]" + isinstance(k, int) and isinstance(v, FlyteFile) for k, v in j.items() + ), "j is not Dict[int, FlyteFile]" # Strict type checks for Dict[int, List[int]] assert isinstance(k, dict) and all( - isinstance(k, int) and isinstance(v, list) and all(isinstance(i, int) for i in v) for k, v in - k.items()), "k is not Dict[int, List[int]]" + isinstance(k, int) + and isinstance(v, list) + and all(isinstance(i, int) for i in v) + for k, v in k.items() + ), "k is not Dict[int, List[int]]" # Strict type checks for Dict[int, Dict[int, int]] assert isinstance(l, dict) and all( - isinstance(k, int) and isinstance(v, dict) and all( - isinstance(sub_k, int) and isinstance(sub_v, int) for sub_k, sub_v in v.items()) - for k, v in l.items()), "l is not Dict[int, Dict[int, int]]" + isinstance(k, int) + and isinstance(v, dict) + and all( + isinstance(sub_k, int) and isinstance(sub_v, int) + for sub_k, sub_v in v.items() + ) + for k, v in l.items() + ), "l is not Dict[int, Dict[int, int]]" # Strict type check for a generic dict assert isinstance(m, dict), "m is not dict" @@ -878,23 +853,50 @@ def t_test_all_attributes(a: int, b: float, c: str, d: bool, e: List[int], f: Li @workflow def wf(dc: DC): t_inner(dc.inner_dc) - t_test_all_attributes(a=dc.a, b=dc.b, c=dc.c, - d=dc.d, e=dc.e, f=dc.f, - g=dc.g, h=dc.h, i=dc.i, - j=dc.j, k=dc.k, l=dc.l, - m=dc.m, n=dc.n, o=dc.o, enum_status=dc.enum_status) - - t_test_all_attributes(a=dc.inner_dc.a, b=dc.inner_dc.b, c=dc.inner_dc.c, - d=dc.inner_dc.d, e=dc.inner_dc.e, f=dc.inner_dc.f, - g=dc.inner_dc.g, h=dc.inner_dc.h, i=dc.inner_dc.i, - j=dc.inner_dc.j, k=dc.inner_dc.k, l=dc.inner_dc.l, - m=dc.inner_dc.m, n=dc.inner_dc.n, o=dc.inner_dc.o, enum_status=dc.inner_dc.enum_status) + t_test_all_attributes( + a=dc.a, + b=dc.b, + c=dc.c, + d=dc.d, + e=dc.e, + f=dc.f, + g=dc.g, + h=dc.h, + i=dc.i, + j=dc.j, + k=dc.k, + l=dc.l, + m=dc.m, + n=dc.n, + o=dc.o, + enum_status=dc.enum_status, + ) + + t_test_all_attributes( + a=dc.inner_dc.a, + b=dc.inner_dc.b, + c=dc.inner_dc.c, + d=dc.inner_dc.d, + e=dc.inner_dc.e, + f=dc.inner_dc.f, + g=dc.inner_dc.g, + h=dc.inner_dc.h, + i=dc.inner_dc.i, + j=dc.inner_dc.j, + k=dc.inner_dc.k, + l=dc.inner_dc.l, + m=dc.inner_dc.m, + n=dc.inner_dc.n, + o=dc.inner_dc.o, + enum_status=dc.inner_dc.enum_status, + ) wf(dc=DC()) def test_backward_compatible_with_dataclass_in_protobuf_struct( - local_dummy_file, local_dummy_directory): + local_dummy_file, local_dummy_directory +): # Flyte Console will send the input data as protobuf Struct # This test also test how Flyte Console with attribute access on the # Struct object @@ -907,25 +909,29 @@ class InnerDC: d: bool = False e: List[int] = field(default_factory=lambda: [0, 1, 2, -1, -2]) f: List[FlyteFile] = field( - default_factory=lambda: [ - FlyteFile(local_dummy_file)]) + default_factory=lambda: [FlyteFile(local_dummy_file)] + ) g: List[List[int]] = field(default_factory=lambda: [[0], [1], [-1]]) - h: List[Dict[int, bool]] = field(default_factory=lambda: [ - {0: False}, {1: True}, {-1: True}]) - i: Dict[int, bool] = field(default_factory=lambda: { - 0: False, 1: True, -1: False}) - j: Dict[int, FlyteFile] = field(default_factory=lambda: {0: FlyteFile(local_dummy_file), - 1: FlyteFile(local_dummy_file), - -1: FlyteFile(local_dummy_file)}) - k: Dict[int, List[int]] = field( - default_factory=lambda: {0: [0, 1, -1]}) - l: Dict[int, Dict[int, int]] = field( - default_factory=lambda: {1: {-1: 0}}) + h: List[Dict[int, bool]] = field( + default_factory=lambda: [{0: False}, {1: True}, {-1: True}] + ) + i: Dict[int, bool] = field( + default_factory=lambda: {0: False, 1: True, -1: False} + ) + j: Dict[int, FlyteFile] = field( + default_factory=lambda: { + 0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file), + } + ) + k: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]}) + l: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}}) m: dict = field(default_factory=lambda: {"key": "value"}) - n: FlyteFile = field( - default_factory=lambda: FlyteFile(local_dummy_file)) + n: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) o: FlyteDirectory = field( - default_factory=lambda: FlyteDirectory(local_dummy_directory)) + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) enum_status: Status = field(default=Status.PENDING) @dataclass @@ -937,24 +943,30 @@ class DC: e: List[int] = field(default_factory=lambda: [0, 1, 2, -1, -2]) f: List[FlyteFile] = field( default_factory=lambda: [ - FlyteFile(local_dummy_file), ]) + FlyteFile(local_dummy_file), + ] + ) g: List[List[int]] = field(default_factory=lambda: [[0], [1], [-1]]) - h: List[Dict[int, bool]] = field(default_factory=lambda: [ - {0: False}, {1: True}, {-1: True}]) - i: Dict[int, bool] = field(default_factory=lambda: { - 0: False, 1: True, -1: False}) - j: Dict[int, FlyteFile] = field(default_factory=lambda: {0: FlyteFile(local_dummy_file), - 1: FlyteFile(local_dummy_file), - -1: FlyteFile(local_dummy_file)}) - k: Dict[int, List[int]] = field( - default_factory=lambda: {0: [0, 1, -1]}) - l: Dict[int, Dict[int, int]] = field( - default_factory=lambda: {1: {-1: 0}}) + h: List[Dict[int, bool]] = field( + default_factory=lambda: [{0: False}, {1: True}, {-1: True}] + ) + i: Dict[int, bool] = field( + default_factory=lambda: {0: False, 1: True, -1: False} + ) + j: Dict[int, FlyteFile] = field( + default_factory=lambda: { + 0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file), + } + ) + k: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]}) + l: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}}) m: dict = field(default_factory=lambda: {"key": "value"}) - n: FlyteFile = field( - default_factory=lambda: FlyteFile(local_dummy_file)) + n: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) o: FlyteDirectory = field( - default_factory=lambda: FlyteDirectory(local_dummy_directory)) + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) inner_dc: InnerDC = field(default_factory=lambda: InnerDC()) enum_status: Status = field(default=Status.PENDING) @@ -985,10 +997,24 @@ def t_inner(inner_dc: InnerDC): # enum: Status assert inner_dc.enum_status == Status.PENDING - def t_test_all_attributes(a: int, b: float, c: str, d: bool, e: List[int], f: List[FlyteFile], g: List[List[int]], - h: List[Dict[int, bool]], i: Dict[int, bool], j: Dict[int, FlyteFile], - k: Dict[int, List[int]], l: Dict[int, Dict[int, int]], m: dict, - n: FlyteFile, o: FlyteDirectory, enum_status: Status): + def t_test_all_attributes( + a: int, + b: float, + c: str, + d: bool, + e: List[int], + f: List[FlyteFile], + g: List[List[int]], + h: List[Dict[int, bool]], + i: Dict[int, bool], + j: Dict[int, FlyteFile], + k: Dict[int, List[int]], + l: Dict[int, Dict[int, int]], + m: dict, + n: FlyteFile, + o: FlyteDirectory, + enum_status: Status, + ): # Strict type checks for simple types assert isinstance(a, int), f"a is not int, it's {type(a)}" assert a == -1 @@ -997,40 +1023,55 @@ def t_test_all_attributes(a: int, b: float, c: str, d: bool, e: List[int], f: Li assert isinstance(d, bool), f"d is not bool, it's {type(d)}" # Strict type checks for List[int] - assert isinstance(e, list) and all(isinstance(i, int) - for i in e), "e is not List[int]" + assert isinstance(e, list) and all( + isinstance(i, int) for i in e + ), "e is not List[int]" # Strict type checks for List[FlyteFile] - assert isinstance(f, list) and all(isinstance(i, FlyteFile) - for i in f), "f is not List[FlyteFile]" + assert isinstance(f, list) and all( + isinstance(i, FlyteFile) for i in f + ), "f is not List[FlyteFile]" # Strict type checks for List[List[int]] assert isinstance(g, list) and all( - isinstance(i, list) and all(isinstance(j, int) for j in i) for i in g), "g is not List[List[int]]" + isinstance(i, list) and all(isinstance(j, int) for j in i) for i in g + ), "g is not List[List[int]]" # Strict type checks for List[Dict[int, bool]] assert isinstance(h, list) and all( - isinstance(i, dict) and all(isinstance(k, int) and isinstance(v, bool) for k, v in i.items()) for i in h + isinstance(i, dict) + and all(isinstance(k, int) and isinstance(v, bool) for k, v in i.items()) + for i in h ), "h is not List[Dict[int, bool]]" # Strict type checks for Dict[int, bool] assert isinstance(i, dict) and all( - isinstance(k, int) and isinstance(v, bool) for k, v in i.items()), "i is not Dict[int, bool]" + isinstance(k, int) and isinstance(v, bool) for k, v in i.items() + ), "i is not Dict[int, bool]" # Strict type checks for Dict[int, FlyteFile] assert isinstance(j, dict) and all( - isinstance(k, int) and isinstance(v, FlyteFile) for k, v in j.items()), "j is not Dict[int, FlyteFile]" + isinstance(k, int) and isinstance(v, FlyteFile) for k, v in j.items() + ), "j is not Dict[int, FlyteFile]" # Strict type checks for Dict[int, List[int]] assert isinstance(k, dict) and all( - isinstance(k, int) and isinstance(v, list) and all(isinstance(i, int) for i in v) for k, v in - k.items()), "k is not Dict[int, List[int]]" + isinstance(k, int) + and isinstance(v, list) + and all(isinstance(i, int) for i in v) + for k, v in k.items() + ), "k is not Dict[int, List[int]]" # Strict type checks for Dict[int, Dict[int, int]] assert isinstance(l, dict) and all( - isinstance(k, int) and isinstance(v, dict) and all( - isinstance(sub_k, int) and isinstance(sub_v, int) for sub_k, sub_v in v.items()) - for k, v in l.items()), "l is not Dict[int, Dict[int, int]]" + isinstance(k, int) + and isinstance(v, dict) + and all( + isinstance(sub_k, int) and isinstance(sub_v, int) + for sub_k, sub_v in v.items() + ) + for k, v in l.items() + ), "l is not Dict[int, Dict[int, int]]" # Strict type check for a generic dict assert isinstance(m, dict), "m is not dict" @@ -1050,46 +1091,79 @@ def t_test_all_attributes(a: int, b: float, c: str, d: bool, e: List[int], f: Li DataclassTransformer()._make_dataclass_serializable(python_val=dc, python_type=DC) json_str = JSONEncoder(DC).encode(dc) upstream_output = Literal( - scalar=Scalar( - generic=_json_format.Parse( - json_str, - _struct.Struct()))) + scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())) + ) downstream_input = TypeEngine.to_python_value( - FlyteContextManager.current_context(), upstream_output, DC) + FlyteContextManager.current_context(), upstream_output, DC + ) t_inner(downstream_input.inner_dc) - t_test_all_attributes(a=downstream_input.a, b=downstream_input.b, c=downstream_input.c, - d=downstream_input.d, e=downstream_input.e, f=downstream_input.f, - g=downstream_input.g, h=downstream_input.h, i=downstream_input.i, - j=downstream_input.j, k=downstream_input.k, l=downstream_input.l, - m=downstream_input.m, n=downstream_input.n, o=downstream_input.o, - enum_status=downstream_input.enum_status) - t_test_all_attributes(a=downstream_input.inner_dc.a, b=downstream_input.inner_dc.b, c=downstream_input.inner_dc.c, - d=downstream_input.inner_dc.d, e=downstream_input.inner_dc.e, f=downstream_input.inner_dc.f, - g=downstream_input.inner_dc.g, h=downstream_input.inner_dc.h, i=downstream_input.inner_dc.i, - j=downstream_input.inner_dc.j, k=downstream_input.inner_dc.k, l=downstream_input.inner_dc.l, - m=downstream_input.inner_dc.m, n=downstream_input.inner_dc.n, o=downstream_input.inner_dc.o, - enum_status=downstream_input.inner_dc.enum_status) + t_test_all_attributes( + a=downstream_input.a, + b=downstream_input.b, + c=downstream_input.c, + d=downstream_input.d, + e=downstream_input.e, + f=downstream_input.f, + g=downstream_input.g, + h=downstream_input.h, + i=downstream_input.i, + j=downstream_input.j, + k=downstream_input.k, + l=downstream_input.l, + m=downstream_input.m, + n=downstream_input.n, + o=downstream_input.o, + enum_status=downstream_input.enum_status, + ) + t_test_all_attributes( + a=downstream_input.inner_dc.a, + b=downstream_input.inner_dc.b, + c=downstream_input.inner_dc.c, + d=downstream_input.inner_dc.d, + e=downstream_input.inner_dc.e, + f=downstream_input.inner_dc.f, + g=downstream_input.inner_dc.g, + h=downstream_input.inner_dc.h, + i=downstream_input.inner_dc.i, + j=downstream_input.inner_dc.j, + k=downstream_input.inner_dc.k, + l=downstream_input.inner_dc.l, + m=downstream_input.inner_dc.m, + n=downstream_input.inner_dc.n, + o=downstream_input.inner_dc.o, + enum_status=downstream_input.inner_dc.enum_status, + ) def test_backward_compatible_with_untyped_dict_in_protobuf_struct(): # This is the old dataclass serialization behavior. # https://github.com/flyteorg/flytekit/blob/94786cfd4a5c2c3b23ac29dcd6f04d0553fa1beb/flytekit/core/type_engine.py#L1699-L1720 - dict_input = {"a": 1.0, "b": "str", - "c": False, "d": True, - "e": [1.0, 2.0, -1.0, 0.0], - "f": {"a": {"b": [1.0, -1.0]}}} + dict_input = { + "a": 1.0, + "b": "str", + "c": False, + "d": True, + "e": [1.0, 2.0, -1.0, 0.0], + "f": {"a": {"b": [1.0, -1.0]}}, + } - upstream_output = Literal(scalar=Scalar(generic=_json_format.Parse(json.dumps(dict_input), _struct.Struct())), - metadata={"format": "json"}) + upstream_output = Literal( + scalar=Scalar( + generic=_json_format.Parse(json.dumps(dict_input), _struct.Struct()) + ), + metadata={"format": "json"}, + ) downstream_input = TypeEngine.to_python_value( - FlyteContextManager.current_context(), upstream_output, dict) + FlyteContextManager.current_context(), upstream_output, dict + ) assert dict_input == downstream_input def test_flyte_console_input_with_typed_dict_with_flyte_types_in_dataclass_in_protobuf_struct( - local_dummy_file, local_dummy_directory): + local_dummy_file, local_dummy_directory +): # TODO: We can add more nested cases for non-flyte types. """ Handles the case where Flyte Console provides input as a protobuf struct. @@ -1120,135 +1194,140 @@ def wf(dc: DC): dict_int_flyte_file = {"1": {"path": local_dummy_file}} json_str = json.dumps(dict_int_flyte_file) upstream_output = Literal( - scalar=Scalar( - generic=_json_format.Parse( - json_str, - _struct.Struct())), - metadata={ - "format": "json"}) + scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())), + metadata={"format": "json"}, + ) downstream_input = TypeEngine.to_python_value( - FlyteContextManager.current_context(), upstream_output, Dict[int, FlyteFile]) + FlyteContextManager.current_context(), upstream_output, Dict[int, FlyteFile] + ) assert downstream_input == {1: FlyteFile(local_dummy_file)} # FlyteConsole trims trailing ".0" when converting float-like strings dict_float_flyte_file = {"1": {"path": local_dummy_file}} json_str = json.dumps(dict_float_flyte_file) upstream_output = Literal( - scalar=Scalar( - generic=_json_format.Parse( - json_str, - _struct.Struct())), - metadata={ - "format": "json"}) + scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())), + metadata={"format": "json"}, + ) downstream_input = TypeEngine.to_python_value( - FlyteContextManager.current_context(), upstream_output, Dict[float, FlyteFile]) + FlyteContextManager.current_context(), upstream_output, Dict[float, FlyteFile] + ) assert downstream_input == {1.0: FlyteFile(local_dummy_file)} dict_float_flyte_file = {"1.0": {"path": local_dummy_file}} json_str = json.dumps(dict_float_flyte_file) upstream_output = Literal( - scalar=Scalar( - generic=_json_format.Parse( - json_str, - _struct.Struct())), - metadata={ - "format": "json"}) + scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())), + metadata={"format": "json"}, + ) downstream_input = TypeEngine.to_python_value( - FlyteContextManager.current_context(), upstream_output, Dict[float, FlyteFile]) + FlyteContextManager.current_context(), upstream_output, Dict[float, FlyteFile] + ) assert downstream_input == {1.0: FlyteFile(local_dummy_file)} dict_str_flyte_file = {"1": {"path": local_dummy_file}} json_str = json.dumps(dict_str_flyte_file) upstream_output = Literal( - scalar=Scalar( - generic=_json_format.Parse( - json_str, - _struct.Struct())), - metadata={ - "format": "json"}) + scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())), + metadata={"format": "json"}, + ) downstream_input = TypeEngine.to_python_value( - FlyteContextManager.current_context(), upstream_output, Dict[str, FlyteFile]) + FlyteContextManager.current_context(), upstream_output, Dict[str, FlyteFile] + ) assert downstream_input == {"1": FlyteFile(local_dummy_file)} dict_int_flyte_directory = {"1": {"path": local_dummy_directory}} json_str = json.dumps(dict_int_flyte_directory) upstream_output = Literal( - scalar=Scalar( - generic=_json_format.Parse( - json_str, - _struct.Struct())), - metadata={ - "format": "json"}) + scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())), + metadata={"format": "json"}, + ) downstream_input = TypeEngine.to_python_value( - FlyteContextManager.current_context(), upstream_output, Dict[int, FlyteDirectory]) + FlyteContextManager.current_context(), + upstream_output, + Dict[int, FlyteDirectory], + ) assert downstream_input == {1: FlyteDirectory(local_dummy_directory)} # FlyteConsole trims trailing ".0" when converting float-like strings dict_float_flyte_directory = {"1": {"path": local_dummy_directory}} json_str = json.dumps(dict_float_flyte_directory) upstream_output = Literal( - scalar=Scalar( - generic=_json_format.Parse( - json_str, - _struct.Struct())), - metadata={ - "format": "json"}) + scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())), + metadata={"format": "json"}, + ) downstream_input = TypeEngine.to_python_value( - FlyteContextManager.current_context(), upstream_output, Dict[float, FlyteDirectory]) + FlyteContextManager.current_context(), + upstream_output, + Dict[float, FlyteDirectory], + ) assert downstream_input == {1.0: FlyteDirectory(local_dummy_directory)} dict_float_flyte_directory = {"1.0": {"path": local_dummy_directory}} json_str = json.dumps(dict_float_flyte_directory) upstream_output = Literal( - scalar=Scalar( - generic=_json_format.Parse( - json_str, - _struct.Struct())), - metadata={ - "format": "json"}) + scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())), + metadata={"format": "json"}, + ) downstream_input = TypeEngine.to_python_value( - FlyteContextManager.current_context(), upstream_output, Dict[float, FlyteDirectory]) + FlyteContextManager.current_context(), + upstream_output, + Dict[float, FlyteDirectory], + ) assert downstream_input == {1.0: FlyteDirectory(local_dummy_directory)} dict_str_flyte_file = {"1": {"path": local_dummy_file}} json_str = json.dumps(dict_str_flyte_file) - upstream_output = Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())), - metadata={"format": "json"}) - downstream_input = TypeEngine.to_python_value(FlyteContextManager.current_context(), upstream_output, - Dict[str, FlyteFile]) + upstream_output = Literal( + scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())), + metadata={"format": "json"}, + ) + downstream_input = TypeEngine.to_python_value( + FlyteContextManager.current_context(), upstream_output, Dict[str, FlyteFile] + ) assert downstream_input == {"1": FlyteFile(local_dummy_file)} -def test_all_types_with_optional_in_dataclass_basemodel_wf( - local_dummy_file, local_dummy_directory): +def test_all_types_with_optional_in_dataclass_wf( + local_dummy_file, local_dummy_directory +): @dataclass class InnerDC: a: Optional[int] = -1 b: Optional[float] = 2.1 c: Optional[str] = "Hello, Flyte" d: Optional[bool] = False - e: Optional[List[int]] = field( - default_factory=lambda: [0, 1, 2, -1, -2]) + e: Optional[List[int]] = field(default_factory=lambda: [0, 1, 2, -1, -2]) f: Optional[List[FlyteFile]] = field( - default_factory=lambda: [FlyteFile(local_dummy_file)]) - g: Optional[List[List[int]]] = field( - default_factory=lambda: [[0], [1], [-1]]) + default_factory=lambda: [FlyteFile(local_dummy_file)] + ) + g: Optional[List[List[int]]] = field(default_factory=lambda: [[0], [1], [-1]]) h: Optional[List[Dict[int, bool]]] = field( - default_factory=lambda: [{0: False}, {1: True}, {-1: True}]) + default_factory=lambda: [{0: False}, {1: True}, {-1: True}] + ) i: Optional[Dict[int, bool]] = field( - default_factory=lambda: {0: False, 1: True, -1: False}) - j: Optional[Dict[int, FlyteFile]] = field(default_factory=lambda: {0: FlyteFile(local_dummy_file), - 1: FlyteFile(local_dummy_file), - -1: FlyteFile(local_dummy_file)}) + default_factory=lambda: {0: False, 1: True, -1: False} + ) + j: Optional[Dict[int, FlyteFile]] = field( + default_factory=lambda: { + 0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file), + } + ) k: Optional[Dict[int, List[int]]] = field( - default_factory=lambda: {0: [0, 1, -1]}) + default_factory=lambda: {0: [0, 1, -1]} + ) l: Optional[Dict[int, Dict[int, int]]] = field( - default_factory=lambda: {1: {-1: 0}}) + default_factory=lambda: {1: {-1: 0}} + ) m: Optional[dict] = field(default_factory=lambda: {"key": "value"}) n: Optional[FlyteFile] = field( - default_factory=lambda: FlyteFile(local_dummy_file)) + default_factory=lambda: FlyteFile(local_dummy_file) + ) o: Optional[FlyteDirectory] = field( - default_factory=lambda: FlyteDirectory(local_dummy_directory)) + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) enum_status: Optional[Status] = field(default=Status.PENDING) @dataclass @@ -1257,28 +1336,37 @@ class DC: b: Optional[float] = 2.1 c: Optional[str] = "Hello, Flyte" d: Optional[bool] = False - e: Optional[List[int]] = field( - default_factory=lambda: [0, 1, 2, -1, -2]) + e: Optional[List[int]] = field(default_factory=lambda: [0, 1, 2, -1, -2]) f: Optional[List[FlyteFile]] = field( - default_factory=lambda: [FlyteFile(local_dummy_file)]) - g: Optional[List[List[int]]] = field( - default_factory=lambda: [[0], [1], [-1]]) + default_factory=lambda: [FlyteFile(local_dummy_file)] + ) + g: Optional[List[List[int]]] = field(default_factory=lambda: [[0], [1], [-1]]) h: Optional[List[Dict[int, bool]]] = field( - default_factory=lambda: [{0: False}, {1: True}, {-1: True}]) + default_factory=lambda: [{0: False}, {1: True}, {-1: True}] + ) i: Optional[Dict[int, bool]] = field( - default_factory=lambda: {0: False, 1: True, -1: False}) - j: Optional[Dict[int, FlyteFile]] = field(default_factory=lambda: {0: FlyteFile(local_dummy_file), - 1: FlyteFile(local_dummy_file), - -1: FlyteFile(local_dummy_file)}) + default_factory=lambda: {0: False, 1: True, -1: False} + ) + j: Optional[Dict[int, FlyteFile]] = field( + default_factory=lambda: { + 0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file), + } + ) k: Optional[Dict[int, List[int]]] = field( - default_factory=lambda: {0: [0, 1, -1]}) + default_factory=lambda: {0: [0, 1, -1]} + ) l: Optional[Dict[int, Dict[int, int]]] = field( - default_factory=lambda: {1: {-1: 0}}) + default_factory=lambda: {1: {-1: 0}} + ) m: Optional[dict] = field(default_factory=lambda: {"key": "value"}) n: Optional[FlyteFile] = field( - default_factory=lambda: FlyteFile(local_dummy_file)) + default_factory=lambda: FlyteFile(local_dummy_file) + ) o: Optional[FlyteDirectory] = field( - default_factory=lambda: FlyteDirectory(local_dummy_directory)) + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) inner_dc: Optional[InnerDC] = field(default_factory=lambda: InnerDC()) enum_status: Optional[Status] = field(default=Status.PENDING) @@ -1311,15 +1399,24 @@ def t_inner(inner_dc: InnerDC): assert inner_dc.enum_status == Status.PENDING @task - def t_test_all_attributes(a: Optional[int], b: Optional[float], c: Optional[str], d: Optional[bool], - e: Optional[List[int]], f: Optional[List[FlyteFile]], - g: Optional[List[List[int]]], - h: Optional[List[Dict[int, bool]]], i: Optional[Dict[int, bool]], - j: Optional[Dict[int, FlyteFile]], - k: Optional[Dict[int, List[int]]], l: Optional[Dict[int, Dict[int, int]]], - m: Optional[dict], - n: Optional[FlyteFile], o: Optional[FlyteDirectory], - enum_status: Optional[Status]): + def t_test_all_attributes( + a: Optional[int], + b: Optional[float], + c: Optional[str], + d: Optional[bool], + e: Optional[List[int]], + f: Optional[List[FlyteFile]], + g: Optional[List[List[int]]], + h: Optional[List[Dict[int, bool]]], + i: Optional[Dict[int, bool]], + j: Optional[Dict[int, FlyteFile]], + k: Optional[Dict[int, List[int]]], + l: Optional[Dict[int, Dict[int, int]]], + m: Optional[dict], + n: Optional[FlyteFile], + o: Optional[FlyteDirectory], + enum_status: Optional[Status], + ): # Strict type checks for simple types assert isinstance(a, int), f"a is not int, it's {type(a)}" assert a == -1 @@ -1328,40 +1425,55 @@ def t_test_all_attributes(a: Optional[int], b: Optional[float], c: Optional[str] assert isinstance(d, bool), f"d is not bool, it's {type(d)}" # Strict type checks for List[int] - assert isinstance(e, list) and all(isinstance(i, int) - for i in e), "e is not List[int]" + assert isinstance(e, list) and all( + isinstance(i, int) for i in e + ), "e is not List[int]" # Strict type checks for List[FlyteFile] - assert isinstance(f, list) and all(isinstance(i, FlyteFile) - for i in f), "f is not List[FlyteFile]" + assert isinstance(f, list) and all( + isinstance(i, FlyteFile) for i in f + ), "f is not List[FlyteFile]" # Strict type checks for List[List[int]] assert isinstance(g, list) and all( - isinstance(i, list) and all(isinstance(j, int) for j in i) for i in g), "g is not List[List[int]]" + isinstance(i, list) and all(isinstance(j, int) for j in i) for i in g + ), "g is not List[List[int]]" # Strict type checks for List[Dict[int, bool]] assert isinstance(h, list) and all( - isinstance(i, dict) and all(isinstance(k, int) and isinstance(v, bool) for k, v in i.items()) for i in h + isinstance(i, dict) + and all(isinstance(k, int) and isinstance(v, bool) for k, v in i.items()) + for i in h ), "h is not List[Dict[int, bool]]" # Strict type checks for Dict[int, bool] assert isinstance(i, dict) and all( - isinstance(k, int) and isinstance(v, bool) for k, v in i.items()), "i is not Dict[int, bool]" + isinstance(k, int) and isinstance(v, bool) for k, v in i.items() + ), "i is not Dict[int, bool]" # Strict type checks for Dict[int, FlyteFile] assert isinstance(j, dict) and all( - isinstance(k, int) and isinstance(v, FlyteFile) for k, v in j.items()), "j is not Dict[int, FlyteFile]" + isinstance(k, int) and isinstance(v, FlyteFile) for k, v in j.items() + ), "j is not Dict[int, FlyteFile]" # Strict type checks for Dict[int, List[int]] assert isinstance(k, dict) and all( - isinstance(k, int) and isinstance(v, list) and all(isinstance(i, int) for i in v) for k, v in - k.items()), "k is not Dict[int, List[int]]" + isinstance(k, int) + and isinstance(v, list) + and all(isinstance(i, int) for i in v) + for k, v in k.items() + ), "k is not Dict[int, List[int]]" # Strict type checks for Dict[int, Dict[int, int]] assert isinstance(l, dict) and all( - isinstance(k, int) and isinstance(v, dict) and all( - isinstance(sub_k, int) and isinstance(sub_v, int) for sub_k, sub_v in v.items()) - for k, v in l.items()), "l is not Dict[int, Dict[int, int]]" + isinstance(k, int) + and isinstance(v, dict) + and all( + isinstance(sub_k, int) and isinstance(sub_v, int) + for sub_k, sub_v in v.items() + ) + for k, v in l.items() + ), "l is not Dict[int, Dict[int, int]]" # Strict type check for a generic dict assert isinstance(m, dict), "m is not dict" @@ -1378,12 +1490,24 @@ def t_test_all_attributes(a: Optional[int], b: Optional[float], c: Optional[str] @workflow def wf(dc: DC): t_inner(dc.inner_dc) - t_test_all_attributes(a=dc.a, b=dc.b, c=dc.c, - d=dc.d, e=dc.e, f=dc.f, - g=dc.g, h=dc.h, i=dc.i, - j=dc.j, k=dc.k, l=dc.l, - m=dc.m, n=dc.n, o=dc.o, - enum_status=dc.enum_status) + t_test_all_attributes( + a=dc.a, + b=dc.b, + c=dc.c, + d=dc.d, + e=dc.e, + f=dc.f, + g=dc.g, + h=dc.h, + i=dc.i, + j=dc.j, + k=dc.k, + l=dc.l, + m=dc.m, + n=dc.n, + o=dc.o, + enum_status=dc.enum_status, + ) wf(dc=DC()) @@ -1433,26 +1557,47 @@ def t_inner(inner_dc: Optional[InnerDC]): return inner_dc @task - def t_test_all_attributes(a: Optional[int], b: Optional[float], c: Optional[str], d: Optional[bool], - e: Optional[List[int]], f: Optional[List[FlyteFile]], - g: Optional[List[List[int]]], - h: Optional[List[Dict[int, bool]]], i: Optional[Dict[int, bool]], - j: Optional[Dict[int, FlyteFile]], - k: Optional[Dict[int, List[int]]], l: Optional[Dict[int, Dict[int, int]]], - m: Optional[dict], - n: Optional[FlyteFile], o: Optional[FlyteDirectory], - enum_status: Optional[Status]): + def t_test_all_attributes( + a: Optional[int], + b: Optional[float], + c: Optional[str], + d: Optional[bool], + e: Optional[List[int]], + f: Optional[List[FlyteFile]], + g: Optional[List[List[int]]], + h: Optional[List[Dict[int, bool]]], + i: Optional[Dict[int, bool]], + j: Optional[Dict[int, FlyteFile]], + k: Optional[Dict[int, List[int]]], + l: Optional[Dict[int, Dict[int, int]]], + m: Optional[dict], + n: Optional[FlyteFile], + o: Optional[FlyteDirectory], + enum_status: Optional[Status], + ): return @workflow def wf(dc: DC): t_inner(dc.inner_dc) - t_test_all_attributes(a=dc.a, b=dc.b, c=dc.c, - d=dc.d, e=dc.e, f=dc.f, - g=dc.g, h=dc.h, i=dc.i, - j=dc.j, k=dc.k, l=dc.l, - m=dc.m, n=dc.n, o=dc.o, - enum_status=dc.enum_status) + t_test_all_attributes( + a=dc.a, + b=dc.b, + c=dc.c, + d=dc.d, + e=dc.e, + f=dc.f, + g=dc.g, + h=dc.h, + i=dc.i, + j=dc.j, + k=dc.k, + l=dc.l, + m=dc.m, + n=dc.n, + o=dc.o, + enum_status=dc.enum_status, + ) wf(dc=DC()) @@ -1464,8 +1609,9 @@ class DC: b: Union[int, bool, str, float] @task - def add(a: Union[int, bool, str, float], b: Union[int, - bool, str, float]) -> Union[int, bool, str, float]: + def add( + a: Union[int, bool, str, float], b: Union[int, bool, str, float] + ) -> Union[int, bool, str, float]: return a + b # type: ignore @workflow diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index f5c3fb5a3c..6664deaaeb 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -20,21 +20,24 @@ import flytekit import flytekit.configuration from flytekit import Secret, SQLTask, dynamic, kwtypes, map_task -from flytekit.configuration import FastSerializationSettings, Image, ImageConfig +from flytekit.configuration import (FastSerializationSettings, Image, + ImageConfig) from flytekit.core import context_manager, launch_plan, promise from flytekit.core.condition import conditional +from flytekit.core.constants import MESSAGEPACK from flytekit.core.context_manager import ExecutionState from flytekit.core.data_persistence import FileAccessProvider, flyte_tmp_dir from flytekit.core.hash import HashMethod from flytekit.core.node import Node from flytekit.core.promise import NodeOutput, Promise, VoidPromise from flytekit.core.resources import Resources -from flytekit.types.pickle.pickle import FlytePickleTransformer from flytekit.core.task import TaskMetadata, task from flytekit.core.testing import patch, task_mock -from flytekit.core.type_engine import RestrictedTypeError, SimpleTransformer, TypeEngine, TypeTransformerFailedError +from flytekit.core.type_engine import (RestrictedTypeError, SimpleTransformer, + TypeEngine, TypeTransformerFailedError) from flytekit.core.workflow import workflow -from flytekit.exceptions.user import FlyteValidationException, FlyteFailureNodeInputMismatchException +from flytekit.exceptions.user import (FlyteFailureNodeInputMismatchException, + FlyteValidationException) from flytekit.models import literals as _literal_models from flytekit.models.core import types as _core_types from flytekit.models.interface import Parameter @@ -45,6 +48,7 @@ from flytekit.types.directory import FlyteDirectory, TensorboardLogs from flytekit.types.error import FlyteError from flytekit.types.file import FlyteFile +from flytekit.types.pickle.pickle import FlytePickleTransformer from flytekit.types.schema import FlyteSchema, SchemaOpenMode from flytekit.types.structured.structured_dataset import StructuredDataset @@ -83,7 +87,11 @@ def test_forwardref_namedtuple_output(): # This test case tests typing.NamedTuple outputs for cases where eg. # from __future__ import annotations is enabled, such that all type hints become ForwardRef @task - def my_task(a: int) -> typing.NamedTuple("OutputsBC", b=typing.ForwardRef("int"), c=typing.ForwardRef("str")): + def my_task( + a: int, + ) -> typing.NamedTuple( + "OutputsBC", b=typing.ForwardRef("int"), c=typing.ForwardRef("str") + ): ctx = flytekit.current_context() assert str(ctx.execution_id) == "ex:local:local:local" return a + 2, "hello world" @@ -109,7 +117,9 @@ def my_task(a: int): assert my_task(a=3) is None ctx = context_manager.FlyteContextManager.current_context() - with context_manager.FlyteContextManager.with_context(ctx.with_new_compilation_state()) as ctx: + with context_manager.FlyteContextManager.with_context( + ctx.with_new_compilation_state() + ) as ctx: outputs = my_task(a=3) assert isinstance(outputs, VoidPromise) @@ -155,7 +165,9 @@ def my_task() -> str: assert my_task() == "Hello world" ctx = context_manager.FlyteContextManager.current_context() - with context_manager.FlyteContextManager.with_context(ctx.with_new_compilation_state()) as ctx: + with context_manager.FlyteContextManager.with_context( + ctx.with_new_compilation_state() + ) as ctx: outputs = my_task() assert ctx.compilation_state is not None nodes = ctx.compilation_state.nodes @@ -181,23 +193,31 @@ def test_engine_file_output(): dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, ) - fs = FileAccessProvider(local_sandbox_dir="/tmp/flytetesting", raw_output_prefix="/tmp/flyteraw") + fs = FileAccessProvider( + local_sandbox_dir="/tmp/flytetesting", raw_output_prefix="/tmp/flyteraw" + ) ctx = context_manager.FlyteContextManager.current_context() - with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)) as ctx: + with context_manager.FlyteContextManager.with_context( + ctx.with_file_access(fs) + ) as ctx: # Write some text to a file not in that directory above test_file_location = "/tmp/sample.txt" with open(test_file_location, "w") as fh: fh.write("Hello World\n") - lit = TypeEngine.to_literal(ctx, test_file_location, os.PathLike, LiteralType(blob=basic_blob_type)) + lit = TypeEngine.to_literal( + ctx, test_file_location, os.PathLike, LiteralType(blob=basic_blob_type) + ) # Since we're using local as remote, we should be able to just read the file from the 'remote' location. with open(lit.scalar.blob.uri, "r") as fh: assert fh.readline() == "Hello World\n" # We should also be able to turn the thing back into regular python native thing. - redownloaded_local_file_location = TypeEngine.to_python_value(ctx, lit, os.PathLike) + redownloaded_local_file_location = TypeEngine.to_python_value( + ctx, lit, os.PathLike + ) with open(redownloaded_local_file_location, "r") as fh: assert fh.readline() == "Hello World\n" @@ -365,7 +385,9 @@ def mimic_sub_wf(a: int) -> (str, str): with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( - ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION) + ctx.new_execution_state().with_params( + mode=ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION + ) ) ) as ctx: a, b = mimic_sub_wf(a=3) @@ -400,7 +422,11 @@ def my_wf() -> FlyteSchema: with task_mock(sql) as mock: mock.return_value = pd.DataFrame(data={"x": [1, 2], "y": ["3", "4"]}) - assert (my_wf().open().all() == pd.DataFrame(data={"x": [1, 2], "y": ["3", "4"]})).all().all() + assert ( + (my_wf().open().all() == pd.DataFrame(data={"x": [1, 2], "y": ["3", "4"]})) + .all() + .all() + ) assert context_manager.FlyteContextManager.size() == 1 @@ -428,7 +454,11 @@ def my_wf() -> FlyteSchema: @patch(sql) def test_user_demo_test(mock_sql): mock_sql.return_value = pd.DataFrame(data={"x": [1, 2], "y": ["3", "4"]}) - assert (my_wf().open().all() == pd.DataFrame(data={"x": [1, 2], "y": ["3", "4"]})).all().all() + assert ( + (my_wf().open().all() == pd.DataFrame(data={"x": [1, 2], "y": ["3", "4"]})) + .all() + .all() + ) # Have to call because tests inside tests don't run test_user_demo_test() @@ -656,9 +686,15 @@ def my_wf(a: int, b: str) -> (str, typing.List[str]): ) ) ) as ctx: - new_exc_state = ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION) - with context_manager.FlyteContextManager.with_context(ctx.with_execution_state(new_exc_state)) as ctx: - dynamic_job_spec = my_subwf.compile_into_workflow(ctx, my_subwf._task_function, a=5) + new_exc_state = ctx.execution_state.with_params( + mode=ExecutionState.Mode.TASK_EXECUTION + ) + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state(new_exc_state) + ) as ctx: + dynamic_job_spec = my_subwf.compile_into_workflow( + ctx, my_subwf._task_function, a=5 + ) assert len(dynamic_job_spec._nodes) == 5 assert len(dynamic_job_spec.tasks) == 1 @@ -735,7 +771,9 @@ def lister() -> typing.List[str]: return s assert len(lister.interface.outputs) == 1 - binding_data = lister.output_bindings[0].binding # the property should be named binding_data + binding_data = lister.output_bindings[ + 0 + ].binding # the property should be named binding_data assert binding_data.collection is not None assert len(binding_data.collection.bindings) == 10 @@ -809,7 +847,13 @@ def my_wf(a: int, b: str) -> (int, str): .else_() .fail("Unable to choose branch") ) - f = conditional("test2").if_(d == "hello ").then(t2(a="It is hello")).else_().then(t2(a="Not Hello!")) + f = ( + conditional("test2") + .if_(d == "hello ") + .then(t2(a="It is hello")) + .else_() + .then(t2(a="Not Hello!")) + ) return x, f x = my_wf(a=5, b="hello ") @@ -832,7 +876,13 @@ def t2(a: str) -> str: @workflow def my_wf(a: int, b: str) -> str: new_a = t1(a=a) - return conditional("test1").if_(new_a != 5).then(t2(a=b)).else_().fail("Unable to choose branch") + return ( + conditional("test1") + .if_(new_a != 5) + .then(t2(a=b)) + .else_() + .fail("Unable to choose branch") + ) with pytest.raises(ValueError): my_wf(a=4, b="hello") @@ -856,8 +906,16 @@ def t2(a: str) -> str: @workflow def my_wf(a: int, b: str) -> (int, str): x, y = t1(a=a) - d = conditional("test1").if_(x == 4).then(t2(a=b)).elif_(x >= 5).then(t2(a=y)) - conditional("test2").if_(x == 4).then(t2(a=b)).elif_(x >= 5).then(t2(a=y)).else_().fail("blah") + d = ( + conditional("test1") + .if_(x == 4) + .then(t2(a=b)) + .elif_(x >= 5) + .then(t2(a=y)) + ) + conditional("test2").if_(x == 4).then(t2(a=b)).elif_(x >= 5).then( + t2(a=y) + ).else_().fail("blah") return x, d my_wf() @@ -951,7 +1009,9 @@ def my_subwf(a: int) -> typing.Tuple[str, str]: return y, v lp = launch_plan.LaunchPlan.create("serialize_test1", my_subwf) - lp_with_defaults = launch_plan.LaunchPlan.create("serialize_test2", my_subwf, default_inputs={"a": 3}) + lp_with_defaults = launch_plan.LaunchPlan.create( + "serialize_test2", my_subwf, default_inputs={"a": 3} + ) serialization_settings = flytekit.configuration.SerializationSettings( project="proj", @@ -968,7 +1028,9 @@ def my_subwf(a: int) -> typing.Tuple[str, str]: lp_model = get_serializable(OrderedDict(), serialization_settings, lp_with_defaults) assert len(lp_model.spec.default_inputs.parameters) == 1 assert not lp_model.spec.default_inputs.parameters["a"].required - assert lp_model.spec.default_inputs.parameters["a"].default == _literal_models.Literal( + assert lp_model.spec.default_inputs.parameters[ + "a" + ].default == _literal_models.Literal( scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=3)) ) assert len(lp_model.spec.fixed_inputs.literals) == 0 @@ -1014,17 +1076,23 @@ def wf() -> FlyteSchema[kwtypes(x=int)]: w = t1() assert w is not None df = w.open(override_mode=SchemaOpenMode.READ).all() - result_df = df.reset_index(drop=True) == pd.DataFrame(data={"x": [1, 2], "y": ["3", "4"]}).reset_index(drop=True) + result_df = df.reset_index(drop=True) == pd.DataFrame( + data={"x": [1, 2], "y": ["3", "4"]} + ).reset_index(drop=True) assert result_df.all().all() df = t2(s=w.as_readonly()) df = df.open(override_mode=SchemaOpenMode.READ).all() - result_df = df.reset_index(drop=True) == pd.DataFrame(data={"x": [1, 2]}).reset_index(drop=True) + result_df = df.reset_index(drop=True) == pd.DataFrame( + data={"x": [1, 2]} + ).reset_index(drop=True) assert result_df.all().all() x = wf() df = x.open().all() - result_df = df.reset_index(drop=True) == pd.DataFrame(data={"x": [1, 2]}).reset_index(drop=True) + result_df = df.reset_index(drop=True) == pd.DataFrame( + data={"x": [1, 2]} + ).reset_index(drop=True) assert result_df.all().all() @@ -1281,7 +1349,11 @@ def my_wf(a: int) -> str: context_manager.FlyteContextManager.current_context().with_new_compilation_state() ): task_spec = get_serializable(OrderedDict(), serialization_settings, t1) - assert task_spec.template.container.env == {"FOO": "foofoo", "BAR": "bar", "BAZ": "baz"} + assert task_spec.template.container.env == { + "FOO": "foofoo", + "BAR": "bar", + "BAZ": "baz", + } def test_resources(): @@ -1315,13 +1387,19 @@ def my_wf(a: int) -> str: ): task_spec = get_serializable(OrderedDict(), serialization_settings, t1) assert task_spec.template.container.resources.requests == [ - _resource_models.ResourceEntry(_resource_models.ResourceName.EPHEMERAL_STORAGE, "500Mi"), + _resource_models.ResourceEntry( + _resource_models.ResourceName.EPHEMERAL_STORAGE, "500Mi" + ), _resource_models.ResourceEntry(_resource_models.ResourceName.CPU, "1"), ] assert task_spec.template.container.resources.limits == [ - _resource_models.ResourceEntry(_resource_models.ResourceName.EPHEMERAL_STORAGE, "501Mi"), + _resource_models.ResourceEntry( + _resource_models.ResourceName.EPHEMERAL_STORAGE, "501Mi" + ), _resource_models.ResourceEntry(_resource_models.ResourceName.CPU, "2"), - _resource_models.ResourceEntry(_resource_models.ResourceName.MEMORY, "400M"), + _resource_models.ResourceEntry( + _resource_models.ResourceName.MEMORY, "400M" + ), ] task_spec2 = get_serializable(OrderedDict(), serialization_settings, t2) @@ -1333,8 +1411,7 @@ def my_wf(a: int) -> str: def test_wf_explicitly_returning_empty_task(): @task - def t1(): - ... + def t1(): ... @workflow def my_subwf(): @@ -1381,7 +1458,9 @@ def foo() -> str: def foo2() -> str: return flytekit.current_context().secrets.get("group", "key") - os.environ[flytekit.current_context().secrets.get_secrets_env_var("group", "key")] = "super-secret-value2" + os.environ[ + flytekit.current_context().secrets.get_secrets_env_var("group", "key") + ] = "super-secret-value2" assert foo2() == "super-secret-value2" with pytest.raises(AssertionError): @@ -1428,11 +1507,19 @@ def my_subwf(a: int) -> typing.List[str]: nested_my_subwf = my_wf.get_all_tasks()[0] - ctx = context_manager.FlyteContextManager.current_context().with_serialization_settings(settings) + ctx = context_manager.FlyteContextManager.current_context().with_serialization_settings( + settings + ) with context_manager.FlyteContextManager.with_context(ctx) as ctx: - es = ctx.new_execution_state().with_params(mode=ExecutionState.Mode.TASK_EXECUTION) - with context_manager.FlyteContextManager.with_context(ctx.with_execution_state(es)) as ctx: - dynamic_job_spec = nested_my_subwf.compile_into_workflow(ctx, nested_my_subwf._task_function, a=5) + es = ctx.new_execution_state().with_params( + mode=ExecutionState.Mode.TASK_EXECUTION + ) + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state(es) + ) as ctx: + dynamic_job_spec = nested_my_subwf.compile_into_workflow( + ctx, nested_my_subwf._task_function, a=5 + ) assert len(dynamic_job_spec._nodes) == 5 @@ -1526,7 +1613,9 @@ def t2(a: dict) -> str: output_lm = t2.dispatch_execute(ctx, lm) str_value = output_lm.literals["o0"].scalar.primitive.string_value - assert str_value == "K: k2 V: 2, K: k1 V: v1" or str_value == "K: k1 V: v1, K: k2 V: 2" + assert ( + str_value == "K: k2 V: 2, K: k1 V: v1" or str_value == "K: k1 V: v1, K: k2 V: 2" + ) def test_guess_dict2(): @@ -1538,7 +1627,10 @@ def t2(a: typing.List[dict]) -> str: return " ".join(strs) task_spec = get_serializable(OrderedDict(), serialization_settings, t2) - assert task_spec.template.interface.inputs["a"].type.collection_type.simple == SimpleType.STRUCT + assert ( + task_spec.template.interface.inputs["a"].type.collection_type.simple + == SimpleType.STRUCT + ) pt_map = TypeEngine.guess_python_types(task_spec.template.interface.inputs) assert pt_map == {"a": typing.List[dict]} @@ -1556,11 +1648,14 @@ def t2() -> dict: ctx = context_manager.FlyteContextManager.current_context() output_lm = t2.dispatch_execute(ctx, _literal_models.LiteralMap(literals={})) msgpack_bytes = msgpack.dumps({"k1": "v1", "k2": 3, 4: {"one": [1, "two", [3]]}}) - binary_idl_obj = Binary(value=msgpack_bytes, tag="msgpack") + binary_idl_obj = Binary(value=msgpack_bytes, tag=MESSAGEPACK) assert output_lm.literals["o0"].scalar.binary == binary_idl_obj -@pytest.mark.skipif(sys.version_info < (3, 9), reason="Use of dict hints is only supported in Python 3.9+") +@pytest.mark.skipif( + sys.version_info < (3, 9), + reason="Use of dict hints is only supported in Python 3.9+", +) def test_guess_dict4(): @dataclass class Foo(DataClassJsonMixin): @@ -1596,7 +1691,13 @@ def t2() -> Bar: assert dataclasses.is_dataclass(pt_map["o0"]) output_lm = t2.dispatch_execute(ctx, _literal_models.LiteralMap(literals={})) - msgpack_bytes = msgpack.dumps({"x": 1, "y": {"hello": "world"}, "z": {"x": 1, "y": "foo", "z": {"hello": "world"}}}) + msgpack_bytes = msgpack.dumps( + { + "x": 1, + "y": {"hello": "world"}, + "z": {"x": 1, "y": "foo", "z": {"hello": "world"}}, + } + ) assert output_lm.literals["o0"].scalar.binary.value == msgpack_bytes @@ -1625,7 +1726,7 @@ def foo3(a: typing.Dict) -> typing.Dict: return a @task - def foo4(input: DC1=DC1(1, 'a')) -> DC2: + def foo4(input: DC1 = DC1(1, "a")) -> DC2: return input # type: ignore with pytest.raises( @@ -1676,7 +1777,9 @@ def fail(a: int, b: str) -> typing.Tuple[int, str]: return a + 1, b @task - def failure_handler(a: int, b: str, err: typing.Optional[FlyteError]) -> typing.Tuple[int, str]: + def failure_handler( + a: int, b: str, err: typing.Optional[FlyteError] + ) -> typing.Tuple[int, str]: print(f"Handling error: {err}") return a + 1, b @@ -1728,12 +1831,12 @@ def wf1(a: int = 3, b: str = "hello"): with pytest.raises( FlyteFailureNodeInputMismatchException, match="Mismatched Inputs Detected\n" - f"The failure node `{exec_prefix}tests.flytekit.unit.core.test_type_hints.t1` has " - "inputs that do not align with those expected by the workflow `tests.flytekit.unit.core.test_type_hints.wf1`.\n" - "Failure Node's Inputs: {'a': }\n" - "Workflow's Inputs: {'a': , 'b': }\n" - "Action Required:\n" - "Please ensure that all input arguments in the failure node are provided and match the expected arguments specified in the workflow.", + f"The failure node `{exec_prefix}tests.flytekit.unit.core.test_type_hints.t1` has " + "inputs that do not align with those expected by the workflow `tests.flytekit.unit.core.test_type_hints.wf1`.\n" + "Failure Node's Inputs: {'a': }\n" + "Workflow's Inputs: {'a': , 'b': }\n" + "Action Required:\n" + "Please ensure that all input arguments in the failure node are provided and match the expected arguments specified in the workflow.", ): wf1() @@ -1754,7 +1857,9 @@ def test_union_type(exec_prefix): from flytekit.types.schema import FlyteSchema - ut = typing.Union[int, str, float, FlyteFile, FlyteSchema, typing.List[int], typing.Dict[str, int]] + ut = typing.Union[ + int, str, float, FlyteFile, FlyteSchema, typing.List[int], typing.Dict[str, int] + ] @task def t1(a: ut) -> ut: @@ -1787,7 +1892,7 @@ def wf2(a: typing.Union[int, str]) -> typing.Union[int, str]: TypeError, match=re.escape( f"Error encountered while converting inputs of '{exec_prefix}tests.flytekit.unit.core.test_type_hints.t2':\n" - r' Cannot convert from Flyte Serialized object (Literal):' + r" Cannot convert from Flyte Serialized object (Literal):" ), ): assert wf2(a="2") == "2" @@ -1848,7 +1953,9 @@ def __eq__(self, other): MyInt, LiteralType(simple=SimpleType.INTEGER), lambda x: _literal_models.Literal( - scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=x.val)) + scalar=_literal_models.Scalar( + primitive=_literal_models.Primitive(integer=x.val) + ) ), lambda x: MyInt(x.scalar.primitive.integer), ) @@ -1865,7 +1972,8 @@ def wf(a: int) -> int: return t1(a=a) with pytest.raises( - TypeError, match="Ambiguous choice of variant for union type. Both int and MyInt transformers match" + TypeError, + match="Ambiguous choice of variant for union type. Both int and MyInt transformers match", ): assert wf(a=10) == 10 @@ -1888,7 +1996,9 @@ def __eq__(self, other): MyInt, LiteralType(simple=SimpleType.INTEGER), lambda x: _literal_models.Literal( - scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=x.val)) + scalar=_literal_models.Scalar( + primitive=_literal_models.Primitive(integer=x.val) + ) ), lambda x: MyInt(x.scalar.primitive.integer), ) @@ -1931,7 +2041,9 @@ def plus_two( _literal_models.LiteralMap( literals={ "a": _literal_models.Literal( - scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=3)) + scalar=_literal_models.Scalar( + primitive=_literal_models.Primitive(integer=3) + ) ) } ), @@ -1988,7 +2100,9 @@ def t2() -> pd.DataFrame: # Auxiliary task used to sum up the dataframes. It demonstrates that the use of `Annotated` does not # have any impact in the definition and execution of cached or uncached downstream tasks @task - def sum_dataframes(df0: pd.DataFrame, df1: pd.DataFrame, df2: pd.DataFrame) -> pd.DataFrame: + def sum_dataframes( + df0: pd.DataFrame, df1: pd.DataFrame, df2: pd.DataFrame + ) -> pd.DataFrame: return df0 + df1 + df2 @workflow @@ -2000,7 +2114,12 @@ def wf() -> pd.DataFrame: df = wf() - expected_df = pd.DataFrame(data={"col1": [1 + 10 + 100, 2 + 20 + 200], "col2": [3 + 30 + 300, 4 + 40 + 400]}) + expected_df = pd.DataFrame( + data={ + "col1": [1 + 10 + 100, 2 + 20 + 200], + "col2": [3 + 30 + 300, 4 + 40 + 400], + } + ) assert expected_df.equals(df) @@ -2015,7 +2134,10 @@ def hash_pandas_dataframe(df: pd.DataFrame) -> str: def produce_list_of_annotated_dataframes() -> ( typing.List[Annotated[pd.DataFrame, HashMethod(hash_pandas_dataframe)]] ): - return [pd.DataFrame({"column_1": [1, 2, 3]}), pd.DataFrame({"column_1": [4, 5, 6]})] + return [ + pd.DataFrame({"column_1": [1, 2, 3]}), + pd.DataFrame({"column_1": [4, 5, 6]}), + ] @task def sum_list_of_pandas_dataframes(lst: typing.List[pd.DataFrame]) -> pd.DataFrame: diff --git a/tests/flytekit/unit/extras/pydantic_transformer/test_generice_idl_dataclass_in_pydantic_basemodel.py b/tests/flytekit/unit/extras/pydantic_transformer/test_generice_idl_dataclass_in_pydantic_basemodel.py new file mode 100644 index 0000000000..1eff5398ae --- /dev/null +++ b/tests/flytekit/unit/extras/pydantic_transformer/test_generice_idl_dataclass_in_pydantic_basemodel.py @@ -0,0 +1,117 @@ +from pydantic import BaseModel, Field +import pytest +import os +import copy + +from flytekit import task, workflow + +@pytest.fixture(autouse=True) +def prepare_env_variable(): + try: + origin_env = copy.deepcopy(os.environ.copy()) + os.environ["FLYTE_USE_OLD_DC_FORMAT"] = "True" + yield + finally: + os.environ = origin_env + + +def test_dataclasss_in_pydantic_basemodel(): + from dataclasses import dataclass + + @dataclass + class InnerBM: + a: int = -1 + b: float = 3.14 + c: str = "Hello, Flyte" + d: bool = False + + class BM(BaseModel): + a: int = -1 + b: float = 3.14 + c: str = "Hello, Flyte" + d: bool = False + inner_bm: InnerBM = Field(default_factory=lambda: InnerBM()) + + @task + def t_bm(bm: BM): + assert isinstance(bm, BM) + assert isinstance(bm.inner_bm, InnerBM) + + @task + def t_inner(inner_bm: InnerBM): + assert isinstance(inner_bm, InnerBM) + + @task + def t_test_primitive_attributes(a: int, b: float, c: str, d: bool): + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert a == -1 + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert b == 3.14 + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert c == "Hello, Flyte" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + assert d is False + print("All primitive attributes passed strict type checks.") + + @workflow + def wf(bm: BM): + t_bm(bm=bm) + t_inner(inner_bm=bm.inner_bm) + t_test_primitive_attributes(a=bm.a, b=bm.b, c=bm.c, d=bm.d) + t_test_primitive_attributes( + a=bm.inner_bm.a, b=bm.inner_bm.b, c=bm.inner_bm.c, d=bm.inner_bm.d + ) + + bm = BM() + wf(bm=bm) + + +def test_pydantic_dataclasss_in_pydantic_basemodel(): + from pydantic.dataclasses import dataclass + + @dataclass + class InnerBM: + a: int = -1 + b: float = 3.14 + c: str = "Hello, Flyte" + d: bool = False + + class BM(BaseModel): + a: int = -1 + b: float = 3.14 + c: str = "Hello, Flyte" + d: bool = False + inner_bm: InnerBM = Field(default_factory=lambda: InnerBM()) + + @task + def t_bm(bm: BM): + assert isinstance(bm, BM) + assert isinstance(bm.inner_bm, InnerBM) + + @task + def t_inner(inner_bm: InnerBM): + assert isinstance(inner_bm, InnerBM) + + @task + def t_test_primitive_attributes(a: int, b: float, c: str, d: bool): + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert a == -1 + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert b == 3.14 + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert c == "Hello, Flyte" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + assert d is False + print("All primitive attributes passed strict type checks.") + + @workflow + def wf(bm: BM): + t_bm(bm=bm) + t_inner(inner_bm=bm.inner_bm) + t_test_primitive_attributes(a=bm.a, b=bm.b, c=bm.c, d=bm.d) + t_test_primitive_attributes( + a=bm.inner_bm.a, b=bm.inner_bm.b, c=bm.inner_bm.c, d=bm.inner_bm.d + ) + + bm = BM() + wf(bm=bm) diff --git a/tests/flytekit/unit/extras/pydantic_transformer/test_generice_idl_pydantic_basemodel_in_dataclass.py b/tests/flytekit/unit/extras/pydantic_transformer/test_generice_idl_pydantic_basemodel_in_dataclass.py new file mode 100644 index 0000000000..9035b299fc --- /dev/null +++ b/tests/flytekit/unit/extras/pydantic_transformer/test_generice_idl_pydantic_basemodel_in_dataclass.py @@ -0,0 +1,145 @@ +import copy +import os + +import pytest +from pydantic import BaseModel + +from flytekit import task, workflow + +@pytest.fixture(autouse=True) +def prepare_env_variable(): + try: + origin_env = copy.deepcopy(os.environ.copy()) + os.environ["FLYTE_USE_OLD_DC_FORMAT"] = "True" + yield + finally: + os.environ = origin_env + + +def test_pydantic_basemodel_in_dataclass(): + from dataclasses import dataclass, field + + # Define InnerBM using Pydantic BaseModel + class InnerBM(BaseModel): + a: int = -1 + b: float = 3.14 + c: str = "Hello, Flyte" + d: bool = False + + # Define the dataclass DC + @dataclass + class DC: + a: int = -1 + b: float = 3.14 + c: str = "Hello, Flyte" + d: bool = False + inner_bm: InnerBM = field(default_factory=lambda: InnerBM()) + + # Task to check DC instance + @task + def t_dc(dc: DC): + assert isinstance(dc, DC) + assert isinstance(dc.inner_bm, InnerBM) + + # Task to check InnerBM instance + @task + def t_inner(inner_bm: InnerBM): + assert isinstance(inner_bm, InnerBM) + + # Task to check primitive attributes + @task + def t_test_primitive_attributes(a: int, b: float, c: str, d: bool): + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert a == -1 + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert b == 3.14 + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert c == "Hello, Flyte" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + assert d is False + print("All primitive attributes passed strict type checks.") + + # Define the workflow + @workflow + def wf(dc: DC): + t_dc(dc=dc) + t_inner(inner_bm=dc.inner_bm) + t_test_primitive_attributes(a=dc.a, b=dc.b, c=dc.c, d=dc.d) + t_test_primitive_attributes( + a=dc.inner_bm.a, b=dc.inner_bm.b, c=dc.inner_bm.c, d=dc.inner_bm.d + ) + + # Create an instance of DC and run the workflow + dc_instance = DC() + with pytest.raises(Exception) as excinfo: + wf(dc=dc_instance) + + # Assert that the error message contains "UnserializableField" + assert "is not serializable" in str( + excinfo.value + ), f"Unexpected error: {excinfo.value}" + + +def test_pydantic_basemodel_in_pydantic_dataclass(): + from pydantic import Field + from pydantic.dataclasses import dataclass + + # Define InnerBM using Pydantic BaseModel + class InnerBM(BaseModel): + a: int = -1 + b: float = 3.14 + c: str = "Hello, Flyte" + d: bool = False + + # Define the Pydantic dataclass DC + @dataclass + class DC: + a: int = -1 + b: float = 3.14 + c: str = "Hello, Flyte" + d: bool = False + inner_bm: InnerBM = Field(default_factory=lambda: InnerBM()) + + # Task to check DC instance + @task + def t_dc(dc: DC): + assert isinstance(dc, DC) + assert isinstance(dc.inner_bm, InnerBM) + + # Task to check InnerBM instance + @task + def t_inner(inner_bm: InnerBM): + assert isinstance(inner_bm, InnerBM) + + # Task to check primitive attributes + @task + def t_test_primitive_attributes(a: int, b: float, c: str, d: bool): + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert a == -1 + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert b == 3.14 + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert c == "Hello, Flyte" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + assert d is False + print("All primitive attributes passed strict type checks.") + + # Define the workflow + @workflow + def wf(dc: DC): + t_dc(dc=dc) + t_inner(inner_bm=dc.inner_bm) + t_test_primitive_attributes(a=dc.a, b=dc.b, c=dc.c, d=dc.d) + t_test_primitive_attributes( + a=dc.inner_bm.a, b=dc.inner_bm.b, c=dc.inner_bm.c, d=dc.inner_bm.d + ) + + # Create an instance of DC and run the workflow + dc_instance = DC() + with pytest.raises(Exception) as excinfo: + wf(dc=dc_instance) + + # Assert that the error message contains "UnserializableField" + assert "is not serializable" in str( + excinfo.value + ), f"Unexpected error: {excinfo.value}" diff --git a/tests/flytekit/unit/extras/pydantic_transformer/test_generice_idl_pydantic_basemodel_transformer.py b/tests/flytekit/unit/extras/pydantic_transformer/test_generice_idl_pydantic_basemodel_transformer.py new file mode 100644 index 0000000000..8ef7acfbec --- /dev/null +++ b/tests/flytekit/unit/extras/pydantic_transformer/test_generice_idl_pydantic_basemodel_transformer.py @@ -0,0 +1,700 @@ +import copy +import os +import tempfile +from enum import Enum +from typing import Dict, List, Optional +from unittest.mock import patch + +import pytest +from google.protobuf import json_format as _json_format +from google.protobuf import struct_pb2 as _struct +from pydantic import BaseModel, Field + +from flytekit import task, workflow +from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.type_engine import TypeEngine +from flytekit.models.literals import Literal, Scalar +from flytekit.types.directory import FlyteDirectory +from flytekit.types.file import FlyteFile +from flytekit.types.schema import FlyteSchema +from flytekit.types.structured import StructuredDataset + + +@pytest.fixture(autouse=True) +def prepare_env_variable(): + try: + origin_env = copy.deepcopy(os.environ.copy()) + os.environ["FLYTE_USE_OLD_DC_FORMAT"] = "True" + yield + finally: + os.environ = origin_env + + +class Status(Enum): + PENDING = "pending" + APPROVED = "approved" + REJECTED = "rejected" + + +@pytest.fixture +def local_dummy_file(): + fd, path = tempfile.mkstemp() + try: + with os.fdopen(fd, "w") as tmp: + tmp.write("Hello FlyteFile") + yield path + finally: + os.remove(path) + + +@pytest.fixture +def local_dummy_directory(): + temp_dir = tempfile.TemporaryDirectory() + try: + with open(os.path.join(temp_dir.name, "file"), "w") as tmp: + tmp.write("Hello FlyteDirectory") + yield temp_dir.name + finally: + temp_dir.cleanup() + + +def test_flytetypes_in_pydantic_basemodel_wf(local_dummy_file, local_dummy_directory): + class InnerBM(BaseModel): + flytefile: FlyteFile = Field( + default_factory=lambda: FlyteFile(local_dummy_file) + ) + flytedir: FlyteDirectory = Field( + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) + + class BM(BaseModel): + flytefile: FlyteFile = Field( + default_factory=lambda: FlyteFile(local_dummy_file) + ) + flytedir: FlyteDirectory = Field( + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) + inner_bm: InnerBM = Field(default_factory=lambda: InnerBM()) + + @task + def t1(path: FlyteFile) -> FlyteFile: + return path + + @task + def t2(path: FlyteDirectory) -> FlyteDirectory: + return path + + @workflow + def wf(bm: BM) -> (FlyteFile, FlyteFile, FlyteDirectory, FlyteDirectory): + f1 = t1(path=bm.flytefile) + f2 = t1(path=bm.inner_bm.flytefile) + d1 = t2(path=bm.flytedir) + d2 = t2(path=bm.inner_bm.flytedir) + return f1, f2, d1, d2 + + o1, o2, o3, o4 = wf(bm=BM()) + with open(o1, "r") as fh: + assert fh.read() == "Hello FlyteFile" + + with open(o2, "r") as fh: + assert fh.read() == "Hello FlyteFile" + + with open(os.path.join(o3, "file"), "r") as fh: + assert fh.read() == "Hello FlyteDirectory" + + with open(os.path.join(o4, "file"), "r") as fh: + assert fh.read() == "Hello FlyteDirectory" + + +def test_all_types_in_pydantic_basemodel_wf(local_dummy_file, local_dummy_directory): + class InnerBM(BaseModel): + a: int = -1 + b: float = 2.1 + c: str = "Hello, Flyte" + d: bool = False + e: dict = Field(default_factory=lambda: {"key": "value"}) + f: FlyteFile = Field(default_factory=lambda: FlyteFile(local_dummy_file)) + g: FlyteDirectory = Field( + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) + enum_status: Status = Field(default=Status.PENDING) + + class BM(BaseModel): + a: int = -1 + b: float = 2.1 + c: str = "Hello, Flyte" + d: bool = False + e: dict = Field(default_factory=lambda: {"key": "value"}) + f: FlyteFile = Field(default_factory=lambda: FlyteFile(local_dummy_file)) + g: FlyteDirectory = Field( + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) + inner_bm: InnerBM = Field(default_factory=lambda: InnerBM()) + enum_status: Status = Field(default=Status.PENDING) + + @task + def t_inner(inner_bm: InnerBM): + assert type(inner_bm) is InnerBM + + assert type(inner_bm.f) is FlyteFile + with open(inner_bm.f, "r") as f: + assert f.read() == "Hello FlyteFile" + + assert type(inner_bm.g) is FlyteDirectory + assert not inner_bm.g.downloaded + with open(os.path.join(inner_bm.g, "file"), "r") as fh: + assert fh.read() == "Hello FlyteDirectory" + assert inner_bm.g.downloaded + + # enum: Status + assert inner_bm.enum_status == Status.PENDING + + @task + def t_test_all_attributes( + a: int, + b: float, + c: str, + d: bool, + e: dict, + f: FlyteFile, + g: FlyteDirectory, + enum_status: Status, + ): + # Strict type checks for simple types + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + + # Strict type check for a generic dict + assert isinstance(e, dict), "m is not dict" + + # Strict type check for FlyteFile + assert isinstance(f, FlyteFile), "n is not FlyteFile" + + # Strict type check for FlyteDirectory + assert isinstance(g, FlyteDirectory), "o is not FlyteDirectory" + + # Strict type check for Enum + assert isinstance(enum_status, Status), "enum_status is not Status" + + print("All attributes passed strict type checks.") + + @workflow + def wf(bm: BM): + t_inner(bm.inner_bm) + t_test_all_attributes( + a=bm.a, + b=bm.b, + c=bm.c, + d=bm.d, + e=bm.e, + f=bm.f, + g=bm.g, + enum_status=bm.enum_status, + ) + + t_test_all_attributes( + a=bm.inner_bm.a, + b=bm.inner_bm.b, + c=bm.inner_bm.c, + d=bm.inner_bm.d, + e=bm.inner_bm.e, + f=bm.inner_bm.f, + g=bm.inner_bm.g, + enum_status=bm.inner_bm.enum_status, + ) + + wf(bm=BM()) + + +def test_all_types_with_optional_in_pydantic_basemodel_wf( + local_dummy_file, local_dummy_directory +): + class InnerBM(BaseModel): + a: Optional[int] = -1 + b: Optional[float] = 2.1 + c: Optional[str] = "Hello, Flyte" + d: Optional[bool] = False + m: Optional[dict] = Field(default_factory=lambda: {"key": "value"}) + n: Optional[FlyteFile] = Field( + default_factory=lambda: FlyteFile(local_dummy_file) + ) + o: Optional[FlyteDirectory] = Field( + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) + enum_status: Optional[Status] = Field(default=Status.PENDING) + + class BM(BaseModel): + a: Optional[int] = -1 + b: Optional[float] = 2.1 + c: Optional[str] = "Hello, Flyte" + d: Optional[bool] = False + m: Optional[dict] = Field(default_factory=lambda: {"key": "value"}) + n: Optional[FlyteFile] = Field( + default_factory=lambda: FlyteFile(local_dummy_file) + ) + o: Optional[FlyteDirectory] = Field( + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) + inner_bm: Optional[InnerBM] = Field(default_factory=lambda: InnerBM()) + enum_status: Optional[Status] = Field(default=Status.PENDING) + + @task + def t_inner(inner_bm: InnerBM): + assert type(inner_bm) is InnerBM + + # n: FlyteFile + assert type(inner_bm.n) is FlyteFile + with open(inner_bm.n, "r") as f: + assert f.read() == "Hello FlyteFile" + # o: FlyteDirectory + assert type(inner_bm.o) is FlyteDirectory + assert not inner_bm.o.downloaded + with open(os.path.join(inner_bm.o, "file"), "r") as fh: + assert fh.read() == "Hello FlyteDirectory" + assert inner_bm.o.downloaded + + # enum: Status + assert inner_bm.enum_status == Status.PENDING + + @task + def t_test_all_attributes( + a: Optional[int], + b: Optional[float], + c: Optional[str], + d: Optional[bool], + m: Optional[dict], + n: Optional[FlyteFile], + o: Optional[FlyteDirectory], + enum_status: Optional[Status], + ): + # Strict type checks for simple types + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + + # Strict type check for a generic dict + assert isinstance(m, dict), "m is not dict" + + # Strict type check for FlyteFile + assert isinstance(n, FlyteFile), "n is not FlyteFile" + + # Strict type check for FlyteDirectory + assert isinstance(o, FlyteDirectory), "o is not FlyteDirectory" + + # Strict type check for Enum + assert isinstance(enum_status, Status), "enum_status is not Status" + + @workflow + def wf(bm: BM): + t_inner(bm.inner_bm) + t_test_all_attributes( + a=bm.a, + b=bm.b, + c=bm.c, + d=bm.d, + m=bm.m, + n=bm.n, + o=bm.o, + enum_status=bm.enum_status, + ) + + wf(bm=BM()) + + +def test_all_types_with_optional_and_none_in_pydantic_basemodel_wf( + local_dummy_file, local_dummy_directory +): + class InnerBM(BaseModel): + a: Optional[int] = None + b: Optional[float] = None + c: Optional[str] = None + d: Optional[bool] = None + e: Optional[List[int]] = None + f: Optional[List[FlyteFile]] = None + g: Optional[List[List[int]]] = None + h: Optional[List[Dict[int, bool]]] = None + i: Optional[Dict[int, bool]] = None + j: Optional[Dict[int, FlyteFile]] = None + k: Optional[Dict[int, List[int]]] = None + l: Optional[Dict[int, Dict[int, int]]] = None + m: Optional[dict] = None + n: Optional[FlyteFile] = None + o: Optional[FlyteDirectory] = None + enum_status: Optional[Status] = None + + class BM(BaseModel): + a: Optional[int] = None + b: Optional[float] = None + c: Optional[str] = None + d: Optional[bool] = None + e: Optional[List[int]] = None + f: Optional[List[FlyteFile]] = None + g: Optional[List[List[int]]] = None + h: Optional[List[Dict[int, bool]]] = None + i: Optional[Dict[int, bool]] = None + j: Optional[Dict[int, FlyteFile]] = None + k: Optional[Dict[int, List[int]]] = None + l: Optional[Dict[int, Dict[int, int]]] = None + m: Optional[dict] = None + n: Optional[FlyteFile] = None + o: Optional[FlyteDirectory] = None + inner_bm: Optional[InnerBM] = None + enum_status: Optional[Status] = None + + @task + def t_inner(inner_bm: Optional[InnerBM]): + return inner_bm + + @task + def t_test_all_attributes( + a: Optional[int], + b: Optional[float], + c: Optional[str], + d: Optional[bool], + e: Optional[List[int]], + f: Optional[List[FlyteFile]], + g: Optional[List[List[int]]], + h: Optional[List[Dict[int, bool]]], + i: Optional[Dict[int, bool]], + j: Optional[Dict[int, FlyteFile]], + k: Optional[Dict[int, List[int]]], + l: Optional[Dict[int, Dict[int, int]]], + m: Optional[dict], + n: Optional[FlyteFile], + o: Optional[FlyteDirectory], + enum_status: Optional[Status], + ): + return + + @workflow + def wf(bm: BM): + t_inner(bm.inner_bm) + t_test_all_attributes( + a=bm.a, + b=bm.b, + c=bm.c, + d=bm.d, + e=bm.e, + f=bm.f, + g=bm.g, + h=bm.h, + i=bm.i, + j=bm.j, + k=bm.k, + l=bm.l, + m=bm.m, + n=bm.n, + o=bm.o, + enum_status=bm.enum_status, + ) + + wf(bm=BM()) + + +def test_input_from_flyte_console_pydantic_basemodel( + local_dummy_file, local_dummy_directory +): + # Flyte Console will send the input data as protobuf Struct + + class InnerBM(BaseModel): + a: int = -1 + b: float = 2.1 + c: str = "Hello, Flyte" + d: bool = False + e: List[int] = Field(default_factory=lambda: [0, 1, 2, -1, -2]) + f: List[FlyteFile] = Field( + default_factory=lambda: [FlyteFile(local_dummy_file)] + ) + g: List[List[int]] = Field(default_factory=lambda: [[0], [1], [-1]]) + h: List[Dict[int, bool]] = Field( + default_factory=lambda: [{0: False}, {1: True}, {-1: True}] + ) + i: Dict[int, bool] = Field( + default_factory=lambda: {0: False, 1: True, -1: False} + ) + j: Dict[int, FlyteFile] = Field( + default_factory=lambda: { + 0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file), + } + ) + k: Dict[int, List[int]] = Field(default_factory=lambda: {0: [0, 1, -1]}) + l: Dict[int, Dict[int, int]] = Field(default_factory=lambda: {1: {-1: 0}}) + m: dict = Field(default_factory=lambda: {"key": "value"}) + n: FlyteFile = Field(default_factory=lambda: FlyteFile(local_dummy_file)) + o: FlyteDirectory = Field( + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) + enum_status: Status = Field(default=Status.PENDING) + + class BM(BaseModel): + a: int = -1 + b: float = 2.1 + c: str = "Hello, Flyte" + d: bool = False + e: List[int] = Field(default_factory=lambda: [0, 1, 2, -1, -2]) + f: List[FlyteFile] = Field( + default_factory=lambda: [ + FlyteFile(local_dummy_file), + ] + ) + g: List[List[int]] = Field(default_factory=lambda: [[0], [1], [-1]]) + h: List[Dict[int, bool]] = Field( + default_factory=lambda: [{0: False}, {1: True}, {-1: True}] + ) + i: Dict[int, bool] = Field( + default_factory=lambda: {0: False, 1: True, -1: False} + ) + j: Dict[int, FlyteFile] = Field( + default_factory=lambda: { + 0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file), + } + ) + k: Dict[int, List[int]] = Field(default_factory=lambda: {0: [0, 1, -1]}) + l: Dict[int, Dict[int, int]] = Field(default_factory=lambda: {1: {-1: 0}}) + m: dict = Field(default_factory=lambda: {"key": "value"}) + n: FlyteFile = Field(default_factory=lambda: FlyteFile(local_dummy_file)) + o: FlyteDirectory = Field( + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) + inner_bm: InnerBM = Field(default_factory=lambda: InnerBM()) + enum_status: Status = Field(default=Status.PENDING) + + @task + def t_inner(inner_bm: InnerBM): + assert type(inner_bm) is InnerBM + + # f: List[FlyteFile] + for ff in inner_bm.f: + assert type(ff) is FlyteFile + with open(ff, "r") as f: + assert f.read() == "Hello FlyteFile" + # j: Dict[int, FlyteFile] + for _, ff in inner_bm.j.items(): + assert type(ff) is FlyteFile + with open(ff, "r") as f: + assert f.read() == "Hello FlyteFile" + # n: FlyteFile + assert type(inner_bm.n) is FlyteFile + with open(inner_bm.n, "r") as f: + assert f.read() == "Hello FlyteFile" + # o: FlyteDirectory + assert type(inner_bm.o) is FlyteDirectory + assert not inner_bm.o.downloaded + with open(os.path.join(inner_bm.o, "file"), "r") as fh: + assert fh.read() == "Hello FlyteDirectory" + assert inner_bm.o.downloaded + + # enum: Status + assert inner_bm.enum_status == Status.PENDING + + def t_test_all_attributes( + a: int, + b: float, + c: str, + d: bool, + e: List[int], + f: List[FlyteFile], + g: List[List[int]], + h: List[Dict[int, bool]], + i: Dict[int, bool], + j: Dict[int, FlyteFile], + k: Dict[int, List[int]], + l: Dict[int, Dict[int, int]], + m: dict, + n: FlyteFile, + o: FlyteDirectory, + enum_status: Status, + ): + # Strict type checks for simple types + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert a == -1 + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + + # Strict type checks for List[int] + assert isinstance(e, list) and all( + isinstance(i, int) for i in e + ), "e is not List[int]" + + # Strict type checks for List[FlyteFile] + assert isinstance(f, list) and all( + isinstance(i, FlyteFile) for i in f + ), "f is not List[FlyteFile]" + + # Strict type checks for List[List[int]] + assert isinstance(g, list) and all( + isinstance(i, list) and all(isinstance(j, int) for j in i) for i in g + ), "g is not List[List[int]]" + + # Strict type checks for List[Dict[int, bool]] + assert isinstance(h, list) and all( + isinstance(i, dict) + and all(isinstance(k, int) and isinstance(v, bool) for k, v in i.items()) + for i in h + ), "h is not List[Dict[int, bool]]" + + # Strict type checks for Dict[int, bool] + assert isinstance(i, dict) and all( + isinstance(k, int) and isinstance(v, bool) for k, v in i.items() + ), "i is not Dict[int, bool]" + + # Strict type checks for Dict[int, FlyteFile] + assert isinstance(j, dict) and all( + isinstance(k, int) and isinstance(v, FlyteFile) for k, v in j.items() + ), "j is not Dict[int, FlyteFile]" + + # Strict type checks for Dict[int, List[int]] + assert isinstance(k, dict) and all( + isinstance(k, int) + and isinstance(v, list) + and all(isinstance(i, int) for i in v) + for k, v in k.items() + ), "k is not Dict[int, List[int]]" + + # Strict type checks for Dict[int, Dict[int, int]] + assert isinstance(l, dict) and all( + isinstance(k, int) + and isinstance(v, dict) + and all( + isinstance(sub_k, int) and isinstance(sub_v, int) + for sub_k, sub_v in v.items() + ) + for k, v in l.items() + ), "l is not Dict[int, Dict[int, int]]" + + # Strict type check for a generic dict + assert isinstance(m, dict), "m is not dict" + + # Strict type check for FlyteFile + assert isinstance(n, FlyteFile), "n is not FlyteFile" + + # Strict type check for FlyteDirectory + assert isinstance(o, FlyteDirectory), "o is not FlyteDirectory" + + # Strict type check for Enum + assert isinstance(enum_status, Status), "enum_status is not Status" + + print("All attributes passed strict type checks.") + + # This is the old dataclass serialization behavior. + # https://github.com/flyteorg/flytekit/blob/94786cfd4a5c2c3b23ac29bmd6f04d0553fa1beb/flytekit/core/type_engine.py#L702-L728 + bm = BM() + json_str = bm.model_dump_json() + upstream_output = Literal( + scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())) + ) + + downstream_input = TypeEngine.to_python_value( + FlyteContextManager.current_context(), upstream_output, BM + ) + t_inner(downstream_input.inner_bm) + t_test_all_attributes( + a=downstream_input.a, + b=downstream_input.b, + c=downstream_input.c, + d=downstream_input.d, + e=downstream_input.e, + f=downstream_input.f, + g=downstream_input.g, + h=downstream_input.h, + i=downstream_input.i, + j=downstream_input.j, + k=downstream_input.k, + l=downstream_input.l, + m=downstream_input.m, + n=downstream_input.n, + o=downstream_input.o, + enum_status=downstream_input.enum_status, + ) + t_test_all_attributes( + a=downstream_input.inner_bm.a, + b=downstream_input.inner_bm.b, + c=downstream_input.inner_bm.c, + d=downstream_input.inner_bm.d, + e=downstream_input.inner_bm.e, + f=downstream_input.inner_bm.f, + g=downstream_input.inner_bm.g, + h=downstream_input.inner_bm.h, + i=downstream_input.inner_bm.i, + j=downstream_input.inner_bm.j, + k=downstream_input.inner_bm.k, + l=downstream_input.inner_bm.l, + m=downstream_input.inner_bm.m, + n=downstream_input.inner_bm.n, + o=downstream_input.inner_bm.o, + enum_status=downstream_input.inner_bm.enum_status, + ) + + +def test_flyte_types_deserialization_not_called_when_using_constructor( + local_dummy_file, local_dummy_directory +): + # Mocking both FlyteFilePathTransformer and FlyteDirectoryPathTransformer + with patch( + "flytekit.types.file.FlyteFilePathTransformer.to_python_value" + ) as mock_file_to_python_value, patch( + "flytekit.types.directory.FlyteDirToMultipartBlobTransformer.to_python_value" + ) as mock_directory_to_python_value, patch( + "flytekit.types.structured.StructuredDatasetTransformerEngine.to_python_value" + ) as mock_structured_dataset_to_python_value, patch( + "flytekit.types.schema.FlyteSchemaTransformer.to_python_value" + ) as mock_schema_to_python_value: + + # Define your Pydantic model + class BM(BaseModel): + ff: FlyteFile = Field(default_factory=lambda: FlyteFile(local_dummy_file)) + fd: FlyteDirectory = Field( + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) + sd: StructuredDataset = Field(default_factory=lambda: StructuredDataset()) + fsc: FlyteSchema = Field(default_factory=lambda: FlyteSchema()) + + # Create an instance of BM (should not call the deserialization) + BM() + + mock_file_to_python_value.assert_not_called() + mock_directory_to_python_value.assert_not_called() + mock_structured_dataset_to_python_value.assert_not_called() + mock_schema_to_python_value.assert_not_called() + + +def test_flyte_types_deserialization_called_once_when_using_model_validate_json( + local_dummy_file, local_dummy_directory +): + """ + It's hard to mock flyte schema and structured dataset in tests, so we will only test FlyteFile and FlyteDirectory + """ + with patch( + "flytekit.types.file.FlyteFilePathTransformer.to_python_value" + ) as mock_file_to_python_value, patch( + "flytekit.types.directory.FlyteDirToMultipartBlobTransformer.to_python_value" + ) as mock_directory_to_python_value: + # Define your Pydantic model + class BM(BaseModel): + ff: FlyteFile = Field(default_factory=lambda: FlyteFile(local_dummy_file)) + fd: FlyteDirectory = Field( + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) + + # Create instances of FlyteFile and FlyteDirectory + bm = BM( + ff=FlyteFile(local_dummy_file), fd=FlyteDirectory(local_dummy_directory) + ) + + # Serialize and Deserialize with model_validate_json + json_str = bm.model_dump_json() + bm.model_validate_json( + json_data=json_str, strict=False, context={"deserialize": True} + ) + + # Assert that the to_python_value method was called once + mock_file_to_python_value.assert_called_once() + mock_directory_to_python_value.assert_called_once()