-
Notifications
You must be signed in to change notification settings - Fork 0
/
results.py
executable file
·150 lines (137 loc) · 6.02 KB
/
results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#!/usr/bin/env python
import sys
import argparse
import numpy as np
import pandas as pd
from pathlib import Path
from sklearn import metrics
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import plot_precision_recall_curve
def get_predicts(demux):
""" parse the .best file into a pd dataframe """
barcodes = pd.read_csv(demux, sep="\t", usecols=['BARCODE', 'BEST', 'LLK12', 'SNG.LLK1'], index_col='BARCODE')
# create doublet pseudo-probability: LLK12
prob = barcodes['SNG.LLK1']
# now parse the BEST column
barcodes = barcodes['BEST'].str.split('-', n=1, expand=True)
barcodes.columns = ['type', 'sample']
# add the prob column back
barcodes['prob'] = prob
barcodes['sample'] = barcodes['sample'].str.split('-')
# remove the prob suffix from each doublet
barcodes.loc[barcodes['type'] == 'DBL', 'sample'] = barcodes['sample'][barcodes['type'] == 'DBL'].apply(lambda row: row[:2])
# convert lists to tuples b/c they're immutable
barcodes['sample'] = barcodes['sample'].apply(tuple)
return barcodes
def get_truth(truth):
""" parse the true barcode assignments into an equivalent pd df """
barcodes = pd.read_csv(truth, sep="\t", header=None, names=['samp', 'old', 'BARCODE'])
barcodes = barcodes.groupby('BARCODE').apply(lambda grp: list(grp['samp'])).to_frame('sample')
barcodes.insert(0, 'type', barcodes['sample'].map(lambda x: ['SNG', 'DBL'][len(x)-1]), True)
# convert lists to tuples b/c they're immutable
barcodes['sample'] = barcodes['sample'].apply(lambda samp: tuple(str(s) for s in samp))
return barcodes
def type_metrics(predicts, truth):
""" were droplots correctly classified by their type? """
labels = ['DBL', 'SNG']
scores = metrics.precision_recall_fscore_support(
truth['type'],
predicts['type'],
labels=labels
)
scores = pd.DataFrame(
scores, columns=labels,
index=['precision', 'recall', 'fscore', 'support']
)
return scores
def hamming_exactmatch(predicts, truth, samples, exact=False):
""" return the hamming loss (or exact match accuracy) for these droplets """
# convert to matrix format
preds_matrix = MultiLabelBinarizer(classes = samples)
preds_matrix = preds_matrix.fit_transform(predicts['sample'].copy())
trth_matrix = MultiLabelBinarizer(classes = samples)
trth_matrix = trth_matrix.fit_transform(truth['sample'].copy())
if exact:
return metrics.accuracy_score(trth_matrix, preds_matrix)
return metrics.hamming_loss(trth_matrix, preds_matrix)
def hammings_accuracy(predicts, truth, exact=False):
""" return the hamming loss (or exact match accuracy) for doublets, singlets, and both """
# preprocess the predicts dataframe by converting ambiguous droplets to None
preds = predicts.copy()
preds.loc[preds['type'] == 'AMB', 'sample'] = [('NA',)] * sum(preds['type'] == 'AMB')
# and get the union of the samples from the predicts and the truth
samples = set(preds['sample'].explode().unique())
samples |= set(truth['sample'].explode().unique())
samples = sorted(tuple(samples))
# get the hamming losses
labels = {'DBL':0, 'SNG':0, 'BOTH':0}
for label in labels:
lab = [label]
if lab == ['BOTH']:
lab = ['DBL', 'SNG']
trth = truth[truth['type'].isin(lab)]
prds = preds.loc[trth.index].copy()
labels[label] = hamming_exactmatch(prds, trth, samples, exact)
return labels
def cohen_kappa(predicts, truth):
""" calculate cohen's kappa for the singlets """
preds = predicts.loc[preds['type'] == 'SNG', 'sample']
trth = truth.loc[preds['type'] == 'SNG', 'sample']
# agh this won't work because they won't share the same droplets!
# potential solution: just calculate cohen's kappa among predicted and simulated singlets
return metrics.cohen_kappa_score(preds, trth)
def prc_curve(predicts, truth):
""" plot a precision/recall curve for the singlets """
precision, recall, thresholds = metrics.precision_recall_curve(
truth['type'] == 'SNG',
predicts
)
plt.plot(recall, precision)
fig = plt.gcf()
plt.xlim([0.0, 1.0])
plt.xlabel('Recall')
plt.ylabel('Precision')
return fig
def main(demux, truth, curve=False):
# retrieve the predicted samples from demuxlet
predicts = get_predicts(demux)
# extract the probs
probs = predicts['prob']
predicts.drop('prob', axis=1, inplace=True)
# retrieve the true samples from the simulation script
truth = get_truth(truth)
type_scores = type_metrics(predicts, truth)
ham_score = hammings_accuracy(predicts, truth)
accuracy_score = hammings_accuracy(predicts, truth, exact=True)
if curve:
curve = prc_curve(probs, truth)
precision, recall = type_scores['SNG'][:2]
curve.axes[0].plot(recall, precision, "ob")
return type_scores, ham_score, accuracy_score, curve
else:
return type_scores, ham_score, accuracy_score
def write_out(out, curve, type_scores, ham_score, accuracy_score, prc_curve=None):
print("precision/recall:", file=out)
print(type_scores, file=out)
print("\nhamming loss:", file=out)
print(ham_score, file=out)
print("\nsubset accuracy:", file=out)
print(accuracy_score, file=out)
if curve:
prc_curve.savefig(curve, bbox_inches='tight')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Summarize results from a demultiplexing simulation.")
parser.add_argument(
"demux", type=Path, help="demuxlet's .best file"
)
parser.add_argument(
"truth", type=Path, help="the true labels"
)
parser.add_argument(
"curve", type=Path, nargs='?', default=None, help="the path to a file to which to write a precision/recall curve for the singlets if desired (default: don't do it)"
)
args = parser.parse_args()
results = main(args.demux, args.truth, bool(args.curve))
write_out(sys.stdout, args.curve, *results)