From 29f69eab26726974554c0d5966dc0873bfee4482 Mon Sep 17 00:00:00 2001 From: augray Date: Fri, 7 Oct 2022 16:58:10 -0700 Subject: [PATCH] Add support for specifying k8s tolerations (#192) People may want/need more control over which nodes their jobs run on. K8s tolerations is a mechanism many people use to achieve this. If we don't support them, people may not be able to run their Sematic funcs where they want them to go. This adds support for it. It also adds support for enums, because the natural way to support tolerations includes some enum types that are currently serialized/deserialized with Sematic's type stuff. It fails without actual enum support. I've wanted enum support a few times anyway, so I took this opportunity to "bite the bullet." --- docs/changelog.md | 3 + sematic/__init__.py | 3 + sematic/resolvers/resource_requirements.py | 153 +++++++++++++++++- sematic/resolvers/tests/BUILD | 10 ++ .../tests/test_resource_requirements.py | 55 +++++++ sematic/scheduling/kubernetes.py | 12 +- sematic/scheduling/tests/test_kubernetes.py | 27 ++++ sematic/types/BUILD | 1 + sematic/types/__init__.py | 1 + sematic/types/casting.py | 10 ++ sematic/types/registry.py | 31 +++- sematic/types/serialization.py | 11 ++ sematic/types/tests/test_registry.py | 13 +- sematic/types/types/BUILD | 8 + sematic/types/types/enum.py | 51 ++++++ sematic/types/types/tests/BUILD | 9 ++ sematic/types/types/tests/test_enum.py | 67 ++++++++ 17 files changed, 460 insertions(+), 5 deletions(-) create mode 100644 sematic/resolvers/tests/test_resource_requirements.py create mode 100644 sematic/types/types/enum.py create mode 100644 sematic/types/types/tests/test_enum.py diff --git a/docs/changelog.md b/docs/changelog.md index 7ae4e4b67..2be803a5d 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,6 +4,9 @@ Lines for version numbers should always be formatted as `* MAJOR.MINOR.PATCH` with nothing else on the line. --> * HEAD + * [feature] Support Enums + * [feature] Allow specifying Kubernetes tolerations for cloud jobs + * [feature] Redesigned log UI * 0.15.1 * [bugfix] Ensure log ingestion happens when using bazel cloud jobs and `dev=False` * [bugfix] Avoid spamming the server with requests for logs for incomplete runs diff --git a/sematic/__init__.py b/sematic/__init__.py index 991079bfe..5aed9914f 100644 --- a/sematic/__init__.py +++ b/sematic/__init__.py @@ -27,6 +27,9 @@ from sematic.resolvers.resource_requirements import ( # noqa: F401,E402 KubernetesResourceRequirements, KubernetesSecretMount, + KubernetesToleration, + KubernetesTolerationEffect, + KubernetesTolerationOperator, ResourceRequirements, ) from sematic.versions import CURRENT_VERSION_STR as __version__ # noqa: F401,E402 diff --git a/sematic/resolvers/resource_requirements.py b/sematic/resolvers/resource_requirements.py index 684fd1bfb..08a476a8e 100644 --- a/sematic/resolvers/resource_requirements.py +++ b/sematic/resolvers/resource_requirements.py @@ -1,6 +1,7 @@ # Standard Library from dataclasses import dataclass, field -from typing import Dict +from enum import Enum, unique +from typing import Dict, List, Optional, Union KUBERNETES_SECRET_NAME = "sematic-func-secrets" @@ -49,6 +50,150 @@ class KubernetesSecretMount: file_secret_root_path: str = "/secrets" +@unique +class KubernetesTolerationOperator(Enum): + """The way that a toleration should be checked to see if it applies + + See Kubernetes documentation for more: + https://kubernetes.io/docs/concepts/scheduling-eviction/taint-and-toleration/ + + Options + ------- + Equal: + value must be specified, and must be equal for the toleration and the taint + for the toleration to be considered to apply. In addition to this condition, + the "effect" must be equal for the toleration and the taint for the toleration + to be considered to apply. + Exists: + value is not required. If a taint with the given key exists on the node, + the toleration is considered to apply. In addition to this condition, + the "effect" must be equal for the toleration and the taint for the toleration + to be considered to apply. + """ + + Equal = "Equal" + Exists = "Exists" + + +@unique +class KubernetesTolerationEffect(Enum): + """The effect that the toleration is meant to tolerate + + See Kubernetes documentation for more: + https://kubernetes.io/docs/concepts/scheduling-eviction/taint-and-toleration/ + + Options + ------- + NoSchedule: + The toleration indicates that the pod can run on the node even + if it has specified a NoSchedule taint, assuming the rest of + the toleration matches the taint. + PreferNoSchedule: + The toleration indicates that the pod can run on the node even + if it has specified a PreferNoSchedule taint, assuming the rest + of the toleration matches the taint. + NoExecute: + The pod will not be evicted from the node even if the node has + specified a NoExecute taint, assuming the rest of the toleration + matches the taint. + All: + The pod will not be evicted from the node even if the node has + any kind of taint, assuming the rest of the toleration + matches the taint. + """ + + NoSchedule = "NoSchedule" + PreferNoSchedule = "PreferNoSchedule" + NoExecute = "NoExecute" + All = "All" + + +@dataclass +class KubernetesToleration: + """Toleration for a node taint, enabling the pod for the function to run on the node + + See Kubernetes documentation for more: + https://kubernetes.io/docs/concepts/scheduling-eviction/taint-and-toleration/ + + Attributes + ---------- + key: + The key for the node taint intended to be tolerated. If empty, means + to match all keys AND all values + operator: + The way to compare the key/value pair to the node taint's key/value pair + to see if the toleration applies + effect: + The effect of the node taint the toleration is intended to tolerate. + Leaving it empty means to tolerate all effects. + value: + If the operator is Equals, this value will be compared to the value + on the node taint to see if this toleration applies. + toleration_seconds: + Only specified when effect is NoExecute (otherwise is an error). It + specifies the amount of time the pod can continue executing on a node + with a NoExecute taint + """ + + key: Optional[str] = None + operator: KubernetesTolerationOperator = KubernetesTolerationOperator.Equal + effect: KubernetesTolerationEffect = KubernetesTolerationEffect.All + value: Optional[str] = None + toleration_seconds: Optional[int] = None + + def to_api_keyword_args(self) -> Dict[str, Optional[Union[str, int]]]: + """Convert to the format for kwargs the API python client API for tolerations""" + effect: Optional[str] = self.effect.value + if self.effect == KubernetesTolerationEffect.All: + # the actual API makes "all" the default behavior with no other way to + # specify + effect = None + operator = self.operator.value + return dict( + effect=effect, + key=self.key, + operator=operator, + toleration_seconds=self.toleration_seconds, + value=self.value, + ) + + def __post_init__(self): + """Ensure that the values in the toleration are valid; raise otherwise + + Raises + ------ + ValueError: + If the values are not valid + """ + if not (self.key is None or isinstance(self.key, str)): + raise ValueError(f"key must be None or a string, got: {self.key}") + if not isinstance(self.operator, KubernetesTolerationOperator): + raise ValueError( + f"operator must be a {KubernetesTolerationOperator}, got {self.operator}" + ) + if not isinstance(self.effect, KubernetesTolerationEffect): + raise ValueError( + f"effect must be a {KubernetesTolerationEffect}, got {self.effect}" + ) + if not (self.value is None or isinstance(self.value, str)): + raise ValueError(f"value must be None or a string, got: {self.value}") + if not ( + self.toleration_seconds is None or isinstance(self.toleration_seconds, int) + ): + raise ValueError( + "toleration_seconds must be None or an " + f"int, got: {self.toleration_seconds}" + ) + if ( + self.toleration_seconds is not None + and self.effect != KubernetesTolerationEffect.NoExecute + ): + raise ValueError( + "toleration_seconds should only be specified when the effect " + "is NoExecute." + ) + + @dataclass class KubernetesResourceRequirements: """Information on the Kubernetes resources required. @@ -69,11 +214,17 @@ class KubernetesResourceRequirements: secret_mounts: Requests to take the contents of Kubernetes secrets and expose them as environment variables or files on disk when running in the cloud. + tolerations: + If your Kubernetes configuration uses node taints to control which workloads + get scheduled on which nodes, this enables control over how your workload + interacts with these node taints. More information can be found here: + https://kubernetes.io/docs/concepts/scheduling-eviction/taint-and-toleration/ """ node_selector: Dict[str, str] = field(default_factory=dict) requests: Dict[str, str] = field(default_factory=dict) secret_mounts: KubernetesSecretMount = field(default_factory=KubernetesSecretMount) + tolerations: List[KubernetesToleration] = field(default_factory=list) @dataclass diff --git a/sematic/resolvers/tests/BUILD b/sematic/resolvers/tests/BUILD index 6635fb88c..1c51f4051 100644 --- a/sematic/resolvers/tests/BUILD +++ b/sematic/resolvers/tests/BUILD @@ -23,6 +23,16 @@ pytest_test( ], ) + +pytest_test( + name = "test_resource_requirements", + srcs = ["test_resource_requirements.py"], + deps = [ + "//sematic/resolvers:resource_requirements", + "//sematic/types:serialization", + ], +) + pytest_test( name = "test_silent_resolver", srcs = ["test_silent_resolver.py"], diff --git a/sematic/resolvers/tests/test_resource_requirements.py b/sematic/resolvers/tests/test_resource_requirements.py new file mode 100644 index 000000000..f5626acf6 --- /dev/null +++ b/sematic/resolvers/tests/test_resource_requirements.py @@ -0,0 +1,55 @@ +import pytest + +# Sematic +from sematic.resolvers.resource_requirements import ( + KubernetesResourceRequirements, + KubernetesSecretMount, + KubernetesToleration, + KubernetesTolerationEffect, + KubernetesTolerationOperator, + ResourceRequirements, +) +from sematic.types.serialization import ( + value_from_json_encodable, + value_to_json_encodable, +) + + +def test_is_serializable(): + requirements = ResourceRequirements( + kubernetes=KubernetesResourceRequirements( + node_selector={"foo": "bar"}, + requests={"cpu": "500m", "memory": "100Gi"}, + secret_mounts=KubernetesSecretMount( + environment_secrets={"a": "b"}, + file_secret_root_path="/foo/bar", + file_secrets={"c": "d"}, + ), + tolerations=[ + KubernetesToleration( + key="k", + value="v", + effect=KubernetesTolerationEffect.NoExecute, + operator=KubernetesTolerationOperator.Equal, + toleration_seconds=42, + ) + ], + ) + ) + encoded = value_to_json_encodable(requirements, ResourceRequirements) + decoded = value_from_json_encodable(encoded, ResourceRequirements) + assert decoded == requirements + + +def test_validation(): + with pytest.raises( + ValueError, + match="toleration_seconds should only be specified when the effect is NoExecute.", + ): + KubernetesToleration( + key="k", + value="v", + effect=KubernetesTolerationEffect.PreferNoSchedule, + operator=KubernetesTolerationOperator.Equal, + toleration_seconds=42, + ) diff --git a/sematic/scheduling/kubernetes.py b/sematic/scheduling/kubernetes.py index 337463c74..c5a4ba0dd 100644 --- a/sematic/scheduling/kubernetes.py +++ b/sematic/scheduling/kubernetes.py @@ -203,6 +203,7 @@ def _schedule_kubernetes_job( volumes = [] volume_mounts = [] secret_env_vars = [] + tolerations = [] if resource_requirements is not None: node_selector = resource_requirements.kubernetes.node_selector resource_requests = resource_requirements.kubernetes.requests @@ -214,10 +215,17 @@ def _schedule_kubernetes_job( secret_env_vars.extend( _environment_secrets(resource_requirements.kubernetes.secret_mounts) ) + tolerations = [ + kubernetes.client.V1Toleration( # type: ignore + **toleration.to_api_keyword_args() # type: ignore + ) + for toleration in resource_requirements.kubernetes.tolerations + ] logger.debug("kubernetes node_selector %s", node_selector) logger.debug("kubernetes resource requests %s", resource_requests) logger.debug("kubernetes volumes and mounts: %s, %s", volumes, volume_mounts) logger.debug("kubernetes environment secrets: %s", secret_env_vars) + logger.debug("kubernetes tolerations: %s", tolerations) pod_name_env_var = kubernetes.client.V1EnvVar( # type: ignore name=KUBERNETES_POD_NAME_ENV_VAR, @@ -228,6 +236,8 @@ def _schedule_kubernetes_job( ), ) + # See client documentation here: + # https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1Job.md job = kubernetes.client.V1Job( # type: ignore api_version="batch/v1", kind="Job", @@ -282,7 +292,7 @@ def _schedule_kubernetes_job( ) ], volumes=volumes, - tolerations=[], + tolerations=tolerations, restart_policy="Never", ), ), diff --git a/sematic/scheduling/tests/test_kubernetes.py b/sematic/scheduling/tests/test_kubernetes.py index da2625ab9..8b9ed4b3d 100644 --- a/sematic/scheduling/tests/test_kubernetes.py +++ b/sematic/scheduling/tests/test_kubernetes.py @@ -8,6 +8,9 @@ from sematic.resolvers.resource_requirements import ( KubernetesResourceRequirements, KubernetesSecretMount, + KubernetesToleration, + KubernetesTolerationEffect, + KubernetesTolerationOperator, ResourceRequirements, ) from sematic.scheduling.kubernetes import ( @@ -47,6 +50,16 @@ def test_schedule_kubernetes_job( file_secrets=file_secrets, file_secret_root_path=secret_root, ), + tolerations=[ + KubernetesToleration( + key="foo", + operator=KubernetesTolerationOperator.Equal, + effect=KubernetesTolerationEffect.NoExecute, + value="bar", + toleration_seconds=42, + ), + KubernetesToleration(), + ], ) ) mock_user_settings.return_value = {"KUBERNETES_NAMESPACE": namespace} @@ -86,6 +99,20 @@ def test_schedule_kubernetes_job( assert container.resources.limits == requests assert container.resources.requests == requests + tolerations = job.spec.template.spec.tolerations + assert len(tolerations) == 2 + assert tolerations[0].key == "foo" + assert tolerations[0].value == "bar" + assert tolerations[0].effect == "NoExecute" + assert tolerations[0].operator == "Equal" + assert tolerations[0].toleration_seconds == 42 + + assert tolerations[1].key is None + assert tolerations[1].value is None + assert tolerations[1].effect is None + assert tolerations[1].operator == "Equal" + assert tolerations[1].toleration_seconds is None + IS_ACTIVE_CASES = [ ( diff --git a/sematic/types/BUILD b/sematic/types/BUILD index 5af545eab..785af902c 100644 --- a/sematic/types/BUILD +++ b/sematic/types/BUILD @@ -6,6 +6,7 @@ sematic_py_lib( ":type", "//sematic/types/types:bool", "//sematic/types/types:dataclass", + "//sematic/types/types:enum", "//sematic/types/types:float", "//sematic/types/types:integer", "//sematic/types/types:none", diff --git a/sematic/types/__init__.py b/sematic/types/__init__.py index 6c706cf0b..766a8a67e 100644 --- a/sematic/types/__init__.py +++ b/sematic/types/__init__.py @@ -8,6 +8,7 @@ import sematic.types.types.bool # noqa: F401 import sematic.types.types.dataclass # noqa: F401 import sematic.types.types.dict # noqa: F401 +import sematic.types.types.enum # noqa: F401 import sematic.types.types.float # noqa: F401 import sematic.types.types.integer # noqa: F401 import sematic.types.types.list # noqa: F401 diff --git a/sematic/types/casting.py b/sematic/types/casting.py index 4f43e0821..bbfb803e5 100644 --- a/sematic/types/casting.py +++ b/sematic/types/casting.py @@ -1,12 +1,14 @@ # Standard Library import dataclasses import typing +from enum import Enum # Sematic from sematic.types.registry import ( DataclassKey, get_can_cast_func, get_safe_cast_func, + is_enum, is_parameterized_generic, ) @@ -47,6 +49,11 @@ def can_cast_type( if can_cast_func is None and dataclasses.is_dataclass(to_type): can_cast_func = get_can_cast_func(DataclassKey) + if can_cast_func is None and is_enum(to_type): + # enum types can register their own handlers, but if they don't + # we can use the default enum handler + can_cast_func = get_can_cast_func(Enum) + if can_cast_func is not None: return can_cast_func(from_type, to_type) @@ -86,6 +93,9 @@ def safe_cast( if _safe_cast_func is None and dataclasses.is_dataclass(type_): _safe_cast_func = get_safe_cast_func(DataclassKey) + if _safe_cast_func is None and is_enum(type_): + _safe_cast_func = get_safe_cast_func(Enum) + if _safe_cast_func is not None: return _safe_cast_func(value, type_) diff --git a/sematic/types/registry.py b/sematic/types/registry.py index c501b56d4..8686fb9f4 100644 --- a/sematic/types/registry.py +++ b/sematic/types/registry.py @@ -1,4 +1,7 @@ # Standard Library +import inspect +import typing +from enum import Enum from typing import ( Any, Callable, @@ -59,6 +62,21 @@ _CAN_CAST_REGISTRY: Dict[RegistryKey, CanCastTypeCallable] = {} +def is_enum(type_: typing.Type[Any]) -> bool: + """Determine if the given type is an enum type or not + + Parameters + ---------- + type_: + The type being checked + + Returns + ------- + True if the type is an enum type, False otherwise. + """ + return inspect.isclass(type_) and issubclass(type_, Enum) + + def register_can_cast( *types: RegistryKey, ) -> Callable[[CanCastTypeCallable], CanCastTypeCallable]: @@ -280,7 +298,7 @@ def assert_supported(type_): subclasses_type = issubclass(type_, type) except TypeError: subclasses_type = False - if type(type_) is type or subclasses_type: + if type(type_) is type or subclasses_type or is_enum(type_): return if not is_parameterized_generic(type_, raise_for_unparameterized=True): raise TypeError( @@ -360,8 +378,17 @@ def _is_supported_registry_key(type_: RegistryKey) -> bool: subclasses_type = issubclass(type_, type) except TypeError: subclasses_type = False + try: + subclasses_enum = issubclass(type_, Enum) + except TypeError: + subclasses_enum = False is_unparameterized_generic = type_ in SUPPORTED_GENERIC_TYPING_ANNOTATIONS.keys() - return type(type_) is type or subclasses_type or is_unparameterized_generic + return ( + type(type_) is type + or subclasses_type + or is_unparameterized_generic + or subclasses_enum + ) def _validate_registry_keys(*types_: RegistryKey): diff --git a/sematic/types/serialization.py b/sematic/types/serialization.py index abe0372ae..b290f9fdc 100644 --- a/sematic/types/serialization.py +++ b/sematic/types/serialization.py @@ -10,6 +10,7 @@ import inspect import json import typing +from enum import Enum # Third-party import cloudpickle # type: ignore @@ -22,6 +23,7 @@ get_origin_type, get_to_json_encodable_func, get_to_json_encodable_summary_func, + is_enum, is_parameterized_generic, is_sematic_parametrized_generic_type, is_supported_type_annotation, @@ -39,6 +41,9 @@ def value_to_json_encodable(value: typing.Any, type_: typing.Any) -> typing.Any: if to_json_encodable_func is None and dataclasses.is_dataclass(type_): to_json_encodable_func = get_to_json_encodable_func(DataclassKey) + if to_json_encodable_func is None and is_enum(type_): + to_json_encodable_func = get_to_json_encodable_func(Enum) + # If we have a serializer, we use it if to_json_encodable_func is not None: return to_json_encodable_func(value, type_) @@ -67,6 +72,9 @@ def value_from_json_encodable( if from_json_encodable_func is None and dataclasses.is_dataclass(type_): from_json_encodable_func = get_from_json_encodable_func(DataclassKey) + if from_json_encodable_func is None and is_enum(type_): + from_json_encodable_func = get_from_json_encodable_func(Enum) + # If we have a deserializer we use it if from_json_encodable_func is not None: return from_json_encodable_func(json_encodable, type_) @@ -97,6 +105,9 @@ def get_json_encodable_summary(value: typing.Any, type_: typing.Any) -> typing.A DataclassKey ) + if to_json_encodable_summary_func is None and is_enum(type_): + to_json_encodable_summary_func = get_to_json_encodable_summary_func(Enum) + if to_json_encodable_summary_func is not None: return to_json_encodable_summary_func(value, type_) diff --git a/sematic/types/tests/test_registry.py b/sematic/types/tests/test_registry.py index d0e053a42..1a4bd5b84 100644 --- a/sematic/types/tests/test_registry.py +++ b/sematic/types/tests/test_registry.py @@ -1,6 +1,6 @@ -# Standard # Standard Library from dataclasses import dataclass +from enum import Enum, unique from typing import Any, List, Literal, Optional, Union # Third party @@ -15,6 +15,13 @@ ) +@unique +class Color(Enum): + RED = "RED" + GREEN = "GREEN" + BLUE = "BLUE" + + @dataclass class FooDataclass: foo: int @@ -26,6 +33,7 @@ class FooStandard: def test_validate_type_annotation(): validate_type_annotation(int) + validate_type_annotation(Color) validate_type_annotation(FooDataclass) validate_type_annotation(FooStandard) validate_type_annotation(Union[int, float]) @@ -46,12 +54,14 @@ def test_validate_type_annotation(): def test_is_supported_type_annotation(): assert is_supported_type_annotation(FooDataclass) + assert is_supported_type_annotation(Color) assert not is_supported_type_annotation(Union) def test_is_parameterized_generic(): assert not is_parameterized_generic(FooDataclass) assert not is_parameterized_generic(FooStandard) + assert not is_parameterized_generic(Color) assert is_parameterized_generic(Union[int, float]) assert is_parameterized_generic(Optional[int]) assert is_parameterized_generic(List[int]) @@ -65,6 +75,7 @@ def test_validate_registry_keys(): _validate_registry_keys(int, int) _validate_registry_keys(FooDataclass) _validate_registry_keys(FooStandard) + _validate_registry_keys(Color) _validate_registry_keys(Union) _validate_registry_keys(List) with pytest.raises( diff --git a/sematic/types/types/BUILD b/sematic/types/types/BUILD index bf2114737..bdf76bf90 100644 --- a/sematic/types/types/BUILD +++ b/sematic/types/types/BUILD @@ -8,6 +8,14 @@ sematic_py_lib( ], ) +sematic_py_lib( + name = "enum", + srcs = ["enum.py"], + deps = [ + "//sematic/types:registry", + ], +) + sematic_py_lib( name = "float", srcs = ["float.py"], diff --git a/sematic/types/types/enum.py b/sematic/types/types/enum.py new file mode 100644 index 000000000..ee782312a --- /dev/null +++ b/sematic/types/types/enum.py @@ -0,0 +1,51 @@ +# Standard Library +import typing +from enum import Enum + +# Sematic +from sematic.types.registry import ( + register_can_cast, + register_from_json_encodable, + register_to_json_encodable, + register_to_json_encodable_summary, +) + + +@register_can_cast(Enum) +def can_cast_type( + from_type: typing.Type[Enum], to_type: typing.Type[Enum] +) -> typing.Tuple[bool, typing.Optional[str]]: + """ + Type casting logic for `Enum`. + + The types must be equal + """ + if from_type is to_type: + return True, None + + return False, "{} does not match {}".format(from_type, to_type) + + +# Default safe_cast behavior is sufficient + + +@register_to_json_encodable_summary(Enum) +def _enum_summary(value: Enum, _) -> str: + return _enum_to_encodable(value, _) + + +@register_to_json_encodable(Enum) +def _enum_to_encodable(value: Enum, type_: typing.Type[Enum]) -> str: + # Ex: foo.bar.Color.RED + if not isinstance(value, type_): + raise ValueError(f"The value '{value}' is not a {type_.__name__}") + return value.name + + +@register_from_json_encodable(Enum) +def _enum_from_encodable(value: str, type_: typing.Type[Enum]): + value_name = value.split(".")[-1] + if not hasattr(type_, value_name): + raise ValueError(f"The type {type_.__name__} has no value '{value_name}'") + deserialized = type_[value_name] + return deserialized diff --git a/sematic/types/types/tests/BUILD b/sematic/types/types/tests/BUILD index d040f3548..ec46a3036 100644 --- a/sematic/types/types/tests/BUILD +++ b/sematic/types/types/tests/BUILD @@ -56,6 +56,15 @@ pytest_test( ], ) +pytest_test( + name = "test_enum", + srcs = ["test_enum.py"], + deps = [ + "//sematic/types/types:enum", + "//sematic/types:serialization", + ], +) + pytest_test( name = "test_dict", srcs = ["test_dict.py"], diff --git a/sematic/types/types/tests/test_enum.py b/sematic/types/types/tests/test_enum.py new file mode 100644 index 000000000..916f84db5 --- /dev/null +++ b/sematic/types/types/tests/test_enum.py @@ -0,0 +1,67 @@ +# Standard Library +from enum import Enum, unique + +import pytest + +# Sematic +from sematic.types.serialization import ( + value_from_json_encodable, + value_to_json_encodable, +) +from sematic.types.types.enum import ( + _enum_from_encodable, + _enum_to_encodable, + can_cast_type, +) + + +class SomethingExotic: + def __init__(self, number): + self.number = number + + +@unique +class Color(Enum): + RED = "RED" + GREEN = "GREEN" + BLUE = "BLUE" + + +@unique +class ExoticNumbers(Enum): + MEANING_OF_LIFE = SomethingExotic(42) + LUCKY_NUMBER = SomethingExotic(7) + + +def test_to_from_encodable(): + encoded = _enum_to_encodable(Color.RED, Color) + assert encoded == "RED" + + # ensures registration worked + assert encoded == value_to_json_encodable(Color.RED, Color) + decoded = _enum_from_encodable(encoded, Color) + assert decoded == value_from_json_encodable(encoded, Color) + assert decoded == Color.RED + + encoded = _enum_to_encodable(ExoticNumbers.LUCKY_NUMBER, ExoticNumbers) + assert encoded == "LUCKY_NUMBER" + decoded = _enum_from_encodable(encoded, ExoticNumbers) + assert decoded == ExoticNumbers.LUCKY_NUMBER + assert decoded.value.number == 7 + + with pytest.raises( + ValueError, match=r"The value 'ExoticNumbers.LUCKY_NUMBER' is not a Color" + ): + _enum_to_encodable(ExoticNumbers.LUCKY_NUMBER, Color) + + with pytest.raises(ValueError, match=r"The type Color has no value 'LUCKY_NUMBER'"): + _enum_from_encodable("LUCKY_NUMBER", Color) + + +def test_can_cast_type(): + assert can_cast_type(Color, Color) == (True, None) + assert can_cast_type(Color, SomethingExotic) == ( + False, + " does not match " + "", + )