From 0a24cad22aeebcf0058430717492c6202ca419c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Tue, 12 Dec 2023 13:31:28 +0100 Subject: [PATCH] Account for dtype in the pixel array so the maximum value stays correct in the invert function (#3493) * fix(lib): fix This fixes an issue where the `invert` argument would only work for `uint8` dtypes. Now the `max` value is updated according to the pixel array dtype. Maybe we should add unit tests for that, but haven't found an obvious place to put unit tests. * chore(ci): add basic test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(ci): wrong attr name * Update tests/module/mobject/test_image.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Benjamin Hackl --- manim/mobject/types/image_mobject.py | 4 +++- tests/module/mobject/test_image.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) create mode 100644 tests/module/mobject/test_image.py diff --git a/manim/mobject/types/image_mobject.py b/manim/mobject/types/image_mobject.py index cb0248d3a5..990fb67686 100644 --- a/manim/mobject/types/image_mobject.py +++ b/manim/mobject/types/image_mobject.py @@ -191,7 +191,9 @@ def __init__( self.pixel_array, self.pixel_array_dtype ) if self.invert: - self.pixel_array[:, :, :3] = 255 - self.pixel_array[:, :, :3] + self.pixel_array[:, :, :3] = ( + np.iinfo(self.pixel_array_dtype).max - self.pixel_array[:, :, :3] + ) super().__init__(scale_to_resolution, **kwargs) def get_pixel_array(self): diff --git a/tests/module/mobject/test_image.py b/tests/module/mobject/test_image.py new file mode 100644 index 0000000000..a79fce5b8a --- /dev/null +++ b/tests/module/mobject/test_image.py @@ -0,0 +1,14 @@ +import numpy as np +import pytest + +from manim import ImageMobject + + +@pytest.mark.parametrize("dtype", [np.uint8, np.uint16]) +def test_invert_image(dtype): + array = (255 * np.random.rand(10, 10, 4)).astype(dtype) + image = ImageMobject(array, pixel_array_dtype=dtype, invert=True) + assert image.pixel_array.dtype == dtype + + array[:, :, :3] = np.iinfo(dtype).max - array[:, :, :3] + assert np.allclose(array, image.pixel_array)