diff --git a/src/mixins/timeable.py b/src/mixins/timeable.py index 9f7c362..5457b64 100644 --- a/src/mixins/timeable.py +++ b/src/mixins/timeable.py @@ -7,7 +7,7 @@ import numpy as np -from .utils import doublewrap, normalize_unit +from .utils import doublewrap, pprint_stats_map class TimeableMixin: @@ -38,21 +38,6 @@ class TimeableMixin: (None, "weeks"), ] - @classmethod - def _pprint_duration(cls, mean_sec: float, n_times: int = 1, std_seconds: float | None = None) -> str: - mean_time, mean_unit = normalize_unit((mean_sec, "sec"), cls._CUTOFFS_AND_UNITS) - - if std_seconds: - std_time = std_seconds * mean_time / mean_sec - mean_std_str = f"{mean_time:.1f} ± {std_time:.1f} {mean_unit}" - else: - mean_std_str = f"{mean_time:.1f} {mean_unit}" - - if n_times > 1: - return f"{mean_std_str} (x{n_times})" - else: - return mean_std_str - def __init__(self, *args, **kwargs): self._timings = kwargs.get("_timings", defaultdict(list)) @@ -116,15 +101,9 @@ def _duration_stats(self): return out def _profile_durations(self, only_keys: set[str] | None = None): - stats = self._duration_stats + stats = {k: ((v, "sec"), n, s) for k, (v, n, s) in self._duration_stats.items()} if only_keys is not None: stats = {k: v for k, v in stats.items() if k in only_keys} - longest_key_length = max(len(k) for k in stats) - ordered_keys = sorted(stats.keys(), key=lambda k: stats[k][0] * stats[k][1]) - tfk_str = "\n".join( - (f"{k}:{' '*(longest_key_length - len(k))} " f"{self._pprint_duration(*stats[k])}") - for k in ordered_keys - ) - return tfk_str + return pprint_stats_map(stats, self._CUTOFFS_AND_UNITS) diff --git a/src/mixins/utils.py b/src/mixins/utils.py index b6efa67..e2e9a8c 100644 --- a/src/mixins/utils.py +++ b/src/mixins/utils.py @@ -38,7 +38,7 @@ def normalize_unit(val: QUANT_T, cutoffs_and_units: UNITS_LIST_T) -> QUANT_T: Raises: LookupError: If the unit is not in the list of valid units. - Example: + Examples: >>> normalize_unit((1000, "ms"), [(1000, "ms"), (60, "s"), (60, "min"), (None, "h")]) (1.0, 's') >>> normalize_unit((720000, "ms"), [(1000, "ms"), (60, "s"), (60, "min"), (None, "h")]) @@ -68,3 +68,101 @@ def normalize_unit(val: QUANT_T, cutoffs_and_units: UNITS_LIST_T) -> QUANT_T: if (upper_bound_factor is None) or (min_unit < upper_bound * upper_bound_factor): return min_unit / upper_bound, unit upper_bound *= upper_bound_factor + + +def pprint_quant( + cutoffs_and_units: UNITS_LIST_T, val: QUANT_T, n_times: int = 1, std_val: float | None = None +) -> str: + """Pretty prints a quantity with its unit, reflecting both a possible variance and # of times measured. + + Args: + val: A tuple of a number and a unit to be normalized. + cutoffs_and_units: A list of tuples of valid cutoffs and units to normalize the passed units. + n_times: The number of times the quantity was measured. + std_val: The standard deviation of the quantity (in the same units as val). + + Returns: + A string representation of the quantity with its unit. + + Examples: + >>> pprint_quant([(1000, "ms"), (60, "s"), (60, "min"), (None, "h")], (1000, "ms")) + '1.0 s' + >>> pprint_quant([(1000, "ms"), (60, "s"), (60, "min"), (None, "h")], (1000, "ms"), n_times=3) + '1.0 s (x3)' + >>> pprint_quant([(1000, "ms"), (60, "s"), (60, "min"), (None, "h")], (1000, "ms"), std_val=100) + '1.0 ± 0.1 s' + >>> pprint_quant([(1000, "ms"), (60, "s"), (60, "min")], (1200, "ms"), n_times=3, std_val=100) + '1.2 ± 0.1 s (x3)' + """ + norm_val, unit = normalize_unit(val, cutoffs_and_units) + if std_val is not None: + unit_conversion_factor = norm_val / val[0] + std_norm_val = std_val * unit_conversion_factor + + mean_std_str = f"{norm_val:.1f} ± {std_norm_val:.1f} {unit}" + else: + mean_std_str = f"{norm_val:.1f} {unit}" + + if n_times > 1: + return f"{mean_std_str} (x{n_times})" + else: + return mean_std_str + + +def pprint_stats_map( + stats: dict[str, tuple[QUANT_T, int, float | None]], cutoffs_and_units: UNITS_LIST_T +) -> str: + """Pretty prints a dictionary of summary statistics of quantities with their units, respecting key length. + + Args: + stats: A dictionary of tuples of a number and a unit to be normalized, the number of times measured, + and the standard deviation of the quantity (in the same units as the number). + cutoffs_and_units: A list of tuples of valid cutoffs and units to normalize the passed units. + + Returns: + A string representation of the dictionary of quantities with their units. This string representation + will be ordered by the greatest total value (mean times number of measurements) of the quantities. + + Examples: + >>> print(pprint_stats_map( + ... {"foo": ((1000, "ms"), 3, 100), "foobar": ((1000, "ms"), 1, None)}, + ... [(1000, "ms"), (60, "s"), (60, "min"), (None, "h")] + ... )) + foobar: 1.0 s + foo: 1.0 ± 0.1 s (x3) + >>> print(pprint_stats_map( + ... {"foo": ((1000, "ms"), 3, 100), "foobar": ((72, "s"), 1, None)}, + ... [(1000, "ms"), (60, "s"), (60, "min"), (None, "h")] + ... )) + foo: 1.0 ± 0.1 s (x3) + foobar: 1.2 min + >>> pprint_stats_map( + ... {"foo": ((1000, "ms"), 3, 100), "foobar": ((72, "y"), 1, None)}, + ... [(1000, "ms"), (60, "s"), (60, "min"), (None, "h")] + ... ) + Traceback (most recent call last): + ... + ValueError: Unit y in stats key foobar not found in cutoffs_and_units! + """ + + def total_val(key: str) -> float: + X_val, X_unit = stats[key][0] + factor = 1 + for fac, unit in cutoffs_and_units: + if unit == X_unit: + break + if fac is None: + raise ValueError(f"Unit {X_unit} in stats key {key} not found in cutoffs_and_units!") + factor *= fac + + base_unit_val = X_val * factor + n_times = stats[key][1] + + return base_unit_val * n_times + + longest_key_length = max(len(k) for k in stats) + ordered_keys = sorted(stats.keys(), key=total_val, reverse=False) + return "\n".join( + (f"{k}:{' '*(longest_key_length - len(k))} " f"{pprint_quant(cutoffs_and_units, *stats[k])}") + for k in ordered_keys + ) diff --git a/tests/test_timeable_mixin.py b/tests/test_timeable_mixin.py index 117ec31..0a04226 100644 --- a/tests/test_timeable_mixin.py +++ b/tests/test_timeable_mixin.py @@ -79,9 +79,11 @@ def test_times_and_profiling(): assert 0 == stats["decorated_takes_time_auto_key"][2] got_str = T._profile_durations() - want_str = "decorated_takes_time_auto_key: 2.0 sec\ndecorated: 1.5 ± 0.5 sec (x2)" + want_str = ( + "decorated_takes_time_auto_key: 2.0 ± 0.0 sec\ndecorated: 1.5 ± 0.5 sec (x2)" + ) assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}" got_str = T._profile_durations(only_keys=["decorated_takes_time_auto_key"]) - want_str = "decorated_takes_time_auto_key: 2.0 sec" + want_str = "decorated_takes_time_auto_key: 2.0 ± 0.0 sec" assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}"