From c141639869dff97bd9484266d42e7540b20cbe4e Mon Sep 17 00:00:00 2001 From: SWHL Date: Sun, 3 Nov 2024 22:14:40 +0800 Subject: [PATCH] feat: fix encode error when reading yaml --- .github/workflows/gen_whl_to_pypi.yml | 15 +----------- README.md | 9 +++++-- demo.py | 4 +-- rapid_latex_ocr/__init__.py | 4 +-- rapid_latex_ocr/main.py | 35 +++++++++++++++++++-------- rapid_latex_ocr/utils.py | 20 ++++++++++++--- tests/test_main.py | 4 +-- 7 files changed, 55 insertions(+), 36 deletions(-) diff --git a/.github/workflows/gen_whl_to_pypi.yml b/.github/workflows/gen_whl_to_pypi.yml index a542da2..d47d3c8 100644 --- a/.github/workflows/gen_whl_to_pypi.yml +++ b/.github/workflows/gen_whl_to_pypi.yml @@ -2,12 +2,6 @@ name: Push rapid_latex_ocr to pypi on: push: - # branches: [ main ] - # paths: - # - 'rapid_latex_ocr/**' - # - 'docs/doc_whl.md' - # - 'setup.py' - # - '.github/workflows/gen_whl_to_pypi.yml' tags: - v* @@ -51,14 +45,7 @@ jobs: pip install -r requirements.txt python -m pip install --upgrade pip pip install wheel get_pypi_latest_version - python setup.py bdist_wheel ${{ github.ref_name }} - - # - name: Publish distribution 📦 to Test PyPI - # uses: pypa/gh-action-pypi-publish@v1.5.0 - # with: - # password: ${{ secrets.TEST_PYPI_API_TOKEN }} - # repository_url: https://test.pypi.org/legacy/ - # packages_dir: dist/ + python setup.py bdist_wheel "${{ github.ref_name }}" - name: Publish distribution 📦 to PyPI uses: pypa/gh-action-pypi-publish@v1.5.0 diff --git a/README.md b/README.md index 82d9a88..c530b4a 100644 --- a/README.md +++ b/README.md @@ -62,9 +62,9 @@ pip install rapid_latex_ocr #### Used by python script ```python -from rapid_latex_ocr import LatexOCR +from rapid_latex_ocr import LaTeXOCR -model = LatexOCR() +model = LaTeXOCR() img_path = "tests/test_files/6.png" with open(img_path, "rb") as f: @@ -90,6 +90,11 @@ $ rapid_latex_ocr tests/test_files/6.png
Click to expand +#### 2024-11-03 v0.0.9 update + +- 修复读取配置文件编码问题 +- 引入`dataclasses`类,简化参数传递 + #### 2023-12-10 v0.0.6 update - Fixed issue [#12](https://github.com/RapidAI/RapidLaTeXOCR/issues/12) diff --git a/demo.py b/demo.py index dbc37d9..cdb048d 100644 --- a/demo.py +++ b/demo.py @@ -1,9 +1,9 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com -from rapid_latex_ocr import LatexOCR +from rapid_latex_ocr import LaTeXOCR -model = LatexOCR() +model = LaTeXOCR() img_path = "tests/test_files/6.png" with open(img_path, "rb") as f: diff --git a/rapid_latex_ocr/__init__.py b/rapid_latex_ocr/__init__.py index c69bd69..c4132a7 100644 --- a/rapid_latex_ocr/__init__.py +++ b/rapid_latex_ocr/__init__.py @@ -1,6 +1,6 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com -from .main import LatexOCR +from .main import LaTeXOCR -__all__ = ["LatexOCR"] +__all__ = ["LaTeXOCR"] diff --git a/rapid_latex_ocr/main.py b/rapid_latex_ocr/main.py index 7058cb3..23cf76e 100644 --- a/rapid_latex_ocr/main.py +++ b/rapid_latex_ocr/main.py @@ -5,6 +5,7 @@ import re import time import traceback +from dataclasses import dataclass from pathlib import Path from typing import Tuple, Union @@ -13,14 +14,26 @@ from PIL import Image from .models import EncoderDecoder -from .utils import PreProcess, TokenizerCls, DownloadModel +from .utils import DownloadModel, PreProcess, TokenizerCls, get_file_encode from .utils_load import InputType, LoadImage, LoadImageError, OrtInferSession cur_dir = Path(__file__).resolve().parent DEFAULT_CONFIG = cur_dir / "config.yaml" -class LatexOCR: +@dataclass +class LaTeXOCRInput: + max_width: int = 672 + max_height: int = 192 + min_height: int = 32 + min_width: int = 32 + bos_token: int = 1 + max_seq_len: int = 512 + eos_token: int = 2 + temperature: float = 0.00001 + + +class LaTeXOCR: def __init__( self, config_path: Union[str, Path] = DEFAULT_CONFIG, @@ -36,12 +49,14 @@ def __init__( self.get_model_path() - with open(config_path, "r", encoding="utf-8") as f: + file_encode = get_file_encode(config_path) + with open(config_path, "r", encoding=file_encode) as f: args = yaml.load(f, Loader=yaml.FullLoader) + input_params = LaTeXOCRInput(**args) - self.max_dims = [args.get("max_width"), args.get("max_height")] - self.min_dims = [args.get("min_width", 32), args.get("min_height", 32)] - self.temperature = args.get("temperature", 0.00001) + self.max_dims = [input_params.max_width, input_params.max_height] + self.min_dims = [input_params.min_width, input_params.min_height] + self.temperature = input_params.temperature self.load_img = LoadImage() @@ -52,9 +67,9 @@ def __init__( self.encoder_decoder = EncoderDecoder( encoder_path=self.encoder_path, decoder_path=self.decoder_path, - bos_token=args["bos_token"], - eos_token=args["eos_token"], - max_seq_len=args["max_seq_len"], + bos_token=input_params.bos_token, + eos_token=input_params.eos_token, + max_seq_len=input_params.max_seq_len, ) self.tokenizer = TokenizerCls(self.tokenizer_json) @@ -191,7 +206,7 @@ def main(): parser.add_argument("img_path", type=str, help="Only img path of the formula.") args = parser.parse_args() - engine = LatexOCR( + engine = LaTeXOCR( image_resizer_path=args.image_resizer_path, encoder_path=args.encoder_path, decoder_path=args.decoder_path, diff --git a/rapid_latex_ocr/utils.py b/rapid_latex_ocr/utils.py index cb07423..488f7b0 100644 --- a/rapid_latex_ocr/utils.py +++ b/rapid_latex_ocr/utils.py @@ -1,17 +1,18 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com +import io from pathlib import Path -from typing import List, Union, Optional +from typing import List, Optional, Union +import chardet import cv2 import numpy as np +import requests +import tqdm from PIL import Image from tokenizers import Tokenizer from tokenizers.models import BPE -import requests -import tqdm -import io class PreProcess: @@ -171,6 +172,17 @@ def save_file(save_path: Union[str, Path], file: bytes): f.write(file) +def get_file_encode(file_path: Union[str, Path]) -> str: + try: + with open(file_path, "rb") as f: + raw_data = f.read(100) + result = chardet.detect(raw_data) + encoding = result["encoding"] + return encoding + except Exception: + return "utf-8" + + if __name__ == "__main__": downloader = DownloadModel() downloader("decoder.onnx") diff --git a/tests/test_main.py b/tests/test_main.py index 6edb5a4..bab3db5 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -10,9 +10,9 @@ sys.path.append(str(root_dir)) import pytest -from rapid_latex_ocr import LatexOCR +from rapid_latex_ocr import LaTeXOCR -model = LatexOCR() +model = LaTeXOCR() img_dir = cur_dir / "test_files"