Skip to content

Commit

Permalink
Merge pull request #8 from mmcdermott/6_memtrackable
Browse files Browse the repository at this point in the history
Adds a memtrackable mixin for use in benchmarking.
  • Loading branch information
mmcdermott authored Nov 21, 2024
2 parents 70b8d46 + 145ed5e commit b02aef4
Show file tree
Hide file tree
Showing 8 changed files with 403 additions and 67 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:

- name: Install packages
run: |
pip install -e .[tests]
pip install -e .[tests,memtrackable]
#----------------------------------------------
# run test suite
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ classifiers = [
dependencies = ["numpy"]

[project.optional-dependencies]
memtrackable = ["memray"]
dev = ["pre-commit<4"]
tests = ["pytest", "pytest-cov", "pytest-benchmark"]

Expand Down
3 changes: 2 additions & 1 deletion src/mixins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .memtrackable import MemTrackableMixin
from .seedable import SeedableMixin
from .timeable import TimeableMixin

__all__ = ["SeedableMixin", "TimeableMixin"]
__all__ = ["MemTrackableMixin", "SeedableMixin", "TimeableMixin"]
154 changes: 154 additions & 0 deletions src/mixins/memtrackable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from __future__ import annotations

import functools
import json
import subprocess
from collections import defaultdict
from contextlib import contextmanager
from pathlib import Path
from tempfile import TemporaryDirectory

import numpy as np
from memray import Tracker

from .utils import doublewrap, pprint_stats_map


class MemTrackableMixin:
"""A mixin class to add memory tracking functionality to a class for profiling its methods.
This mixin class provides the following functionality:
- Tracking the memory use of methods using the TrackMemoryAs decorator.
- Timing of arbitrary code blocks using the _track_memory_as context manager.
- Profiling of the memory used across the life cycle of the class.
This class uses `memray` to track the memory usage of the methods.
Attributes:
_memory_usage: A dictionary of lists of memory usages of tracked methods.
The keys of the dictionary are the names of the tracked code blocks / methods.
The values are lists of memory used during the lifecycle of each tracked code blocks / methods.
"""

_CUTOFFS_AND_UNITS = [
(8, "b"),
(1000, "B"),
(1000, "kB"),
(1000, "MB"),
(1000, "GB"),
(1000, "TB"),
(1000, "PB"),
]

@staticmethod
def get_memray_stats(memray_tracker_fp: Path, memray_stats_fp: Path) -> dict:
"""Extracts the memory stats from the tracker file and saves it to a json file and returns the stats.
Args:
memray_tracker_fp: The path to the memray tracker file.
memray_stats_fp: The path to save the memory statistics.
Returns:
The stats extracted from the memray tracker file.
Raises:
FileNotFoundError: If the memray tracker file is not found.
ValueError: If the stats file cannot be parsed.
Examples:
>>> import tempfile
>>> with tempfile.TemporaryDirectory() as tmpdir:
... memray_tracker_fp = Path(tmpdir) / ".memray"
... memray_stats_fp = Path(tmpdir) / "memray_stats.json"
... with Tracker(memray_tracker_fp, follow_fork=True):
... A = np.ones((1000,), dtype=np.float64)
... stats = MemTrackableMixin.get_memray_stats(memray_tracker_fp, memray_stats_fp)
... print(stats['metadata']['peak_memory'])
8000
If you run it on a tracker file that is malformed, it won't work.
>>> with tempfile.TemporaryDirectory() as tmpdir:
... memray_tracker_fp = Path(tmpdir) / ".memray"
... memray_stats_fp = Path(tmpdir) / "memray_stats.json"
... memray_tracker_fp.touch()
... MemTrackableMixin.get_memray_stats(memray_tracker_fp, memray_stats_fp)
Traceback (most recent call last):
...
ValueError: Failed to extract and parse memray stats file at ...
If you run it on a non-existent file, it won't work.
>>> MemTrackableMixin.get_memray_stats(Path("non_existent.mem"), Path("non_existent.stats"))
Traceback (most recent call last):
...
FileNotFoundError: Memray tracker file not found at non_existent.mem
"""
if not memray_tracker_fp.is_file():
raise FileNotFoundError(f"Memray tracker file not found at {memray_tracker_fp}")

memray_stats_cmd = f"memray stats {memray_tracker_fp} --json -o {memray_stats_fp} -f"
try:
subprocess.run(memray_stats_cmd, shell=True, check=True, capture_output=True)
return json.loads(memray_stats_fp.read_text())
except Exception as e:
raise ValueError(f"Failed to extract and parse memray stats file at {memray_stats_fp}") from e

def __init__(self, *args, **kwargs):
self._mem_stats = kwargs.get("_mem_stats", defaultdict(list))

def __assert_key_exists(self, key: str) -> None:
if not hasattr(self, "_mem_stats"):
raise AttributeError("self._mem_stats should exist!")
if key not in self._mem_stats:
raise AttributeError(f"{key} should exist in self._mem_stats!")

def _peak_mem_for(self, key: str) -> list[float]:
self.__assert_key_exists(key)

return [v["metadata"]["peak_memory"] for v in self._mem_stats[key]]

@contextmanager
def _track_memory_as(self, key: str):
if not hasattr(self, "_mem_stats"):
self._mem_stats = defaultdict(list)

memory_stats = {}
with TemporaryDirectory() as tmpdir:
memray_fp = Path(tmpdir) / ".memray"
memray_stats_fp = Path(tmpdir) / "memray_stats.json"

try:
with Tracker(memray_fp, follow_fork=True):
yield
finally:
memory_stats.update(MemTrackableMixin.get_memray_stats(memray_fp, memray_stats_fp))
self._mem_stats[key].append(memory_stats)

@staticmethod
@doublewrap
def TrackMemoryAs(fn, key: str | None = None):
if key is None:
key = fn.__name__

@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
with self._track_memory_as(key):
out = fn(self, *args, **kwargs)
return out

return wrapper

@property
def _memory_stats(self):
out = {}
for k in self._mem_stats:
arr = np.array(self._peak_mem_for(k))
out[k] = (arr.mean(), len(arr), None if len(arr) <= 1 else arr.std())
return out

def _profile_memory_usages(self, only_keys: set[str] | None = None):
stats = {k: ((v, "B"), n, s) for k, (v, n, s) in self._memory_stats.items()}

if only_keys is not None:
stats = {k: v for k, v in stats.items() if k in only_keys}

return pprint_stats_map(stats, self._CUTOFFS_AND_UNITS)
49 changes: 4 additions & 45 deletions src/mixins/timeable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np

from .utils import doublewrap
from .utils import doublewrap, pprint_stats_map


class TimeableMixin:
Expand Down Expand Up @@ -38,41 +38,6 @@ class TimeableMixin:
(None, "weeks"),
]

@classmethod
def _get_pprint_num_unit(cls, x: float, x_unit: str = "sec") -> tuple[float, str]:
x_unit_factor = 1
for fac, unit in cls._CUTOFFS_AND_UNITS:
if unit == x_unit:
break
if fac is None:
raise LookupError(
f"Passed unit {x_unit} invalid! "
f"Must be one of {', '.join(u for f, u in cls._CUTOFFS_AND_UNITS)}."
)
x_unit_factor *= fac

min_unit = x * x_unit_factor
upper_bound = 1
for upper_bound_factor, unit in cls._CUTOFFS_AND_UNITS:
if (upper_bound_factor is None) or (min_unit < upper_bound * upper_bound_factor):
return min_unit / upper_bound, unit
upper_bound *= upper_bound_factor

@classmethod
def _pprint_duration(cls, mean_sec: float, n_times: int = 1, std_seconds: float | None = None) -> str:
mean_time, mean_unit = cls._get_pprint_num_unit(mean_sec)

if std_seconds:
std_time = std_seconds * mean_time / mean_sec
mean_std_str = f"{mean_time:.1f} ± {std_time:.1f} {mean_unit}"
else:
mean_std_str = f"{mean_time:.1f} {mean_unit}"

if n_times > 1:
return f"{mean_std_str} (x{n_times})"
else:
return mean_std_str

def __init__(self, *args, **kwargs):
self._timings = kwargs.get("_timings", defaultdict(list))

Expand Down Expand Up @@ -132,19 +97,13 @@ def _duration_stats(self):
out = {}
for k in self._timings:
arr = np.array(self._times_for(k))
out[k] = (arr.mean(), len(arr), arr.std())
out[k] = (arr.mean(), len(arr), None if len(arr) <= 1 else arr.std())
return out

def _profile_durations(self, only_keys: set[str] | None = None):
stats = self._duration_stats
stats = {k: ((v, "sec"), n, s) for k, (v, n, s) in self._duration_stats.items()}

if only_keys is not None:
stats = {k: v for k, v in stats.items() if k in only_keys}

longest_key_length = max(len(k) for k in stats)
ordered_keys = sorted(stats.keys(), key=lambda k: stats[k][0] * stats[k][1])
tfk_str = "\n".join(
(f"{k}:{' '*(longest_key_length - len(k))} " f"{self._pprint_duration(*stats[k])}")
for k in ordered_keys
)
return tfk_str
return pprint_stats_map(stats, self._CUTOFFS_AND_UNITS)
Loading

0 comments on commit b02aef4

Please sign in to comment.