Skip to content

Commit

Permalink
Added doctests for getting memray stats and adjusted error cases.
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Nov 20, 2024
1 parent 14e28ea commit 145ed5e
Showing 1 changed file with 45 additions and 2 deletions.
47 changes: 45 additions & 2 deletions src/mixins/memtrackable.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,55 @@ class MemTrackableMixin:

@staticmethod
def get_memray_stats(memray_tracker_fp: Path, memray_stats_fp: Path) -> dict:
"""Extracts the memory stats from the tracker file and saves it to a json file and returns the stats.
Args:
memray_tracker_fp: The path to the memray tracker file.
memray_stats_fp: The path to save the memory statistics.
Returns:
The stats extracted from the memray tracker file.
Raises:
FileNotFoundError: If the memray tracker file is not found.
ValueError: If the stats file cannot be parsed.
Examples:
>>> import tempfile
>>> with tempfile.TemporaryDirectory() as tmpdir:
... memray_tracker_fp = Path(tmpdir) / ".memray"
... memray_stats_fp = Path(tmpdir) / "memray_stats.json"
... with Tracker(memray_tracker_fp, follow_fork=True):
... A = np.ones((1000,), dtype=np.float64)
... stats = MemTrackableMixin.get_memray_stats(memray_tracker_fp, memray_stats_fp)
... print(stats['metadata']['peak_memory'])
8000
If you run it on a tracker file that is malformed, it won't work.
>>> with tempfile.TemporaryDirectory() as tmpdir:
... memray_tracker_fp = Path(tmpdir) / ".memray"
... memray_stats_fp = Path(tmpdir) / "memray_stats.json"
... memray_tracker_fp.touch()
... MemTrackableMixin.get_memray_stats(memray_tracker_fp, memray_stats_fp)
Traceback (most recent call last):
...
ValueError: Failed to extract and parse memray stats file at ...
If you run it on a non-existent file, it won't work.
>>> MemTrackableMixin.get_memray_stats(Path("non_existent.mem"), Path("non_existent.stats"))
Traceback (most recent call last):
...
FileNotFoundError: Memray tracker file not found at non_existent.mem
"""
if not memray_tracker_fp.is_file():
raise FileNotFoundError(f"Memray tracker file not found at {memray_tracker_fp}")

memray_stats_cmd = f"memray stats {memray_tracker_fp} --json -o {memray_stats_fp} -f"
subprocess.run(memray_stats_cmd, shell=True, check=True, capture_output=True)
try:
subprocess.run(memray_stats_cmd, shell=True, check=True, capture_output=True)
return json.loads(memray_stats_fp.read_text())
except Exception as e:
raise ValueError(f"Failed to parse memray stats file at {memray_stats_fp}") from e
raise ValueError(f"Failed to extract and parse memray stats file at {memray_stats_fp}") from e

def __init__(self, *args, **kwargs):
self._mem_stats = kwargs.get("_mem_stats", defaultdict(list))
Expand Down

0 comments on commit 145ed5e

Please sign in to comment.