Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable setting batch size programatically #91

Merged
merged 1 commit into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "surya-ocr"
version = "0.4.0"
version = "0.4.1"
description = "OCR, layout, reading order, and line detection in 90+ languages"
authors = ["Vik Paruchuri <[email protected]>"]
readme = "README.md"
Expand Down
9 changes: 5 additions & 4 deletions surya/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ def get_batch_size():
return batch_size


def batch_detection(images: List, model, processor) -> Tuple[List[List[np.ndarray]], List[Tuple[int, int]]]:
def batch_detection(images: List, model, processor, batch_size=None) -> Tuple[List[List[np.ndarray]], List[Tuple[int, int]]]:
assert all([isinstance(image, Image.Image) for image in images])
batch_size = get_batch_size()
if batch_size is None:
batch_size = get_batch_size()
heatmap_count = model.config.num_labels

images = [image.convert("RGB") for image in images]
Expand Down Expand Up @@ -109,8 +110,8 @@ def parallel_get_lines(preds, orig_sizes):
return result


def batch_text_detection(images: List, model, processor) -> List[TextDetectionResult]:
preds, orig_sizes = batch_detection(images, model, processor)
def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]:
preds, orig_sizes = batch_detection(images, model, processor, batch_size=batch_size)
results = []
if len(images) == 1: # Ensures we don't parallelize with streamlit
for i in range(len(images)):
Expand Down
4 changes: 2 additions & 2 deletions surya/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ def parallel_get_regions(heatmaps: List[Image.Image], orig_size, id2label, detec
return result


def batch_layout_detection(images: List, model, processor, detection_results: Optional[List[TextDetectionResult]] = None) -> List[LayoutResult]:
preds, orig_sizes = batch_detection(images, model, processor)
def batch_layout_detection(images: List, model, processor, detection_results: Optional[List[TextDetectionResult]] = None, batch_size=None) -> List[LayoutResult]:
preds, orig_sizes = batch_detection(images, model, processor, batch_size=batch_size)
id2label = model.config.id2label

results = []
Expand Down
4 changes: 2 additions & 2 deletions surya/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def run_recognition(images: List[Image.Image], langs: List[List[str]], rec_model
return predictions_by_image


def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_processor, rec_model, rec_processor) -> List[OCRResult]:
def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_processor, rec_model, rec_processor, batch_size=None) -> List[OCRResult]:
det_predictions = batch_text_detection(images, det_model, det_processor)
if det_model.device == "cuda":
torch.cuda.empty_cache() # Empty cache from first model run
Expand All @@ -75,7 +75,7 @@ def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_pr
all_slices.extend(slices)
all_langs.extend([lang] * len(slices))

rec_predictions, confidence_scores = batch_recognition(all_slices, all_langs, rec_model, rec_processor)
rec_predictions, confidence_scores = batch_recognition(all_slices, all_langs, rec_model, rec_processor, batch_size=batch_size)

predictions_by_image = []
slice_start = 0
Expand Down
5 changes: 3 additions & 2 deletions surya/ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ def rank_elements(arr):
return rank


def batch_ordering(images: List, bboxes: List[List[List[float]]], model, processor) -> List[OrderResult]:
def batch_ordering(images: List, bboxes: List[List[List[float]]], model, processor, batch_size=None) -> List[OrderResult]:
assert all([isinstance(image, Image.Image) for image in images])
assert len(images) == len(bboxes)
batch_size = get_batch_size()
if batch_size is None:
batch_size = get_batch_size()

images = [image.convert("RGB") for image in images]

Expand Down
6 changes: 4 additions & 2 deletions surya/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ def get_batch_size():
return batch_size


def batch_recognition(images: List, languages: List[List[str]], model, processor):
def batch_recognition(images: List, languages: List[List[str]], model, processor, batch_size=None):
assert all([isinstance(image, Image.Image) for image in images])
assert len(images) == len(languages)
batch_size = get_batch_size()

if batch_size is None:
batch_size = get_batch_size()

images = [image.convert("RGB") for image in images]

Expand Down
6 changes: 5 additions & 1 deletion surya/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,16 @@ def TORCH_DEVICE_MODEL(self) -> str:
@computed_field
def TORCH_DEVICE_DETECTION(self) -> str:
if self.TORCH_DEVICE is not None:
# Does not work with mps
if "mps" in self.TORCH_DEVICE:
return "cpu"

return self.TORCH_DEVICE

# Does not work with mps
if torch.cuda.is_available():
return "cuda"

# Does not work with mps
return "cpu"

# Text detection
Expand Down
Loading