Skip to content

Commit

Permalink
Add inclusion of number of times a timing was called to profile_timings.
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Feb 5, 2023
1 parent ae19875 commit ca1e711
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
14 changes: 9 additions & 5 deletions mixins/timeable.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,16 @@ def _get_pprint_num_unit(cls, seconds: float) -> Tuple[float, str]:
upper_bound *= upper_bound_factor

@classmethod
def _pprint_duration(cls, mean_sec: float, std_seconds: Optional[float] = None) -> str:
def _pprint_duration(cls, mean_sec: float, n_times: int = 1, std_seconds: Optional[float] = None) -> str:
mean_time, mean_unit = cls._get_pprint_num_unit(mean_sec)

if std_seconds:
std_time = std_seconds * mean_time/mean_sec
return f"{mean_time:.1f} ± {std_time:.1f} {mean_unit}"
else: return f"{mean_time:.1f} {mean_unit}"
mean_std_str = f"{mean_time:.1f} ± {std_time:.1f} {mean_unit}"
else: mean_std_str = f"{mean_time:.1f} {mean_unit}"

if n_times > 1: return f"{mean_std_str} (x{n_times})"
else: return mean_std_str

def __init__(self, *args, **kwargs):
self._timings = kwargs.get('_timings', defaultdict(list))
Expand Down Expand Up @@ -90,7 +94,7 @@ def _duration_stats(self):
out = {}
for k in self._timings:
arr = np.array(self._times_for(k))
out[k] = (arr.mean(), arr.std())
out[k] = (arr.mean(), len(arr), arr.std())
return out

def _profile_durations(self, only_keys: Optional[Set[str]] = None):
Expand All @@ -100,7 +104,7 @@ def _profile_durations(self, only_keys: Optional[Set[str]] = None):
stats = {k: v for k, v in stats.items() if k in only_keys}

longest_key_length = max(len(k) for k in stats)
ordered_keys = sorted(stats.keys(), key=lambda k: stats[k][0])
ordered_keys = sorted(stats.keys(), key=lambda k: stats[k][0] * stats[k][1])
tfk_str = '\n'.join(
(
f"{k}:{' '*(longest_key_length - len(k))} "
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "ml_mixins"
version = "0.0.3"
version = "0.0.4"
authors = [
{ name="Matthew B. A. McDermott", email="[email protected]" },
]
Expand Down
10 changes: 6 additions & 4 deletions tests/test_timeable_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,16 @@ def test_times_and_profiling(self):

self.assertEqual({'decorated', 'decorated_takes_time_auto_key'}, set(stats.keys()))
np.testing.assert_almost_equal(1.5, stats['decorated'][0], decimal=1)
np.testing.assert_almost_equal(0.5, stats['decorated'][1], decimal=1)
self.assertEqual(2, stats['decorated'][1])
np.testing.assert_almost_equal(0.5, stats['decorated'][2], decimal=1)
np.testing.assert_almost_equal(2, stats['decorated_takes_time_auto_key'][0], decimal=1)
self.assertEqual(0, stats['decorated_takes_time_auto_key'][1])
self.assertEqual(1, stats['decorated_takes_time_auto_key'][1])
self.assertEqual(0, stats['decorated_takes_time_auto_key'][2])

got_str = T._profile_durations()
want_str = (
"decorated: 1.5 ± 0.5 sec\n"
"decorated_takes_time_auto_key: 2.0 sec"
"decorated_takes_time_auto_key: 2.0 sec\n"
"decorated: 1.5 ± 0.5 sec (x2)"
)
self.assertEqual(want_str, got_str, msg=f"Want:\n{want_str}\nGot:\n{got_str}")

Expand Down

0 comments on commit ca1e711

Please sign in to comment.