-
Notifications
You must be signed in to change notification settings - Fork 4
/
get_abag_recovery.py
124 lines (113 loc) · 5.46 KB
/
get_abag_recovery.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from tqdm import tqdm
import os
import pandas as pd
import torch
from utils.command_line_utils import _get_args
from utils.AntibodyMetricsReporter import AntibodyMetricsReporter
torch.set_grad_enabled(False)
def get_trials(mr):
if mr==1.0:
trials_masking = 1
trials_sampling = 100
elif mr >= 0.50 and mr < 1.0:
trials_masking = 20
trials_sampling = 100
else:
# low mr
trials_masking = 50
trials_sampling = 20
return trials_masking, trials_sampling
def run_ab_or_ag_masking(masking_rates=[1.0]):
row_list = []
for sel in ['Ab', 'Ag']:
print('Selection partner', sel)
for mr in tqdm(masking_rates):
trials_masking, trials_sampling = get_trials(mr)
print(mr, sel)
reporter = AntibodyMetricsReporter(partner_selection=sel,
region_selection=None,
intersect_with_contacts=False)
(outdict, iddict), (outdict_sampled, iddict_sampled, outdict_max_sampled) = \
reporter.run_trials_for_seqrec(mr=mr,
trials_masking=trials_masking,
trials_sampling=trials_sampling,
subdir='MaskAborAgContact'
)
for key in outdict:
df_row_dict = {}
df_row_dict['AndIntersection'] = True #default with not selecting region
df_row_dict['Mask'] = sel
df_row_dict['Region'] = key
df_row_dict['Accuracy'] = outdict[key]
df_row_dict['Trials_Masking'] = trials_masking
df_row_dict['Trials_Sampling'] = trials_sampling
df_row_dict['MaskRate'] = mr
#df_row_dict['PerTarget'] = iddict[key]
df_row_dict['Accuracy_Sampled'] = outdict_sampled[key]
df_row_dict['Accuracy_SampledMax'] = outdict_max_sampled[key]
#df_row_dict['PerTarget_Sampled'] = iddict_sampled[key]
row_list.append(df_row_dict)
return row_list
def run_region_masking(masking_rates=[1.0],
regions = ['h1', 'h2', 'h3', 'l1', 'l2', 'l3', 'non-cdr'],
contact_options = [True, False],
subdirs = ['MaskCDRandContact', 'MaskCDR']):
partner_sel = 'Ab'
row_list = []
for intersect_with_contacts, subdir in zip(contact_options, subdirs):
for region_sel in regions:
print('Selection partner', partner_sel)
for mr in tqdm(masking_rates):
trials_masking, trials_sampling = get_trials(mr)
print(partner_sel, mr, region_sel, intersect_with_contacts)
reporter = AntibodyMetricsReporter(partner_selection=partner_sel,
region_selection=region_sel,
intersect_with_contacts=intersect_with_contacts)
(outdict, iddict), (outdict_sampled, iddict_sampled, outdict_max_sampled) = \
reporter.run_trials_for_seqrec(mr=mr,
trials_masking=trials_masking,
trials_sampling=trials_sampling,
subdir=subdir,
suffix='_{}'.format(region_sel)
)
df_row_dict = {}
df_row_dict['AndIntersection'] = intersect_with_contacts
df_row_dict['Mask'] = region_sel
df_row_dict['Region'] = region_sel
df_row_dict['Accuracy'] = outdict[region_sel]
df_row_dict['Trials_Masking'] = trials_masking
df_row_dict['Trials_Sampling'] = trials_sampling
df_row_dict['MaskRate'] = mr
#df_row_dict['PerTarget'] = iddict[region_sel]
df_row_dict['Accuracy_Sampled'] = outdict_sampled[region_sel]
df_row_dict['Accuracy_SampledMax'] = outdict_max_sampled[region_sel]
#df_row_dict['PerTarget_Sampled'] = iddict_sampled[region_sel]
row_list.append(df_row_dict)
return row_list
args = _get_args()
do_not_overwrite=False
csv_file = os.path.join(args.output_dir, 'ConsolidatedAccuracy_MaskedCdrs.csv')
if os.path.exists(csv_file) and do_not_overwrite:
df_2 = pd.read_csv(csv_file)
if 'Unnamed: 0' in df_2.columns:
df_2.drop(columns=['Unnamed: 0'], inplace=True)
row_list = df_2.to_dict(orient='list')
else:
row_list = run_region_masking()
df_2 = pd.DataFrame(row_list)
df_2.to_csv(csv_file, index=False)
csv_file = os.path.join(args.output_dir, 'ConsolidatedAccuracy_MaskedAborAg.csv')
if os.path.exists(csv_file) and do_not_overwrite:
print('Found existing: ', csv_file)
df_1 = pd.read_csv(csv_file)
if 'Unnamed: 0' in df_1.columns:
df_1.drop(columns=['Unnamed: 0'], inplace=True)
print(df_1)
row_list = df_1.to_dict(orient='list')
print(row_list.keys())
else:
row_list = run_ab_or_ag_masking()
df_1 = pd.DataFrame(row_list)
df_1.to_csv(csv_file, index=False)
df = pd.concat([df_1, df_2]).reset_index()
df.to_csv(os.path.join(args.output_dir, 'ConsolidatedAccuracy_AbAgRegions_All.csv'))