Skip to content

Commit

Permalink
more checks and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Oct 12, 2023
1 parent c3fd9c4 commit c3b6b50
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 1 deletion.
4 changes: 3 additions & 1 deletion doctr/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def estimate_orientation(img: np.ndarray, n_ct: int = 50, ratio_threshold_for_li
the angle of the general document orientation
"""

assert len(img.shape) == 3 and img.shape[-1] in [1, 3], f"Image shape {img.shape} not supported"
if np.max(img) <= 1 and np.min(img) >= 0 or (np.max(img) <= 255 and np.min(img) >= 0 and img.shape[-1] == 1):
thresh = img.astype(np.uint8)
if np.max(img) <= 255 and np.min(img) >= 0 and img.shape[-1] == 3:
Expand Down Expand Up @@ -70,7 +71,8 @@ def estimate_orientation(img: np.ndarray, n_ct: int = 50, ratio_threshold_for_li
if len(angles) == 0:
return 0 # in case no angles is found
else:
return -median_low(angles)
median = -median_low(angles)
return median if median != 0 else 0


def rectify_crops(
Expand Down
3 changes: 3 additions & 0 deletions tests/common/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def test_estimate_orientation(mock_image, mock_bitmap, mock_tilted_payslip):
angle_rotated = estimate_orientation(rotated)
assert abs(angle_rotated) < 1.0

with pytest.raises(AssertionError):
estimate_orientation(np.ones((10, 10, 10)))


def test_get_lang():
sentence = "This is a test sentence."
Expand Down
6 changes: 6 additions & 0 deletions tests/common/test_models_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def test_documentbuilder():
out = doc_builder(pages, [boxes, boxes], [[("hello", 1.0)] * words_per_page] * num_pages, [(100, 200), (100, 200)])
assert isinstance(out, Document)
assert len(out.pages) == num_pages
assert all([isinstance(page.page, np.ndarray) for page in out.pages]) and all(
[page.page.shape == (100, 200, 3) for page in out.pages]
)
# 1 Block & 1 line per page
assert len(out.pages[0].blocks) == 1 and len(out.pages[0].blocks[0].lines) == 1
assert len(out.pages[0].blocks[0].lines[0].words) == words_per_page
Expand Down Expand Up @@ -79,6 +82,9 @@ def test_kiedocumentbuilder():
)
assert isinstance(out, KIEDocument)
assert len(out.pages) == num_pages
assert all([isinstance(page.page, np.ndarray) for page in out.pages]) and all(
[page.page.shape == (100, 200, 3) for page in out.pages]
)
# 1 Block & 1 line per page
assert len(out.pages[0].predictions) == 1
assert len(out.pages[0].predictions[CLASS_NAME]) == words_per_page
Expand Down

0 comments on commit c3b6b50

Please sign in to comment.