diff --git a/tests/test_utils.py b/tests/test_utils.py index e2c14ce..e91a48e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,9 +5,9 @@ def test_overlay_mask(): + # RGB image img = Image.fromarray(np.zeros((4, 4, 3)).astype(np.uint8)) mask = Image.fromarray(255 * np.ones((4, 4)).astype(np.uint8)) - overlayed = utils.overlay_mask(img, mask, alpha=0.7) # Check object type @@ -16,3 +16,11 @@ def test_overlay_mask(): assert np.all(np.asarray(overlayed)[..., 0] == 0) assert np.all(np.asarray(overlayed)[..., 1] == 0) assert np.all(np.asarray(overlayed)[..., 2] == 39) + + # grayscale image + img = Image.fromarray(np.zeros((4, 4)).astype(np.uint8)) + mask = Image.fromarray(255 * np.ones((4, 4)).astype(np.uint8)) + overlayed = utils.overlay_mask(img, mask, alpha=0.7) + + # Verify value + assert np.all(np.asarray(overlayed) == 39) diff --git a/torchcam/utils.py b/torchcam/utils.py index 24b7419..e67a0d0 100644 --- a/torchcam/utils.py +++ b/torchcam/utils.py @@ -39,9 +39,16 @@ def overlay_mask(img: Image, mask: Image, colormap: str = "jet", alpha: float = if not isinstance(alpha, float) or alpha < 0 or alpha >= 1: raise ValueError("alpha argument is expected to be of type float between 0 and 1") + if len(img.getbands()) not in {1, 3}: + raise ValueError("img argument needs to be a grayscale or RGB image") + cmap = cm.get_cmap(colormap) # Resize mask and apply colormap overlay = mask.resize(img.size, resample=Resampling.BICUBIC) - overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, :3]).astype(np.uint8) + + overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, 2 if len(img.getbands()) == 1 else slice(0, 3)]).astype( + np.uint8 + ) + # Overlay the image with the mask return fromarray((alpha * np.asarray(img) + (1 - alpha) * cast(np.ndarray, overlay)).astype(np.uint8))