Skip to content

Commit

Permalink
Merge pull request #252 from VikParuchuri/dev
Browse files Browse the repository at this point in the history
New layout model
  • Loading branch information
VikParuchuri authored Nov 27, 2024
2 parents cb86a92 + e05dbd0 commit 30ce562
Show file tree
Hide file tree
Showing 40 changed files with 3,440 additions and 4,718 deletions.
79 changes: 17 additions & 62 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,9 @@ model, processor = load_model(), load_processor()
predictions = batch_text_detection([image], model, processor)
```

## Layout analysis
## Layout and reading order

This command will write out a json file with the detected layout.
This command will write out a json file with the detected layout and reading order.

```shell
surya_layout DATA_PATH
Expand All @@ -215,14 +215,14 @@ The `results.json` file will contain a json dictionary where the keys are the in
- `bboxes` - detected bounding boxes for text
- `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
- `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left.
- `confidence` - the confidence of the model in the detected text (0-1). This is currently not very reliable.
- `label` - the label for the bbox. One of `Caption`, `Footnote`, `Formula`, `List-item`, `Page-footer`, `Page-header`, `Picture`, `Figure`, `Section-header`, `Table`, `Text`, `Title`.
- `position` - the reading order of the box.
- `label` - the label for the bbox. One of `Caption`, `Footnote`, `Formula`, `List-item`, `Page-footer`, `Page-header`, `Picture`, `Figure`, `Section-header`, `Table`, `Form`, `Table-of-contents`, `Handwriting`, `Text`, `Text-inline-math`.
- `page` - the page number in the file
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox.

**Performance tips**

Setting the `DETECTOR_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `400MB` of VRAM, so very high batch sizes are possible. The default is a batch size `36`, which will use about 16GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `6`.
Setting the `LAYOUT_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `220MB` of VRAM, so very high batch sizes are possible. The default is a batch size `32`, which will use about 7GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `4`.

### From python

Expand All @@ -231,7 +231,6 @@ from PIL import Image
from surya.detection import batch_text_detection
from surya.layout import batch_layout_detection
from surya.model.layout.model import load_model, load_processor
from surya.settings import settings

image = Image.open(IMAGE_PATH)
model = load_model()
Expand All @@ -244,52 +243,6 @@ line_predictions = batch_text_detection([image], det_model, det_processor)
layout_predictions = batch_layout_detection([image], model, processor, line_predictions)
```

## Reading order

This command will write out a json file with the detected reading order and layout.

```shell
surya_order DATA_PATH
```

- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
- `--images` will save images of the pages and detected text lines (optional)
- `--max` specifies the maximum number of pages to process if you don't want to process everything
- `--results_dir` specifies the directory to save results to instead of the default

The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains:

- `bboxes` - detected bounding boxes for text
- `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
- `position` - the position in the reading order of the bbox, starting from 0.
- `label` - the label for the bbox. See the layout section of the documentation for a list of potential labels.
- `page` - the page number in the file
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox.

**Performance tips**

Setting the `ORDER_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `360MB` of VRAM, so very high batch sizes are possible. The default is a batch size `32`, which will use about 11GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `4`.

### From python

```python
from PIL import Image
from surya.ordering import batch_ordering
from surya.model.ordering.processor import load_processor
from surya.model.ordering.model import load_model

image = Image.open(IMAGE_PATH)
# bboxes should be a list of lists with layout bboxes for the image in [x1,y1,x2,y2] format
# You can get this from the layout model, see above for usage
bboxes = [bbox1, bbox2, ...]

model = load_model()
processor = load_processor()

# order_predictions will be a list of dicts, one per image
order_predictions = batch_ordering([image], [bboxes], model, processor)
```

## Table Recognition

This command will write out a json file with the detected table cells and row/column ids, along with row/column bounding boxes. If you want to get a formatted markdown table, check out the [tabled](https://www.github.com/VikParuchuri/tabled) repo.
Expand Down Expand Up @@ -324,6 +277,9 @@ The `results.json` file will contain a json dictionary where the keys are the in

Setting the `TABLE_REC_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `150MB` of VRAM, so very high batch sizes are possible. The default is a batch size `64`, which will use about 10GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `8`.

### From python

See `table_recognition.py` for a code sample. Table recognition depends on extracting cells, so it is a little more involved to setup than other model types.

# Limitations

Expand Down Expand Up @@ -410,16 +366,15 @@ Then we calculate precision and recall for the whole dataset.

## Layout analysis

![Benchmark chart](static/images/benchmark_layout_chart.png)

| Layout Type | precision | recall |
| ----------- | --------- | ------ |
| Image | 0.97 | 0.96 |
| Table | 0.99 | 0.99 |
| Text | 0.9 | 0.97 |
| Title | 0.94 | 0.88 |
| Layout Type | precision | recall |
|---------------|-------------|----------|
| Image | 0.91265 | 0.93976 |
| List | 0.80849 | 0.86792 |
| Table | 0.84957 | 0.96104 |
| Text | 0.93019 | 0.94571 |
| Title | 0.92102 | 0.95404 |

Time per image - .4 seconds on GPU (A10).
Time per image - .13 seconds on GPU (A10).

**Methodology**

Expand All @@ -430,7 +385,7 @@ I benchmarked the layout analysis on [Publaynet](https://github.com/ibm-aur-nlp/

## Reading Order

75% mean accuracy, and .14 seconds per image on an A6000 GPU. See methodology for notes - this benchmark is not perfect measure of accuracy, and is more useful as a sanity check.
88% mean accuracy, and .4 seconds per image on an A10 GPU. See methodology for notes - this benchmark is not perfect measure of accuracy, and is more useful as a sanity check.

**Methodology**

Expand Down
20 changes: 8 additions & 12 deletions benchmark/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import json

from surya.benchmark.metrics import precision_recall
from surya.detection import batch_text_detection
from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
from surya.model.layout.model import load_model, load_processor
from surya.model.layout.model import load_model
from surya.model.layout.processor import load_processor
from surya.input.processing import convert_if_not_rgb
from surya.layout import batch_layout_detection
from surya.postprocessing.heatmap import draw_bboxes_on_image
Expand All @@ -26,8 +25,6 @@ def main():

model = load_model()
processor = load_processor()
det_model = load_det_model()
det_processor = load_det_processor()

pathname = "layout_bench"
# These have already been shuffled randomly, so sampling from the start is fine
Expand All @@ -36,12 +33,10 @@ def main():
images = convert_if_not_rgb(images)

if settings.LAYOUT_STATIC_CACHE:
line_prediction = batch_text_detection(images[:1], det_model, det_processor)
batch_layout_detection(images[:1], model, processor, line_prediction)
batch_layout_detection(images[:1], model, processor)

start = time.time()
line_predictions = batch_text_detection(images, det_model, det_processor)
layout_predictions = batch_layout_detection(images, model, processor, line_predictions)
layout_predictions = batch_layout_detection(images, model, processor)
surya_time = time.time() - start

folder_name = os.path.basename(pathname).split(".")[0]
Expand All @@ -50,9 +45,10 @@ def main():

label_alignment = { # First is publaynet, second is surya
"Image": [["Figure"], ["Picture", "Figure"]],
"Table": [["Table"], ["Table"]],
"Text": [["Text", "List"], ["Text", "Formula", "Footnote", "Caption", "List-item"]],
"Title": [["Title"], ["Section-header", "Title"]]
"Table": [["Table"], ["Table", "Form", "TableOfContents"]],
"Text": [["Text"], ["Text", "Formula", "Footnote", "Caption", "TextInlineMath", "Code", "Handwriting"]],
"List": [["List"], ["ListItem"]],
"Title": [["Title"], ["SectionHeader", "Title"]]
}

page_metrics = collections.OrderedDict()
Expand Down
30 changes: 20 additions & 10 deletions benchmark/ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import json

from surya.input.processing import convert_if_not_rgb
from surya.model.ordering.model import load_model
from surya.model.ordering.processor import load_processor
from surya.ordering import batch_ordering
from surya.layout import batch_layout_detection
from surya.model.layout.model import load_model
from surya.model.layout.processor import load_processor
from surya.schema import Bbox
from surya.settings import settings
from surya.benchmark.metrics import rank_accuracy
import os
Expand All @@ -15,7 +16,7 @@


def main():
parser = argparse.ArgumentParser(description="Benchmark surya reading order model.")
parser = argparse.ArgumentParser(description="Benchmark surya layout for reading order.")
parser.add_argument("--results_dir", type=str, help="Path to JSON file with benchmark results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
parser.add_argument("--max", type=int, help="Maximum number of images to run benchmark on.", default=None)
args = parser.parse_args()
Expand All @@ -31,10 +32,9 @@ def main():
dataset = datasets.load_dataset(settings.ORDER_BENCH_DATASET_NAME, split=split)
images = list(dataset["image"])
images = convert_if_not_rgb(images)
bboxes = list(dataset["bboxes"])

start = time.time()
order_predictions = batch_ordering(images, bboxes, model, processor)
layout_predictions = batch_layout_detection(images, model, processor)
surya_time = time.time() - start

folder_name = os.path.basename(pathname).split(".")[0]
Expand All @@ -43,11 +43,21 @@ def main():

page_metrics = collections.OrderedDict()
mean_accuracy = 0
for idx, order_pred in enumerate(order_predictions):
for idx, order_pred in enumerate(layout_predictions):
row = dataset[idx]
pred_labels = [str(l.position) for l in order_pred.bboxes]
labels = row["labels"]
accuracy = rank_accuracy(pred_labels, labels)
bboxes = row["bboxes"]
pred_positions = []
for label, bbox in zip(labels, bboxes):
max_intersection = 0
matching_idx = 0
for pred_box in order_pred.bboxes:
intersection = pred_box.intersection_pct(Bbox(bbox=bbox))
if intersection > max_intersection:
max_intersection = intersection
matching_idx = pred_box.position
pred_positions.append(matching_idx)
accuracy = rank_accuracy(pred_positions, labels)
mean_accuracy += accuracy
page_results = {
"accuracy": accuracy,
Expand All @@ -56,7 +66,7 @@ def main():

page_metrics[idx] = page_results

mean_accuracy /= len(order_predictions)
mean_accuracy /= len(layout_predictions)

out_data = {
"time": surya_time,
Expand Down
19 changes: 5 additions & 14 deletions detect_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
import json
from collections import defaultdict

from surya.detection import batch_text_detection
from surya.input.load import load_from_folder, load_from_file
from surya.layout import batch_layout_detection
from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
from surya.model.layout.model import load_model, load_processor
from surya.model.layout.model import load_model
from surya.model.layout.processor import load_processor
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.settings import settings
import os
Expand All @@ -27,8 +26,6 @@ def main():

model = load_model()
processor = load_processor()
det_model = load_det_model()
det_processor = load_det_processor()

if os.path.isdir(args.input_path):
images, names, _ = load_from_folder(args.input_path, args.max)
Expand All @@ -38,9 +35,7 @@ def main():
folder_name = os.path.basename(args.input_path).split(".")[0]

start = time.time()
line_predictions = batch_text_detection(images, det_model, det_processor)

layout_predictions = batch_layout_detection(images, model, processor, line_predictions, include_maps=args.debug)
layout_predictions = batch_layout_detection(images, model, processor)
result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)
if args.debug:
Expand All @@ -49,17 +44,13 @@ def main():
if args.images:
for idx, (image, layout_pred, name) in enumerate(zip(images, layout_predictions, names)):
polygons = [p.polygon for p in layout_pred.bboxes]
labels = [p.label for p in layout_pred.bboxes]
labels = [f"{p.label}-{p.position}" for p in layout_pred.bboxes]
bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image), labels=labels)
bbox_image.save(os.path.join(result_path, f"{name}_{idx}_layout.png"))

if args.debug:
heatmap = layout_pred.segmentation_map
heatmap.save(os.path.join(result_path, f"{name}_{idx}_segmentation.png"))

predictions_by_page = defaultdict(list)
for idx, (pred, name, image) in enumerate(zip(layout_predictions, names, images)):
out_pred = pred.model_dump(exclude=["segmentation_map"])
out_pred = pred.model_dump()
out_pred["page"] = len(predictions_by_page[name]) + 1
predictions_by_page[name].append(out_pred)

Expand Down
35 changes: 5 additions & 30 deletions ocr_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,19 @@
from surya.input.pdflines import get_page_text_lines, get_table_blocks
from surya.layout import batch_layout_detection
from surya.model.detection.model import load_model, load_processor
from surya.model.layout.model import load_model as load_layout_model, load_processor as load_layout_processor
from surya.model.layout.model import load_model as load_layout_model
from surya.model.layout.processor import load_processor as load_layout_processor
from surya.model.recognition.model import load_model as load_rec_model
from surya.model.recognition.processor import load_processor as load_rec_processor
from surya.model.ordering.processor import load_processor as load_order_processor
from surya.model.ordering.model import load_model as load_order_model
from surya.model.table_rec.model import load_model as load_table_model
from surya.model.table_rec.processor import load_processor as load_table_processor
from surya.ordering import batch_ordering
from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image
from surya.ocr import run_ocr
from surya.postprocessing.text import draw_text_on_image
from PIL import Image
from surya.languages import CODE_TO_LANGUAGE
from surya.input.langs import replace_lang_with_code
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, OrderResult, TableResult
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, TableResult
from surya.settings import settings
from surya.tables import batch_table_recognition
from surya.postprocessing.util import rescale_bboxes, rescale_bbox
Expand All @@ -43,10 +41,6 @@ def load_rec_cached():
def load_layout_cached():
return load_layout_model(), load_layout_processor()

@st.cache_resource()
def load_order_cached():
return load_order_model(), load_order_processor()


@st.cache_resource()
def load_table_cached():
Expand All @@ -61,24 +55,13 @@ def text_detection(img) -> (Image.Image, TextDetectionResult):


def layout_detection(img) -> (Image.Image, LayoutResult):
_, det_pred = text_detection(img)
pred = batch_layout_detection([img], layout_model, layout_processor, [det_pred])[0]
pred = batch_layout_detection([img], layout_model, layout_processor)[0]
polygons = [p.polygon for p in pred.bboxes]
labels = [p.label for p in pred.bboxes]
labels = [f"{p.label}-{p.position}" for p in pred.bboxes]
layout_img = draw_polys_on_image(polygons, img.copy(), labels=labels, label_font_size=18)
return layout_img, pred


def order_detection(img) -> (Image.Image, OrderResult):
_, layout_pred = layout_detection(img)
bboxes = [l.bbox for l in layout_pred.bboxes]
pred = batch_ordering([img], [bboxes], order_model, order_processor)[0]
polys = [l.polygon for l in pred.bboxes]
positions = [str(l.position) for l in pred.bboxes]
order_img = draw_polys_on_image(polys, img.copy(), labels=positions, label_font_size=18)
return order_img, pred


def table_recognition(img, highres_img, filepath, page_idx: int, use_pdf_boxes: bool, skip_table_detection: bool) -> (Image.Image, List[TableResult]):
if skip_table_detection:
layout_tables = [(0, 0, highres_img.size[0], highres_img.size[1])]
Expand Down Expand Up @@ -171,7 +154,6 @@ def page_count(pdf_file):
det_model, det_processor = load_det_cached()
rec_model, rec_processor = load_rec_cached()
layout_model, layout_processor = load_layout_cached()
order_model, order_processor = load_order_cached()
table_model, table_processor = load_table_cached()


Expand Down Expand Up @@ -211,7 +193,6 @@ def page_count(pdf_file):
text_det = st.sidebar.button("Run Text Detection")
text_rec = st.sidebar.button("Run OCR")
layout_det = st.sidebar.button("Run Layout Analysis")
order_det = st.sidebar.button("Run Reading Order")
table_rec = st.sidebar.button("Run Table Rec")
use_pdf_boxes = st.sidebar.checkbox("PDF table boxes", value=True, help="Table recognition only: Use the bounding boxes from the PDF file vs text detection model.")
skip_table_detection = st.sidebar.checkbox("Skip table detection", value=False, help="Table recognition only: Skip table detection and treat the whole image/page as a table.")
Expand Down Expand Up @@ -245,12 +226,6 @@ def page_count(pdf_file):
with text_tab:
st.text("\n".join([p.text for p in pred.text_lines]))

if order_det:
order_img, pred = order_detection(pil_image)
with col1:
st.image(order_img, caption="Reading Order", use_column_width=True)
st.json(pred.model_dump(), expanded=True)


if table_rec:
table_img, pred = table_recognition(pil_image, pil_image_highres, in_file, page_number - 1 if page_number else None, use_pdf_boxes, skip_table_detection)
Expand Down
Loading

0 comments on commit 30ce562

Please sign in to comment.