diff --git a/openfisca_core/indexed_enums/_enum_type.py b/openfisca_core/indexed_enums/_enum_type.py index 8083a6d49..9b95364a2 100644 --- a/openfisca_core/indexed_enums/_enum_type.py +++ b/openfisca_core/indexed_enums/_enum_type.py @@ -66,5 +66,14 @@ def __new__( def __dir__(cls) -> list[str]: return sorted({"indices", "names", "enums", *super().__dir__()}) + def __hash__(cls) -> int: + return object.__hash__(cls.__name__) + + def __eq__(cls, other: object) -> bool: + return hash(cls) == hash(other) + + def __ne__(cls, other: object) -> bool: + return hash(cls) != hash(other) + __all__ = ["EnumType"] diff --git a/openfisca_core/indexed_enums/_utils.py b/openfisca_core/indexed_enums/_utils.py index aa676b92f..67c9e741b 100644 --- a/openfisca_core/indexed_enums/_utils.py +++ b/openfisca_core/indexed_enums/_utils.py @@ -92,8 +92,7 @@ def _int_to_index( ... ) >>> _int_to_index(Road, 1) - Traceback (most recent call last): - TypeError: 'int' object is not iterable + array([1], dtype=uint8) >>> _int_to_index(Road, [1]) array([1], dtype=uint8) @@ -105,8 +104,7 @@ def _int_to_index( array([1], dtype=uint8) >>> _int_to_index(Road, numpy.array(1)) - Traceback (most recent call last): - TypeError: iteration over a 0-d array + array([1], dtype=uint8) >>> _int_to_index(Road, numpy.array([1])) array([1], dtype=uint8) @@ -118,9 +116,9 @@ def _int_to_index( array([1, 1], dtype=uint8) """ - return numpy.array( - [index for index in value if index < len(enum_class.__members__)], t.EnumDType - ) + indices = enum_class.indices + values = numpy.array(value, copy=False) + return values[values < indices.size].astype(t.EnumDType) def _str_to_index( @@ -155,14 +153,13 @@ def _str_to_index( ... ) >>> _str_to_index(Road, "AVENUE") - array([], dtype=uint8) + array([1], dtype=uint8) >>> _str_to_index(Road, ["AVENUE"]) array([1], dtype=uint8) >>> _str_to_index(Road, numpy.array("AVENUE")) - Traceback (most recent call last): - TypeError: iteration over a 0-d array + array([1], dtype=uint8) >>> _str_to_index(Road, numpy.array(["AVENUE"])) array([1], dtype=uint8) @@ -174,14 +171,12 @@ def _str_to_index( array([1, 1], dtype=uint8) """ - return numpy.array( - [ - enum_class.__members__[name].index - for name in value - if name in enum_class._member_names_ - ], - t.EnumDType, - ) + values = numpy.array(value, copy=False) + names = enum_class.names + mask = numpy.isin(values, names) + sorter = numpy.argsort(names) + result = sorter[numpy.searchsorted(names, values[mask], sorter=sorter)] + return result.astype(t.EnumDType) __all__ = ["_enum_to_index", "_int_to_index", "_str_to_index"] diff --git a/openfisca_core/indexed_enums/enum.py b/openfisca_core/indexed_enums/enum.py index a733fd5da..43a893e85 100644 --- a/openfisca_core/indexed_enums/enum.py +++ b/openfisca_core/indexed_enums/enum.py @@ -117,13 +117,6 @@ def __init__(self, *__args: object, **__kwargs: object) -> None: """ self.index = len(self._member_names_) - # Bypass the slow Enum.__eq__ - __eq__ = object.__eq__ - - # In Python 3, __hash__ must be defined if __eq__ is defined to stay - # hashable. - __hash__ = object.__hash__ - def __repr__(self) -> str: return f"{self.__class__.__name__}.{self.name}" @@ -199,7 +192,7 @@ def _encode_array(cls, value: t.VarArray) -> t.EnumArray: indices = _int_to_index(cls, value) elif _is_str_array(value): # type: ignore[unreachable] indices = _str_to_index(cls, value) - elif _is_enum_array(value) and cls.__name__ is value[0].__class__.__name__: + elif _is_enum_array(value) and cls == value[0].__class__: indices = _enum_to_index(value) else: raise EnumEncodingError(cls, value) diff --git a/openfisca_core/indexed_enums/enum_array.py b/openfisca_core/indexed_enums/enum_array.py index 98f9b4c6a..65bc209a7 100644 --- a/openfisca_core/indexed_enums/enum_array.py +++ b/openfisca_core/indexed_enums/enum_array.py @@ -153,10 +153,7 @@ def __eq__(self, other: object) -> t.BoolArray: # type: ignore[override] return NotImplemented if other is None: return NotImplemented - if ( - isinstance(other, type(t.Enum)) - and other.__name__ is self.possible_values.__name__ - ): + if isinstance(other, type(t.Enum)) and other == self.possible_values: result = ( self.view(numpy.ndarray) == self.possible_values.indices[ @@ -164,10 +161,7 @@ def __eq__(self, other: object) -> t.BoolArray: # type: ignore[override] ] ) return result - if ( - isinstance(other, t.Enum) - and other.__class__.__name__ is self.possible_values.__name__ - ): + if isinstance(other, t.Enum) and other.__class__ == self.possible_values: result = self.view(numpy.ndarray) == other.index return result # For NumPy >=1.26.x.