-
Notifications
You must be signed in to change notification settings - Fork 2
/
setup_latent_guidance.py
58 lines (48 loc) · 1.88 KB
/
setup_latent_guidance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#!/usr/bin/python
# -*- coding:utf-8 -*-
import os
import yaml
import argparse
from tqdm import tqdm
import torch
from generate import get_best_ckpt, to_device
from data import create_dataloader, create_dataset
def main(args):
config = yaml.safe_load(open(args.config, 'r'))
# load model
b_ckpt = args.ckpt if args.ckpt.endswith('.ckpt') else get_best_ckpt(args.ckpt)
print(f'Using checkpoint {b_ckpt}')
model = torch.load(b_ckpt, map_location='cpu')
device = torch.device('cpu' if args.gpu == -1 else f'cuda:{args.gpu}')
model.to(device)
model.eval()
# load data
_, _, test_set = create_dataset(config['dataset'])
test_loader = create_dataloader(test_set, config['dataloader'])
all_dists = []
with torch.no_grad():
for batch in tqdm(test_loader):
batch = to_device(batch, device)
H, Z, _, _ = model.autoencoder.encode(
batch['X'], batch['S'], batch['mask'], batch['position_ids'],
batch['lengths'], batch['atom_mask'], no_randomness=True
)
pos = batch['position_ids'][batch['mask']]
Z = Z.squeeze(1)
dists = torch.norm(Z[1:] - Z[:-1], dim=-1) # [N]
pos_dist = pos[1:] - pos[:-1]
dists = dists[pos_dist == 1]
all_dists.append(dists)
all_dists = torch.cat(all_dists, dim=0)
mean, std = torch.mean(all_dists), torch.std(all_dists)
print(mean, std)
model.set_consec_dist(mean.item(), std.item())
torch.save(model, b_ckpt)
def parse():
parser = argparse.ArgumentParser(description='Calculate distance between consecutive latent points')
parser.add_argument('--config', type=str, required=True)
parser.add_argument('--ckpt', type=str, required=True)
parser.add_argument('--gpu', type=int, default=0)
return parser.parse_args()
if __name__ == '__main__':
main(parse())