Skip to content

Commit

Permalink
Benchmarks for biased sampling (#267)
Browse files Browse the repository at this point in the history
Added `--biased` parameter to `neighbor.py` and `hetero_neighbor.py` to
run benchmarks for biased sampling.
Output for command python neighbor.py --biased --libraries pyg-lib
--write-csv :

![image](https://github.com/pyg-team/pyg-lib/assets/57872493/d5d01246-a368-4768-ad34-fda415ce2733)
Output for command python hetero_neighbor.py --biased --libraries
pyg-lib --write-csv :

![image](https://github.com/pyg-team/pyg-lib/assets/57872493/7dbd2d68-9709-4259-a1f7-bf880bec5330)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <[email protected]>
  • Loading branch information
3 people authored Oct 18, 2023
1 parent 6215fe7 commit 7055d40
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [0.4.0] - 2023-MM-DD
### Added
### Changed
- Added `--biased` parameter to run benchmarks for biased sampling ([#267](https://github.com/pyg-team/pyg-lib/pull/267))
### Removed

## [0.3.0] - 2023-10-11
Expand Down
10 changes: 10 additions & 0 deletions benchmark/sampler/hetero_neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
# TODO(kgajdamo): Enable sampling with replacement
# argparser.add_argument('--replace', action='store_true')
argparser.add_argument('--shuffle', action='store_true')
argparser.add_argument('--biased', action='store_true')
argparser.add_argument('--temporal', action='store_true')
argparser.add_argument('--temporal-strategy', choices=['uniform', 'last'],
default='uniform')
Expand All @@ -51,6 +52,7 @@ def test_hetero_neighbor(dataset, **kwargs):

colptr_dict, row_dict = dataset
num_nodes_dict = {k[-1]: v.size(0) - 1 for k, v in colptr_dict.items()}
num_edges_dict = {k: v.size(0) for k, v in row_dict.items()}

if args.temporal:
# generate random timestamps
Expand All @@ -60,6 +62,13 @@ def test_hetero_neighbor(dataset, **kwargs):
else:
node_time_dict = None

edge_weight_dict = None
if args.biased:
edge_weight_dict = {
edge_type: torch.rand(num_edges)
for edge_type, num_edges in num_edges_dict.items()
}

if args.shuffle:
node_perm = torch.randperm(num_nodes_dict['paper'])
else:
Expand All @@ -86,6 +95,7 @@ def test_hetero_neighbor(dataset, **kwargs):
num_neighbors_dict,
node_time_dict,
seed_time_dict=None,
edge_weight_dict=edge_weight_dict,
csc=True,
replace=False,
directed=True,
Expand Down
11 changes: 10 additions & 1 deletion benchmark/sampler/neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
])
argparser.add_argument('--replace', action='store_true')
argparser.add_argument('--shuffle', action='store_true')
argparser.add_argument('--biased', action='store_true')
argparser.add_argument('--temporal', action='store_true')
argparser.add_argument('--temporal-strategy', choices=['uniform', 'last'],
default='uniform')
Expand All @@ -45,7 +46,10 @@ def test_neighbor(dataset, **kwargs):
raise ValueError(
"Temporal sampling needs to create disjoint subgraphs")

(rowptr, col), num_nodes = dataset, dataset[0].size(0) - 1
rowptr, col = dataset
num_nodes = rowptr.numel() - 1
num_edges = col.numel()

if 'dgl' in args.libraries:
import dgl
dgl_graph = dgl.graph(
Expand All @@ -57,6 +61,10 @@ def test_neighbor(dataset, **kwargs):
else:
node_time = None

edge_weight = None
if args.biased:
edge_weight = torch.rand(num_edges)

if args.shuffle:
node_perm = torch.randperm(num_nodes)
else:
Expand All @@ -80,6 +88,7 @@ def test_neighbor(dataset, **kwargs):
num_neighbors,
time=node_time,
seed_time=None,
edge_weight=edge_weight,
replace=args.replace,
directed=args.directed,
disjoint=args.disjoint,
Expand Down

0 comments on commit 7055d40

Please sign in to comment.