Skip to content

Commit

Permalink
feat: fix encode error when reading yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Nov 3, 2024
1 parent b5258f8 commit c141639
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 36 deletions.
15 changes: 1 addition & 14 deletions .github/workflows/gen_whl_to_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,6 @@ name: Push rapid_latex_ocr to pypi

on:
push:
# branches: [ main ]
# paths:
# - 'rapid_latex_ocr/**'
# - 'docs/doc_whl.md'
# - 'setup.py'
# - '.github/workflows/gen_whl_to_pypi.yml'
tags:
- v*

Expand Down Expand Up @@ -51,14 +45,7 @@ jobs:
pip install -r requirements.txt
python -m pip install --upgrade pip
pip install wheel get_pypi_latest_version
python setup.py bdist_wheel ${{ github.ref_name }}
# - name: Publish distribution 📦 to Test PyPI
# uses: pypa/[email protected]
# with:
# password: ${{ secrets.TEST_PYPI_API_TOKEN }}
# repository_url: https://test.pypi.org/legacy/
# packages_dir: dist/
python setup.py bdist_wheel "${{ github.ref_name }}"
- name: Publish distribution 📦 to PyPI
uses: pypa/[email protected]
Expand Down
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ pip install rapid_latex_ocr
#### Used by python script

```python
from rapid_latex_ocr import LatexOCR
from rapid_latex_ocr import LaTeXOCR

model = LatexOCR()
model = LaTeXOCR()

img_path = "tests/test_files/6.png"
with open(img_path, "rb") as f:
Expand All @@ -90,6 +90,11 @@ $ rapid_latex_ocr tests/test_files/6.png
<details>
<summary>Click to expand</summary>

#### 2024-11-03 v0.0.9 update

- 修复读取配置文件编码问题
- 引入`dataclasses`类,简化参数传递

#### 2023-12-10 v0.0.6 update

- Fixed issue [#12](https://github.com/RapidAI/RapidLaTeXOCR/issues/12)
Expand Down
4 changes: 2 additions & 2 deletions demo.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
from rapid_latex_ocr import LatexOCR
from rapid_latex_ocr import LaTeXOCR

model = LatexOCR()
model = LaTeXOCR()

img_path = "tests/test_files/6.png"
with open(img_path, "rb") as f:
Expand Down
4 changes: 2 additions & 2 deletions rapid_latex_ocr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
from .main import LatexOCR
from .main import LaTeXOCR

__all__ = ["LatexOCR"]
__all__ = ["LaTeXOCR"]
35 changes: 25 additions & 10 deletions rapid_latex_ocr/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import time
import traceback
from dataclasses import dataclass
from pathlib import Path
from typing import Tuple, Union

Expand All @@ -13,14 +14,26 @@
from PIL import Image

from .models import EncoderDecoder
from .utils import PreProcess, TokenizerCls, DownloadModel
from .utils import DownloadModel, PreProcess, TokenizerCls, get_file_encode
from .utils_load import InputType, LoadImage, LoadImageError, OrtInferSession

cur_dir = Path(__file__).resolve().parent
DEFAULT_CONFIG = cur_dir / "config.yaml"


class LatexOCR:
@dataclass
class LaTeXOCRInput:
max_width: int = 672
max_height: int = 192
min_height: int = 32
min_width: int = 32
bos_token: int = 1
max_seq_len: int = 512
eos_token: int = 2
temperature: float = 0.00001


class LaTeXOCR:
def __init__(
self,
config_path: Union[str, Path] = DEFAULT_CONFIG,
Expand All @@ -36,12 +49,14 @@ def __init__(

self.get_model_path()

with open(config_path, "r", encoding="utf-8") as f:
file_encode = get_file_encode(config_path)
with open(config_path, "r", encoding=file_encode) as f:
args = yaml.load(f, Loader=yaml.FullLoader)
input_params = LaTeXOCRInput(**args)

self.max_dims = [args.get("max_width"), args.get("max_height")]
self.min_dims = [args.get("min_width", 32), args.get("min_height", 32)]
self.temperature = args.get("temperature", 0.00001)
self.max_dims = [input_params.max_width, input_params.max_height]
self.min_dims = [input_params.min_width, input_params.min_height]
self.temperature = input_params.temperature

self.load_img = LoadImage()

Expand All @@ -52,9 +67,9 @@ def __init__(
self.encoder_decoder = EncoderDecoder(
encoder_path=self.encoder_path,
decoder_path=self.decoder_path,
bos_token=args["bos_token"],
eos_token=args["eos_token"],
max_seq_len=args["max_seq_len"],
bos_token=input_params.bos_token,
eos_token=input_params.eos_token,
max_seq_len=input_params.max_seq_len,
)
self.tokenizer = TokenizerCls(self.tokenizer_json)

Expand Down Expand Up @@ -191,7 +206,7 @@ def main():
parser.add_argument("img_path", type=str, help="Only img path of the formula.")
args = parser.parse_args()

engine = LatexOCR(
engine = LaTeXOCR(
image_resizer_path=args.image_resizer_path,
encoder_path=args.encoder_path,
decoder_path=args.decoder_path,
Expand Down
20 changes: 16 additions & 4 deletions rapid_latex_ocr/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
import io
from pathlib import Path
from typing import List, Union, Optional
from typing import List, Optional, Union

import chardet
import cv2
import numpy as np
import requests
import tqdm
from PIL import Image
from tokenizers import Tokenizer
from tokenizers.models import BPE
import requests
import tqdm
import io


class PreProcess:
Expand Down Expand Up @@ -171,6 +172,17 @@ def save_file(save_path: Union[str, Path], file: bytes):
f.write(file)


def get_file_encode(file_path: Union[str, Path]) -> str:
try:
with open(file_path, "rb") as f:
raw_data = f.read(100)
result = chardet.detect(raw_data)
encoding = result["encoding"]
return encoding
except Exception:
return "utf-8"


if __name__ == "__main__":
downloader = DownloadModel()
downloader("decoder.onnx")
4 changes: 2 additions & 2 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
sys.path.append(str(root_dir))
import pytest

from rapid_latex_ocr import LatexOCR
from rapid_latex_ocr import LaTeXOCR

model = LatexOCR()
model = LaTeXOCR()

img_dir = cur_dir / "test_files"

Expand Down

0 comments on commit c141639

Please sign in to comment.