-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmol_mdp_ext.py
306 lines (266 loc) · 12.9 KB
/
mol_mdp_ext.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
from collections import defaultdict
import os.path
import numpy as np
from utils.molMDP import BlockMoleculeData, MolMDP
import utils.chem as chem
from rdkit import Chem
import networkx as nx
import model_atom, model_block, model_fingerprint
class BlockMoleculeDataExtended(BlockMoleculeData):
@property
def mol(self):
return chem.mol_from_frag(jun_bonds=self.jbonds, frags=self.blocks)[0]
@property
def smiles(self):
return Chem.MolToSmiles(self.mol)
def copy(self): # shallow copy
o = BlockMoleculeDataExtended()
o.blockidxs = list(self.blockidxs)
o.blocks = list(self.blocks)
o.slices = list(self.slices)
o.numblocks = self.numblocks
o.jbonds = list(self.jbonds)
o.stems = list(self.stems)
return o
def as_dict(self):
return {'blockidxs': self.blockidxs,
'slices': self.slices,
'numblocks': self.numblocks,
'jbonds': self.jbonds,
'stems': self.stems}
class MolMDPExtended(MolMDP):
def build_translation_table(self):
"""build a symmetry mapping for blocks. Necessary to compute parent transitions"""
self.translation_table = {}
for blockidx in range(len(self.block_mols)):
# Blocks have multiple ways of being attached. By default,
# a new block is attached to the target stem by attaching
# it's kth atom, where k = block_rs[new_block_idx][0].
# When computing a reverse action (from a parent), we may
# wish to attach the new block to a different atom. In
# the blocks library, there are duplicates of the same
# block but with block_rs[block][0] set to a different
# atom. Thus, for the reverse action we have to find out
# which duplicate this corresponds to.
# Here, we compute, for block blockidx, what is the index
# of the duplicate block, if someone wants to attach to
# atom x of the block.
# So atom_map[x] == bidx, such that block_rs[bidx][0] == x
atom_map = {}
for j in range(len(self.block_mols)):
if self.block_smi[blockidx] == self.block_smi[j]:
atom_map[self.block_rs[j][0]] = j
self.translation_table[blockidx] = atom_map
# We're still missing some "duplicates", as some might be
# symmetric versions of each other. For example, block CC with
# block_rs == [0,1] has no duplicate, because the duplicate
# with block_rs [1,0] would be a symmetric version (both C
# atoms are the "same").
# To test this, let's create nonsense molecules by attaching
# duplicate blocks to a Gold atom, and testing whether they
# are the same.
gold = Chem.MolFromSmiles('[Au]')
# If we find that two molecules are the same when attaching
# them with two different atoms, then that means the atom
# numbers are symmetries. We can add those to the table.
for blockidx in range(len(self.block_mols)):
for j in self.block_rs[blockidx]:
if j not in self.translation_table[blockidx]:
symmetric_duplicate = None
for atom, block_duplicate in self.translation_table[blockidx].items():
molA, _ = chem.mol_from_frag(
jun_bonds=[[0,1,0,j]],
frags=[gold, self.block_mols[blockidx]])
molB, _ = chem.mol_from_frag(
jun_bonds=[[0,1,0,atom]],
frags=[gold, self.block_mols[blockidx]])
if (Chem.MolToSmiles(molA) == Chem.MolToSmiles(molB) or
molA.HasSubstructMatch(molB)):
symmetric_duplicate = block_duplicate
break
if symmetric_duplicate is None:
raise ValueError('block', blockidx, self.block_smi[blockidx],
'has no duplicate for atom', j,
'in position 0, and no symmetrical correspondance')
self.translation_table[blockidx][j] = symmetric_duplicate
#print('block', blockidx, '+ atom', j,
# 'in position 0 is a symmetric duplicate of',
# symmetric_duplicate)
def parents(self, mol=None):
"""returns all the possible parents of molecule mol (or the current
molecule if mol is None.
Returns a list of (BlockMoleculeDataExtended, (block_idx, stem_idx)) pairs such that
for a pair (m, (b, s)), MolMDPExtended.add_block_to(m, b, s) == mol.
"""
if len(mol.blockidxs) == 1:
# If there's just a single block, then the only parent is
# the empty block with the action that recreates that block
return [(BlockMoleculeDataExtended(), (mol.blockidxs[0], 0))]
# Compute the how many blocks each block is connected to
blocks_degree = defaultdict(int)
for a,b,_,_ in mol.jbonds:
blocks_degree[a] += 1
blocks_degree[b] += 1
# Keep only blocks of degree 1 (those are the ones that could
# have just been added)
blocks_degree_1 = [i for i, d in blocks_degree.items() if d == 1]
# Form new molecules without these blocks
parent_mols = []
for rblockidx in blocks_degree_1:
new_mol = mol.copy()
# find which bond we're removing
removed_bonds = [(jbidx, bond) for jbidx, bond in enumerate(new_mol.jbonds)
if rblockidx in bond[:2]]
assert len(removed_bonds) == 1
rjbidx, rbond = removed_bonds[0]
# Pop the bond
new_mol.jbonds.pop(rjbidx)
# Remove the block
mask = np.ones(len(new_mol.blockidxs), dtype=np.bool)
mask[rblockidx] = 0
reindex = new_mol.delete_blocks(mask)
# reindex maps old blockidx to new blockidx, since the
# block the removed block was attached to might have its
# index shifted by 1.
# Compute which stem the bond was using
stem = ([reindex[rbond[0]], rbond[2]] if rblockidx == rbond[1] else
[reindex[rbond[1]], rbond[3]])
# and add it back
new_mol.stems = [list(i) for i in new_mol.stems] + [stem]
#new_mol.stems.append(stem)
# and we have a parent. The stem idx to recreate mol is
# the last stem, since we appended `stem` in the back of
# the stem list.
# We also have to translate the block id to match the bond
# we broke, see build_translation_table().
removed_stem_atom = (
rbond[3] if rblockidx == rbond[1] else rbond[2])
blockid = mol.blockidxs[rblockidx]
if removed_stem_atom not in self.translation_table[blockid]:
raise ValueError('Could not translate removed stem to duplicate or symmetric block.')
parent_mols.append([new_mol,
# action = (block_idx, stem_idx)
(self.translation_table[blockid][removed_stem_atom],
len(new_mol.stems) - 1)])
if not len(parent_mols):
raise ValueError('Could not find any parents')
return parent_mols
def add_block_to(self, mol, block_idx, stem_idx=None, atmidx=None):
'''out-of-place version of add_block'''
#assert (block_idx >= 0) and (block_idx <= len(self.block_mols)), "unknown block"
if mol.numblocks == 0:
stem_idx = None
new_mol = mol.copy()
new_mol.add_block(block_idx,
block=self.block_mols[block_idx],
block_r=self.block_rs[block_idx],
stem_idx=stem_idx, atmidx=atmidx)
return new_mol
def remove_jbond_from(self, mol, jbond_idx=None, atmidx=None):
new_mol = mol.copy()
new_mol.remove_jbond(jbond_idx, atmidx)
return new_mol
def a2mol(self, acts):
mol = BlockMoleculeDataExtended()
for i in acts:
if i[0] >= 0:
mol = self.add_block_to(mol, *i)
return mol
def reset(self):
self.molecule = BlockMoleculeDataExtended()
return None
def post_init(self, device, repr_type, include_bonds=False, include_nblocks=False):
self.device = device
self.repr_type = repr_type
#self.max_bond_atmidx = max([max(i) for i in self.block_rs])
self.max_num_atm = max(self.block_natm)
# see model_block.mol2graph
self.true_block_set = sorted(set(self.block_smi))
self.stem_type_offset = np.int32([0] + list(np.cumsum([
max(self.block_rs[self.block_smi.index(i)])+1 for i in self.true_block_set])))
self.num_stem_types = self.stem_type_offset[-1]
self.true_blockidx = [self.true_block_set.index(i) for i in self.block_smi]
self.num_true_blocks = len(self.true_block_set)
self.include_nblocks = include_nblocks
self.include_bonds = include_bonds
#print(self.max_num_atm, self.num_stem_types)
self.molcache = {}
def mols2batch(self, mols):
if self.repr_type == 'block_graph':
return model_block.mols2batch(mols, self)
elif self.repr_type == 'atom_graph':
return model_atom.mols2batch(mols, self)
elif self.repr_type == 'morgan_fingerprint':
return model_fingerprint.mols2batch(mols, self)
def mol2repr(self, mol=None):
if mol is None:
mol = self.molecule
#molhash = str(mol.blockidxs)+':'+str(mol.stems)+':'+str(mol.jbonds)
#if molhash in self.molcache:
# return self.molcache[molhash]
if self.repr_type == 'block_graph':
r = model_block.mol2graph(mol, self, self.floatX)
elif self.repr_type == 'atom_graph':
r = model_atom.mol2graph(mol, self, self.floatX,
bonds=self.include_bonds,
nblocks=self.include_nblocks)
elif self.repr_type == 'morgan_fingerprint':
r = model_fingerprint.mol2fp(mol, self, self.floatX)
#self.molcache[molhash] = r
return r
def get_nx_graph(self, mol: BlockMoleculeData, true_block=False):
true_blockidx = self.true_blockidx
G = nx.DiGraph()
blockidxs = [true_blockidx[xx] for xx in mol.blockidxs] if true_block else mol.blockidxs
G.add_nodes_from([(ix, {"block": blockidxs[ix]}) for ix in range(len(blockidxs))])
if len(mol.jbonds) > 0:
edges = []
for jbond in mol.jbonds:
edges.append((jbond[0], jbond[1],
{"bond": [jbond[2], jbond[3]]}))
edges.append((jbond[1], jbond[0],
{"bond": [jbond[3], jbond[2]]}))
G.add_edges_from(edges)
return G
def graphs_are_isomorphic(self, g1, g2):
return nx.algorithms.is_isomorphic(g1, g2, node_match=node_match, edge_match=edge_match)
def node_match(x1, x2):
return x1["block"] == x2["block"]
def edge_match(x1, x2):
return x1["bond"] == x2["bond"]
def test_mdp_parent():
mdp = MolMDPExtended("./data/blocks_PDB_105.json")
mdp.build_translation_table()
import tqdm
rng = np.random.RandomState(142)
nblocks = mdp.num_blocks
# First let's test that the parent-finding method is
# correct. i.e. let's test that the (mol, (parent, action)) pairs
# are such that add_block_to(parent, action) == mol
for i in tqdm.tqdm(range(10000)):
mdp.molecule = mol = BlockMoleculeDataExtended()
nblocks = rng.randint(1, 10)
for i in range(nblocks):
if len(mol.blocks) and not len(mol.stems): break
mdp.add_block(rng.randint(nblocks), rng.randint(max(1, len(mol.stems))))
parents = mdp.parents(mol)
s = mol.smiles
for p, (a,b) in parents:
c = mdp.add_block_to(p, a, b)
if c.smiles != s:
# SMILES might differ but this might still be the same mol
# we can check this way but its a bit more costly
assert c.mol.HasSubstructMatch(mol.mol)
# Now let's test whether we can always backtrack to the root from
# any molecule without any errors
for i in tqdm.tqdm(range(10000)):
mdp.molecule = mol = BlockMoleculeDataExtended()
nblocks = rng.randint(1, 10)
for i in range(nblocks):
if len(mol.blocks) and not len(mol.stems): break
mdp.add_block(rng.randint(nblocks), rng.randint(max(1, len(mol.stems))))
while len(mol.blocks):
parents = mdp.parents(mol)
mol = parents[rng.randint(len(parents))][0]
if __name__ == '__main__':
test_mdp_parent()