Skip to content

Commit

Permalink
Add has_index protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko committed Sep 25, 2023
1 parent ca4125b commit 5216c63
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 12 deletions.
12 changes: 9 additions & 3 deletions openfisca_core/indexed_enums/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,18 @@ class Enum(enum.Enum):
have an index.
"""

#: Index of the enum.
index: int

# Tweak enums to add an index attribute to each enum item
def __init__(self, name: str) -> None:
# When the enum item is initialized, self._member_names_ contains the
def __new__(cls, name: str) -> Enum:
# When the enum item is initialized, cls._member_names_ contains the
# names of the previously initialized items, so its length is the index
# of this item.
self.index = len(self._member_names_)
new = object.__new__(cls)
new._value_ = name
new.index = len(cls._member_names_)
return new

# Bypass the slow Enum.__eq__
__eq__ = object.__eq__
Expand Down
21 changes: 12 additions & 9 deletions openfisca_core/indexed_enums/enum_array.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

from openfisca_core.types import Enum
from typing import Any, NoReturn, Optional, Type
from numpy.typing import NDArray
from typing import Any, Iterable, NoReturn

import numpy

from .typing import HasIndex

class EnumArray(numpy.ndarray):

class EnumArray(NDArray[numpy.int_]):
"""
NumPy array subclass representing an array of enum items.
Expand All @@ -18,24 +20,25 @@ class EnumArray(numpy.ndarray):
# https://docs.scipy.org/doc/numpy-1.13.0/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array.
def __new__(
cls,
input_array: numpy.int_,
possible_values: Optional[Type[Enum]] = None,
input_array: NDArray[numpy.int_],
possible_values: Iterable[HasIndex] | None = None,
) -> EnumArray:
obj = numpy.asarray(input_array).view(cls)
obj.possible_values = possible_values
return obj

# See previous comment
def __array_finalize__(self, obj: Optional[numpy.int_]) -> None:
def __array_finalize__(self, obj: NDArray[numpy.int_] | EnumArray | None) -> None:
if obj is None:
return

self.possible_values = getattr(obj, "possible_values", None)

def __eq__(self, other: Any) -> bool:
def __eq__(self, other: HasIndex | Any) -> bool:
# When comparing to an item of self.possible_values, use the item index
# to speed up the comparison.
if other.__class__.__name__ is self.possible_values.__name__:

if hasattr(other, "index"):
# Use view(ndarray) so that the result is a classic ndarray, not an
# EnumArray.
return self.view(numpy.ndarray) == other.index
Expand Down Expand Up @@ -79,7 +82,7 @@ def decode(self) -> numpy.object_:
list(self.possible_values),
)

def decode_to_str(self) -> numpy.str_:
def decode_to_str(self) -> NDArray[numpy.str_]:
"""
Return the array of string identifiers corresponding to self.
Expand Down
7 changes: 7 additions & 0 deletions openfisca_core/indexed_enums/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing import Protocol


class HasIndex(Protocol):
"""Indexable class protocol."""

index: int

0 comments on commit 5216c63

Please sign in to comment.