Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds a memtrackable mixin for use in benchmarking. #8

Merged
merged 5 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:

- name: Install packages
run: |
pip install -e .[tests]
pip install -e .[tests,memtrackable]

#----------------------------------------------
# run test suite
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ classifiers = [
dependencies = ["numpy"]

[project.optional-dependencies]
memtrackable = ["memray"]
dev = ["pre-commit<4"]
tests = ["pytest", "pytest-cov", "pytest-benchmark"]

Expand Down
3 changes: 2 additions & 1 deletion src/mixins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .memtrackable import MemTrackableMixin
from .seedable import SeedableMixin
from .timeable import TimeableMixin

__all__ = ["SeedableMixin", "TimeableMixin"]
__all__ = ["MemTrackableMixin", "SeedableMixin", "TimeableMixin"]
154 changes: 154 additions & 0 deletions src/mixins/memtrackable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from __future__ import annotations

import functools
import json
import subprocess
from collections import defaultdict
from contextlib import contextmanager
from pathlib import Path
from tempfile import TemporaryDirectory

import numpy as np
from memray import Tracker

from .utils import doublewrap, pprint_stats_map


class MemTrackableMixin:
"""A mixin class to add memory tracking functionality to a class for profiling its methods.

This mixin class provides the following functionality:
- Tracking the memory use of methods using the TrackMemoryAs decorator.
- Timing of arbitrary code blocks using the _track_memory_as context manager.
- Profiling of the memory used across the life cycle of the class.

This class uses `memray` to track the memory usage of the methods.

Attributes:
_memory_usage: A dictionary of lists of memory usages of tracked methods.
The keys of the dictionary are the names of the tracked code blocks / methods.
The values are lists of memory used during the lifecycle of each tracked code blocks / methods.
"""

_CUTOFFS_AND_UNITS = [
(8, "b"),
(1000, "B"),
(1000, "kB"),
(1000, "MB"),
(1000, "GB"),
(1000, "TB"),
(1000, "PB"),
]

@staticmethod
def get_memray_stats(memray_tracker_fp: Path, memray_stats_fp: Path) -> dict:
"""Extracts the memory stats from the tracker file and saves it to a json file and returns the stats.

Args:
memray_tracker_fp: The path to the memray tracker file.
memray_stats_fp: The path to save the memory statistics.

Returns:
The stats extracted from the memray tracker file.

Raises:
FileNotFoundError: If the memray tracker file is not found.
ValueError: If the stats file cannot be parsed.

Examples:
>>> import tempfile
>>> with tempfile.TemporaryDirectory() as tmpdir:
... memray_tracker_fp = Path(tmpdir) / ".memray"
... memray_stats_fp = Path(tmpdir) / "memray_stats.json"
... with Tracker(memray_tracker_fp, follow_fork=True):
... A = np.ones((1000,), dtype=np.float64)
... stats = MemTrackableMixin.get_memray_stats(memray_tracker_fp, memray_stats_fp)
... print(stats['metadata']['peak_memory'])
8000

If you run it on a tracker file that is malformed, it won't work.
>>> with tempfile.TemporaryDirectory() as tmpdir:
... memray_tracker_fp = Path(tmpdir) / ".memray"
... memray_stats_fp = Path(tmpdir) / "memray_stats.json"
... memray_tracker_fp.touch()
... MemTrackableMixin.get_memray_stats(memray_tracker_fp, memray_stats_fp)
Traceback (most recent call last):
...
ValueError: Failed to extract and parse memray stats file at ...

If you run it on a non-existent file, it won't work.
>>> MemTrackableMixin.get_memray_stats(Path("non_existent.mem"), Path("non_existent.stats"))
Traceback (most recent call last):
...
FileNotFoundError: Memray tracker file not found at non_existent.mem
"""
if not memray_tracker_fp.is_file():
raise FileNotFoundError(f"Memray tracker file not found at {memray_tracker_fp}")

memray_stats_cmd = f"memray stats {memray_tracker_fp} --json -o {memray_stats_fp} -f"
try:
subprocess.run(memray_stats_cmd, shell=True, check=True, capture_output=True)
return json.loads(memray_stats_fp.read_text())
except Exception as e:
raise ValueError(f"Failed to extract and parse memray stats file at {memray_stats_fp}") from e

def __init__(self, *args, **kwargs):
self._mem_stats = kwargs.get("_mem_stats", defaultdict(list))

def __assert_key_exists(self, key: str) -> None:
if not hasattr(self, "_mem_stats"):
raise AttributeError("self._mem_stats should exist!")
if key not in self._mem_stats:
raise AttributeError(f"{key} should exist in self._mem_stats!")

def _peak_mem_for(self, key: str) -> list[float]:
self.__assert_key_exists(key)

return [v["metadata"]["peak_memory"] for v in self._mem_stats[key]]

@contextmanager
def _track_memory_as(self, key: str):
if not hasattr(self, "_mem_stats"):
self._mem_stats = defaultdict(list)

memory_stats = {}
with TemporaryDirectory() as tmpdir:
memray_fp = Path(tmpdir) / ".memray"
memray_stats_fp = Path(tmpdir) / "memray_stats.json"

try:
with Tracker(memray_fp, follow_fork=True):
yield
finally:
memory_stats.update(MemTrackableMixin.get_memray_stats(memray_fp, memray_stats_fp))
self._mem_stats[key].append(memory_stats)

@staticmethod
@doublewrap
def TrackMemoryAs(fn, key: str | None = None):
if key is None:
key = fn.__name__

@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
with self._track_memory_as(key):
out = fn(self, *args, **kwargs)
return out

return wrapper

@property
def _memory_stats(self):
out = {}
for k in self._mem_stats:
arr = np.array(self._peak_mem_for(k))
out[k] = (arr.mean(), len(arr), None if len(arr) <= 1 else arr.std())
return out

def _profile_memory_usages(self, only_keys: set[str] | None = None):
stats = {k: ((v, "B"), n, s) for k, (v, n, s) in self._memory_stats.items()}

if only_keys is not None:
stats = {k: v for k, v in stats.items() if k in only_keys}

return pprint_stats_map(stats, self._CUTOFFS_AND_UNITS)
49 changes: 4 additions & 45 deletions src/mixins/timeable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np

from .utils import doublewrap
from .utils import doublewrap, pprint_stats_map


class TimeableMixin:
Expand Down Expand Up @@ -38,41 +38,6 @@ class TimeableMixin:
(None, "weeks"),
]

@classmethod
def _get_pprint_num_unit(cls, x: float, x_unit: str = "sec") -> tuple[float, str]:
x_unit_factor = 1
for fac, unit in cls._CUTOFFS_AND_UNITS:
if unit == x_unit:
break
if fac is None:
raise LookupError(
f"Passed unit {x_unit} invalid! "
f"Must be one of {', '.join(u for f, u in cls._CUTOFFS_AND_UNITS)}."
)
x_unit_factor *= fac

min_unit = x * x_unit_factor
upper_bound = 1
for upper_bound_factor, unit in cls._CUTOFFS_AND_UNITS:
if (upper_bound_factor is None) or (min_unit < upper_bound * upper_bound_factor):
return min_unit / upper_bound, unit
upper_bound *= upper_bound_factor

@classmethod
def _pprint_duration(cls, mean_sec: float, n_times: int = 1, std_seconds: float | None = None) -> str:
mean_time, mean_unit = cls._get_pprint_num_unit(mean_sec)

if std_seconds:
std_time = std_seconds * mean_time / mean_sec
mean_std_str = f"{mean_time:.1f} ± {std_time:.1f} {mean_unit}"
else:
mean_std_str = f"{mean_time:.1f} {mean_unit}"

if n_times > 1:
return f"{mean_std_str} (x{n_times})"
else:
return mean_std_str

def __init__(self, *args, **kwargs):
self._timings = kwargs.get("_timings", defaultdict(list))

Expand Down Expand Up @@ -132,19 +97,13 @@ def _duration_stats(self):
out = {}
for k in self._timings:
arr = np.array(self._times_for(k))
out[k] = (arr.mean(), len(arr), arr.std())
out[k] = (arr.mean(), len(arr), None if len(arr) <= 1 else arr.std())
return out

def _profile_durations(self, only_keys: set[str] | None = None):
stats = self._duration_stats
stats = {k: ((v, "sec"), n, s) for k, (v, n, s) in self._duration_stats.items()}

if only_keys is not None:
stats = {k: v for k, v in stats.items() if k in only_keys}

longest_key_length = max(len(k) for k in stats)
ordered_keys = sorted(stats.keys(), key=lambda k: stats[k][0] * stats[k][1])
tfk_str = "\n".join(
(f"{k}:{' '*(longest_key_length - len(k))} " f"{self._pprint_duration(*stats[k])}")
for k in ordered_keys
)
return tfk_str
return pprint_stats_map(stats, self._CUTOFFS_AND_UNITS)
Loading
Loading