Skip to content

Commit

Permalink
change test case for distGB
Browse files Browse the repository at this point in the history
  • Loading branch information
CfromBU committed Dec 18, 2024
1 parent 23c6cd7 commit 4eb9a9d
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions tests/distributed/test_distributed_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1882,9 +1882,9 @@ def create_hetero_graph(dense=False, empty=False):
generate_ip_config("rpc_ip_config.txt", num_server, num_server)

g = create_hetero_graph()
indices = torch.randperm(g.num_edges("r34"))[:10]
eids = torch.randperm(g.num_edges("r34"))[:10]
mask = torch.zeros(g.num_edges("r34"), dtype=torch.bool)
mask[indices] = True
mask[eids] = True

num_parts = num_server

Expand All @@ -1900,17 +1900,14 @@ def create_hetero_graph(dense=False, empty=False):
store_eids=True,
)

pserver_list = []

part_config = tmpdir / "test_sampling.json"

dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("test_sampling", part_config=part_config)
print(dist_graph.local_partition)

os.environ["DGL_DIST_DEBUG"] = "1"

edges = {("n3", "r34", "n4"): indices}
edges = {("n3", "r34", "n4"): eids}
sampler = dgl.dataloading.MultiLayerNeighborSampler([10, 10], mask="mask")
loader = dgl.dataloading.DistEdgeDataLoader(
dist_graph, edges, sampler, batch_size=64
Expand Down

0 comments on commit 4eb9a9d

Please sign in to comment.