diff --git a/pyproject.toml b/pyproject.toml index e5e2a5e2..2fd3ae31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "surya-ocr" -version = "0.4.0" +version = "0.4.1" description = "OCR, layout, reading order, and line detection in 90+ languages" authors = ["Vik Paruchuri "] readme = "README.md" diff --git a/surya/detection.py b/surya/detection.py index 9cbe6d22..198a1dec 100644 --- a/surya/detection.py +++ b/surya/detection.py @@ -23,9 +23,10 @@ def get_batch_size(): return batch_size -def batch_detection(images: List, model, processor) -> Tuple[List[List[np.ndarray]], List[Tuple[int, int]]]: +def batch_detection(images: List, model, processor, batch_size=None) -> Tuple[List[List[np.ndarray]], List[Tuple[int, int]]]: assert all([isinstance(image, Image.Image) for image in images]) - batch_size = get_batch_size() + if batch_size is None: + batch_size = get_batch_size() heatmap_count = model.config.num_labels images = [image.convert("RGB") for image in images] @@ -109,8 +110,8 @@ def parallel_get_lines(preds, orig_sizes): return result -def batch_text_detection(images: List, model, processor) -> List[TextDetectionResult]: - preds, orig_sizes = batch_detection(images, model, processor) +def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]: + preds, orig_sizes = batch_detection(images, model, processor, batch_size=batch_size) results = [] if len(images) == 1: # Ensures we don't parallelize with streamlit for i in range(len(images)): diff --git a/surya/layout.py b/surya/layout.py index c1bfa142..104a860f 100644 --- a/surya/layout.py +++ b/surya/layout.py @@ -181,8 +181,8 @@ def parallel_get_regions(heatmaps: List[Image.Image], orig_size, id2label, detec return result -def batch_layout_detection(images: List, model, processor, detection_results: Optional[List[TextDetectionResult]] = None) -> List[LayoutResult]: - preds, orig_sizes = batch_detection(images, model, processor) +def batch_layout_detection(images: List, model, processor, detection_results: Optional[List[TextDetectionResult]] = None, batch_size=None) -> List[LayoutResult]: + preds, orig_sizes = batch_detection(images, model, processor, batch_size=batch_size) id2label = model.config.id2label results = [] diff --git a/surya/ocr.py b/surya/ocr.py index 123e888d..51bc1f4f 100644 --- a/surya/ocr.py +++ b/surya/ocr.py @@ -60,7 +60,7 @@ def run_recognition(images: List[Image.Image], langs: List[List[str]], rec_model return predictions_by_image -def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_processor, rec_model, rec_processor) -> List[OCRResult]: +def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_processor, rec_model, rec_processor, batch_size=None) -> List[OCRResult]: det_predictions = batch_text_detection(images, det_model, det_processor) if det_model.device == "cuda": torch.cuda.empty_cache() # Empty cache from first model run @@ -75,7 +75,7 @@ def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_pr all_slices.extend(slices) all_langs.extend([lang] * len(slices)) - rec_predictions, confidence_scores = batch_recognition(all_slices, all_langs, rec_model, rec_processor) + rec_predictions, confidence_scores = batch_recognition(all_slices, all_langs, rec_model, rec_processor, batch_size=batch_size) predictions_by_image = [] slice_start = 0 diff --git a/surya/ordering.py b/surya/ordering.py index 854e1a80..526e9729 100644 --- a/surya/ordering.py +++ b/surya/ordering.py @@ -30,10 +30,11 @@ def rank_elements(arr): return rank -def batch_ordering(images: List, bboxes: List[List[List[float]]], model, processor) -> List[OrderResult]: +def batch_ordering(images: List, bboxes: List[List[List[float]]], model, processor, batch_size=None) -> List[OrderResult]: assert all([isinstance(image, Image.Image) for image in images]) assert len(images) == len(bboxes) - batch_size = get_batch_size() + if batch_size is None: + batch_size = get_batch_size() images = [image.convert("RGB") for image in images] diff --git a/surya/recognition.py b/surya/recognition.py index 5cd9dcd3..b8239a79 100644 --- a/surya/recognition.py +++ b/surya/recognition.py @@ -21,10 +21,12 @@ def get_batch_size(): return batch_size -def batch_recognition(images: List, languages: List[List[str]], model, processor): +def batch_recognition(images: List, languages: List[List[str]], model, processor, batch_size=None): assert all([isinstance(image, Image.Image) for image in images]) assert len(images) == len(languages) - batch_size = get_batch_size() + + if batch_size is None: + batch_size = get_batch_size() images = [image.convert("RGB") for image in images] diff --git a/surya/settings.py b/surya/settings.py index c72f7c3f..64347f51 100644 --- a/surya/settings.py +++ b/surya/settings.py @@ -34,12 +34,16 @@ def TORCH_DEVICE_MODEL(self) -> str: @computed_field def TORCH_DEVICE_DETECTION(self) -> str: if self.TORCH_DEVICE is not None: + # Does not work with mps + if "mps" in self.TORCH_DEVICE: + return "cpu" + return self.TORCH_DEVICE - # Does not work with mps if torch.cuda.is_available(): return "cuda" + # Does not work with mps return "cpu" # Text detection