diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index c10c70f..76c34af 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -11,10 +11,10 @@ jobs: - name: Set up Python 3.11 uses: actions/setup-python@v4 with: - python-version: 3.11 + python-version: "3.11" - name: Install python dependencies run: | - pip install poetry + pip install poetry==1.8.3 poetry install - name: Build package run: | diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6620ab3..0b230d6 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,7 +13,7 @@ jobs: - name: Set up Python 3.11 uses: actions/setup-python@v4 with: - python-version: 3.11 + python-version: "3.11" - name: Install apt dependencies run: | sudo apt-get update @@ -24,6 +24,9 @@ jobs: poetry install poetry remove torch poetry run pip install torch --index-url https://download.pytorch.org/whl/cpu + - name: Run ruff lint for checking code style + run: | + poetry run ruff check --select I . - name: Run detection benchmark test run: | poetry run python benchmark/detection.py --max 2 @@ -40,6 +43,3 @@ jobs: run: | poetry run python benchmark/ordering.py --max 5 poetry run python scripts/verify_benchmark_scores.py results/benchmark/order_bench/results.json --bench_type ordering - - - diff --git a/benchmark/detection.py b/benchmark/detection.py index d149abb..623479c 100644 --- a/benchmark/detection.py +++ b/benchmark/detection.py @@ -2,20 +2,21 @@ import collections import copy import json +import os +import time + +import datasets +from tabulate import tabulate from surya.benchmark.bbox import get_pdf_lines from surya.benchmark.metrics import precision_recall from surya.benchmark.tesseract import tesseract_parallel -from surya.model.detection.model import load_model, load_processor -from surya.input.processing import open_pdf, get_page_images, convert_if_not_rgb from surya.detection import batch_text_detection +from surya.input.processing import convert_if_not_rgb, get_page_images, open_pdf +from surya.model.detection.model import load_model, load_processor from surya.postprocessing.heatmap import draw_polys_on_image from surya.postprocessing.util import rescale_bbox from surya.settings import settings -import os -import time -from tabulate import tabulate -import datasets def main(): diff --git a/benchmark/gcloud_label.py b/benchmark/gcloud_label.py index 5c9012d..489722f 100644 --- a/benchmark/gcloud_label.py +++ b/benchmark/gcloud_label.py @@ -1,14 +1,15 @@ import argparse +import hashlib +import io import json +import os from collections import defaultdict import datasets -from surya.settings import settings from google.cloud import vision -import hashlib -import os from tqdm import tqdm -import io + +from surya.settings import settings DATA_DIR = os.path.join(settings.BASE_DIR, settings.DATA_DIR) RESULT_DIR = os.path.join(settings.BASE_DIR, settings.RESULT_DIR) @@ -146,4 +147,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/benchmark/layout.py b/benchmark/layout.py index d2370d0..b620715 100644 --- a/benchmark/layout.py +++ b/benchmark/layout.py @@ -2,19 +2,20 @@ import collections import copy import json +import os +import time + +import datasets +from tabulate import tabulate from surya.benchmark.metrics import precision_recall from surya.detection import batch_text_detection -from surya.model.detection.model import load_model, load_processor -from surya.input.processing import open_pdf, get_page_images, convert_if_not_rgb +from surya.input.processing import convert_if_not_rgb, get_page_images, open_pdf from surya.layout import batch_layout_detection -from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image +from surya.model.detection.model import load_model, load_processor +from surya.postprocessing.heatmap import draw_bboxes_on_image, draw_polys_on_image from surya.postprocessing.util import rescale_bbox from surya.settings import settings -import os -import time -from tabulate import tabulate -import datasets def main(): diff --git a/benchmark/ordering.py b/benchmark/ordering.py index fc6a333..ed291a2 100644 --- a/benchmark/ordering.py +++ b/benchmark/ordering.py @@ -2,16 +2,17 @@ import collections import copy import json +import os +import time +import datasets + +from surya.benchmark.metrics import rank_accuracy from surya.input.processing import convert_if_not_rgb from surya.model.ordering.model import load_model from surya.model.ordering.processor import load_processor from surya.ordering import batch_ordering from surya.settings import settings -from surya.benchmark.metrics import rank_accuracy -import os -import time -import datasets def main(): diff --git a/benchmark/profile.sh b/benchmark/profile.sh index f6cecc4..8f2e88d 100644 --- a/benchmark/profile.sh +++ b/benchmark/profile.sh @@ -1 +1,2 @@ -python -m cProfile -s time -o data/profile.pstats detect_text.py data/benchmark/nyt2.pdf --max 10 \ No newline at end of file +#!/bin/bash +python -m cProfile -s time -o data/profile.pstats detect_text.py data/benchmark/nyt2.pdf --max 10 diff --git a/benchmark/pymupdf_test.py b/benchmark/pymupdf_test.py index 7d82b61..3a3129a 100644 --- a/benchmark/pymupdf_test.py +++ b/benchmark/pymupdf_test.py @@ -2,9 +2,8 @@ import os from surya.benchmark.bbox import get_pdf_lines +from surya.input.processing import get_page_images, open_pdf from surya.postprocessing.heatmap import draw_bboxes_on_image - -from surya.input.processing import open_pdf, get_page_images from surya.settings import settings diff --git a/benchmark/recognition.py b/benchmark/recognition.py index 28f1be8..17d6b0e 100644 --- a/benchmark/recognition.py +++ b/benchmark/recognition.py @@ -1,22 +1,28 @@ import argparse +import json +import os +import time from collections import defaultdict +import datasets import torch +from tabulate import tabulate from benchmark.scoring import overlap_score +from surya.benchmark.tesseract import ( + TESS_CODE_TO_LANGUAGE, + surya_lang_to_tesseract, + tesseract_ocr_parallel, +) from surya.input.processing import convert_if_not_rgb +from surya.languages import CODE_TO_LANGUAGE from surya.model.recognition.model import load_model as load_recognition_model -from surya.model.recognition.processor import load_processor as load_recognition_processor +from surya.model.recognition.processor import ( + load_processor as load_recognition_processor, +) from surya.ocr import run_recognition from surya.postprocessing.text import draw_text_on_image from surya.settings import settings -from surya.languages import CODE_TO_LANGUAGE -from surya.benchmark.tesseract import tesseract_ocr_parallel, surya_lang_to_tesseract, TESS_CODE_TO_LANGUAGE -import os -import datasets -import json -import time -from tabulate import tabulate KEY_LANGUAGES = ["Chinese", "Spanish", "English", "Arabic", "Hindi", "Bengali", "Russian", "Japanese"] diff --git a/benchmark/scoring.py b/benchmark/scoring.py index 50bf089..16f3870 100644 --- a/benchmark/scoring.py +++ b/benchmark/scoring.py @@ -19,4 +19,4 @@ def overlap_score(pred_lines: List[str], reference_lines: List[str]): line_weights.append(line_weight) line_scores = [line_scores[i] * line_weights[i] for i in range(len(line_scores))] - return sum(line_scores) / sum(line_weights) \ No newline at end of file + return sum(line_scores) / sum(line_weights) diff --git a/benchmark/tesseract_test.py b/benchmark/tesseract_test.py index 49ca86b..12b84fe 100644 --- a/benchmark/tesseract_test.py +++ b/benchmark/tesseract_test.py @@ -2,9 +2,8 @@ import os from surya.benchmark.tesseract import tesseract_bboxes +from surya.input.processing import get_page_images, open_pdf from surya.postprocessing.heatmap import draw_bboxes_on_image - -from surya.input.processing import open_pdf, get_page_images from surya.settings import settings diff --git a/benchmark/viz.sh b/benchmark/viz.sh index f642982..6d1816c 100644 --- a/benchmark/viz.sh +++ b/benchmark/viz.sh @@ -1 +1,2 @@ -snakeviz data/profile.pstats \ No newline at end of file +#!/bin/bash +snakeviz data/profile.pstats diff --git a/detect_layout.py b/detect_layout.py index 8e791b7..d2a080f 100644 --- a/detect_layout.py +++ b/detect_layout.py @@ -1,15 +1,15 @@ import argparse import copy import json +import os from collections import defaultdict from surya.detection import batch_text_detection -from surya.input.load import load_from_folder, load_from_file +from surya.input.load import load_from_file, load_from_folder from surya.layout import batch_layout_detection from surya.model.detection.model import load_model, load_processor from surya.postprocessing.heatmap import draw_polys_on_image from surya.settings import settings -import os def main(): diff --git a/detect_text.py b/detect_text.py index e2ecc4d..e952b5d 100644 --- a/detect_text.py +++ b/detect_text.py @@ -1,17 +1,18 @@ import argparse import copy import json +import os import time from collections import defaultdict -from surya.input.load import load_from_folder, load_from_file -from surya.model.detection.model import load_model, load_processor +from tqdm import tqdm + from surya.detection import batch_text_detection +from surya.input.load import load_from_file, load_from_folder +from surya.model.detection.model import load_model, load_processor from surya.postprocessing.affinity import draw_lines_on_image from surya.postprocessing.heatmap import draw_polys_on_image from surya.settings import settings -import os -from tqdm import tqdm def main(): diff --git a/ocr_app.py b/ocr_app.py index e8f20e3..ee9c6f3 100644 --- a/ocr_app.py +++ b/ocr_app.py @@ -3,23 +3,25 @@ import pypdfium2 import streamlit as st +from PIL import Image + from surya.detection import batch_text_detection +from surya.input.langs import replace_lang_with_code +from surya.languages import CODE_TO_LANGUAGE from surya.layout import batch_layout_detection from surya.model.detection.model import load_model, load_processor +from surya.model.ordering.model import load_model as load_order_model +from surya.model.ordering.processor import load_processor as load_order_processor from surya.model.recognition.model import load_model as load_rec_model from surya.model.recognition.processor import load_processor as load_rec_processor -from surya.model.ordering.processor import load_processor as load_order_processor -from surya.model.ordering.model import load_model as load_order_model +from surya.ocr import run_ocr from surya.ordering import batch_ordering from surya.postprocessing.heatmap import draw_polys_on_image -from surya.ocr import run_ocr from surya.postprocessing.text import draw_text_on_image -from PIL import Image -from surya.languages import CODE_TO_LANGUAGE -from surya.input.langs import replace_lang_with_code -from surya.schema import OCRResult, TextDetectionResult, LayoutResult, OrderResult +from surya.schema import LayoutResult, OCRResult, OrderResult, TextDetectionResult from surya.settings import settings + @st.cache_resource() def load_det_cached(): checkpoint = settings.DETECTOR_MODEL_CHECKPOINT @@ -181,4 +183,4 @@ def page_count(pdf_file): st.json(pred.model_dump(), expanded=True) with col2: - st.image(pil_image, caption="Uploaded Image", use_column_width=True) \ No newline at end of file + st.image(pil_image, caption="Uploaded Image", use_column_width=True) diff --git a/ocr_text.py b/ocr_text.py index e624c14..69a7ecb 100644 --- a/ocr_text.py +++ b/ocr_text.py @@ -1,15 +1,18 @@ -import os import argparse import json +import os from collections import defaultdict import torch -from surya.input.langs import replace_lang_with_code, get_unique_langs -from surya.input.load import load_from_folder, load_from_file, load_lang_file -from surya.model.detection.model import load_model as load_detection_model, load_processor as load_detection_processor +from surya.input.langs import get_unique_langs, replace_lang_with_code +from surya.input.load import load_from_file, load_from_folder, load_lang_file +from surya.model.detection.model import load_model as load_detection_model +from surya.model.detection.model import load_processor as load_detection_processor from surya.model.recognition.model import load_model as load_recognition_model -from surya.model.recognition.processor import load_processor as load_recognition_processor +from surya.model.recognition.processor import ( + load_processor as load_recognition_processor, +) from surya.model.recognition.tokenizer import _tokenize from surya.ocr import run_ocr from surya.postprocessing.text import draw_text_on_image diff --git a/poetry.lock b/poetry.lock index 28087ff..ca55048 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2315,11 +2315,11 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""}, - {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] [[package]] @@ -2385,8 +2385,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.22.4", markers = "python_version < \"3.11\""}, - {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -2987,6 +2987,7 @@ python-versions = ">=3.8" files = [ {file = "PyMuPDFb-1.24.6-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:21e3ed890f736def68b9a031122ae1fb854d5cb9a53aa144b6e2ca3092416a6b"}, {file = "PyMuPDFb-1.24.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:8704d2dfadc9448ce184597d8b0f9c30143e379ac948a517f9c4db7c0c71ed51"}, + {file = "PyMuPDFb-1.24.6-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:01662584d5cfa7a91f77585f13fc23a12291cfd76a57e0a28dd5a56bf521cb2c"}, {file = "PyMuPDFb-1.24.6-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1f7657353529ae3f88575c83ee49eac9adea311a034b9c97248a65cee7df0e5"}, {file = "PyMuPDFb-1.24.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:cebc2cedb870d1e1168e2f502eb06f05938f6df69103b0853a2b329611ec19a7"}, {file = "PyMuPDFb-1.24.6-py3-none-win32.whl", hash = "sha256:ac4b865cd1e239db04674f85e02844a0e405f8255ee7a74dfee0d86aad0d3576"}, @@ -3144,7 +3145,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -3701,6 +3701,33 @@ files = [ {file = "rpds_py-0.18.1.tar.gz", hash = "sha256:dc48b479d540770c811fbd1eb9ba2bb66951863e448efec2e2c102625328e92f"}, ] +[[package]] +name = "ruff" +version = "0.5.5" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +files = [ + {file = "ruff-0.5.5-py3-none-linux_armv6l.whl", hash = "sha256:605d589ec35d1da9213a9d4d7e7a9c761d90bba78fc8790d1c5e65026c1b9eaf"}, + {file = "ruff-0.5.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:00817603822a3e42b80f7c3298c8269e09f889ee94640cd1fc7f9329788d7bf8"}, + {file = "ruff-0.5.5-py3-none-macosx_11_0_arm64.whl", hash = "sha256:187a60f555e9f865a2ff2c6984b9afeffa7158ba6e1eab56cb830404c942b0f3"}, + {file = "ruff-0.5.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe26fc46fa8c6e0ae3f47ddccfbb136253c831c3289bba044befe68f467bfb16"}, + {file = "ruff-0.5.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4ad25dd9c5faac95c8e9efb13e15803cd8bbf7f4600645a60ffe17c73f60779b"}, + {file = "ruff-0.5.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f70737c157d7edf749bcb952d13854e8f745cec695a01bdc6e29c29c288fc36e"}, + {file = "ruff-0.5.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:cfd7de17cef6ab559e9f5ab859f0d3296393bc78f69030967ca4d87a541b97a0"}, + {file = "ruff-0.5.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a09b43e02f76ac0145f86a08e045e2ea452066f7ba064fd6b0cdccb486f7c3e7"}, + {file = "ruff-0.5.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d0b856cb19c60cd40198be5d8d4b556228e3dcd545b4f423d1ad812bfdca5884"}, + {file = "ruff-0.5.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3687d002f911e8a5faf977e619a034d159a8373514a587249cc00f211c67a091"}, + {file = "ruff-0.5.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:ac9dc814e510436e30d0ba535f435a7f3dc97f895f844f5b3f347ec8c228a523"}, + {file = "ruff-0.5.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:af9bdf6c389b5add40d89b201425b531e0a5cceb3cfdcc69f04d3d531c6be74f"}, + {file = "ruff-0.5.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:d40a8533ed545390ef8315b8e25c4bb85739b90bd0f3fe1280a29ae364cc55d8"}, + {file = "ruff-0.5.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:cab904683bf9e2ecbbe9ff235bfe056f0eba754d0168ad5407832928d579e7ab"}, + {file = "ruff-0.5.5-py3-none-win32.whl", hash = "sha256:696f18463b47a94575db635ebb4c178188645636f05e934fdf361b74edf1bb2d"}, + {file = "ruff-0.5.5-py3-none-win_amd64.whl", hash = "sha256:50f36d77f52d4c9c2f1361ccbfbd09099a1b2ea5d2b2222c586ab08885cf3445"}, + {file = "ruff-0.5.5-py3-none-win_arm64.whl", hash = "sha256:3191317d967af701f1b73a31ed5788795936e423b7acce82a2b63e26eb3e89d6"}, + {file = "ruff-0.5.5.tar.gz", hash = "sha256:cc5516bdb4858d972fbc31d246bdb390eab8df1a26e2353be2dbc0c2d7f5421a"}, +] + [[package]] name = "safetensors" version = "0.4.3" @@ -4819,4 +4846,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13,!=3.9.7" -content-hash = "d250e5223075069c0561f95e970624731feb7ddc20f1bc7b8ef6dd826a8f3085" +content-hash = "822edb2014af50155b4a3acb81385ff4823aac092aaef60be5c42761c87c3dce" diff --git a/pyproject.toml b/pyproject.toml index 26cd6bd..40a1e2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ rapidfuzz = "^3.6.1" arabic-reshaper = "^3.0.0" streamlit = "^1.31.0" playwright = "^1.41.2" +ruff = "^0.5.5" [tool.poetry.scripts] surya_detect = "detect_text:main" diff --git a/reading_order.py b/reading_order.py index cc30ad2..c12016d 100644 --- a/reading_order.py +++ b/reading_order.py @@ -1,13 +1,14 @@ -import os import argparse import copy import json +import os from collections import defaultdict from surya.detection import batch_text_detection -from surya.input.load import load_from_folder, load_from_file +from surya.input.load import load_from_file, load_from_folder from surya.layout import batch_layout_detection -from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor +from surya.model.detection.model import load_model as load_det_model +from surya.model.detection.model import load_processor as load_det_processor from surya.model.ordering.model import load_model from surya.model.ordering.processor import load_processor from surya.ordering import batch_ordering diff --git a/run_ocr_app.py b/run_ocr_app.py index 50d2840..9b1f35b 100644 --- a/run_ocr_app.py +++ b/run_ocr_app.py @@ -1,6 +1,6 @@ import argparse -import subprocess import os +import subprocess def run_app(): @@ -10,4 +10,4 @@ def run_app(): subprocess.run(cmd, env={**os.environ, "IN_STREAMLIT": "true"}) if __name__ == "__main__": - run_app() \ No newline at end of file + run_app() diff --git a/scripts/verify_benchmark_scores.py b/scripts/verify_benchmark_scores.py index 4e8db55..5d2d8f8 100644 --- a/scripts/verify_benchmark_scores.py +++ b/scripts/verify_benchmark_scores.py @@ -1,5 +1,5 @@ -import json import argparse +import json def verify_layout(data): diff --git a/surya/benchmark/bbox.py b/surya/benchmark/bbox.py index b7593e8..653db60 100644 --- a/surya/benchmark/bbox.py +++ b/surya/benchmark/bbox.py @@ -1,4 +1,5 @@ import fitz as pymupdf + from surya.postprocessing.util import rescale_bbox @@ -19,4 +20,4 @@ def get_pdf_lines(pdf_path, img_sizes): line_boxes = [rescale_bbox(bbox, (pwidth, pheight), img_size) for bbox in line_boxes] page_lines.append(line_boxes) - return page_lines \ No newline at end of file + return page_lines diff --git a/surya/benchmark/metrics.py b/surya/benchmark/metrics.py index afcb417..29c786e 100644 --- a/surya/benchmark/metrics.py +++ b/surya/benchmark/metrics.py @@ -1,8 +1,9 @@ +from concurrent.futures import ProcessPoolExecutor from functools import partial from itertools import repeat import numpy as np -from concurrent.futures import ProcessPoolExecutor + def intersection_area(box1, box2): x_left = max(box1[0], box2[0]) @@ -136,4 +137,4 @@ def rank_accuracy(preds, references): if (i, j, ref > ref2) in pairs: correct += 1 - return correct / len(pairs) \ No newline at end of file + return correct / len(pairs) diff --git a/surya/benchmark/tesseract.py b/surya/benchmark/tesseract.py index a2d025e..f56da3d 100644 --- a/surya/benchmark/tesseract.py +++ b/surya/benchmark/tesseract.py @@ -1,3 +1,5 @@ +import os +from concurrent.futures import ProcessPoolExecutor from typing import List, Optional import numpy as np @@ -5,13 +7,11 @@ from pytesseract import Output from tqdm import tqdm -from surya.input.processing import slice_bboxes_from_image -from surya.settings import settings -import os -from concurrent.futures import ProcessPoolExecutor from surya.detection import get_batch_size as get_det_batch_size -from surya.recognition import get_batch_size as get_rec_batch_size +from surya.input.processing import slice_bboxes_from_image from surya.languages import CODE_TO_LANGUAGE +from surya.recognition import get_batch_size as get_rec_batch_size +from surya.settings import settings def surya_lang_to_tesseract(code: str) -> Optional[str]: diff --git a/surya/detection.py b/surya/detection.py index 08e9852..706d282 100644 --- a/surya/detection.py +++ b/surya/detection.py @@ -1,18 +1,23 @@ +from concurrent.futures import ProcessPoolExecutor from typing import List, Tuple -import torch import numpy as np +import torch +import torch.nn.functional as F from PIL import Image +from tqdm import tqdm +from surya.input.processing import ( + convert_if_not_rgb, + get_total_splits, + prepare_image_detection, + split_image, +) from surya.model.detection.model import EfficientViTForSemanticSegmentation -from surya.postprocessing.heatmap import get_and_clean_boxes from surya.postprocessing.affinity import get_vertical_lines -from surya.input.processing import prepare_image_detection, split_image, get_total_splits, convert_if_not_rgb +from surya.postprocessing.heatmap import get_and_clean_boxes from surya.schema import TextDetectionResult from surya.settings import settings -from tqdm import tqdm -from concurrent.futures import ProcessPoolExecutor -import torch.nn.functional as F def get_batch_size(): diff --git a/surya/input/langs.py b/surya/input/langs.py index e347408..780bacd 100644 --- a/surya/input/langs.py +++ b/surya/input/langs.py @@ -1,5 +1,6 @@ from typing import List -from surya.languages import LANGUAGE_TO_CODE, CODE_TO_LANGUAGE + +from surya.languages import CODE_TO_LANGUAGE, LANGUAGE_TO_CODE def replace_lang_with_code(langs: List[str]): @@ -16,4 +17,4 @@ def get_unique_langs(langs: List[List[str]]): for lang in lang_list: if lang not in uniques: uniques.append(lang) - return uniques \ No newline at end of file + return uniques diff --git a/surya/input/load.py b/surya/input/load.py index aa8f1a1..0014eaf 100644 --- a/surya/input/load.py +++ b/surya/input/load.py @@ -1,10 +1,11 @@ -import PIL - -from surya.input.processing import open_pdf, get_page_images +import json import os + import filetype +import PIL from PIL import Image -import json + +from surya.input.processing import get_page_images, open_pdf def get_name_from_path(path): diff --git a/surya/input/processing.py b/surya/input/processing.py index 9933279..968bfcd 100644 --- a/surya/input/processing.py +++ b/surya/input/processing.py @@ -1,11 +1,12 @@ +import math from typing import List import cv2 import numpy as np -import math import pypdfium2 -from PIL import Image, ImageOps, ImageDraw import torch +from PIL import Image, ImageDraw, ImageOps + from surya.settings import settings @@ -113,4 +114,4 @@ def slice_and_pad_poly(image_array: np.array, coordinates): cropped_polygon[mask == 0] = settings.RECOGNITION_PAD_VALUE rectangle_image = Image.fromarray(cropped_polygon) - return rectangle_image \ No newline at end of file + return rectangle_image diff --git a/surya/layout.py b/surya/layout.py index 89f2a65..2e8b804 100644 --- a/surya/layout.py +++ b/surya/layout.py @@ -1,12 +1,17 @@ from collections import defaultdict from concurrent.futures import ProcessPoolExecutor from typing import List, Optional -from PIL import Image + import numpy as np +from PIL import Image from surya.detection import batch_detection -from surya.postprocessing.heatmap import keep_largest_boxes, get_and_clean_boxes, get_detected_boxes -from surya.schema import LayoutResult, LayoutBox, TextDetectionResult +from surya.postprocessing.heatmap import ( + get_and_clean_boxes, + get_detected_boxes, + keep_largest_boxes, +) +from surya.schema import LayoutBox, LayoutResult, TextDetectionResult from surya.settings import settings @@ -201,4 +206,4 @@ def batch_layout_detection(images: List, model, processor, detection_results: Op for future in futures: results.append(future.result()) - return results \ No newline at end of file + return results diff --git a/surya/model/detection/config.py b/surya/model/detection/config.py index bdbe0a1..e3b431c 100644 --- a/surya/model/detection/config.py +++ b/surya/model/detection/config.py @@ -48,4 +48,4 @@ def __init__( self.decoder_layer_hidden_size = decoder_layer_hidden_size self.semantic_loss_ignore_index = semantic_loss_ignore_index - self.initializer_range = initializer_range \ No newline at end of file + self.initializer_range = initializer_range diff --git a/surya/model/detection/model.py b/surya/model/detection/model.py index 7199a29..ca09d81 100644 --- a/surya/model/detection/model.py +++ b/surya/model/detection/model.py @@ -7,13 +7,12 @@ Original code (that timm adapted from) at https://github.com/mit-han-lab/efficientvit """ -from typing import Optional, Union, Tuple from functools import partial +from typing import Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F - from transformers import PreTrainedModel from transformers.modeling_outputs import SemanticSegmenterOutput @@ -764,4 +763,4 @@ def forward( loss=None, logits=logits, hidden_states=encoder_hidden_states - ) \ No newline at end of file + ) diff --git a/surya/model/detection/processor.py b/surya/model/detection/processor.py index 822d7d1..c2d648c 100644 --- a/surya/model/detection/processor.py +++ b/surya/model/detection/processor.py @@ -2,8 +2,13 @@ from typing import Any, Dict, List, Optional, Union import numpy as np - -from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +import PIL.Image +import torch +from transformers.image_processing_utils import ( + BaseImageProcessor, + BatchFeature, + get_size_dict, +) from transformers.image_transforms import to_channel_dimension_format from transformers.image_utils import ( IMAGENET_DEFAULT_MEAN, @@ -17,10 +22,6 @@ from transformers.utils import TensorType -import PIL.Image -import torch - - class SegformerImageProcessor(BaseImageProcessor): r""" Constructs a Segformer image processor. @@ -281,4 +282,4 @@ def preprocess( ] data = {"pixel_values": images} - return BatchFeature(data=data, tensor_type=return_tensors) \ No newline at end of file + return BatchFeature(data=data, tensor_type=return_tensors) diff --git a/surya/model/ordering/config.py b/surya/model/ordering/config.py index fcf20f7..69d08ab 100644 --- a/surya/model/ordering/config.py +++ b/surya/model/ordering/config.py @@ -1,8 +1,8 @@ -from transformers import MBartConfig, DonutSwinConfig +from transformers import DonutSwinConfig, MBartConfig class MBartOrderConfig(MBartConfig): pass class VariableDonutSwinConfig(DonutSwinConfig): - pass \ No newline at end of file + pass diff --git a/surya/model/ordering/decoder.py b/surya/model/ordering/decoder.py index 89fc3eb..3ef170f 100644 --- a/surya/model/ordering/decoder.py +++ b/surya/model/ordering/decoder.py @@ -1,15 +1,27 @@ import copy -from typing import Optional, List, Union, Tuple +import math +from typing import List, Optional, Tuple, Union -from transformers import MBartForCausalLM, MBartConfig +import torch from torch import nn +from transformers import MBartConfig, MBartForCausalLM from transformers.activations import ACT2FN -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_attention_mask -from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions -from transformers.models.mbart.modeling_mbart import MBartPreTrainedModel, MBartDecoder, MBartLearnedPositionalEmbedding, MBartDecoderLayer +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, +) +from transformers.models.mbart.modeling_mbart import ( + MBartDecoder, + MBartDecoderLayer, + MBartLearnedPositionalEmbedding, + MBartPreTrainedModel, +) + from surya.model.ordering.config import MBartOrderConfig -import torch -import math def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -554,4 +566,4 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, - ) \ No newline at end of file + ) diff --git a/surya/model/ordering/encoder.py b/surya/model/ordering/encoder.py index ff001b1..861f60b 100644 --- a/surya/model/ordering/encoder.py +++ b/surya/model/ordering/encoder.py @@ -1,15 +1,20 @@ -from torch import nn -import torch -from typing import Optional, Tuple, Union import collections import math +from typing import Optional, Tuple, Union +import torch +from torch import nn from transformers import DonutSwinPreTrainedModel -from transformers.models.donut.modeling_donut_swin import DonutSwinPatchEmbeddings, DonutSwinEmbeddings, DonutSwinModel, \ - DonutSwinEncoder +from transformers.models.donut.modeling_donut_swin import ( + DonutSwinEmbeddings, + DonutSwinEncoder, + DonutSwinModel, + DonutSwinPatchEmbeddings, +) from surya.model.ordering.config import VariableDonutSwinConfig + class VariableDonutSwinEmbeddings(DonutSwinEmbeddings): """ Construct the patch and position embeddings. Optionally, also the mask token. @@ -80,4 +85,4 @@ def __init__(self, config, add_pooling_layer=True, use_mask_token=False, **kwarg self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None # Initialize weights and apply final processing - self.post_init() \ No newline at end of file + self.post_init() diff --git a/surya/model/ordering/encoderdecoder.py b/surya/model/ordering/encoderdecoder.py index f7351f1..ee4ea68 100644 --- a/surya/model/ordering/encoderdecoder.py +++ b/surya/model/ordering/encoderdecoder.py @@ -1,8 +1,8 @@ -from typing import Optional, Union, Tuple, List +from typing import List, Optional, Tuple, Union import torch from transformers import VisionEncoderDecoderModel -from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput +from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput class OrderVisionEncoderDecoderModel(VisionEncoderDecoderModel): diff --git a/surya/model/ordering/model.py b/surya/model/ordering/model.py index 8c92fee..93bbab0 100644 --- a/surya/model/ordering/model.py +++ b/surya/model/ordering/model.py @@ -1,5 +1,12 @@ -from transformers import DetrConfig, BeitConfig, DetrImageProcessor, VisionEncoderDecoderConfig, AutoModelForCausalLM, \ - AutoModel +from transformers import ( + AutoModel, + AutoModelForCausalLM, + BeitConfig, + DetrConfig, + DetrImageProcessor, + VisionEncoderDecoderConfig, +) + from surya.model.ordering.config import MBartOrderConfig, VariableDonutSwinConfig from surya.model.ordering.decoder import MBartOrder from surya.model.ordering.encoder import VariableDonutSwinModel @@ -31,4 +38,4 @@ def load_model(checkpoint=settings.ORDER_MODEL_CHECKPOINT, device=settings.TORCH model = model.to(device) model = model.eval() print(f"Loaded reading order model {checkpoint} on device {device} with dtype {dtype}") - return model \ No newline at end of file + return model diff --git a/surya/model/ordering/processor.py b/surya/model/ordering/processor.py index c6f463b..0497123 100644 --- a/surya/model/ordering/processor.py +++ b/surya/model/ordering/processor.py @@ -1,15 +1,22 @@ from copy import deepcopy -from typing import Dict, Union, Optional, List, Tuple +from typing import Dict, List, Optional, Tuple, Union +import numpy as np +import PIL import torch +from PIL import Image from torch import TensorType from transformers import DonutImageProcessor, DonutProcessor from transformers.image_processing_utils import BatchFeature -from transformers.image_utils import PILImageResampling, ImageInput, ChannelDimension, make_list_of_images, \ - valid_images, to_numpy_array -import numpy as np -from PIL import Image -import PIL +from transformers.image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + make_list_of_images, + to_numpy_array, + valid_images, +) + from surya.settings import settings @@ -153,4 +160,4 @@ def preprocess( "input_boxes_mask": box_mask, "input_boxes_counts": box_counts, } - return BatchFeature(data=data, tensor_type=return_tensors) \ No newline at end of file + return BatchFeature(data=data, tensor_type=return_tensors) diff --git a/surya/model/recognition/config.py b/surya/model/recognition/config.py index 23d9bbf..cace7e6 100644 --- a/surya/model/recognition/config.py +++ b/surya/model/recognition/config.py @@ -1,4 +1,4 @@ -from transformers import T5Config, MBartConfig, DonutSwinConfig +from transformers import DonutSwinConfig, MBartConfig, T5Config class MBartMoEConfig(MBartConfig): @@ -108,4 +108,4 @@ class VariableDonutSwinConfig(DonutSwinConfig): 'xh': 90, 'yi': 91, 'zh': 92 -} \ No newline at end of file +} diff --git a/surya/model/recognition/decoder.py b/surya/model/recognition/decoder.py index fe2d4f4..a277d23 100644 --- a/surya/model/recognition/decoder.py +++ b/surya/model/recognition/decoder.py @@ -1,14 +1,18 @@ import copy -from typing import Optional, List, Union, Tuple +import math +from typing import List, Optional, Tuple, Union -from transformers import MBartForCausalLM, MBartConfig +import torch from torch import nn +from transformers import MBartConfig, MBartForCausalLM from transformers.activations import ACT2FN -from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions -from transformers.models.mbart.modeling_mbart import MBartPreTrainedModel, MBartDecoder +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, +) +from transformers.models.mbart.modeling_mbart import MBartDecoder, MBartPreTrainedModel + from surya.model.recognition.config import MBartMoEConfig -import torch -import math class MBartLearnedPositionalEmbedding(nn.Embedding): diff --git a/surya/model/recognition/encoder.py b/surya/model/recognition/encoder.py index f01f35c..4f9a5fa 100644 --- a/surya/model/recognition/encoder.py +++ b/surya/model/recognition/encoder.py @@ -1,10 +1,21 @@ -from torch import nn -import torch from typing import Optional, Tuple, Union -from transformers.models.donut.modeling_donut_swin import DonutSwinPatchEmbeddings, DonutSwinEmbeddings, DonutSwinModel, \ - DonutSwinEncoder, DonutSwinModelOutput, DonutSwinEncoderOutput, DonutSwinAttention, DonutSwinDropPath, \ - DonutSwinIntermediate, DonutSwinOutput, window_partition, window_reverse +import torch +from torch import nn +from transformers.models.donut.modeling_donut_swin import ( + DonutSwinAttention, + DonutSwinDropPath, + DonutSwinEmbeddings, + DonutSwinEncoder, + DonutSwinEncoderOutput, + DonutSwinIntermediate, + DonutSwinModel, + DonutSwinModelOutput, + DonutSwinOutput, + DonutSwinPatchEmbeddings, + window_partition, + window_reverse, +) from surya.model.recognition.config import VariableDonutSwinConfig diff --git a/surya/model/recognition/model.py b/surya/model/recognition/model.py index 1ee2563..1cdcf80 100644 --- a/surya/model/recognition/model.py +++ b/surya/model/recognition/model.py @@ -5,13 +5,21 @@ warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated") import logging + logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) from typing import List, Optional, Tuple -from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig, AutoModel, AutoModelForCausalLM + +from transformers import ( + AutoModel, + AutoModelForCausalLM, + VisionEncoderDecoderConfig, + VisionEncoderDecoderModel, +) + from surya.model.recognition.config import MBartMoEConfig, VariableDonutSwinConfig -from surya.model.recognition.encoder import VariableDonutSwinModel from surya.model.recognition.decoder import MBartMoE +from surya.model.recognition.encoder import VariableDonutSwinModel from surya.settings import settings diff --git a/surya/model/recognition/processor.py b/surya/model/recognition/processor.py index 1e5193a..551fc70 100644 --- a/surya/model/recognition/processor.py +++ b/surya/model/recognition/processor.py @@ -1,14 +1,21 @@ -from typing import Dict, Union, Optional, List, Iterable +from typing import Dict, Iterable, List, Optional, Union import cv2 +import numpy as np +import PIL +from PIL import Image from torch import TensorType from transformers import DonutImageProcessor, DonutProcessor from transformers.image_processing_utils import BatchFeature -from transformers.image_transforms import pad, normalize -from transformers.image_utils import PILImageResampling, ImageInput, ChannelDimension, make_list_of_images, get_image_size -import numpy as np -from PIL import Image -import PIL +from transformers.image_transforms import normalize, pad +from transformers.image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + make_list_of_images, +) + from surya.model.recognition.tokenizer import Byt5LangTokenizer from surya.settings import settings @@ -213,4 +220,4 @@ def __call__(self, *args, **kwargs): else: inputs["labels"] = encodings["input_ids"] inputs["langs"] = encodings["langs"] - return inputs \ No newline at end of file + return inputs diff --git a/surya/model/recognition/tokenizer.py b/surya/model/recognition/tokenizer.py index 27c062c..bfca6fb 100644 --- a/surya/model/recognition/tokenizer.py +++ b/surya/model/recognition/tokenizer.py @@ -1,9 +1,11 @@ from itertools import chain from typing import List, Union -from transformers import ByT5Tokenizer + import numpy as np import torch -from surya.model.recognition.config import LANGUAGE_MAP, TOTAL_TOKENS, TOKEN_OFFSET +from transformers import ByT5Tokenizer + +from surya.model.recognition.config import LANGUAGE_MAP, TOKEN_OFFSET, TOTAL_TOKENS def text_to_utf16_numbers(text): diff --git a/surya/ocr.py b/surya/ocr.py index 1744098..c270d62 100644 --- a/surya/ocr.py +++ b/surya/ocr.py @@ -1,11 +1,16 @@ from typing import List + from PIL import Image from surya.detection import batch_text_detection -from surya.input.processing import slice_polys_from_image, slice_bboxes_from_image, convert_if_not_rgb +from surya.input.processing import ( + convert_if_not_rgb, + slice_bboxes_from_image, + slice_polys_from_image, +) from surya.postprocessing.text import sort_text_lines from surya.recognition import batch_recognition -from surya.schema import TextLine, OCRResult +from surya.schema import OCRResult, TextLine def run_recognition(images: List[Image.Image], langs: List[List[str]], rec_model, rec_processor, bboxes: List[List[List[int]]] = None, polygons: List[List[List[List[int]]]] = None, batch_size=None) -> List[OCRResult]: diff --git a/surya/ordering.py b/surya/ordering.py index 0b87ba1..c485c0c 100644 --- a/surya/ordering.py +++ b/surya/ordering.py @@ -1,14 +1,15 @@ from copy import deepcopy from typing import List + +import numpy as np import torch from PIL import Image +from tqdm import tqdm from surya.input.processing import convert_if_not_rgb from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel from surya.schema import OrderBox, OrderResult from surya.settings import settings -from tqdm import tqdm -import numpy as np def get_batch_size(): diff --git a/surya/postprocessing/affinity.py b/surya/postprocessing/affinity.py index 4cb538c..2c43e9a 100644 --- a/surya/postprocessing/affinity.py +++ b/surya/postprocessing/affinity.py @@ -2,7 +2,6 @@ import cv2 import numpy as np - from PIL import Image, ImageDraw from surya.postprocessing.util import get_line_angle, rescale_bbox @@ -162,4 +161,4 @@ def get_vertical_lines(image, processor_size, image_size, divisor=20, x_toleranc # Always start with top left of page vertical_lines[0].bbox[1] = 0 - return vertical_lines \ No newline at end of file + return vertical_lines diff --git a/surya/postprocessing/fonts.py b/surya/postprocessing/fonts.py index e9e1878..fc476bb 100644 --- a/surya/postprocessing/fonts.py +++ b/surya/postprocessing/fonts.py @@ -1,5 +1,6 @@ -from typing import List, Optional import os +from typing import List, Optional + import requests from surya.settings import settings @@ -21,4 +22,4 @@ def get_font_path(langs: Optional[List[str]] = None) -> str: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) - return font_path \ No newline at end of file + return font_path diff --git a/surya/postprocessing/heatmap.py b/surya/postprocessing/heatmap.py index 9cc14cb..52d8558 100644 --- a/surya/postprocessing/heatmap.py +++ b/surya/postprocessing/heatmap.py @@ -1,15 +1,15 @@ +import math from typing import List, Tuple -import numpy as np import cv2 -import math +import numpy as np from PIL import ImageDraw, ImageFont from surya.postprocessing.fonts import get_font_path +from surya.postprocessing.text import get_text_size from surya.postprocessing.util import rescale_bbox from surya.schema import PolygonBox from surya.settings import settings -from surya.postprocessing.text import get_text_size def keep_largest_boxes(boxes: List[PolygonBox]) -> List[PolygonBox]: diff --git a/surya/postprocessing/math/latex.py b/surya/postprocessing/math/latex.py index b07e5fb..557da4c 100644 --- a/surya/postprocessing/math/latex.py +++ b/surya/postprocessing/math/latex.py @@ -1,4 +1,5 @@ import re + from ftfy import fix_text diff --git a/surya/postprocessing/math/render.py b/surya/postprocessing/math/render.py index 761334a..aeabe5a 100644 --- a/surya/postprocessing/math/render.py +++ b/surya/postprocessing/math/render.py @@ -1,7 +1,8 @@ -from playwright.sync_api import sync_playwright -from PIL import Image import io +from PIL import Image +from playwright.sync_api import sync_playwright + def latex_to_pil(latex_code, target_width, target_height, fontsize=18): html_template = """ @@ -85,4 +86,4 @@ def latex_to_pil(latex_code, target_width, target_height, fontsize=18): image_stream = io.BytesIO(screenshot_bytes) pil_image = Image.open(image_stream) pil_image.load() - return pil_image \ No newline at end of file + return pil_image diff --git a/surya/postprocessing/text.py b/surya/postprocessing/text.py index fea9c3e..8d62312 100644 --- a/surya/postprocessing/text.py +++ b/surya/postprocessing/text.py @@ -5,9 +5,9 @@ from PIL import Image, ImageDraw, ImageFont from surya.postprocessing.fonts import get_font_path +from surya.postprocessing.math.latex import is_latex from surya.schema import TextLine from surya.settings import settings -from surya.postprocessing.math.latex import is_latex def sort_text_lines(lines: List[TextLine], tolerance=1.25): diff --git a/surya/postprocessing/util.py b/surya/postprocessing/util.py index 3da0e9b..67d1632 100644 --- a/surya/postprocessing/util.py +++ b/surya/postprocessing/util.py @@ -1,5 +1,5 @@ -import math import copy +import math def get_line_angle(x1, y1, x2, y2): @@ -41,4 +41,4 @@ def rescale_point(point, processor_size, image_size): def rescale_points(points, processor_size, image_size): - return [rescale_point(point, processor_size, image_size) for point in points] \ No newline at end of file + return [rescale_point(point, processor_size, image_size) for point in points] diff --git a/surya/recognition.py b/surya/recognition.py index 883ebfe..042a44f 100644 --- a/surya/recognition.py +++ b/surya/recognition.py @@ -1,14 +1,15 @@ from typing import List + +import numpy as np import torch +import torch.nn.functional as F from PIL import Image +from tqdm import tqdm from surya.input.processing import convert_if_not_rgb -from surya.postprocessing.math.latex import fix_math, contains_math +from surya.postprocessing.math.latex import contains_math, fix_math from surya.postprocessing.text import truncate_repetitions from surya.settings import settings -from tqdm import tqdm -import numpy as np -import torch.nn.functional as F def get_batch_size(): diff --git a/surya/schema.py b/surya/schema.py index 129f991..d5e4e48 100644 --- a/surya/schema.py +++ b/surya/schema.py @@ -1,7 +1,7 @@ import copy -from typing import List, Tuple, Any, Optional +from typing import Any, List, Optional, Tuple -from pydantic import BaseModel, field_validator, computed_field +from pydantic import BaseModel, computed_field, field_validator from surya.postprocessing.util import rescale_bbox diff --git a/surya/settings.py b/surya/settings.py index 9e47ae8..4ad742a 100644 --- a/surya/settings.py +++ b/surya/settings.py @@ -1,10 +1,10 @@ +import os from typing import Dict, Optional +import torch from dotenv import find_dotenv from pydantic import computed_field from pydantic_settings import BaseSettings -import torch -import os class Settings(BaseSettings): @@ -83,4 +83,4 @@ class Config: extra = "ignore" -settings = Settings() \ No newline at end of file +settings = Settings()