Skip to content

Commit

Permalink
[prototype] compute orientation on segmentation map (#1336)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Nov 17, 2023
1 parent 6d92df5 commit e645ead
Show file tree
Hide file tree
Showing 20 changed files with 279 additions and 243 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ If both options are set to False, the predictor will always fit and return rotat
To interpret your model's predictions, you can visualize them interactively as follows:

```python
result.show(doc)
result.show()
```

![Visualization sample](docs/images/doctr_example_script.gif)
Expand Down
34 changes: 16 additions & 18 deletions doctr/io/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ class Page(Element):
Args:
----
page: image encoded as a numpy array in uint8
blocks: list of block elements
page_idx: the index of the page in the input raw document
dimensions: the page size in pixels in format (height, width)
Expand All @@ -248,13 +249,15 @@ class Page(Element):

def __init__(
self,
page: np.ndarray,
blocks: List[Block],
page_idx: int,
dimensions: Tuple[int, int],
orientation: Optional[Dict[str, Any]] = None,
language: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(blocks=blocks)
self.page = page
self.page_idx = page_idx
self.dimensions = dimensions
self.orientation = orientation if isinstance(orientation, dict) else dict(value=None, confidence=None)
Expand All @@ -267,17 +270,15 @@ def render(self, block_break: str = "\n\n") -> str:
def extra_repr(self) -> str:
return f"dimensions={self.dimensions}"

def show(self, page: np.ndarray, interactive: bool = True, preserve_aspect_ratio: bool = False, **kwargs) -> None:
def show(self, interactive: bool = True, preserve_aspect_ratio: bool = False, **kwargs) -> None:
"""Overlay the result on a given image
Args:
----
page: image encoded as a numpy array in uint8
interactive: whether the display should be interactive
preserve_aspect_ratio: pass True if you passed True to the predictor
**kwargs: additional keyword arguments passed to the matplotlib.pyplot.show method
"""
visualize_page(self.export(), page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio)
visualize_page(self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio)
plt.show(**kwargs)

def synthesize(self, **kwargs) -> np.ndarray:
Expand Down Expand Up @@ -408,6 +409,7 @@ class KIEPage(Element):
Args:
----
predictions: Dictionary with list of block elements for each detection class
page: image encoded as a numpy array in uint8
page_idx: the index of the page in the input raw document
dimensions: the page size in pixels in format (height, width)
orientation: a dictionary with the value of the rotation angle in degress and confidence of the prediction
Expand All @@ -420,13 +422,15 @@ class KIEPage(Element):

def __init__(
self,
page: np.ndarray,
predictions: Dict[str, List[Prediction]],
page_idx: int,
dimensions: Tuple[int, int],
orientation: Optional[Dict[str, Any]] = None,
language: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(predictions=predictions)
self.page = page
self.page_idx = page_idx
self.dimensions = dimensions
self.orientation = orientation if isinstance(orientation, dict) else dict(value=None, confidence=None)
Expand All @@ -441,17 +445,17 @@ def render(self, prediction_break: str = "\n\n") -> str:
def extra_repr(self) -> str:
return f"dimensions={self.dimensions}"

def show(self, page: np.ndarray, interactive: bool = True, preserve_aspect_ratio: bool = False, **kwargs) -> None:
def show(self, interactive: bool = True, preserve_aspect_ratio: bool = False, **kwargs) -> None:
"""Overlay the result on a given image
Args:
----
page: image encoded as a numpy array in uint8
interactive: whether the display should be interactive
preserve_aspect_ratio: pass True if you passed True to the predictor
**kwargs: keyword arguments passed to the matplotlib.pyplot.show method
"""
visualize_kie_page(self.export(), page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio)
visualize_kie_page(
self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio
)
plt.show(**kwargs)

def synthesize(self, **kwargs) -> np.ndarray:
Expand Down Expand Up @@ -561,16 +565,10 @@ def render(self, page_break: str = "\n\n\n\n") -> str:
"""Renders the full text of the element"""
return page_break.join(p.render() for p in self.pages)

def show(self, pages: List[np.ndarray], **kwargs) -> None:
"""Overlay the result on a given image
Args:
----
pages: list of images encoded as numpy arrays in uint8
**kwargs: keyword arguments passed to the Page.show method
"""
for img, result in zip(pages, self.pages):
result.show(img, **kwargs)
def show(self, **kwargs) -> None:
"""Overlay the result on a given image"""
for result in self.pages:
result.show(**kwargs)

def synthesize(self, **kwargs) -> List[np.ndarray]:
"""Synthesize all pages from their predictions
Expand Down
69 changes: 21 additions & 48 deletions doctr/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
from langdetect import LangDetectException, detect_langs

__all__ = ["estimate_orientation", "get_bitmap_angle", "get_language", "invert_data_structure"]
__all__ = ["estimate_orientation", "get_language", "invert_data_structure"]


def get_max_width_length_ratio(contour: np.ndarray) -> float:
Expand All @@ -21,29 +21,37 @@ def get_max_width_length_ratio(contour: np.ndarray) -> float:
----
contour: the contour from cv2.findContour
Returns: the maximum shape ratio
Returns:
-------
the maximum shape ratio
"""
_, (w, h), _ = cv2.minAreaRect(contour)
return max(w / h, h / w)


def estimate_orientation(img: np.ndarray, n_ct: int = 50, ratio_threshold_for_lines: float = 5) -> float:
def estimate_orientation(img: np.ndarray, n_ct: int = 50, ratio_threshold_for_lines: float = 5) -> int:
"""Estimate the angle of the general document orientation based on the
lines of the document and the assumption that they should be horizontal.
Args:
----
img: the img to analyze
img: the img or bitmap to analyze (H, W, C)
n_ct: the number of contours used for the orientation estimation
ratio_threshold_for_lines: this is the ratio w/h used to discriminates lines
Returns:
-------
the angle of the general document orientation
"""
gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray_img = cv2.medianBlur(gray_img, 5)
thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
assert len(img.shape) == 3 and img.shape[-1] in [1, 3], f"Image shape {img.shape} not supported"
max_value = np.max(img)
min_value = np.min(img)
if max_value <= 1 and min_value >= 0 or (max_value <= 255 and min_value >= 0 and img.shape[-1] == 1):
thresh = img.astype(np.uint8)
if max_value <= 255 and min_value >= 0 and img.shape[-1] == 3:
gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray_img = cv2.medianBlur(gray_img, 5)
thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]

# try to merge words in lines
(h, w) = img.shape[:2]
Expand All @@ -69,47 +77,8 @@ def estimate_orientation(img: np.ndarray, n_ct: int = 50, ratio_threshold_for_li
if len(angles) == 0:
return 0 # in case no angles is found
else:
return -median_low(angles)


def get_bitmap_angle(bitmap: np.ndarray, n_ct: int = 20, std_max: float = 3.0) -> float:
"""From a binarized segmentation map, find contours and fit min area rectangles to determine page angle
Args:
----
bitmap: binarized segmentation map
n_ct: number of contours to use to fit page angle
std_max: maximum deviation of the angle distribution to consider the mean angle reliable
Returns:
-------
The angle of the page
"""
# Find all contours on binarized seg map
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
# Sort contours
contours = sorted(contours, key=cv2.contourArea, reverse=True)

# Find largest contours and fit angles
# Track heights and widths to find aspect ratio (determine is rotation is clockwise)
angles, heights, widths = [], [], []
for ct in contours[:n_ct]:
_, (w, h), alpha = cv2.minAreaRect(ct)
widths.append(w)
heights.append(h)
angles.append(alpha)

if np.std(angles) > std_max:
# Edge case with angles of both 0 and 90°, or multi_oriented docs
angle = 0.0
else:
angle = -np.mean(angles)
# Determine rotation direction (clockwise/counterclockwise)
# Angle coverage: [-90°, +90°], half of the quadrant
if np.sum(widths) < np.sum(heights): # CounterClockwise
angle = 90 + angle

return angle
median = -median_low(angles)
return round(median) if abs(median) != 0 else 0


def rectify_crops(
Expand Down Expand Up @@ -154,9 +123,13 @@ def rectify_loc_preds(
def get_language(text: str) -> Tuple[str, float]:
"""Get languages of a text using langdetect model.
Get the language with the highest probability or no language if only a few words or a low probability
Args:
----
text (str): text
Returns:
-------
The detected language in ISO 639 code and confidence score
"""
try:
Expand Down
14 changes: 10 additions & 4 deletions doctr/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def extra_repr(self) -> str:

def __call__(
self,
pages: List[np.ndarray],
boxes: List[np.ndarray],
text_preds: List[List[Tuple[str, float]]],
page_shapes: List[Tuple[int, int]],
Expand All @@ -297,6 +298,7 @@ def __call__(
Args:
----
pages: list of N elements, where each element represents the page image
boxes: list of N elements, where each element represents the localization predictions, of shape (*, 5)
or (*, 6) for all words for a given page
text_preds: list of N elements, where each element is the list of all word prediction (text + confidence)
Expand Down Expand Up @@ -325,6 +327,7 @@ def __call__(

_pages = [
Page(
page,
self._build_blocks(
page_boxes,
word_preds,
Expand All @@ -334,8 +337,8 @@ def __call__(
orientation,
language,
)
for _idx, shape, page_boxes, word_preds, orientation, language in zip(
range(len(boxes)), page_shapes, boxes, text_preds, _orientations, _languages
for page, _idx, shape, page_boxes, word_preds, orientation, language in zip(
pages, range(len(boxes)), page_shapes, boxes, text_preds, _orientations, _languages
)
]

Expand All @@ -356,6 +359,7 @@ class KIEDocumentBuilder(DocumentBuilder):

def __call__( # type: ignore[override]
self,
pages: List[np.ndarray],
boxes: List[Dict[str, np.ndarray]],
text_preds: List[Dict[str, List[Tuple[str, float]]]],
page_shapes: List[Tuple[int, int]],
Expand All @@ -366,6 +370,7 @@ def __call__( # type: ignore[override]
Args:
----
pages: list of N elements, where each element represents the page image
boxes: list of N dictionaries, where each element represents the localization predictions for a class,
of shape (*, 5) or (*, 6) for all predictions
text_preds: list of N dictionaries, where each element is the list of all word prediction
Expand Down Expand Up @@ -400,6 +405,7 @@ def __call__( # type: ignore[override]

_pages = [
KIEPage(
page,
{
k: self._build_blocks(
page_boxes[k],
Expand All @@ -412,8 +418,8 @@ def __call__( # type: ignore[override]
orientation,
language,
)
for _idx, shape, page_boxes, word_preds, orientation, language in zip(
range(len(boxes)), page_shapes, boxes, text_preds, _orientations, _languages
for page, _idx, shape, page_boxes, word_preds, orientation, language in zip(
pages, range(len(boxes)), page_shapes, boxes, text_preds, _orientations, _languages
)
]

Expand Down
17 changes: 13 additions & 4 deletions doctr/models/detection/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Any, List, Union
from typing import Any, Dict, List, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -37,8 +37,9 @@ def __init__(
def forward(
self,
pages: List[Union[np.ndarray, torch.Tensor]],
return_maps: bool = False,
**kwargs: Any,
) -> List[np.ndarray]:
) -> Union[List[Dict[str, np.ndarray]], Tuple[List[Dict[str, np.ndarray]], List[np.ndarray]]]:
# Dimension check
if any(page.ndim != 3 for page in pages):
raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
Expand All @@ -48,5 +49,13 @@ def forward(
self.model, processed_batches = set_device_and_dtype(
self.model, processed_batches, _params.device, _params.dtype
)
predicted_batches = [self.model(batch, return_preds=True, **kwargs)["preds"] for batch in processed_batches]
return [pred for batch in predicted_batches for pred in batch]
predicted_batches = [
self.model(batch, return_preds=True, return_model_output=True, **kwargs) for batch in processed_batches
]
preds = [pred for batch in predicted_batches for pred in batch["preds"]]
if return_maps:
seg_maps = [
pred.permute(1, 2, 0).detach().cpu().numpy() for batch in predicted_batches for pred in batch["out_map"]
]
return preds, seg_maps
return preds
15 changes: 11 additions & 4 deletions doctr/models/detection/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Tuple, Union

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -37,14 +37,21 @@ def __init__(
def __call__(
self,
pages: List[Union[np.ndarray, tf.Tensor]],
return_maps: bool = False,
**kwargs: Any,
) -> List[Dict[str, np.ndarray]]:
) -> Union[List[Dict[str, np.ndarray]], Tuple[List[Dict[str, np.ndarray]], List[np.ndarray]]]:
# Dimension check
if any(page.ndim != 3 for page in pages):
raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")

processed_batches = self.pre_processor(pages)
predicted_batches = [
self.model(batch, return_preds=True, training=False, **kwargs)["preds"] for batch in processed_batches
self.model(batch, return_preds=True, return_model_output=True, training=False, **kwargs)
for batch in processed_batches
]
return [pred for batch in predicted_batches for pred in batch]

preds = [pred for batch in predicted_batches for pred in batch["preds"]]
if return_maps:
seg_maps = [pred.numpy() for batch in predicted_batches for pred in batch["out_map"]]
return preds, seg_maps
return preds
Loading

0 comments on commit e645ead

Please sign in to comment.