Skip to content

Commit

Permalink
Moved unit normalization to a utils function and added doctests for i…
Browse files Browse the repository at this point in the history
…t in prep for usage across both timeable and memtrackable
  • Loading branch information
mmcdermott committed Nov 19, 2024
1 parent 70b8d46 commit 71d2afa
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 41 deletions.
24 changes: 2 additions & 22 deletions src/mixins/timeable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np

from .utils import doublewrap
from .utils import doublewrap, normalize_unit


class TimeableMixin:
Expand Down Expand Up @@ -38,29 +38,9 @@ class TimeableMixin:
(None, "weeks"),
]

@classmethod
def _get_pprint_num_unit(cls, x: float, x_unit: str = "sec") -> tuple[float, str]:
x_unit_factor = 1
for fac, unit in cls._CUTOFFS_AND_UNITS:
if unit == x_unit:
break
if fac is None:
raise LookupError(
f"Passed unit {x_unit} invalid! "
f"Must be one of {', '.join(u for f, u in cls._CUTOFFS_AND_UNITS)}."
)
x_unit_factor *= fac

min_unit = x * x_unit_factor
upper_bound = 1
for upper_bound_factor, unit in cls._CUTOFFS_AND_UNITS:
if (upper_bound_factor is None) or (min_unit < upper_bound * upper_bound_factor):
return min_unit / upper_bound, unit
upper_bound *= upper_bound_factor

@classmethod
def _pprint_duration(cls, mean_sec: float, n_times: int = 1, std_seconds: float | None = None) -> str:
mean_time, mean_unit = cls._get_pprint_num_unit(mean_sec)
mean_time, mean_unit = normalize_unit((mean_sec, "sec"), cls._CUTOFFS_AND_UNITS)

if std_seconds:
std_time = std_seconds * mean_time / mean_sec
Expand Down
49 changes: 49 additions & 0 deletions src/mixins/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,52 @@ def new_dec(*args, **kwargs):
return lambda realf: f(realf, *args, **kwargs)

return new_dec


QUANT_T = tuple[float | None, str]
UNITS_LIST_T = list[QUANT_T]


def normalize_unit(val: QUANT_T, cutoffs_and_units: UNITS_LIST_T) -> QUANT_T:
"""Converts a quantity to the largest possible unit and returns the quantity and unit.
Args:
val: A tuple of a number and a unit to be normalized.
cutoffs_and_units: A list of tuples of valid cutoffs and units.
Returns:
A tuple of the normalized value and unit.
Raises:
LookupError: If the unit is not in the list of valid units.
Example:
>>> normalize_unit((1000, "ms"), [(1000, "ms"), (60, "s"), (60, "min"), (None, "h")])
(1.0, 's')
>>> normalize_unit((720000, "ms"), [(1000, "ms"), (60, "s"), (60, "min"), (None, "h")])
(12.0, 'min')
>>> normalize_unit((3600, "s"), [(1000, "ms"), (60, "s"), (60, "min"), (None, "h")])
(1.0, 'h')
>>> normalize_unit((5000, "ns"), [(1000, "ms"), (60, "s"), (60, "min"), (None, "h")])
Traceback (most recent call last):
...
LookupError: Passed unit ns invalid! Must be one of ms, s, min, h.
"""
x, x_unit = val
x_unit_factor = 1
for fac, unit in cutoffs_and_units:
if unit == x_unit:
break
if fac is None:
raise LookupError(
f"Passed unit {x_unit} invalid! "
f"Must be one of {', '.join(u for f, u in cutoffs_and_units)}."
)
x_unit_factor *= fac

min_unit = x * x_unit_factor
upper_bound = 1
for upper_bound_factor, unit in cutoffs_and_units:
if (upper_bound_factor is None) or (min_unit < upper_bound * upper_bound_factor):
return min_unit / upper_bound, unit
upper_bound *= upper_bound_factor
19 changes: 0 additions & 19 deletions tests/test_timeable_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,25 +47,6 @@ def test_benchmark_timing(benchmark):
benchmark(T.decorated_takes_time, 0.00001)


def test_pprint_num_unit():
assert (5, "μs") == TimeableMixin._get_pprint_num_unit(5 * 1e-6)

class Derived(TimeableMixin):
_CUTOFFS_AND_UNITS = [(10, "foo"), (2, "bar"), (None, "biz")]

assert (3, "biz") == Derived._get_pprint_num_unit(3, "biz")
assert (3, "foo") == Derived._get_pprint_num_unit(3 / 20, "biz")
assert (1.2, "biz") == Derived._get_pprint_num_unit(2.4 * 10, "foo")

try:
Derived._get_pprint_num_unit(1, "WRONG")
raise AssertionError("Should have raised an exception")
except LookupError:
pass
except Exception as e:
raise AssertionError(f"Raised the wrong exception: {e}")


def test_context_manager():
T = TimeableDerived()

Expand Down

0 comments on commit 71d2afa

Please sign in to comment.