-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathsample_lm.py
executable file
·126 lines (83 loc) · 3.4 KB
/
sample_lm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#!/usr/bin/env python
from __future__ import division
import onmt
import onmt.markdown
import torch
import argparse
import math
import numpy
from onmt.model_factory import build_model
parser = argparse.ArgumentParser(description='translate.py')
onmt.markdown.add_md_help_argument(parser)
parser.add_argument('-models', required=True,
help='Path to model .pt file')
parser.add_argument('-output', default='model.averaged',
help="""Path to output averaged model""")
parser.add_argument('-gpu', type=int, default=-1,
help="Device to run on")
parser.add_argument('-method', default='mean',
help="method to average: mean|gmean")
def main():
opt = parser.parse_args()
opt.cuda = opt.gpu > -1
if opt.cuda:
torch.cuda.set_device(opt.gpu)
# opt.model should be a string of models, split by |
models = opt.models.split("|")
# print(models)
n_models = len(models)
print("Loading main model from %s ..." % models[0])
checkpoint = torch.load(models[0], map_location=lambda storage, loc: storage)
if 'optim' in checkpoint:
del checkpoint['optim']
main_checkpoint = checkpoint
model_opt = checkpoint['opt']
dicts = checkpoint['dicts']
main_model = build_model(model_opt, checkpoint['dicts'])
main_model.load_state_dict(checkpoint['model'])
if opt.cuda:
main_model = main_model.cuda()
for i in range(1, len(models)):
model = models[i]
print("Loading model from %s ..." % models[i])
checkpoint = torch.load(model, map_location=lambda storage, loc: storage)
model_opt = checkpoint['opt']
# delete optim information to save GPU memory
if 'optim' in checkpoint:
del checkpoint['optim']
current_model = build_model(model_opt, checkpoint['dicts'])
current_model.load_state_dict(checkpoint['model'])
if opt.cuda:
current_model = current_model.cuda()
if opt.method == 'mean':
# Sum the parameter values
for (main_param, param) in zip(main_model.parameters(), current_model.parameters()):
main_param.data.add_(param.data)
elif opt.method == 'gmean':
# Take the geometric mean of parameter values
for (main_param, param) in zip(main_model.parameters(), current_model.parameters()):
main_param.data.mul_(param.data)
else:
raise NotImplementedError
# Normalizing
if opt.method == 'mean':
for main_param in main_model.parameters():
main_param.data.div_(n_models)
elif opt.method == 'gmean':
for main_param in main_model.parameters():
main_param.data.pow_(1./n_models)
# Saving
model_state_dict = main_model.state_dict()
save_checkpoint = {
'model': model_state_dict,
'dicts': dicts,
'opt': model_opt,
'epoch': -1,
'iteration' : -1,
'batchOrder' : None,
'optim': None
}
print("Saving averaged model to %s" % opt.output)
torch.save(save_checkpoint, opt.output)
if __name__ == "__main__":
main()