diff --git a/examples/lep/cnn3d/train.py b/examples/lep/cnn3d/train.py index 2d1e583..df204ed 100644 --- a/examples/lep/cnn3d/train.py +++ b/examples/lep/cnn3d/train.py @@ -37,13 +37,10 @@ def major_vote(results): def compute_stats(df): results = major_vote(df) res = {} - all_true = results['true'].astype(np.int8) - all_pred = results['pred'].astype(np.int8) - res['auroc'] = sm.roc_auc_score(all_true, all_pred) - res['auprc'] = sm.average_precision_score(all_true, all_pred) - res['acc'] = sm.accuracy_score(all_true, all_pred.round()) - res['bal_acc'] = \ - sm.balanced_accuracy_score(all_true, all_pred.round()) + all_true = results['true'] + all_prob = results['avg_prob'] + res['auroc'] = sm.roc_auc_score(all_true, all_prob) + res['auprc'] = sm.average_precision_score(all_true, all_prob) return res diff --git a/examples/msp/cnn3d/data.py b/examples/msp/cnn3d/data.py index a652839..cadce7b 100644 --- a/examples/msp/cnn3d/data.py +++ b/examples/msp/cnn3d/data.py @@ -74,7 +74,7 @@ def __call__(self, item): transformed = { 'feature_original': self._voxelize(item['original_atoms'], mut_chain, mut_res, False), 'feature_mutated': self._voxelize(item['mutated_atoms'], mut_chain, mut_res, True), - 'label': int(item['label'] == '0'), # Convert to 0 for original, 1 for mutated + 'label': int(item['label']), # Convert to 0 for original, 1 for mutated 'id': item['id'] } return transformed diff --git a/examples/msp/cnn3d/train.py b/examples/msp/cnn3d/train.py index db97721..835a217 100644 --- a/examples/msp/cnn3d/train.py +++ b/examples/msp/cnn3d/train.py @@ -37,13 +37,10 @@ def major_vote(results): def compute_stats(df): results = major_vote(df) res = {} - all_true = results['true'].astype(np.int8) - all_pred = results['pred'].astype(np.int8) - res['auroc'] = sm.roc_auc_score(all_true, all_pred) - res['auprc'] = sm.average_precision_score(all_true, all_pred) - res['acc'] = sm.accuracy_score(all_true, all_pred.round()) - res['bal_acc'] = \ - sm.balanced_accuracy_score(all_true, all_pred.round()) + all_true = results['true'] + all_prob = results['avg_prob'].astype(np.float) + res['auroc'] = sm.roc_auc_score(all_true, all_prob) + res['auprc'] = sm.average_precision_score(all_true, all_prob) return res