Skip to content

Commit

Permalink
Add folder level inference
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Jan 12, 2024
1 parent 1a42344 commit 411c140
Show file tree
Hide file tree
Showing 12 changed files with 836 additions and 39 deletions.
29 changes: 29 additions & 0 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: Python package
on:
push:
tags:
- "v*.*.*"
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.11
uses: actions/setup-python@v4
with:
python-version: 3.11
- name: Install python dependencies
run: |
pip install poetry
poetry install
poetry remove torch
poetry run pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Build package
run: |
poetry build
- name: Publish package
env:
PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
run: |
poetry config pypi-token.pypi "$PYPI_TOKEN"
poetry publish
29 changes: 29 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: Integration test

on: [push]

env:
TORCH_DEVICE: "cpu"

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.11
uses: actions/setup-python@v4
with:
python-version: 3.11
- name: Install python dependencies
run: |
pip install poetry
poetry install
poetry remove torch
poetry run pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Run benchmark test
run: |
poetry run python benchmark/detection.py --max 2
poetry run python scripts/verify_benchmark_scores.py results/benchmark/doclaynet_bench/results.json
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ test_data
training
wandb
notebooks
results
data

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
674 changes: 674 additions & 0 deletions LICENSE

Large diffs are not rendered by default.

16 changes: 9 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Surya is a multilingual document OCR toolkit. It can do:

- Accurate line-level text detection
- Line-level text recognition (coming soon)
- Text recognition (coming soon)
- Table and chart detection (coming soon)

It works on a range of documents and languages (see [usage](#usage) and [benchmarks](#benchmarks) for more details).
Expand Down Expand Up @@ -47,14 +47,15 @@ Model weights will automatically download the first time you run surya.

## Text line detection

You can detect text lines in an image or pdf with the following command. This will write out a json file with the detected bboxes, and optionally save images of the pages with the bboxes.
You can detect text lines in an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected bboxes, and optionally save images of the pages with the bboxes.

Setting the `DETECTOR_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use 280MB of VRAM, so very high batch sizes are possible. Depending on your CPU core count, `DETECTOR_BATCH_SIZE` might make a difference there too.

```
surya_detect PDF_PATH --images
surya_detect DATA_PATH --images
```

- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
- `--images` will save images of the pages and detected text lines (optional)
- `--max` specifies the maximum number of pages to process if you don't want to process everything
- `--results_dir` specifies the directory to save results to instead of the default
Expand All @@ -63,7 +64,7 @@ This has worked with every language I've tried. It will work best with document

You can adjust `DETECTOR_NMS_THRESHOLD` and `DETECTOR_TEXT_THRESHOLD` if you don't get good results. Try lowering them to detect more text, and vice versa.

*Importing in Python*
**Importing in Python**

You can also do text detection from code with:

Expand Down Expand Up @@ -115,7 +116,7 @@ Tesseract is CPU-based, and surya is CPU or GPU. I ran the benchmarks on a syst
- tesseract - 32 CPU cores, or 8 workers using 4 cores each
- surya - 32 batch size, for 9GB VRAM usage

*Methodology*
**Methodology**

Surya predicts line-level bboxes, while tesseract and others predict word-level or character-level. It's also hard to find 100% correct datasets with line-level annotations. Merging bboxes can be noisy, so I chose not to use IoU as the metric for evaluation.

Expand All @@ -135,7 +136,7 @@ You can benchmark the performance of surya on your machine.
- Follow the manual install instructions above.
- `poetry install --group dev` # Installs dev dependencies

*Text line detection*
**Text line detection**

This will evaluate tesseract and surya for text line detection across a randomly sampled set of images from [doclaynet](https://huggingface.co/datasets/vikp/doclaynet_bench).

Expand All @@ -155,7 +156,7 @@ This was trained on 4x A6000s for about 5 days. It used a diverse set of 1M ima

# Commercial usage

*Text detection*
**Text detection**

The text detection model was trained from scratch, so it's okay for commercial usage. The weights are licensed cc-by-nc-sa-4.0, but I will waive that for any organization under $10M in gross revenue in the last 12 months.

Expand All @@ -167,5 +168,6 @@ This work would not have been possible without amazing open source AI work:

- [Segformer](https://arxiv.org/pdf/2105.15203.pdf) from NVIDIA
- [transformers](https://github.com/huggingface/transformers) from huggingface
- [CRAFT](https://github.com/clovaai/CRAFT-pytorch), a great scene text detection model

Thank you to everyone who makes open source AI possible.
2 changes: 0 additions & 2 deletions data/.gitignore

This file was deleted.

90 changes: 68 additions & 22 deletions detect_text.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import copy
import json
from collections import defaultdict

from PIL import Image

Expand All @@ -14,8 +15,58 @@
import filetype


def get_name_from_path(path):
return os.path.basename(path).split(".")[0]


def load_pdf(pdf_path, max_pages=None):
doc = open_pdf(pdf_path)
page_count = len(doc)
if max_pages:
page_count = min(max_pages, page_count)

page_indices = list(range(page_count))

images = get_page_images(doc, page_indices)
doc.close()
names = [get_name_from_path(pdf_path) for _ in page_indices]
return images, names


def load_image(image_path):
image = Image.open(image_path).convert("RGB")
name = get_name_from_path(image_path)
return [image], [name]


def load_from_file(input_path, max_pages=None):
input_type = filetype.guess(input_path)
if input_type.extension == "pdf":
return load_pdf(input_path, max_pages)
else:
return load_image(input_path)


def load_from_folder(folder_path, max_pages=None):
image_paths = [os.path.join(folder_path, image_name) for image_name in os.listdir(folder_path)]
image_paths = [ip for ip in image_paths if not os.path.isdir(ip) and not ip.startswith(".")]

images = []
names = []
for path in image_paths:
if filetype.guess(path).extension == "pdf":
image, name = load_pdf(path, max_pages)
images.extend(image)
names.extend(name)
else:
image, name = load_image(path)
images.extend(image)
names.extend(name)
return images, names


def main():
parser = argparse.ArgumentParser(description="Detect bboxes in an input file (PDF or image).")
parser = argparse.ArgumentParser(description="Detect bboxes in an input file or folder (PDFs or image).")
parser.add_argument("input_path", type=str, help="Path to pdf or image file to detect bboxes in.")
parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "surya"))
parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
Expand All @@ -26,49 +77,44 @@ def main():
model = load_model()
processor = load_processor()

input_type = filetype.guess(args.input_path)
if input_type.extension == "pdf":
doc = open_pdf(args.input_path)
page_count = len(doc)
if args.max:
page_count = min(args.max, page_count)

page_indices = list(range(page_count))

images = get_page_images(doc, page_indices)
doc.close()
if os.path.isdir(args.input_path):
images, names = load_from_folder(args.input_path, args.max)
folder_name = os.path.basename(args.input_path)
else:
image = Image.open(args.input_path).convert("RGB")
images = [image]
images, names = load_from_file(args.input_path, args.max)
folder_name = os.path.basename(args.input_path).split(".")[0]

predictions = batch_inference(images, model, processor)

folder_name = os.path.basename(args.input_path).split(".")[0]
result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)

if args.images:
for idx, (image, pred) in enumerate(zip(images, predictions)):
for idx, (image, pred, name) in enumerate(zip(images, predictions, names)):
bbox_image = draw_bboxes_on_image(pred["bboxes"], copy.deepcopy(image))
bbox_image.save(os.path.join(result_path, f"{idx}_bbox.png"))
bbox_image.save(os.path.join(result_path, f"{name}_{idx}_bbox.png"))

column_image = draw_lines_on_image(pred["vertical_lines"], copy.deepcopy(image))
column_image.save(os.path.join(result_path, f"{idx}_column.png"))
column_image.save(os.path.join(result_path, f"{name}_{idx}_column.png"))

if args.debug:
heatmap = pred["heatmap"]
heatmap.save(os.path.join(result_path, f"{idx}_heat.png"))
heatmap.save(os.path.join(result_path, f"{name}_{idx}_heat.png"))

affinity_map = pred["affinity_map"]
affinity_map.save(os.path.join(result_path, f"{idx}_affinity.png"))
affinity_map.save(os.path.join(result_path, f"{name}_{idx}_affinity.png"))

# Remove all the images from the predictions
for pred in predictions:
pred.pop("heatmap", None)
pred.pop("affinity_map", None)

predictions_by_page = defaultdict(list)
for idx, (pred, name) in enumerate(zip(predictions, names)):
pred["page_number"] = len(predictions_by_page[name]) + 1
predictions_by_page[name].append(pred)

with open(os.path.join(result_path, "results.json"), "w+") as f:
json.dump(predictions, f, indent=4)
json.dump(predictions_by_page, f)

print(f"Wrote results to {result_path}")

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "surya-ocr"
version = "0.1.0"
version = "0.1.2"
description = "Document OCR models for multilingual text detection and recognition"
authors = ["Vik Paruchuri <[email protected]>"]
readme = "README.md"
Expand Down
2 changes: 0 additions & 2 deletions results/.gitignore

This file was deleted.

20 changes: 20 additions & 0 deletions scripts/verify_benchmark_scores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import json
import argparse


def verify_scores(file_path):
with open(file_path, 'r') as file:
data = json.load(file)

scores = data["metrics"]["surya"]

if scores["precision"] <= 0.9 or scores["recall"] <= 0.9:
print(scores)
raise ValueError("Scores do not meet the required threshold")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Verify benchmark scores")
parser.add_argument("file_path", type=str, help="Path to the json file")
args = parser.parse_args()
verify_scores(args.file_path)
4 changes: 2 additions & 2 deletions surya/benchmark/tesseract.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def tesseract_bboxes(img):
bboxes = []
n_boxes = len(ocr['level'])
for i in range(n_boxes):
# TODO: it is possible to merge by line here, but it gives bad results. Find another way to get line-level.
# It is possible to merge by line here with line number, but it gives bad results.
_, x, y, w, h = ocr['text'][i], ocr['left'][i], ocr['top'][i], ocr['width'][i], ocr['height'][i]
bbox = (x, y, x + w, y + h)
bboxes.append(bbox)
Expand All @@ -28,7 +28,7 @@ def tesseract_parallel(imgs):
tess_parallel_cores = min(tess_parallel_cores, cpus)

# Tesseract uses 4 threads per instance
tess_parallel = tess_parallel_cores // 4
tess_parallel = max(tess_parallel_cores // 4, 1)

with ProcessPoolExecutor(max_workers=tess_parallel) as executor:
tess_bboxes = executor.map(tesseract_bboxes, imgs)
Expand Down
5 changes: 2 additions & 3 deletions surya/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@ def TORCH_DEVICE_MODEL(self) -> str:
DETECTOR_NMS_THRESHOLD: float = 0.35 # Threshold for non-maximum suppression

# Paths
BASE_DIR: str = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
DATA_DIR: str = os.path.join(BASE_DIR, "data")
RESULT_DIR: str = os.path.join(BASE_DIR, "results")
DATA_DIR: str = "data"
RESULT_DIR: str = "results"

@computed_field
@property
Expand Down

0 comments on commit 411c140

Please sign in to comment.