Skip to content

Commit

Permalink
Merge pull request #68 from VikParuchuri/dev
Browse files Browse the repository at this point in the history
Add layout model
  • Loading branch information
VikParuchuri authored Mar 26, 2024
2 parents ce8e95b + f36040d commit 3cdc3b6
Show file tree
Hide file tree
Showing 30 changed files with 453 additions and 207 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ jobs:
run: |
poetry run python benchmark/recognition.py --max 2
poetry run python scripts/verify_benchmark_scores.py results/benchmark/rec_bench/results.json --bench_type recognition
- name: Run layout benchmark test
run: |
poetry run python benchmark/layout.py --max 5
poetry run python scripts/verify_benchmark_scores.py results/benchmark/layout_bench/results.json --bench_type layout
147 changes: 112 additions & 35 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ Surya is a document OCR toolkit that does:

- Accurate OCR in 90+ languages
- Line-level text detection in any language
- Table and chart detection (coming soon)
- Layout analysis (table, image, header, etc detection) in any language

It works on a range of documents (see [usage](#usage) and [benchmarks](#benchmarks) for more details).

| Detection | OCR |
|:----------------------------------------------------------------:|:-----------------------------------------------------------------------:|
| ![New York Times Article Detection](static/images/excerpt.png) | ![New York Times Article Recognition](static/images/excerpt_text.png) |
| Detection | OCR | Layout |
|:----------------------------------------------------------------:|:-----------------------------------------------------------------------:|:---------------------------------------------------------------------:|
| ![New York Times Article Detection](static/images/excerpt.png) | ![New York Times Article Recognition](static/images/excerpt_text.png) | ![New York Times Article Detection](static/images/excerpt_layout.png) |


Surya is named for the [Hindu sun god](https://en.wikipedia.org/wiki/Surya), who has universal vision.
Expand All @@ -21,27 +21,27 @@ Surya is named for the [Hindu sun god](https://en.wikipedia.org/wiki/Surya), who

## Examples

| Name | Text Detection | OCR |
|------------------|:-----------------------------------:|-----------------------------------------:|
| Japanese | [Image](static/images/japanese.jpg) | [Image](static/images/japanese_text.jpg) |
| Chinese | [Image](static/images/chinese.jpg) | [Image](static/images/chinese_text.jpg) |
| Hindi | [Image](static/images/hindi.jpg) | [Image](static/images/hindi_text.jpg) |
| Arabic | [Image](static/images/arabic.jpg) | [Image](static/images/arabic_text.jpg) |
| Chinese + Hindi | [Image](static/images/chi_hind.jpg) | [Image](static/images/chi_hind_text.jpg) |
| Presentation | [Image](static/images/pres.png) | [Image](static/images/pres_text.jpg) |
| Scientific Paper | [Image](static/images/paper.jpg) | [Image](static/images/paper_text.jpg) |
| Scanned Document | [Image](static/images/scanned.png) | [Image](static/images/scanned_text.jpg) |
| New York Times | [Image](static/images/nyt.jpg) | [Image](static/images/nyt_text.jpg) |
| Scanned Form | [Image](static/images/funsd.png) | [Image](static/images/funsd_text.jpg) |
| Textbook | [Image](static/images/textbook.jpg) | [Image](static/images/textbook_text.jpg) |
| Name | Text Detection | OCR | Layout |
|------------------|:-----------------------------------:|-----------------------------------------:|--------:|
| Japanese | [Image](static/images/japanese.jpg) | [Image](static/images/japanese_text.jpg) | [Image](static/images/japanese_layout.jpg) |
| Chinese | [Image](static/images/chinese.jpg) | [Image](static/images/chinese_text.jpg) | [Image](static/images/chinese_layout.jpg) |
| Hindi | [Image](static/images/hindi.jpg) | [Image](static/images/hindi_text.jpg) | [Image](static/images/hindi_layout.jpg) |
| Arabic | [Image](static/images/arabic.jpg) | [Image](static/images/arabic_text.jpg) | [Image](static/images/arabic_layout.jpg) |
| Chinese + Hindi | [Image](static/images/chi_hind.jpg) | [Image](static/images/chi_hind_text.jpg) | [Image](static/images/chi_hind_layout.jpg) |
| Presentation | [Image](static/images/pres.png) | [Image](static/images/pres_text.jpg) | [Image](static/images/pres_layout.jpg) |
| Scientific Paper | [Image](static/images/paper.jpg) | [Image](static/images/paper_text.jpg) | [Image](static/images/paper_layout.jpg) |
| Scanned Document | [Image](static/images/scanned.png) | [Image](static/images/scanned_text.jpg) | [Image](static/images/scanned_layout.jpg) |
| New York Times | [Image](static/images/nyt.jpg) | [Image](static/images/nyt_text.jpg) | [Image](static/images/nyt_layout.jpg) |
| Scanned Form | [Image](static/images/funsd.png) | [Image](static/images/funsd_text.jpg) | -- |
| Textbook | [Image](static/images/textbook.jpg) | [Image](static/images/textbook_text.jpg) | [Image](static/images/textbook_layout.jpg) |

# Installation

You'll need python 3.9+ and PyTorch. You may need to install the CPU version of torch first if you're not using a Mac or a GPU machine. See [here](https://pytorch.org/get-started/locally/) for more details.

Install with:

```
```shell
pip install surya-ocr
```

Expand All @@ -56,7 +56,7 @@ Model weights will automatically download the first time you run surya. Note th

I've included a streamlit app that lets you interactively try Surya on images or PDF files. Run it with:

```
```shell
pip install streamlit
surya_gui
```
Expand All @@ -67,7 +67,7 @@ Pass the `--math` command line argument to use the math detection model instead

You can OCR text in an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected text and bboxes, and optionally save images of the reconstructed page.

```
```shell
surya_ocr DATA_PATH --images --langs hi,en
```

Expand Down Expand Up @@ -96,17 +96,17 @@ Setting the `RECOGNITION_BATCH_SIZE` env var properly will make a big difference

### From python

```
```python
from PIL import Image
from surya.ocr import run_ocr
from surya.model.detection.segformer import load_model as load_det_model, load_processor as load_det_processor
from surya.model.recognition.model import load_model as load_rec_model
from surya.model.recognition.processor import load_processor as load_rec_processor
from surya.model.detection import segformer
from surya.model.recognition.model import load_model
from surya.model.recognition.processor import load_processor

image = Image.open(IMAGE_PATH)
langs = ["en"] # Replace with your languages
det_processor, det_model = load_det_processor(), load_det_model()
rec_model, rec_processor = load_rec_model(), load_rec_processor()
det_processor, det_model = segformer.load_processor(), segformer.load_model()
rec_model, rec_processor = load_model(), load_processor()

predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec_processor)
```
Expand All @@ -115,7 +115,7 @@ predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec

You can detect text lines in an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected bboxes.

```
```shell
surya_detect DATA_PATH --images
```

Expand Down Expand Up @@ -144,25 +144,72 @@ Setting the `DETECTOR_BATCH_SIZE` env var properly will make a big difference wh

### From python

```
```python
from PIL import Image
from surya.detection import batch_detection
from surya.model.segformer import load_model, load_processor
from surya.detection import batch_text_detection
from surya.model.detection.segformer import load_model, load_processor

image = Image.open(IMAGE_PATH)
model, processor = load_model(), load_processor()

# predictions is a list of dicts, one per image
predictions = batch_detection([image], model, processor)
predictions = batch_text_detection([image], model, processor)
```

## Layout analysis

You can detect the layout of an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected layout.

```shell
surya_layout DATA_PATH --images
```

- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
- `--images` will save images of the pages and detected text lines (optional)
- `--max` specifies the maximum number of pages to process if you don't want to process everything
- `--results_dir` specifies the directory to save results to instead of the default

The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains:

- `bboxes` - detected bounding boxes for text
- `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
- `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left.
- `confidence` - the confidence of the model in the detected text (0-1). This is currently not very reliable.
- `label` - the label for the bbox. One of `Caption`, `Footnote`, `Formula`, `List-item`, `Page-footer`, `Page-header`, `Picture`, `Figure`, `Section-header`, `Table`, `Text`, `Title`.
- `page` - the page number in the file
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox.

**Performance tips**

Setting the `DETECTOR_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `280MB` of VRAM, so very high batch sizes are possible. The default is a batch size `32`, which will use about 9GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `2`.

### From python

```python
from PIL import Image
from surya.detection import batch_text_detection
from surya.layout import batch_layout_detection
from surya.model.detection.segformer import load_model, load_processor
from surya.settings import settings

image = Image.open(IMAGE_PATH)
model = load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
processor = load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
det_model = load_model()
det_processor = load_processor()

# layout_predictions is a list of dicts, one per image
line_predictions = batch_text_detection([image], det_model, det_processor)
layout_predictions = batch_layout_detection([image], model, processor, line_predictions)
```

# Limitations

- This is specialized for document OCR. It will likely not work on photos or other images.
- Surya is for OCR - the goal is to recognize the text lines correctly, not sort them into reading order. Surya will attempt to sort the lines, which will work in many cases, but use something like [marker](https://github.com/VikParuchuri/marker) or other postprocessing if you need to order the text.
- It is for printed text, not handwriting (though it may work on some handwriting).
- The model has trained itself to ignore advertisements.
- You can find language support for OCR in `surya/languages.py`. Text detection should work with any language.
- The text detection model has trained itself to ignore advertisements.
- You can find language support for OCR in `surya/languages.py`. Text detection and layout analysis will work with any language.

## Troubleshooting

Expand All @@ -172,7 +219,6 @@ If OCR isn't working properly:
- Preprocessing the image (binarizing, deskewing, etc) can help with very old/blurry images.
- You can adjust `DETECTOR_BLANK_THRESHOLD` and `DETECTOR_TEXT_THRESHOLD` if you don't get good results. `DETECTOR_BLANK_THRESHOLD` controls the space between lines - any prediction below this number will be considered blank space. `DETECTOR_TEXT_THRESHOLD` controls how text is joined - any number above this is considered text. `DETECTOR_TEXT_THRESHOLD` should always be higher than `DETECTOR_BLANK_THRESHOLD`, and both should be in the 0-1 range. Looking at the heatmap from the debug output of the detector can tell you how to adjust these (if you see faint things that look like boxes, lower the thresholds, and if you see bboxes being joined together, raise the thresholds).


# Manual install

If you want to develop surya, you can install it manually:
Expand Down Expand Up @@ -231,6 +277,26 @@ First calculate coverage for each bbox, then add a small penalty for double cove

Then we calculate precision and recall for the whole dataset.

## Layout analysis

![Benchmark chart](static/images/benchmark_layout_chart.png)

| Layout Type | precision | recall |
|---------------|-------------|----------|
| Image | 0.95 | 0.99 |
| Table | 0.95 | 0.96 |
| Text | 0.89 | 0.95 |
| Title | 0.92 | 0.89 |

Time per image - .79 seconds on GPU (A6000).

**Methodology**

I benchmarked the layout analysis on [Publaynet](https://github.com/ibm-aur-nlp/PubLayNet), which was not in the training data. I had to align publaynet labels with the surya layout labels. I was then able to find coverage for each layout type:

- Precision - how well the predicted bboxes cover ground truth bboxes
- Recall - how well ground truth bboxes cover predicted bboxes

## Running your own benchmarks

You can benchmark the performance of surya on your machine.
Expand Down Expand Up @@ -265,6 +331,17 @@ python benchmark/recognition.py --tesseract
- `--tesseract` will run the benchmark with tesseract. You have to run `sudo apt-get install tesseract-ocr-all` to install all tesseract data, and set `TESSDATA_PREFIX` to the path to the tesseract data folder.
- Set `RECOGNITION_BATCH_SIZE=864` to use the same batch size as the benchmark.

**Layout analysis**

This will evaluate surya on the publaynet dataset.

```
python benchmark/layout.py
```

- `--max` controls how many images to process for the benchmark
- `--debug` will render images with detected text
- `--results_dir` will let you specify a directory to save results to instead of the default one

# Training

Expand All @@ -274,7 +351,7 @@ Text recognition was trained on 4x A6000s for 2 weeks. It was trained using a m

# Commercial usage

The text detection and OCR models were trained from scratch, so they're okay for commercial usage. The weights are licensed cc-by-nc-sa-4.0, but I will waive that for any organization under $5M USD in gross revenue in the most recent 12-month period.
The text detection, layout analysis, and OCR models were trained from scratch, so they're okay for commercial usage. The weights are licensed cc-by-nc-sa-4.0, but I will waive that for any organization under $5M USD in gross revenue in the most recent 12-month period.

If you want to remove the GPL license requirements for inference or use the weights commercially over the revenue limit, please contact me at [email protected] for dual licensing.

Expand Down
114 changes: 114 additions & 0 deletions benchmark/layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import argparse
import collections
import copy
import json

from surya.benchmark.metrics import precision_recall
from surya.detection import batch_text_detection
from surya.model.detection.segformer import load_model, load_processor
from surya.input.processing import open_pdf, get_page_images
from surya.layout import batch_layout_detection
from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image
from surya.postprocessing.util import rescale_bbox
from surya.settings import settings
import os
import time
from tabulate import tabulate
import datasets


def main():
parser = argparse.ArgumentParser(description="Benchmark surya layout model.")
parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
parser.add_argument("--max", type=int, help="Maximum number of images to run benchmark on.", default=100)
parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False)
args = parser.parse_args()

model = load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
processor = load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
det_model = load_model()
det_processor = load_processor()

pathname = "layout_bench"
# These have already been shuffled randomly, so sampling from the start is fine
dataset = datasets.load_dataset(settings.LAYOUT_BENCH_DATASET_NAME, split=f"train[:{args.max}]")
images = list(dataset["image"])
images = [i.convert("RGB") for i in images]

start = time.time()
line_predictions = batch_text_detection(images, det_model, det_processor)
layout_predictions = batch_layout_detection(images, model, processor, line_predictions)
surya_time = time.time() - start

folder_name = os.path.basename(pathname).split(".")[0]
result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)

label_alignment = { # First is publaynet, second is surya
"Image": [["Figure"], ["Picture", "Figure"]],
"Table": [["Table"], ["Table"]],
"Text": [["Text", "List"], ["Text", "Formula", "Footnote", "Caption", "List-item"]],
"Title": [["Title"], ["Section-header", "Title"]]
}

page_metrics = collections.OrderedDict()
for idx, pred in enumerate(layout_predictions):
row = dataset[idx]
all_correct_bboxes = []
page_results = {}
for label_name in label_alignment:
correct_cats, surya_cats = label_alignment[label_name]
correct_bboxes = [b for b, l in zip(row["bboxes"], row["labels"]) if l in correct_cats]
all_correct_bboxes.extend(correct_bboxes)
pred_bboxes = [b.bbox for b in pred.bboxes if b.label in surya_cats]

metrics = precision_recall(pred_bboxes, correct_bboxes, penalize_double=False)
weight = len(correct_bboxes)
metrics["weight"] = weight
page_results[label_name] = metrics

page_metrics[idx] = page_results

if args.debug:
bbox_image = draw_bboxes_on_image(all_correct_bboxes, copy.deepcopy(images[idx]))
bbox_image.save(os.path.join(result_path, f"{idx}_layout.png"))

mean_metrics = collections.defaultdict(dict)
layout_types = sorted(page_metrics[0].keys())
metric_types = sorted(page_metrics[0][layout_types[0]].keys())
metric_types.remove("weight")
for l in layout_types:
for m in metric_types:
metric = []
total = 0
for page in page_metrics:
metric.append(page_metrics[page][l][m] * page_metrics[page][l]["weight"])
total += page_metrics[page][l]["weight"]

value = sum(metric)
if value > 0:
value /= total
mean_metrics[l][m] = value

out_data = {
"time": surya_time,
"metrics": mean_metrics,
"page_metrics": page_metrics
}

with open(os.path.join(result_path, "results.json"), "w+") as f:
json.dump(out_data, f, indent=4)

table_headers = ["Layout Type", ] + metric_types
table_data = []
for layout_type in layout_types:
table_data.append([layout_type, ] + [f"{mean_metrics[layout_type][m]:.2f}" for m in metric_types])

print(tabulate(table_data, headers=table_headers, tablefmt="github"))
print(f"Took {surya_time / len(images):.2f} seconds per image, and {surya_time:.1f} seconds total.")
print("Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold.")
print(f"Wrote results to {result_path}")


if __name__ == "__main__":
main()
Loading

0 comments on commit 3cdc3b6

Please sign in to comment.