-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
24b126b
commit daae983
Showing
6 changed files
with
211 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .memtrackable import MemTrackableMixin | ||
from .seedable import SeedableMixin | ||
from .timeable import TimeableMixin | ||
|
||
__all__ = ["SeedableMixin", "TimeableMixin"] | ||
__all__ = ["MemTrackableMixin", "SeedableMixin", "TimeableMixin"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from __future__ import annotations | ||
|
||
import functools | ||
import json | ||
import subprocess | ||
from collections import defaultdict | ||
from contextlib import contextmanager | ||
from pathlib import Path | ||
from tempfile import TemporaryDirectory | ||
|
||
import numpy as np | ||
from memray import Tracker | ||
|
||
from .utils import doublewrap, pprint_stats_map | ||
|
||
|
||
class MemTrackableMixin: | ||
"""A mixin class to add memory tracking functionality to a class for profiling its methods. | ||
This mixin class provides the following functionality: | ||
- Tracking the memory use of methods using the TrackMemoryAs decorator. | ||
- Timing of arbitrary code blocks using the _track_memory_as context manager. | ||
- Profiling of the memory used across the life cycle of the class. | ||
This class uses `memray` to track the memory usage of the methods. | ||
Attributes: | ||
_memory_usage: A dictionary of lists of memory usages of tracked methods. | ||
The keys of the dictionary are the names of the tracked code blocks / methods. | ||
The values are lists of memory used during the lifecycle of each tracked code blocks / methods. | ||
""" | ||
|
||
_CUTOFFS_AND_UNITS = [ | ||
(8, "b"), | ||
(1000, "B"), | ||
(1000, "kB"), | ||
(1000, "MB"), | ||
(1000, "GB"), | ||
(1000, "TB"), | ||
(1000, "PB"), | ||
] | ||
|
||
@staticmethod | ||
def get_memray_stats(memray_tracker_fp: Path, memray_stats_fp: Path) -> dict: | ||
memray_stats_cmd = f"memray stats {memray_tracker_fp} --json -o {memray_stats_fp} -f" | ||
subprocess.run(memray_stats_cmd, shell=True, check=True, capture_output=True) | ||
try: | ||
return json.loads(memray_stats_fp.read_text()) | ||
except Exception as e: | ||
raise ValueError(f"Failed to parse memray stats file at {memray_stats_fp}") from e | ||
|
||
def __init__(self, *args, **kwargs): | ||
self._mem_stats = kwargs.get("_mem_stats", defaultdict(list)) | ||
|
||
def __assert_key_exists(self, key: str) -> None: | ||
if not hasattr(self, "_mem_stats"): | ||
raise AttributeError("self._mem_stats should exist!") | ||
if key not in self._mem_stats: | ||
raise AttributeError(f"{key} should exist in self._mem_stats!") | ||
|
||
def _peak_mem_for(self, key: str) -> list[float]: | ||
self.__assert_key_exists(key) | ||
|
||
return [v["metadata"]["peak_memory"] for v in self._mem_stats[key]] | ||
|
||
@contextmanager | ||
def _track_memory_as(self, key: str): | ||
if not hasattr(self, "_mem_stats"): | ||
self._mem_stats = defaultdict(list) | ||
|
||
memory_stats = {} | ||
with TemporaryDirectory() as tmpdir: | ||
memray_fp = Path(tmpdir) / ".memray" | ||
memray_stats_fp = Path(tmpdir) / "memray_stats.json" | ||
|
||
try: | ||
with Tracker(memray_fp, follow_fork=True): | ||
yield | ||
finally: | ||
memory_stats.update(MemTrackableMixin.get_memray_stats(memray_fp, memray_stats_fp)) | ||
self._mem_stats[key].append(memory_stats) | ||
|
||
@staticmethod | ||
@doublewrap | ||
def TrackMemoryAs(fn, key: str | None = None): | ||
if key is None: | ||
key = fn.__name__ | ||
|
||
@functools.wraps(fn) | ||
def wrapper(self, *args, **kwargs): | ||
with self._track_memory_as(key): | ||
out = fn(self, *args, **kwargs) | ||
return out | ||
|
||
return wrapper | ||
|
||
@property | ||
def _memory_stats(self): | ||
out = {} | ||
for k in self._mem_stats: | ||
arr = np.array(self._peak_mem_for(k)) | ||
out[k] = (arr.mean(), len(arr), None if len(arr) <= 1 else arr.std()) | ||
return out | ||
|
||
def _profile_memory_usages(self, only_keys: set[str] | None = None): | ||
stats = {k: ((v, "B"), n, s) for k, (v, n, s) in self._memory_stats.items()} | ||
|
||
if only_keys is not None: | ||
stats = {k: v for k, v in stats.items() if k in only_keys} | ||
|
||
return pprint_stats_map(stats, self._CUTOFFS_AND_UNITS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import numpy as np | ||
|
||
from mixins import MemTrackableMixin | ||
|
||
|
||
class MemTrackableDerived(MemTrackableMixin): | ||
def __init__(self): | ||
self.foo = "foo" | ||
# Doesn't call super().__init__()! Should still work in this case. | ||
|
||
def uses_contextlib(self, mem_size_64b: int = 800000): | ||
with self._track_memory_as("using_contextlib"): | ||
np.ones((mem_size_64b,), dtype=np.float64) | ||
|
||
@MemTrackableMixin.TrackMemoryAs(key="decorated") | ||
def decorated_takes_mem(self, mem_size_64b: int = 800000): | ||
np.ones((mem_size_64b,), dtype=np.float64) | ||
|
||
@MemTrackableMixin.TrackMemoryAs | ||
def decorated_takes_mem_auto_key(self, mem_size_64b: int = 800000): | ||
np.ones((mem_size_64b,), dtype=np.float64) | ||
|
||
|
||
def test_constructs(): | ||
MemTrackableMixin() | ||
MemTrackableDerived() | ||
|
||
|
||
def test_errors_if_not_initialized(): | ||
M = MemTrackableDerived() | ||
try: | ||
M._peak_mem_for("foo") | ||
raise AssertionError("Should have raised an exception!") | ||
except AttributeError as e: | ||
assert "self._mem_stats should exist!" in str(e) | ||
|
||
M.uses_contextlib(mem_size_64b=8000000) # 64 MB | ||
try: | ||
M._peak_mem_for("wrong_key") | ||
raise AssertionError("Should have raised an exception!") | ||
except AttributeError as e: | ||
assert "wrong_key should exist in self._mem_stats!" in str(e) | ||
|
||
|
||
def test_benchmark_timing(benchmark): | ||
M = MemTrackableDerived() | ||
|
||
benchmark(M.decorated_takes_mem, 8000) | ||
|
||
|
||
def test_context_manager(): | ||
M = MemTrackableDerived() | ||
|
||
M.uses_contextlib(mem_size_64b=8000000) # 64 MB | ||
|
||
mem_used = M._peak_mem_for("using_contextlib")[-1] | ||
np.testing.assert_almost_equal(mem_used, 64 * 1000000, decimal=1) | ||
|
||
M.uses_contextlib(mem_size_64b=80000) # 0.64 MB | ||
|
||
mem_used = M._peak_mem_for("using_contextlib")[-1] | ||
np.testing.assert_almost_equal(mem_used, 64 * 10000, decimal=1) | ||
|
||
|
||
def test_decorators_and_profiling(): | ||
M = MemTrackableDerived() | ||
M.decorated_takes_mem(mem_size_64b=16000) | ||
|
||
mem_used = M._peak_mem_for("decorated")[-1] | ||
np.testing.assert_almost_equal(mem_used, 64 * 2000, decimal=1) | ||
|
||
M.decorated_takes_mem_auto_key(mem_size_64b=40000) | ||
mem_used = M._peak_mem_for("decorated_takes_mem_auto_key")[-1] | ||
np.testing.assert_almost_equal(mem_used, 64 * 5000, decimal=1) | ||
|
||
M.decorated_takes_mem(mem_size_64b=8000) | ||
stats = M._memory_stats | ||
|
||
assert {"decorated", "decorated_takes_mem_auto_key"} == set(stats.keys()) | ||
np.testing.assert_almost_equal(64 * 1500, stats["decorated"][0], decimal=1) | ||
assert 2 == stats["decorated"][1] | ||
np.testing.assert_almost_equal(0.5 * 64 * 1000, stats["decorated"][2], decimal=1) | ||
np.testing.assert_almost_equal(64 * 5000, stats["decorated_takes_mem_auto_key"][0], decimal=1) | ||
assert 1 == stats["decorated_takes_mem_auto_key"][1] | ||
assert stats["decorated_takes_mem_auto_key"][2] is None | ||
|
||
got_str = M._profile_memory_usages() | ||
want_str = "decorated: 96.0 ± 32.0 kB (x2)\ndecorated_takes_mem_auto_key: 320.0 kB" | ||
assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}" | ||
|
||
got_str = M._profile_memory_usages(only_keys=["decorated_takes_mem_auto_key"]) | ||
want_str = "decorated_takes_mem_auto_key: 320.0 kB" | ||
assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters