diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 42f20a2..6620ab3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -36,6 +36,10 @@ jobs: run: | poetry run python benchmark/layout.py --max 5 poetry run python scripts/verify_benchmark_scores.py results/benchmark/layout_bench/results.json --bench_type layout + - name: Run ordering benchmark text + run: | + poetry run python benchmark/ordering.py --max 5 + poetry run python scripts/verify_benchmark_scores.py results/benchmark/order_bench/results.json --bench_type ordering diff --git a/README.md b/README.md index e11643e..f615407 100644 --- a/README.md +++ b/README.md @@ -2,16 +2,20 @@ Surya is a document OCR toolkit that does: -- Accurate OCR in 90+ languages +- OCR in 90+ languages that benchmarks favorably vs cloud services - Line-level text detection in any language -- Layout analysis (table, image, header, etc detection) in any language +- Layout analysis (table, image, header, etc detection) +- Reading order detection It works on a range of documents (see [usage](#usage) and [benchmarks](#benchmarks) for more details). -| Detection | OCR | Layout | -|:----------------------------------------------------------------:|:-----------------------------------------------------------------------:|:---------------------------------------------------------------------:| -| ![New York Times Article Detection](static/images/excerpt.png) | ![New York Times Article Recognition](static/images/excerpt_text.png) | ![New York Times Article Detection](static/images/excerpt_layout.png) | +| Detection | OCR | +|:----------------------------------------------------------------:|:-----------------------------------------------------------------------:| +| ![New York Times Article Detection](static/images/excerpt.png) | ![New York Times Article Recognition](static/images/excerpt_text.png) | +| Layout | Reading Order | +|:------------------------------------------------------------------:|:--------------------------------------------------------------------------:| +| ![New York Times Article Layout](static/images/excerpt_layout.png) | ![New York Times Article Reading Order](static/images/excerpt_reading.jpg) | Surya is named for the [Hindu sun god](https://en.wikipedia.org/wiki/Surya), who has universal vision. @@ -21,19 +25,19 @@ Surya is named for the [Hindu sun god](https://en.wikipedia.org/wiki/Surya), who ## Examples -| Name | Text Detection | OCR | Layout | -|------------------|:-----------------------------------:|-----------------------------------------:|--------:| -| Japanese | [Image](static/images/japanese.jpg) | [Image](static/images/japanese_text.jpg) | [Image](static/images/japanese_layout.jpg) | -| Chinese | [Image](static/images/chinese.jpg) | [Image](static/images/chinese_text.jpg) | [Image](static/images/chinese_layout.jpg) | -| Hindi | [Image](static/images/hindi.jpg) | [Image](static/images/hindi_text.jpg) | [Image](static/images/hindi_layout.jpg) | -| Arabic | [Image](static/images/arabic.jpg) | [Image](static/images/arabic_text.jpg) | [Image](static/images/arabic_layout.jpg) | -| Chinese + Hindi | [Image](static/images/chi_hind.jpg) | [Image](static/images/chi_hind_text.jpg) | [Image](static/images/chi_hind_layout.jpg) | -| Presentation | [Image](static/images/pres.png) | [Image](static/images/pres_text.jpg) | [Image](static/images/pres_layout.jpg) | -| Scientific Paper | [Image](static/images/paper.jpg) | [Image](static/images/paper_text.jpg) | [Image](static/images/paper_layout.jpg) | -| Scanned Document | [Image](static/images/scanned.png) | [Image](static/images/scanned_text.jpg) | [Image](static/images/scanned_layout.jpg) | -| New York Times | [Image](static/images/nyt.jpg) | [Image](static/images/nyt_text.jpg) | [Image](static/images/nyt_layout.jpg) | -| Scanned Form | [Image](static/images/funsd.png) | [Image](static/images/funsd_text.jpg) | -- | -| Textbook | [Image](static/images/textbook.jpg) | [Image](static/images/textbook_text.jpg) | [Image](static/images/textbook_layout.jpg) | +| Name | Detection | OCR | Layout | Order | +|------------------|:-----------------------------------:|-----------------------------------------:|-------------------------------------------:|--------------------------------------------:| +| Japanese | [Image](static/images/japanese.jpg) | [Image](static/images/japanese_text.jpg) | [Image](static/images/japanese_layout.jpg) | [Image](static/images/japanese_reading.jpg) | +| Chinese | [Image](static/images/chinese.jpg) | [Image](static/images/chinese_text.jpg) | [Image](static/images/chinese_layout.jpg) | [Image](static/images/chinese_reading.jpg) | +| Hindi | [Image](static/images/hindi.jpg) | [Image](static/images/hindi_text.jpg) | [Image](static/images/hindi_layout.jpg) | [Image](static/images/hindi_reading.jpg) | +| Arabic | [Image](static/images/arabic.jpg) | [Image](static/images/arabic_text.jpg) | [Image](static/images/arabic_layout.jpg) | [Image](static/images/arabic_reading.jpg) | +| Chinese + Hindi | [Image](static/images/chi_hind.jpg) | [Image](static/images/chi_hind_text.jpg) | [Image](static/images/chi_hind_layout.jpg) | [Image](static/images/chi_hind_reading.jpg) | +| Presentation | [Image](static/images/pres.png) | [Image](static/images/pres_text.jpg) | [Image](static/images/pres_layout.jpg) | [Image](static/images/pres_reading.jpg) | +| Scientific Paper | [Image](static/images/paper.jpg) | [Image](static/images/paper_text.jpg) | [Image](static/images/paper_layout.jpg) | [Image](static/images/paper_reading.jpg) | +| Scanned Document | [Image](static/images/scanned.png) | [Image](static/images/scanned_text.jpg) | [Image](static/images/scanned_layout.jpg) | [Image](static/images/scanned_reading.jpg) | +| New York Times | [Image](static/images/nyt.jpg) | [Image](static/images/nyt_text.jpg) | [Image](static/images/nyt_layout.jpg) | [Image](static/images/nyt_order.jpg) | +| Scanned Form | [Image](static/images/funsd.png) | [Image](static/images/funsd_text.jpg) | [Image](static/images/funsd_layout.jpg) | [Image](static/images/funsd_reading.jpg) | +| Textbook | [Image](static/images/textbook.jpg) | [Image](static/images/textbook_text.jpg) | [Image](static/images/textbook_layout.jpg) | [Image](static/images/textbook_order.jpg) | # Installation @@ -61,11 +65,11 @@ pip install streamlit surya_gui ``` -Pass the `--math` command line argument to use the math detection model instead of the default model. This will detect math better, but will be worse at everything else. +Pass the `--math` command line argument to use the math text detection model instead of the default model. This will detect math better, but will be worse at everything else. ## OCR (text recognition) -You can OCR text in an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected text and bboxes, and optionally save images of the reconstructed page. +This command will write out a json file with the detected text and bboxes: ```shell surya_ocr DATA_PATH --images --langs hi,en @@ -113,7 +117,7 @@ predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec ## Text line detection -You can detect text lines in an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected bboxes. +This command will write out a json file with the detected bboxes. ```shell surya_detect DATA_PATH --images @@ -158,7 +162,7 @@ predictions = batch_text_detection([image], model, processor) ## Layout analysis -You can detect the layout of an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected layout. +This command will write out a json file with the detected layout. ```shell surya_layout DATA_PATH --images @@ -203,13 +207,58 @@ line_predictions = batch_text_detection([image], det_model, det_processor) layout_predictions = batch_layout_detection([image], model, processor, line_predictions) ``` +## Reading order + +This command will write out a json file with the detected reading order and layout. + +```shell +surya_order DATA_PATH --images +``` + +- `DATA_PATH` can be an image, pdf, or folder of images/pdfs +- `--images` will save images of the pages and detected text lines (optional) +- `--max` specifies the maximum number of pages to process if you don't want to process everything +- `--results_dir` specifies the directory to save results to instead of the default + +The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains: + +- `bboxes` - detected bounding boxes for text + - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. + - `position` - the position in the reading order of the bbox, starting from 0. + - `label` - the label for the bbox. See the layout section of the documentation for a list of potential labels. +- `page` - the page number in the file +- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox. + +**Performance tips** + +Setting the `ORDER_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `360MB` of VRAM, so very high batch sizes are possible. The default is a batch size `32`, which will use about 11GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `4`. + +### From python + +```python +from PIL import Image +from surya.ordering import batch_ordering +from surya.model.ordering.processor import load_processor +from surya.model.ordering.model import load_model + +image = Image.open(IMAGE_PATH) +# bboxes should be a list of lists with layout bboxes for the image in [x1,y1,x2,y2] format +# You can get this from the layout model, see above for usage +bboxes = [bbox1, bbox2, ...] + +model = load_model() +processor = load_processor() + +# order_predictions will be a list of dicts, one per image +order_predictions = batch_ordering([image], [bboxes], model, processor) +``` + # Limitations - This is specialized for document OCR. It will likely not work on photos or other images. -- Surya is for OCR - the goal is to recognize the text lines correctly, not sort them into reading order. Surya will attempt to sort the lines, which will work in many cases, but use something like [marker](https://github.com/VikParuchuri/marker) or other postprocessing if you need to order the text. - It is for printed text, not handwriting (though it may work on some handwriting). - The text detection model has trained itself to ignore advertisements. -- You can find language support for OCR in `surya/languages.py`. Text detection and layout analysis will work with any language. +- You can find language support for OCR in `surya/languages.py`. Text detection, layout analysis, and reading order will work with any language. ## Troubleshooting @@ -232,7 +281,7 @@ If you want to develop surya, you can install it manually: ## OCR -![Benchmark chart](static/images/benchmark_rec_chart.png) +![Benchmark chart tesseract](static/images/benchmark_rec_chart.png) | Model | Time per page (s) | Avg similarity (⬆) | |-----------|-------------------|--------------------| @@ -243,12 +292,22 @@ If you want to develop surya, you can install it manually: Tesseract is CPU-based, and surya is CPU or GPU. I tried to cost-match the resources used, so I used a 1xA6000 (48GB VRAM) for surya, and 28 CPU cores for Tesseract (same price on Lambda Labs/DigitalOcean). +### Google Cloud Vision + +I benchmarked OCR against Google Cloud vision since it has similar language coverage to Surya. + +![Benchmark chart google cloud](static/images/gcloud_rec_bench.png) + +[Full language results](static/images/gcloud_full_langs.png) + **Methodology** I measured normalized sentence similarity (0-1, higher is better) based on a set of real-world and synthetic pdfs. I sampled PDFs from common crawl, then filtered out the ones with bad OCR. I couldn't find PDFs for some languages, so I also generated simple synthetic PDFs for those. I used the reference line bboxes from the PDFs with both tesseract and surya, to just evaluate the OCR quality. +For Google Cloud, I aligned the output from Google Cloud with the ground truth. I had to skip RTL languages since they didn't align well. + ## Text line detection ![Benchmark chart](static/images/benchmark_chart_small.png) @@ -297,6 +356,16 @@ I benchmarked the layout analysis on [Publaynet](https://github.com/ibm-aur-nlp/ - Precision - how well the predicted bboxes cover ground truth bboxes - Recall - how well ground truth bboxes cover predicted bboxes +## Reading Order + +75% mean accuracy, and .14 seconds per image on an A6000 GPU. See methodology for notes - this benchmark is not perfect measure of accuracy, and is more useful as a sanity check. + +**Methodology** + +I benchmarked the layout analysis on the layout dataset from [here](https://www.icst.pku.edu.cn/cpdp/sjzy/), which was not in the training data. Unfortunately, this dataset is fairly noisy, and not all the labels are correct. It was very hard to find a dataset annotated with reading order and also layout information. I wanted to avoid using a cloud service for the ground truth. + +The accuracy is computed by finding if each pair of layout boxes is in the correct order, then taking the % that are correct. + ## Running your own benchmarks You can benchmark the performance of surya on your machine. @@ -343,6 +412,16 @@ python benchmark/layout.py - `--debug` will render images with detected text - `--results_dir` will let you specify a directory to save results to instead of the default one +**Reading Order** + +``` +python benchmark/ordering.py +``` + +- `--max` controls how many images to process for the benchmark +- `--debug` will render images with detected text +- `--results_dir` will let you specify a directory to save results to instead of the default one + # Training Text detection was trained on 4x A6000s for 3 days. It used a diverse set of images as training data. It was trained from scratch using a modified segformer architecture that reduces inference RAM requirements. @@ -351,7 +430,7 @@ Text recognition was trained on 4x A6000s for 2 weeks. It was trained using a m # Commercial usage -The text detection, layout analysis, and OCR models were trained from scratch, so they're okay for commercial usage. The weights are licensed cc-by-nc-sa-4.0, but I will waive that for any organization under $5M USD in gross revenue in the most recent 12-month period. +All models were trained from scratch, so they're okay for commercial usage. The weights are licensed cc-by-nc-sa-4.0, but I will waive that for any organization under $5M USD in gross revenue in the most recent 12-month period. If you want to remove the GPL license requirements for inference or use the weights commercially over the revenue limit, please contact me at surya@vikas.sh for dual licensing. @@ -364,4 +443,4 @@ This work would not have been possible without amazing open source AI work: - [transformers](https://github.com/huggingface/transformers) from huggingface - [CRAFT](https://github.com/clovaai/CRAFT-pytorch), a great scene text detection model -Thank you to everyone who makes open source AI possible. +Thank you to everyone who makes open source AI possible. \ No newline at end of file diff --git a/benchmark/ordering.py b/benchmark/ordering.py new file mode 100644 index 0000000..fd301ee --- /dev/null +++ b/benchmark/ordering.py @@ -0,0 +1,79 @@ +import argparse +import collections +import copy +import json + +from surya.benchmark.metrics import precision_recall +from surya.model.ordering.model import load_model +from surya.model.ordering.processor import load_processor +from surya.postprocessing.heatmap import draw_bboxes_on_image +from surya.ordering import batch_ordering +from surya.settings import settings +from surya.benchmark.metrics import rank_accuracy +import os +import time +from tabulate import tabulate +import datasets + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark surya reading order model.") + parser.add_argument("--results_dir", type=str, help="Path to JSON file with benchmark results.", default=os.path.join(settings.RESULT_DIR, "benchmark")) + parser.add_argument("--max", type=int, help="Maximum number of images to run benchmark on.", default=None) + args = parser.parse_args() + + model = load_model() + processor = load_processor() + + pathname = "order_bench" + # These have already been shuffled randomly, so sampling from the start is fine + split = "train" + if args.max is not None: + split = f"train[:{args.max}]" + dataset = datasets.load_dataset(settings.ORDER_BENCH_DATASET_NAME, split=split) + images = list(dataset["image"]) + images = [i.convert("RGB") for i in images] + bboxes = list(dataset["bboxes"]) + + start = time.time() + order_predictions = batch_ordering(images, bboxes, model, processor) + surya_time = time.time() - start + + folder_name = os.path.basename(pathname).split(".")[0] + result_path = os.path.join(args.results_dir, folder_name) + os.makedirs(result_path, exist_ok=True) + + page_metrics = collections.OrderedDict() + mean_accuracy = 0 + for idx, order_pred in enumerate(order_predictions): + row = dataset[idx] + pred_labels = [str(l.position) for l in order_pred.bboxes] + labels = row["labels"] + accuracy = rank_accuracy(pred_labels, labels) + mean_accuracy += accuracy + page_results = { + "accuracy": accuracy, + "box_count": len(labels) + } + + page_metrics[idx] = page_results + + mean_accuracy /= len(order_predictions) + + out_data = { + "time": surya_time, + "mean_accuracy": mean_accuracy, + "page_metrics": page_metrics + } + + with open(os.path.join(result_path, "results.json"), "w+") as f: + json.dump(out_data, f, indent=4) + + print(f"Mean accuracy is {mean_accuracy:.2f}.") + print(f"Took {surya_time / len(images):.2f} seconds per image, and {surya_time:.1f} seconds total.") + print("Mean accuracy is the % of correct ranking pairs.") + print(f"Wrote results to {result_path}") + + +if __name__ == "__main__": + main() diff --git a/ocr_app.py b/ocr_app.py index 53979c0..5560f1d 100644 --- a/ocr_app.py +++ b/ocr_app.py @@ -10,13 +10,16 @@ from surya.model.detection.segformer import load_model, load_processor from surya.model.recognition.model import load_model as load_rec_model from surya.model.recognition.processor import load_processor as load_rec_processor +from surya.model.ordering.processor import load_processor as load_order_processor +from surya.model.ordering.model import load_model as load_order_model +from surya.ordering import batch_ordering from surya.postprocessing.heatmap import draw_polys_on_image from surya.ocr import run_ocr from surya.postprocessing.text import draw_text_on_image from PIL import Image from surya.languages import CODE_TO_LANGUAGE from surya.input.langs import replace_lang_with_code -from surya.schema import OCRResult, TextDetectionResult, LayoutResult +from surya.schema import OCRResult, TextDetectionResult, LayoutResult, OrderResult from surya.settings import settings parser = argparse.ArgumentParser(description="Run OCR on an image or PDF.") @@ -43,15 +46,19 @@ def load_rec_cached(): def load_layout_cached(): return load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT), load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) +@st.cache_resource() +def load_order_cached(): + return load_order_model(), load_order_processor() + -def text_detection(img) -> TextDetectionResult: +def text_detection(img) -> (Image.Image, TextDetectionResult): pred = batch_text_detection([img], det_model, det_processor)[0] polygons = [p.polygon for p in pred.bboxes] det_img = draw_polys_on_image(polygons, img.copy()) return det_img, pred -def layout_detection(img) -> LayoutResult: +def layout_detection(img) -> (Image.Image, LayoutResult): _, det_pred = text_detection(img) pred = batch_layout_detection([img], layout_model, layout_processor, [det_pred])[0] polygons = [p.polygon for p in pred.bboxes] @@ -60,8 +67,18 @@ def layout_detection(img) -> LayoutResult: return layout_img, pred +def order_detection(img) -> (Image.Image, OrderResult): + _, layout_pred = layout_detection(img) + bboxes = [l.bbox for l in layout_pred.bboxes] + pred = batch_ordering([img], [bboxes], order_model, order_processor)[0] + polys = [l.polygon for l in pred.bboxes] + positions = [str(l.position) for l in pred.bboxes] + order_img = draw_polys_on_image(polys, img.copy(), labels=positions, label_font_size=20) + return order_img, pred + + # Function for OCR -def ocr(img, langs: List[str]) -> OCRResult: +def ocr(img, langs: List[str]) -> (Image.Image, OCRResult): replace_lang_with_code(langs) img_pred = run_ocr([img], [langs], det_model, det_processor, rec_model, rec_processor)[0] @@ -101,6 +118,7 @@ def page_count(pdf_file): det_model, det_processor = load_det_cached() rec_model, rec_processor = load_rec_cached() layout_model, layout_processor = load_layout_cached() +order_model, order_processor = load_order_cached() st.markdown(""" @@ -136,9 +154,13 @@ def page_count(pdf_file): text_det = st.sidebar.button("Run Text Detection") text_rec = st.sidebar.button("Run OCR") layout_det = st.sidebar.button("Run Layout Analysis") +order_det = st.sidebar.button("Run Reading Order") + +if pil_image is None: + st.stop() # Run Text Detection -if text_det and pil_image is not None: +if text_det: det_img, pred = text_detection(pil_image) with col1: st.image(det_img, caption="Detected Text", use_column_width=True) @@ -146,14 +168,14 @@ def page_count(pdf_file): # Run layout -if layout_det and pil_image is not None: +if layout_det: layout_img, pred = layout_detection(pil_image) with col1: st.image(layout_img, caption="Detected Layout", use_column_width=True) st.json(pred.model_dump(exclude=["segmentation_map"]), expanded=True) # Run OCR -if text_rec and pil_image is not None: +if text_rec: rec_img, pred = ocr(pil_image, languages) with col1: st.image(rec_img, caption="OCR Result", use_column_width=True) @@ -163,5 +185,11 @@ def page_count(pdf_file): with text_tab: st.text("\n".join([p.text for p in pred.text_lines])) +if order_det: + order_img, pred = order_detection(pil_image) + with col1: + st.image(order_img, caption="Reading Order", use_column_width=True) + st.json(pred.model_dump(), expanded=True) + with col2: st.image(pil_image, caption="Uploaded Image", use_column_width=True) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 342c0e3..e5e2a5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "surya-ocr" -version = "0.3.0" -description = "OCR, layout analysis, and line detection in 90+ languages" +version = "0.4.0" +description = "OCR, layout, reading order, and line detection in 90+ languages" authors = ["Vik Paruchuri "] readme = "README.md" license = "GPL-3.0-or-later" @@ -15,7 +15,8 @@ include = [ "ocr_text.py", "ocr_app.py", "run_ocr_app.py", - "detect_layout.py" + "detect_layout.py", + "reading_order.py", ] [tool.poetry.dependencies] @@ -48,6 +49,7 @@ surya_detect = "detect_text:main" surya_ocr = "ocr_text:main" surya_layout = "detect_layout:main" surya_gui = "run_ocr_app:run_app" +surya_order = "reading_order:main" [build-system] requires = ["poetry-core"] diff --git a/reading_order.py b/reading_order.py new file mode 100644 index 0000000..0e169b8 --- /dev/null +++ b/reading_order.py @@ -0,0 +1,81 @@ +import argparse +import copy +import json +from collections import defaultdict + +from surya.detection import batch_text_detection +from surya.input.load import load_from_folder, load_from_file +from surya.layout import batch_layout_detection +from surya.model.detection.segformer import load_model as load_det_model, load_processor as load_det_processor +from surya.model.ordering.model import load_model +from surya.model.ordering.processor import load_processor +from surya.ordering import batch_ordering +from surya.postprocessing.heatmap import draw_polys_on_image +from surya.settings import settings +import os + + +def main(): + parser = argparse.ArgumentParser(description="Find reading order of an input file or folder (PDFs or image).") + parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to find reading order in.") + parser.add_argument("--results_dir", type=str, help="Path to JSON file with layout results.", default=os.path.join(settings.RESULT_DIR, "surya")) + parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None) + parser.add_argument("--images", action="store_true", help="Save images of detected layout bboxes.", default=False) + args = parser.parse_args() + + model = load_model() + processor = load_processor() + + layout_model = load_det_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) + layout_processor = load_det_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) + + det_model = load_det_model() + det_processor = load_det_processor() + + if os.path.isdir(args.input_path): + images, names = load_from_folder(args.input_path, args.max) + folder_name = os.path.basename(args.input_path) + else: + images, names = load_from_file(args.input_path, args.max) + folder_name = os.path.basename(args.input_path).split(".")[0] + + line_predictions = batch_text_detection(images, det_model, det_processor) + layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions) + bboxes = [] + for layout_pred in layout_predictions: + bbox = [l.bbox for l in layout_pred.bboxes] + bboxes.append(bbox) + + order_predictions = batch_ordering(images, bboxes, model, processor) + result_path = os.path.join(args.results_dir, folder_name) + os.makedirs(result_path, exist_ok=True) + + if args.images: + for idx, (image, layout_pred, order_pred, name) in enumerate(zip(images, layout_predictions, order_predictions, names)): + polys = [l.polygon for l in order_pred.bboxes] + labels = [str(l.position) for l in order_pred.bboxes] + bbox_image = draw_polys_on_image(polys, copy.deepcopy(image), labels=labels, label_font_size=20) + bbox_image.save(os.path.join(result_path, f"{name}_{idx}_order.png")) + + predictions_by_page = defaultdict(list) + for idx, (layout_pred, pred, name, image) in enumerate(zip(layout_predictions, order_predictions, names, images)): + out_pred = pred.model_dump() + for bbox, layout_bbox in zip(out_pred["bboxes"], layout_pred.bboxes): + bbox["label"] = layout_bbox.label + + out_pred["page"] = len(predictions_by_page[name]) + 1 + predictions_by_page[name].append(out_pred) + + # Sort in reading order + for name in predictions_by_page: + for page_preds in predictions_by_page[name]: + page_preds["bboxes"] = sorted(page_preds["bboxes"], key=lambda x: x["position"]) + + with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: + json.dump(predictions_by_page, f, ensure_ascii=False) + + print(f"Wrote results to {result_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/verify_benchmark_scores.py b/scripts/verify_benchmark_scores.py index d343163..4e8db55 100644 --- a/scripts/verify_benchmark_scores.py +++ b/scripts/verify_benchmark_scores.py @@ -21,6 +21,12 @@ def verify_rec(data): raise ValueError("Scores do not meet the required threshold") +def verify_order(data): + score = data["mean_accuracy"] + if score < 0.75: + raise ValueError("Scores do not meet the required threshold") + + def verify_scores(file_path, bench_type): with open(file_path, 'r') as file: data = json.load(file) @@ -31,6 +37,8 @@ def verify_scores(file_path, bench_type): verify_rec(data) elif bench_type == "layout": verify_layout(data) + elif bench_type == "ordering": + verify_order(data) else: raise ValueError("Invalid benchmark type") diff --git a/static/images/arabic_reading.jpg b/static/images/arabic_reading.jpg new file mode 100644 index 0000000..df00ae6 Binary files /dev/null and b/static/images/arabic_reading.jpg differ diff --git a/static/images/chi_hind_reading.jpg b/static/images/chi_hind_reading.jpg new file mode 100644 index 0000000..338712b Binary files /dev/null and b/static/images/chi_hind_reading.jpg differ diff --git a/static/images/chinese_reading.jpg b/static/images/chinese_reading.jpg new file mode 100644 index 0000000..fcf49ab Binary files /dev/null and b/static/images/chinese_reading.jpg differ diff --git a/static/images/excerpt_reading.jpg b/static/images/excerpt_reading.jpg new file mode 100644 index 0000000..1ae15c0 Binary files /dev/null and b/static/images/excerpt_reading.jpg differ diff --git a/static/images/funsd_layout.jpg b/static/images/funsd_layout.jpg new file mode 100644 index 0000000..7f82a43 Binary files /dev/null and b/static/images/funsd_layout.jpg differ diff --git a/static/images/funsd_reading.jpg b/static/images/funsd_reading.jpg new file mode 100644 index 0000000..47ec1e1 Binary files /dev/null and b/static/images/funsd_reading.jpg differ diff --git a/static/images/gcloud_full_langs.png b/static/images/gcloud_full_langs.png new file mode 100644 index 0000000..95fd874 Binary files /dev/null and b/static/images/gcloud_full_langs.png differ diff --git a/static/images/gcloud_rec_bench.png b/static/images/gcloud_rec_bench.png new file mode 100644 index 0000000..f20d321 Binary files /dev/null and b/static/images/gcloud_rec_bench.png differ diff --git a/static/images/hindi_reading.jpg b/static/images/hindi_reading.jpg new file mode 100644 index 0000000..75d6986 Binary files /dev/null and b/static/images/hindi_reading.jpg differ diff --git a/static/images/japanese_reading.jpg b/static/images/japanese_reading.jpg new file mode 100644 index 0000000..26e7276 Binary files /dev/null and b/static/images/japanese_reading.jpg differ diff --git a/static/images/nyt_order.jpg b/static/images/nyt_order.jpg new file mode 100644 index 0000000..f98027b Binary files /dev/null and b/static/images/nyt_order.jpg differ diff --git a/static/images/paper_reading.jpg b/static/images/paper_reading.jpg new file mode 100644 index 0000000..c1675af Binary files /dev/null and b/static/images/paper_reading.jpg differ diff --git a/static/images/pres_reading.jpg b/static/images/pres_reading.jpg new file mode 100644 index 0000000..7f61e29 Binary files /dev/null and b/static/images/pres_reading.jpg differ diff --git a/static/images/scanned_reading.jpg b/static/images/scanned_reading.jpg new file mode 100644 index 0000000..bbf2bd5 Binary files /dev/null and b/static/images/scanned_reading.jpg differ diff --git a/static/images/textbook_order.jpg b/static/images/textbook_order.jpg new file mode 100644 index 0000000..fb06135 Binary files /dev/null and b/static/images/textbook_order.jpg differ diff --git a/surya/benchmark/metrics.py b/surya/benchmark/metrics.py index 9349bfb..afcb417 100644 --- a/surya/benchmark/metrics.py +++ b/surya/benchmark/metrics.py @@ -117,4 +117,23 @@ def mean_coverage(preds, references): if len(coverages) == 0: return 0 coverage = sum(coverages) / len(coverages) - return {"coverage": coverage} \ No newline at end of file + return {"coverage": coverage} + + +def rank_accuracy(preds, references): + # Preds and references need to be aligned so each position refers to the same bbox + pairs = [] + for i, pred in enumerate(preds): + for j, pred2 in enumerate(preds): + if i == j: + continue + pairs.append((i, j, pred > pred2)) + + # Find how many of the prediction rankings are correct + correct = 0 + for i, ref in enumerate(references): + for j, ref2 in enumerate(references): + if (i, j, ref > ref2) in pairs: + correct += 1 + + return correct / len(pairs) \ No newline at end of file diff --git a/surya/layout.py b/surya/layout.py index 582f12d..c1bfa14 100644 --- a/surya/layout.py +++ b/surya/layout.py @@ -50,29 +50,35 @@ def get_regions_from_detection_result(detection_result: TextDetectionResult, hea # Expand bbox to cover intersecting lines box_lines = defaultdict(list) used_lines = set() - for bbox_idx, bbox in enumerate(detected_boxes): - for line_idx, line_bbox in enumerate(line_bboxes): - if line_bbox.intersection_pct(bbox) >= .5 and line_idx not in used_lines: - box_lines[bbox_idx].append(line_bbox.bbox) - used_lines.add(line_idx) + + # We try 2 rounds of identifying the correct lines to snap to + # First round is majority intersection, second lowers the threshold + for thresh in [.5, .4]: + for bbox_idx, bbox in enumerate(detected_boxes): + for line_idx, line_bbox in enumerate(line_bboxes): + if line_bbox.intersection_pct(bbox) > thresh and line_idx not in used_lines: + box_lines[bbox_idx].append(line_bbox.bbox) + used_lines.add(line_idx) new_boxes = [] for bbox_idx, bbox in enumerate(detected_boxes): if bbox.label == "Picture" and bbox.area < 200: # Remove very small figures continue + # Skip if we didn't find any lines to snap to, except for Pictures and Formulas if bbox_idx not in box_lines and bbox.label not in ["Picture", "Formula"]: continue covered_lines = box_lines[bbox_idx] + # Snap non-picture layout boxes to correct text boundaries if len(covered_lines) > 0 and bbox.label not in ["Picture"]: min_x = min([line[0] for line in covered_lines]) min_y = min([line[1] for line in covered_lines]) max_x = max([line[2] for line in covered_lines]) max_y = max([line[3] for line in covered_lines]) + # Tables and formulas can contain text, but text isn't the whole area if bbox.label in ["Table", "Formula"]: - # Figures can tables can contain text, but text isn't the whole area min_x_box = min([b[0] for b in bbox.polygon]) min_y_box = min([b[1] for b in bbox.polygon]) max_x_box = max([b[0] for b in bbox.polygon]) @@ -97,6 +103,7 @@ def get_regions_from_detection_result(detection_result: TextDetectionResult, hea new_boxes.append(bbox) + # Merge tables together (sometimes one column is detected as a separate table) for i in range(5): # Up to 5 rounds of merging to_remove = set() for bbox_idx, bbox in enumerate(new_boxes): @@ -113,6 +120,7 @@ def get_regions_from_detection_result(detection_result: TextDetectionResult, hea new_boxes = [bbox for idx, bbox in enumerate(new_boxes) if idx not in to_remove] + # Ensure we account for all text lines in the layout unused_lines = [line for idx, line in enumerate(line_bboxes) if idx not in used_lines] for bbox in unused_lines: new_boxes.append(LayoutBox(polygon=bbox.polygon, label="Text", confidence=.5)) @@ -121,6 +129,19 @@ def get_regions_from_detection_result(detection_result: TextDetectionResult, hea bbox.rescale(list(reversed(heatmap.shape)), orig_size) detected_boxes = [bbox for bbox in new_boxes if bbox.area > 16] + + # Remove bboxes contained inside others, unless they're captions + contained_bbox = [] + for i, bbox in enumerate(detected_boxes): + for j, bbox2 in enumerate(detected_boxes): + if i == j: + continue + + if bbox2.intersection_pct(bbox) >= .95 and bbox2.label not in ["Caption"]: + contained_bbox.append(j) + + detected_boxes = [bbox for idx, bbox in enumerate(detected_boxes) if idx not in contained_bbox] + return detected_boxes diff --git a/surya/model/ordering/config.py b/surya/model/ordering/config.py new file mode 100644 index 0000000..fcf20f7 --- /dev/null +++ b/surya/model/ordering/config.py @@ -0,0 +1,8 @@ +from transformers import MBartConfig, DonutSwinConfig + + +class MBartOrderConfig(MBartConfig): + pass + +class VariableDonutSwinConfig(DonutSwinConfig): + pass \ No newline at end of file diff --git a/surya/model/ordering/decoder.py b/surya/model/ordering/decoder.py new file mode 100644 index 0000000..38b779f --- /dev/null +++ b/surya/model/ordering/decoder.py @@ -0,0 +1,556 @@ +import copy +from typing import Optional, List, Union, Tuple + +from transformers import MBartForCausalLM, MBartConfig +from torch import nn +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_attention_mask +from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions +from transformers.models.mbart.modeling_mbart import MBartPreTrainedModel, MBartDecoder, MBartLearnedPositionalEmbedding, MBartDecoderLayer +from surya.model.ordering.config import MBartOrderConfig +import torch +import math + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + From llama + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class MBartGQAttention(nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[MBartConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.num_kv_groups = self.num_heads // self.num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0, f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})" + assert embed_dim % self.num_kv_heads == 0, f"embed_dim ({self.embed_dim}) must be divisible by num_kv_heads ({self.num_kv_heads})" + + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def _shape_key_value(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape_key_value(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape_key_value(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + + # Expand kv heads, then match query shape + key_states = repeat_kv(key_states, self.num_kv_groups) + value_states = repeat_kv(value_states, self.num_kv_groups) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +MBART_ATTENTION_CLASSES = { + "eager": MBartGQAttention, + "flash_attention_2": None +} + + +class MBartOrderDecoderLayer(MBartDecoderLayer): + def __init__(self, config: MBartConfig): + nn.Module.__init__(self) + self.embed_dim = config.d_model + + self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + num_kv_heads=config.kv_heads, + dropout=config.attention_dropout, + is_decoder=True, + is_causal=True, + config=config, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation]( + self.embed_dim, + config.decoder_attention_heads, + num_kv_heads=config.kv_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + +class BboxEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.x1_embed = nn.Embedding(config.max_width, config.d_model) + self.y1_embed = nn.Embedding(config.max_height, config.d_model) + self.x2_embed = nn.Embedding(config.max_width, config.d_model) + self.y2_embed = nn.Embedding(config.max_height, config.d_model) + self.w_embed = nn.Embedding(config.max_width, config.d_model) + self.h_embed = nn.Embedding(config.max_height, config.d_model) + self.cx_embed = nn.Embedding(config.max_width, config.d_model) + self.cy_embed = nn.Embedding(config.max_height, config.d_model) + self.box_pos_embed = nn.Embedding(config.max_position_embeddings, config.d_model) + + def forward(self, boxes: torch.LongTensor, input_box_counts: torch.LongTensor, past_key_values_length: int): + x1, y1, x2, y2 = boxes.unbind(dim=-1) + # Shape is (batch_size, num_boxes/seq len, d_model) + w = x2 - x1 + h = y2 - y1 + # Center x and y in torch long tensors + cx = (x1 + x2) / 2 + cy = (y1 + y2) / 2 + cx = cx.long() + cy = cy.long() + + coord_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2) + embedded = coord_embeds + self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy) + + # Add in positional embeddings for the boxes + if past_key_values_length == 0: + for j in range(embedded.shape[0]): + box_start = input_box_counts[j, 0] + box_end = input_box_counts[j, 1] - 1 # Skip the sep token + box_count = box_end - box_start + embedded[j, box_start:box_end] = embedded[j, box_start:box_end] + self.box_pos_embed.weight[:box_count] + + return embedded + + +class MBartOrderDecoder(MBartDecoder): + def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None): + MBartPreTrainedModel.__init__(self, config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = BboxEmbedding(config) if embed_tokens is None else embed_tokens + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = MBartLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + ) + # Language-specific MoE goes at second and second-to-last layer + self.layers = nn.ModuleList([MBartOrderDecoderLayer(config) for _ in range(config.decoder_layers)]) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.layernorm_embedding = nn.LayerNorm(config.d_model) + self.layer_norm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_boxes: torch.LongTensor = None, + input_boxes_mask: Optional[torch.Tensor] = None, + input_boxes_counts: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_boxes is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_boxes is not None: + input = input_boxes + input_shape = input_boxes.size()[:-1] # Shape (batch_size, num_boxes) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_boxes, input_boxes_counts, past_key_values_length) * self.embed_scale + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = input_boxes_mask if (input_boxes_mask is not None and 0 in input_boxes_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + input_boxes_mask, input_shape, inputs_embeds, past_key_values_length + ) + + if past_key_values_length == 0: + box_ends = input_boxes_counts[:, 1] + box_starts = input_boxes_counts[:, 0] + input_shape_arranged = torch.arange(input_shape[1], device=attention_mask.device)[None, :] + # Enable all boxes to attend to each other (before the sep token) + # Ensure that the boxes are not attending to the padding tokens + boxes_end_mask = input_shape_arranged < box_ends[:, None] + boxes_start_mask = input_shape_arranged >= box_starts[:, None] + boxes_mask = boxes_end_mask & boxes_start_mask + boxes_mask = boxes_mask.unsqueeze(1).unsqueeze(1) # Enable proper broadcasting + attention_mask = attention_mask.masked_fill(boxes_mask, 0) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self._use_flash_attention_2: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # embed positions + positions = self.embed_positions(input, past_key_values_length) + + hidden_states = inputs_embeds + positions.to(inputs_embeds.device) + hidden_states = self.layernorm_embedding(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != len(self.layers): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {attn_mask.size()[0]}." + ) + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +class MBartOrderDecoderWrapper(MBartPreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the [`EncoderDecoderModel`] framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = MBartOrderDecoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +class MBartOrder(MBartForCausalLM): + config_class = MBartOrderConfig + _tied_weights_keys = [] + + def __init__(self, config): + config = copy.deepcopy(config) + config.is_decoder = True + config.is_encoder_decoder = False + MBartPreTrainedModel.__init__(self, config) + self.model = MBartOrderDecoderWrapper(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_boxes: torch.LongTensor = None, + input_boxes_mask: Optional[torch.Tensor] = None, + input_boxes_counts: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_boxes=input_boxes, + input_boxes_mask=input_boxes_mask, + input_boxes_counts=input_boxes_counts, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) \ No newline at end of file diff --git a/surya/model/ordering/encoder.py b/surya/model/ordering/encoder.py new file mode 100644 index 0000000..59a9678 --- /dev/null +++ b/surya/model/ordering/encoder.py @@ -0,0 +1,83 @@ +from torch import nn +import torch +from typing import Optional, Tuple, Union +import collections +import math + +from transformers import DonutSwinPreTrainedModel +from transformers.models.donut.modeling_donut_swin import DonutSwinPatchEmbeddings, DonutSwinEmbeddings, DonutSwinModel, \ + DonutSwinEncoder + +from surya.model.ordering.config import VariableDonutSwinConfig + +class VariableDonutSwinEmbeddings(DonutSwinEmbeddings): + """ + Construct the patch and position embeddings. Optionally, also the mask token. + """ + + def __init__(self, config, use_mask_token=False): + super().__init__(config, use_mask_token) + + self.patch_embeddings = DonutSwinPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.patch_grid = self.patch_embeddings.grid_size + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None + self.position_embeddings = None + + if config.use_absolute_embeddings: + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim)) + + self.row_embeddings = None + self.column_embeddings = None + if config.use_2d_embeddings: + self.row_embeddings = nn.Parameter(torch.zeros(1, self.patch_grid[0] + 1, config.embed_dim)) + self.column_embeddings = nn.Parameter(torch.zeros(1, self.patch_grid[1] + 1, config.embed_dim)) + + self.norm = nn.LayerNorm(config.embed_dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None + ) -> Tuple[torch.Tensor]: + + embeddings, output_dimensions = self.patch_embeddings(pixel_values) + # Layernorm across the last dimension (each patch is a single row) + embeddings = self.norm(embeddings) + batch_size, seq_len, embed_dim = embeddings.size() + + if bool_masked_pos is not None: + mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + if self.position_embeddings is not None: + embeddings = embeddings + self.position_embeddings[:, :seq_len, :] + + if self.row_embeddings is not None and self.column_embeddings is not None: + # Repeat the x position embeddings across the y axis like 0, 1, 2, 3, 0, 1, 2, 3, ... + row_embeddings = self.row_embeddings[:, :output_dimensions[0], :].repeat_interleave(output_dimensions[1], dim=1) + column_embeddings = self.column_embeddings[:, :output_dimensions[1], :].repeat(1, output_dimensions[0], 1) + + embeddings = embeddings + row_embeddings + column_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings, output_dimensions + + +class VariableDonutSwinModel(DonutSwinModel): + config_class = VariableDonutSwinConfig + def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + super().__init__(config) + self.config = config + self.num_layers = len(config.depths) + self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) + + self.embeddings = VariableDonutSwinEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid) + + self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() \ No newline at end of file diff --git a/surya/model/ordering/encoderdecoder.py b/surya/model/ordering/encoderdecoder.py new file mode 100644 index 0000000..f7351f1 --- /dev/null +++ b/surya/model/ordering/encoderdecoder.py @@ -0,0 +1,90 @@ +from typing import Optional, Union, Tuple, List + +import torch +from transformers import VisionEncoderDecoderModel +from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput + + +class OrderVisionEncoderDecoderModel(VisionEncoderDecoderModel): + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + decoder_input_boxes: torch.LongTensor = None, + # Shape (batch_size, num_boxes, 4), all coords scaled 0 - 1000, with 1001 as padding + decoder_input_boxes_mask: torch.LongTensor = None, # Shape (batch_size, num_boxes), 0 if padding, 1 otherwise + decoder_input_boxes_counts: torch.LongTensor = None, # Shape (batch_size), number of boxes in each image + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[List[List[int]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + if encoder_outputs is None: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + encoder_outputs = self.encoder( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs_encoder, + ) + elif isinstance(encoder_outputs, tuple): + encoder_outputs = BaseModelOutput(*encoder_outputs) + + encoder_hidden_states = encoder_outputs[0] + + # optionally project encoder_hidden_states + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) + + # else: + encoder_attention_mask = None + + # Decode + decoder_outputs = self.decoder( + input_boxes=decoder_input_boxes, + input_boxes_mask=decoder_input_boxes_mask, + input_boxes_counts=decoder_input_boxes_counts, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache, + past_key_values=past_key_values, + return_dict=return_dict, + labels=labels, + **kwargs_decoder, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqLMOutput( + loss=decoder_outputs.loss, + logits=decoder_outputs.logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) diff --git a/surya/model/ordering/model.py b/surya/model/ordering/model.py new file mode 100644 index 0000000..da551cb --- /dev/null +++ b/surya/model/ordering/model.py @@ -0,0 +1,34 @@ +from transformers import DetrConfig, BeitConfig, DetrImageProcessor, VisionEncoderDecoderConfig, AutoModelForCausalLM, \ + AutoModel +from surya.model.ordering.config import MBartOrderConfig, VariableDonutSwinConfig +from surya.model.ordering.decoder import MBartOrder +from surya.model.ordering.encoder import VariableDonutSwinModel +from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel +from surya.model.ordering.processor import OrderImageProcessor +from surya.settings import settings + + +def load_model(checkpoint=settings.ORDER_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE): + config = VisionEncoderDecoderConfig.from_pretrained(checkpoint) + + decoder_config = vars(config.decoder) + decoder = MBartOrderConfig(**decoder_config) + config.decoder = decoder + + encoder_config = vars(config.encoder) + encoder = VariableDonutSwinConfig(**encoder_config) + config.encoder = encoder + + # Get transformers to load custom model + AutoModel.register(MBartOrderConfig, MBartOrder) + AutoModelForCausalLM.register(MBartOrderConfig, MBartOrder) + AutoModel.register(VariableDonutSwinConfig, VariableDonutSwinModel) + + model = OrderVisionEncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype) + assert isinstance(model.decoder, MBartOrder) + assert isinstance(model.encoder, VariableDonutSwinModel) + + model = model.to(device) + model = model.eval() + print(f"Loading reading order model {checkpoint} on device {device} with dtype {dtype}") + return model \ No newline at end of file diff --git a/surya/model/ordering/processor.py b/surya/model/ordering/processor.py new file mode 100644 index 0000000..3262682 --- /dev/null +++ b/surya/model/ordering/processor.py @@ -0,0 +1,165 @@ +from copy import deepcopy +from typing import Dict, Union, Optional, List, Tuple + +import torch +from torch import TensorType +from transformers import DonutImageProcessor, DonutProcessor +from transformers.image_processing_utils import BatchFeature +from transformers.image_utils import PILImageResampling, ImageInput, ChannelDimension, make_list_of_images, \ + valid_images, to_numpy_array +import numpy as np +from PIL import Image +import PIL +from surya.settings import settings + + +def load_processor(checkpoint=settings.ORDER_MODEL_CHECKPOINT): + processor = OrderImageProcessor.from_pretrained(checkpoint) + processor.size = settings.ORDER_IMAGE_SIZE + box_size = 1024 + max_tokens = 256 + processor.token_sep_id = max_tokens + box_size + 1 + processor.token_pad_id = max_tokens + box_size + 2 + processor.max_boxes = settings.ORDER_MAX_BOXES - 1 + processor.box_size = {"height": box_size, "width": box_size} + return processor + + +class OrderImageProcessor(DonutImageProcessor): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.patch_size = kwargs.get("patch_size", (4, 4)) + + def process_inner(self, images: List[List]): + # This will be in list of lists format, with height x width x channel + assert isinstance(images[0], (list, np.ndarray)) + + # convert list of lists format to array + if isinstance(images[0], list): + # numpy unit8 needed for augmentation + np_images = [np.array(img, dtype=np.uint8) for img in images] + else: + np_images = [img.astype(np.uint8) for img in images] + np_images = [img.transpose(2, 0, 1) for img in np_images] # convert to CHW format + + assert np_images[0].shape[0] == 3 # RGB input images, channel dim last + + # Convert to float32 for rescale/normalize + np_images = [img.astype(np.float32) for img in np_images] + + # Rescale and normalize + np_images = [ + self.rescale(img, scale=self.rescale_factor, input_data_format=ChannelDimension.FIRST) + for img in np_images + ] + np_images = [ + self.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST) + for img in np_images + ] + + return np_images + + def process_boxes(self, boxes): + padded_boxes = [] + box_masks = [] + box_counts = [] + for b in boxes: + # Left pad for generation + padded_b = deepcopy(b) + padded_b.append([self.token_sep_id] * 4) # Sep token to indicate start of label predictions + padded_boxes.append(padded_b) + + max_boxes = max(len(b) for b in padded_boxes) + for i in range(len(padded_boxes)): + pad_len = max_boxes - len(padded_boxes[i]) + box_len = len(padded_boxes[i]) + box_mask = [0] * pad_len + [1] * box_len + padded_box = [[self.token_pad_id] * 4] * pad_len + padded_boxes[i] + padded_boxes[i] = padded_box + box_masks.append(box_mask) + box_counts.append([pad_len, max_boxes]) + + return padded_boxes, box_masks, box_counts + + def resize_img_and_boxes(self, img, boxes): + orig_dim = img.size + new_size = (self.size["width"], self.size["height"]) + img.thumbnail(new_size, Image.Resampling.LANCZOS) # Shrink largest dimension to fit new size + img = img.resize(new_size, Image.Resampling.LANCZOS) # Stretch smaller dimension to fit new size + + img = np.asarray(img, dtype=np.uint8) + + width, height = orig_dim + box_width, box_height = self.box_size["width"], self.box_size["height"] + for box in boxes: + # Rescale to 0-1024 + box[0] = box[0] / width * box_width + box[1] = box[1] / height * box_height + box[2] = box[2] / width * box_width + box[3] = box[3] / height * box_height + + if box[0] < 0: + box[0] = 0 + if box[1] < 0: + box[1] = 0 + if box[2] > box_width: + box[2] = box_width + if box[3] > box_height: + box[3] = box_height + + return img, boxes + + def preprocess( + self, + images: ImageInput, + boxes: List[List[int]], + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_thumbnail: bool = None, + do_align_long_axis: bool = None, + do_pad: bool = None, + random_padding: bool = False, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> PIL.Image.Image: + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + new_images = [] + new_boxes = [] + for img, box in zip(images, boxes): + if len(box) > self.max_boxes: + raise ValueError(f"Too many boxes, max is {self.max_boxes}") + img, box = self.resize_img_and_boxes(img, box) + new_images.append(img) + new_boxes.append(box) + + images = new_images + boxes = new_boxes + + # Convert to numpy for later processing steps + images = [to_numpy_array(image) for image in images] + + images = self.process_inner(images) + boxes, box_mask, box_counts = self.process_boxes(boxes) + data = { + "pixel_values": images, + "input_boxes": boxes, + "input_boxes_mask": box_mask, + "input_boxes_counts": box_counts, + } + return BatchFeature(data=data, tensor_type=return_tensors) \ No newline at end of file diff --git a/surya/ordering.py b/surya/ordering.py new file mode 100644 index 0000000..854e1a8 --- /dev/null +++ b/surya/ordering.py @@ -0,0 +1,137 @@ +from copy import deepcopy +from typing import List, Optional +import torch +from PIL import Image + +from surya.schema import OrderBox, OrderResult +from surya.settings import settings +from tqdm import tqdm +import numpy as np + + +def get_batch_size(): + batch_size = settings.ORDER_BATCH_SIZE + if batch_size is None: + batch_size = 8 + if settings.TORCH_DEVICE_MODEL == "mps": + batch_size = 8 + if settings.TORCH_DEVICE_MODEL == "cuda": + batch_size = 32 + return batch_size + + +def rank_elements(arr): + enumerated_and_sorted = sorted(enumerate(arr), key=lambda x: x[1]) + rank = [0] * len(arr) + + for rank_value, (original_index, value) in enumerate(enumerated_and_sorted): + rank[original_index] = rank_value + + return rank + + +def batch_ordering(images: List, bboxes: List[List[List[float]]], model, processor) -> List[OrderResult]: + assert all([isinstance(image, Image.Image) for image in images]) + assert len(images) == len(bboxes) + batch_size = get_batch_size() + + images = [image.convert("RGB") for image in images] + + output_order = [] + for i in tqdm(range(0, len(images), batch_size), desc="Finding reading order"): + batch_bboxes = deepcopy(bboxes[i:i+batch_size]) + batch_images = images[i:i+batch_size] + orig_sizes = [image.size for image in batch_images] + model_inputs = processor(images=batch_images, boxes=batch_bboxes) + + batch_pixel_values = model_inputs["pixel_values"] + batch_bboxes = model_inputs["input_boxes"] + batch_bbox_mask = model_inputs["input_boxes_mask"] + batch_bbox_counts = model_inputs["input_boxes_counts"] + + batch_bboxes = torch.from_numpy(np.array(batch_bboxes, dtype=np.int32)).to(model.device) + batch_bbox_mask = torch.from_numpy(np.array(batch_bbox_mask, dtype=np.int32)).to(model.device) + batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=model.dtype).to(model.device) + batch_bbox_counts = torch.tensor(np.array(batch_bbox_counts), dtype=torch.long).to(model.device) + + token_count = 0 + past_key_values = None + encoder_outputs = None + batch_predictions = [[] for _ in range(len(batch_images))] + done = [False for _ in range(len(batch_images))] + while token_count < settings.ORDER_MAX_BOXES: + with torch.inference_mode(): + return_dict = model( + pixel_values=batch_pixel_values, + decoder_input_boxes=batch_bboxes, + decoder_input_boxes_mask=batch_bbox_mask, + decoder_input_boxes_counts=batch_bbox_counts, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + ) + logits = return_dict["logits"].detach().cpu() + + last_tokens = [] + last_token_mask = [] + min_val = torch.finfo(model.dtype).min + for j in range(logits.shape[0]): + label_count = batch_bbox_counts[j, 1] - batch_bbox_counts[j, 0] - 1 # Subtract 1 for the sep token + new_logits = logits[j, -1].clone() + new_logits[batch_predictions[j]] = min_val # Mask out already predicted tokens, we can only predict each token once + new_logits[label_count:] = min_val # Mask out all logit positions above the number of bboxes + pred = int(torch.argmax(new_logits, dim=-1).item()) + + # Add one to avoid colliding with the 1000 height/width token for bboxes + last_tokens.append([[pred + processor.box_size["height"] + 1] * 4]) + if len(batch_predictions[j]) == label_count - 1: # Minus one since we're appending the final label + last_token_mask.append([0]) + batch_predictions[j].append(pred) + done[j] = True + elif len(batch_predictions[j]) < label_count - 1: + last_token_mask.append([1]) + batch_predictions[j].append(pred) # Get rank prediction for given position + else: + last_token_mask.append([0]) + + # Break when we finished generating all sequences + if all(done): + break + + past_key_values = return_dict["past_key_values"] + encoder_outputs = (return_dict["encoder_last_hidden_state"],) + + batch_bboxes = torch.tensor(last_tokens, dtype=torch.long).to(model.device) + token_bbox_mask = torch.tensor(last_token_mask, dtype=torch.long).to(model.device) + batch_bbox_mask = torch.cat([batch_bbox_mask, token_bbox_mask], dim=1) + token_count += 1 + + for j, row_pred in enumerate(batch_predictions): + row_bboxes = bboxes[i+j] + assert len(row_pred) == len(row_bboxes), f"Mismatch between logits and bboxes. Logits: {len(row_pred)}, Bboxes: {len(row_bboxes)}" + + orig_size = orig_sizes[j] + ranks = [0] * len(row_bboxes) + + for box_idx in range(len(row_bboxes)): + ranks[row_pred[box_idx]] = box_idx + + order_boxes = [] + for row_bbox, rank in zip(row_bboxes, ranks): + order_box = OrderBox( + bbox=row_bbox, + position=rank, + ) + order_boxes.append(order_box) + + result = OrderResult( + bboxes=order_boxes, + image_bbox=[0, 0, orig_size[0], orig_size[1]], + ) + output_order.append(result) + return output_order + + + + + + diff --git a/surya/postprocessing/heatmap.py b/surya/postprocessing/heatmap.py index 835b3fe..ef401fd 100644 --- a/surya/postprocessing/heatmap.py +++ b/surya/postprocessing/heatmap.py @@ -184,7 +184,7 @@ def get_and_clean_boxes(textmap, processor_size, image_size, text_threshold=None return bboxes -def draw_bboxes_on_image(bboxes, image): +def draw_bboxes_on_image(bboxes, image, labels=None): draw = ImageDraw.Draw(image) for bbox in bboxes: @@ -193,10 +193,10 @@ def draw_bboxes_on_image(bboxes, image): return image -def draw_polys_on_image(corners, image, labels=None, box_padding=-1, label_offset=1): +def draw_polys_on_image(corners, image, labels=None, box_padding=-1, label_offset=1, label_font_size=10): draw = ImageDraw.Draw(image) font_path = get_font_path() - label_font = ImageFont.truetype(font_path, 10) + label_font = ImageFont.truetype(font_path, label_font_size) for i in range(len(corners)): poly = corners[i] diff --git a/surya/schema.py b/surya/schema.py index efcea77..545e5d1 100644 --- a/surya/schema.py +++ b/surya/schema.py @@ -115,11 +115,19 @@ def width(self): def area(self): return self.width * self.height + @property + def polygon(self): + return [[self.bbox[0], self.bbox[1]], [self.bbox[2], self.bbox[1]], [self.bbox[2], self.bbox[3]], [self.bbox[0], self.bbox[3]]] + class LayoutBox(PolygonBox): label: str +class OrderBox(Bbox): + position: int + + class ColumnLine(Bbox): vertical: bool horizontal: bool @@ -149,3 +157,8 @@ class LayoutResult(BaseModel): bboxes: List[LayoutBox] segmentation_map: Any image_bbox: List[float] + + +class OrderResult(BaseModel): + bboxes: List[OrderBox] + image_bbox: List[float] diff --git a/surya/settings.py b/surya/settings.py index 7ad957e..c72f7c3 100644 --- a/surya/settings.py +++ b/surya/settings.py @@ -68,9 +68,16 @@ def TORCH_DEVICE_DETECTION(self) -> str: RECOGNITION_PAD_VALUE: int = 0 # Should be 0 or 255 # Layout - LAYOUT_MODEL_CHECKPOINT: str = "vikp/surya_layout" + LAYOUT_MODEL_CHECKPOINT: str = "vikp/surya_layout2" LAYOUT_BENCH_DATASET_NAME: str = "vikp/publaynet_bench" + # Ordering + ORDER_MODEL_CHECKPOINT: str = "vikp/surya_order" + ORDER_IMAGE_SIZE: Dict = {"height": 1024, "width": 1024} + ORDER_MAX_BOXES: int = 256 + ORDER_BATCH_SIZE: Optional[int] = None # Defaults to 4 for CPU/MPS, 32 otherwise + ORDER_BENCH_DATASET_NAME: str = "vikp/order_bench" + # Tesseract (for benchmarks only) TESSDATA_PREFIX: Optional[str] = None