Skip to content

Commit

Permalink
Add bad OCR detection to app
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Dec 19, 2024
1 parent 2525ee0 commit 20b7b62
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 5 deletions.
59 changes: 55 additions & 4 deletions ocr_app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import io
import tempfile
from typing import List

import pypdfium2
Expand All @@ -15,6 +16,7 @@
from surya.model.recognition.processor import load_processor as load_rec_processor
from surya.model.table_rec.model import load_model as load_table_model
from surya.model.table_rec.processor import load_processor as load_table_processor
from surya.model.ocr_error.model import load_model as load_ocr_error_model, load_tokenizer as load_ocr_error_processor
from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image
from surya.ocr import run_ocr
from surya.postprocessing.text import draw_text_on_image
Expand All @@ -24,7 +26,9 @@
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, TableResult
from surya.settings import settings
from surya.tables import batch_table_recognition
from surya.postprocessing.util import rescale_bboxes, rescale_bbox
from surya.postprocessing.util import rescale_bbox
from pdftext.extraction import plain_text_output
from surya.ocr_error import batch_ocr_error_detection


@st.cache_resource()
Expand All @@ -46,6 +50,39 @@ def load_layout_cached():
def load_table_cached():
return load_table_model(), load_table_processor()

@st.cache_resource()
def load_ocr_error_cached():
return load_ocr_error_model(), load_ocr_error_processor()


def run_ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pages=15):
with tempfile.NamedTemporaryFile(suffix=".pdf") as f:
f.write(pdf_file.getvalue())
f.seek(0)

# Sample the text from the middle of the PDF
page_middle = page_count // 2
page_range = range(max(page_middle - max_pages, 0), min(page_middle + max_pages, page_count))
text = plain_text_output(f.name, page_range=page_range)

sample_gap = len(text) // max_samples
if len(text) == 0 or sample_gap == 0:
return "This PDF has no text or very little text", ["no text"]

if sample_gap < sample_len:
sample_gap = sample_len

# Split the text into samples for the model
samples = []
for i in range(0, len(text), sample_gap):
samples.append(text[i:i + sample_len])

results = batch_ocr_error_detection(samples, ocr_error_model, ocr_error_processor)
label = "This PDF has good text."
if results.labels.count("bad") / len(results.labels) > .2:
label = "This PDF may have garbled or bad OCR text."
return label, results.labels


def text_detection(img) -> (Image.Image, TextDetectionResult):
pred = batch_text_detection([img], det_model, det_processor)[0]
Expand Down Expand Up @@ -139,13 +176,16 @@ def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI):
)
png = list(renderer)[0]
png_image = png.convert("RGB")
doc.close()
return png_image


@st.cache_data()
def page_count(pdf_file):
def page_counter(pdf_file):
doc = open_pdf(pdf_file)
return len(doc)
doc_len = len(doc)
doc.close()
return doc_len


st.set_page_config(layout="wide")
Expand All @@ -155,6 +195,7 @@ def page_count(pdf_file):
rec_model, rec_processor = load_rec_cached()
layout_model, layout_processor = load_layout_cached()
table_model, table_processor = load_table_cached()
ocr_error_model, ocr_error_processor = load_ocr_error_cached()


st.markdown("""
Expand All @@ -179,8 +220,9 @@ def page_count(pdf_file):

filetype = in_file.type
whole_image = False
page_count = None
if "pdf" in filetype:
page_count = page_count(in_file)
page_count = page_counter(in_file)
page_number = st.sidebar.number_input(f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count)

pil_image = get_page_image(in_file, page_number, settings.IMAGE_DPI)
Expand All @@ -194,6 +236,7 @@ def page_count(pdf_file):
text_rec = st.sidebar.button("Run OCR")
layout_det = st.sidebar.button("Run Layout Analysis")
table_rec = st.sidebar.button("Run Table Rec")
ocr_errors = st.sidebar.button("Run bad PDF text detection")
use_pdf_boxes = st.sidebar.checkbox("PDF table boxes", value=True, help="Table recognition only: Use the bounding boxes from the PDF file vs text detection model.")
skip_table_detection = st.sidebar.checkbox("Skip table detection", value=False, help="Table recognition only: Skip table detection and treat the whole image/page as a table.")

Expand Down Expand Up @@ -233,5 +276,13 @@ def page_count(pdf_file):
st.image(table_img, caption="Table Recognition", use_container_width=True)
st.json([p.model_dump() for p in pred], expanded=True)

if ocr_errors:
if "pdf" not in filetype:
st.error("This feature only works with PDFs.")
label, results = run_ocr_errors(in_file, page_count)
with col1:
st.write(label)
st.json(results)

with col2:
st.image(pil_image, caption="Uploaded Image", use_container_width=True)
2 changes: 1 addition & 1 deletion surya/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def TORCH_DEVICE_MODEL(self) -> str:
COMPILE_TABLE_REC: bool = False

# OCR Error Detection
OCR_ERROR_MODEL_CHECKPOINT: str = "tarun-menta/ocr_error_detection"
OCR_ERROR_MODEL_CHECKPOINT: str = "datalab-to/ocr_error_detection"
OCR_ERROR_BATCH_SIZE: Optional[int] = None
COMPILE_OCR_ERROR: bool = False

Expand Down

0 comments on commit 20b7b62

Please sign in to comment.