-
Notifications
You must be signed in to change notification settings - Fork 1
/
HBHistogram.py
401 lines (313 loc) · 13.5 KB
/
HBHistogram.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
# -*- coding: utf-8 -*-
# Standard imports
import argparse as ap
import glob
from collections import defaultdict
from pathlib import Path
# External imports
import numpy as np
import matplotlib as mpl
#mpl.use('Agg')
import matplotlib.pyplot as plt
from matplotlib import cm
# PELE imports
from Helpers.PELEIterator import SimIt
from Helpers.ReportUtils import extract_PELE_ids
from Helpers.ReportUtils import extract_metrics
from Helpers.ReportUtils import get_metric_by_PELE_id
# Script information
__author__ = "Marti Municoy, Carles Perez"
__license__ = "GPL"
__version__ = "1.0.1"
__maintainer__ = "Marti Municoy, Carles Perez"
__email__ = "[email protected], [email protected]"
def parse_args():
parser = ap.ArgumentParser()
parser.add_argument("hbonds_data_paths", metavar="PATH", type=str,
nargs='*',
help="Path to H bonds data files")
parser.add_argument("-m", "--mode", choices=["count",
"frequent_interactions",
"relative_frequency",
"mean_energies"],
type=str, metavar="MODE",
default="count",
help="Selection of computation mode: " +
"(1) count - sum of all residues interacting, " +
"(2) frequent_interactions - sum of interactions " +
"present at least in a 10%% of the structures of" +
" the simulation, (3) relative_frequency - mean " +
"of interacting residues frequencies for each " +
"ligand, (4) mean_energies - mean interaction " +
"energies per H bond are calculated")
parser.add_argument("-l", "--lim",
metavar="L", type=float, default='0.1',
help="Frequency limit for frequent_interations method")
parser.add_argument("--epochs_to_ignore", nargs='*',
metavar="N", type=int, default=[],
help="PELE epochs whose H bonds will be ignored")
parser.add_argument("--trajectories_to_ignore", nargs='*',
metavar="N", type=int, default=[],
help="PELE trajectories whose H bonds will be ignored")
parser.add_argument("--models_to_ignore", nargs='*',
metavar="N", type=int, default=[],
help="PELE models whose H bonds will be ignored")
parser.add_argument("-o", "--output",
metavar="PATH", type=str, default=None,
help="Output path to save the plot")
parser.add_argument("-n", "--processors_number",
metavar="N", type=int, default=None,
help="Number of processors")
parser.add_argument("--PELE_output_path",
metavar="PATH", type=str, default='output',
help="Relative path to PELE output folder")
parser.add_argument("--PELE_report_name",
metavar="PATH", type=str, default='report',
help="Name of PELE's reports")
args = parser.parse_args()
return args.hbonds_data_paths, args.mode, args.lim, \
args.epochs_to_ignore, args.trajectories_to_ignore, \
args.models_to_ignore, args.output, args.processors_number, \
args.PELE_output_path, args.PELE_report_name
def create_df(hb_path):
rows_df = []
with open(hb_path) as file:
# Skip four header lines
file.readline()
file.readline()
file.readline()
file.readline()
rows = file.readlines()
for row in rows[2:]:
rows_df.append(row.split())
return rows_df
def get_hbond_atoms_from_df(df, hb_path, epochs_to_ignore,
trajectories_to_ignore, models_to_ignore):
hbond_atoms = defaultdict(list)
for row in df:
try:
epoch, trajectory, model = map(int, row[0:3])
except (IndexError, ValueError):
print ("get_hbond_atoms_from_df Warning: found row with non" +
" valid format at {}:".format(hb_path))
print(" {}".format(row))
if ((epoch in epochs_to_ignore) or
(trajectory in trajectories_to_ignore) or
(model in models_to_ignore)):
continue
try:
residues = row[3].split(',')
for residue in residues:
hbond_atoms[(epoch, trajectory, model)].append(
residue.split(":"))
except IndexError:
pass
return hbond_atoms
def count(hbond_atoms):
counter = defaultdict(dict)
for _, hbonds in hbond_atoms.items():
for (chain, residue, atom) in hbonds:
counter[(chain, residue)][atom] = \
counter[(chain, residue)].get(atom, 0) + 1
return counter
def count_norm(hbond_atoms):
counter = defaultdict(dict)
if (len(hbond_atoms) == 0):
return counter
number_of_snapshots = len(hbond_atoms)
norm_factor = 1 / number_of_snapshots
for _, hbonds in hbond_atoms.items():
for (chain, residue, atom) in hbonds:
counter[(chain, residue)][atom] = \
counter[(chain, residue)].get(atom, 0) + 1
for residue, atom_freq in counter.items():
for atom, freq in atom_freq.items():
counter[residue][atom] *= norm_factor
return counter
def discard_non_frequent(counter, lim=0.1):
new_counter = defaultdict(dict)
for (chain, residue), atom_freq in counter.items():
for atom, freq in atom_freq.items():
if (freq >= lim):
new_counter[(chain, residue)][atom] = \
new_counter[(chain, residue)].get(atom, 0) + 1
return new_counter
def count_energy(hbond_atoms, ie_by_PELE_id):
counter = defaultdict(lambda: defaultdict(list))
for PELE_id, hbs in hbond_atoms.items():
# Preventing repeated H bonds in the same snapshot
for (chain, residue, atom) in set(map(tuple, hbs)):
counter[(chain, residue)][atom].append(ie_by_PELE_id[PELE_id])
# Calculate mean and sum of means
sum_of_means = float(0.0)
for (chain, residue), atom_ies in counter.items():
for atom, ies in atom_ies.items():
ies_mean = np.mean(ies)
counter[(chain, residue)][atom] = ies_mean
sum_of_means += ies_mean
if (sum_of_means == 0):
return defaultdict(dict)
norm_factor = 1 #/ sum_of_means
# Calculate relative energy
for (chain, residue), atom_ies in counter.items():
for atom, ies in atom_ies.items():
counter[(chain, residue)][atom] *= norm_factor
return counter
def combine_results(general_results, mode):
combined_results = defaultdict(dict)
if (mode == "count" or mode == "frequent_interactions"):
for _, hbonds in general_results.items():
for residue, atom_freq in hbonds.items():
for atom, freq in atom_freq.items():
combined_results[residue][atom] = \
combined_results[residue].get(atom, 0) + freq
elif (mode == "relative_frequency"):
counter = defaultdict(list)
atom_set = set()
for _, hbonds in general_results.items():
for residue, atom_freq in hbonds.items():
for atom, freq in atom_freq.items():
atom_set.add(residue + (atom, ))
for _, hbonds in general_results.items():
for (chain, residue, atom) in atom_set:
if (chain, residue) in hbonds:
if (atom in hbonds[(chain, residue)]):
counter[(chain, residue, atom)].append(
hbonds[(chain, residue)][atom])
continue
counter[(chain, residue, atom)].append(0)
for (chain, residue, atom), freqs in counter.items():
combined_results[(chain, residue)][atom] = np.mean(freqs)
elif (mode == "mean_energies"):
ie_combiner = defaultdict(lambda: defaultdict(list))
atom_set = set()
for _, hbonds in general_results.items():
for residue, atom_ies in hbonds.items():
for atom, ie in atom_ies.items():
ie_combiner[residue][atom].append(ie)
for residue, atom_ies in ie_combiner.items():
for atom, ies in atom_ies.items():
combined_results[residue][atom] = np.mean(ies)
return combined_results
def generate_barplot(dictionary, mode, lim, output_path):
fig, ax = plt.subplots(1, figsize=(10, 8))
fig.tight_layout()
fig.subplots_adjust(left=0.12, right=0.98, top=0.98, bottom=0.08)
# y ticks and labels handlers
y = 1.4
ys = []
sub_ys = []
sub_xs = []
ylabels = []
sub_ylabels = []
# colormap handlers
norm = mpl.colors.Normalize(0, 10)
cmap = cm.get_cmap('tab10')
color_index = 0
max_freq = None
min_freq = None
for residue, atom_freq in dictionary.items():
for atom, freq in atom_freq.items():
if (max_freq is None or freq > max_freq):
max_freq = freq
if (min_freq is None or freq < min_freq):
min_freq = freq
for residue, atom_freq in sorted(dictionary.items()):
_ys = []
jump = False
for atom, freq in sorted(atom_freq.items()):
if (mode != 'mean_energies'):
if (freq < max_freq / 100):
continue
_ys.append(y)
sub_ylabels.append(atom)
sub_xs.append(freq)
plt.barh(y, freq, align='center', color=cmap(norm(color_index)))
y += 0.9
jump = True
if (jump):
ys.append(np.mean(_ys))
ylabels.append(residue)
sub_ys += _ys
if (color_index < 9):
color_index += 1
else:
color_index = 0
y += 0.5
plt.ylim((0, y))
plt.ylabel('COVID-19 Mpro residues', fontweight='bold')
plt.yticks(ys, ['{}:{}'.format(*i) for i in ylabels])
if (mode == "count"):
plt.xlabel('Absolut H bond counts', fontweight='bold')
elif (mode == "relative_frequency"):
plt.xlabel('Average of relative H bond frequencies', fontweight='bold')
elif (mode == "frequent_interactions"):
plt.xlabel('Absolut H bond counts with frequencies above ' +
'{}'.format(lim), fontweight='bold')
elif (mode == "mean_energies"):
ax.set_xlim(max_freq + (max_freq - min_freq) * 0.05,
min_freq - (max_freq - min_freq) * 0.05)
plt.xlabel('Average of mean total energies for each H bond',
fontweight='bold')
if (mode == 'mean_energies'):
offset = 0
else:
offset = max_freq * 0.025
for sub_x, sub_y, sub_ylabel in zip(sub_xs, sub_ys, sub_ylabels):
ax.text(sub_x + offset, sub_y, sub_ylabel.strip(),
horizontalalignment='left', verticalalignment='center',
size=7)
ax.set_facecolor('whitesmoke')
if (output_path is not None):
output_path = Path(output_path)
if (output_path.parent.is_dir()):
plt.savefig(str(output_path), dpi=300, transparent=True,
pad_inches=0.05)
return
plt.show()
def main():
hb_paths, mode, lim, epochs_to_ignore, trajectories_to_ignore, \
models_to_ignore, relative_output_path, proc_number, \
PELE_output_path, PELE_report_name = parse_args()
hb_paths_list = []
if (type(hb_paths) == list):
for hb_path in hb_paths:
hb_paths_list += glob.glob(hb_path)
else:
hb_paths_list = glob.glob(hb_paths)
general_results = {}
for hb_path in hb_paths_list:
df = create_df(hb_path)
# Calculate hbond_atoms, which is a dict with PELE_ids as key and
# corresponding lists of H bonds as values
hbond_atoms = get_hbond_atoms_from_df(df, hb_path,
epochs_to_ignore,
trajectories_to_ignore,
models_to_ignore)
if (relative_output_path is not None):
output_path = Path(hb_path).parent.joinpath(relative_output_path)
else:
output_path = relative_output_path
if (mode == "count"):
counter = count(hbond_atoms)
elif (mode == "relative_frequency"):
counter = count_norm(hbond_atoms)
elif (mode == "frequent_interactions"):
counter = count_norm(hbond_atoms)
counter = discard_non_frequent(counter, lim)
elif (mode == "mean_energies"):
sim_it = SimIt(Path(hb_path).parent)
sim_it.build_repo_it(PELE_output_path, 'report')
reports = [repo for repo in sim_it.repo_it]
PELE_ids = extract_PELE_ids(reports)
metrics = extract_metrics(reports, (4, ), proc_number)
ies = []
for ies_chunk in metrics:
ies.append(list(map(float, np.concatenate(ies_chunk))))
ie_by_PELE_id = get_metric_by_PELE_id(PELE_ids, ies)
counter = count_energy(hbond_atoms, ie_by_PELE_id)
general_results[hb_path] = counter
combined_results = combine_results(general_results, mode)
generate_barplot(combined_results, mode, lim, output_path)
if __name__ == "__main__":
main()