Skip to content

Commit

Permalink
Handle polygon bboxes
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Jan 15, 2024
1 parent 3a68240 commit 7248218
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 19 deletions.
5 changes: 3 additions & 2 deletions benchmark/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from surya.model.segformer import load_model, load_processor
from surya.model.processing import open_pdf, get_page_images
from surya.detection import batch_inference
from surya.postprocessing.heatmap import draw_bboxes_on_image
from surya.postprocessing.heatmap import draw_bboxes_on_image, draw_polys_on_image
from surya.postprocessing.util import rescale_bbox
from surya.settings import settings
import os
Expand Down Expand Up @@ -68,6 +68,7 @@ def main():
page_metrics = collections.OrderedDict()
for idx, (tb, sb, cb) in enumerate(zip(tess_predictions, predictions, correct_boxes)):
surya_boxes = sb["bboxes"]
surya_polys = sb["polys"]

surya_metrics = precision_recall(surya_boxes, cb)
tess_metrics = precision_recall(tb, cb)
Expand All @@ -78,7 +79,7 @@ def main():
}

if args.debug:
bbox_image = draw_bboxes_on_image(surya_boxes, copy.deepcopy(images[idx]))
bbox_image = draw_polys_on_image(surya_polys, copy.deepcopy(images[idx]))
bbox_image.save(os.path.join(result_path, f"{idx}_bbox.png"))

mean_metrics = {}
Expand Down
4 changes: 2 additions & 2 deletions detect_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from surya.model.processing import open_pdf, get_page_images
from surya.detection import batch_inference
from surya.postprocessing.affinity import draw_lines_on_image
from surya.postprocessing.heatmap import draw_bboxes_on_image
from surya.postprocessing.heatmap import draw_bboxes_on_image, draw_polys_on_image
from surya.settings import settings
import os
import filetype
Expand Down Expand Up @@ -90,7 +90,7 @@ def main():

if args.images:
for idx, (image, pred, name) in enumerate(zip(images, predictions, names)):
bbox_image = draw_bboxes_on_image(pred["bboxes"], copy.deepcopy(image))
bbox_image = draw_polys_on_image(pred["polygons"], copy.deepcopy(image))
bbox_image.save(os.path.join(result_path, f"{name}_{idx}_bbox.png"))

column_image = draw_lines_on_image(pred["vertical_lines"], copy.deepcopy(image))
Expand Down
4 changes: 3 additions & 1 deletion surya/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,13 @@ def batch_inference(images: List, model, processor):
affinity_size = list(reversed(affinity_map.shape))
heatmap_size = list(reversed(heatmap.shape))
bboxes = get_and_clean_boxes(heatmap, heatmap_size, orig_sizes[i])
bbox_data = [bbox.model_dump() for bbox in bboxes]
vertical_lines = get_vertical_lines(affinity_map, affinity_size, orig_sizes[i])
horizontal_lines = get_horizontal_lines(affinity_map, affinity_size, orig_sizes[i])

results.append({
"bboxes": bboxes,
"bboxes": [bbd["bbox"] for bbd in bbox_data],
"polygons": [bbd["corners"] for bbd in bbox_data],
"vertical_lines": vertical_lines,
"horizontal_lines": horizontal_lines,
"heatmap": heat_img,
Expand Down
28 changes: 15 additions & 13 deletions surya/postprocessing/heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from PIL import ImageDraw

from surya.postprocessing.util import rescale_bbox
from surya.schema import PolygonBox
from surya.settings import settings


Expand Down Expand Up @@ -93,24 +94,14 @@ def get_detected_boxes(textmap, text_threshold=settings.DETECTOR_TEXT_THRESHOLD,
textmap = textmap.astype(np.float32)
boxes, labels = detect_boxes(textmap, text_threshold, low_text)
# From point form to box form
boxes = [
[box[0][0], box[0][1], box[1][0], box[2][1]]
for box in boxes
]

# Ensure correct box format
for box in boxes:
if box[0] > box[2]:
box[0], box[2] = box[2], box[0]
if box[1] > box[3]:
box[1], box[3] = box[3], box[1]
boxes = [PolygonBox(corners=box) for box in boxes]
return boxes


def get_and_clean_boxes(textmap, processor_size, image_size):
bboxes = get_detected_boxes(textmap)
bboxes = [rescale_bbox(bbox, processor_size, image_size) for bbox in bboxes]
bboxes = clean_contained_boxes(bboxes)
for bbox in bboxes:
bbox.rescale(processor_size, image_size)
return bboxes


Expand All @@ -122,3 +113,14 @@ def draw_bboxes_on_image(bboxes, image):

return image


def draw_polys_on_image(corners, image):
draw = ImageDraw.Draw(image)

for poly in corners:
poly = [(p[0], p[1]) for p in poly]
draw.polygon(poly, outline='red', width=1)

return image


20 changes: 19 additions & 1 deletion surya/postprocessing/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,22 @@ def rescale_bbox(bbox, processor_size, image_size):
new_bbox[1] = int(new_bbox[1] * height_scaler)
new_bbox[2] = int(new_bbox[2] * width_scaler)
new_bbox[3] = int(new_bbox[3] * height_scaler)
return new_bbox
return new_bbox


def rescale_point(point, processor_size, image_size):
# Point is in x, y format
page_width, page_height = processor_size

img_width, img_height = image_size
width_scaler = img_width / page_width
height_scaler = img_height / page_height

new_point = copy.deepcopy(point)
new_point[0] = int(new_point[0] * width_scaler)
new_point[1] = int(new_point[1] * height_scaler)
return new_point


def rescale_points(points, processor_size, image_size):
return [rescale_point(point, processor_size, image_size) for point in points]
80 changes: 80 additions & 0 deletions surya/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import copy
from typing import List, Tuple

from pydantic import BaseModel, field_validator, computed_field


class PolygonBox(BaseModel):
corners: List[List[float]]

@field_validator('corners')
@classmethod
def check_elements(cls, v: List[List[float]]) -> List[List[float]]:
if len(v) != 4:
raise ValueError('corner must have 4 elements')

for corner in v:
if len(corner) != 2:
raise ValueError('corner must have 2 elements')
return v

@property
def height(self):
return self.corners[1][1] - self.corners[0][1]

@property
def width(self):
return self.corners[1][0] - self.corners[0][0]

@property
def area(self):
return self.width * self.height

@computed_field
@property
def bbox(self) -> List[float]:
box = [self.corners[0][0], self.corners[0][1], self.corners[1][0], self.corners[2][1]]
if box[0] > box[2]:
box[0], box[2] = box[2], box[0]
if box[1] > box[3]:
box[1], box[3] = box[3], box[1]
return box


def rescale(self, processor_size, image_size):
# Point is in x, y format
page_width, page_height = processor_size

img_width, img_height = image_size
width_scaler = img_width / page_width
height_scaler = img_height / page_height

new_corners = copy.deepcopy(self.corners)
for corner in new_corners:
corner[0] = int(corner[0] * width_scaler)
corner[1] = int(corner[1] * height_scaler)
self.corners = new_corners



class Bbox(BaseModel):
bbox: List[float]

@field_validator('bbox')
@classmethod
def check_4_elements(cls, v: List[float]) -> List[float]:
if len(v) != 4:
raise ValueError('bbox must have 4 elements')
return v

@property
def height(self):
return self.bbox[3] - self.bbox[1]

@property
def width(self):
return self.bbox[2] - self.bbox[0]

@property
def area(self):
return self.width * self.height

0 comments on commit 7248218

Please sign in to comment.