Skip to content

Commit

Permalink
avoid overriding itertools.batched, if any
Browse files Browse the repository at this point in the history
  • Loading branch information
bertsky authored Sep 27, 2024
1 parent 45e20b1 commit 842bd92
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions ocrd_calamari/recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,19 @@
TOOL = "ocrd-calamari-recognize"

BATCH_SIZE = 64
def batched(iterable, n):

def batched_length_limited(iterable, n, limit=32000):
# batched('ABCDEFG', 3) → ABC DEF G
if n < 1:
raise ValueError('n must be at least one')
iterator = iter(iterable)
while batch := tuple(itertools.islice(iterator, n)):
# implement poor man's batch bucketing to avoid OOM:
maxlen = max(image.shape[1] for image in batch)
if maxlen * n > 32000 and n > 1:
yield from batched(batch, n//2)
if maxlen * n > limit and n > 1:
yield from batched_length_limited(batch, n//2)
else:
yield batch
itertools.batched = batched

class CalamariRecognize(Processor):
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -195,7 +195,7 @@ def process(self):
# avoid too large a batch size (causing OOM on CPU or GPU)
fun = lambda x: self.predictor.predict_raw(x, progress_bar=False)
results = itertools.chain.from_iterable(
map(fun, itertools.batched(images, BATCH_SIZE)))
map(fun, batched_length_limited(images, BATCH_SIZE)))

for line, line_coords, raw_results in zip(lines, coords, results):
for i, p in enumerate(raw_results):
Expand Down

0 comments on commit 842bd92

Please sign in to comment.