-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodelio.py
83 lines (68 loc) · 2.83 KB
/
modelio.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import torch.nn as nn
import inspect
import functools
def store_config_args(func):
"""
Class-method decorator that saves every argument provided to the
function as a dictionary in 'self.config'. This is used to assist
model loading - see LoadableModel.
"""
# attrs, varargs, varkw, defaults = inspect.getargspec(func)
fullargspec = inspect.getfullargspec(func)
attrs = fullargspec.args
varargs = fullargspec.varargs
varkw = fullargspec.varkw
defaults = fullargspec.defaults
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
self.config = {}
# first save the default values
if defaults:
for attr, val in zip(reversed(attrs), reversed(defaults)):
self.config[attr] = val
# next handle positional args
for attr, val in zip(attrs[1:], args):
self.config[attr] = val
# lastly handle keyword args
if kwargs:
for attr, val in kwargs.items():
self.config[attr] = val
return func(self, *args, **kwargs)
return wrapper
class LoadableModel(nn.Module):
"""
Base class for easy pytorch model loading without having to manually
specify the architecture configuration at load time.
We can cache the arguments used to the construct the initial network, so that
we can construct the exact same network when loading from file. The arguments
provided to __init__ are automatically saved into the object (in self.config)
if the __init__ method is decorated with the @store_config_args utility.
"""
# this constructor just functions as a check to make sure that every
# LoadableModel subclass has provided an internal config parameter
# either manually or via store_config_args
def __init__(self, *args, **kwargs):
if not hasattr(self, 'config'):
raise RuntimeError('models that inherit from LoadableModel must decorate the '
'constructor with @store_config_args')
super().__init__(*args, **kwargs)
def save(self, path):
"""
Saves the model configuration and weights to a pytorch file.
"""
# don't save the transformer_grid buffers - see SpatialTransformer doc for more info
sd = self.state_dict().copy()
grid_buffers = [key for key in sd.keys() if key.endswith('.grid')]
for key in grid_buffers:
sd.pop(key)
torch.save({'config': self.config, 'model_state': sd}, path)
@classmethod
def load(cls, path, device):
"""
Load a python model configuration and weights.
"""
checkpoint = torch.load(path, map_location=torch.device(device))
model = cls(**checkpoint['config'])
model.load_state_dict(checkpoint['model_state'], strict=False)
return model