forked from hiroi-sora/GapTree_Sort_Algorithm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
166 lines (129 loc) · 5.53 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# 测试代码
import os
import json
import time
from pathlib import Path
# 测试图片
# test_image = "test/1.png" # 单栏布局,含有表格
test_image = "test/2.png" # 双栏布局,含有跨列的表格、图片
# test_image = "test/3.png" # 双栏布局,含有大量列内图片、表格
# test_image = "test/4.png" # 四栏布局(两页拼接),含有跨列标题
# test_image = "test/5.png" # 三栏布局,栏宽度差异大
# test_image = "test/5_r.png" # 样例5的旋转版本
# test_image = "test/6.png" # 竖排,双栏布局。(可视化显示可能有问题,但结果顺序是对的)
# test_image = "test/6_r.png" # 样例6的旋转版本
# 如果使用自己的图片,需要将OCR引擎【RapidOCR_json】放在本目录下。
# https://github.com/hiroi-sora/RapidOCR-json
# 也可以使用另外的任何OCR方式,或者从PDF中提取的文本。
# 要求:每个参与排序的元素块,都必须提供矩形包围盒的左上角和右下角坐标。
# ======================= 测试:获取文本块 =====================
def get_ocr_cache(image_path): # 加载OCR缓存json文件
absolute_path = os.path.abspath(image_path)
# 尝试查找与图片同名的OCR结果缓存文件。
json_path = Path(absolute_path).with_suffix(".json")
try:
with open(json_path, "r", encoding="utf-8") as file:
json_data = file.read()
input_dict = json.loads(json_data)
if input_dict["code"] == 100:
return input_dict["data"]
except Exception:
pass
return None
def get_rapidocr_json(image_path):
# 调用 RapidOCR-json 引擎。下载:
# https://github.com/hiroi-sora/RapidOCR-json
try:
from rapidocr import Rapid_pipe
ocr = Rapid_pipe("RapidOCR_json/RapidOCR-json.exe", {"maxSideLen": 4096})
absolute_path = os.path.abspath(image_path)
result = ocr.run(absolute_path)
if "code" in result and result["code"] == 100:
json_path = Path(absolute_path).with_suffix(".json")
print(f"OCR成功,获取{len(result['data'])}个文本块。")
print(f"写入缓存:{json_path}")
with open(json_path, "w", encoding="utf-8") as file:
file.write(json.dumps(result))
return result["data"]
except Exception:
pass
return None
text_blocks = get_ocr_cache(test_image)
if not text_blocks:
print(f"未获取缓存,尝试重新进行OCR。")
text_blocks = get_rapidocr_json(test_image)
if not text_blocks:
print(f"RapidOCR-json 调用失败。")
exit()
# ======================= 标准化bbox =====================
# 将旋转、竖排等文本块,转为标准横排。此步可忽略。
from preprocessing import linePreprocessing
t1 = time.time()
bboxes = linePreprocessing(text_blocks)
t2 = time.time()
print(f"预处理完毕。共{len(text_blocks)}个文本块,耗时{(t2-t1):.{8}f}s")
for i, tb in enumerate(text_blocks):
tb["bbox"] = bboxes[i] # 写入标准化的bbox
# ======================= 调用间隙树算法进行排序 =====================
from gap_tree import GapTree # gap_tree.py
def tb_bbox(tb): # 从文本块对象中,提取左上角、右下角坐标元组
b = tb["box"]
return (b[0][0], b[0][1], b[2][0], b[2][1])
gtree = GapTree(lambda tb: tb["bbox"])
t1 = time.time()
sorted_text_blocks = gtree.sort(text_blocks) # 输入文本块,进行排序
t2 = time.time()
print(f"排序完毕。共{len(text_blocks)}个文本块,耗时{(t2-t1):.{8}f}s")
# ======================= 进一步:区块内分析段落关系 =====================
from paragraph_parse import ParagraphParse
def get_info(tb): # 返回信息
b = tb["box"]
return ((b[0][0], b[0][1], b[2][0], b[2][1]), tb["text"])
def set_end(tb, end): # 获取预测的块尾分隔符
tb["end"] = end
# 也可以: tb["text"] += end
pp = ParagraphParse(get_info, set_end)
# 获取所有区块的文本块
nodes_text_blocks = gtree.get_nodes_text_blocks()
for tbs in nodes_text_blocks:
tbs = pp.run(tbs) # 预测结尾分隔符
for tb in tbs: # 输出文本和结尾分隔符
print(tb["text"], end=tb["end"])
print()
# ======================= 测试:结果可视化 =====================
try:
from visualize import visualize
except Exception:
print("无法加载结果可视化模块")
exit()
# 原始OCR预览图
pil_origin = visualize(text_blocks, test_image).get(isOrder=True)
# 排序后的预览图
pil_sorted = visualize(sorted_text_blocks, test_image).get(isOrder=True)
# 左右拼接 1
pil_show_1 = visualize.createContrast(pil_origin, pil_sorted)
# 竖切线 预览图
cut_tbs = []
for c in gtree.current_cuts:
x0 = c[0]
x1 = c[1]
y0 = gtree.current_rows[c[2]][0][0][1]
y1 = gtree.current_rows[c[3]][0][0][3]
cut_tbs.append({"box": [[x0, y0], [x1, y0], [x1, y1], [x0, y1]], "text": ""})
pil_cuts = visualize(cut_tbs, test_image).get(isOrder=True)
# 树节点 预览图
node_tbs = []
for node in gtree.current_nodes:
if not node["units"]:
continue # 跳过没有块的根节点
x0 = node["x_left"]
x1 = node["x_right"]
y0 = gtree.current_rows[node["r_top"]][0][0][1]
y1 = gtree.current_rows[node["r_bottom"]][0][0][3]
node_tbs.append({"box": [[x0, y0], [x1, y0], [x1, y1], [x0, y1]], "text": ""})
pil_nodes = visualize(node_tbs, test_image).get(isOrder=True)
# 左右拼接 2
pil_show_2 = visualize.createContrast(pil_cuts, pil_nodes)
print("可视化展示")
pil_show = visualize.createContrast(pil_show_1, pil_show_2)
pil_show.show()