Skip to content

Commit

Permalink
Merge pull request #223 from VikParuchuri/dev
Browse files Browse the repository at this point in the history
Revert again
  • Loading branch information
VikParuchuri authored Oct 23, 2024
2 parents 2dce355 + 9076927 commit 54b6299
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 91 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "surya-ocr"
version = "0.6.9"
version = "0.6.10"
description = "OCR, layout, reading order, and table recognition in 90+ languages"
authors = ["Vik Paruchuri <[email protected]>"]
readme = "README.md"
Expand Down
57 changes: 12 additions & 45 deletions surya/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,55 +128,22 @@ def parallel_get_lines(preds, orig_sizes):
return result



def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]:
detection_generator = batch_detection(images, model, processor, batch_size=batch_size)

postprocessing_futures = []
results = []
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH
batch_queue = Queue()
processing_error = threading.Event()

def inference_producer():
try:
for batch in detection_generator:
batch_queue.put(batch)
if processing_error.is_set():
break
except Exception as e:
processing_error.set()
print("Error with batch detection", e)
finally:
batch_queue.put(None) # Signal end of batches

def postprocessing_consumer(executor):
while not processing_error.is_set():
batch = batch_queue.get()
if batch is None:
break

try:
preds, orig_sizes = batch
func = executor.submit if parallelize else FakeParallel
for pred, orig_size in zip(preds, orig_sizes):
postprocessing_futures.append(func(parallel_get_lines, pred, orig_size))
except Exception as e:
processing_error.set()
print("Error with postprocessing", e)

# Start producer and consumer threads
producer = threading.Thread(target=inference_producer, daemon=True)
producer.start()

with ProcessPoolExecutor(
max_workers=max_workers,
mp_context=multiprocessing.get_context("spawn")
) if parallelize else contextlib.nullcontext() as executor:
consumer = threading.Thread(target=postprocessing_consumer, args=(executor,), daemon=True)
consumer.start()
producer.join()
consumer.join()

results = [future.result() for future in postprocessing_futures]

if parallelize:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
for preds, orig_sizes in detection_generator:
batch_results = list(executor.map(parallel_get_lines, preds, orig_sizes))
results.extend(batch_results)
else:
for preds, orig_sizes in detection_generator:
for pred, orig_size in zip(preds, orig_sizes):
results.append(parallel_get_lines(pred, orig_size))

return results
75 changes: 30 additions & 45 deletions surya/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,55 +195,40 @@ def batch_layout_detection(images: List, model, processor, detection_results: Op
layout_generator = batch_detection(images, model, processor, batch_size=batch_size)
id2label = model.config.id2label

postprocessing_futures = []
results = []
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH
batch_queue = Queue()
processing_error = threading.Event()

def inference_producer():
try:
for batch in layout_generator:
batch_queue.put(batch)
if processing_error.is_set():
break
except Exception as e:
processing_error.set()
print("Error in layout detection producer", e)
finally:
batch_queue.put(None) # Signal end of batches

def postprocessing_consumer(executor):
img_idx = 0
while not processing_error.is_set():
batch = batch_queue.get()
if batch is None:
break

try:
preds, orig_sizes = batch
if parallelize:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
img_idx = 0
for preds, orig_sizes in layout_generator:
futures = []
for pred, orig_size in zip(preds, orig_sizes):
func = executor.submit if parallelize else FakeParallel
future = func(parallel_get_regions, pred, orig_size, id2label, detection_results[img_idx] if detection_results else None)
postprocessing_futures.append(future)
future = executor.submit(
parallel_get_regions,
pred,
orig_size,
id2label,
detection_results[img_idx] if detection_results else None
)

futures.append(future)
img_idx += 1
except Exception as e:
processing_error.set()
print("Error in layout postprocessing", e)

# Start producer and consumer threads
producer = threading.Thread(target=inference_producer, daemon=True)
producer.start()

with ProcessPoolExecutor(
max_workers=max_workers,
mp_context=multiprocessing.get_context("spawn")
) if parallelize else contextlib.nullcontext() as executor:
consumer = threading.Thread(target=postprocessing_consumer, args=(executor,), daemon=True)
consumer.start()
producer.join()
consumer.join()

results = [future.result() for future in postprocessing_futures]

for future in futures:
results.append(future.result())
else:
img_idx = 0
for preds, orig_sizes in layout_generator:
for pred, orig_size in zip(preds, orig_sizes):
results.append(parallel_get_regions(
pred,
orig_size,
id2label,
detection_results[img_idx] if detection_results else None
))

img_idx += 1

return results

0 comments on commit 54b6299

Please sign in to comment.