-
Notifications
You must be signed in to change notification settings - Fork 112
/
Copy pathtrain_agent.py
executable file
·170 lines (138 loc) · 7.76 KB
/
train_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
#!/usr/bin/env python
import torch
import pickle
import numpy as np
import time
import os
from shutil import copyfile
from model import RNN
from data_structs import Vocabulary, Experience
from scoring_functions import get_scoring_function
from utils import Variable, seq_to_smiles, fraction_valid_smiles, unique
from vizard_logger import VizardLog
def train_agent(restore_prior_from='data/Prior.ckpt',
restore_agent_from='data/Prior.ckpt',
scoring_function='tanimoto',
scoring_function_kwargs=None,
save_dir=None, learning_rate=0.0005,
batch_size=64, n_steps=3000,
num_processes=0, sigma=60,
experience_replay=0):
voc = Vocabulary(init_from_file="data/Voc")
start_time = time.time()
Prior = RNN(voc)
Agent = RNN(voc)
logger = VizardLog('data/logs')
# By default restore Agent to same model as Prior, but can restore from already trained Agent too.
# Saved models are partially on the GPU, but if we dont have cuda enabled we can remap these
# to the CPU.
if torch.cuda.is_available():
Prior.rnn.load_state_dict(torch.load('data/Prior.ckpt'))
Agent.rnn.load_state_dict(torch.load(restore_agent_from))
else:
Prior.rnn.load_state_dict(torch.load('data/Prior.ckpt', map_location=lambda storage, loc: storage))
Agent.rnn.load_state_dict(torch.load(restore_agent_from, map_location=lambda storage, loc: storage))
# We dont need gradients with respect to Prior
for param in Prior.rnn.parameters():
param.requires_grad = False
optimizer = torch.optim.Adam(Agent.rnn.parameters(), lr=0.0005)
# Scoring_function
scoring_function = get_scoring_function(scoring_function=scoring_function, num_processes=num_processes,
**scoring_function_kwargs)
# For policy based RL, we normally train on-policy and correct for the fact that more likely actions
# occur more often (which means the agent can get biased towards them). Using experience replay is
# therefor not as theoretically sound as it is for value based RL, but it seems to work well.
experience = Experience(voc)
# Log some network weights that can be dynamically plotted with the Vizard bokeh app
logger.log(Agent.rnn.gru_2.weight_ih.cpu().data.numpy()[::100], "init_weight_GRU_layer_2_w_ih")
logger.log(Agent.rnn.gru_2.weight_hh.cpu().data.numpy()[::100], "init_weight_GRU_layer_2_w_hh")
logger.log(Agent.rnn.embedding.weight.cpu().data.numpy()[::30], "init_weight_GRU_embedding")
logger.log(Agent.rnn.gru_2.bias_ih.cpu().data.numpy(), "init_weight_GRU_layer_2_b_ih")
logger.log(Agent.rnn.gru_2.bias_hh.cpu().data.numpy(), "init_weight_GRU_layer_2_b_hh")
# Information for the logger
step_score = [[], []]
print("Model initialized, starting training...")
for step in range(n_steps):
# Sample from Agent
seqs, agent_likelihood, entropy = Agent.sample(batch_size)
# Remove duplicates, ie only consider unique seqs
unique_idxs = unique(seqs)
seqs = seqs[unique_idxs]
agent_likelihood = agent_likelihood[unique_idxs]
entropy = entropy[unique_idxs]
# Get prior likelihood and score
prior_likelihood, _ = Prior.likelihood(Variable(seqs))
smiles = seq_to_smiles(seqs, voc)
score = scoring_function(smiles)
# Calculate augmented likelihood
augmented_likelihood = prior_likelihood + sigma * Variable(score)
loss = torch.pow((augmented_likelihood - agent_likelihood), 2)
# Experience Replay
# First sample
if experience_replay and len(experience)>4:
exp_seqs, exp_score, exp_prior_likelihood = experience.sample(4)
exp_agent_likelihood, exp_entropy = Agent.likelihood(exp_seqs.long())
exp_augmented_likelihood = exp_prior_likelihood + sigma * exp_score
exp_loss = torch.pow((Variable(exp_augmented_likelihood) - exp_agent_likelihood), 2)
loss = torch.cat((loss, exp_loss), 0)
agent_likelihood = torch.cat((agent_likelihood, exp_agent_likelihood), 0)
# Then add new experience
prior_likelihood = prior_likelihood.data.cpu().numpy()
new_experience = zip(smiles, score, prior_likelihood)
experience.add_experience(new_experience)
# Calculate loss
loss = loss.mean()
# Add regularizer that penalizes high likelihood for the entire sequence
loss_p = - (1 / agent_likelihood).mean()
loss += 5 * 1e3 * loss_p
# Calculate gradients and make an update to the network weights
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Convert to numpy arrays so that we can print them
augmented_likelihood = augmented_likelihood.data.cpu().numpy()
agent_likelihood = agent_likelihood.data.cpu().numpy()
# Print some information for this step
time_elapsed = (time.time() - start_time) / 3600
time_left = (time_elapsed * ((n_steps - step) / (step + 1)))
print("\n Step {} Fraction valid SMILES: {:4.1f} Time elapsed: {:.2f}h Time left: {:.2f}h".format(
step, fraction_valid_smiles(smiles) * 100, time_elapsed, time_left))
print(" Agent Prior Target Score SMILES")
for i in range(10):
print(" {:6.2f} {:6.2f} {:6.2f} {:6.2f} {}".format(agent_likelihood[i],
prior_likelihood[i],
augmented_likelihood[i],
score[i],
smiles[i]))
# Need this for Vizard plotting
step_score[0].append(step + 1)
step_score[1].append(np.mean(score))
# Log some weights
logger.log(Agent.rnn.gru_2.weight_ih.cpu().data.numpy()[::100], "weight_GRU_layer_2_w_ih")
logger.log(Agent.rnn.gru_2.weight_hh.cpu().data.numpy()[::100], "weight_GRU_layer_2_w_hh")
logger.log(Agent.rnn.embedding.weight.cpu().data.numpy()[::30], "weight_GRU_embedding")
logger.log(Agent.rnn.gru_2.bias_ih.cpu().data.numpy(), "weight_GRU_layer_2_b_ih")
logger.log(Agent.rnn.gru_2.bias_hh.cpu().data.numpy(), "weight_GRU_layer_2_b_hh")
logger.log("\n".join([smiles + "\t" + str(round(score, 2)) for smiles, score in zip \
(smiles[:12], score[:12])]), "SMILES", dtype="text", overwrite=True)
logger.log(np.array(step_score), "Scores")
# If the entire training finishes, we create a new folder where we save this python file
# as well as some sampled sequences and the contents of the experinence (which are the highest
# scored sequences seen during training)
if not save_dir:
save_dir = 'data/results/run_' + time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime())
os.makedirs(save_dir)
copyfile('train_agent.py', os.path.join(save_dir, "train_agent.py"))
experience.print_memory(os.path.join(save_dir, "memory"))
torch.save(Agent.rnn.state_dict(), os.path.join(save_dir, 'Agent.ckpt'))
seqs, agent_likelihood, entropy = Agent.sample(256)
prior_likelihood, _ = Prior.likelihood(Variable(seqs))
prior_likelihood = prior_likelihood.data.cpu().numpy()
smiles = seq_to_smiles(seqs, voc)
score = scoring_function(smiles)
with open(os.path.join(save_dir, "sampled"), 'w') as f:
f.write("SMILES Score PriorLogP\n")
for smiles, score, prior_likelihood in zip(smiles, score, prior_likelihood):
f.write("{} {:5.2f} {:6.2f}\n".format(smiles, score, prior_likelihood))
if __name__ == "__main__":
train_agent()