diff --git a/python/dgl/graphbolt/feature_fetcher.py b/python/dgl/graphbolt/feature_fetcher.py index cf9d5f4104c2..832d00181588 100644 --- a/python/dgl/graphbolt/feature_fetcher.py +++ b/python/dgl/graphbolt/feature_fetcher.py @@ -166,15 +166,15 @@ def _cooperative_exchange(self, data): self.node_feature_keys, Dict ) or isinstance(self.edge_feature_keys, Dict) if is_heterogeneous: - node_features = {key: {} for key, _ in data.node_features.keys()} - for (key, ntype), feature in data.node_features.items(): + node_features = {key: {} for _, key in data.node_features.keys()} + for (ntype, key), feature in data.node_features.items(): node_features[key][ntype] = feature - for key, feature in node_features.items(): + for key, feature in sorted(node_features.items()): new_feature = CooperativeConvFunction.apply(subgraph, feature) for ntype, tensor in new_feature.items(): - data.node_features[(key, ntype)] = tensor + data.node_features[(ntype, key)] = tensor else: - for key in data.node_features: + for key in sorted(data.node_features): feature = data.node_features[key] new_feature = CooperativeConvFunction.apply(subgraph, feature) data.node_features[key] = new_feature diff --git a/python/dgl/graphbolt/impl/cooperative_conv.py b/python/dgl/graphbolt/impl/cooperative_conv.py index 22c5ae316c71..133f8d0c9835 100644 --- a/python/dgl/graphbolt/impl/cooperative_conv.py +++ b/python/dgl/graphbolt/impl/cooperative_conv.py @@ -42,19 +42,21 @@ def forward( seed_sizes, ) outs = {} - for ntype, typed_tensor in convert_to_hetero(tensor).items(): + for ntype, typed_tensor in sorted(convert_to_hetero(tensor).items()): out = typed_tensor.new_empty( - (sum(counts_sent[ntype]),) + typed_tensor.shape[1:] + (sum(counts_sent.get(ntype, [0])),) + typed_tensor.shape[1:], + requires_grad=typed_tensor.requires_grad, ) + default_splits = [0] * torch.distributed.get_world_size() all_to_all( - torch.split(out, counts_sent[ntype]), + torch.split(out, counts_sent.get(ntype, default_splits)), torch.split( - typed_tensor[seed_inverse_ids[ntype]], - counts_received[ntype], + typed_tensor[seed_inverse_ids.get(ntype, slice(None))], + counts_received.get(ntype, default_splits), ), ) outs[ntype] = out - return revert_to_homo(out) + return revert_to_homo(outs) @staticmethod def backward( @@ -69,7 +71,9 @@ def backward( ) = ctx.communication_variables delattr(ctx, "communication_variables") outs = {} - for ntype, typed_grad_output in convert_to_hetero(grad_output).items(): + for ntype, typed_grad_output in sorted( + convert_to_hetero(grad_output).items() + ): out = typed_grad_output.new_empty( (sum(counts_received[ntype]),) + typed_grad_output.shape[1:] ) diff --git a/python/dgl/graphbolt/impl/neighbor_sampler.py b/python/dgl/graphbolt/impl/neighbor_sampler.py index 7ddba6d7ccac..059fca7ef4c0 100644 --- a/python/dgl/graphbolt/impl/neighbor_sampler.py +++ b/python/dgl/graphbolt/impl/neighbor_sampler.py @@ -561,7 +561,7 @@ def _seeds_cooperative_exchange_1(minibatch): seeds_offsets = {"_N": seeds_offsets} num_ntypes = len(seeds_offsets) counts_sent = torch.empty(world_size * num_ntypes, dtype=torch.int64) - for i, offsets in enumerate(seeds_offsets.values()): + for i, (_, offsets) in enumerate(sorted(seeds_offsets.items())): counts_sent[ torch.arange(i, world_size * num_ntypes, num_ntypes) ] = offsets.diff() @@ -589,7 +589,7 @@ def _seeds_cooperative_exchange_2(minibatch): seeds_received = {} counts_sent = {} counts_received = {} - for i, (ntype, typed_seeds) in enumerate(seeds.items()): + for i, (ntype, typed_seeds) in enumerate(sorted(seeds.items())): idx = torch.arange(i, world_size * num_ntypes, num_ntypes) typed_counts_sent = subgraph._counts_sent[idx].tolist() typed_counts_received = subgraph._counts_received[idx].tolist() diff --git a/python/dgl/graphbolt/subgraph_sampler.py b/python/dgl/graphbolt/subgraph_sampler.py index 88fc9c124de5..85029bb627e0 100644 --- a/python/dgl/graphbolt/subgraph_sampler.py +++ b/python/dgl/graphbolt/subgraph_sampler.py @@ -236,7 +236,9 @@ def _seeds_cooperative_exchange_1_wait_future(minibatch): else: minibatch._seeds_offsets = {"_N": minibatch._seeds_offsets} counts_sent = torch.empty(world_size * num_ntypes, dtype=torch.int64) - for i, offsets in enumerate(minibatch._seeds_offsets.values()): + for i, (_, offsets) in enumerate( + sorted(minibatch._seeds_offsets.items()) + ): counts_sent[ torch.arange(i, world_size * num_ntypes, num_ntypes) ] = offsets.diff() @@ -261,7 +263,7 @@ def _seeds_cooperative_exchange_2(minibatch): seeds_received = {} counts_sent = {} counts_received = {} - for i, (ntype, typed_seeds) in enumerate(seeds.items()): + for i, (ntype, typed_seeds) in enumerate(sorted(seeds.items())): idx = torch.arange(i, world_size * num_ntypes, num_ntypes) typed_counts_sent = minibatch._counts_sent[idx].tolist() typed_counts_received = minibatch._counts_received[idx].tolist() diff --git a/tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py b/tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py index f88e011f4385..3eb5bd591752 100644 --- a/tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py +++ b/tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py @@ -37,13 +37,16 @@ def test_rank_sort_and_unique_and_compact(dtype, rank): assert_equal(offsets1, offsets2) assert offsets1.is_pinned() and offsets2.is_pinned() - res3 = torch.ops.graphbolt.rank_sort(nodes_list1, rank, WORLD_SIZE) + # Test with the reverse order of ntypes. See if results are equivalent. + res3 = torch.ops.graphbolt.rank_sort(nodes_list1[::-1], rank, WORLD_SIZE) # This function is deterministic. Call with identical arguments and check. - for (nodes1, idx1, offsets1), (nodes3, idx3, offsets3) in zip(res1, res3): + for (nodes1, idx1, offsets1), (nodes3, idx3, offsets3) in zip( + res1, reversed(res3) + ): assert_equal(nodes1, nodes3) assert_equal(idx1, idx3) - assert_equal(offsets1, offsets3) + assert_equal(offsets1.diff(), offsets3.diff()) # The dependency on the rank argument is simply a permutation. res4 = torch.ops.graphbolt.rank_sort(nodes_list1, 0, WORLD_SIZE) @@ -57,12 +60,12 @@ def test_rank_sort_and_unique_and_compact(dtype, rank): nodes1[off1[j] : off1[j + 1]], nodes4[off4[i] : off4[i + 1]] ) - unique, compacted, offsets = gb.unique_and_compact( - nodes_list1[:1], rank, WORLD_SIZE - ) + nodes = {str(i): [typed_seeds] for i, typed_seeds in enumerate(nodes_list1)} - nodes1, idx1, offsets1 = res1[0] + unique, compacted, offsets = gb.unique_and_compact(nodes, rank, WORLD_SIZE) - assert_equal(unique, nodes1) - assert_equal(compacted[0], idx1) - assert_equal(offsets, offsets1) + for i in nodes.keys(): + nodes1, idx1, offsets1 = res1[int(i)] + assert_equal(unique[i], nodes1) + assert_equal(compacted[i][0], idx1) + assert_equal(offsets[i], offsets1)