-
Notifications
You must be signed in to change notification settings - Fork 0
/
benchmark_module.py
202 lines (169 loc) · 5.85 KB
/
benchmark_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
from operator import gt
from pathlib import Path
import networkx as nx
import cassiopeia as cas
import pandas as pd
import pickle as pic
from tqdm import tqdm, trange
import numpy as np
import inverse_whd
import pickle
import numpy as np
import matplotlib.pyplot as plt
import numba
import os
from cassiopeia.solver.CassiopeiaSolver import CassiopeiaSolver
class BenchmarkModule:
def __init__(
self,
test_name: str,
solver: CassiopeiaSolver = cas.solver.NeighborJoiningSolver(add_root = True, dissimilarity_function=cas.solver.dissimilarity_functions.weighted_hamming_distance), # type: ignore
gt_trees_dir = "/data/yosef2/users/richardz/projects/CassiopeiaV2-Reproducibility/priors/states100.pkl",
numtrees = 50,
out_basefolder = "./benchmarking/",
):
self.test_name = test_name
self.solver = solver
self.gt_trees_dir = gt_trees_dir
self.numtrees = numtrees
self.out_basefolder = out_basefolder
def get_gt_tree(self, i):
gt_tree_file = os.path.join(self.gt_trees_dir, f"tree{i}.pkl")
gt_tree = pic.load(open(gt_tree_file, "rb"))
return gt_tree
def get_recon_tree(self, i):
recon_file = os.path.join(self.out_basefolder, self.test_name, f"recon{i}")
recon_tree = cas.data.CassiopeiaTree(
tree=recon_file
)
return recon_tree
def run_solver(self, i, cm, collapse_mutationless_edges) -> str:
# Initialize output recon tree
recon_tree = cas.data.CassiopeiaTree(
character_matrix=cm,
missing_state_indicator = -1
)
# Instantiate Solver
self.solver.solve(recon_tree, collapse_mutationless_edges = collapse_mutationless_edges)
return recon_tree.get_newick()
def get_cm(self, i):
cm_file = os.path.join(self.gt_trees_dir, f"cm{i}.txt")
cm = pd.read_table(cm_file, index_col = 0)
cm = cm.replace(-2, -1) # type: ignore
return cm
def reconstruct(
self,
overwrite=False,
collapse_mutationless_edges=True
):
pbar = trange(self.numtrees)
for i in pbar:
pbar.set_description(f"Reconstructing tree {i}")
# Output File
recon_outfile = Path(os.path.join(self.out_basefolder, self.test_name, f"recon{i}"))
recon_outfile.parent.mkdir(parents=True, exist_ok=True)
if not overwrite and recon_outfile.exists():
pbar.set_description(f"Skipping reconstruction {i}")
continue
# Get CM
cm = self.get_cm(i)
# Instantiate Solver
recon_newick = self.run_solver(i, cm, collapse_mutationless_edges)
# Save
with open(recon_outfile, "w+") as f:
f.write(recon_newick)
f.close()
def evaluate(self, overwrite=False):
# Output Files
rf_out = Path(os.path.join(self.out_basefolder, f"{self.test_name}.rf.csv"))
triplets_out = Path(os.path.join(self.out_basefolder, f"{self.test_name}.triplets.csv"))
# Check overwrites
if not overwrite and rf_out.exists() and triplets_out.exists():
return
# Init datframes
triplets_df = pd.DataFrame(
columns=[
"NumberOfCells",
"Priors",
"Fitness",
"Stressor",
"Parameter",
"Algorithm",
"Replicate",
"Depth",
"TripletsCorrect",
]
)
RF_df = pd.DataFrame(
columns=[
"NumberOfCells",
"Priors",
"Fitness",
"Stressor",
"Parameter",
"Algorithm",
"Replicate",
"UnNormalizedRobinsonFoulds",
"MaxRobinsonFoulds",
"NormalizedRobinsonFoulds",
]
)
# Main Loop
pbar = trange(self.numtrees)
for i in pbar:
pbar.set_description(f"Evaluating tree {i}")
# GT Tree
gt_tree = self.get_gt_tree(i)
# Recon Tree
recon_tree = self.get_recon_tree(i)
# Triplets
triplet_correct = cas.critique.triplets_correct(
gt_tree,
recon_tree,
number_of_trials=1000,
min_triplets_at_depth=50,
)[0]
for depth in triplet_correct:
triplets_df = triplets_df.append(
pd.Series(
[
400,
"no_priors",
"no_fit",
"char",
40,
"SNJ",
i,
depth,
triplet_correct[depth],
],
index=triplets_df.columns,
),
ignore_index=True,
)
# RF
rf, rf_max = cas.critique.robinson_foulds(
gt_tree, recon_tree
)
RF_df = RF_df.append(
pd.Series(
[
400,
"no_priors",
"no_fit",
"char",
40,
"SNJ",
i,
rf,
rf_max,
rf / rf_max,
],
index=RF_df.columns,
),
ignore_index=True,
)
# Save
triplets_df.to_csv(triplets_out)
RF_df.to_csv(rf_out)
return