Skip to content

Commit

Permalink
Moved pretty print code out as well.
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Nov 19, 2024
1 parent 71d2afa commit 24b126b
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 27 deletions.
27 changes: 3 additions & 24 deletions src/mixins/timeable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np

from .utils import doublewrap, normalize_unit
from .utils import doublewrap, pprint_stats_map


class TimeableMixin:
Expand Down Expand Up @@ -38,21 +38,6 @@ class TimeableMixin:
(None, "weeks"),
]

@classmethod
def _pprint_duration(cls, mean_sec: float, n_times: int = 1, std_seconds: float | None = None) -> str:
mean_time, mean_unit = normalize_unit((mean_sec, "sec"), cls._CUTOFFS_AND_UNITS)

if std_seconds:
std_time = std_seconds * mean_time / mean_sec
mean_std_str = f"{mean_time:.1f} ± {std_time:.1f} {mean_unit}"
else:
mean_std_str = f"{mean_time:.1f} {mean_unit}"

if n_times > 1:
return f"{mean_std_str} (x{n_times})"
else:
return mean_std_str

def __init__(self, *args, **kwargs):
self._timings = kwargs.get("_timings", defaultdict(list))

Expand Down Expand Up @@ -116,15 +101,9 @@ def _duration_stats(self):
return out

def _profile_durations(self, only_keys: set[str] | None = None):
stats = self._duration_stats
stats = {k: ((v, "sec"), n, s) for k, (v, n, s) in self._duration_stats.items()}

if only_keys is not None:
stats = {k: v for k, v in stats.items() if k in only_keys}

longest_key_length = max(len(k) for k in stats)
ordered_keys = sorted(stats.keys(), key=lambda k: stats[k][0] * stats[k][1])
tfk_str = "\n".join(
(f"{k}:{' '*(longest_key_length - len(k))} " f"{self._pprint_duration(*stats[k])}")
for k in ordered_keys
)
return tfk_str
return pprint_stats_map(stats, self._CUTOFFS_AND_UNITS)
100 changes: 99 additions & 1 deletion src/mixins/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def normalize_unit(val: QUANT_T, cutoffs_and_units: UNITS_LIST_T) -> QUANT_T:
Raises:
LookupError: If the unit is not in the list of valid units.
Example:
Examples:
>>> normalize_unit((1000, "ms"), [(1000, "ms"), (60, "s"), (60, "min"), (None, "h")])
(1.0, 's')
>>> normalize_unit((720000, "ms"), [(1000, "ms"), (60, "s"), (60, "min"), (None, "h")])
Expand Down Expand Up @@ -68,3 +68,101 @@ def normalize_unit(val: QUANT_T, cutoffs_and_units: UNITS_LIST_T) -> QUANT_T:
if (upper_bound_factor is None) or (min_unit < upper_bound * upper_bound_factor):
return min_unit / upper_bound, unit
upper_bound *= upper_bound_factor


def pprint_quant(
cutoffs_and_units: UNITS_LIST_T, val: QUANT_T, n_times: int = 1, std_val: float | None = None
) -> str:
"""Pretty prints a quantity with its unit, reflecting both a possible variance and # of times measured.
Args:
val: A tuple of a number and a unit to be normalized.
cutoffs_and_units: A list of tuples of valid cutoffs and units to normalize the passed units.
n_times: The number of times the quantity was measured.
std_val: The standard deviation of the quantity (in the same units as val).
Returns:
A string representation of the quantity with its unit.
Examples:
>>> pprint_quant([(1000, "ms"), (60, "s"), (60, "min"), (None, "h")], (1000, "ms"))
'1.0 s'
>>> pprint_quant([(1000, "ms"), (60, "s"), (60, "min"), (None, "h")], (1000, "ms"), n_times=3)
'1.0 s (x3)'
>>> pprint_quant([(1000, "ms"), (60, "s"), (60, "min"), (None, "h")], (1000, "ms"), std_val=100)
'1.0 ± 0.1 s'
>>> pprint_quant([(1000, "ms"), (60, "s"), (60, "min")], (1200, "ms"), n_times=3, std_val=100)
'1.2 ± 0.1 s (x3)'
"""
norm_val, unit = normalize_unit(val, cutoffs_and_units)
if std_val is not None:
unit_conversion_factor = norm_val / val[0]
std_norm_val = std_val * unit_conversion_factor

mean_std_str = f"{norm_val:.1f} ± {std_norm_val:.1f} {unit}"
else:
mean_std_str = f"{norm_val:.1f} {unit}"

if n_times > 1:
return f"{mean_std_str} (x{n_times})"
else:
return mean_std_str


def pprint_stats_map(
stats: dict[str, tuple[QUANT_T, int, float | None]], cutoffs_and_units: UNITS_LIST_T
) -> str:
"""Pretty prints a dictionary of summary statistics of quantities with their units, respecting key length.
Args:
stats: A dictionary of tuples of a number and a unit to be normalized, the number of times measured,
and the standard deviation of the quantity (in the same units as the number).
cutoffs_and_units: A list of tuples of valid cutoffs and units to normalize the passed units.
Returns:
A string representation of the dictionary of quantities with their units. This string representation
will be ordered by the greatest total value (mean times number of measurements) of the quantities.
Examples:
>>> print(pprint_stats_map(
... {"foo": ((1000, "ms"), 3, 100), "foobar": ((1000, "ms"), 1, None)},
... [(1000, "ms"), (60, "s"), (60, "min"), (None, "h")]
... ))
foobar: 1.0 s
foo: 1.0 ± 0.1 s (x3)
>>> print(pprint_stats_map(
... {"foo": ((1000, "ms"), 3, 100), "foobar": ((72, "s"), 1, None)},
... [(1000, "ms"), (60, "s"), (60, "min"), (None, "h")]
... ))
foo: 1.0 ± 0.1 s (x3)
foobar: 1.2 min
>>> pprint_stats_map(
... {"foo": ((1000, "ms"), 3, 100), "foobar": ((72, "y"), 1, None)},
... [(1000, "ms"), (60, "s"), (60, "min"), (None, "h")]
... )
Traceback (most recent call last):
...
ValueError: Unit y in stats key foobar not found in cutoffs_and_units!
"""

def total_val(key: str) -> float:
X_val, X_unit = stats[key][0]
factor = 1
for fac, unit in cutoffs_and_units:
if unit == X_unit:
break
if fac is None:
raise ValueError(f"Unit {X_unit} in stats key {key} not found in cutoffs_and_units!")
factor *= fac

base_unit_val = X_val * factor
n_times = stats[key][1]

return base_unit_val * n_times

longest_key_length = max(len(k) for k in stats)
ordered_keys = sorted(stats.keys(), key=total_val, reverse=False)
return "\n".join(
(f"{k}:{' '*(longest_key_length - len(k))} " f"{pprint_quant(cutoffs_and_units, *stats[k])}")
for k in ordered_keys
)
6 changes: 4 additions & 2 deletions tests/test_timeable_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,11 @@ def test_times_and_profiling():
assert 0 == stats["decorated_takes_time_auto_key"][2]

got_str = T._profile_durations()
want_str = "decorated_takes_time_auto_key: 2.0 sec\ndecorated: 1.5 ± 0.5 sec (x2)"
want_str = (
"decorated_takes_time_auto_key: 2.0 ± 0.0 sec\ndecorated: 1.5 ± 0.5 sec (x2)"
)
assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}"

got_str = T._profile_durations(only_keys=["decorated_takes_time_auto_key"])
want_str = "decorated_takes_time_auto_key: 2.0 sec"
want_str = "decorated_takes_time_auto_key: 2.0 ± 0.0 sec"
assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}"

0 comments on commit 24b126b

Please sign in to comment.