Skip to content

Commit

Permalink
Optim code
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Sep 14, 2023
1 parent 34785ce commit 9a087e4
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 103 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,6 @@
### Changlog
- 2023-09-13 v0.0.4 update:
- Merge [pr #5](https://github.com/RapidAI/RapidLatexOCR/pull/5)
- Optim code
- 2023-07-15 v0.0.1 update:
- First release
130 changes: 65 additions & 65 deletions docs/doc_whl.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,70 +17,70 @@
- Welcome all friends to actively contribute to make this tool better.


### Use
1. Installation
1. pip install `rapid_latext_ocr` library. Because packaging the model into the whl package exceeds the pypi limit (100M), the model needs to be downloaded separately.
```bash
pip install rapid_latex_ocr
```
2. Download the model ([Google Drive](https://drive.google.com/drive/folders/1e8BgLk1cPQDSZjgoLgloFYMAQWLTaroQ?usp=sharing) | [Baidu NetDisk](https://pan.baidu.com/s/1rnYmmKp2HhOkYVFehUiMNg?pwd=dh72)), when initializing, just specify the model path, see the next part for details.

|model name|size|
|---:|:---:|
|`image_resizer.onnx`|37.1M|
|`encoder.onnx`|84.8M|
|`decoder.onnx`|48.5M|

2. Use
- Used by python script:
```python
from rapid_latex_ocr import LatexOCR
image_resizer_path = 'models/image_resizer.onnx'
encoder_path = 'models/encoder.onnx'
decoder_path = 'models/decoder.onnx'
tokenizer_json = 'models/tokenizer.json'
model = LatexOCR(image_resizer_path=image_resizer_path,
encoder_path=encoder_path,
decoder_path=decoder_path,
tokenizer_json=tokenizer_json)
img_path = "tests/test_files/6.png"
with open(img_path, "rb") as f:
data = f. read()
result, elapse = model(data)
print(result)
# {\frac{x^{2}}{a^{2}}}-{\frac{y^{2}}{b^{2}}}=1
print(elapse)
# 0.4131628000000003
```
- Used by command line.
```bash
$ rapid_latex_ocr -h
usage: rapid_latex_ocr [-h] [-img_resizer IMAGE_RESIZER_PATH]
[-encdoer ENCODER_PATH] [-decoder DECODER_PATH]
[-tokenizer TOKENIZER_JSON]
img_path
positional arguments:
img_path Only img path of the formula.
optional arguments:
-h, --help show this help message and exit
-img_resizer IMAGE_RESIZER_PATH, --image_resizer_path IMAGE_RESIZER_PATH
-encdoer ENCODER_PATH, --encoder_path ENCODER_PATH
-decoder DECODER_PATH, --decoder_path DECODER_PATH
-tokenizer TOKENIZER_JSON, --tokenizer_json TOKENIZER_JSON
$ rapid_latex_ocr tests/test_files/6.png \
-img_resizer models/image_resizer.onnx \
-encoder models/encoder.onnx \
-dedocer models/decoder.onnx \
-tokenizer models/tokenizer.json
# ('{\\frac{x^{2}}{a^{2}}}-{\\frac{y^{2}}{b^{2}}}=1', 0.47902780000000034)
```
### Installation
1. pip install `rapid_latext_ocr` library. Because packaging the model into the whl package exceeds the pypi limit (100M), the model needs to be downloaded separately.
```bash
pip install rapid_latex_ocr
```
2. Download the model ([Google Drive](https://drive.google.com/drive/folders/1e8BgLk1cPQDSZjgoLgloFYMAQWLTaroQ?usp=sharing) | [Baidu NetDisk](https://pan.baidu.com/s/1rnYmmKp2HhOkYVFehUiMNg?pwd=dh72)), when initializing, just specify the model path, see the next part for details.

|model name|size|
|---:|:---:|
|`image_resizer.onnx`|37.1M|
|`encoder.onnx`|84.8M|
|`decoder.onnx`|48.5M|


### Usage
- Used by python script:
```python
from rapid_latex_ocr import LatexOCR
image_resizer_path = 'models/image_resizer.onnx'
encoder_path = 'models/encoder.onnx'
decoder_path = 'models/decoder.onnx'
tokenizer_json = 'models/tokenizer.json'
model = LatexOCR(image_resizer_path=image_resizer_path,
encoder_path=encoder_path,
decoder_path=decoder_path,
tokenizer_json=tokenizer_json)
img_path = "tests/test_files/6.png"
with open(img_path, "rb") as f:
data = f. read()
result, elapse = model(data)
print(result)
# {\frac{x^{2}}{a^{2}}}-{\frac{y^{2}}{b^{2}}}=1
print(elapse)
# 0.4131628000000003
```
- Used by command line.
```bash
$ rapid_latex_ocr -h
usage: rapid_latex_ocr [-h] [-img_resizer IMAGE_RESIZER_PATH]
[-encdoer ENCODER_PATH] [-decoder DECODER_PATH]
[-tokenizer TOKENIZER_JSON]
img_path
positional arguments:
img_path Only img path of the formula.
optional arguments:
-h, --help show this help message and exit
-img_resizer IMAGE_RESIZER_PATH, --image_resizer_path IMAGE_RESIZER_PATH
-encdoer ENCODER_PATH, --encoder_path ENCODER_PATH
-decoder DECODER_PATH, --decoder_path DECODER_PATH
-tokenizer TOKENIZER_JSON, --tokenizer_json TOKENIZER_JSON
$ rapid_latex_ocr tests/test_files/6.png \
-img_resizer models/image_resizer.onnx \
-encoder models/encoder.onnx \
-dedocer models/decoder.onnx \
-tokenizer models/tokenizer.json
# ('{\\frac{x^{2}}{a^{2}}}-{\\frac{y^{2}}{b^{2}}}=1', 0.47902780000000034)
```

### See details for [RapidLatexOCR](https://github.com/RapidAI/RapidLatexOCR)
18 changes: 11 additions & 7 deletions rapid_latex_ocr/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
if tokenizer_json is None:
raise FileNotFoundError("tokenizer_json must not be None.")

with open(config_path, "r") as f:
with open(config_path, "r", encoding="utf-8") as f:
args = yaml.load(f, Loader=yaml.FullLoader)

self.max_dims = [args.get("max_width"), args.get("max_height")]
Expand All @@ -68,23 +68,27 @@ def __call__(self, img: InputType) -> Tuple[str, float]:

try:
img = self.load_img(img)
except LoadImageError:
except LoadImageError as exc:
error_info = traceback.format_exc()
raise LoadImageError(
f"Load the img meets error. Error info is {error_info}"
)
) from exc

try:
resizered_img = self.loop_image_resizer(img)
except Exception:
except Exception as e:
error_info = traceback.format_exc()
raise ValueError(f"image resizer meets error. Error info is {error_info}")
raise ValueError(
f"image resizer meets error. Error info is {error_info}"
) from e

try:
dec = self.encoder_decoder(resizered_img, temperature=self.temperature)
except Exception:
except Exception as e:
error_info = traceback.format_exc()
raise ValueError(f"EncoderDecoder meets error. Error info is {error_info}")
raise ValueError(
f"EncoderDecoder meets error. Error info is {error_info}"
) from e

decode = self.tokenizer.token2str(dec)
pred = self.post_process(decode[0])
Expand Down
79 changes: 48 additions & 31 deletions rapid_latex_ocr/utils_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, model_path: Union[str, Path], num_threads: int = -1):
str(model_path), sess_options=self.sess_opt, providers=EP_list
)
except TypeError:
# 这里兼容ort 1.5.2
# compatible with onnxruntime 1.5.2
self.session = InferenceSession(str(model_path), sess_options=self.sess_opt)

def _init_sess_opt(self):
Expand Down Expand Up @@ -82,63 +82,66 @@ class ONNXRuntimeError(Exception):


class LoadImage:
def __init__(
self,
):
pass

def __call__(self, img: InputType) -> np.ndarray:
if not isinstance(img, InputType.__args__):
raise LoadImageError(
f"The img type {type(img)} does not in {InputType.__args__}"
)

img = self.load_img(img)

if img.ndim == 2:
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

if img.ndim == 3 and img.shape[2] == 4:
return self.cvt_four_to_three(img)

img = self.convert_img(img)
return img
# 支持背景为透明的png图片,nparray没跑通注释了,交由后来人吧
def is_image_transparent(self, img):
if img.mode == "RGBA":
# 如果图像是四通道的,抓取alpha通道
alpha = img.split()[3]
# 利用alpha通道的getextrema()函数获取图像的最小和最大alpha值
min_alpha = alpha.getextrema()[0]
# 图像即为透明,如果最小alpha值小于255(即存在alpha值为0,即透明像素)
# 创建一个白色背景图像
if min_alpha < 255:
bg = Image.new('RGBA', img.size, (255, 255, 255, 255))

# 合并背景图像与源图片
final_img = Image.alpha_composite(bg, img)
return final_img
return img
else:
return img # 不是四通道图像,即没有透明度

def load_img(self, img: InputType) -> np.ndarray:
if isinstance(img, (str, Path)):
self.verify_exist(img)
try:
img = np.array(self.is_image_transparent(Image.open(img)))
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
img = np.array(Image.open(img))
except UnidentifiedImageError as e:
raise LoadImageError(f"cannot identify image file {img}") from e
return img

if isinstance(img, bytes):
img = np.array(self.is_image_transparent(Image.open(BytesIO(img))))
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
img = np.array(Image.open(BytesIO(img)))
return img

if isinstance(img, np.ndarray):
return img

raise LoadImageError(f"{type(img)} is not supported!")

def convert_img(self, img: np.ndarray):
if img.ndim == 2:
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

if img.ndim == 3:
channel = img.shape[2]
if channel == 1:
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

if channel == 2:
return self.cvt_two_to_three(img)

if channel == 4:
return self.cvt_four_to_three(img)

if channel == 3:
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

raise LoadImageError(
f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
)

raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")

@staticmethod
def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
"""RGBA → RGB"""
"""RGBA → BGR"""
r, g, b, a = cv2.split(img)
new_img = cv2.merge((b, g, r))

Expand All @@ -149,6 +152,20 @@ def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
new_img = cv2.add(new_img, not_a)
return new_img

@staticmethod
def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
"""gray + alpha → BGR"""
img_gray = img[..., 0]
img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)

img_alpha = img[..., 1]
not_a = cv2.bitwise_not(img_alpha)
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)

new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha)
new_img = cv2.add(new_img, not_a)
return new_img

@staticmethod
def verify_exist(file_path: Union[str, Path]):
if not Path(file_path).exists():
Expand Down

0 comments on commit 9a087e4

Please sign in to comment.