Skip to content

Commit

Permalink
Refactor recognition model
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Jan 6, 2025
1 parent 187fc8f commit 1de6d91
Show file tree
Hide file tree
Showing 33 changed files with 699 additions and 880 deletions.
13 changes: 5 additions & 8 deletions benchmark/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@

from benchmark.scoring import overlap_score
from surya.input.processing import convert_if_not_rgb
from surya.model.recognition.model import load_model as load_recognition_model
from surya.model.recognition.processor import load_processor as load_recognition_processor
from surya.ocr import run_recognition
from surya.postprocessing.text import draw_text_on_image
from surya.recognition import RecognitionPredictor
from surya.settings import settings
from surya.languages import CODE_TO_LANGUAGE
from surya.recognition.languages import CODE_TO_LANGUAGE
from surya.benchmark.tesseract import tesseract_ocr_parallel, surya_lang_to_tesseract, TESS_CODE_TO_LANGUAGE
import os
import datasets
Expand All @@ -30,8 +28,7 @@ def main():
parser.add_argument("--specify_language", action="store_true", help="Pass language codes into the model.", default=False)
args = parser.parse_args()

rec_model = load_recognition_model()
rec_processor = load_recognition_processor()
rec_predictor = RecognitionPredictor()

split = "train"
if args.max:
Expand Down Expand Up @@ -61,10 +58,10 @@ def main():

if settings.RECOGNITION_STATIC_CACHE:
# Run through one batch to compile the model
run_recognition(images[:1], lang_list[:1], rec_model, rec_processor, bboxes=bboxes[:1])
rec_predictor(images[:1], lang_list[:1], bboxes=bboxes[:1])

start = time.time()
predictions_by_image = run_recognition(images, lang_list if args.specify_language else n_list, rec_model, rec_processor, bboxes=bboxes)
predictions_by_image = rec_predictor(images, lang_list if args.specify_language else n_list, bboxes=bboxes)
surya_time = time.time() - start

surya_scores = defaultdict(list)
Expand Down
18 changes: 7 additions & 11 deletions ocr_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,18 @@

from surya.layout import batch_layout_detection
from surya.detection import DetectionPredictor
from surya.recognition import RecognitionPredictor

from surya.model.layout.model import load_model as load_layout_model
from surya.model.layout.processor import load_processor as load_layout_processor
from surya.model.recognition.model import load_model as load_rec_model
from surya.model.recognition.processor import load_processor as load_rec_processor
from surya.model.table_rec.model import load_model as load_table_model
from surya.model.table_rec.processor import load_processor as load_table_processor
from surya.model.ocr_error.model import load_model as load_ocr_error_model, load_tokenizer as load_ocr_error_processor
from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image
from surya.ocr import run_ocr

from surya.postprocessing.text import draw_text_on_image
from PIL import Image
from surya.languages import CODE_TO_LANGUAGE
from surya.input.langs import replace_lang_with_code
from surya.recognition.languages import CODE_TO_LANGUAGE, replace_lang_with_code
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, TableResult
from surya.settings import settings
from surya.tables import batch_table_recognition
Expand All @@ -32,17 +31,14 @@
def load_det_cached():
return DetectionPredictor()


@st.cache_resource()
def load_rec_cached():
return load_rec_model(), load_rec_processor()

return RecognitionPredictor()

@st.cache_resource()
def load_layout_cached():
return load_layout_model(), load_layout_processor()


@st.cache_resource()
def load_table_cached():
return load_table_model(), load_table_processor()
Expand Down Expand Up @@ -139,7 +135,7 @@ def table_recognition(img, highres_img, skip_table_detection: bool) -> (Image.Im
# Function for OCR
def ocr(img, highres_img, langs: List[str]) -> (Image.Image, OCRResult):
replace_lang_with_code(langs)
img_pred = run_ocr([img], [langs], det_predictor, rec_model, rec_processor, highres_images=[highres_img])[0]
img_pred = recognition_predictor([img], [langs], det_predictor, highres_images=[highres_img])[0]

bboxes = [l.bbox for l in img_pred.text_lines]
text = [l.text for l in img_pred.text_lines]
Expand Down Expand Up @@ -178,7 +174,7 @@ def page_counter(pdf_file):
col1, col2 = st.columns([.5, .5])

det_predictor = load_det_cached()
rec_model, rec_processor = load_rec_cached()
recognition_predictor = load_rec_cached()
layout_model, layout_processor = load_layout_cached()
table_model, table_processor = load_table_cached()
ocr_error_model, ocr_error_processor = load_ocr_error_cached()
Expand Down
17 changes: 6 additions & 11 deletions ocr_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
import time
from collections import defaultdict

from surya.input.langs import replace_lang_with_code
from surya.detection import DetectionPredictor
from surya.recognition.languages import replace_lang_with_code
from surya.input.load import load_from_folder, load_from_file, load_lang_file
from surya.model.detection.model import load_model as load_detection_model, load_processor as load_detection_processor
from surya.model.recognition.model import load_model as load_recognition_model
from surya.model.recognition.processor import load_processor as load_recognition_processor
from surya.ocr import run_ocr
from surya.postprocessing.text import draw_text_on_image
from surya.recognition import RecognitionPredictor
from surya.settings import settings


Expand Down Expand Up @@ -49,17 +47,14 @@ def main():
else:
image_langs = [None] * len(images)

det_processor = load_detection_processor()
det_model = load_detection_model()

rec_model = load_recognition_model()
rec_processor = load_recognition_processor()
det_predictor = DetectionPredictor()
rec_predictor = RecognitionPredictor()

result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)

start = time.time()
predictions_by_image = run_ocr(images, image_langs, det_model, det_processor, rec_model, rec_processor, highres_images=highres_images)
predictions_by_image = rec_predictor(images, image_langs, det_predictor=det_predictor, highres_images=highres_images)
if args.debug:
print(f"OCR took {time.time() - start:.2f} seconds")
max_chars = max([len(l.text) for p in predictions_by_image for l in p.text_lines])
Expand Down
10 changes: 5 additions & 5 deletions surya/benchmark/tesseract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from surya.settings import settings
import os
from concurrent.futures import ProcessPoolExecutor
from surya.detection import get_batch_size as get_det_batch_size
from surya.recognition import get_batch_size as get_rec_batch_size
from surya.languages import CODE_TO_LANGUAGE
from surya.recognition.languages import CODE_TO_LANGUAGE
from surya.recognition import RecognitionPredictor
from surya.detection import DetectionPredictor


def surya_lang_to_tesseract(code: str) -> Optional[str]:
Expand All @@ -33,7 +33,7 @@ def tesseract_ocr(img, bboxes, lang: str):


def tesseract_ocr_parallel(imgs, bboxes, langs: List[str], cpus=None):
tess_parallel_cores = min(len(imgs), get_rec_batch_size())
tess_parallel_cores = min(len(imgs), RecognitionPredictor.get_batch_size())
if not cpus:
cpus = os.cpu_count()
tess_parallel_cores = min(tess_parallel_cores, cpus)
Expand Down Expand Up @@ -67,7 +67,7 @@ def tesseract_bboxes(img):

def tesseract_parallel(imgs):
# Tesseract uses 4 threads per instance
tess_parallel_cores = min(len(imgs), get_det_batch_size())
tess_parallel_cores = min(len(imgs), DetectionPredictor.get_batch_size())
cpus = os.cpu_count()
tess_parallel_cores = min(tess_parallel_cores, cpus)

Expand Down
2 changes: 1 addition & 1 deletion surya/common/donut/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from transformers.utils import ModelOutput
from surya.model.recognition.config import DonutSwinConfig
from transformers import DonutSwinConfig

_EXPECTED_OUTPUT_SHAPE = [1, 49, 1024]

Expand Down
21 changes: 21 additions & 0 deletions surya/common/load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Optional, Any

import torch

from surya.settings import settings


class ModelLoader:
def __init__(self, checkpoint: Optional[str] = None):
self.checkpoint = checkpoint

def model(
self,
device: torch.device | str | None = settings.TORCH_DEVICE_MODEL,
dtype: Optional[torch.dtype | str] = settings.MODEL_DTYPE) -> Any:
raise NotImplementedError()

def processor(
self
) -> Any:
raise NotImplementedError()
15 changes: 7 additions & 8 deletions surya/common/predictor.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
from typing import Optional
import torch

from surya.common.load import ModelLoader
from surya.settings import settings


class BasePredictor:
model_loader_cls = ModelLoader
def __init__(self, checkpoint: Optional[str] = None, device: torch.device | str | None = settings.TORCH_DEVICE_MODEL, dtype: Optional[torch.dtype | str] = settings.MODEL_DTYPE):
self.model = None
self.processor = None
self.load_model(checkpoint, device, dtype)
self.load_processor(checkpoint)
loader = self.model_loader_cls(checkpoint)

def load_model(self, checkpoint: Optional[str] = None, device: torch.device | str | None = settings.TORCH_DEVICE_MODEL, dtype: Optional[torch.dtype | str] = settings.MODEL_DTYPE):
raise NotImplementedError()

def load_processor(self, checkpoint: Optional[str] = None):
raise NotImplementedError()
self.model = loader.model(device, dtype)
self.processor = loader.processor()

def to(self, device_dtype: torch.device | str | None = None):
if self.model:
self.model.to(device_dtype)
else:
raise ValueError("Model not loaded")

def get_batch_size(self):
@staticmethod
def get_batch_size():
raise NotImplementedError()

def __call__(self, *args, **kwargs):
Expand Down
40 changes: 6 additions & 34 deletions surya/detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from tqdm import tqdm

from surya.common.predictor import BasePredictor
from surya.detection.config import EfficientViTConfig
from surya.detection.model import EfficientViTForSemanticSegmentation

from surya.detection.loader import DetectionModelLoader
from surya.detection.parallel import FakeExecutor
from surya.detection.processor import SegformerImageProcessor
from surya.detection.util import get_total_splits, split_image
Expand All @@ -20,6 +20,8 @@


class DetectionPredictor(BasePredictor):
model_loader_cls = DetectionModelLoader

def __call__(self, images: List[Image.Image], batch_size=None, include_maps=False) -> List[TextDetectionResult]:
detection_generator = self.batch_detection(images, batch_size=batch_size, static_cache=settings.DETECTOR_STATIC_CACHE)

Expand All @@ -34,38 +36,8 @@ def __call__(self, images: List[Image.Image], batch_size=None, include_maps=Fals

return [future.result() for future in postprocessing_futures]

def load_model(
self,
checkpoint: Optional[str] = None,
device: Optional[torch.device | str] = None,
dtype: Optional[torch.dtype | str] = None
):
if checkpoint is None:
checkpoint = settings.DETECTOR_MODEL_CHECKPOINT
config = EfficientViTConfig.from_pretrained(checkpoint)
model = EfficientViTForSemanticSegmentation.from_pretrained(checkpoint, torch_dtype=dtype, config=config,
ignore_mismatched_sizes=True)
model = model.to(device)
model = model.eval()

if settings.DETECTOR_STATIC_CACHE:
torch.set_float32_matmul_precision('high')
torch._dynamo.config.cache_size_limit = 1
torch._dynamo.config.suppress_errors = False

print(f"Compiling detection model {checkpoint} on device {device} with dtype {dtype}")
model = torch.compile(model)

print(f"Loaded detection model {checkpoint} on device {device} with dtype {dtype}")
self.model = model

def load_processor(self, checkpoint: Optional[str] = None):
if checkpoint is None:
checkpoint = settings.DETECTOR_MODEL_CHECKPOINT

self.processor = SegformerImageProcessor.from_pretrained(checkpoint)

def get_batch_size(self):
@staticmethod
def get_batch_size():
batch_size = settings.DETECTOR_BATCH_SIZE
if batch_size is None:
batch_size = 8
Expand Down
43 changes: 43 additions & 0 deletions surya/detection/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Optional

import torch

from surya.common.load import ModelLoader
from surya.detection.processor import SegformerImageProcessor

from surya.detection.model.config import EfficientViTConfig
from surya.detection.model.encoderdecoder import EfficientViTForSemanticSegmentation
from surya.settings import settings


class DetectionModelLoader(ModelLoader):
def __init__(self, checkpoint: Optional[str] = None):
super().__init__(checkpoint)

if self.checkpoint is None:
self.checkpoint = settings.DETECTOR_MODEL_CHECKPOINT

def model(
self,
device: Optional[torch.device | str] = None,
dtype: Optional[torch.dtype | str] = None
):
config = EfficientViTConfig.from_pretrained(self.checkpoint)
model = EfficientViTForSemanticSegmentation.from_pretrained(self.checkpoint, torch_dtype=dtype, config=config,
ignore_mismatched_sizes=True)
model = model.to(device)
model = model.eval()

if settings.DETECTOR_STATIC_CACHE:
torch.set_float32_matmul_precision('high')
torch._dynamo.config.cache_size_limit = 1
torch._dynamo.config.suppress_errors = False

print(f"Compiling detection model {self.checkpoint} on device {device} with dtype {dtype}")
model = torch.compile(model)

print(f"Loaded detection model {self.checkpoint} on device {device} with dtype {dtype}")
return model

def processor(self):
return SegformerImageProcessor.from_pretrained(self.checkpoint)
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from transformers import PreTrainedModel
from transformers.modeling_outputs import SemanticSegmenterOutput

from surya.detection.config import EfficientViTConfig
from surya.detection.model.config import EfficientViTConfig


def val2list(x: Union[List, Tuple, Any], repeat_time=1):
Expand Down
19 changes: 0 additions & 19 deletions surya/input/langs.py

This file was deleted.

Loading

0 comments on commit 1de6d91

Please sign in to comment.