Skip to content

Commit

Permalink
Added memtrackable mixin.
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Nov 19, 2024
1 parent 24b126b commit daae983
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 7 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ classifiers = [
dependencies = ["numpy"]

[project.optional-dependencies]
memtrackable = ["memray"]
dev = ["pre-commit<4"]
tests = ["pytest", "pytest-cov", "pytest-benchmark"]

Expand Down
3 changes: 2 additions & 1 deletion src/mixins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .memtrackable import MemTrackableMixin
from .seedable import SeedableMixin
from .timeable import TimeableMixin

__all__ = ["SeedableMixin", "TimeableMixin"]
__all__ = ["MemTrackableMixin", "SeedableMixin", "TimeableMixin"]
111 changes: 111 additions & 0 deletions src/mixins/memtrackable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from __future__ import annotations

import functools
import json
import subprocess
from collections import defaultdict
from contextlib import contextmanager
from pathlib import Path
from tempfile import TemporaryDirectory

import numpy as np
from memray import Tracker

from .utils import doublewrap, pprint_stats_map


class MemTrackableMixin:
"""A mixin class to add memory tracking functionality to a class for profiling its methods.
This mixin class provides the following functionality:
- Tracking the memory use of methods using the TrackMemoryAs decorator.
- Timing of arbitrary code blocks using the _track_memory_as context manager.
- Profiling of the memory used across the life cycle of the class.
This class uses `memray` to track the memory usage of the methods.
Attributes:
_memory_usage: A dictionary of lists of memory usages of tracked methods.
The keys of the dictionary are the names of the tracked code blocks / methods.
The values are lists of memory used during the lifecycle of each tracked code blocks / methods.
"""

_CUTOFFS_AND_UNITS = [
(8, "b"),
(1000, "B"),
(1000, "kB"),
(1000, "MB"),
(1000, "GB"),
(1000, "TB"),
(1000, "PB"),
]

@staticmethod
def get_memray_stats(memray_tracker_fp: Path, memray_stats_fp: Path) -> dict:
memray_stats_cmd = f"memray stats {memray_tracker_fp} --json -o {memray_stats_fp} -f"
subprocess.run(memray_stats_cmd, shell=True, check=True, capture_output=True)
try:
return json.loads(memray_stats_fp.read_text())
except Exception as e:
raise ValueError(f"Failed to parse memray stats file at {memray_stats_fp}") from e

def __init__(self, *args, **kwargs):
self._mem_stats = kwargs.get("_mem_stats", defaultdict(list))

def __assert_key_exists(self, key: str) -> None:
if not hasattr(self, "_mem_stats"):
raise AttributeError("self._mem_stats should exist!")
if key not in self._mem_stats:
raise AttributeError(f"{key} should exist in self._mem_stats!")

def _peak_mem_for(self, key: str) -> list[float]:
self.__assert_key_exists(key)

return [v["metadata"]["peak_memory"] for v in self._mem_stats[key]]

@contextmanager
def _track_memory_as(self, key: str):
if not hasattr(self, "_mem_stats"):
self._mem_stats = defaultdict(list)

memory_stats = {}
with TemporaryDirectory() as tmpdir:
memray_fp = Path(tmpdir) / ".memray"
memray_stats_fp = Path(tmpdir) / "memray_stats.json"

try:
with Tracker(memray_fp, follow_fork=True):
yield
finally:
memory_stats.update(MemTrackableMixin.get_memray_stats(memray_fp, memray_stats_fp))
self._mem_stats[key].append(memory_stats)

@staticmethod
@doublewrap
def TrackMemoryAs(fn, key: str | None = None):
if key is None:
key = fn.__name__

@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
with self._track_memory_as(key):
out = fn(self, *args, **kwargs)
return out

return wrapper

@property
def _memory_stats(self):
out = {}
for k in self._mem_stats:
arr = np.array(self._peak_mem_for(k))
out[k] = (arr.mean(), len(arr), None if len(arr) <= 1 else arr.std())
return out

def _profile_memory_usages(self, only_keys: set[str] | None = None):
stats = {k: ((v, "B"), n, s) for k, (v, n, s) in self._memory_stats.items()}

if only_keys is not None:
stats = {k: v for k, v in stats.items() if k in only_keys}

return pprint_stats_map(stats, self._CUTOFFS_AND_UNITS)
2 changes: 1 addition & 1 deletion src/mixins/timeable.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _duration_stats(self):
out = {}
for k in self._timings:
arr = np.array(self._times_for(k))
out[k] = (arr.mean(), len(arr), arr.std())
out[k] = (arr.mean(), len(arr), None if len(arr) <= 1 else arr.std())
return out

def _profile_durations(self, only_keys: set[str] | None = None):
Expand Down
93 changes: 93 additions & 0 deletions tests/test_memtrackable_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import numpy as np

from mixins import MemTrackableMixin


class MemTrackableDerived(MemTrackableMixin):
def __init__(self):
self.foo = "foo"
# Doesn't call super().__init__()! Should still work in this case.

def uses_contextlib(self, mem_size_64b: int = 800000):
with self._track_memory_as("using_contextlib"):
np.ones((mem_size_64b,), dtype=np.float64)

@MemTrackableMixin.TrackMemoryAs(key="decorated")
def decorated_takes_mem(self, mem_size_64b: int = 800000):
np.ones((mem_size_64b,), dtype=np.float64)

@MemTrackableMixin.TrackMemoryAs
def decorated_takes_mem_auto_key(self, mem_size_64b: int = 800000):
np.ones((mem_size_64b,), dtype=np.float64)


def test_constructs():
MemTrackableMixin()
MemTrackableDerived()


def test_errors_if_not_initialized():
M = MemTrackableDerived()
try:
M._peak_mem_for("foo")
raise AssertionError("Should have raised an exception!")
except AttributeError as e:
assert "self._mem_stats should exist!" in str(e)

M.uses_contextlib(mem_size_64b=8000000) # 64 MB
try:
M._peak_mem_for("wrong_key")
raise AssertionError("Should have raised an exception!")
except AttributeError as e:
assert "wrong_key should exist in self._mem_stats!" in str(e)


def test_benchmark_timing(benchmark):
M = MemTrackableDerived()

benchmark(M.decorated_takes_mem, 8000)


def test_context_manager():
M = MemTrackableDerived()

M.uses_contextlib(mem_size_64b=8000000) # 64 MB

mem_used = M._peak_mem_for("using_contextlib")[-1]
np.testing.assert_almost_equal(mem_used, 64 * 1000000, decimal=1)

M.uses_contextlib(mem_size_64b=80000) # 0.64 MB

mem_used = M._peak_mem_for("using_contextlib")[-1]
np.testing.assert_almost_equal(mem_used, 64 * 10000, decimal=1)


def test_decorators_and_profiling():
M = MemTrackableDerived()
M.decorated_takes_mem(mem_size_64b=16000)

mem_used = M._peak_mem_for("decorated")[-1]
np.testing.assert_almost_equal(mem_used, 64 * 2000, decimal=1)

M.decorated_takes_mem_auto_key(mem_size_64b=40000)
mem_used = M._peak_mem_for("decorated_takes_mem_auto_key")[-1]
np.testing.assert_almost_equal(mem_used, 64 * 5000, decimal=1)

M.decorated_takes_mem(mem_size_64b=8000)
stats = M._memory_stats

assert {"decorated", "decorated_takes_mem_auto_key"} == set(stats.keys())
np.testing.assert_almost_equal(64 * 1500, stats["decorated"][0], decimal=1)
assert 2 == stats["decorated"][1]
np.testing.assert_almost_equal(0.5 * 64 * 1000, stats["decorated"][2], decimal=1)
np.testing.assert_almost_equal(64 * 5000, stats["decorated_takes_mem_auto_key"][0], decimal=1)
assert 1 == stats["decorated_takes_mem_auto_key"][1]
assert stats["decorated_takes_mem_auto_key"][2] is None

got_str = M._profile_memory_usages()
want_str = "decorated: 96.0 ± 32.0 kB (x2)\ndecorated_takes_mem_auto_key: 320.0 kB"
assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}"

got_str = M._profile_memory_usages(only_keys=["decorated_takes_mem_auto_key"])
want_str = "decorated_takes_mem_auto_key: 320.0 kB"
assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}"
8 changes: 3 additions & 5 deletions tests/test_timeable_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,12 @@ def test_times_and_profiling():
np.testing.assert_almost_equal(0.5, stats["decorated"][2], decimal=1)
np.testing.assert_almost_equal(2, stats["decorated_takes_time_auto_key"][0], decimal=1)
assert 1 == stats["decorated_takes_time_auto_key"][1]
assert 0 == stats["decorated_takes_time_auto_key"][2]
assert stats["decorated_takes_time_auto_key"][2] is None

got_str = T._profile_durations()
want_str = (
"decorated_takes_time_auto_key: 2.0 ± 0.0 sec\ndecorated: 1.5 ± 0.5 sec (x2)"
)
want_str = "decorated_takes_time_auto_key: 2.0 sec\ndecorated: 1.5 ± 0.5 sec (x2)"
assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}"

got_str = T._profile_durations(only_keys=["decorated_takes_time_auto_key"])
want_str = "decorated_takes_time_auto_key: 2.0 ± 0.0 sec"
want_str = "decorated_takes_time_auto_key: 2.0 sec"
assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}"

0 comments on commit daae983

Please sign in to comment.