diff --git a/changes/299.general.rst b/changes/299.general.rst new file mode 100644 index 00000000..c13ddf4b --- /dev/null +++ b/changes/299.general.rst @@ -0,0 +1 @@ +Add infrastructure for testing memory usage diff --git a/src/stcal/testing_helpers.py b/src/stcal/testing_helpers.py new file mode 100644 index 00000000..7c322a30 --- /dev/null +++ b/src/stcal/testing_helpers.py @@ -0,0 +1,51 @@ +import tracemalloc + +MEMORY_UNIT_CONVERSION = {"B": 1, "KB": 1024, "MB": 1024 ** 2, "GB": 1024 ** 3, "TB": 1024 ** 4} + +class MemoryThresholdExceeded(Exception): + pass + + +class MemoryThreshold: + """ + Context manager to check peak memory usage against an expected threshold. + + example usage: + with MemoryThreshold(expected_usage): + # code that should not exceed expected + + If the code in the with statement uses more than the expected_usage + memory a ``MemoryThresholdExceeded`` exception + will be raised. + + Note that this class does not prevent allocations beyond the threshold + and only checks the actual peak allocations to the threshold at the + end of the with statement. + """ + + def __init__(self, expected_usage): + """ + Parameters + ---------- + expected_usage : str + Expected peak memory usage expressed as a whitespace-separated string + with a number and a memory unit (e.g. "100 KB"). + Supported units are "B", "KB", "MB", "GB", "TB". + """ + expected, self.units = expected_usage.upper().split() + self.expected_usage_bytes = float(expected) * MEMORY_UNIT_CONVERSION[self.units] + + def __enter__(self): + tracemalloc.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + _, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + + if peak > self.expected_usage_bytes: + scaling = MEMORY_UNIT_CONVERSION[self.units] + msg = ("Peak memory usage exceeded expected usage: " + f"{peak / scaling:.2f} {self.units} > " + f"{self.expected_usage_bytes / scaling:.2f} {self.units} ") + raise MemoryThresholdExceeded(msg) diff --git a/tests/outlier_detection/test_median.py b/tests/outlier_detection/test_median.py index 5361e44b..62a30dac 100644 --- a/tests/outlier_detection/test_median.py +++ b/tests/outlier_detection/test_median.py @@ -10,6 +10,7 @@ _OnDiskMedian, nanmedian3D, ) +from stcal.testing_helpers import MemoryThreshold def test_disk_appendable_array(tmp_path): @@ -194,3 +195,40 @@ def test_nanmedian3D(): assert med.dtype == np.float32 assert np.allclose(med, np.nanmedian(cube, axis=0), equal_nan=True) + + +@pytest.mark.parametrize("in_memory", [True, False]) +def test_memory_computer(in_memory, tmp_path): + """ + Analytically calculate how much memory the median computation + is supposed to take, then ensure that the implementation + stays near that. + + in_memory=True case allocates the following memory: + - one cube size + - median array == one frame size + + in_memory=False case allocates the following memory: + - one buffer size, which by default is the frame size + - median array == one frame size + + add a half-frame-size buffer to the expected memory usage in both cases + """ + shp = (20, 500, 500) + cube_size = np.dtype("float32").itemsize * shp[0] * shp[1] * shp[2] #bytes + frame_size = cube_size / shp[0] + + # calculate expected memory usage + if in_memory: + expected_mem = cube_size + frame_size*1.5 + else: + expected_mem = frame_size * 2.5 + + # compute the median while tracking memory usage + with MemoryThreshold(str(expected_mem) + " B"): + computer = MedianComputer(shp, in_memory=in_memory, tempdir=tmp_path) + for i in range(shp[0]): + frame = np.full(shp[1:], i, dtype=np.float32) + computer.append(frame, i) + del frame + computer.evaluate() diff --git a/tests/test_infrastructure.py b/tests/test_infrastructure.py new file mode 100644 index 00000000..37e68482 --- /dev/null +++ b/tests/test_infrastructure.py @@ -0,0 +1,16 @@ +"""Tests of custom testing infrastructure""" + +import pytest +import numpy as np +from stcal.testing_helpers import MemoryThreshold, MemoryThresholdExceeded + + +def test_memory_threshold(): + with MemoryThreshold("10 KB"): + buff = np.ones(1000, dtype=np.uint8) + + +def test_memory_threshold_exceeded(): + with pytest.raises(MemoryThresholdExceeded): + with MemoryThreshold("500. B"): + buff = np.ones(10000, dtype=np.uint8) \ No newline at end of file