From 71d2afa94bd5d5124a22e5eca3aa176faedb94e5 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 19 Nov 2024 13:05:36 -0500 Subject: [PATCH 1/5] Moved unit normalization to a utils function and added doctests for it in prep for usage across both timeable and memtrackable --- src/mixins/timeable.py | 24 ++---------------- src/mixins/utils.py | 49 ++++++++++++++++++++++++++++++++++++ tests/test_timeable_mixin.py | 19 -------------- 3 files changed, 51 insertions(+), 41 deletions(-) diff --git a/src/mixins/timeable.py b/src/mixins/timeable.py index f8e8e42..9f7c362 100644 --- a/src/mixins/timeable.py +++ b/src/mixins/timeable.py @@ -7,7 +7,7 @@ import numpy as np -from .utils import doublewrap +from .utils import doublewrap, normalize_unit class TimeableMixin: @@ -38,29 +38,9 @@ class TimeableMixin: (None, "weeks"), ] - @classmethod - def _get_pprint_num_unit(cls, x: float, x_unit: str = "sec") -> tuple[float, str]: - x_unit_factor = 1 - for fac, unit in cls._CUTOFFS_AND_UNITS: - if unit == x_unit: - break - if fac is None: - raise LookupError( - f"Passed unit {x_unit} invalid! " - f"Must be one of {', '.join(u for f, u in cls._CUTOFFS_AND_UNITS)}." - ) - x_unit_factor *= fac - - min_unit = x * x_unit_factor - upper_bound = 1 - for upper_bound_factor, unit in cls._CUTOFFS_AND_UNITS: - if (upper_bound_factor is None) or (min_unit < upper_bound * upper_bound_factor): - return min_unit / upper_bound, unit - upper_bound *= upper_bound_factor - @classmethod def _pprint_duration(cls, mean_sec: float, n_times: int = 1, std_seconds: float | None = None) -> str: - mean_time, mean_unit = cls._get_pprint_num_unit(mean_sec) + mean_time, mean_unit = normalize_unit((mean_sec, "sec"), cls._CUTOFFS_AND_UNITS) if std_seconds: std_time = std_seconds * mean_time / mean_sec diff --git a/src/mixins/utils.py b/src/mixins/utils.py index 9e74e5f..b6efa67 100644 --- a/src/mixins/utils.py +++ b/src/mixins/utils.py @@ -19,3 +19,52 @@ def new_dec(*args, **kwargs): return lambda realf: f(realf, *args, **kwargs) return new_dec + + +QUANT_T = tuple[float | None, str] +UNITS_LIST_T = list[QUANT_T] + + +def normalize_unit(val: QUANT_T, cutoffs_and_units: UNITS_LIST_T) -> QUANT_T: + """Converts a quantity to the largest possible unit and returns the quantity and unit. + + Args: + val: A tuple of a number and a unit to be normalized. + cutoffs_and_units: A list of tuples of valid cutoffs and units. + + Returns: + A tuple of the normalized value and unit. + + Raises: + LookupError: If the unit is not in the list of valid units. + + Example: + >>> normalize_unit((1000, "ms"), [(1000, "ms"), (60, "s"), (60, "min"), (None, "h")]) + (1.0, 's') + >>> normalize_unit((720000, "ms"), [(1000, "ms"), (60, "s"), (60, "min"), (None, "h")]) + (12.0, 'min') + >>> normalize_unit((3600, "s"), [(1000, "ms"), (60, "s"), (60, "min"), (None, "h")]) + (1.0, 'h') + >>> normalize_unit((5000, "ns"), [(1000, "ms"), (60, "s"), (60, "min"), (None, "h")]) + Traceback (most recent call last): + ... + LookupError: Passed unit ns invalid! Must be one of ms, s, min, h. + """ + x, x_unit = val + x_unit_factor = 1 + for fac, unit in cutoffs_and_units: + if unit == x_unit: + break + if fac is None: + raise LookupError( + f"Passed unit {x_unit} invalid! " + f"Must be one of {', '.join(u for f, u in cutoffs_and_units)}." + ) + x_unit_factor *= fac + + min_unit = x * x_unit_factor + upper_bound = 1 + for upper_bound_factor, unit in cutoffs_and_units: + if (upper_bound_factor is None) or (min_unit < upper_bound * upper_bound_factor): + return min_unit / upper_bound, unit + upper_bound *= upper_bound_factor diff --git a/tests/test_timeable_mixin.py b/tests/test_timeable_mixin.py index 286459e..117ec31 100644 --- a/tests/test_timeable_mixin.py +++ b/tests/test_timeable_mixin.py @@ -47,25 +47,6 @@ def test_benchmark_timing(benchmark): benchmark(T.decorated_takes_time, 0.00001) -def test_pprint_num_unit(): - assert (5, "μs") == TimeableMixin._get_pprint_num_unit(5 * 1e-6) - - class Derived(TimeableMixin): - _CUTOFFS_AND_UNITS = [(10, "foo"), (2, "bar"), (None, "biz")] - - assert (3, "biz") == Derived._get_pprint_num_unit(3, "biz") - assert (3, "foo") == Derived._get_pprint_num_unit(3 / 20, "biz") - assert (1.2, "biz") == Derived._get_pprint_num_unit(2.4 * 10, "foo") - - try: - Derived._get_pprint_num_unit(1, "WRONG") - raise AssertionError("Should have raised an exception") - except LookupError: - pass - except Exception as e: - raise AssertionError(f"Raised the wrong exception: {e}") - - def test_context_manager(): T = TimeableDerived() From 24b126b0413ce7708817c0780a78c7b5b66cc72e Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 19 Nov 2024 15:28:04 -0500 Subject: [PATCH 2/5] Moved pretty print code out as well. --- src/mixins/timeable.py | 27 ++-------- src/mixins/utils.py | 100 ++++++++++++++++++++++++++++++++++- tests/test_timeable_mixin.py | 6 ++- 3 files changed, 106 insertions(+), 27 deletions(-) diff --git a/src/mixins/timeable.py b/src/mixins/timeable.py index 9f7c362..5457b64 100644 --- a/src/mixins/timeable.py +++ b/src/mixins/timeable.py @@ -7,7 +7,7 @@ import numpy as np -from .utils import doublewrap, normalize_unit +from .utils import doublewrap, pprint_stats_map class TimeableMixin: @@ -38,21 +38,6 @@ class TimeableMixin: (None, "weeks"), ] - @classmethod - def _pprint_duration(cls, mean_sec: float, n_times: int = 1, std_seconds: float | None = None) -> str: - mean_time, mean_unit = normalize_unit((mean_sec, "sec"), cls._CUTOFFS_AND_UNITS) - - if std_seconds: - std_time = std_seconds * mean_time / mean_sec - mean_std_str = f"{mean_time:.1f} ± {std_time:.1f} {mean_unit}" - else: - mean_std_str = f"{mean_time:.1f} {mean_unit}" - - if n_times > 1: - return f"{mean_std_str} (x{n_times})" - else: - return mean_std_str - def __init__(self, *args, **kwargs): self._timings = kwargs.get("_timings", defaultdict(list)) @@ -116,15 +101,9 @@ def _duration_stats(self): return out def _profile_durations(self, only_keys: set[str] | None = None): - stats = self._duration_stats + stats = {k: ((v, "sec"), n, s) for k, (v, n, s) in self._duration_stats.items()} if only_keys is not None: stats = {k: v for k, v in stats.items() if k in only_keys} - longest_key_length = max(len(k) for k in stats) - ordered_keys = sorted(stats.keys(), key=lambda k: stats[k][0] * stats[k][1]) - tfk_str = "\n".join( - (f"{k}:{' '*(longest_key_length - len(k))} " f"{self._pprint_duration(*stats[k])}") - for k in ordered_keys - ) - return tfk_str + return pprint_stats_map(stats, self._CUTOFFS_AND_UNITS) diff --git a/src/mixins/utils.py b/src/mixins/utils.py index b6efa67..e2e9a8c 100644 --- a/src/mixins/utils.py +++ b/src/mixins/utils.py @@ -38,7 +38,7 @@ def normalize_unit(val: QUANT_T, cutoffs_and_units: UNITS_LIST_T) -> QUANT_T: Raises: LookupError: If the unit is not in the list of valid units. - Example: + Examples: >>> normalize_unit((1000, "ms"), [(1000, "ms"), (60, "s"), (60, "min"), (None, "h")]) (1.0, 's') >>> normalize_unit((720000, "ms"), [(1000, "ms"), (60, "s"), (60, "min"), (None, "h")]) @@ -68,3 +68,101 @@ def normalize_unit(val: QUANT_T, cutoffs_and_units: UNITS_LIST_T) -> QUANT_T: if (upper_bound_factor is None) or (min_unit < upper_bound * upper_bound_factor): return min_unit / upper_bound, unit upper_bound *= upper_bound_factor + + +def pprint_quant( + cutoffs_and_units: UNITS_LIST_T, val: QUANT_T, n_times: int = 1, std_val: float | None = None +) -> str: + """Pretty prints a quantity with its unit, reflecting both a possible variance and # of times measured. + + Args: + val: A tuple of a number and a unit to be normalized. + cutoffs_and_units: A list of tuples of valid cutoffs and units to normalize the passed units. + n_times: The number of times the quantity was measured. + std_val: The standard deviation of the quantity (in the same units as val). + + Returns: + A string representation of the quantity with its unit. + + Examples: + >>> pprint_quant([(1000, "ms"), (60, "s"), (60, "min"), (None, "h")], (1000, "ms")) + '1.0 s' + >>> pprint_quant([(1000, "ms"), (60, "s"), (60, "min"), (None, "h")], (1000, "ms"), n_times=3) + '1.0 s (x3)' + >>> pprint_quant([(1000, "ms"), (60, "s"), (60, "min"), (None, "h")], (1000, "ms"), std_val=100) + '1.0 ± 0.1 s' + >>> pprint_quant([(1000, "ms"), (60, "s"), (60, "min")], (1200, "ms"), n_times=3, std_val=100) + '1.2 ± 0.1 s (x3)' + """ + norm_val, unit = normalize_unit(val, cutoffs_and_units) + if std_val is not None: + unit_conversion_factor = norm_val / val[0] + std_norm_val = std_val * unit_conversion_factor + + mean_std_str = f"{norm_val:.1f} ± {std_norm_val:.1f} {unit}" + else: + mean_std_str = f"{norm_val:.1f} {unit}" + + if n_times > 1: + return f"{mean_std_str} (x{n_times})" + else: + return mean_std_str + + +def pprint_stats_map( + stats: dict[str, tuple[QUANT_T, int, float | None]], cutoffs_and_units: UNITS_LIST_T +) -> str: + """Pretty prints a dictionary of summary statistics of quantities with their units, respecting key length. + + Args: + stats: A dictionary of tuples of a number and a unit to be normalized, the number of times measured, + and the standard deviation of the quantity (in the same units as the number). + cutoffs_and_units: A list of tuples of valid cutoffs and units to normalize the passed units. + + Returns: + A string representation of the dictionary of quantities with their units. This string representation + will be ordered by the greatest total value (mean times number of measurements) of the quantities. + + Examples: + >>> print(pprint_stats_map( + ... {"foo": ((1000, "ms"), 3, 100), "foobar": ((1000, "ms"), 1, None)}, + ... [(1000, "ms"), (60, "s"), (60, "min"), (None, "h")] + ... )) + foobar: 1.0 s + foo: 1.0 ± 0.1 s (x3) + >>> print(pprint_stats_map( + ... {"foo": ((1000, "ms"), 3, 100), "foobar": ((72, "s"), 1, None)}, + ... [(1000, "ms"), (60, "s"), (60, "min"), (None, "h")] + ... )) + foo: 1.0 ± 0.1 s (x3) + foobar: 1.2 min + >>> pprint_stats_map( + ... {"foo": ((1000, "ms"), 3, 100), "foobar": ((72, "y"), 1, None)}, + ... [(1000, "ms"), (60, "s"), (60, "min"), (None, "h")] + ... ) + Traceback (most recent call last): + ... + ValueError: Unit y in stats key foobar not found in cutoffs_and_units! + """ + + def total_val(key: str) -> float: + X_val, X_unit = stats[key][0] + factor = 1 + for fac, unit in cutoffs_and_units: + if unit == X_unit: + break + if fac is None: + raise ValueError(f"Unit {X_unit} in stats key {key} not found in cutoffs_and_units!") + factor *= fac + + base_unit_val = X_val * factor + n_times = stats[key][1] + + return base_unit_val * n_times + + longest_key_length = max(len(k) for k in stats) + ordered_keys = sorted(stats.keys(), key=total_val, reverse=False) + return "\n".join( + (f"{k}:{' '*(longest_key_length - len(k))} " f"{pprint_quant(cutoffs_and_units, *stats[k])}") + for k in ordered_keys + ) diff --git a/tests/test_timeable_mixin.py b/tests/test_timeable_mixin.py index 117ec31..0a04226 100644 --- a/tests/test_timeable_mixin.py +++ b/tests/test_timeable_mixin.py @@ -79,9 +79,11 @@ def test_times_and_profiling(): assert 0 == stats["decorated_takes_time_auto_key"][2] got_str = T._profile_durations() - want_str = "decorated_takes_time_auto_key: 2.0 sec\ndecorated: 1.5 ± 0.5 sec (x2)" + want_str = ( + "decorated_takes_time_auto_key: 2.0 ± 0.0 sec\ndecorated: 1.5 ± 0.5 sec (x2)" + ) assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}" got_str = T._profile_durations(only_keys=["decorated_takes_time_auto_key"]) - want_str = "decorated_takes_time_auto_key: 2.0 sec" + want_str = "decorated_takes_time_auto_key: 2.0 ± 0.0 sec" assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}" From daae98358310e473beddb193c311865980053b30 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 19 Nov 2024 16:03:54 -0500 Subject: [PATCH 3/5] Added memtrackable mixin. --- pyproject.toml | 1 + src/mixins/__init__.py | 3 +- src/mixins/memtrackable.py | 111 +++++++++++++++++++++++++++++++ src/mixins/timeable.py | 2 +- tests/test_memtrackable_mixin.py | 93 ++++++++++++++++++++++++++ tests/test_timeable_mixin.py | 8 +-- 6 files changed, 211 insertions(+), 7 deletions(-) create mode 100644 src/mixins/memtrackable.py create mode 100644 tests/test_memtrackable_mixin.py diff --git a/pyproject.toml b/pyproject.toml index 6152edb..4b9719e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ classifiers = [ dependencies = ["numpy"] [project.optional-dependencies] +memtrackable = ["memray"] dev = ["pre-commit<4"] tests = ["pytest", "pytest-cov", "pytest-benchmark"] diff --git a/src/mixins/__init__.py b/src/mixins/__init__.py index 574b788..b9ca970 100644 --- a/src/mixins/__init__.py +++ b/src/mixins/__init__.py @@ -1,4 +1,5 @@ +from .memtrackable import MemTrackableMixin from .seedable import SeedableMixin from .timeable import TimeableMixin -__all__ = ["SeedableMixin", "TimeableMixin"] +__all__ = ["MemTrackableMixin", "SeedableMixin", "TimeableMixin"] diff --git a/src/mixins/memtrackable.py b/src/mixins/memtrackable.py new file mode 100644 index 0000000..30fd915 --- /dev/null +++ b/src/mixins/memtrackable.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import functools +import json +import subprocess +from collections import defaultdict +from contextlib import contextmanager +from pathlib import Path +from tempfile import TemporaryDirectory + +import numpy as np +from memray import Tracker + +from .utils import doublewrap, pprint_stats_map + + +class MemTrackableMixin: + """A mixin class to add memory tracking functionality to a class for profiling its methods. + + This mixin class provides the following functionality: + - Tracking the memory use of methods using the TrackMemoryAs decorator. + - Timing of arbitrary code blocks using the _track_memory_as context manager. + - Profiling of the memory used across the life cycle of the class. + + This class uses `memray` to track the memory usage of the methods. + + Attributes: + _memory_usage: A dictionary of lists of memory usages of tracked methods. + The keys of the dictionary are the names of the tracked code blocks / methods. + The values are lists of memory used during the lifecycle of each tracked code blocks / methods. + """ + + _CUTOFFS_AND_UNITS = [ + (8, "b"), + (1000, "B"), + (1000, "kB"), + (1000, "MB"), + (1000, "GB"), + (1000, "TB"), + (1000, "PB"), + ] + + @staticmethod + def get_memray_stats(memray_tracker_fp: Path, memray_stats_fp: Path) -> dict: + memray_stats_cmd = f"memray stats {memray_tracker_fp} --json -o {memray_stats_fp} -f" + subprocess.run(memray_stats_cmd, shell=True, check=True, capture_output=True) + try: + return json.loads(memray_stats_fp.read_text()) + except Exception as e: + raise ValueError(f"Failed to parse memray stats file at {memray_stats_fp}") from e + + def __init__(self, *args, **kwargs): + self._mem_stats = kwargs.get("_mem_stats", defaultdict(list)) + + def __assert_key_exists(self, key: str) -> None: + if not hasattr(self, "_mem_stats"): + raise AttributeError("self._mem_stats should exist!") + if key not in self._mem_stats: + raise AttributeError(f"{key} should exist in self._mem_stats!") + + def _peak_mem_for(self, key: str) -> list[float]: + self.__assert_key_exists(key) + + return [v["metadata"]["peak_memory"] for v in self._mem_stats[key]] + + @contextmanager + def _track_memory_as(self, key: str): + if not hasattr(self, "_mem_stats"): + self._mem_stats = defaultdict(list) + + memory_stats = {} + with TemporaryDirectory() as tmpdir: + memray_fp = Path(tmpdir) / ".memray" + memray_stats_fp = Path(tmpdir) / "memray_stats.json" + + try: + with Tracker(memray_fp, follow_fork=True): + yield + finally: + memory_stats.update(MemTrackableMixin.get_memray_stats(memray_fp, memray_stats_fp)) + self._mem_stats[key].append(memory_stats) + + @staticmethod + @doublewrap + def TrackMemoryAs(fn, key: str | None = None): + if key is None: + key = fn.__name__ + + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + with self._track_memory_as(key): + out = fn(self, *args, **kwargs) + return out + + return wrapper + + @property + def _memory_stats(self): + out = {} + for k in self._mem_stats: + arr = np.array(self._peak_mem_for(k)) + out[k] = (arr.mean(), len(arr), None if len(arr) <= 1 else arr.std()) + return out + + def _profile_memory_usages(self, only_keys: set[str] | None = None): + stats = {k: ((v, "B"), n, s) for k, (v, n, s) in self._memory_stats.items()} + + if only_keys is not None: + stats = {k: v for k, v in stats.items() if k in only_keys} + + return pprint_stats_map(stats, self._CUTOFFS_AND_UNITS) diff --git a/src/mixins/timeable.py b/src/mixins/timeable.py index 5457b64..be55959 100644 --- a/src/mixins/timeable.py +++ b/src/mixins/timeable.py @@ -97,7 +97,7 @@ def _duration_stats(self): out = {} for k in self._timings: arr = np.array(self._times_for(k)) - out[k] = (arr.mean(), len(arr), arr.std()) + out[k] = (arr.mean(), len(arr), None if len(arr) <= 1 else arr.std()) return out def _profile_durations(self, only_keys: set[str] | None = None): diff --git a/tests/test_memtrackable_mixin.py b/tests/test_memtrackable_mixin.py new file mode 100644 index 0000000..aeef381 --- /dev/null +++ b/tests/test_memtrackable_mixin.py @@ -0,0 +1,93 @@ +import numpy as np + +from mixins import MemTrackableMixin + + +class MemTrackableDerived(MemTrackableMixin): + def __init__(self): + self.foo = "foo" + # Doesn't call super().__init__()! Should still work in this case. + + def uses_contextlib(self, mem_size_64b: int = 800000): + with self._track_memory_as("using_contextlib"): + np.ones((mem_size_64b,), dtype=np.float64) + + @MemTrackableMixin.TrackMemoryAs(key="decorated") + def decorated_takes_mem(self, mem_size_64b: int = 800000): + np.ones((mem_size_64b,), dtype=np.float64) + + @MemTrackableMixin.TrackMemoryAs + def decorated_takes_mem_auto_key(self, mem_size_64b: int = 800000): + np.ones((mem_size_64b,), dtype=np.float64) + + +def test_constructs(): + MemTrackableMixin() + MemTrackableDerived() + + +def test_errors_if_not_initialized(): + M = MemTrackableDerived() + try: + M._peak_mem_for("foo") + raise AssertionError("Should have raised an exception!") + except AttributeError as e: + assert "self._mem_stats should exist!" in str(e) + + M.uses_contextlib(mem_size_64b=8000000) # 64 MB + try: + M._peak_mem_for("wrong_key") + raise AssertionError("Should have raised an exception!") + except AttributeError as e: + assert "wrong_key should exist in self._mem_stats!" in str(e) + + +def test_benchmark_timing(benchmark): + M = MemTrackableDerived() + + benchmark(M.decorated_takes_mem, 8000) + + +def test_context_manager(): + M = MemTrackableDerived() + + M.uses_contextlib(mem_size_64b=8000000) # 64 MB + + mem_used = M._peak_mem_for("using_contextlib")[-1] + np.testing.assert_almost_equal(mem_used, 64 * 1000000, decimal=1) + + M.uses_contextlib(mem_size_64b=80000) # 0.64 MB + + mem_used = M._peak_mem_for("using_contextlib")[-1] + np.testing.assert_almost_equal(mem_used, 64 * 10000, decimal=1) + + +def test_decorators_and_profiling(): + M = MemTrackableDerived() + M.decorated_takes_mem(mem_size_64b=16000) + + mem_used = M._peak_mem_for("decorated")[-1] + np.testing.assert_almost_equal(mem_used, 64 * 2000, decimal=1) + + M.decorated_takes_mem_auto_key(mem_size_64b=40000) + mem_used = M._peak_mem_for("decorated_takes_mem_auto_key")[-1] + np.testing.assert_almost_equal(mem_used, 64 * 5000, decimal=1) + + M.decorated_takes_mem(mem_size_64b=8000) + stats = M._memory_stats + + assert {"decorated", "decorated_takes_mem_auto_key"} == set(stats.keys()) + np.testing.assert_almost_equal(64 * 1500, stats["decorated"][0], decimal=1) + assert 2 == stats["decorated"][1] + np.testing.assert_almost_equal(0.5 * 64 * 1000, stats["decorated"][2], decimal=1) + np.testing.assert_almost_equal(64 * 5000, stats["decorated_takes_mem_auto_key"][0], decimal=1) + assert 1 == stats["decorated_takes_mem_auto_key"][1] + assert stats["decorated_takes_mem_auto_key"][2] is None + + got_str = M._profile_memory_usages() + want_str = "decorated: 96.0 ± 32.0 kB (x2)\ndecorated_takes_mem_auto_key: 320.0 kB" + assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}" + + got_str = M._profile_memory_usages(only_keys=["decorated_takes_mem_auto_key"]) + want_str = "decorated_takes_mem_auto_key: 320.0 kB" + assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}" diff --git a/tests/test_timeable_mixin.py b/tests/test_timeable_mixin.py index 0a04226..4bbf732 100644 --- a/tests/test_timeable_mixin.py +++ b/tests/test_timeable_mixin.py @@ -76,14 +76,12 @@ def test_times_and_profiling(): np.testing.assert_almost_equal(0.5, stats["decorated"][2], decimal=1) np.testing.assert_almost_equal(2, stats["decorated_takes_time_auto_key"][0], decimal=1) assert 1 == stats["decorated_takes_time_auto_key"][1] - assert 0 == stats["decorated_takes_time_auto_key"][2] + assert stats["decorated_takes_time_auto_key"][2] is None got_str = T._profile_durations() - want_str = ( - "decorated_takes_time_auto_key: 2.0 ± 0.0 sec\ndecorated: 1.5 ± 0.5 sec (x2)" - ) + want_str = "decorated_takes_time_auto_key: 2.0 sec\ndecorated: 1.5 ± 0.5 sec (x2)" assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}" got_str = T._profile_durations(only_keys=["decorated_takes_time_auto_key"]) - want_str = "decorated_takes_time_auto_key: 2.0 ± 0.0 sec" + want_str = "decorated_takes_time_auto_key: 2.0 sec" assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}" From 14e28eae3f0a1ab5c6e7a58b45622197366f86a9 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 20 Nov 2024 09:58:10 -0500 Subject: [PATCH 4/5] Made tests install memray. --- .github/workflows/tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 2e0fab8..5da1d1b 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -28,7 +28,7 @@ jobs: - name: Install packages run: | - pip install -e .[tests] + pip install -e .[tests,memtrackable] #---------------------------------------------- # run test suite From 145ed5ee9a79d82dc24f00b3b9d3a99b945b2d56 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 20 Nov 2024 10:09:57 -0500 Subject: [PATCH 5/5] Added doctests for getting memray stats and adjusted error cases. --- src/mixins/memtrackable.py | 47 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/src/mixins/memtrackable.py b/src/mixins/memtrackable.py index 30fd915..2b6ee25 100644 --- a/src/mixins/memtrackable.py +++ b/src/mixins/memtrackable.py @@ -42,12 +42,55 @@ class MemTrackableMixin: @staticmethod def get_memray_stats(memray_tracker_fp: Path, memray_stats_fp: Path) -> dict: + """Extracts the memory stats from the tracker file and saves it to a json file and returns the stats. + + Args: + memray_tracker_fp: The path to the memray tracker file. + memray_stats_fp: The path to save the memory statistics. + + Returns: + The stats extracted from the memray tracker file. + + Raises: + FileNotFoundError: If the memray tracker file is not found. + ValueError: If the stats file cannot be parsed. + + Examples: + >>> import tempfile + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... memray_tracker_fp = Path(tmpdir) / ".memray" + ... memray_stats_fp = Path(tmpdir) / "memray_stats.json" + ... with Tracker(memray_tracker_fp, follow_fork=True): + ... A = np.ones((1000,), dtype=np.float64) + ... stats = MemTrackableMixin.get_memray_stats(memray_tracker_fp, memray_stats_fp) + ... print(stats['metadata']['peak_memory']) + 8000 + + If you run it on a tracker file that is malformed, it won't work. + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... memray_tracker_fp = Path(tmpdir) / ".memray" + ... memray_stats_fp = Path(tmpdir) / "memray_stats.json" + ... memray_tracker_fp.touch() + ... MemTrackableMixin.get_memray_stats(memray_tracker_fp, memray_stats_fp) + Traceback (most recent call last): + ... + ValueError: Failed to extract and parse memray stats file at ... + + If you run it on a non-existent file, it won't work. + >>> MemTrackableMixin.get_memray_stats(Path("non_existent.mem"), Path("non_existent.stats")) + Traceback (most recent call last): + ... + FileNotFoundError: Memray tracker file not found at non_existent.mem + """ + if not memray_tracker_fp.is_file(): + raise FileNotFoundError(f"Memray tracker file not found at {memray_tracker_fp}") + memray_stats_cmd = f"memray stats {memray_tracker_fp} --json -o {memray_stats_fp} -f" - subprocess.run(memray_stats_cmd, shell=True, check=True, capture_output=True) try: + subprocess.run(memray_stats_cmd, shell=True, check=True, capture_output=True) return json.loads(memray_stats_fp.read_text()) except Exception as e: - raise ValueError(f"Failed to parse memray stats file at {memray_stats_fp}") from e + raise ValueError(f"Failed to extract and parse memray stats file at {memray_stats_fp}") from e def __init__(self, *args, **kwargs): self._mem_stats = kwargs.get("_mem_stats", defaultdict(list))