-
Notifications
You must be signed in to change notification settings - Fork 2
/
cal_metrics.py
228 lines (196 loc) · 8.82 KB
/
cal_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
#!/usr/bin/python
# -*- coding:utf-8 -*-
import argparse
import json
import os
import random
from copy import deepcopy
from collections import defaultdict
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map
import statistics
import warnings
warnings.filterwarnings("ignore")
import numpy as np
from scipy.stats import spearmanr
from data.converter.pdb_to_list_blocks import pdb_to_list_blocks
from evaluation import diversity
from evaluation.dockq import dockq
from evaluation.rmsd import compute_rmsd
from utils.random_seed import setup_seed
from evaluation.seq_metric import aar, slide_aar
def _get_ref_pdb(_id, root_dir):
return os.path.join(root_dir, 'references', f'{_id}_ref.pdb')
def _get_gen_pdb(_id, number, root_dir, use_rosetta):
suffix = '_rosetta' if use_rosetta else ''
return os.path.join(root_dir, 'candidates', _id, f'{_id}_gen_{number}{suffix}.pdb')
def cal_metrics(items):
# all of the items are conditioned on the same binding pocket
root_dir = items[0]['root_dir']
ref_pdb, rec_chain, lig_chain = items[0]['ref_pdb'], items[0]['rec_chain'], items[0]['lig_chain']
ref_pdb = _get_ref_pdb(items[0]['id'], root_dir)
seq_only, struct_only, backbone_only = items[0]['seq_only'], items[0]['struct_only'], items[0]['backbone_only']
# prepare
results = defaultdict(list)
cand_seqs, cand_ca_xs = [], []
rec_blocks, ref_pep_blocks = pdb_to_list_blocks(ref_pdb, [rec_chain, lig_chain])
ref_ca_x, ca_mask = [], []
for ref_block in ref_pep_blocks:
if ref_block.has_unit('CA'):
ca_mask.append(1)
ref_ca_x.append(ref_block.get_unit_by_name('CA').get_coord())
else:
ca_mask.append(0)
ref_ca_x.append([0, 0, 0])
ref_ca_x, ca_mask = np.array(ref_ca_x), np.array(ca_mask).astype(bool)
for item in items:
if not struct_only:
cand_seqs.append(item['gen_seq'])
results['Slide AAR'].append(slide_aar(item['gen_seq'], item['ref_seq'], aar))
# structure metrics
gen_pdb = _get_gen_pdb(item['id'], item['number'], root_dir, item['rosetta'])
_, gen_pep_blocks = pdb_to_list_blocks(gen_pdb, [rec_chain, lig_chain])
assert len(gen_pep_blocks) == len(ref_pep_blocks), f'{item}\t{len(ref_pep_blocks)}\t{len(gen_pep_blocks)}'
# CA RMSD
gen_ca_x = np.array([block.get_unit_by_name('CA').get_coord() for block in gen_pep_blocks])
cand_ca_xs.append(gen_ca_x)
rmsd = compute_rmsd(ref_ca_x[ca_mask], gen_ca_x[ca_mask], aligned=True)
results['RMSD(CA)'].append(rmsd)
if struct_only:
results['RMSD<=2.0'].append(1 if rmsd <= 2.0 else 0)
results['RMSD<=5.0'].append(1 if rmsd <= 5.0 else 0)
results['RMSD<=10.0'].append(1 if rmsd <= 10.0 else 0)
if backbone_only:
continue
# 5. DockQ
dockq_score = dockq(gen_pdb, ref_pdb, lig_chain)
results['DockQ'].append(dockq_score)
if struct_only:
results['DockQ>=0.23'].append(1 if dockq_score >= 0.23 else 0)
results['DockQ>=0.49'].append(1 if dockq_score >= 0.49 else 0)
results['DockQ>=0.80'].append(1 if dockq_score >= 0.80 else 0)
# Full atom RMSD
if struct_only:
gen_all_x, ref_all_x = [], []
for gen_block, ref_block in zip(gen_pep_blocks, ref_pep_blocks):
for ref_atom in ref_block:
if gen_block.has_unit(ref_atom.name):
ref_all_x.append(ref_atom.get_coord())
gen_all_x.append(gen_block.get_unit_by_name(ref_atom.name).get_coord())
results['RMSD(full-atom)'].append(compute_rmsd(
np.array(gen_all_x), np.array(ref_all_x), aligned=True
))
pmets = [item['pmetric'] for item in items]
indexes = list(range(len(items)))
# aggregation
for name in results:
vals = results[name]
corr = spearmanr(vals, pmets, nan_policy='omit').statistic
if np.isnan(corr):
corr = 0
aggr_res = {
'max': max(vals),
'min': min(vals),
'mean': sum(vals) / len(vals),
'random': vals[0],
'max*': vals[(max if corr > 0 else min)(indexes, key=lambda i: pmets[i])],
'min*': vals[(min if corr > 0 else max)(indexes, key=lambda i: pmets[i])],
'pmet_corr': corr,
'individual': vals,
'individual_pmet': pmets
}
results[name] = aggr_res
if len(cand_seqs) > 1 and not seq_only:
seq_div, struct_div, co_div, consistency = diversity.diversity(cand_seqs, np.array(cand_ca_xs))
results['Sequence Diversity'] = seq_div
results['Struct Diversity'] = struct_div
results['Codesign Diversity'] = co_div
results['Consistency'] = consistency
return results
def cnt_aa_dist(seqs):
cnts = {}
for seq in seqs:
for aa in seq:
if aa not in cnts:
cnts[aa] = 0
cnts[aa] += 1
aas = sorted(list(cnts.keys()), key=lambda aa: cnts[aa])
total = sum(cnts.values())
for aa in aas:
print(f'\t{aa}: {cnts[aa] / total}')
def main(args):
root_dir = os.path.dirname(args.results)
# load dG filter
if args.filter_dG is None:
filter_func = lambda _id, n: True
else:
dG_results = json.load(open(args.filter_dG, 'r'))
filter_func = lambda _id, n: dG_results[_id]['all'][str(n)] < 0
# load results
with open(args.results, 'r') as fin:
lines = fin.read().strip().split('\n')
id2items = {}
for line in lines:
item = json.loads(line)
_id = item['id']
if not filter_func(_id, item['number']):
continue
if _id not in id2items:
id2items[_id] = []
item['root_dir'] = root_dir
item['rosetta'] = args.rosetta
id2items[_id].append(item)
ids = list(id2items.keys())
if args.filter_dG is not None:
# delete results with only one sample since it cannot calculate diversity
del_ids = [_id for _id in ids if len(id2items[_id]) < 2]
for _id in del_ids:
print(f'Deleting {_id} since it only has one sample passed the filter')
del id2items[_id]
if args.num_workers > 1:
metrics = process_map(cal_metrics, id2items.values(), max_workers=args.num_workers, chunksize=1)
else:
metrics = [cal_metrics(inputs) for inputs in tqdm(id2items.values())]
eval_results_path = os.path.join(os.path.dirname(args.results), 'eval_report.json')
with open(eval_results_path, 'w') as fout:
for i, _id in enumerate(id2items):
metric = deepcopy(metrics[i])
metric['id'] = _id
fout.write(json.dumps(metric) + '\n')
# individual level results
print('Point-wise evaluation results:')
for name in metrics[0]:
vals = [item[name] for item in metrics]
if isinstance(vals[0], dict):
if 'RMSD' in name and '<=' not in name:
aggr = 'min'
else:
aggr = 'max'
aggr_vals = [val[aggr] for val in vals]
if '>=' in name or '<=' in name: # percentage
print(f'{name}: {sum(aggr_vals) / len(aggr_vals)}')
else:
if 'RMSD' in name:
print(f'{name}(median): {statistics.median(aggr_vals)}') # unbounded, some extreme values will affect the mean but not the median
else:
print(f'{name}(mean): {sum(aggr_vals) / len(aggr_vals)}')
lowest_i = min([i for i in range(len(aggr_vals))], key=lambda i: aggr_vals[i])
highest_i = max([i for i in range(len(aggr_vals))], key=lambda i: aggr_vals[i])
print(f'\tlowest: {aggr_vals[lowest_i]}, id: {ids[lowest_i]}', end='')
print(f'\thighest: {aggr_vals[highest_i]}, id: {ids[highest_i]}')
else:
print(f'{name} (mean): {sum(vals) / len(vals)}')
lowest_i = min([i for i in range(len(vals))], key=lambda i: vals[i])
highest_i = max([i for i in range(len(vals))], key=lambda i: vals[i])
print(f'\tlowest: {vals[lowest_i]}, id: {ids[lowest_i]}')
print(f'\thighest: {vals[highest_i]}, id: {ids[highest_i]}')
def parse():
parser = argparse.ArgumentParser(description='calculate metrics')
parser.add_argument('--results', type=str, required=True, help='Path to test set')
parser.add_argument('--num_workers', type=int, default=8, help='Number of workers to use')
parser.add_argument('--rosetta', action='store_true', help='Use the rosetta-refined structure')
parser.add_argument('--filter_dG', type=str, default=None, help='Only calculate results on samples with dG<0')
return parser.parse_args()
if __name__ == '__main__':
setup_seed(0)
main(parse())