-
Notifications
You must be signed in to change notification settings - Fork 42
/
Copy pathaverage_checkpoints.py
36 lines (31 loc) · 980 Bytes
/
average_checkpoints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
import torch
def average_checkpoints(last):
avg = None
for path in last:
states = torch.load(path, map_location=lambda storage, loc: storage)["state_dict"]
states = {k[6:]: v for k, v in states.items() if k.startswith("model.")}
if avg is None:
avg = states
else:
for k in avg.keys():
avg[k] += states[k]
# average
for k in avg.keys():
if avg[k] is not None:
if avg[k].is_floating_point():
avg[k] /= len(last)
else:
avg[k] //= len(last)
return avg
def ensemble(args):
last = [
os.path.join(args.exp_dir, args.exp_name, f"epoch={n}.ckpt")
for n in range(
args.max_epochs - 10,
args.max_epochs,
)
]
model_path = os.path.join(args.exp_dir, args.exp_name, f"model_avg_10.pth")
torch.save(average_checkpoints(last), model_path)
return model_path