diff --git a/pyproject.toml b/pyproject.toml index 6152edb..4b9719e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ classifiers = [ dependencies = ["numpy"] [project.optional-dependencies] +memtrackable = ["memray"] dev = ["pre-commit<4"] tests = ["pytest", "pytest-cov", "pytest-benchmark"] diff --git a/src/mixins/__init__.py b/src/mixins/__init__.py index 574b788..b9ca970 100644 --- a/src/mixins/__init__.py +++ b/src/mixins/__init__.py @@ -1,4 +1,5 @@ +from .memtrackable import MemTrackableMixin from .seedable import SeedableMixin from .timeable import TimeableMixin -__all__ = ["SeedableMixin", "TimeableMixin"] +__all__ = ["MemTrackableMixin", "SeedableMixin", "TimeableMixin"] diff --git a/src/mixins/memtrackable.py b/src/mixins/memtrackable.py new file mode 100644 index 0000000..30fd915 --- /dev/null +++ b/src/mixins/memtrackable.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import functools +import json +import subprocess +from collections import defaultdict +from contextlib import contextmanager +from pathlib import Path +from tempfile import TemporaryDirectory + +import numpy as np +from memray import Tracker + +from .utils import doublewrap, pprint_stats_map + + +class MemTrackableMixin: + """A mixin class to add memory tracking functionality to a class for profiling its methods. + + This mixin class provides the following functionality: + - Tracking the memory use of methods using the TrackMemoryAs decorator. + - Timing of arbitrary code blocks using the _track_memory_as context manager. + - Profiling of the memory used across the life cycle of the class. + + This class uses `memray` to track the memory usage of the methods. + + Attributes: + _memory_usage: A dictionary of lists of memory usages of tracked methods. + The keys of the dictionary are the names of the tracked code blocks / methods. + The values are lists of memory used during the lifecycle of each tracked code blocks / methods. + """ + + _CUTOFFS_AND_UNITS = [ + (8, "b"), + (1000, "B"), + (1000, "kB"), + (1000, "MB"), + (1000, "GB"), + (1000, "TB"), + (1000, "PB"), + ] + + @staticmethod + def get_memray_stats(memray_tracker_fp: Path, memray_stats_fp: Path) -> dict: + memray_stats_cmd = f"memray stats {memray_tracker_fp} --json -o {memray_stats_fp} -f" + subprocess.run(memray_stats_cmd, shell=True, check=True, capture_output=True) + try: + return json.loads(memray_stats_fp.read_text()) + except Exception as e: + raise ValueError(f"Failed to parse memray stats file at {memray_stats_fp}") from e + + def __init__(self, *args, **kwargs): + self._mem_stats = kwargs.get("_mem_stats", defaultdict(list)) + + def __assert_key_exists(self, key: str) -> None: + if not hasattr(self, "_mem_stats"): + raise AttributeError("self._mem_stats should exist!") + if key not in self._mem_stats: + raise AttributeError(f"{key} should exist in self._mem_stats!") + + def _peak_mem_for(self, key: str) -> list[float]: + self.__assert_key_exists(key) + + return [v["metadata"]["peak_memory"] for v in self._mem_stats[key]] + + @contextmanager + def _track_memory_as(self, key: str): + if not hasattr(self, "_mem_stats"): + self._mem_stats = defaultdict(list) + + memory_stats = {} + with TemporaryDirectory() as tmpdir: + memray_fp = Path(tmpdir) / ".memray" + memray_stats_fp = Path(tmpdir) / "memray_stats.json" + + try: + with Tracker(memray_fp, follow_fork=True): + yield + finally: + memory_stats.update(MemTrackableMixin.get_memray_stats(memray_fp, memray_stats_fp)) + self._mem_stats[key].append(memory_stats) + + @staticmethod + @doublewrap + def TrackMemoryAs(fn, key: str | None = None): + if key is None: + key = fn.__name__ + + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + with self._track_memory_as(key): + out = fn(self, *args, **kwargs) + return out + + return wrapper + + @property + def _memory_stats(self): + out = {} + for k in self._mem_stats: + arr = np.array(self._peak_mem_for(k)) + out[k] = (arr.mean(), len(arr), None if len(arr) <= 1 else arr.std()) + return out + + def _profile_memory_usages(self, only_keys: set[str] | None = None): + stats = {k: ((v, "B"), n, s) for k, (v, n, s) in self._memory_stats.items()} + + if only_keys is not None: + stats = {k: v for k, v in stats.items() if k in only_keys} + + return pprint_stats_map(stats, self._CUTOFFS_AND_UNITS) diff --git a/src/mixins/timeable.py b/src/mixins/timeable.py index 5457b64..be55959 100644 --- a/src/mixins/timeable.py +++ b/src/mixins/timeable.py @@ -97,7 +97,7 @@ def _duration_stats(self): out = {} for k in self._timings: arr = np.array(self._times_for(k)) - out[k] = (arr.mean(), len(arr), arr.std()) + out[k] = (arr.mean(), len(arr), None if len(arr) <= 1 else arr.std()) return out def _profile_durations(self, only_keys: set[str] | None = None): diff --git a/tests/test_memtrackable_mixin.py b/tests/test_memtrackable_mixin.py new file mode 100644 index 0000000..aeef381 --- /dev/null +++ b/tests/test_memtrackable_mixin.py @@ -0,0 +1,93 @@ +import numpy as np + +from mixins import MemTrackableMixin + + +class MemTrackableDerived(MemTrackableMixin): + def __init__(self): + self.foo = "foo" + # Doesn't call super().__init__()! Should still work in this case. + + def uses_contextlib(self, mem_size_64b: int = 800000): + with self._track_memory_as("using_contextlib"): + np.ones((mem_size_64b,), dtype=np.float64) + + @MemTrackableMixin.TrackMemoryAs(key="decorated") + def decorated_takes_mem(self, mem_size_64b: int = 800000): + np.ones((mem_size_64b,), dtype=np.float64) + + @MemTrackableMixin.TrackMemoryAs + def decorated_takes_mem_auto_key(self, mem_size_64b: int = 800000): + np.ones((mem_size_64b,), dtype=np.float64) + + +def test_constructs(): + MemTrackableMixin() + MemTrackableDerived() + + +def test_errors_if_not_initialized(): + M = MemTrackableDerived() + try: + M._peak_mem_for("foo") + raise AssertionError("Should have raised an exception!") + except AttributeError as e: + assert "self._mem_stats should exist!" in str(e) + + M.uses_contextlib(mem_size_64b=8000000) # 64 MB + try: + M._peak_mem_for("wrong_key") + raise AssertionError("Should have raised an exception!") + except AttributeError as e: + assert "wrong_key should exist in self._mem_stats!" in str(e) + + +def test_benchmark_timing(benchmark): + M = MemTrackableDerived() + + benchmark(M.decorated_takes_mem, 8000) + + +def test_context_manager(): + M = MemTrackableDerived() + + M.uses_contextlib(mem_size_64b=8000000) # 64 MB + + mem_used = M._peak_mem_for("using_contextlib")[-1] + np.testing.assert_almost_equal(mem_used, 64 * 1000000, decimal=1) + + M.uses_contextlib(mem_size_64b=80000) # 0.64 MB + + mem_used = M._peak_mem_for("using_contextlib")[-1] + np.testing.assert_almost_equal(mem_used, 64 * 10000, decimal=1) + + +def test_decorators_and_profiling(): + M = MemTrackableDerived() + M.decorated_takes_mem(mem_size_64b=16000) + + mem_used = M._peak_mem_for("decorated")[-1] + np.testing.assert_almost_equal(mem_used, 64 * 2000, decimal=1) + + M.decorated_takes_mem_auto_key(mem_size_64b=40000) + mem_used = M._peak_mem_for("decorated_takes_mem_auto_key")[-1] + np.testing.assert_almost_equal(mem_used, 64 * 5000, decimal=1) + + M.decorated_takes_mem(mem_size_64b=8000) + stats = M._memory_stats + + assert {"decorated", "decorated_takes_mem_auto_key"} == set(stats.keys()) + np.testing.assert_almost_equal(64 * 1500, stats["decorated"][0], decimal=1) + assert 2 == stats["decorated"][1] + np.testing.assert_almost_equal(0.5 * 64 * 1000, stats["decorated"][2], decimal=1) + np.testing.assert_almost_equal(64 * 5000, stats["decorated_takes_mem_auto_key"][0], decimal=1) + assert 1 == stats["decorated_takes_mem_auto_key"][1] + assert stats["decorated_takes_mem_auto_key"][2] is None + + got_str = M._profile_memory_usages() + want_str = "decorated: 96.0 ± 32.0 kB (x2)\ndecorated_takes_mem_auto_key: 320.0 kB" + assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}" + + got_str = M._profile_memory_usages(only_keys=["decorated_takes_mem_auto_key"]) + want_str = "decorated_takes_mem_auto_key: 320.0 kB" + assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}" diff --git a/tests/test_timeable_mixin.py b/tests/test_timeable_mixin.py index 0a04226..4bbf732 100644 --- a/tests/test_timeable_mixin.py +++ b/tests/test_timeable_mixin.py @@ -76,14 +76,12 @@ def test_times_and_profiling(): np.testing.assert_almost_equal(0.5, stats["decorated"][2], decimal=1) np.testing.assert_almost_equal(2, stats["decorated_takes_time_auto_key"][0], decimal=1) assert 1 == stats["decorated_takes_time_auto_key"][1] - assert 0 == stats["decorated_takes_time_auto_key"][2] + assert stats["decorated_takes_time_auto_key"][2] is None got_str = T._profile_durations() - want_str = ( - "decorated_takes_time_auto_key: 2.0 ± 0.0 sec\ndecorated: 1.5 ± 0.5 sec (x2)" - ) + want_str = "decorated_takes_time_auto_key: 2.0 sec\ndecorated: 1.5 ± 0.5 sec (x2)" assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}" got_str = T._profile_durations(only_keys=["decorated_takes_time_auto_key"]) - want_str = "decorated_takes_time_auto_key: 2.0 ± 0.0 sec" + want_str = "decorated_takes_time_auto_key: 2.0 sec" assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}"