Skip to content

Commit

Permalink
Merge pull request #222 from VikParuchuri/dev
Browse files Browse the repository at this point in the history
Parallel layout and text detection
  • Loading branch information
VikParuchuri authored Oct 23, 2024
2 parents 49348d1 + ef8ec0d commit 2dce355
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 43 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "surya-ocr"
version = "0.6.8"
version = "0.6.9"
description = "OCR, layout, reading order, and table recognition in 90+ languages"
authors = ["Vik Paruchuri <[email protected]>"]
readme = "README.md"
Expand Down
60 changes: 49 additions & 11 deletions surya/detection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import contextlib
import multiprocessing
import threading
from queue import Queue
from typing import List, Tuple, Generator
Expand All @@ -16,6 +18,8 @@
from concurrent.futures import ProcessPoolExecutor
import torch.nn.functional as F

from surya.util.parallel import FakeParallel


def get_batch_size():
batch_size = settings.DETECTOR_BATCH_SIZE
Expand Down Expand Up @@ -127,18 +131,52 @@ def parallel_get_lines(preds, orig_sizes):
def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]:
detection_generator = batch_detection(images, model, processor, batch_size=batch_size)

results = []
postprocessing_futures = []
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH

if parallelize:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
for preds, orig_sizes in detection_generator:
batch_results = list(executor.map(parallel_get_lines, preds, orig_sizes))
results.extend(batch_results)
else:
for preds, orig_sizes in detection_generator:
for pred, orig_size in zip(preds, orig_sizes):
results.append(parallel_get_lines(pred, orig_size))
batch_queue = Queue()
processing_error = threading.Event()

def inference_producer():
try:
for batch in detection_generator:
batch_queue.put(batch)
if processing_error.is_set():
break
except Exception as e:
processing_error.set()
print("Error with batch detection", e)
finally:
batch_queue.put(None) # Signal end of batches

def postprocessing_consumer(executor):
while not processing_error.is_set():
batch = batch_queue.get()
if batch is None:
break

try:
preds, orig_sizes = batch
func = executor.submit if parallelize else FakeParallel
for pred, orig_size in zip(preds, orig_sizes):
postprocessing_futures.append(func(parallel_get_lines, pred, orig_size))
except Exception as e:
processing_error.set()
print("Error with postprocessing", e)

# Start producer and consumer threads
producer = threading.Thread(target=inference_producer, daemon=True)
producer.start()

with ProcessPoolExecutor(
max_workers=max_workers,
mp_context=multiprocessing.get_context("spawn")
) if parallelize else contextlib.nullcontext() as executor:
consumer = threading.Thread(target=postprocessing_consumer, args=(executor,), daemon=True)
consumer.start()
producer.join()
consumer.join()

results = [future.result() for future in postprocessing_futures]

return results
78 changes: 48 additions & 30 deletions surya/layout.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import contextlib
import multiprocessing
import threading
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
Expand All @@ -10,6 +12,7 @@
from surya.postprocessing.heatmap import keep_largest_boxes, get_and_clean_boxes, get_detected_boxes
from surya.schema import LayoutResult, LayoutBox, TextDetectionResult
from surya.settings import settings
from surya.util.parallel import FakeParallel


def get_regions_from_detection_result(detection_result: TextDetectionResult, heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment, vertical_line_width=20) -> List[LayoutBox]:
Expand Down Expand Up @@ -192,40 +195,55 @@ def batch_layout_detection(images: List, model, processor, detection_results: Op
layout_generator = batch_detection(images, model, processor, batch_size=batch_size)
id2label = model.config.id2label

results = []
postprocessing_futures = []
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH
batch_queue = Queue()
processing_error = threading.Event()

def inference_producer():
try:
for batch in layout_generator:
batch_queue.put(batch)
if processing_error.is_set():
break
except Exception as e:
processing_error.set()
print("Error in layout detection producer", e)
finally:
batch_queue.put(None) # Signal end of batches

def postprocessing_consumer(executor):
img_idx = 0
while not processing_error.is_set():
batch = batch_queue.get()
if batch is None:
break

if parallelize:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
img_idx = 0
for preds, orig_sizes in layout_generator:
futures = []
try:
preds, orig_sizes = batch
for pred, orig_size in zip(preds, orig_sizes):
future = executor.submit(
parallel_get_regions,
pred,
orig_size,
id2label,
detection_results[img_idx] if detection_results else None
)

futures.append(future)
func = executor.submit if parallelize else FakeParallel
future = func(parallel_get_regions, pred, orig_size, id2label, detection_results[img_idx] if detection_results else None)
postprocessing_futures.append(future)
img_idx += 1

for future in futures:
results.append(future.result())
else:
img_idx = 0
for preds, orig_sizes in layout_generator:
for pred, orig_size in zip(preds, orig_sizes):
results.append(parallel_get_regions(
pred,
orig_size,
id2label,
detection_results[img_idx] if detection_results else None
))

img_idx += 1
except Exception as e:
processing_error.set()
print("Error in layout postprocessing", e)

# Start producer and consumer threads
producer = threading.Thread(target=inference_producer, daemon=True)
producer.start()

with ProcessPoolExecutor(
max_workers=max_workers,
mp_context=multiprocessing.get_context("spawn")
) if parallelize else contextlib.nullcontext() as executor:
consumer = threading.Thread(target=postprocessing_consumer, args=(executor,), daemon=True)
consumer.start()
producer.join()
consumer.join()

results = [future.result() for future in postprocessing_futures]

return results
7 changes: 6 additions & 1 deletion surya/model/recognition/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ def utf16_numbers_to_text(numbers):
byte_array.append(number & 0xFF) # Lower byte
byte_array.append((number >> 8) & 0xFF) # Upper byte

text = byte_array.decode('utf-16le', errors="ignore")
try:
text = byte_array.decode('utf-16le', errors="ignore")
except Exception as e:
print(f"Error decoding utf16: {e}")
text = ""

return text


Expand Down
6 changes: 6 additions & 0 deletions surya/util/parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class FakeParallel():
def __init__(self, func, *args):
self._result = func(*args)

def result(self):
return self._result

0 comments on commit 2dce355

Please sign in to comment.