would it be a good idea to add sorting to the detection predictor pipeline? #1836
Unanswered
mohamedsaeed8223
asked this question in
Q&A
Replies: 1 comment
-
Hi @mohamedsaeed8223 👋, I understand the intuition behind it, but the accuracy of the sorting of the individual boxes is not high enough - this is only achieved by merging them into lines and sublines and sorting them. Which in turn would mean that we would have to drag this information along the entire pipeline. However, you can also easily build your own pipeline by using the In the following example you can still use all the other features (like sorting) but without the need for the recognition from typing import Any
import cv2
import numpy as np
import torch
import requests
from doctr.io import DocumentFile
from doctr.models.recognition.predictor import RecognitionPredictor
from doctr.models.detection import detection_predictor
from doctr.models.predictor import OCRPredictor
# Fetch a example image
image_url = "https://huggingface.co/datasets/huggingfacejs/tasks/resolve/main/document-question-answering/document-question-answering-input.png"
bytes_data = requests.get(image_url).content
# Convert relative coordinates to absolute pixel values
def _to_absolute(geom, img_shape: tuple[int, int]) -> list[list[int]]:
h, w = img_shape
if len(geom) == 2: # Assume straight pages = True -> [[xmin, ymin], [xmax, ymax]]
(xmin, ymin), (xmax, ymax) = geom
xmin, xmax = int(round(w * xmin)), int(round(w * xmax))
ymin, ymax = int(round(h * ymin)), int(round(h * ymax))
return [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]
# Assume straight pages = False -> [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
else: # For polygons, convert each point to absolute coordinates
return [[int(point[0] * w), int(point[1] * h)] for point in geom]
# Load the document and model
doc = DocumentFile.from_images(bytes_data)
# Mock the recognition predictor (NOTE: This can also be replaced with any other recognition model for example TrOCR)
class MockedRecognitionPredictor(RecognitionPredictor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __call__(
self,
crops: list[np.ndarray | torch.Tensor],
**kwargs: Any,
) -> list[tuple[str, float]]:
return [("Hello", 1.0) for _ in crops]
# Initialize the OCR predictor
det_predictor = detection_predictor(pretrained=True, assume_straight_pages=False)
model = OCRPredictor(
det_predictor=det_predictor,
reco_predictor=MockedRecognitionPredictor(pre_processor=None, model=torch.nn.Module()),
assume_straight_pages=False,
straighten_pages=False,
detect_orientation=True,
detect_language=True,
disable_crop_orientation=False,
disable_page_orientation=False,
)
res = model(doc)
json_res = res.export()
# Decode the image (only for visualization purposes)
image = cv2.imdecode(np.frombuffer(bytes_data, np.uint8), cv2.IMREAD_COLOR)
for page in json_res["pages"]:
page_idx = page["page_idx"] # The index of the page
shape = page["dimensions"] # The shape of the page (height, width)
# Dict with the orientation of the page (angle in degrees, confidence)
# (if detect_orientation is True and/or assume_straight_pages is False)
orientation = page["orientation"]
language = page["language"] # The detected language of the page (if detect_language is True)
for block in page["blocks"]:
block_geom = _to_absolute(block["geometry"], shape) # The geom of the block (now absolute coordinates)
# The average objectness score of the block (over lines in the block)
block_objectness_score = block["objectness_score"]
# draw block on image
cv2.polylines(image, [np.array(block_geom).reshape(-1, 1, 2)], True, (0, 255, 0), 2)
for line in block["lines"]:
line_geom = _to_absolute(line["geometry"], shape) # The geom of the line (now absolute coordinates)
# The average objectness score of the block (over words in the line)
line_objectness_score = line["objectness_score"]
# draw line on image
cv2.polylines(image, [np.array(line_geom).reshape(-1, 1, 2)], True, (0, 0, 255), 2)
for word in line["words"]:
word_geom = _to_absolute(word["geometry"], shape) # The geom of the word (now absolute coordinates)
word_objectness_score = word["objectness_score"] # The objectness score of the word crop
value = word["value"] # The text value of the word
confidence = word["confidence"] # The confidence of the word
# Dict with the orientation of the word crop (angle in degrees, confidence)
word_crop_orientation = word["crop_orientation"]
# Draw word on image
# cv2.polylines(image, [np.array(word_geom).reshape(-1, 1, 2)], True, (255, 0, 0), 2)
# Save the final image with drawn polygons
cv2.imwrite("output.png", image) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I'm currently working on an application that needs OCR results and needs them grouped and sorted.
Most of my data is multi-column and grouped blocks of text that can't be sorted simply with x or y coordinates.
I found the sorting method in the document builder module but I was wondering if it would be better to have it in the detection predictor pipeline and have the flags like "reading direction" etc be there as well.
Beta Was this translation helpful? Give feedback.
All reactions