Skip to content

Commit

Permalink
0.3.5 - add cache dir environment variable
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjoyo committed May 6, 2024
1 parent d5036a2 commit 66e56f5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
14 changes: 6 additions & 8 deletions bpm_ai_inference/ocr/tesseract.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

logger = logging.getLogger(__name__)

TESSDATA_DIR = "~/.bpm.ai/tessdata/"


@cachable()
class TesseractOCR(OCR):
Expand All @@ -32,8 +30,9 @@ class TesseractOCR(OCR):
def __init__(self):
if not has_pytesseract:
raise ImportError('pytesseract is not installed')
os.makedirs(TESSDATA_DIR, exist_ok=True)
os.environ["TESSDATA_PREFIX"] = TESSDATA_DIR
self.tessdata_dir = os.path.join(os.getenv("BPM_AI_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "bpm-ai")), "tessdata")
os.makedirs(self.tessdata_dir, exist_ok=True)
os.environ["TESSDATA_PREFIX"] = self.tessdata_dir

@override
async def _do_process(
Expand Down Expand Up @@ -82,12 +81,11 @@ def identify_image_language(self, image: Image) -> str:
text = pytesseract.image_to_string(image)
return indentify_language_iso_639_3(text) or "eng"

@staticmethod
def download_if_missing(lang: str):
def download_if_missing(self, lang: str):
lang_file = f'{lang}.traineddata'
tessdata_file_path = os.path.join(TESSDATA_DIR, lang_file)
tessdata_file_path = os.path.join(self.tessdata_dir, lang_file)
if not os.path.exists(tessdata_file_path):
logger.info(f'tesseract: {lang_file} not found in {TESSDATA_DIR}, downloading...')
logger.info(f'tesseract: {lang_file} not found in {self.tessdata_dir}, downloading...')
download_url = f'https://github.com/tesseract-ocr/tessdata_best/raw/main/{lang_file}'
urllib.request.urlretrieve(download_url, tessdata_file_path)
logger.info(f'tesseract: Downloaded {lang_file} to tessdata directory')
4 changes: 2 additions & 2 deletions bpm_ai_inference/translation/easy_nmt/easy_nmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def __init__(
self.config = None

if cache_folder is None:
if 'EASYNMT_CACHE' in os.environ:
cache_folder = os.environ['EASYNMT_CACHE']
if 'BPM_AI_CACHE_DIR' in os.environ:
cache_folder = os.path.join(os.environ['BPM_AI_CACHE_DIR'], 'easynmt_v2')
else:
cache_folder = os.path.join(torch.hub._get_torch_home(), 'easynmt_v2')
self._cache_folder = cache_folder
Expand Down

0 comments on commit 66e56f5

Please sign in to comment.