Skip to content

Commit

Permalink
[bugfix] Fix dataloader bug in GAT PPI data example (dmlc#1966)
Browse files Browse the repository at this point in the history
* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c.

* fix dataloader bug

Co-authored-by: xiang song(charlie.song) <[email protected]>
Co-authored-by: Zihao Ye <[email protected]>
  • Loading branch information
3 people authored Aug 7, 2020
1 parent 18bfec2 commit dcf4641
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions examples/pytorch/gat/train_ppi.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,9 @@ def main(args):
if epoch % 5 == 0:
score_list = []
val_loss_list = []
for batch, valid_data in enumerate(valid_dataloader):
subgraph, feats, labels = valid_data
for batch, subgraph in enumerate(valid_dataloader):
subgraph = subgraph.to(device)
feats = feats.to(device)
labels = labels.to(device)
score, val_loss = evaluate(feats.float(), model, subgraph, labels.float(), loss_fcn)
score, val_loss = evaluate(subgraph.ndata['feat'], model, subgraph, subgraph.ndata['label'], loss_fcn)
score_list.append(score)
val_loss_list.append(val_loss)
mean_score = np.array(score_list).mean()
Expand Down

0 comments on commit dcf4641

Please sign in to comment.