Skip to content

Commit

Permalink
perf: fix enum performance (#1306)
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko committed Nov 21, 2024
1 parent 9f43461 commit e2e3fb3
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 34 deletions.
9 changes: 9 additions & 0 deletions openfisca_core/indexed_enums/_enum_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,14 @@ def __new__(
def __dir__(cls) -> list[str]:
return sorted({"indices", "names", "enums", *super().__dir__()})

def __hash__(cls) -> int:
return object.__hash__(cls.__name__)

def __eq__(cls, other: object) -> bool:
return hash(cls) == hash(other)

def __ne__(cls, other: object) -> bool:
return hash(cls) != hash(other)


__all__ = ["EnumType"]
31 changes: 13 additions & 18 deletions openfisca_core/indexed_enums/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ def _int_to_index(
... )
>>> _int_to_index(Road, 1)
Traceback (most recent call last):
TypeError: 'int' object is not iterable
array([1], dtype=uint8)
>>> _int_to_index(Road, [1])
array([1], dtype=uint8)
Expand All @@ -105,8 +104,7 @@ def _int_to_index(
array([1], dtype=uint8)
>>> _int_to_index(Road, numpy.array(1))
Traceback (most recent call last):
TypeError: iteration over a 0-d array
array([1], dtype=uint8)
>>> _int_to_index(Road, numpy.array([1]))
array([1], dtype=uint8)
Expand All @@ -118,9 +116,9 @@ def _int_to_index(
array([1, 1], dtype=uint8)
"""
return numpy.array(
[index for index in value if index < len(enum_class.__members__)], t.EnumDType
)
indices = enum_class.indices
values = numpy.array(value, copy=False)
return values[values < indices.size].astype(t.EnumDType)


def _str_to_index(
Expand Down Expand Up @@ -155,14 +153,13 @@ def _str_to_index(
... )
>>> _str_to_index(Road, "AVENUE")
array([], dtype=uint8)
array([1], dtype=uint8)
>>> _str_to_index(Road, ["AVENUE"])
array([1], dtype=uint8)
>>> _str_to_index(Road, numpy.array("AVENUE"))
Traceback (most recent call last):
TypeError: iteration over a 0-d array
array([1], dtype=uint8)
>>> _str_to_index(Road, numpy.array(["AVENUE"]))
array([1], dtype=uint8)
Expand All @@ -174,14 +171,12 @@ def _str_to_index(
array([1, 1], dtype=uint8)
"""
return numpy.array(
[
enum_class.__members__[name].index
for name in value
if name in enum_class._member_names_
],
t.EnumDType,
)
values = numpy.array(value, copy=False)
names = enum_class.names
mask = numpy.isin(values, names)
sorter = numpy.argsort(names)
result = sorter[numpy.searchsorted(names, values[mask], sorter=sorter)]
return result.astype(t.EnumDType)


__all__ = ["_enum_to_index", "_int_to_index", "_str_to_index"]
9 changes: 1 addition & 8 deletions openfisca_core/indexed_enums/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,6 @@ def __init__(self, *__args: object, **__kwargs: object) -> None:
"""
self.index = len(self._member_names_)

# Bypass the slow Enum.__eq__
__eq__ = object.__eq__

# In Python 3, __hash__ must be defined if __eq__ is defined to stay
# hashable.
__hash__ = object.__hash__

def __repr__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"

Expand Down Expand Up @@ -199,7 +192,7 @@ def _encode_array(cls, value: t.VarArray) -> t.EnumArray:
indices = _int_to_index(cls, value)
elif _is_str_array(value): # type: ignore[unreachable]
indices = _str_to_index(cls, value)
elif _is_enum_array(value) and cls.__name__ is value[0].__class__.__name__:
elif _is_enum_array(value) and cls == value[0].__class__:
indices = _enum_to_index(value)
else:
raise EnumEncodingError(cls, value)
Expand Down
10 changes: 2 additions & 8 deletions openfisca_core/indexed_enums/enum_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,21 +153,15 @@ def __eq__(self, other: object) -> t.BoolArray: # type: ignore[override]
return NotImplemented
if other is None:
return NotImplemented
if (
isinstance(other, type(t.Enum))
and other.__name__ is self.possible_values.__name__
):
if isinstance(other, type(t.Enum)) and other == self.possible_values:
result = (
self.view(numpy.ndarray)
== self.possible_values.indices[
self.possible_values.indices <= max(self)
]
)
return result
if (
isinstance(other, t.Enum)
and other.__class__.__name__ is self.possible_values.__name__
):
if isinstance(other, t.Enum) and other.__class__ == self.possible_values:
result = self.view(numpy.ndarray) == other.index
return result
# For NumPy >=1.26.x.
Expand Down

0 comments on commit e2e3fb3

Please sign in to comment.