Skip to content

Commit

Permalink
feat: adding overflow disable option to cat axes (#883)
Browse files Browse the repository at this point in the history
Signed-off-by: Henry Schreiner <[email protected]>
  • Loading branch information
henryiii authored Sep 11, 2023
1 parent 0a8e283 commit 5341ae3
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 15 deletions.
8 changes: 7 additions & 1 deletion include/bh_python/axis.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,20 @@ BHP_SPECIALIZE_NAME(integer_oflow)
BHP_SPECIALIZE_NAME(integer_growth)
BHP_SPECIALIZE_NAME(integer_circular)

using category_int_none = bh::axis::category<int, metadata_t, option::none_t>;
using category_int = bh::axis::category<int, metadata_t>;
using category_int_growth = bh::axis::category<int, metadata_t, option::growth_t>;

BHP_SPECIALIZE_NAME(category_int_none)
BHP_SPECIALIZE_NAME(category_int)
BHP_SPECIALIZE_NAME(category_int_growth)

using category_str_none = bh::axis::category<std::string, metadata_t, option::none_t>;
using category_str = bh::axis::category<std::string, metadata_t, option::overflow_t>;
using category_str_growth
= bh::axis::category<std::string, metadata_t, option::growth_t>;

BHP_SPECIALIZE_NAME(category_str_none)
BHP_SPECIALIZE_NAME(category_str)
BHP_SPECIALIZE_NAME(category_str_growth)

Expand Down Expand Up @@ -306,7 +310,9 @@ using axis_variant = bh::axis::variant<axis::regular_uoflow,
axis::category_int_growth,
axis::category_str,
axis::category_str_growth,
axis::boolean>;
axis::boolean,
axis::category_int_none,
axis::category_str_none>;

// This saves a little typing
using vector_axis_variant = std::vector<axis_variant>;
Expand Down
2 changes: 2 additions & 0 deletions src/boost_histogram/_core/axis/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class _BaseCatInt(_BaseAxis):
def __iter__(self) -> Iterator[int]: ...
def bin(self, arg0: int) -> int: ...

class category_int_none(_BaseCatInt): ...
class category_int(_BaseCatInt): ...
class category_int_growth(_BaseCatInt): ...

Expand All @@ -107,6 +108,7 @@ class _BaseCatStr(_BaseAxis):
def __iter__(self) -> Iterator[str]: ...
def bin(self, arg0: int) -> str: ...

class category_str_none(_BaseCatStr): ...
class category_str(_BaseCatStr): ...
class category_str_growth(_BaseCatStr): ...

Expand Down
30 changes: 20 additions & 10 deletions src/boost_histogram/_internal/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,15 +583,15 @@ def _repr_args_(self) -> list[str]:

if self.traits.growth:
ret.append("growth=True")
elif self.traits.circular:
ret.append("circular=True")
elif not self.traits.overflow:
ret.append("overflow=False")

ret += super()._repr_args_()
return ret


@set_module("boost_histogram.axis")
@register({ca.category_str_growth, ca.category_str})
@register({ca.category_str_growth, ca.category_str, ca.category_str_none})
class StrCategory(BaseCategory, family=boost_histogram):
__slots__ = ()

Expand All @@ -601,6 +601,7 @@ def __init__(
*,
metadata: Any = None,
growth: bool = False,
overflow: bool = True,
__dict__: dict[str, Any] | None = None,
):
"""
Expand All @@ -618,21 +619,25 @@ def __init__(
growth : bool = False
Allow the axis to grow if a value is encountered out of range.
Be careful, the axis will grow as large as needed.
overflow : bool = True
Include an overflow bin for "missed" hits. Ignored if growth=True.
__dict__: Optional[Dict[str, Any]] = None
The full metadata dictionary
"""

options = _opts(growth=growth)
options = _opts(growth=growth, overflow=overflow)

ax: ca._BaseCatStr

# henryiii: We currently expand "abc" to "a", "b", "c" - some
# Python interfaces protect against that

if options == {"growth"}:
if "growth" in options:
ax = ca.category_str_growth(tuple(categories))
elif options == set():
elif options == {"overflow"}:
ax = ca.category_str(tuple(categories))
elif not options:
ax = ca.category_str_none(tuple(categories))
else:
raise KeyError("Unsupported collection of options")

Expand All @@ -659,7 +664,7 @@ def _repr_args_(self) -> list[str]:


@set_module("boost_histogram.axis")
@register({ca.category_int, ca.category_int_growth})
@register({ca.category_int, ca.category_int_growth, ca.category_int_none})
class IntCategory(BaseCategory, family=boost_histogram):
__slots__ = ()

Expand All @@ -669,6 +674,7 @@ def __init__(
*,
metadata: Any = None,
growth: bool = False,
overflow: bool = True,
__dict__: dict[str, Any] | None = None,
):
"""
Expand All @@ -686,17 +692,21 @@ def __init__(
growth : bool = False
Allow the axis to grow if a value is encountered out of range.
Be careful, the axis will grow as large as needed.
overflow : bool = True
Include an overflow bin for "missed" hits. Ignored if growth=True.
__dict__: Optional[Dict[str, Any]] = None
The full metadata dictionary
"""

options = _opts(growth=growth)
options = _opts(growth=growth, overflow=overflow)
ax: ca._BaseCatInt

if options == {"growth"}:
if "growth" in options:
ax = ca.category_int_growth(tuple(categories))
elif options == set():
elif options == {"overflow"}:
ax = ca.category_int(tuple(categories))
elif not options:
ax = ca.category_int_none(tuple(categories))
else:
raise KeyError("Unsupported collection of options")

Expand Down
8 changes: 6 additions & 2 deletions src/register_axis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,14 @@ void register_axes(py::module& mod) {
axis::integer_circular>(
mod, [](auto ax) { ax.def(py::init<int, int>(), "start"_a, "stop"_a); });

register_axis_each<axis::category_int, axis::category_int_growth>(
register_axis_each<axis::category_int,
axis::category_int_growth,
axis::category_int_none>(
mod, [](auto ax) { ax.def(py::init<std::vector<int>>(), "categories"_a); });

register_axis_each<axis::category_str, axis::category_str_growth>(mod, [](auto ax) {
register_axis_each<axis::category_str,
axis::category_str_growth,
axis::category_str_none>(mod, [](auto ax) {
ax.def(py::init<std::vector<std::string>>(), "categories"_a);
});

Expand Down
4 changes: 4 additions & 0 deletions tests/test_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,14 @@
(bh.axis.Integer, (1, 2), "g", {}),
(bh.axis.Integer, (1, 2), "", {"circular": True}),
(bh.axis.IntCategory, ((1, 2, 3),), "", {}),
(bh.axis.IntCategory, ((1, 2, 3),), "o", {}),
(bh.axis.IntCategory, ((1, 2, 3),), "g", {}),
(bh.axis.IntCategory, ((1, 2, 3),), "go", {}),
(bh.axis.IntCategory, ((),), "g", {}),
(bh.axis.StrCategory, (tuple("ABC"),), "", {}),
(bh.axis.StrCategory, (tuple("ABC"),), "o", {}),
(bh.axis.StrCategory, (tuple("ABC"),), "g", {}),
(bh.axis.StrCategory, (tuple("ABC"),), "go", {}),
(bh.axis.StrCategory, ((),), "g", {}),
],
)
Expand Down
15 changes: 13 additions & 2 deletions tests/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,17 @@ def test_growing_cats():
assert h.size == 4


def test_noflow_cats():
h = bh.Histogram(
bh.axis.IntCategory([1, 2, 3], overflow=False),
bh.axis.StrCategory(["hi"], overflow=False),
)

h.fill([1, 2, 3, 4], ["hi", "ho", "hi", "ho"])

assert h.sum() == 2


def test_metadata_add():
h1 = bh.Histogram(
bh.axis.IntCategory([1, 2, 3]), bh.axis.StrCategory(["1", "2", "3"])
Expand Down Expand Up @@ -655,9 +666,9 @@ def test_rebin_nd():


# CLASSIC: This used to have metadata too, but that does not compare equal
def test_pickle_0():
def test_pickle_0(flow):
a = bh.Histogram(
bh.axis.IntCategory([0, 1, 2]),
bh.axis.IntCategory([0, 1, 2], overflow=flow),
bh.axis.Integer(0, 20),
bh.axis.Regular(2, 0.0, 20.0, underflow=False, overflow=False),
bh.axis.Variable([0.0, 1.0, 2.0]),
Expand Down

0 comments on commit 5341ae3

Please sign in to comment.