Skip to content

Commit

Permalink
🎉 Language models: try load from Gensim API if path KV load fails.
Browse files Browse the repository at this point in the history
  • Loading branch information
asaf-kali committed Jan 19, 2024
1 parent 69e55bc commit 8e9a13f
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 16 deletions.
9 changes: 8 additions & 1 deletion playground/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
log = logging.getLogger(__name__)

model_id = ModelIdentifier(language="english", model_name="wiki-50", is_stemmed=False)
# model_id = ModelIdentifier(language="english", model_name="glove-twitter-25", is_stemmed=False)
# model_id = ModelIdentifier(language="english", model_name="google-300", is_stemmed=False)
# model_id = ModelIdentifier(language="hebrew", model_name="twitter", is_stemmed=False)
# model_id = ModelIdentifier(language="hebrew", model_name="ft-200", is_stemmed=False)
Expand All @@ -46,7 +47,13 @@ def run_offline(board: Board = ENGLISH_BOARDS[2]): # noqa: F405
game_runner = None
try:
# blue_hinter = GPTHinter(name="Yoda", api_key=GPT_API_KEY)
blue_hinter = NaiveHinter(name="Yoda", team_color=TeamColor.BLUE, model_adapter=adapter, max_group_size=4)
blue_hinter = NaiveHinter(
name="Yoda",
team_color=TeamColor.BLUE,
model_identifier=model_id,
model_adapter=adapter,
max_group_size=4,
)
red_hinter = NaiveHinter(
name="Einstein",
team_color=TeamColor.RED,
Expand Down
2 changes: 1 addition & 1 deletion playground/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def print_results(game_runner: Optional[GameRunner]):

def _print_board(state: GameState):
log.info("")
log.info(f"{state.board}")
log.info(f"\n{state.board}")


def _print_moves(game_runner: GameRunner):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

[tool.poetry]
name = "codenames-solvers"
version = "1.6.1"
version = "1.7.0"
description = "Solvers implementation for Codenames board game in python."
authors = ["Michael Kali <[email protected]>", "Asaf Kali <[email protected]>"]
readme = "README.md"
Expand Down
35 changes: 22 additions & 13 deletions solvers/models/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from threading import Lock
from typing import Dict

import gensim.downloader as gensim_api
from generic_iterative_stemmer.models import StemmedKeyedVectors
from gensim.models import KeyedVectors

Expand All @@ -13,8 +14,8 @@


class ModelCache:
def __init__(self):
self.language_data_folder = "~/.cache/language_data"
def __init__(self, language_data_folder: str = "~/.cache/language_data"):
self.language_data_folder = language_data_folder
self._cache: Dict[ModelIdentifier, KeyedVectors] = {}
self._main_lock = Lock()
self._model_locks: Dict[ModelIdentifier, Lock] = {}
Expand All @@ -35,25 +36,33 @@ def load_model(self, model_identifier: ModelIdentifier) -> KeyedVectors:
return self._cache[model_identifier]

def _load_model(self, model_identifier: ModelIdentifier) -> KeyedVectors:
# TODO: in case loading fails, try gensim downloader
# import gensim.downloader as api
# model = api.load("wiki-he")
log.info("Loading model...", extra={"model": model_identifier.dict()})
language_base_folder = expanduser(os.path.join(self.language_data_folder, model_identifier.language))
model = load_kv_format(
language_base_folder=language_base_folder,
model_name=model_identifier.model_name,
is_stemmed=model_identifier.is_stemmed,
)
log.info("Model loaded", extra={"model": model_identifier.dict()})
return model
try:
return load_kv_format(
language_base_folder=language_base_folder,
model_name=model_identifier.model_name,
is_stemmed=model_identifier.is_stemmed,
)
except Exception as e:
log.warning(f"Failed to load model: {e}", exc_info=True)
return load_from_gensim(model_identifier)


def load_kv_format(language_base_folder: str, model_name: str, is_stemmed: bool = False) -> KeyedVectors:
model_folder = os.path.join(language_base_folder, model_name)
file_path = os.path.join(model_folder, "model.kv") # TODO: This needs fixing
file_path = os.path.join(model_folder, "model.kv")
log.debug(f"Looking for [{model_name}] in {file_path}...")
if is_stemmed:
model = StemmedKeyedVectors.load(file_path)
else:
model = KeyedVectors.load(file_path)
log.debug(f"Successfully loaded [{model_name}] from {file_path}")
return model


def load_from_gensim(model_identifier: ModelIdentifier) -> KeyedVectors:
log.debug(f"Looking for [{model_identifier.model_name}] in gensim API...")
model = gensim_api.load(model_identifier.model_name)
log.debug(f"Successfully loaded [{model_identifier.model_name}] from gensim API")
return model
3 changes: 3 additions & 0 deletions solvers/models/identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@ class ModelIdentifier(BaseModel):

def __hash__(self):
return hash(f"{self.language}-{self.model_name}-{self.is_stemmed}")

def __str__(self) -> str:
return f"{self.language}-{self.model_name}"

0 comments on commit 8e9a13f

Please sign in to comment.