Skip to content

Commit

Permalink
Merge pull request #75 from VikParuchuri/dev
Browse files Browse the repository at this point in the history
Add reading order model
  • Loading branch information
VikParuchuri authored Apr 22, 2024
2 parents 3cdc3b6 + 1abd2f0 commit e8c98ac
Show file tree
Hide file tree
Showing 34 changed files with 1,462 additions and 48 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ jobs:
run: |
poetry run python benchmark/layout.py --max 5
poetry run python scripts/verify_benchmark_scores.py results/benchmark/layout_bench/results.json --bench_type layout
- name: Run ordering benchmark text
run: |
poetry run python benchmark/ordering.py --max 5
poetry run python scripts/verify_benchmark_scores.py results/benchmark/order_bench/results.json --bench_type ordering
133 changes: 106 additions & 27 deletions README.md

Large diffs are not rendered by default.

79 changes: 79 additions & 0 deletions benchmark/ordering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import argparse
import collections
import copy
import json

from surya.benchmark.metrics import precision_recall
from surya.model.ordering.model import load_model
from surya.model.ordering.processor import load_processor
from surya.postprocessing.heatmap import draw_bboxes_on_image
from surya.ordering import batch_ordering
from surya.settings import settings
from surya.benchmark.metrics import rank_accuracy
import os
import time
from tabulate import tabulate
import datasets


def main():
parser = argparse.ArgumentParser(description="Benchmark surya reading order model.")
parser.add_argument("--results_dir", type=str, help="Path to JSON file with benchmark results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
parser.add_argument("--max", type=int, help="Maximum number of images to run benchmark on.", default=None)
args = parser.parse_args()

model = load_model()
processor = load_processor()

pathname = "order_bench"
# These have already been shuffled randomly, so sampling from the start is fine
split = "train"
if args.max is not None:
split = f"train[:{args.max}]"
dataset = datasets.load_dataset(settings.ORDER_BENCH_DATASET_NAME, split=split)
images = list(dataset["image"])
images = [i.convert("RGB") for i in images]
bboxes = list(dataset["bboxes"])

start = time.time()
order_predictions = batch_ordering(images, bboxes, model, processor)
surya_time = time.time() - start

folder_name = os.path.basename(pathname).split(".")[0]
result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)

page_metrics = collections.OrderedDict()
mean_accuracy = 0
for idx, order_pred in enumerate(order_predictions):
row = dataset[idx]
pred_labels = [str(l.position) for l in order_pred.bboxes]
labels = row["labels"]
accuracy = rank_accuracy(pred_labels, labels)
mean_accuracy += accuracy
page_results = {
"accuracy": accuracy,
"box_count": len(labels)
}

page_metrics[idx] = page_results

mean_accuracy /= len(order_predictions)

out_data = {
"time": surya_time,
"mean_accuracy": mean_accuracy,
"page_metrics": page_metrics
}

with open(os.path.join(result_path, "results.json"), "w+") as f:
json.dump(out_data, f, indent=4)

print(f"Mean accuracy is {mean_accuracy:.2f}.")
print(f"Took {surya_time / len(images):.2f} seconds per image, and {surya_time:.1f} seconds total.")
print("Mean accuracy is the % of correct ranking pairs.")
print(f"Wrote results to {result_path}")


if __name__ == "__main__":
main()
42 changes: 35 additions & 7 deletions ocr_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@
from surya.model.detection.segformer import load_model, load_processor
from surya.model.recognition.model import load_model as load_rec_model
from surya.model.recognition.processor import load_processor as load_rec_processor
from surya.model.ordering.processor import load_processor as load_order_processor
from surya.model.ordering.model import load_model as load_order_model
from surya.ordering import batch_ordering
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.ocr import run_ocr
from surya.postprocessing.text import draw_text_on_image
from PIL import Image
from surya.languages import CODE_TO_LANGUAGE
from surya.input.langs import replace_lang_with_code
from surya.schema import OCRResult, TextDetectionResult, LayoutResult
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, OrderResult
from surya.settings import settings

parser = argparse.ArgumentParser(description="Run OCR on an image or PDF.")
Expand All @@ -43,15 +46,19 @@ def load_rec_cached():
def load_layout_cached():
return load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT), load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)

@st.cache_resource()
def load_order_cached():
return load_order_model(), load_order_processor()


def text_detection(img) -> TextDetectionResult:
def text_detection(img) -> (Image.Image, TextDetectionResult):
pred = batch_text_detection([img], det_model, det_processor)[0]
polygons = [p.polygon for p in pred.bboxes]
det_img = draw_polys_on_image(polygons, img.copy())
return det_img, pred


def layout_detection(img) -> LayoutResult:
def layout_detection(img) -> (Image.Image, LayoutResult):
_, det_pred = text_detection(img)
pred = batch_layout_detection([img], layout_model, layout_processor, [det_pred])[0]
polygons = [p.polygon for p in pred.bboxes]
Expand All @@ -60,8 +67,18 @@ def layout_detection(img) -> LayoutResult:
return layout_img, pred


def order_detection(img) -> (Image.Image, OrderResult):
_, layout_pred = layout_detection(img)
bboxes = [l.bbox for l in layout_pred.bboxes]
pred = batch_ordering([img], [bboxes], order_model, order_processor)[0]
polys = [l.polygon for l in pred.bboxes]
positions = [str(l.position) for l in pred.bboxes]
order_img = draw_polys_on_image(polys, img.copy(), labels=positions, label_font_size=20)
return order_img, pred


# Function for OCR
def ocr(img, langs: List[str]) -> OCRResult:
def ocr(img, langs: List[str]) -> (Image.Image, OCRResult):
replace_lang_with_code(langs)
img_pred = run_ocr([img], [langs], det_model, det_processor, rec_model, rec_processor)[0]

Expand Down Expand Up @@ -101,6 +118,7 @@ def page_count(pdf_file):
det_model, det_processor = load_det_cached()
rec_model, rec_processor = load_rec_cached()
layout_model, layout_processor = load_layout_cached()
order_model, order_processor = load_order_cached()


st.markdown("""
Expand Down Expand Up @@ -136,24 +154,28 @@ def page_count(pdf_file):
text_det = st.sidebar.button("Run Text Detection")
text_rec = st.sidebar.button("Run OCR")
layout_det = st.sidebar.button("Run Layout Analysis")
order_det = st.sidebar.button("Run Reading Order")

if pil_image is None:
st.stop()

# Run Text Detection
if text_det and pil_image is not None:
if text_det:
det_img, pred = text_detection(pil_image)
with col1:
st.image(det_img, caption="Detected Text", use_column_width=True)
st.json(pred.model_dump(exclude=["heatmap", "affinity_map"]), expanded=True)


# Run layout
if layout_det and pil_image is not None:
if layout_det:
layout_img, pred = layout_detection(pil_image)
with col1:
st.image(layout_img, caption="Detected Layout", use_column_width=True)
st.json(pred.model_dump(exclude=["segmentation_map"]), expanded=True)

# Run OCR
if text_rec and pil_image is not None:
if text_rec:
rec_img, pred = ocr(pil_image, languages)
with col1:
st.image(rec_img, caption="OCR Result", use_column_width=True)
Expand All @@ -163,5 +185,11 @@ def page_count(pdf_file):
with text_tab:
st.text("\n".join([p.text for p in pred.text_lines]))

if order_det:
order_img, pred = order_detection(pil_image)
with col1:
st.image(order_img, caption="Reading Order", use_column_width=True)
st.json(pred.model_dump(), expanded=True)

with col2:
st.image(pil_image, caption="Uploaded Image", use_column_width=True)
8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]
name = "surya-ocr"
version = "0.3.0"
description = "OCR, layout analysis, and line detection in 90+ languages"
version = "0.4.0"
description = "OCR, layout, reading order, and line detection in 90+ languages"
authors = ["Vik Paruchuri <[email protected]>"]
readme = "README.md"
license = "GPL-3.0-or-later"
Expand All @@ -15,7 +15,8 @@ include = [
"ocr_text.py",
"ocr_app.py",
"run_ocr_app.py",
"detect_layout.py"
"detect_layout.py",
"reading_order.py",
]

[tool.poetry.dependencies]
Expand Down Expand Up @@ -48,6 +49,7 @@ surya_detect = "detect_text:main"
surya_ocr = "ocr_text:main"
surya_layout = "detect_layout:main"
surya_gui = "run_ocr_app:run_app"
surya_order = "reading_order:main"

[build-system]
requires = ["poetry-core"]
Expand Down
81 changes: 81 additions & 0 deletions reading_order.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import argparse
import copy
import json
from collections import defaultdict

from surya.detection import batch_text_detection
from surya.input.load import load_from_folder, load_from_file
from surya.layout import batch_layout_detection
from surya.model.detection.segformer import load_model as load_det_model, load_processor as load_det_processor
from surya.model.ordering.model import load_model
from surya.model.ordering.processor import load_processor
from surya.ordering import batch_ordering
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.settings import settings
import os


def main():
parser = argparse.ArgumentParser(description="Find reading order of an input file or folder (PDFs or image).")
parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to find reading order in.")
parser.add_argument("--results_dir", type=str, help="Path to JSON file with layout results.", default=os.path.join(settings.RESULT_DIR, "surya"))
parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
parser.add_argument("--images", action="store_true", help="Save images of detected layout bboxes.", default=False)
args = parser.parse_args()

model = load_model()
processor = load_processor()

layout_model = load_det_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
layout_processor = load_det_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)

det_model = load_det_model()
det_processor = load_det_processor()

if os.path.isdir(args.input_path):
images, names = load_from_folder(args.input_path, args.max)
folder_name = os.path.basename(args.input_path)
else:
images, names = load_from_file(args.input_path, args.max)
folder_name = os.path.basename(args.input_path).split(".")[0]

line_predictions = batch_text_detection(images, det_model, det_processor)
layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions)
bboxes = []
for layout_pred in layout_predictions:
bbox = [l.bbox for l in layout_pred.bboxes]
bboxes.append(bbox)

order_predictions = batch_ordering(images, bboxes, model, processor)
result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)

if args.images:
for idx, (image, layout_pred, order_pred, name) in enumerate(zip(images, layout_predictions, order_predictions, names)):
polys = [l.polygon for l in order_pred.bboxes]
labels = [str(l.position) for l in order_pred.bboxes]
bbox_image = draw_polys_on_image(polys, copy.deepcopy(image), labels=labels, label_font_size=20)
bbox_image.save(os.path.join(result_path, f"{name}_{idx}_order.png"))

predictions_by_page = defaultdict(list)
for idx, (layout_pred, pred, name, image) in enumerate(zip(layout_predictions, order_predictions, names, images)):
out_pred = pred.model_dump()
for bbox, layout_bbox in zip(out_pred["bboxes"], layout_pred.bboxes):
bbox["label"] = layout_bbox.label

out_pred["page"] = len(predictions_by_page[name]) + 1
predictions_by_page[name].append(out_pred)

# Sort in reading order
for name in predictions_by_page:
for page_preds in predictions_by_page[name]:
page_preds["bboxes"] = sorted(page_preds["bboxes"], key=lambda x: x["position"])

with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
json.dump(predictions_by_page, f, ensure_ascii=False)

print(f"Wrote results to {result_path}")


if __name__ == "__main__":
main()
8 changes: 8 additions & 0 deletions scripts/verify_benchmark_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ def verify_rec(data):
raise ValueError("Scores do not meet the required threshold")


def verify_order(data):
score = data["mean_accuracy"]
if score < 0.75:
raise ValueError("Scores do not meet the required threshold")


def verify_scores(file_path, bench_type):
with open(file_path, 'r') as file:
data = json.load(file)
Expand All @@ -31,6 +37,8 @@ def verify_scores(file_path, bench_type):
verify_rec(data)
elif bench_type == "layout":
verify_layout(data)
elif bench_type == "ordering":
verify_order(data)
else:
raise ValueError("Invalid benchmark type")

Expand Down
Binary file added static/images/arabic_reading.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added static/images/chi_hind_reading.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added static/images/chinese_reading.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added static/images/excerpt_reading.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added static/images/funsd_layout.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added static/images/funsd_reading.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added static/images/gcloud_full_langs.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added static/images/gcloud_rec_bench.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added static/images/hindi_reading.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added static/images/japanese_reading.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added static/images/nyt_order.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added static/images/paper_reading.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added static/images/pres_reading.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added static/images/scanned_reading.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added static/images/textbook_order.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 20 additions & 1 deletion surya/benchmark/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,23 @@ def mean_coverage(preds, references):
if len(coverages) == 0:
return 0
coverage = sum(coverages) / len(coverages)
return {"coverage": coverage}
return {"coverage": coverage}


def rank_accuracy(preds, references):
# Preds and references need to be aligned so each position refers to the same bbox
pairs = []
for i, pred in enumerate(preds):
for j, pred2 in enumerate(preds):
if i == j:
continue
pairs.append((i, j, pred > pred2))

# Find how many of the prediction rankings are correct
correct = 0
for i, ref in enumerate(references):
for j, ref2 in enumerate(references):
if (i, j, ref > ref2) in pairs:
correct += 1

return correct / len(pairs)
Loading

0 comments on commit e8c98ac

Please sign in to comment.