diff --git a/pyproject.toml b/pyproject.toml index a32fa3c..45d611f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ [tool.poetry] name = "codenames-solvers" -version = "1.7.3" +version = "1.7.5" description = "Solvers implementation for Codenames board game in python." authors = ["Michael Kali ", "Asaf Kali "] readme = "README.md" diff --git a/solvers/models/cache.py b/solvers/models/cache.py index af242ac..3450d11 100644 --- a/solvers/models/cache.py +++ b/solvers/models/cache.py @@ -13,6 +13,17 @@ log = logging.getLogger(__name__) +class ModelLoadError(Exception): + def __init__(self, model_identifier: ModelIdentifier, message: str): + self.model_identifier = model_identifier + super().__init__(f"Failed to load model {model_identifier}: {message}") + + +class ModelNotFoundError(ModelLoadError): + def __init__(self, model_identifier: ModelIdentifier): + super().__init__(model_identifier, "Model not found") + + class ModelCache: def __init__(self, language_data_folder: str = "~/.cache/language_data"): self.language_data_folder = language_data_folder @@ -49,9 +60,15 @@ def _load_model(self, model_identifier: ModelIdentifier) -> KeyedVectors: model_name=model_identifier.model_name, is_stemmed=model_identifier.is_stemmed, ) - except Exception as e: - log.warning(f"Failed to load model: {e}", exc_info=True) - return load_from_gensim(model_identifier) + except Exception as local_load_error: + log.warning(f"Failed to load local model: {local_load_error}") + try: + return load_from_gensim(model_identifier) + except Exception as gensim_load_error: + log.warning(f"Failed to load model from gensim: {gensim_load_error}") + if isinstance(local_load_error, FileNotFoundError): + raise ModelNotFoundError(model_identifier) from gensim_load_error + raise ModelLoadError(model_identifier, str(local_load_error)) from gensim_load_error def load_kv_format(language_base_folder: str, model_name: str, is_stemmed: bool = False) -> KeyedVectors: diff --git a/solvers/models/identifier.py b/solvers/models/identifier.py index 1d91d09..ee25cca 100644 --- a/solvers/models/identifier.py +++ b/solvers/models/identifier.py @@ -10,4 +10,4 @@ def __hash__(self): return hash(f"{self.language}-{self.model_name}-{self.is_stemmed}") def __str__(self) -> str: - return f"{self.language}-{self.model_name}" + return f"{self.language}/{self.model_name}"