From 145ed5ee9a79d82dc24f00b3b9d3a99b945b2d56 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 20 Nov 2024 10:09:57 -0500 Subject: [PATCH] Added doctests for getting memray stats and adjusted error cases. --- src/mixins/memtrackable.py | 47 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/src/mixins/memtrackable.py b/src/mixins/memtrackable.py index 30fd915..2b6ee25 100644 --- a/src/mixins/memtrackable.py +++ b/src/mixins/memtrackable.py @@ -42,12 +42,55 @@ class MemTrackableMixin: @staticmethod def get_memray_stats(memray_tracker_fp: Path, memray_stats_fp: Path) -> dict: + """Extracts the memory stats from the tracker file and saves it to a json file and returns the stats. + + Args: + memray_tracker_fp: The path to the memray tracker file. + memray_stats_fp: The path to save the memory statistics. + + Returns: + The stats extracted from the memray tracker file. + + Raises: + FileNotFoundError: If the memray tracker file is not found. + ValueError: If the stats file cannot be parsed. + + Examples: + >>> import tempfile + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... memray_tracker_fp = Path(tmpdir) / ".memray" + ... memray_stats_fp = Path(tmpdir) / "memray_stats.json" + ... with Tracker(memray_tracker_fp, follow_fork=True): + ... A = np.ones((1000,), dtype=np.float64) + ... stats = MemTrackableMixin.get_memray_stats(memray_tracker_fp, memray_stats_fp) + ... print(stats['metadata']['peak_memory']) + 8000 + + If you run it on a tracker file that is malformed, it won't work. + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... memray_tracker_fp = Path(tmpdir) / ".memray" + ... memray_stats_fp = Path(tmpdir) / "memray_stats.json" + ... memray_tracker_fp.touch() + ... MemTrackableMixin.get_memray_stats(memray_tracker_fp, memray_stats_fp) + Traceback (most recent call last): + ... + ValueError: Failed to extract and parse memray stats file at ... + + If you run it on a non-existent file, it won't work. + >>> MemTrackableMixin.get_memray_stats(Path("non_existent.mem"), Path("non_existent.stats")) + Traceback (most recent call last): + ... + FileNotFoundError: Memray tracker file not found at non_existent.mem + """ + if not memray_tracker_fp.is_file(): + raise FileNotFoundError(f"Memray tracker file not found at {memray_tracker_fp}") + memray_stats_cmd = f"memray stats {memray_tracker_fp} --json -o {memray_stats_fp} -f" - subprocess.run(memray_stats_cmd, shell=True, check=True, capture_output=True) try: + subprocess.run(memray_stats_cmd, shell=True, check=True, capture_output=True) return json.loads(memray_stats_fp.read_text()) except Exception as e: - raise ValueError(f"Failed to parse memray stats file at {memray_stats_fp}") from e + raise ValueError(f"Failed to extract and parse memray stats file at {memray_stats_fp}") from e def __init__(self, *args, **kwargs): self._mem_stats = kwargs.get("_mem_stats", defaultdict(list))