-
Notifications
You must be signed in to change notification settings - Fork 30
/
hubconf.py
136 lines (105 loc) · 4.17 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
Configuration file for loading pretrained models using PyTorch hub
Usage example: torch.hub.load("castorini/howl", "hey_fire_fox")
"""
import os
import pathlib
import shutil
import typing
import zipfile
import torch
import howl.data.transform as transform
import howl.model as howl_model
from howl.context import InferenceContext
from howl.model.inference import FrameInferenceEngine, InferenceEngine
from howl.workspace import Workspace
dependencies = ["howl", "torch"]
_MODEL_URL = "https://github.com/castorini/howl-models/archive/v1.1.0.zip"
_MODEL_CACHE_FOLDER = "howl-models"
def hey_fire_fox(pretrained=True, **kwargs):
"""Pretrained model for Firefox Voice"""
engine, ctx = _load_model(pretrained, "res8", "howl/hey-fire-fox", **kwargs)
return engine, ctx
def _load_model(
pretrained: bool, model_name: str, workspace_path: str, device: str, **kwargs
) -> typing.Tuple[InferenceEngine, InferenceContext]:
"""
Loads howl model from a workspace
Arguments:
pretrained (bool): load pretrained model weights
model_name (str): name of the model to use
workspace_path (str): relative path to workspace from root of howl-models
Returns the inference engine and context
"""
# Separate `reload_models` flag since PyTorch will pop the 'force_reload' flag
reload_models = kwargs.pop("reload_models", False)
cached_folder = _download_howl_models(reload_models)
workspace_path = pathlib.Path(cached_folder) / workspace_path
workspace = Workspace(workspace_path, delete_existing=False)
# Load model settings
settings = workspace.load_settings()
# Set up context
use_frame = settings.training.objective == "frame"
ctx = InferenceContext(
vocab=settings.training.vocab, token_type=settings.training.token_type, use_blank=not use_frame
)
# Load models
zmuv_transform = transform.ZmuvTransform()
model = howl_model.RegisteredModel.find_registered_class(model_name)(ctx.num_labels).eval()
# Load pretrained weights
if pretrained:
zmuv_transform.load_state_dict(
torch.load(str(workspace.path / "zmuv.pt.bin"), map_location=torch.device(device))
)
workspace.load_model(model, best=True)
# Load engine
model.streaming()
if use_frame:
engine = FrameInferenceEngine(
int(settings.training.max_window_size_seconds * 1000),
int(settings.training.eval_stride_size_seconds * 1000),
model,
zmuv_transform,
ctx,
)
else:
engine = InferenceEngine(model, zmuv_transform, ctx)
return engine, ctx
def _download_howl_models(reload_models: bool) -> str:
"""
Download Howl models from Github release: https://github.com/castorini/howl-models
Arguments:
reload_models (bool): force reload if models are already cached
Returns the cached howl models path
"""
# Create base cache directory
base_dir = pathlib.Path.home() / ".cache/howl"
if not os.path.exists(base_dir):
os.makedirs(base_dir)
cached_folder = os.path.join(base_dir, _MODEL_CACHE_FOLDER)
# Check if existing cache should be used
use_cache = (not reload_models) and os.path.exists(cached_folder)
if not use_cache:
# Clear cache
zip_path = os.path.join(base_dir, _MODEL_CACHE_FOLDER + ".zip")
_remove_files(cached_folder)
_remove_files(zip_path)
print("Downloading howl-models...")
torch.hub.download_url_to_file(_MODEL_URL, zip_path, progress=True)
# Extract files into folder
with zipfile.ZipFile(zip_path) as model_zipfile:
# Find name of extracted folder
extracted_name = model_zipfile.infolist()[0].filename
extracted_path = os.path.join(base_dir, extracted_name)
_remove_files(extracted_path)
model_zipfile.extractall(base_dir)
# Rename folder
shutil.move(extracted_path, cached_folder)
return cached_folder
def _remove_files(path):
# Remove file or folder if it exists
if os.path.exists(path):
if os.path.isfile(path):
os.remove(path)
else:
shutil.rmtree(path)