From 4eb9a9d06ac965ad6a6ba3d86c40144ff5a4b6cd Mon Sep 17 00:00:00 2001 From: CfromBU <2649624957@qq.com> Date: Wed, 18 Dec 2024 02:24:31 +0000 Subject: [PATCH] change test case for distGB --- tests/distributed/test_distributed_sampling.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/distributed/test_distributed_sampling.py b/tests/distributed/test_distributed_sampling.py index 3f6592ff19a1..e6b28be8c83a 100644 --- a/tests/distributed/test_distributed_sampling.py +++ b/tests/distributed/test_distributed_sampling.py @@ -1882,9 +1882,9 @@ def create_hetero_graph(dense=False, empty=False): generate_ip_config("rpc_ip_config.txt", num_server, num_server) g = create_hetero_graph() - indices = torch.randperm(g.num_edges("r34"))[:10] + eids = torch.randperm(g.num_edges("r34"))[:10] mask = torch.zeros(g.num_edges("r34"), dtype=torch.bool) - mask[indices] = True + mask[eids] = True num_parts = num_server @@ -1900,17 +1900,14 @@ def create_hetero_graph(dense=False, empty=False): store_eids=True, ) - pserver_list = [] - part_config = tmpdir / "test_sampling.json" dgl.distributed.initialize("rpc_ip_config.txt") dist_graph = DistGraph("test_sampling", part_config=part_config) - print(dist_graph.local_partition) os.environ["DGL_DIST_DEBUG"] = "1" - edges = {("n3", "r34", "n4"): indices} + edges = {("n3", "r34", "n4"): eids} sampler = dgl.dataloading.MultiLayerNeighborSampler([10, 10], mask="mask") loader = dgl.dataloading.DistEdgeDataLoader( dist_graph, edges, sampler, batch_size=64