Skip to content

Commit

Permalink
[Example] fix auc in caregnn example (dmlc#3647)
Browse files Browse the repository at this point in the history
Co-authored-by: zhjwy9343 <[email protected]>
  • Loading branch information
kayzliu and zhjwy9343 authored Jan 21, 2022
1 parent 5747637 commit ed4134e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
7 changes: 4 additions & 3 deletions examples/pytorch/caregnn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch as th
from model import CAREGNN
import torch.optim as optim
from torch.nn.functional import softmax
from sklearn.metrics import recall_score, roc_auc_score

from utils import EarlyStopping
Expand Down Expand Up @@ -70,13 +71,13 @@ def main(args):
args.sim_weight * loss_fn(logits_sim[train_idx], labels[train_idx])

tr_recall = recall_score(labels[train_idx].cpu(), logits_gnn.data[train_idx].argmax(dim=1).cpu())
tr_auc = roc_auc_score(labels[train_idx].cpu(), logits_gnn.data[train_idx][:, 1].cpu())
tr_auc = roc_auc_score(labels[train_idx].cpu(), softmax(logits_gnn, dim=1).data[train_idx][:, 1].cpu())

# validation
val_loss = loss_fn(logits_gnn[val_idx], labels[val_idx]) + \
args.sim_weight * loss_fn(logits_sim[val_idx], labels[val_idx])
val_recall = recall_score(labels[val_idx].cpu(), logits_gnn.data[val_idx].argmax(dim=1).cpu())
val_auc = roc_auc_score(labels[val_idx].cpu(), logits_gnn.data[val_idx][:, 1].cpu())
val_auc = roc_auc_score(labels[val_idx].cpu(), softmax(logits_gnn, dim=1).data[val_idx][:, 1].cpu())

# backward
optimizer.zero_grad()
Expand Down Expand Up @@ -106,7 +107,7 @@ def main(args):
test_loss = loss_fn(logits_gnn[test_idx], labels[test_idx]) + \
args.sim_weight * loss_fn(logits_sim[test_idx], labels[test_idx])
test_recall = recall_score(labels[test_idx].cpu(), logits_gnn[test_idx].argmax(dim=1).cpu())
test_auc = roc_auc_score(labels[test_idx].cpu(), logits_gnn.data[test_idx][:, 1].cpu())
test_auc = roc_auc_score(labels[test_idx].cpu(), softmax(logits_gnn, dim=1).data[test_idx][:, 1].cpu())

print("Test Recall: {:.4f} AUC: {:.4f} Loss: {:.4f}".format(test_recall, test_auc, test_loss.item()))

Expand Down
5 changes: 3 additions & 2 deletions examples/pytorch/caregnn/main_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse
import torch as th
import torch.optim as optim
from torch.nn.functional import softmax
from sklearn.metrics import roc_auc_score, recall_score

from utils import EarlyStopping
Expand All @@ -22,7 +23,7 @@ def evaluate(model, loss_fn, dataloader, device='cpu'):
# compute loss
loss += loss_fn(logits_gnn, label).item() + args.sim_weight * loss_fn(logits_sim, label).item()
recall += recall_score(label.cpu(), logits_gnn.argmax(dim=1).detach().cpu())
auc += roc_auc_score(label.cpu(), logits_gnn[:, 1].detach().cpu())
auc += roc_auc_score(label.cpu(), softmax(logits_gnn, dim=1)[:, 1].detach().cpu())
num_blocks += 1

return recall / num_blocks, auc / num_blocks, loss / num_blocks
Expand Down Expand Up @@ -121,7 +122,7 @@ def main(args):
blk_loss = loss_fn(logits_gnn, train_label) + args.sim_weight * loss_fn(logits_sim, train_label)
tr_loss += blk_loss.item()
tr_recall += recall_score(train_label.cpu(), logits_gnn.argmax(dim=1).detach().cpu())
tr_auc += roc_auc_score(train_label.cpu(), logits_gnn[:, 1].detach().cpu())
tr_auc += roc_auc_score(train_label.cpu(), softmax(logits_gnn, dim=1)[:, 1].detach().cpu())
tr_blk += 1

# backward
Expand Down

0 comments on commit ed4134e

Please sign in to comment.