Skip to content

Commit

Permalink
[Bugfix] Improve CompGCN (dmlc#3663)
Browse files Browse the repository at this point in the history
Co-authored-by: Mufei Li <[email protected]>
  • Loading branch information
nxznm and mufeili authored Jan 23, 2022
1 parent d3930ba commit 9a6b81e
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
16 changes: 8 additions & 8 deletions examples/pytorch/compGCN/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ This example was implemented by [zhjwy9343](https://github.com/zhjwy9343) and [K

Dependencies
----------------------
- pytorch 1.7.1
- dgl 0.6.0
- numpy 1.19.4
- pytorch 1.9.0
- dgl 0.7.1
- numpy 1.20.3
- ordered_set 4.0.2

Dataset
Expand Down Expand Up @@ -67,11 +67,11 @@ Performance
| Dataset | FB15k-237 | WN18RR |
|---------| ------------------------ | ------------------------ |
| Metric | Paper / ours (dgl) | Paper / ours (dgl) |
| MRR | 0.355 / 0.349 | 0.479 / 0.471 |
| MR | 197 / 208 | 3533 / 3550 |
| Hit@10 | 0.535 / 0.526 | 0.546 / 0.532 |
| Hit@3 | 0.390 / 0.381 | 0.494 / 0.480 |
| Hit@1 | 0.264 / 0.260 | 0.443 / 0.438 |
| MRR | 0.355 / 0.348 | 0.479 / 0.466 |
| MR | 197 / 208 | 3533 / 3542 |
| Hit@10 | 0.535 / 0.527 | 0.546 / 0.525 |
| Hit@3 | 0.390 / 0.380 | 0.494 / 0.476 |
| Hit@1 | 0.264 / 0.259 | 0.443 / 0.435 |



Expand Down
6 changes: 3 additions & 3 deletions examples/pytorch/compGCN/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def main(args):

# compute loss
tr_loss = loss_fn(logits, label)
train_loss.append(tr_loss)
train_loss.append(tr_loss.item())

# backward
optimizer.zero_grad()
Expand All @@ -142,7 +142,7 @@ def main(args):
print("saving model...")
else:
kill_cnt += 1
if kill_cnt > 25:
if kill_cnt > 100:
print('early stop.')
break
print("In epoch {}, Train Loss: {:.4f}, Valid MRR: {:.5}\n, Train time: {}, Valid time: {}"\
Expand All @@ -164,7 +164,7 @@ def main(args):
parser.add_argument('--score_func', dest='score_func', default='conve', help='Score Function for Link prediction')
parser.add_argument('--opn', dest='opn', default='ccorr', help='Composition Operation to be used in CompGCN')

parser.add_argument('--batch', dest='batch_size', default=128, type=int, help='Batch size')
parser.add_argument('--batch', dest='batch_size', default=1024, type=int, help='Batch size')
parser.add_argument('--gpu', type=int, default='0', help='Set GPU Ids : Eg: For CPU = -1, For Single GPU = 0')
parser.add_argument('--epoch', dest='max_epochs', type=int, default=500, help='Number of epochs')
parser.add_argument('--l2', type=float, default=0.0, help='L2 Regularization for Optimizer')
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/compGCN/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def ccorr(a, b):
-------
Tensor, having the same dimension as the input a.
"""
return th.irfft(com_mult(conj(th.rfft(a, 1)), th.rfft(b, 1)), 1, signal_sizes=(a.shape[-1],))
return th.fft.irfftn(th.conj(th.fft.rfftn(a, (-1))) * th.fft.rfftn(b, (-1)), (-1))

#identify in/out edges, compute edge norm for each and store in edata
def in_out_norm(graph):
Expand Down

0 comments on commit 9a6b81e

Please sign in to comment.