From c3b6b504121e2097b985d82c3b9f228e0c162530 Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 5 Oct 2023 11:50:00 +0200 Subject: [PATCH] more checks and tests --- doctr/models/_utils.py | 4 +++- tests/common/test_models.py | 3 +++ tests/common/test_models_builder.py | 6 ++++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/doctr/models/_utils.py b/doctr/models/_utils.py index 8ed94f345b..af15487f9a 100644 --- a/doctr/models/_utils.py +++ b/doctr/models/_utils.py @@ -39,6 +39,7 @@ def estimate_orientation(img: np.ndarray, n_ct: int = 50, ratio_threshold_for_li the angle of the general document orientation """ + assert len(img.shape) == 3 and img.shape[-1] in [1, 3], f"Image shape {img.shape} not supported" if np.max(img) <= 1 and np.min(img) >= 0 or (np.max(img) <= 255 and np.min(img) >= 0 and img.shape[-1] == 1): thresh = img.astype(np.uint8) if np.max(img) <= 255 and np.min(img) >= 0 and img.shape[-1] == 3: @@ -70,7 +71,8 @@ def estimate_orientation(img: np.ndarray, n_ct: int = 50, ratio_threshold_for_li if len(angles) == 0: return 0 # in case no angles is found else: - return -median_low(angles) + median = -median_low(angles) + return median if median != 0 else 0 def rectify_crops( diff --git a/tests/common/test_models.py b/tests/common/test_models.py index fe59f7a3ce..c4f1534965 100644 --- a/tests/common/test_models.py +++ b/tests/common/test_models.py @@ -52,6 +52,9 @@ def test_estimate_orientation(mock_image, mock_bitmap, mock_tilted_payslip): angle_rotated = estimate_orientation(rotated) assert abs(angle_rotated) < 1.0 + with pytest.raises(AssertionError): + estimate_orientation(np.ones((10, 10, 10))) + def test_get_lang(): sentence = "This is a test sentence." diff --git a/tests/common/test_models_builder.py b/tests/common/test_models_builder.py index 90c681b5f6..0a8edadb39 100644 --- a/tests/common/test_models_builder.py +++ b/tests/common/test_models_builder.py @@ -29,6 +29,9 @@ def test_documentbuilder(): out = doc_builder(pages, [boxes, boxes], [[("hello", 1.0)] * words_per_page] * num_pages, [(100, 200), (100, 200)]) assert isinstance(out, Document) assert len(out.pages) == num_pages + assert all([isinstance(page.page, np.ndarray) for page in out.pages]) and all( + [page.page.shape == (100, 200, 3) for page in out.pages] + ) # 1 Block & 1 line per page assert len(out.pages[0].blocks) == 1 and len(out.pages[0].blocks[0].lines) == 1 assert len(out.pages[0].blocks[0].lines[0].words) == words_per_page @@ -79,6 +82,9 @@ def test_kiedocumentbuilder(): ) assert isinstance(out, KIEDocument) assert len(out.pages) == num_pages + assert all([isinstance(page.page, np.ndarray) for page in out.pages]) and all( + [page.page.shape == (100, 200, 3) for page in out.pages] + ) # 1 Block & 1 line per page assert len(out.pages[0].predictions) == 1 assert len(out.pages[0].predictions[CLASS_NAME]) == words_per_page