-
Notifications
You must be signed in to change notification settings - Fork 0
/
precision_recall_average.py
executable file
·92 lines (75 loc) · 3.83 KB
/
precision_recall_average.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#!/usr/bin/env python
import sys
import argparse
import numpy as np
import math
from utils import filter_tail
from utils import load_data
from utils import argparse_parents
from utils import labels
def compute_precision_and_recall(data, filter_tail_percentage):
avg_precision = .0
avg_recall = .0
count_p = 0
count_r = 0
if filter_tail_percentage:
data = filter_tail.filter_tail(data, filter_tail_percentage)
# for each predicted bin (row of the table)
for bin in data:
# compute the average recall over all bins if the mapped genome size > 0
real_size = float(bin['real_size'])
if real_size > 0:
recall = float(bin['recall'])
count_r += 1
current_avg = avg_recall
avg_recall = (recall - current_avg) / count_r + current_avg
# compute the average precision over all bins
if not np.isnan(bin['precision']):
precision = bin['precision']
count_p += 1
current_avg = avg_precision
avg_precision = (precision - current_avg) / count_p + current_avg
sum_diffs_precision = .0
sum_diffs_recall = .0
for bin in data:
real_size = float(bin['real_size'])
if real_size > 0:
recall = float(bin['recall'])
sum_diffs_recall += math.pow(recall - avg_recall, 2)
if not np.isnan(bin['precision']):
precision = bin['precision']
sum_diffs_precision += math.pow(precision - avg_precision, 2)
std_deviation_precision = math.sqrt(sum_diffs_precision / count_p)
std_error_precision = std_deviation_precision / math.sqrt(count_p)
std_deviation_recall = math.sqrt(sum_diffs_recall / count_r)
std_error_recall = std_deviation_recall / math.sqrt(count_r)
return avg_precision, avg_recall, std_deviation_precision, std_deviation_recall, std_error_precision, std_error_recall
def print_precision_recall_table_header(stream=sys.stdout):
stream.write("%s\n" % "\t".join((labels.TOOL, labels.AVG_PRECISION, labels.STD_DEV_PRECISION,
labels.SEM_PRECISION, labels.AVG_RECALL, labels.STD_DEV_RECALL, labels.
SEM_RECALL)))
def print_precision_recall(label, avg_precision, avg_recall, std_deviation_precision, std_deviation_recall, std_error_precision, std_error_recall,
stream=sys.stdout):
if not label:
label = ""
stream.write("%s\n" % "\t".join((label,
format(avg_precision, '.3f'),
format(std_deviation_precision, '.3f'),
format(std_error_precision, '.3f'),
format(avg_recall, '.3f'),
format(std_deviation_recall, '.3f'),
format(std_error_recall, '.3f'))))
def main():
parser = argparse.ArgumentParser(description="Compute precision and recall, including standard deviation and standard error of the mean, from table of precision and recall per genome. The table can be provided as file or via the standard input",
parents=[argparse_parents.PARSER_MULTI])
args = parser.parse_args()
if not args.file and sys.stdin.isatty():
parser.print_help()
parser.exit(1)
metrics = load_data.load_tsv_table(sys.stdin if not sys.stdin.isatty() else args.file)
avg_precision, avg_recall, std_deviation_precision, std_deviation_recall, std_error_precision, std_error_recall =\
compute_precision_and_recall(metrics, args.filter)
print_precision_recall_table_header()
print_precision_recall(args.label, avg_precision, avg_recall, std_deviation_precision, std_deviation_recall, std_error_precision, std_error_recall)
if __name__ == "__main__":
main()