Skip to content

Commit

Permalink
Fixed flipped label 3dcnn msp
Browse files Browse the repository at this point in the history
  • Loading branch information
psuriana committed Jun 5, 2021
1 parent 5701486 commit 4b6b011
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 15 deletions.
11 changes: 4 additions & 7 deletions examples/lep/cnn3d/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,10 @@ def major_vote(results):
def compute_stats(df):
results = major_vote(df)
res = {}
all_true = results['true'].astype(np.int8)
all_pred = results['pred'].astype(np.int8)
res['auroc'] = sm.roc_auc_score(all_true, all_pred)
res['auprc'] = sm.average_precision_score(all_true, all_pred)
res['acc'] = sm.accuracy_score(all_true, all_pred.round())
res['bal_acc'] = \
sm.balanced_accuracy_score(all_true, all_pred.round())
all_true = results['true']
all_prob = results['avg_prob']
res['auroc'] = sm.roc_auc_score(all_true, all_prob)
res['auprc'] = sm.average_precision_score(all_true, all_prob)
return res


Expand Down
2 changes: 1 addition & 1 deletion examples/msp/cnn3d/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __call__(self, item):
transformed = {
'feature_original': self._voxelize(item['original_atoms'], mut_chain, mut_res, False),
'feature_mutated': self._voxelize(item['mutated_atoms'], mut_chain, mut_res, True),
'label': int(item['label'] == '0'), # Convert to 0 for original, 1 for mutated
'label': int(item['label']), # Convert to 0 for original, 1 for mutated
'id': item['id']
}
return transformed
Expand Down
11 changes: 4 additions & 7 deletions examples/msp/cnn3d/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,10 @@ def major_vote(results):
def compute_stats(df):
results = major_vote(df)
res = {}
all_true = results['true'].astype(np.int8)
all_pred = results['pred'].astype(np.int8)
res['auroc'] = sm.roc_auc_score(all_true, all_pred)
res['auprc'] = sm.average_precision_score(all_true, all_pred)
res['acc'] = sm.accuracy_score(all_true, all_pred.round())
res['bal_acc'] = \
sm.balanced_accuracy_score(all_true, all_pred.round())
all_true = results['true']
all_prob = results['avg_prob'].astype(np.float)
res['auroc'] = sm.roc_auc_score(all_true, all_prob)
res['auprc'] = sm.average_precision_score(all_true, all_prob)
return res


Expand Down

0 comments on commit 4b6b011

Please sign in to comment.