From 1de6d9193ddc9a88a869d78a4d82ea3381544613 Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Mon, 6 Jan 2025 17:19:17 -0500 Subject: [PATCH] Refactor recognition model --- benchmark/recognition.py | 13 +- ocr_app.py | 18 +- ocr_text.py | 17 +- surya/benchmark/tesseract.py | 10 +- surya/common/donut/encoder.py | 2 +- surya/common/load.py | 21 + surya/common/predictor.py | 15 +- surya/detection/__init__.py | 40 +- surya/detection/loader.py | 43 +++ .../model/__init__.py} | 0 surya/detection/{ => model}/config.py | 0 .../{model.py => model/encoderdecoder.py} | 2 +- surya/input/langs.py | 19 - surya/languages.py | 102 ----- surya/model/recognition/model.py | 60 --- surya/model/table_rec/encoder.py | 2 +- surya/ocr.py | 114 ------ surya/postprocessing/math/latex.py | 125 ------ surya/postprocessing/math/render.py | 88 ----- surya/postprocessing/text.py | 75 +--- surya/recognition.py | 191 --------- surya/recognition/__init__.py | 364 ++++++++++++++++++ surya/recognition/languages.py | 111 ++++++ surya/recognition/loader.py | 62 +++ surya/recognition/model/__init__.py | 0 .../model}/config.py | 0 .../model}/decoder.py | 2 +- .../model}/encoder.py | 0 .../model}/encoderdecoder.py | 22 +- surya/recognition/postprocessing.py | 29 ++ surya/{model => }/recognition/processor.py | 8 +- surya/{model => }/recognition/tokenizer.py | 2 +- surya/recognition/util.py | 22 ++ 33 files changed, 699 insertions(+), 880 deletions(-) create mode 100644 surya/common/load.py create mode 100644 surya/detection/loader.py rename surya/{postprocessing/affinity.py => detection/model/__init__.py} (100%) rename surya/detection/{ => model}/config.py (100%) rename surya/detection/{model.py => model/encoderdecoder.py} (99%) delete mode 100644 surya/input/langs.py delete mode 100644 surya/model/recognition/model.py delete mode 100644 surya/ocr.py delete mode 100644 surya/postprocessing/math/latex.py delete mode 100644 surya/postprocessing/math/render.py delete mode 100644 surya/recognition.py create mode 100644 surya/recognition/__init__.py create mode 100644 surya/recognition/languages.py create mode 100644 surya/recognition/loader.py create mode 100644 surya/recognition/model/__init__.py rename surya/{model/recognition => recognition/model}/config.py (100%) rename surya/{model/recognition => recognition/model}/decoder.py (98%) rename surya/{model/recognition => recognition/model}/encoder.py (100%) rename surya/{model/recognition => recognition/model}/encoderdecoder.py (82%) create mode 100644 surya/recognition/postprocessing.py rename surya/{model => }/recognition/processor.py (89%) rename surya/{model => }/recognition/tokenizer.py (98%) create mode 100644 surya/recognition/util.py diff --git a/benchmark/recognition.py b/benchmark/recognition.py index ece579ed..2c80c488 100644 --- a/benchmark/recognition.py +++ b/benchmark/recognition.py @@ -3,12 +3,10 @@ from benchmark.scoring import overlap_score from surya.input.processing import convert_if_not_rgb -from surya.model.recognition.model import load_model as load_recognition_model -from surya.model.recognition.processor import load_processor as load_recognition_processor -from surya.ocr import run_recognition from surya.postprocessing.text import draw_text_on_image +from surya.recognition import RecognitionPredictor from surya.settings import settings -from surya.languages import CODE_TO_LANGUAGE +from surya.recognition.languages import CODE_TO_LANGUAGE from surya.benchmark.tesseract import tesseract_ocr_parallel, surya_lang_to_tesseract, TESS_CODE_TO_LANGUAGE import os import datasets @@ -30,8 +28,7 @@ def main(): parser.add_argument("--specify_language", action="store_true", help="Pass language codes into the model.", default=False) args = parser.parse_args() - rec_model = load_recognition_model() - rec_processor = load_recognition_processor() + rec_predictor = RecognitionPredictor() split = "train" if args.max: @@ -61,10 +58,10 @@ def main(): if settings.RECOGNITION_STATIC_CACHE: # Run through one batch to compile the model - run_recognition(images[:1], lang_list[:1], rec_model, rec_processor, bboxes=bboxes[:1]) + rec_predictor(images[:1], lang_list[:1], bboxes=bboxes[:1]) start = time.time() - predictions_by_image = run_recognition(images, lang_list if args.specify_language else n_list, rec_model, rec_processor, bboxes=bboxes) + predictions_by_image = rec_predictor(images, lang_list if args.specify_language else n_list, bboxes=bboxes) surya_time = time.time() - start surya_scores = defaultdict(list) diff --git a/ocr_app.py b/ocr_app.py index b7d5e8fd..ba8bb8ca 100644 --- a/ocr_app.py +++ b/ocr_app.py @@ -7,19 +7,18 @@ from surya.layout import batch_layout_detection from surya.detection import DetectionPredictor +from surya.recognition import RecognitionPredictor + from surya.model.layout.model import load_model as load_layout_model from surya.model.layout.processor import load_processor as load_layout_processor -from surya.model.recognition.model import load_model as load_rec_model -from surya.model.recognition.processor import load_processor as load_rec_processor from surya.model.table_rec.model import load_model as load_table_model from surya.model.table_rec.processor import load_processor as load_table_processor from surya.model.ocr_error.model import load_model as load_ocr_error_model, load_tokenizer as load_ocr_error_processor from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image -from surya.ocr import run_ocr + from surya.postprocessing.text import draw_text_on_image from PIL import Image -from surya.languages import CODE_TO_LANGUAGE -from surya.input.langs import replace_lang_with_code +from surya.recognition.languages import CODE_TO_LANGUAGE, replace_lang_with_code from surya.schema import OCRResult, TextDetectionResult, LayoutResult, TableResult from surya.settings import settings from surya.tables import batch_table_recognition @@ -32,17 +31,14 @@ def load_det_cached(): return DetectionPredictor() - @st.cache_resource() def load_rec_cached(): - return load_rec_model(), load_rec_processor() - + return RecognitionPredictor() @st.cache_resource() def load_layout_cached(): return load_layout_model(), load_layout_processor() - @st.cache_resource() def load_table_cached(): return load_table_model(), load_table_processor() @@ -139,7 +135,7 @@ def table_recognition(img, highres_img, skip_table_detection: bool) -> (Image.Im # Function for OCR def ocr(img, highres_img, langs: List[str]) -> (Image.Image, OCRResult): replace_lang_with_code(langs) - img_pred = run_ocr([img], [langs], det_predictor, rec_model, rec_processor, highres_images=[highres_img])[0] + img_pred = recognition_predictor([img], [langs], det_predictor, highres_images=[highres_img])[0] bboxes = [l.bbox for l in img_pred.text_lines] text = [l.text for l in img_pred.text_lines] @@ -178,7 +174,7 @@ def page_counter(pdf_file): col1, col2 = st.columns([.5, .5]) det_predictor = load_det_cached() -rec_model, rec_processor = load_rec_cached() +recognition_predictor = load_rec_cached() layout_model, layout_processor = load_layout_cached() table_model, table_processor = load_table_cached() ocr_error_model, ocr_error_processor = load_ocr_error_cached() diff --git a/ocr_text.py b/ocr_text.py index 99a1541f..15345060 100644 --- a/ocr_text.py +++ b/ocr_text.py @@ -4,13 +4,11 @@ import time from collections import defaultdict -from surya.input.langs import replace_lang_with_code +from surya.detection import DetectionPredictor +from surya.recognition.languages import replace_lang_with_code from surya.input.load import load_from_folder, load_from_file, load_lang_file -from surya.model.detection.model import load_model as load_detection_model, load_processor as load_detection_processor -from surya.model.recognition.model import load_model as load_recognition_model -from surya.model.recognition.processor import load_processor as load_recognition_processor -from surya.ocr import run_ocr from surya.postprocessing.text import draw_text_on_image +from surya.recognition import RecognitionPredictor from surya.settings import settings @@ -49,17 +47,14 @@ def main(): else: image_langs = [None] * len(images) - det_processor = load_detection_processor() - det_model = load_detection_model() - - rec_model = load_recognition_model() - rec_processor = load_recognition_processor() + det_predictor = DetectionPredictor() + rec_predictor = RecognitionPredictor() result_path = os.path.join(args.results_dir, folder_name) os.makedirs(result_path, exist_ok=True) start = time.time() - predictions_by_image = run_ocr(images, image_langs, det_model, det_processor, rec_model, rec_processor, highres_images=highres_images) + predictions_by_image = rec_predictor(images, image_langs, det_predictor=det_predictor, highres_images=highres_images) if args.debug: print(f"OCR took {time.time() - start:.2f} seconds") max_chars = max([len(l.text) for p in predictions_by_image for l in p.text_lines]) diff --git a/surya/benchmark/tesseract.py b/surya/benchmark/tesseract.py index 140d46c9..3dc5c7ee 100644 --- a/surya/benchmark/tesseract.py +++ b/surya/benchmark/tesseract.py @@ -7,9 +7,9 @@ from surya.settings import settings import os from concurrent.futures import ProcessPoolExecutor -from surya.detection import get_batch_size as get_det_batch_size -from surya.recognition import get_batch_size as get_rec_batch_size -from surya.languages import CODE_TO_LANGUAGE +from surya.recognition.languages import CODE_TO_LANGUAGE +from surya.recognition import RecognitionPredictor +from surya.detection import DetectionPredictor def surya_lang_to_tesseract(code: str) -> Optional[str]: @@ -33,7 +33,7 @@ def tesseract_ocr(img, bboxes, lang: str): def tesseract_ocr_parallel(imgs, bboxes, langs: List[str], cpus=None): - tess_parallel_cores = min(len(imgs), get_rec_batch_size()) + tess_parallel_cores = min(len(imgs), RecognitionPredictor.get_batch_size()) if not cpus: cpus = os.cpu_count() tess_parallel_cores = min(tess_parallel_cores, cpus) @@ -67,7 +67,7 @@ def tesseract_bboxes(img): def tesseract_parallel(imgs): # Tesseract uses 4 threads per instance - tess_parallel_cores = min(len(imgs), get_det_batch_size()) + tess_parallel_cores = min(len(imgs), DetectionPredictor.get_batch_size()) cpus = os.cpu_count() tess_parallel_cores = min(tess_parallel_cores, cpus) diff --git a/surya/common/donut/encoder.py b/surya/common/donut/encoder.py index a197e0bf..bf03d254 100644 --- a/surya/common/donut/encoder.py +++ b/surya/common/donut/encoder.py @@ -11,7 +11,7 @@ from transformers.modeling_utils import PreTrainedModel from transformers.pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from transformers.utils import ModelOutput -from surya.model.recognition.config import DonutSwinConfig +from transformers import DonutSwinConfig _EXPECTED_OUTPUT_SHAPE = [1, 49, 1024] diff --git a/surya/common/load.py b/surya/common/load.py new file mode 100644 index 00000000..2b1ecec7 --- /dev/null +++ b/surya/common/load.py @@ -0,0 +1,21 @@ +from typing import Optional, Any + +import torch + +from surya.settings import settings + + +class ModelLoader: + def __init__(self, checkpoint: Optional[str] = None): + self.checkpoint = checkpoint + + def model( + self, + device: torch.device | str | None = settings.TORCH_DEVICE_MODEL, + dtype: Optional[torch.dtype | str] = settings.MODEL_DTYPE) -> Any: + raise NotImplementedError() + + def processor( + self + ) -> Any: + raise NotImplementedError() \ No newline at end of file diff --git a/surya/common/predictor.py b/surya/common/predictor.py index db78ff84..23d39d21 100644 --- a/surya/common/predictor.py +++ b/surya/common/predictor.py @@ -1,21 +1,19 @@ from typing import Optional import torch +from surya.common.load import ModelLoader from surya.settings import settings class BasePredictor: + model_loader_cls = ModelLoader def __init__(self, checkpoint: Optional[str] = None, device: torch.device | str | None = settings.TORCH_DEVICE_MODEL, dtype: Optional[torch.dtype | str] = settings.MODEL_DTYPE): self.model = None self.processor = None - self.load_model(checkpoint, device, dtype) - self.load_processor(checkpoint) + loader = self.model_loader_cls(checkpoint) - def load_model(self, checkpoint: Optional[str] = None, device: torch.device | str | None = settings.TORCH_DEVICE_MODEL, dtype: Optional[torch.dtype | str] = settings.MODEL_DTYPE): - raise NotImplementedError() - - def load_processor(self, checkpoint: Optional[str] = None): - raise NotImplementedError() + self.model = loader.model(device, dtype) + self.processor = loader.processor() def to(self, device_dtype: torch.device | str | None = None): if self.model: @@ -23,7 +21,8 @@ def to(self, device_dtype: torch.device | str | None = None): else: raise ValueError("Model not loaded") - def get_batch_size(self): + @staticmethod + def get_batch_size(): raise NotImplementedError() def __call__(self, *args, **kwargs): diff --git a/surya/detection/__init__.py b/surya/detection/__init__.py index 97a4bcf0..94d65eb6 100644 --- a/surya/detection/__init__.py +++ b/surya/detection/__init__.py @@ -9,8 +9,8 @@ from tqdm import tqdm from surya.common.predictor import BasePredictor -from surya.detection.config import EfficientViTConfig -from surya.detection.model import EfficientViTForSemanticSegmentation + +from surya.detection.loader import DetectionModelLoader from surya.detection.parallel import FakeExecutor from surya.detection.processor import SegformerImageProcessor from surya.detection.util import get_total_splits, split_image @@ -20,6 +20,8 @@ class DetectionPredictor(BasePredictor): + model_loader_cls = DetectionModelLoader + def __call__(self, images: List[Image.Image], batch_size=None, include_maps=False) -> List[TextDetectionResult]: detection_generator = self.batch_detection(images, batch_size=batch_size, static_cache=settings.DETECTOR_STATIC_CACHE) @@ -34,38 +36,8 @@ def __call__(self, images: List[Image.Image], batch_size=None, include_maps=Fals return [future.result() for future in postprocessing_futures] - def load_model( - self, - checkpoint: Optional[str] = None, - device: Optional[torch.device | str] = None, - dtype: Optional[torch.dtype | str] = None - ): - if checkpoint is None: - checkpoint = settings.DETECTOR_MODEL_CHECKPOINT - config = EfficientViTConfig.from_pretrained(checkpoint) - model = EfficientViTForSemanticSegmentation.from_pretrained(checkpoint, torch_dtype=dtype, config=config, - ignore_mismatched_sizes=True) - model = model.to(device) - model = model.eval() - - if settings.DETECTOR_STATIC_CACHE: - torch.set_float32_matmul_precision('high') - torch._dynamo.config.cache_size_limit = 1 - torch._dynamo.config.suppress_errors = False - - print(f"Compiling detection model {checkpoint} on device {device} with dtype {dtype}") - model = torch.compile(model) - - print(f"Loaded detection model {checkpoint} on device {device} with dtype {dtype}") - self.model = model - - def load_processor(self, checkpoint: Optional[str] = None): - if checkpoint is None: - checkpoint = settings.DETECTOR_MODEL_CHECKPOINT - - self.processor = SegformerImageProcessor.from_pretrained(checkpoint) - - def get_batch_size(self): + @staticmethod + def get_batch_size(): batch_size = settings.DETECTOR_BATCH_SIZE if batch_size is None: batch_size = 8 diff --git a/surya/detection/loader.py b/surya/detection/loader.py new file mode 100644 index 00000000..63452a99 --- /dev/null +++ b/surya/detection/loader.py @@ -0,0 +1,43 @@ +from typing import Optional + +import torch + +from surya.common.load import ModelLoader +from surya.detection.processor import SegformerImageProcessor + +from surya.detection.model.config import EfficientViTConfig +from surya.detection.model.encoderdecoder import EfficientViTForSemanticSegmentation +from surya.settings import settings + + +class DetectionModelLoader(ModelLoader): + def __init__(self, checkpoint: Optional[str] = None): + super().__init__(checkpoint) + + if self.checkpoint is None: + self.checkpoint = settings.DETECTOR_MODEL_CHECKPOINT + + def model( + self, + device: Optional[torch.device | str] = None, + dtype: Optional[torch.dtype | str] = None + ): + config = EfficientViTConfig.from_pretrained(self.checkpoint) + model = EfficientViTForSemanticSegmentation.from_pretrained(self.checkpoint, torch_dtype=dtype, config=config, + ignore_mismatched_sizes=True) + model = model.to(device) + model = model.eval() + + if settings.DETECTOR_STATIC_CACHE: + torch.set_float32_matmul_precision('high') + torch._dynamo.config.cache_size_limit = 1 + torch._dynamo.config.suppress_errors = False + + print(f"Compiling detection model {self.checkpoint} on device {device} with dtype {dtype}") + model = torch.compile(model) + + print(f"Loaded detection model {self.checkpoint} on device {device} with dtype {dtype}") + return model + + def processor(self): + return SegformerImageProcessor.from_pretrained(self.checkpoint) \ No newline at end of file diff --git a/surya/postprocessing/affinity.py b/surya/detection/model/__init__.py similarity index 100% rename from surya/postprocessing/affinity.py rename to surya/detection/model/__init__.py diff --git a/surya/detection/config.py b/surya/detection/model/config.py similarity index 100% rename from surya/detection/config.py rename to surya/detection/model/config.py diff --git a/surya/detection/model.py b/surya/detection/model/encoderdecoder.py similarity index 99% rename from surya/detection/model.py rename to surya/detection/model/encoderdecoder.py index f17fa965..8a2b21ae 100644 --- a/surya/detection/model.py +++ b/surya/detection/model/encoderdecoder.py @@ -18,7 +18,7 @@ from transformers import PreTrainedModel from transformers.modeling_outputs import SemanticSegmenterOutput -from surya.detection.config import EfficientViTConfig +from surya.detection.model.config import EfficientViTConfig def val2list(x: Union[List, Tuple, Any], repeat_time=1): diff --git a/surya/input/langs.py b/surya/input/langs.py deleted file mode 100644 index e347408f..00000000 --- a/surya/input/langs.py +++ /dev/null @@ -1,19 +0,0 @@ -from typing import List -from surya.languages import LANGUAGE_TO_CODE, CODE_TO_LANGUAGE - - -def replace_lang_with_code(langs: List[str]): - for i in range(len(langs)): - if langs[i].title() in LANGUAGE_TO_CODE: - langs[i] = LANGUAGE_TO_CODE[langs[i].title()] - if langs[i] not in CODE_TO_LANGUAGE: - raise ValueError(f"Language code {langs[i]} not found.") - - -def get_unique_langs(langs: List[List[str]]): - uniques = [] - for lang_list in langs: - for lang in lang_list: - if lang not in uniques: - uniques.append(lang) - return uniques \ No newline at end of file diff --git a/surya/languages.py b/surya/languages.py index d4bfbd48..e69de29b 100644 --- a/surya/languages.py +++ b/surya/languages.py @@ -1,102 +0,0 @@ -CODE_TO_LANGUAGE = { - "_math": "Math", - 'af': 'Afrikaans', - 'am': 'Amharic', - 'ar': 'Arabic', - 'as': 'Assamese', - 'az': 'Azerbaijani', - 'be': 'Belarusian', - 'bg': 'Bulgarian', - 'bn': 'Bengali', - 'br': 'Breton', - 'bs': 'Bosnian', - 'ca': 'Catalan', - 'cs': 'Czech', - 'cy': 'Welsh', - 'da': 'Danish', - 'de': 'German', - 'el': 'Greek', - 'en': 'English', - 'eo': 'Esperanto', - 'es': 'Spanish', - 'et': 'Estonian', - 'eu': 'Basque', - 'fa': 'Persian', - 'fi': 'Finnish', - 'fr': 'French', - 'fy': 'Western Frisian', - 'ga': 'Irish', - 'gd': 'Scottish Gaelic', - 'gl': 'Galician', - 'gu': 'Gujarati', - 'ha': 'Hausa', - 'he': 'Hebrew', - 'hi': 'Hindi', - 'hr': 'Croatian', - 'hu': 'Hungarian', - 'hy': 'Armenian', - 'id': 'Indonesian', - 'is': 'Icelandic', - 'it': 'Italian', - 'ja': 'Japanese', - 'jv': 'Javanese', - 'ka': 'Georgian', - 'kk': 'Kazakh', - 'km': 'Khmer', - 'kn': 'Kannada', - 'ko': 'Korean', - 'ku': 'Kurdish', - 'ky': 'Kyrgyz', - 'la': 'Latin', - 'lo': 'Lao', - 'lt': 'Lithuanian', - 'lv': 'Latvian', - 'mg': 'Malagasy', - 'mk': 'Macedonian', - 'ml': 'Malayalam', - 'mn': 'Mongolian', - 'mr': 'Marathi', - 'ms': 'Malay', - 'my': 'Burmese', - 'ne': 'Nepali', - 'nl': 'Dutch', - 'no': 'Norwegian', - 'om': 'Oromo', - 'or': 'Oriya', - 'pa': 'Punjabi', - 'pl': 'Polish', - 'ps': 'Pashto', - 'pt': 'Portuguese', - 'ro': 'Romanian', - 'ru': 'Russian', - 'sa': 'Sanskrit', - 'sd': 'Sindhi', - 'si': 'Sinhala', - 'sk': 'Slovak', - 'sl': 'Slovenian', - 'so': 'Somali', - 'sq': 'Albanian', - 'sr': 'Serbian', - 'su': 'Sundanese', - 'sv': 'Swedish', - 'sw': 'Swahili', - 'ta': 'Tamil', - 'te': 'Telugu', - 'th': 'Thai', - 'tl': 'Tagalog', - 'tr': 'Turkish', - 'ug': 'Uyghur', - 'uk': 'Ukrainian', - 'ur': 'Urdu', - 'uz': 'Uzbek', - 'vi': 'Vietnamese', - 'xh': 'Xhosa', - 'yi': 'Yiddish', - 'zh': 'Chinese', -} - -LANGUAGE_TO_CODE = {v: k for k, v in CODE_TO_LANGUAGE.items()} - - -def is_arabic(lang_code): - return lang_code in ["ar", "fa", "ps", "ug", "ur"] diff --git a/surya/model/recognition/model.py b/surya/model/recognition/model.py deleted file mode 100644 index 86c81024..00000000 --- a/surya/model/recognition/model.py +++ /dev/null @@ -1,60 +0,0 @@ -import warnings - -import torch - -warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated") - -import logging -logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) - -from surya.model.recognition.encoderdecoder import OCREncoderDecoderModel -from surya.model.recognition.config import DonutSwinConfig, SuryaOCRConfig, SuryaOCRDecoderConfig, SuryaOCRTextEncoderConfig -from surya.model.recognition.encoder import DonutSwinModel -from surya.model.recognition.decoder import SuryaOCRDecoder, SuryaOCRTextEncoder -from surya.settings import settings - -torch.backends.cuda.enable_cudnn_sdp(settings.ENABLE_CUDNN_ATTENTION) -if not settings.ENABLE_EFFICIENT_ATTENTION: - print("Efficient attention is disabled. This will use significantly more VRAM.") - torch.backends.cuda.enable_mem_efficient_sdp(False) - torch.backends.cuda.enable_flash_sdp(False) - torch.backends.cuda.enable_math_sdp(True) - - -def load_model(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE) -> OCREncoderDecoderModel: - - config = SuryaOCRConfig.from_pretrained(checkpoint) - decoder_config = config.decoder - decoder = SuryaOCRDecoderConfig(**decoder_config) - config.decoder = decoder - - encoder_config = config.encoder - encoder = DonutSwinConfig(**encoder_config) - config.encoder = encoder - - text_encoder_config = config.text_encoder - text_encoder = SuryaOCRTextEncoderConfig(**text_encoder_config) - config.text_encoder = text_encoder - - model = OCREncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype) - - assert isinstance(model.decoder, SuryaOCRDecoder) - assert isinstance(model.encoder, DonutSwinModel) - assert isinstance(model.text_encoder, SuryaOCRTextEncoder) - - model = model.to(device) - model = model.eval() - - if settings.RECOGNITION_STATIC_CACHE: - torch.set_float32_matmul_precision('high') - torch._dynamo.config.cache_size_limit = 16 - torch._dynamo.config.suppress_errors = False - - - print(f"Compiling recognition model {checkpoint} on device {device} with dtype {dtype}") - model.encoder = torch.compile(model.encoder) - model.decoder = torch.compile(model.decoder) - model.text_encoder = torch.compile(model.text_encoder) - - print(f"Loaded recognition model {checkpoint} on device {device} with dtype {dtype}") - return model \ No newline at end of file diff --git a/surya/model/table_rec/encoder.py b/surya/model/table_rec/encoder.py index 513876bb..4bfb75c0 100644 --- a/surya/model/table_rec/encoder.py +++ b/surya/model/table_rec/encoder.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from typing import Optional, Union, Tuple -from surya.model.recognition.encoder import DonutSwinPreTrainedModel, DonutSwinModelOutput, DonutSwinEmbeddings, DonutSwinEncoder +from surya.common.donut.encoder import DonutSwinPreTrainedModel, DonutSwinModelOutput, DonutSwinEmbeddings, DonutSwinEncoder class DonutSwinModel(DonutSwinPreTrainedModel): diff --git a/surya/ocr.py b/surya/ocr.py deleted file mode 100644 index 1e59d933..00000000 --- a/surya/ocr.py +++ /dev/null @@ -1,114 +0,0 @@ -from copy import deepcopy -from typing import List -from PIL import Image - -from surya.detection import DetectionPredictor -from surya.input.processing import slice_polys_from_image, slice_bboxes_from_image, convert_if_not_rgb -from surya.postprocessing.text import sort_text_lines -from surya.recognition import batch_recognition -from surya.schema import TextLine, OCRResult - - -def run_recognition(images: List[Image.Image], langs: List[List[str] | None], rec_model, rec_processor, bboxes: List[List[List[int]]] = None, polygons: List[List[List[List[int]]]] = None, batch_size=None) -> List[OCRResult]: - # Polygons need to be in corner format - [[x1, y1], [x2, y2], [x3, y3], [x4, y4]], bboxes in [x1, y1, x2, y2] format - assert bboxes is not None or polygons is not None - assert len(images) == len(langs), "You need to pass in one list of languages for each image" - - images = convert_if_not_rgb(images) - - slice_map = [] - all_slices = [] - all_langs = [] - for idx, (image, lang) in enumerate(zip(images, langs)): - if polygons is not None: - slices = slice_polys_from_image(image, polygons[idx]) - else: - slices = slice_bboxes_from_image(image, bboxes[idx]) - slice_map.append(len(slices)) - all_slices.extend(slices) - all_langs.extend([deepcopy(lang)] * len(slices)) - - rec_predictions, _ = batch_recognition(all_slices, all_langs, rec_model, rec_processor, batch_size=batch_size) - - predictions_by_image = [] - slice_start = 0 - for idx, (image, lang) in enumerate(zip(images, langs)): - slice_end = slice_start + slice_map[idx] - image_lines = rec_predictions[slice_start:slice_end] - slice_start = slice_end - - text_lines = [] - for i in range(len(image_lines)): - if polygons is not None: - poly = polygons[idx][i] - else: - bbox = bboxes[idx][i] - poly = [[bbox[0], bbox[1]], [bbox[2], bbox[1]], [bbox[2], bbox[3]], [bbox[0], bbox[3]]] - - text_lines.append(TextLine( - text=image_lines[i], - polygon=poly - )) - - pred = OCRResult( - text_lines=text_lines, - languages=lang, - image_bbox=[0, 0, image.size[0], image.size[1]] - ) - predictions_by_image.append(pred) - - return predictions_by_image - - -def run_ocr(images: List[Image.Image], langs: List[List[str] | None], det_predictor: DetectionPredictor, rec_model, rec_processor, detection_batch_size=None, recognition_batch_size=None, highres_images: List[Image.Image] | None = None) -> List[OCRResult]: - images = convert_if_not_rgb(images) - highres_images = convert_if_not_rgb(highres_images) if highres_images is not None else [None] * len(images) - det_predictions = det_predictor(images, batch_size=detection_batch_size) - - all_slices = [] - slice_map = [] - all_langs = [] - - for idx, (det_pred, image, highres_image, lang) in enumerate(zip(det_predictions, images, highres_images, langs)): - polygons = [p.polygon for p in det_pred.bboxes] - if highres_image: - width_scaler = highres_image.size[0] / image.size[0] - height_scaler = highres_image.size[1] / image.size[1] - scaled_polygons = [[[int(p[0] * width_scaler), int(p[1] * height_scaler)] for p in polygon] for polygon in polygons] - slices = slice_polys_from_image(highres_image, scaled_polygons) - else: - slices = slice_polys_from_image(image, polygons) - slice_map.append(len(slices)) - all_langs.extend([lang] * len(slices)) - all_slices.extend(slices) - - rec_predictions, confidence_scores = batch_recognition(all_slices, all_langs, rec_model, rec_processor, batch_size=recognition_batch_size) - - predictions_by_image = [] - slice_start = 0 - for idx, (image, det_pred, lang) in enumerate(zip(images, det_predictions, langs)): - slice_end = slice_start + slice_map[idx] - image_lines = rec_predictions[slice_start:slice_end] - line_confidences = confidence_scores[slice_start:slice_end] - slice_start = slice_end - - assert len(image_lines) == len(det_pred.bboxes) - - lines = [] - for text_line, confidence, bbox in zip(image_lines, line_confidences, det_pred.bboxes): - lines.append(TextLine( - text=text_line, - polygon=bbox.polygon, - bbox=bbox.bbox, - confidence=confidence - )) - - lines = sort_text_lines(lines) - - predictions_by_image.append(OCRResult( - text_lines=lines, - languages=lang, - image_bbox=det_pred.image_bbox - )) - - return predictions_by_image diff --git a/surya/postprocessing/math/latex.py b/surya/postprocessing/math/latex.py deleted file mode 100644 index b07e5fb8..00000000 --- a/surya/postprocessing/math/latex.py +++ /dev/null @@ -1,125 +0,0 @@ -import re -from ftfy import fix_text - - -def contains_math(text): - return text.startswith("$") or text.endswith("$") - - -def fix_math(text): - # Fix any issues with the text - text = fix_text(text) - - # Remove LaTeX labels and references - text = remove_labels(text) - text = replace_katex_invalid(text) - text = fix_fences(text) - return text - - -def remove_labels(text): - pattern = r'\\label\{[^}]*\}' - text = re.sub(pattern, '', text) - - ref_pattern = r'\\ref\{[^}]*\}' - text = re.sub(ref_pattern, '', text) - - pageref_pattern = r'\\pageref\{[^}]*\}' - text = re.sub(pageref_pattern, '', text) - return text - - -def replace_katex_invalid(string): - # KaTeX cannot render all LaTeX, so we need to replace some things - string = re.sub(r'\\tag\{.*?\}', '', string) - string = re.sub(r'\\(?:Bigg?|bigg?)\{(.*?)\}', r'\1', string) - string = re.sub(r'\\quad\\mbox\{(.*?)\}', r'\1', string) - string = re.sub(r'\\mbox\{(.*?)\}', r'\1', string) - string = remove_inner_dollars(string) - return string - - -def remove_inner_dollars(text): - def replace_dollar(match): - # Replace single $ with nothing, keep $$ intact - math_block = match.group(1) - return '$$' + math_block.replace('$', '') + '$$' - - pattern = r'\$\$(.*?)\$\$' - return re.sub(pattern, replace_dollar, text, flags=re.DOTALL) - - -def extract_latex_with_positions(text): - pattern = r'(\$\$.*?\$\$|\$.*?\$)' - matches = [] - for match in re.finditer(pattern, text, re.DOTALL): - matches.append((match.group(), match.start(), match.end())) - return matches - - -def slice_latex(text): - # Extract LaTeX blocks along with their positions - latex_blocks_with_positions = extract_latex_with_positions(text) - - chunks = [] - last_position = 0 - for block, start, end in latex_blocks_with_positions: - # Add text before the current LaTeX block, if any - if start > last_position: - chunks.append({"text": text[last_position:start], "type": "text"}) - # Add the LaTeX block - chunks.append({"text": block, "type": "latex"}) - last_position = end - # Add remaining text after the last LaTeX block, if any - if last_position < len(text): - chunks.append({"text": text[last_position:], "type": "text"}) - - return chunks - - -def is_latex(text): - latex_patterns = [ - r'\\(?:begin|end)\{[a-zA-Z]*\}', - r'\$.*?\$', - r'\$\$.*?\$\$', - r'\\[a-zA-Z]+', - r'\\[^a-zA-Z]', - ] - - combined_pattern = '|'.join(latex_patterns) - if re.search(combined_pattern, text, re.DOTALL): - return True - - return False - - -def fix_fences(text): - if text.startswith("$$") and not text.endswith("$$"): - if text[-1] == "$": - text += "$" - else: - text += "$$" - - if text.endswith("$$") and not text.startswith("$$"): - if text[0] == "$": - text = "$" + text - else: - text = "$$" + text - - if text.startswith("$") and not text.endswith("$"): - text = "$" + text + "$$" - - if text.endswith("$") and not text.startswith("$"): - text = "$$" + text + "$" - - return text - - -def strip_fences(text): - while text.startswith("$"): - text = text[1:] - while text.endswith("$"): - text = text[:-1] - return text - - diff --git a/surya/postprocessing/math/render.py b/surya/postprocessing/math/render.py deleted file mode 100644 index 761334a0..00000000 --- a/surya/postprocessing/math/render.py +++ /dev/null @@ -1,88 +0,0 @@ -from playwright.sync_api import sync_playwright -from PIL import Image -import io - - -def latex_to_pil(latex_code, target_width, target_height, fontsize=18): - html_template = """ - - - - - - - - -
{content}
- - - - """ - - formatted_latex = latex_code.replace('\n', '\\n').replace('"', '\\"') - with sync_playwright() as p: - browser = p.chromium.launch() - page = browser.new_page() - page.set_viewport_size({'width': target_width, 'height': target_height}) - - while fontsize <= 30: - html_content = html_template.replace("{content}", formatted_latex).replace("{fontsize}", str(fontsize)) - page.set_content(html_content) - - dimensions = page.evaluate("""() => { - const render = document.getElementById('content'); - return { - width: render.offsetWidth, - height: render.offsetHeight - }; - }""") - - if dimensions['width'] >= target_width or dimensions['height'] >= target_height: - fontsize -= 1 - break - else: - fontsize += 1 - - html_content = html_template.replace("{content}", formatted_latex).replace("{fontsize}", str(fontsize)) - page.set_content(html_content) - - screenshot_bytes = page.screenshot() - browser.close() - - image_stream = io.BytesIO(screenshot_bytes) - pil_image = Image.open(image_stream) - pil_image.load() - return pil_image \ No newline at end of file diff --git a/surya/postprocessing/text.py b/surya/postprocessing/text.py index 542a80cc..0c819fec 100644 --- a/surya/postprocessing/text.py +++ b/surya/postprocessing/text.py @@ -1,63 +1,8 @@ -import os from typing import List, Tuple - -import requests from PIL import Image, ImageDraw, ImageFont from surya.postprocessing.fonts import get_font_path from surya.schema import TextLine -from surya.settings import settings -from surya.postprocessing.math.latex import is_latex - - -def sort_text_lines(lines: List[TextLine] | List[dict], tolerance=1.25): - # Sorts in reading order. Not 100% accurate, this should only - # be used as a starting point for more advanced sorting. - vertical_groups = {} - for line in lines: - group_key = round(line.bbox[1] if isinstance(line, TextLine) else line["bbox"][1] / tolerance) * tolerance - if group_key not in vertical_groups: - vertical_groups[group_key] = [] - vertical_groups[group_key].append(line) - - # Sort each group horizontally and flatten the groups into a single list - sorted_lines = [] - for _, group in sorted(vertical_groups.items()): - sorted_group = sorted(group, key=lambda x: x.bbox[0] if isinstance(x, TextLine) else x["bbox"][0]) - sorted_lines.extend(sorted_group) - - return sorted_lines - - -def truncate_repetitions(text: str, min_len=15): - # From nougat, with some cleanup - if len(text) < 2 * min_len: - return text - - # try to find a length at which the tail is repeating - max_rep_len = None - for rep_len in range(min_len, int(len(text) / 2)): - # check if there is a repetition at the end - same = True - for i in range(0, rep_len): - if text[len(text) - rep_len - i - 1] != text[len(text) - i - 1]: - same = False - break - - if same: - max_rep_len = rep_len - - if max_rep_len is None: - return text - - lcs = text[-max_rep_len:] - - # remove all but the last repetition - text_to_truncate = text - while text_to_truncate.endswith(lcs): - text_to_truncate = text_to_truncate[:-max_rep_len] - - return text[:len(text_to_truncate)] def get_text_size(text, font): @@ -83,19 +28,6 @@ def render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font draw.text((x, y), text, fill="black", font=font) -def render_math(image, draw, text, s_bbox, bbox_width, bbox_height, font_path): - try: - from surya.postprocessing.math.render import latex_to_pil - box_font_size = max(10, min(int(.2 * bbox_height), 24)) - img = latex_to_pil(text, bbox_width, bbox_height, fontsize=box_font_size) - img.thumbnail((bbox_width, bbox_height)) - image.paste(img, (s_bbox[0], s_bbox[1])) - except Exception as e: - print(f"Failed to render math: {e}") - box_font_size = max(10, min(int(.75 * bbox_height), 24)) - render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size) - - def draw_text_on_image(bboxes, texts, image_size: Tuple[int, int], langs: List[str], font_path=None, max_font_size=60, res_upscale=2, has_math=False): if font_path is None: font_path = get_font_path(langs) @@ -109,10 +41,7 @@ def draw_text_on_image(bboxes, texts, image_size: Tuple[int, int], langs: List[s bbox_height = s_bbox[3] - s_bbox[1] # Shrink the text to fit in the bbox if needed - if has_math and is_latex(text): - render_math(image, draw, text, s_bbox, bbox_width, bbox_height, font_path) - else: - box_font_size = max(6, min(int(.75 * bbox_height), max_font_size)) - render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size) + box_font_size = max(6, min(int(.75 * bbox_height), max_font_size)) + render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size) return image diff --git a/surya/recognition.py b/surya/recognition.py deleted file mode 100644 index 452d9985..00000000 --- a/surya/recognition.py +++ /dev/null @@ -1,191 +0,0 @@ -import torch -from typing import List -from PIL import Image - -from surya.postprocessing.math.latex import fix_math, contains_math -from surya.postprocessing.text import truncate_repetitions -from surya.settings import settings -from tqdm import tqdm -import numpy as np -import torch.nn.functional as F - - -def get_batch_size(): - batch_size = settings.RECOGNITION_BATCH_SIZE - if batch_size is None: - batch_size = 32 - if settings.TORCH_DEVICE_MODEL == "mps": - batch_size = 64 # 12GB RAM max - if settings.TORCH_DEVICE_MODEL == "cuda": - batch_size = 256 - return batch_size - - -def pad_to_batch_size(tensor, batch_size): - current_batch_size = tensor.shape[0] - if current_batch_size >= batch_size: - return tensor - - pad_size = batch_size - current_batch_size - padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size) - - return F.pad(tensor, padding, mode='constant', value=0) - -def batch_recognition(images: List[Image.Image], languages: List[List[str] | None], model, processor, batch_size=None): - assert all(isinstance(image, Image.Image) for image in images) - assert len(images) == len(languages) - - if len(images) == 0: - return [], [] - - if batch_size is None: - batch_size = get_batch_size() - - # Sort images by width, so similar length ones go together - sorted_pairs = sorted(enumerate(images), key=lambda x: x[1].width, reverse=False) - indices, images = zip(*sorted_pairs) - indices = list(indices) - images = list(images) - - output_text = [] - confidences = [] - for i in tqdm(range(0, len(images), batch_size), desc="Recognizing Text"): - batch_images = images[i:i+batch_size] - batch_images = [image.convert("RGB") for image in batch_images] # also copies the images - real_batch_size = len(batch_images) - batch_langs = languages[i:i+real_batch_size] - has_math = [lang and "_math" in lang for lang in batch_langs] - - processed_batch = processor(text=[""] * len(batch_images), images=batch_images, langs=batch_langs) - - batch_pixel_values = processed_batch["pixel_values"] - batch_langs = processed_batch["langs"] - batch_decoder_input = [[model.config.decoder_start_token_id] + lang for lang in batch_langs] - max_input_length = max(len(tokens) for tokens in batch_decoder_input) - - # Pad decoder input to max length if needed, to ensure we can convert to a tensor - for idx, tokens in enumerate(batch_decoder_input): - if len(tokens) < max_input_length: - padding_length = max_input_length - len(tokens) - batch_decoder_input[idx] = [processor.tokenizer.pad_id] * padding_length + tokens - current_batch_size = len(batch_pixel_values) - - batch_pixel_values = torch.tensor(np.stack(batch_pixel_values, axis=0), dtype=model.dtype, device=model.device) - batch_decoder_input = torch.tensor(np.stack(batch_decoder_input, axis=0), dtype=torch.long, device=model.device) - if settings.RECOGNITION_STATIC_CACHE: - batch_pixel_values = pad_to_batch_size(batch_pixel_values, batch_size) - batch_decoder_input = pad_to_batch_size(batch_decoder_input, batch_size) - - token_count = 0 - inference_token_count = batch_decoder_input.shape[-1] - batch_predictions = [[] for _ in range(current_batch_size)] - - decoder_position_ids = torch.ones_like(batch_decoder_input[0, :], dtype=torch.int64, device=model.device).cumsum(0) - 1 - model.decoder.model._setup_cache(model.config, batch_size, model.device, model.dtype) - model.text_encoder.model._setup_cache(model.config, batch_size, model.device, model.dtype) - - sequence_scores = None - all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=model.device) - encoder_hidden_states = None - - with torch.inference_mode(): - encoder_batch_size = batch_size // settings.RECOGNITION_ENCODER_BATCH_DIVISOR - for z in range(0, batch_pixel_values.shape[0], encoder_batch_size): - encoder_pixel_values = batch_pixel_values[z:min(z + encoder_batch_size, batch_pixel_values.shape[0])] - encoder_hidden_states_batch = model.encoder(pixel_values=encoder_pixel_values).last_hidden_state - if encoder_hidden_states is None: - encoder_hidden_states = encoder_hidden_states_batch - else: - encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_batch], dim=0) - - text_encoder_input_ids = torch.arange( - model.text_encoder.config.query_token_count, - device=encoder_hidden_states.device, - dtype=torch.long - ).unsqueeze(0).expand(encoder_hidden_states.size(0), -1) - - encoder_text_hidden_states = model.text_encoder( - input_ids=text_encoder_input_ids, - cache_position=None, - attention_mask=None, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=None, - use_cache=False - ).hidden_states - del encoder_hidden_states - - if settings.RECOGNITION_STATIC_CACHE: - # Pad inputs to max batch size for static cache - encoder_text_hidden_states = pad_to_batch_size(encoder_text_hidden_states, batch_size) - batch_decoder_input = pad_to_batch_size(batch_decoder_input, batch_size) - - while token_count < settings.RECOGNITION_MAX_TOKENS - 1: - is_prefill = token_count == 0 - #TODO: add attention mask - return_dict = model.decoder( - input_ids=batch_decoder_input, - encoder_hidden_states=encoder_text_hidden_states, - cache_position=decoder_position_ids, - use_cache=True, - prefill=is_prefill - ) - - decoder_position_ids = decoder_position_ids[-1:] + 1 - logits = return_dict["logits"][:current_batch_size] # Ignore batch padding - aux_logits = return_dict.get("aux_logits", None) - - preds = torch.argmax(logits[:, -1], dim=-1) - scores = torch.max(F.softmax(logits[:, -1], dim=-1), dim=-1).values.unsqueeze(1) - done = (preds == processor.tokenizer.eos_id) | (preds == processor.tokenizer.pad_id) - all_done = all_done | done - - if is_prefill: - sequence_scores = scores - else: - scores = scores.masked_fill(all_done, 0) - sequence_scores = torch.cat([sequence_scores, scores], dim=1) - - if all_done.all(): - break - - batch_decoder_input = preds.unsqueeze(1) - - for j, (pred, status) in enumerate(zip(preds, all_done)): - if not status: - batch_predictions[j].append(int(pred)) - - token_count += inference_token_count - inference_token_count = batch_decoder_input.shape[-1] - max_position_id = torch.max(decoder_position_ids).item() - decoder_position_ids = torch.ones_like(batch_decoder_input[0, :], dtype=torch.int64, device=model.device).cumsum(0) - 1 + max_position_id - - if settings.RECOGNITION_STATIC_CACHE: - batch_decoder_input = pad_to_batch_size(batch_decoder_input, batch_size) - - sequence_scores = torch.sum(sequence_scores, dim=-1) / torch.sum(sequence_scores != 0, dim=-1) - detected_text = processor.tokenizer.batch_decode(batch_predictions) - detected_text = [truncate_repetitions(dt) for dt in detected_text] - - # Postprocess to fix LaTeX output (add $$ signs, etc) - detected_text = [fix_math(text) if math and contains_math(text) else text for text, math in zip(detected_text, has_math)] - - # Convert sequence_scores to list for the current batch - batch_confidences = sequence_scores.tolist() - - # Exclude padded results if real batch size is less than batch size - if settings.RECOGNITION_STATIC_CACHE: - detected_text = detected_text[:real_batch_size] - batch_confidences = batch_confidences[:real_batch_size] - - output_text.extend(detected_text) - confidences.extend(batch_confidences) - - del encoder_text_hidden_states - - output_text = sorted(zip(indices, output_text), key=lambda x: x[0]) - confidences = sorted(zip(indices, confidences), key=lambda x: x[0]) - output_text = [text for _, text in output_text] - confidences = [conf for _, conf in confidences] - return output_text, confidences - - diff --git a/surya/recognition/__init__.py b/surya/recognition/__init__.py new file mode 100644 index 00000000..a33c9321 --- /dev/null +++ b/surya/recognition/__init__.py @@ -0,0 +1,364 @@ +from copy import deepcopy +from typing import List + +import numpy as np +import torch +from PIL import Image +from tqdm import tqdm +import torch.nn.functional as F + +from surya.common.predictor import BasePredictor +from surya.detection import DetectionPredictor +from surya.input.processing import convert_if_not_rgb, slice_polys_from_image, slice_bboxes_from_image +from surya.recognition.loader import RecognitionModelLoader +from surya.recognition.postprocessing import truncate_repetitions +from surya.recognition.processor import SuryaProcessor +from surya.recognition.util import sort_text_lines +from surya.schema import OCRResult, TextLine +from surya.settings import settings + + +class RecognitionPredictor(BasePredictor): + model_loader_cls = RecognitionModelLoader + + def __call__( + self, + images: List[Image.Image], + langs: List[List[str] | None], + det_predictor: DetectionPredictor | None = None, + detection_batch_size: int | None = None, + recognition_batch_size: int | None = None, + highres_images: List[Image.Image] | None = None, + bboxes: List[List[List[int]]] | None = None, + polygons: List[List[List[List[int]]]] | None = None + ) -> List[OCRResult]: + assert len(images) == len(langs), "You need to pass in one list of languages for each image" + images = convert_if_not_rgb(images) + if highres_images is not None: + assert len(images) == len(highres_images), "You need to pass in one highres image for each image" + + highres_images = convert_if_not_rgb(highres_images) if highres_images is not None else [None] * len(images) + + if bboxes is None and polygons is None: + assert det_predictor is not None, "You need to pass in a detection predictor if you don't provide bboxes or polygons" + + # Detect then slice + flat = self.detect_and_slice_bboxes( + images, + langs, + det_predictor, + detection_batch_size=detection_batch_size, + highres_images=highres_images + ) + else: + if bboxes is not None: + assert len(images) == len(bboxes), "You need to pass in one list of bboxes for each image" + if polygons is not None: + assert len(images) == len(polygons), "You need to pass in one list of polygons for each image" + + flat = self.slice_bboxes( + images, + langs, + bboxes=bboxes, + polygons=polygons + ) + + rec_predictions, confidence_scores = self.batch_recognition( + flat["slices"], + flat["langs"], + batch_size=recognition_batch_size + ) + + predictions_by_image = [] + slice_start = 0 + for idx, (image, lang) in enumerate(zip(images, langs)): + slice_end = slice_start + flat["slice_map"][idx] + image_lines = rec_predictions[slice_start:slice_end] + line_confidences = confidence_scores[slice_start:slice_end] + polygons = flat["polygons"][slice_start:slice_end] + slice_start = slice_end + + lines = [] + for text_line, confidence, polygon in zip(image_lines, line_confidences, polygons): + lines.append(TextLine( + text=text_line, + polygon=polygon, + confidence=confidence + )) + + lines = sort_text_lines(lines) + predictions_by_image.append(OCRResult( + text_lines=lines, + languages=lang, + image_bbox=[0, 0, image.size[0], image.size[1]] + )) + + return predictions_by_image + + def detect_and_slice_bboxes( + self, + images: List[Image.Image], + langs: List[List[str] | None], + det_predictor: DetectionPredictor, + detection_batch_size: int | None = None, + highres_images: List[Image.Image] | None = None, + ): + det_predictions = det_predictor(images, batch_size=detection_batch_size) + + all_slices = [] + slice_map = [] + all_langs = [] + all_polygons = [] + + for idx, (det_pred, image, highres_image, lang) in enumerate(zip(det_predictions, images, highres_images, langs)): + polygons = [p.polygon for p in det_pred.bboxes] + if highres_image: + width_scaler = highres_image.size[0] / image.size[0] + height_scaler = highres_image.size[1] / image.size[1] + scaled_polygons = [[[int(p[0] * width_scaler), int(p[1] * height_scaler)] for p in polygon] for + polygon in polygons] + slices = slice_polys_from_image(highres_image, scaled_polygons) + else: + slices = slice_polys_from_image(image, polygons) + slice_map.append(len(slices)) + all_langs.extend([lang] * len(slices)) + all_slices.extend(slices) + all_polygons.extend(polygons) + + assert len(all_slices) == sum(slice_map) == len(all_langs) == len(all_polygons) + + return { + "slices": all_slices, + "slice_map": slice_map, + "langs": all_langs, + "polygons": all_polygons + } + + def slice_bboxes( + self, + images: List[Image.Image], + langs: List[List[str] | None], + bboxes: List[List[List[int]]] | None = None, + polygons: List[List[List[List[int]]]] | None = None + ): + assert bboxes is not None or polygons is not None + slice_map = [] + all_slices = [] + all_langs = [] + all_polygons = [] + for idx, (image, lang) in enumerate(zip(images, langs)): + if polygons is not None: + polys = polygons[idx] + slices = slice_polys_from_image(image, polys) + else: + slices = slice_bboxes_from_image(image, bboxes[idx]) + polys = [ + [[bbox[0], bbox[1]], [bbox[2], bbox[1]], [bbox[2], bbox[3]], [bbox[0], bbox[3]]] + for bbox in bboxes[idx] + ] + slice_map.append(len(slices)) + all_slices.extend(slices) + all_langs.extend([deepcopy(lang)] * len(slices)) + all_polygons.extend(polys) + + assert len(all_slices) == sum(slice_map) == len(all_langs) == len(all_polygons) + + return { + "slices": all_slices, + "slice_map": slice_map, + "langs": all_langs, + "polygons": all_polygons + } + + @staticmethod + def get_batch_size(): + batch_size = settings.RECOGNITION_BATCH_SIZE + if batch_size is None: + batch_size = 32 + if settings.TORCH_DEVICE_MODEL == "mps": + batch_size = 64 # 12GB RAM max + if settings.TORCH_DEVICE_MODEL == "cuda": + batch_size = 256 + return batch_size + + def pad_to_batch_size(self, tensor, batch_size): + current_batch_size = tensor.shape[0] + if current_batch_size >= batch_size: + return tensor + + pad_size = batch_size - current_batch_size + padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size) + + return F.pad(tensor, padding, mode='constant', value=0) + + def prepare_input(self, batch_langs, batch_pixel_values, batch_size): + batch_decoder_input = [[self.model.config.decoder_start_token_id] + lang for lang in batch_langs] + max_input_length = max(len(tokens) for tokens in batch_decoder_input) + + # Pad decoder input to max length if needed, to ensure we can convert to a tensor + for idx, tokens in enumerate(batch_decoder_input): + if len(tokens) < max_input_length: + padding_length = max_input_length - len(tokens) + batch_decoder_input[idx] = [self.processor.tokenizer.pad_id] * padding_length + tokens + current_batch_size = len(batch_pixel_values) + + batch_pixel_values = torch.tensor(np.stack(batch_pixel_values, axis=0), dtype=self.model.dtype, + device=self.model.device) + batch_decoder_input = torch.tensor(np.stack(batch_decoder_input, axis=0), dtype=torch.long, + device=self.model.device) + if settings.RECOGNITION_STATIC_CACHE: + batch_pixel_values = self.pad_to_batch_size(batch_pixel_values, batch_size) + batch_decoder_input = self.pad_to_batch_size(batch_decoder_input, batch_size) + + return batch_pixel_values, batch_decoder_input, current_batch_size + + def batch_recognition( + self, + images: List[Image.Image], + languages: List[List[str] | None], + batch_size=None + ): + assert all(isinstance(image, Image.Image) for image in images) + assert len(images) == len(languages) + + if len(images) == 0: + return [], [] + + if batch_size is None: + batch_size = self.get_batch_size() + + # Sort images by width, so similar length ones go together + sorted_pairs = sorted(enumerate(images), key=lambda x: x[1].width, reverse=False) + indices, images = zip(*sorted_pairs) + indices = list(indices) + images = list(images) + + output_text = [] + confidences = [] + for i in tqdm(range(0, len(images), batch_size), desc="Recognizing Text"): + batch_images = images[i:i + batch_size] + batch_images = [image.convert("RGB") for image in batch_images] # also copies the images + real_batch_size = len(batch_images) + batch_langs = languages[i:i + real_batch_size] + has_math = [lang and "_math" in lang for lang in batch_langs] + + processed_batch = self.processor(text=[""] * len(batch_images), images=batch_images, langs=batch_langs) + + batch_pixel_values = processed_batch["pixel_values"] + batch_langs = processed_batch["langs"] + batch_pixel_values, batch_decoder_input, current_batch_size = self.prepare_input( + batch_langs, + batch_pixel_values, + batch_size + ) + + token_count = 0 + inference_token_count = batch_decoder_input.shape[-1] + batch_predictions = [[] for _ in range(current_batch_size)] + + decoder_position_ids = torch.ones_like(batch_decoder_input[0, :], dtype=torch.int64, + device=self.model.device).cumsum(0) - 1 + self.model.decoder.model._setup_cache(self.model.config, batch_size, self.model.device, self.model.dtype) + self.model.text_encoder.model._setup_cache(self.model.config, batch_size, self.model.device, self.model.dtype) + + sequence_scores = None + all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=self.model.device) + encoder_hidden_states = None + + with torch.inference_mode(): + encoder_batch_size = batch_size // settings.RECOGNITION_ENCODER_BATCH_DIVISOR + for z in range(0, batch_pixel_values.shape[0], encoder_batch_size): + encoder_pixel_values = batch_pixel_values[ + z:min(z + encoder_batch_size, batch_pixel_values.shape[0])] + encoder_hidden_states_batch = self.model.encoder(pixel_values=encoder_pixel_values).last_hidden_state + if encoder_hidden_states is None: + encoder_hidden_states = encoder_hidden_states_batch + else: + encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_batch], dim=0) + + text_encoder_input_ids = torch.arange( + self.model.text_encoder.config.query_token_count, + device=encoder_hidden_states.device, + dtype=torch.long + ).unsqueeze(0).expand(encoder_hidden_states.size(0), -1) + + encoder_text_hidden_states = self.model.text_encoder( + input_ids=text_encoder_input_ids, + cache_position=None, + attention_mask=None, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=None, + use_cache=False + ).hidden_states + del encoder_hidden_states + + if settings.RECOGNITION_STATIC_CACHE: + # Pad inputs to max batch size for static cache + encoder_text_hidden_states = self.pad_to_batch_size(encoder_text_hidden_states, batch_size) + batch_decoder_input = self.pad_to_batch_size(batch_decoder_input, batch_size) + + while token_count < settings.RECOGNITION_MAX_TOKENS - 1: + is_prefill = token_count == 0 + # TODO: add attention mask + return_dict = self.model.decoder( + input_ids=batch_decoder_input, + encoder_hidden_states=encoder_text_hidden_states, + cache_position=decoder_position_ids, + use_cache=True, + prefill=is_prefill + ) + + decoder_position_ids = decoder_position_ids[-1:] + 1 + logits = return_dict["logits"][:current_batch_size] # Ignore batch padding + + preds = torch.argmax(logits[:, -1], dim=-1) + scores = torch.max(F.softmax(logits[:, -1], dim=-1), dim=-1).values.unsqueeze(1) + done = (preds == self.processor.tokenizer.eos_id) | (preds == self.processor.tokenizer.pad_id) + all_done = all_done | done + + if is_prefill: + sequence_scores = scores + else: + scores = scores.masked_fill(all_done, 0) + sequence_scores = torch.cat([sequence_scores, scores], dim=1) + + if all_done.all(): + break + + batch_decoder_input = preds.unsqueeze(1) + + for j, (pred, status) in enumerate(zip(preds, all_done)): + if not status: + batch_predictions[j].append(int(pred)) + + token_count += inference_token_count + inference_token_count = batch_decoder_input.shape[-1] + max_position_id = torch.max(decoder_position_ids).item() + decoder_position_ids = torch.ones_like(batch_decoder_input[0, :], dtype=torch.int64, + device=self.model.device).cumsum(0) - 1 + max_position_id + + if settings.RECOGNITION_STATIC_CACHE: + batch_decoder_input = self.pad_to_batch_size(batch_decoder_input, batch_size) + + sequence_scores = torch.sum(sequence_scores, dim=-1) / torch.sum(sequence_scores != 0, dim=-1) + detected_text = self.processor.tokenizer.batch_decode(batch_predictions) + detected_text = [truncate_repetitions(dt) for dt in detected_text] + + # Convert sequence_scores to list for the current batch + batch_confidences = sequence_scores.tolist() + + # Exclude padded results if real batch size is less than batch size + if settings.RECOGNITION_STATIC_CACHE: + detected_text = detected_text[:real_batch_size] + batch_confidences = batch_confidences[:real_batch_size] + + output_text.extend(detected_text) + confidences.extend(batch_confidences) + + del encoder_text_hidden_states + + output_text = sorted(zip(indices, output_text), key=lambda x: x[0]) + confidences = sorted(zip(indices, confidences), key=lambda x: x[0]) + output_text = [text for _, text in output_text] + confidences = [conf for _, conf in confidences] + return output_text, confidences \ No newline at end of file diff --git a/surya/recognition/languages.py b/surya/recognition/languages.py new file mode 100644 index 00000000..2dff7a90 --- /dev/null +++ b/surya/recognition/languages.py @@ -0,0 +1,111 @@ +from typing import List + +CODE_TO_LANGUAGE = { + "_math": "Math", + 'af': 'Afrikaans', + 'am': 'Amharic', + 'ar': 'Arabic', + 'as': 'Assamese', + 'az': 'Azerbaijani', + 'be': 'Belarusian', + 'bg': 'Bulgarian', + 'bn': 'Bengali', + 'br': 'Breton', + 'bs': 'Bosnian', + 'ca': 'Catalan', + 'cs': 'Czech', + 'cy': 'Welsh', + 'da': 'Danish', + 'de': 'German', + 'el': 'Greek', + 'en': 'English', + 'eo': 'Esperanto', + 'es': 'Spanish', + 'et': 'Estonian', + 'eu': 'Basque', + 'fa': 'Persian', + 'fi': 'Finnish', + 'fr': 'French', + 'fy': 'Western Frisian', + 'ga': 'Irish', + 'gd': 'Scottish Gaelic', + 'gl': 'Galician', + 'gu': 'Gujarati', + 'ha': 'Hausa', + 'he': 'Hebrew', + 'hi': 'Hindi', + 'hr': 'Croatian', + 'hu': 'Hungarian', + 'hy': 'Armenian', + 'id': 'Indonesian', + 'is': 'Icelandic', + 'it': 'Italian', + 'ja': 'Japanese', + 'jv': 'Javanese', + 'ka': 'Georgian', + 'kk': 'Kazakh', + 'km': 'Khmer', + 'kn': 'Kannada', + 'ko': 'Korean', + 'ku': 'Kurdish', + 'ky': 'Kyrgyz', + 'la': 'Latin', + 'lo': 'Lao', + 'lt': 'Lithuanian', + 'lv': 'Latvian', + 'mg': 'Malagasy', + 'mk': 'Macedonian', + 'ml': 'Malayalam', + 'mn': 'Mongolian', + 'mr': 'Marathi', + 'ms': 'Malay', + 'my': 'Burmese', + 'ne': 'Nepali', + 'nl': 'Dutch', + 'no': 'Norwegian', + 'om': 'Oromo', + 'or': 'Oriya', + 'pa': 'Punjabi', + 'pl': 'Polish', + 'ps': 'Pashto', + 'pt': 'Portuguese', + 'ro': 'Romanian', + 'ru': 'Russian', + 'sa': 'Sanskrit', + 'sd': 'Sindhi', + 'si': 'Sinhala', + 'sk': 'Slovak', + 'sl': 'Slovenian', + 'so': 'Somali', + 'sq': 'Albanian', + 'sr': 'Serbian', + 'su': 'Sundanese', + 'sv': 'Swedish', + 'sw': 'Swahili', + 'ta': 'Tamil', + 'te': 'Telugu', + 'th': 'Thai', + 'tl': 'Tagalog', + 'tr': 'Turkish', + 'ug': 'Uyghur', + 'uk': 'Ukrainian', + 'ur': 'Urdu', + 'uz': 'Uzbek', + 'vi': 'Vietnamese', + 'xh': 'Xhosa', + 'yi': 'Yiddish', + 'zh': 'Chinese', +} + +LANGUAGE_TO_CODE = {v: k for k, v in CODE_TO_LANGUAGE.items()} + + +def is_arabic(lang_code): + return lang_code in ["ar", "fa", "ps", "ug", "ur"] + +def replace_lang_with_code(langs: List[str]): + for i in range(len(langs)): + if langs[i].title() in LANGUAGE_TO_CODE: + langs[i] = LANGUAGE_TO_CODE[langs[i].title()] + if langs[i] not in CODE_TO_LANGUAGE: + raise ValueError(f"Language code {langs[i]} not found.") diff --git a/surya/recognition/loader.py b/surya/recognition/loader.py new file mode 100644 index 00000000..8129b588 --- /dev/null +++ b/surya/recognition/loader.py @@ -0,0 +1,62 @@ +from typing import Optional + +import torch + +from surya.common.load import ModelLoader +from surya.recognition.model.config import SuryaOCRConfig, SuryaOCRDecoderConfig, DonutSwinConfig, SuryaOCRTextEncoderConfig +from surya.recognition.model.encoderdecoder import OCREncoderDecoderModel +from surya.recognition.processor import SuryaProcessor +from surya.settings import settings + +torch.backends.cuda.enable_cudnn_sdp(settings.ENABLE_CUDNN_ATTENTION) +if not settings.ENABLE_EFFICIENT_ATTENTION: + print("Efficient attention is disabled. This will use significantly more VRAM.") + torch.backends.cuda.enable_mem_efficient_sdp(False) + torch.backends.cuda.enable_flash_sdp(False) + torch.backends.cuda.enable_math_sdp(True) + +class RecognitionModelLoader(ModelLoader): + def __init__(self, checkpoint: Optional[str] = None): + super().__init__(checkpoint) + + if self.checkpoint is None: + self.checkpoint = settings.RECOGNITION_MODEL_CHECKPOINT + + def model( + self, + device=settings.TORCH_DEVICE_MODEL, + dtype=settings.MODEL_DTYPE + ): + config = SuryaOCRConfig.from_pretrained(self.checkpoint) + decoder_config = config.decoder + decoder = SuryaOCRDecoderConfig(**decoder_config) + config.decoder = decoder + + encoder_config = config.encoder + encoder = DonutSwinConfig(**encoder_config) + config.encoder = encoder + + text_encoder_config = config.text_encoder + text_encoder = SuryaOCRTextEncoderConfig(**text_encoder_config) + config.text_encoder = text_encoder + + model = OCREncoderDecoderModel.from_pretrained(self.checkpoint, config=config, torch_dtype=dtype) + model = model.to(device) + model = model.eval() + + if settings.RECOGNITION_STATIC_CACHE: + torch.set_float32_matmul_precision('high') + torch._dynamo.config.cache_size_limit = 16 + torch._dynamo.config.suppress_errors = False + + print(f"Compiling recognition model {self.checkpoint} on device {device} with dtype {dtype}") + model.encoder = torch.compile(model.encoder) + model.decoder = torch.compile(model.decoder) + model.text_encoder = torch.compile(model.text_encoder) + + print(f"Loaded recognition model {self.checkpoint} on device {device} with dtype {dtype}") + return model + + def processor(self): + return SuryaProcessor() + diff --git a/surya/recognition/model/__init__.py b/surya/recognition/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/surya/model/recognition/config.py b/surya/recognition/model/config.py similarity index 100% rename from surya/model/recognition/config.py rename to surya/recognition/model/config.py diff --git a/surya/model/recognition/decoder.py b/surya/recognition/model/decoder.py similarity index 98% rename from surya/model/recognition/decoder.py rename to surya/recognition/model/decoder.py index 81df0d7d..1f810e84 100644 --- a/surya/model/recognition/decoder.py +++ b/surya/recognition/model/decoder.py @@ -6,7 +6,7 @@ from torch import nn from transformers.utils import ModelOutput -from surya.model.recognition.config import SuryaOCRTextEncoderConfig +from surya.recognition.model.config import SuryaOCRTextEncoderConfig from transformers.modeling_outputs import CausalLMOutput from surya.common.adetr.decoder import SuryaADETRDecoderModel, SuryaADETRDecoderPreTrainedModel, WrappedEmbedding from surya.settings import settings diff --git a/surya/model/recognition/encoder.py b/surya/recognition/model/encoder.py similarity index 100% rename from surya/model/recognition/encoder.py rename to surya/recognition/model/encoder.py diff --git a/surya/model/recognition/encoderdecoder.py b/surya/recognition/model/encoderdecoder.py similarity index 82% rename from surya/model/recognition/encoderdecoder.py rename to surya/recognition/model/encoderdecoder.py index 3175309f..4dc003ef 100644 --- a/surya/model/recognition/encoderdecoder.py +++ b/surya/recognition/model/encoderdecoder.py @@ -4,8 +4,8 @@ from transformers import PreTrainedModel, VisionEncoderDecoderConfig, PretrainedConfig from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import shift_tokens_right -from surya.model.recognition.encoder import DonutSwinModel -from surya.model.recognition.decoder import SuryaOCRDecoder, SuryaOCRTextEncoder +from surya.recognition.model.encoder import DonutSwinModel +from surya.recognition.model.decoder import SuryaOCRDecoder, SuryaOCRTextEncoder class OCREncoderDecoderModel(PreTrainedModel): @@ -116,24 +116,6 @@ def forward( encoder_last_hidden_state=encoder_outputs.last_hidden_state ) - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): - return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs - ): - decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values) - decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None - input_dict = { - "attention_mask": attention_mask, - "decoder_attention_mask": decoder_attention_mask, - "decoder_input_ids": decoder_inputs["input_ids"], - "encoder_outputs": encoder_outputs, - "past_key_values": decoder_inputs["past_key_values"], - "use_cache": use_cache, - } - return input_dict - def resize_token_embeddings(self, *args, **kwargs): raise NotImplementedError( "Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the" diff --git a/surya/recognition/postprocessing.py b/surya/recognition/postprocessing.py new file mode 100644 index 00000000..d05b1e01 --- /dev/null +++ b/surya/recognition/postprocessing.py @@ -0,0 +1,29 @@ +def truncate_repetitions(text: str, min_len=15): + # From nougat, with some cleanup + if len(text) < 2 * min_len: + return text + + # try to find a length at which the tail is repeating + max_rep_len = None + for rep_len in range(min_len, int(len(text) / 2)): + # check if there is a repetition at the end + same = True + for i in range(0, rep_len): + if text[len(text) - rep_len - i - 1] != text[len(text) - i - 1]: + same = False + break + + if same: + max_rep_len = rep_len + + if max_rep_len is None: + return text + + lcs = text[-max_rep_len:] + + # remove all but the last repetition + text_to_truncate = text + while text_to_truncate.endswith(lcs): + text_to_truncate = text_to_truncate[:-max_rep_len] + + return text[:len(text_to_truncate)] \ No newline at end of file diff --git a/surya/model/recognition/processor.py b/surya/recognition/processor.py similarity index 89% rename from surya/model/recognition/processor.py rename to surya/recognition/processor.py index d67ae1d7..88f8f43b 100644 --- a/surya/model/recognition/processor.py +++ b/surya/recognition/processor.py @@ -1,16 +1,12 @@ from transformers import DonutProcessor from surya.common.donut.processor import SuryaEncoderImageProcessor -from surya.model.recognition.tokenizer import Byt5LangTokenizer +from surya.recognition.tokenizer import Byt5LangTokenizer from surya.settings import settings -def load_processor(): - return SuryaProcessor() - - class SuryaProcessor(DonutProcessor): - def __init__(self, image_processor=None, tokenizer=None, train=False, **kwargs): + def __init__(self, image_processor=None, tokenizer=None, **kwargs): image_processor = SuryaEncoderImageProcessor.from_pretrained(settings.RECOGNITION_MODEL_CHECKPOINT) image_processor.do_align_long_axis = True image_processor.max_size = settings.RECOGNITION_IMAGE_SIZE diff --git a/surya/model/recognition/tokenizer.py b/surya/recognition/tokenizer.py similarity index 98% rename from surya/model/recognition/tokenizer.py rename to surya/recognition/tokenizer.py index 30018f5f..fba9d6d5 100644 --- a/surya/model/recognition/tokenizer.py +++ b/surya/recognition/tokenizer.py @@ -2,7 +2,7 @@ from transformers import ByT5Tokenizer import numpy as np import torch -from surya.model.recognition.config import LANGUAGE_MAP, TOTAL_TOKENS, TOKEN_OFFSET +from surya.recognition.model.config import LANGUAGE_MAP, TOTAL_TOKENS, TOKEN_OFFSET def text_to_utf16_numbers(text): diff --git a/surya/recognition/util.py b/surya/recognition/util.py new file mode 100644 index 00000000..4b21d7fc --- /dev/null +++ b/surya/recognition/util.py @@ -0,0 +1,22 @@ +from typing import List + +from surya.schema import TextLine + + +def sort_text_lines(lines: List[TextLine] | List[dict], tolerance=1.25): + # Sorts in reading order. Not 100% accurate, this should only + # be used as a starting point for more advanced sorting. + vertical_groups = {} + for line in lines: + group_key = round(line.bbox[1] if isinstance(line, TextLine) else line["bbox"][1] / tolerance) * tolerance + if group_key not in vertical_groups: + vertical_groups[group_key] = [] + vertical_groups[group_key].append(line) + + # Sort each group horizontally and flatten the groups into a single list + sorted_lines = [] + for _, group in sorted(vertical_groups.items()): + sorted_group = sorted(group, key=lambda x: x.bbox[0] if isinstance(x, TextLine) else x["bbox"][0]) + sorted_lines.extend(sorted_group) + + return sorted_lines \ No newline at end of file