diff --git a/bpm_ai_inference/ocr/tesseract.py b/bpm_ai_inference/ocr/tesseract.py index ec6141c..b81637d 100644 --- a/bpm_ai_inference/ocr/tesseract.py +++ b/bpm_ai_inference/ocr/tesseract.py @@ -19,8 +19,6 @@ logger = logging.getLogger(__name__) -TESSDATA_DIR = "~/.bpm.ai/tessdata/" - @cachable() class TesseractOCR(OCR): @@ -32,8 +30,9 @@ class TesseractOCR(OCR): def __init__(self): if not has_pytesseract: raise ImportError('pytesseract is not installed') - os.makedirs(TESSDATA_DIR, exist_ok=True) - os.environ["TESSDATA_PREFIX"] = TESSDATA_DIR + self.tessdata_dir = os.path.join(os.getenv("BPM_AI_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "bpm-ai")), "tessdata") + os.makedirs(self.tessdata_dir, exist_ok=True) + os.environ["TESSDATA_PREFIX"] = self.tessdata_dir @override async def _do_process( @@ -82,12 +81,11 @@ def identify_image_language(self, image: Image) -> str: text = pytesseract.image_to_string(image) return indentify_language_iso_639_3(text) or "eng" - @staticmethod - def download_if_missing(lang: str): + def download_if_missing(self, lang: str): lang_file = f'{lang}.traineddata' - tessdata_file_path = os.path.join(TESSDATA_DIR, lang_file) + tessdata_file_path = os.path.join(self.tessdata_dir, lang_file) if not os.path.exists(tessdata_file_path): - logger.info(f'tesseract: {lang_file} not found in {TESSDATA_DIR}, downloading...') + logger.info(f'tesseract: {lang_file} not found in {self.tessdata_dir}, downloading...') download_url = f'https://github.com/tesseract-ocr/tessdata_best/raw/main/{lang_file}' urllib.request.urlretrieve(download_url, tessdata_file_path) logger.info(f'tesseract: Downloaded {lang_file} to tessdata directory') diff --git a/bpm_ai_inference/translation/easy_nmt/easy_nmt.py b/bpm_ai_inference/translation/easy_nmt/easy_nmt.py index fb7ee5e..473955b 100644 --- a/bpm_ai_inference/translation/easy_nmt/easy_nmt.py +++ b/bpm_ai_inference/translation/easy_nmt/easy_nmt.py @@ -73,8 +73,8 @@ def __init__( self.config = None if cache_folder is None: - if 'EASYNMT_CACHE' in os.environ: - cache_folder = os.environ['EASYNMT_CACHE'] + if 'BPM_AI_CACHE_DIR' in os.environ: + cache_folder = os.path.join(os.environ['BPM_AI_CACHE_DIR'], 'easynmt_v2') else: cache_folder = os.path.join(torch.hub._get_torch_home(), 'easynmt_v2') self._cache_folder = cache_folder