From 0b02cee54e806c17b50fb731de65e2d005892e7c Mon Sep 17 00:00:00 2001 From: Tjark Miener Date: Fri, 5 Jul 2019 16:27:47 +0200 Subject: [PATCH] added a script to calculate the auc and acc of the predictions (using sklearn) --- scripts/get_prediction_acc_auc.py | 49 +++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 scripts/get_prediction_acc_auc.py diff --git a/scripts/get_prediction_acc_auc.py b/scripts/get_prediction_acc_auc.py new file mode 100644 index 00000000..9d06a02c --- /dev/null +++ b/scripts/get_prediction_acc_auc.py @@ -0,0 +1,49 @@ +import argparse +import numpy as np +import sklearn.metrics + +parser = argparse.ArgumentParser( + description=("Get acc and auc from the predictions csv file.")) +parser.add_argument('predictions_list_file', + help='list of paths to predictions csv file') +args = parser.parse_args() + +# Predictions list has the format: predicted_class,proton,gamma,tel_id,event_number,run_number,class_label +labels = [] +predictions = [] +gamma_predictions = [] +with open(args.predictions_list_file) as f: + for line in f: + if not line or line[0] == '#': continue + predicted_class,proton,gamma,tel_id,event_number,run_number,class_label = line.split(',') + labels.append(class_label.strip()) + gamma_predictions.append(gamma.strip()) + predictions.append(predicted_class.strip()) + labels = np.array(labels[1:]).astype(np.int) + gamma_predictions = np.array(gamma_predictions[1:]).astype(np.float) + predictions = np.array(predictions[1:]).astype(np.int) + +fpr, tpr, thresholds = sklearn.metrics.roc_curve(labels,gamma_predictions, pos_label=1) +auc = sklearn.metrics.auc(fpr, tpr) +print("auc = {}".format(auc)) + +acc = sklearn.metrics.accuracy_score(labels,predictions) +print("acc = {}%".format(acc*100)) + + + +''' +labels = tf.convert_to_tensor(labels, dtype=tf.float32) +predictions = tf.convert_to_tensor(predictions, dtype=tf.float32) + +#acc, update_op = tf.metrics.accuracy(labels,predictions) + +auc, update_op = tf.metrics.auc(labels,predictions) +print(auc) + +sess = tf.Session() + +result = sess.run(auc) +print(result) + +'''