Skip to content

Commit

Permalink
[Feature] Node2vec (dmlc#2992)
Browse files Browse the repository at this point in the history
* add seal example

* 1. add paper infomation in examples/README
2. adjust codes
3. option test

* use latest `to_simple` to replace coalesce graph function

* remove outdated codes

* remove useless comment

* Node2vec
1.implement node2vec random walk c++ op
2.implement node2vec model
3.implement node2vec example

* add CMakeLists file modify

* refine c++ codes

* refine c++ codes

* add missing whitespace

* refine python codes

* add codes

* add node2vec_impl.h

* fix codes

* fix code style problem

* fixes

* remove

* lots of changes

* add benchmark

* fixes

Co-authored-by: smilexuhc <[email protected]>
Co-authored-by: Minjie Wang <[email protected]>
  • Loading branch information
3 people authored Jun 23, 2021
1 parent 7359481 commit e667545
Show file tree
Hide file tree
Showing 20 changed files with 1,135 additions and 113 deletions.
33 changes: 33 additions & 0 deletions benchmarks/benchmarks/api/bench_random_walk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import time
import dgl
import torch

from .. import utils

def _random_walk(g, seeds, length):
return dgl.sampling.random_walk(g, seeds, length=length)

def _node2vec(g, seeds, length):
return dgl.sampling.node2vec_random_walk(g, seeds, 1, 1, length)

@utils.benchmark('time')
@utils.parametrize_cpu('graph_name', ['cora', 'livejournal', 'friendster'])
@utils.parametrize('num_seeds', [10, 100, 1000])
@utils.parametrize('length', [2, 5, 10, 20])
@utils.parametrize('algorithm', ['_random_walk', '_node2vec'])
def track_time(graph_name, num_seeds, length, algorithm):
device = utils.get_bench_device()
graph = utils.get_graph(graph_name, 'csr')
seeds = torch.randint(0, graph.num_nodes(), (num_seeds,))
print(graph_name, num_seeds, length)
alg = globals()[algorithm]
# dry run
for i in range(5):
_ = alg(graph, seeds, length=length)

# timing
with utils.Timer() as t:
for i in range(50):
_ = alg(graph, seeds, length=length)

return t.elapsed_secs / 50
53 changes: 53 additions & 0 deletions examples/pytorch/node2vec/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# DGL Implementation of the Node2vec
This DGL example implements the graph embedding model proposed in the paper
[node2vec: Scalable Feature Learning for Networks](https://arxiv.org/abs/1607.00653)

The author's codes of implementation is in [Node2vec](https://github.com/aditya-grover/node2vec)


Example implementor
----------------------
This example was implemented by [Smile](https://github.com/Smilexuhc) during his intern work at the AWS Shanghai AI Lab.

The graph dataset used in this example
---------------------------------------

cora
- NumNodes: 2708
- NumEdges: 10556

ogbn-products
- NumNodes: 2449029
- NumEdges: 61859140


Dependencies
--------------------------------

- python 3.6+
- Pytorch 1.5.0+
- ogb


How to run example files
--------------------------------
To train a node2vec model:
```shell script
python main.py --task="train"
```

To time node2vec random walks:
```shell script
python main.py --task="time" --runs=10
```

Performance
-------------------------

**Setting:** `walk_length=50, p=0.25, q=4.0`
| Dataset | DGL | PyG |
| -------- | :---------: | :---------: |
| cora | 0.0092s | 0.0179s |
| products | 66.22s | 77.65s |
Note that the number in table are the average results of multiple trials.
For cora, we run 50 trials. For ogbn-products, we run 10 trials.
55 changes: 55 additions & 0 deletions examples/pytorch/node2vec/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import time
from dgl.sampling import node2vec_random_walk
from model import Node2vecModel
from utils import load_graph, parse_arguments


def time_randomwalk(graph, args):
"""
Test cost time of random walk
"""

start_time = time.time()

# default setting for testing
params = {'p': 0.25,
'q': 4,
'walk_length': 50}

for i in range(args.runs):
node2vec_random_walk(graph, graph.nodes(), **params)
end_time = time.time()
cost_time_avg = (end_time-start_time)/args.runs
print("Run dataset {} {} trials, mean run time: {:.3f}s".format(args.dataset, args.runs, cost_time_avg))


def train_node2vec(graph, eval_set, args):
"""
Train node2vec model
"""
trainer = Node2vecModel(graph,
embedding_dim=args.embedding_dim,
walk_length=args.walk_length,
p=args.p,
q=args.q,
num_walks=args.num_walks,
eval_set=eval_set,
eval_steps=1,
device=args.device)

trainer.train(epochs=args.epochs, batch_size=args.batch_size, learning_rate=0.01)


if __name__ == '__main__':

args = parse_arguments()
graph, eval_set = load_graph(args.dataset)

if args.task == 'train':
print("Perform training node2vec model")
train_node2vec(graph, eval_set, args)
elif args.task == 'time':
print("Timing random walks")
time_randomwalk(graph, args)
else:
raise ValueError('Task type error!')
Loading

0 comments on commit e667545

Please sign in to comment.