diff --git a/examples/pytorch/caregnn/main.py b/examples/pytorch/caregnn/main.py index 7fb3fee5ef0f..1b202d81dd73 100644 --- a/examples/pytorch/caregnn/main.py +++ b/examples/pytorch/caregnn/main.py @@ -3,6 +3,7 @@ import torch as th from model import CAREGNN import torch.optim as optim +from torch.nn.functional import softmax from sklearn.metrics import recall_score, roc_auc_score from utils import EarlyStopping @@ -70,13 +71,13 @@ def main(args): args.sim_weight * loss_fn(logits_sim[train_idx], labels[train_idx]) tr_recall = recall_score(labels[train_idx].cpu(), logits_gnn.data[train_idx].argmax(dim=1).cpu()) - tr_auc = roc_auc_score(labels[train_idx].cpu(), logits_gnn.data[train_idx][:, 1].cpu()) + tr_auc = roc_auc_score(labels[train_idx].cpu(), softmax(logits_gnn, dim=1).data[train_idx][:, 1].cpu()) # validation val_loss = loss_fn(logits_gnn[val_idx], labels[val_idx]) + \ args.sim_weight * loss_fn(logits_sim[val_idx], labels[val_idx]) val_recall = recall_score(labels[val_idx].cpu(), logits_gnn.data[val_idx].argmax(dim=1).cpu()) - val_auc = roc_auc_score(labels[val_idx].cpu(), logits_gnn.data[val_idx][:, 1].cpu()) + val_auc = roc_auc_score(labels[val_idx].cpu(), softmax(logits_gnn, dim=1).data[val_idx][:, 1].cpu()) # backward optimizer.zero_grad() @@ -106,7 +107,7 @@ def main(args): test_loss = loss_fn(logits_gnn[test_idx], labels[test_idx]) + \ args.sim_weight * loss_fn(logits_sim[test_idx], labels[test_idx]) test_recall = recall_score(labels[test_idx].cpu(), logits_gnn[test_idx].argmax(dim=1).cpu()) - test_auc = roc_auc_score(labels[test_idx].cpu(), logits_gnn.data[test_idx][:, 1].cpu()) + test_auc = roc_auc_score(labels[test_idx].cpu(), softmax(logits_gnn, dim=1).data[test_idx][:, 1].cpu()) print("Test Recall: {:.4f} AUC: {:.4f} Loss: {:.4f}".format(test_recall, test_auc, test_loss.item())) diff --git a/examples/pytorch/caregnn/main_sampling.py b/examples/pytorch/caregnn/main_sampling.py index 42e95e61dda7..54cee3937249 100644 --- a/examples/pytorch/caregnn/main_sampling.py +++ b/examples/pytorch/caregnn/main_sampling.py @@ -2,6 +2,7 @@ import argparse import torch as th import torch.optim as optim +from torch.nn.functional import softmax from sklearn.metrics import roc_auc_score, recall_score from utils import EarlyStopping @@ -22,7 +23,7 @@ def evaluate(model, loss_fn, dataloader, device='cpu'): # compute loss loss += loss_fn(logits_gnn, label).item() + args.sim_weight * loss_fn(logits_sim, label).item() recall += recall_score(label.cpu(), logits_gnn.argmax(dim=1).detach().cpu()) - auc += roc_auc_score(label.cpu(), logits_gnn[:, 1].detach().cpu()) + auc += roc_auc_score(label.cpu(), softmax(logits_gnn, dim=1)[:, 1].detach().cpu()) num_blocks += 1 return recall / num_blocks, auc / num_blocks, loss / num_blocks @@ -121,7 +122,7 @@ def main(args): blk_loss = loss_fn(logits_gnn, train_label) + args.sim_weight * loss_fn(logits_sim, train_label) tr_loss += blk_loss.item() tr_recall += recall_score(train_label.cpu(), logits_gnn.argmax(dim=1).detach().cpu()) - tr_auc += roc_auc_score(train_label.cpu(), logits_gnn[:, 1].detach().cpu()) + tr_auc += roc_auc_score(train_label.cpu(), softmax(logits_gnn, dim=1)[:, 1].detach().cpu()) tr_blk += 1 # backward