forked from tttianhao/CLEAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train-triplet.py
138 lines (123 loc) · 5.6 KB
/
train-triplet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import time
import os
import pickle
from CLEAN.dataloader import *
from CLEAN.model import *
from CLEAN.utils import *
import torch.nn as nn
import argparse
from CLEAN.distance_map import get_dist_map
def parse():
parser = argparse.ArgumentParser()
parser.add_argument('-l', '--learning_rate', type=float, default=5e-4)
parser.add_argument('-e', '--epoch', type=int, default=2000)
parser.add_argument('-n', '--model_name', type=str, default='split10_triplet')
parser.add_argument('-t', '--training_data', type=str, default='split10')
parser.add_argument('-d', '--hidden_dim', type=int, default=512)
parser.add_argument('-o', '--out_dim', type=int, default=128)
parser.add_argument('--adaptive_rate', type=int, default=100)
parser.add_argument('--verbose', type=bool, default=False)
args = parser.parse_args()
return args
def get_dataloader(dist_map, id_ec, ec_id, args):
params = {
'batch_size': 6000,
'shuffle': True,
}
negative = mine_hard_negative(dist_map, 30)
train_data = Triplet_dataset_with_mine_EC(id_ec, ec_id, negative)
train_loader = torch.utils.data.DataLoader(train_data, **params)
return train_loader
def train(model, args, epoch, train_loader,
optimizer, device, dtype, criterion):
model.train()
total_loss = 0.
start_time = time.time()
for batch, data in enumerate(train_loader):
optimizer.zero_grad()
anchor, positive, negative = data
anchor_out = model(anchor.to(device=device, dtype=dtype))
positive_out = model(positive.to(device=device, dtype=dtype))
negative_out = model(negative.to(device=device, dtype=dtype))
loss = criterion(anchor_out, positive_out, negative_out)
loss.backward()
optimizer.step()
total_loss += loss.item()
if args.verbose:
lr = args.learning_rate
ms_per_batch = (time.time() - start_time) * 1000
cur_loss = total_loss
print(f'| epoch {epoch:3d} | {batch:5d}/{len(train_loader):5d} batches | '
f'lr {lr:02.4f} | ms/batch {ms_per_batch:6.4f} | '
f'loss {cur_loss:5.2f}')
start_time = time.time()
# record running average training loss
return total_loss/(batch + 1)
def main():
seed_everything()
ensure_dirs('./data/model')
args = parse()
torch.backends.cudnn.benchmark = True
id_ec, ec_id_dict = get_ec_id_dict('./data/' + args.training_data + '.csv')
ec_id = {key: list(ec_id_dict[key]) for key in ec_id_dict.keys()}
#======================== override args ====================#
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
dtype = torch.float32
lr, epochs = args.learning_rate, args.epoch
model_name = args.model_name
print('==> device used:', device, '| dtype used: ',
dtype, "\n==> args:", args)
#======================== ESM embedding ===================#
# loading ESM embedding for dist map
esm_emb = pickle.load(
open('./data/distance_map/' + args.training_data + '_esm.pkl',
'rb')).to(device=device, dtype=dtype)
dist_map = pickle.load(open('./data/distance_map/' + \
args.training_data + '.pkl', 'rb'))
#======================== initialize model =================#
model = LayerNormNet(args.hidden_dim, args.out_dim, device, dtype)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))
criterion = nn.TripletMarginLoss(margin=1, reduction='mean')
best_loss = float('inf')
train_loader = get_dataloader(dist_map, id_ec, ec_id, args)
print("The number of unique EC numbers: ", len(dist_map.keys()))
#======================== training =======-=================#
# training
for epoch in range(1, epochs + 1):
if epoch % args.adaptive_rate == 0 and epoch != epochs + 1:
optimizer = torch.optim.Adam(
model.parameters(), lr=lr, betas=(0.9, 0.999))
# save updated model
torch.save(model.state_dict(), './data/model/' +
model_name + '_' + str(epoch) + '.pth')
# delete last model checkpoint
if epoch != args.adaptive_rate:
os.remove('./data/model/' + model_name + '_' +
str(epoch-args.adaptive_rate) + '.pth')
# sample new distance map
dist_map = get_dist_map(
ec_id_dict, esm_emb, device, dtype, model=model)
train_loader = get_dataloader(dist_map, id_ec, ec_id, args)
# -------------------------------------------------------------------- #
epoch_start_time = time.time()
train_loss = train(model, args, epoch, train_loader,
optimizer, device, dtype, criterion)
# only save the current best model near the end of training
if (train_loss < best_loss and epoch > 0.8*epochs):
torch.save(model.state_dict(), './data/model/' + model_name + '.pth')
best_loss = train_loss
print(f'Best from epoch : {epoch:3d}; loss: {train_loss:6.4f}')
elapsed = time.time() - epoch_start_time
print('-' * 75)
print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
f'training loss {train_loss:6.4f}')
print('-' * 75)
# remove tmp save weights
os.remove('./data/model/' + model_name + '.pth')
os.remove('./data/model/' + model_name + '_' + str(epoch) + '.pth')
# save final weights
torch.save(model.state_dict(), './data/model/' + model_name + '.pth')
if __name__ == '__main__':
main()