diff --git a/src/lvmopstools/actor.py b/src/lvmopstools/actor.py index fd1c9bb..6e6286f 100644 --- a/src/lvmopstools/actor.py +++ b/src/lvmopstools/actor.py @@ -32,9 +32,9 @@ "CheckError", "ActorState", "ErrorCodes", + "ErrorCodesBase", "ErrorData", - "extend_enum", - "verify_error_codes", + "create_error_codes", ] @@ -66,35 +66,6 @@ } -def extend_enum(inherited_enum): - """A decorator to extend an enum with new values. - - Taken from https://stackoverflow.com/a/64045773 - - """ - - def wrapper(added_enum): - joined = {} - for item in inherited_enum: - joined[item.name] = item.value - for item in added_enum: - joined[item.name] = item.value - - new_enum_class = enum.Enum(added_enum.__name__, joined) - - # Add methods from the inherited class. - for attr in inherited_enum.__dict__: - if attr.startswith("__"): - continue - if hasattr(new_enum_class, attr): - continue - setattr(new_enum_class, attr, getattr(inherited_enum, attr)) - - return new_enum_class - - return wrapper - - class ActorState(enum.Flag): """Defines the possible states of the actor.""" @@ -121,24 +92,9 @@ class ErrorData: description: str = "" -class verify_error_codes: - """Verifies that all error codes are instances of ErrorData.""" - - def __call__(self, enumeration): - for enum_item in enumeration.__members__.values(): - if not isinstance(enum_item.value, ErrorData): - name = enum_item.name - raise ValueError(f"Error code {name} must be an instance of ErrorData.") - - return enumeration - - -@verify_error_codes() -class ErrorCodes(enum.Enum): +class ErrorCodesBase(enum.Enum): """Enumeration of error codes""" - UNKNOWN = ErrorData(9999, True, "Unknown error.") - @classmethod def get_error_code(cls, error_code: int): """Returns the :obj:`.ErrorCodes` that matches the ``error_code`` value.""" @@ -150,6 +106,28 @@ def get_error_code(cls, error_code: int): raise ValueError(f"Error code {error_code} not found.") +def create_error_codes( + error_codes: dict[str, tuple | list | ErrorData], + name: str = "ErrorCodes", + include_unknown: bool = True, +) -> Any: + """Creates an enumeration of error codes.""" + + error_codes_enum: dict[str, ErrorData] = {} + for error_name, error_data in error_codes.items(): + if not isinstance(error_data, ErrorData): + error_data = ErrorData(*error_data) + error_codes_enum[error_name.upper()] = error_data + + if include_unknown and "UNKNOWN" not in error_codes_enum: + error_codes_enum["UNKNOWN"] = ErrorData(9999, True, "Unknown error") + + return ErrorCodesBase(name, error_codes_enum) + + +ErrorCodes = create_error_codes({"UNKNOWN": ErrorData(9999, True, "Unknown error")}) + + @click.command(cls=CluCommand, name="actor-state") async def actor_state(command: Command[LVMActor], *args, **kwargs): """Returns the actor state.""" @@ -189,7 +167,7 @@ class CheckError(Exception): def __init__( self, message: str = "", - error_code: ErrorCodes | int = ErrorCodes.UNKNOWN, + error_code: ErrorCodesBase | int = ErrorCodes.UNKNOWN, ): if isinstance(error_code, int): self.error_code = ErrorCodes.get_error_code(error_code) @@ -324,7 +302,7 @@ def update_state( async def troubleshoot( self, - error_code: ErrorCodes = ErrorCodes.UNKNOWN, + error_code: ErrorCodesBase = ErrorCodes.UNKNOWN, exception: Exception | None = None, traceback_frame: int = 0, ): @@ -394,7 +372,7 @@ async def _check_internal(self): @abc.abstractmethod async def _troubleshoot_internal( self, - error_code: ErrorCodes, + error_code: ErrorCodesBase, exception: Exception | None = None, ): """Handles internal troubleshooting. diff --git a/tests/test_actor.py b/tests/test_actor.py index 04fef42..93695a3 100644 --- a/tests/test_actor.py +++ b/tests/test_actor.py @@ -9,7 +9,6 @@ from __future__ import annotations import asyncio -import enum import sys import pytest @@ -24,28 +23,28 @@ ErrorCodes, ErrorData, LVMActor, - extend_enum, - verify_error_codes, + create_error_codes, ) -def test_extend_enum(): - @extend_enum(ErrorCodes) - class ExtraErrorCodes(enum.Enum): - SOME_FAILURE_MODE = ErrorData(1, critical=False, description="Test error") +def test_create_error_codes(): - assert hasattr(ExtraErrorCodes, "SOME_FAILURE_MODE") - assert hasattr(ExtraErrorCodes, "UNKNOWN") + ErrorCodesTest = create_error_codes( + { + "CODE1": (1, True), + "CODE2": (2, False, "Non-critical error"), + "CODE3": ErrorData(3, True, "Critical error"), + } + ) - assert hasattr(ExtraErrorCodes, "get_error_code") + assert "CODE1" in ErrorCodesTest.__members__ + assert ErrorCodesTest.CODE1.value.code == 1 + assert ErrorCodesTest.CODE1.value.critical + assert ErrorCodesTest.CODE2.value.code == 2 + assert ErrorCodesTest.CODE2.value.critical is False -def test_verify_error_codes_fails(): - with pytest.raises(ValueError): - - @verify_error_codes() - class NewErrorCodes(enum.Enum): - SOME_FAILURE_MODE = 1 + assert ErrorCodesTest.CODE3.value.code == 3 async def test_command_actor_state(lvm_actor: LVMActor):