Skip to content

Commit

Permalink
Merge pull request #31 from VikParuchuri/dev
Browse files Browse the repository at this point in the history
Truncate repetitions, dump out result as utf-8
  • Loading branch information
VikParuchuri authored Feb 16, 2024
2 parents 7917bfd + fb8a76f commit 534c237
Show file tree
Hide file tree
Showing 13 changed files with 252 additions and 132 deletions.
32 changes: 19 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,15 @@ surya_ocr DATA_PATH --images --langs hi,en
- `--max` specifies the maximum number of pages to process if you don't want to process everything
- `--start_page` specifies the page number to start processing from

The `results.json` file will contain these keys for each page of the input document(s):
The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains:

- `text_lines` - the detected text in each line
- `polys` - the polygons for each detected text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left.
- `bboxes` - the axis-aligned rectangles for each detected text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
- `language` - the languages specified for the page
- `name` - the name of the file
- `page_number` - the page number in the file
- `text_lines` - the detected text and bounding boxes for each line
- `text` - the text in the line
- `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left.
- `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
- `languages` - the languages specified for the page
- `page` - the page number in the file
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox.

**Performance tips**

Expand Down Expand Up @@ -120,13 +121,17 @@ surya_detect DATA_PATH --images
- `--max` specifies the maximum number of pages to process if you don't want to process everything
- `--results_dir` specifies the directory to save results to instead of the default

The `results.json` file will contain these keys for each page of the input document(s):
The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains:

- `polygons` - polygons for each detected text line (these are more accurate than the bboxes) in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left.
- `bboxes` - axis-aligned rectangles for each detected text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
- `vertical_lines` - vertical lines detected in the document in (x1, y1, x2, y2) format.
- `horizontal_lines` - horizontal lines detected in the document in (x1, y1, x2, y2) format.
- `page_number` - the page number of the document
- `bboxes` - detected bounding boxes for text
- `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
- `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left.
- `vertical_lines` - vertical lines detected in the document
- `bbox` - the axis-aligned line coordinates.
- `horizontal_lines` - horizontal lines detected in the document
- `bbox` - the axis-aligned line coordinates.
- `page` - the page number in the file
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox.

**Performance tips**

Expand All @@ -149,6 +154,7 @@ predictions = batch_detection([image], model, processor)
# Limitations

- This is specialized for document OCR. It will likely not work on photos or other images.
- Surya is for OCR - the goal is to recognize the text lines correctly, not sort them into reading order. Surya will attempt to sort the lines, which will work in many cases, but use something like [marker](https://github.com/VikParuchuri/marker) or other postprocessing if you need to order the text.
- It is for printed text, not handwriting (though it may work on some handwriting).
- The model has trained itself to ignore advertisements.
- You can find language support for OCR in `surya/languages.py`. Text detection should work with any language.
Expand Down
4 changes: 2 additions & 2 deletions benchmark/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def main():

page_metrics = collections.OrderedDict()
for idx, (tb, sb, cb) in enumerate(zip(tess_predictions, predictions, correct_boxes)):
surya_boxes = sb["bboxes"]
surya_polys = sb["polygons"]
surya_boxes = [s.bbox for s in sb.bboxes]
surya_polys = [s.polygon for s in sb.bboxes]

surya_metrics = precision_recall(surya_boxes, cb)
tess_metrics = precision_recall(tb, cb)
Expand Down
6 changes: 4 additions & 2 deletions benchmark/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def main():
surya_scores = defaultdict(list)
img_surya_scores = []
for idx, (pred, ref_text, lang) in enumerate(zip(predictions_by_image, line_text, lang_list)):
image_score = overlap_score(pred["text_lines"], ref_text)
pred_text = [l.text for l in pred.text_lines]
image_score = overlap_score(pred_text, ref_text)
img_surya_scores.append(image_score)
for l in lang:
surya_scores[CODE_TO_LANGUAGE[l]].append(image_score)
Expand Down Expand Up @@ -146,7 +147,8 @@ def main():
for idx, (image, pred, ref_text, bbox, lang) in enumerate(zip(images, predictions_by_image, line_text, bboxes, lang_list)):
pred_image_name = f"{'_'.join(lang)}_{idx}_pred.png"
ref_image_name = f"{'_'.join(lang)}_{idx}_ref.png"
pred_image = draw_text_on_image(bbox, pred["text_lines"], image.size)
pred_text = [l.text for l in pred.text_lines]
pred_image = draw_text_on_image(bbox, pred_text, image.size)
pred_image.save(os.path.join(result_path, pred_image_name))
ref_image = draw_text_on_image(bbox, ref_text, image.size)
ref_image.save(os.path.join(result_path, ref_image_name))
Expand Down
23 changes: 10 additions & 13 deletions detect_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,31 +38,28 @@ def main():

if args.images:
for idx, (image, pred, name) in enumerate(zip(images, predictions, names)):
bbox_image = draw_polys_on_image(pred["polygons"], copy.deepcopy(image))
polygons = [p.polygon for p in pred.bboxes]
bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image))
bbox_image.save(os.path.join(result_path, f"{name}_{idx}_bbox.png"))

column_image = draw_lines_on_image(pred["vertical_lines"], copy.deepcopy(image))
column_image = draw_lines_on_image(pred.vertical_lines, copy.deepcopy(image))
column_image.save(os.path.join(result_path, f"{name}_{idx}_column.png"))

if args.debug:
heatmap = pred["heatmap"]
heatmap = pred.heatmap
heatmap.save(os.path.join(result_path, f"{name}_{idx}_heat.png"))

affinity_map = pred["affinity_map"]
affinity_map = pred.affinity_map
affinity_map.save(os.path.join(result_path, f"{name}_{idx}_affinity.png"))

# Remove all the images from the predictions
for pred in predictions:
pred.pop("heatmap", None)
pred.pop("affinity_map", None)

predictions_by_page = defaultdict(list)
for idx, (pred, name) in enumerate(zip(predictions, names)):
pred["page_number"] = len(predictions_by_page[name]) + 1
predictions_by_page[name].append(pred)
for idx, (pred, name, image) in enumerate(zip(predictions, names, images)):
out_pred = pred.model_dump(exclude=["heatmap", "affinity_map"])
out_pred["page"] = len(predictions_by_page[name]) + 1
predictions_by_page[name].append(out_pred)

with open(os.path.join(result_path, "results.json"), "w+") as f:
json.dump(predictions_by_page, f)
json.dump(predictions_by_page, f, ensure_ascii=False)

print(f"Wrote results to {result_path}")

Expand Down
31 changes: 18 additions & 13 deletions ocr_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from PIL import Image
from surya.languages import CODE_TO_LANGUAGE
from surya.input.langs import replace_lang_with_code
from surya.schema import OCRResult, DetectionResult


@st.cache_resource()
Expand All @@ -24,18 +25,22 @@ def load_rec_cached():
return load_rec_model(), load_rec_processor()


def text_detection(img):
preds = batch_detection([img], det_model, det_processor)[0]
det_img = draw_polys_on_image(preds["polygons"], img.copy())
return det_img, preds
def text_detection(img) -> DetectionResult:
pred = batch_detection([img], det_model, det_processor)[0]
polygons = [p.polygon for p in pred.bboxes]
det_img = draw_polys_on_image(polygons, img.copy())
return det_img, pred


# Function for OCR
def ocr(img, langs):
def ocr(img, langs) -> OCRResult:
replace_lang_with_code(langs)
pred = run_ocr([img], [langs], det_model, det_processor, rec_model, rec_processor)[0]
rec_img = draw_text_on_image(pred["bboxes"], pred["text_lines"], img.size)
return rec_img, pred
img_pred = run_ocr([img], [langs], det_model, det_processor, rec_model, rec_processor)[0]

bboxes = [l.bbox for l in img_pred.text_lines]
text = [l.text for l in img_pred.text_lines]
rec_img = draw_text_on_image(bboxes, text, img.size)
return rec_img, img_pred


def open_pdf(pdf_file):
Expand Down Expand Up @@ -104,21 +109,21 @@ def page_count(pdf_file):

# Run Text Detection
if text_det and pil_image is not None:
det_img, preds = text_detection(pil_image)
det_img, pred = text_detection(pil_image)
with col1:
st.image(det_img, caption="Detected Text", use_column_width=True)
st.json(preds, expanded=True)
st.json(pred.model_dump(exclude=["heatmap", "affinity_map"]), expanded=True)

# Run OCR
if text_rec and pil_image is not None:
rec_img, pred = ocr(pil_image, languages)
with col1:
st.image(rec_img, caption="OCR Result", use_column_width=True)
json_tab, text_tab = st.tabs(["JSON", "Full Text"])
json_tab, text_tab = st.tabs(["JSON", "Text Lines (for debugging)"])
with json_tab:
st.json(pred, expanded=True)
st.json(pred.model_dump(), expanded=True)
with text_tab:
st.text("\n".join(pred["text_lines"]))
st.text("\n".join([p.text for p in pred.text_lines]))

with col2:
st.image(pil_image, caption="Uploaded Image", use_column_width=True)
18 changes: 10 additions & 8 deletions ocr_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,21 @@ def main():

predictions_by_image = run_ocr(images, image_langs, det_model, det_processor, rec_model, rec_processor)

page_num = defaultdict(int)
for i, pred in enumerate(predictions_by_image):
pred["name"] = names[i]
pred["page"] = page_num[names[i]]
page_num[names[i]] += 1

if args.images:
for idx, (name, image, pred) in enumerate(zip(names, images, predictions_by_image)):
page_image = draw_text_on_image(pred["bboxes"], pred["text_lines"], image.size)
bboxes = [l.bbox for l in pred.text_lines]
pred_text = [l.text for l in pred.text_lines]
page_image = draw_text_on_image(bboxes, pred_text, image.size)
page_image.save(os.path.join(result_path, f"{name}_{idx}_text.png"))

out_preds = defaultdict(list)
for name, pred, image in zip(names, predictions_by_image, images):
out_pred = pred.model_dump()
out_pred["page"] = len(out_preds[name]) + 1
out_preds[name].append(out_pred)

with open(os.path.join(result_path, "results.json"), "w+") as f:
json.dump(predictions_by_image, f)
json.dump(out_preds, f, ensure_ascii=False)

print(f"Wrote results to {result_path}")

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]
name = "surya-ocr"
version = "0.2.1"
description = "Document OCR models for multilingual text detection and recognition"
version = "0.2.2"
description = "OCR and line detection in 90+ languages"
authors = ["Vik Paruchuri <[email protected]>"]
readme = "README.md"
license = "GPL-3.0-or-later"
Expand Down
22 changes: 12 additions & 10 deletions surya/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from surya.postprocessing.heatmap import get_and_clean_boxes
from surya.postprocessing.affinity import get_vertical_lines, get_horizontal_lines
from surya.input.processing import prepare_image, split_image
from surya.schema import DetectionResult
from surya.settings import settings
from tqdm import tqdm

Expand All @@ -20,7 +21,7 @@ def get_batch_size():
return batch_size


def batch_detection(images: List, model, processor):
def batch_detection(images: List, model, processor) -> List[DetectionResult]:
assert all([isinstance(image, Image.Image) for image in images])
batch_size = get_batch_size()

Expand Down Expand Up @@ -94,18 +95,19 @@ def batch_detection(images: List, model, processor):
affinity_size = list(reversed(affinity_map.shape))
heatmap_size = list(reversed(heatmap.shape))
bboxes = get_and_clean_boxes(heatmap, heatmap_size, orig_sizes[i])
bbox_data = [bbox.model_dump() for bbox in bboxes]
vertical_lines = get_vertical_lines(affinity_map, affinity_size, orig_sizes[i])
horizontal_lines = get_horizontal_lines(affinity_map, affinity_size, orig_sizes[i])

results.append({
"bboxes": [bbd["bbox"] for bbd in bbox_data],
"polygons": [bbd["corners"] for bbd in bbox_data],
"vertical_lines": vertical_lines,
"horizontal_lines": horizontal_lines,
"heatmap": heat_img,
"affinity_map": aff_img,
})
result = DetectionResult(
bboxes=bboxes,
vertical_lines=vertical_lines,
horizontal_lines=horizontal_lines,
heatmap=heat_img,
affinity_map=aff_img,
image_bbox=[0, 0, orig_sizes[i][0], orig_sizes[i][1]]
)

results.append(result)

return results

Expand Down
63 changes: 43 additions & 20 deletions surya/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@

from surya.detection import batch_detection
from surya.input.processing import slice_polys_from_image, slice_bboxes_from_image
from surya.postprocessing.text import truncate_repetitions, sort_text_lines
from surya.recognition import batch_recognition
from surya.schema import TextLine, OCRResult


def run_recognition(images: List[Image.Image], langs: List[List[str]], rec_model, rec_processor, bboxes: List[List[List[int]]] = None, polygons: List[List[List[List[int]]]] = None):
def run_recognition(images: List[Image.Image], langs: List[List[str]], rec_model, rec_processor, bboxes: List[List[List[int]]] = None, polygons: List[List[List[List[int]]]] = None) -> List[OCRResult]:
# Polygons need to be in corner format - [[x1, y1], [x2, y2], [x3, y3], [x4, y4]], bboxes in [x1, y1, x2, y2] format
assert bboxes is not None or polygons is not None
slice_map = []
Expand All @@ -34,22 +36,30 @@ def run_recognition(images: List[Image.Image], langs: List[List[str]], rec_model
image_lines = rec_predictions[slice_start:slice_end]
slice_start = slice_end

pred = {
"text_lines": image_lines,
"language": lang
}

if polygons is not None:
pred["polys"] = polygons[idx]
else:
pred["bboxes"] = bboxes[idx]

text_lines = []
for i in range(len(image_lines)):
if polygons is not None:
poly = polygons[idx][i]
else:
bbox = bboxes[idx][i]
poly = [[bbox[0], bbox[1]], [bbox[2], bbox[1]], [bbox[2], bbox[3]], [bbox[0], bbox[3]]]

text_lines.append(TextLine(
text=image_lines[i],
polygon=poly
))

pred = OCRResult(
text_lines=text_lines,
languages=lang,
image_bbox=[0, 0, image.size[0], image.size[1]]
)
predictions_by_image.append(pred)

return predictions_by_image


def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_processor, rec_model, rec_processor):
def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_processor, rec_model, rec_processor) -> List[OCRResult]:
det_predictions = batch_detection(images, det_model, det_processor)
if det_model.device == "cuda":
torch.cuda.empty_cache() # Empty cache from first model run
Expand All @@ -58,7 +68,8 @@ def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_pr
all_slices = []
all_langs = []
for idx, (image, det_pred, lang) in enumerate(zip(images, det_predictions, langs)):
slices = slice_polys_from_image(image, det_pred["polygons"])
polygons = [p.polygon for p in det_pred.bboxes]
slices = slice_polys_from_image(image, polygons)
slice_map.append(len(slices))
all_slices.extend(slices)
all_langs.extend([lang] * len(slices))
Expand All @@ -72,12 +83,24 @@ def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_pr
image_lines = rec_predictions[slice_start:slice_end]
slice_start = slice_end

assert len(image_lines) == len(det_pred["polygons"]) == len(det_pred["bboxes"])
predictions_by_image.append({
"text_lines": image_lines,
"polys": det_pred["polygons"],
"bboxes": det_pred["bboxes"],
"language": lang
})
assert len(image_lines) == len(det_pred.bboxes)

# Remove repeated characters
image_lines = [truncate_repetitions(l) for l in image_lines]
lines = []
for text_line, bbox in zip(image_lines, det_pred.bboxes):
lines.append(TextLine(
text=text_line,
polygon=bbox.polygon,
bbox=bbox.bbox
))

lines = sort_text_lines(lines)

predictions_by_image.append(OCRResult(
text_lines=lines,
languages=lang,
image_bbox=det_pred.image_bbox
))

return predictions_by_image
Loading

0 comments on commit 534c237

Please sign in to comment.