-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for distributed sampling (#246)
This code belongs to the part of the whole distributed training for PyG. ## Description Distributed training neighbor sampling differs from the sampling currently implemented in pyg-lib. During distributed training nodes from one batch can be sampled by different machines (and therefore different samplers). The result of this is incorrect subtree/subgraph node indexing. To achieve correct results it is necessary to sample by one hop and then synchronise outputs between machines. Proposed algorithm: 1. First sample only global node ids (`sampled_nodes`) with duplicates in `neighbor_sample`. 2. Do not sample rows and cols but save information of how many neighbors were sampled by each node (`cumm_sum_sampled_nbrs_per_node`). 3. After each layer: synchronise and merge outputs from different machines and take new seed nodes (without duplicates) from sampled_nodes. 4. Sample next layer and continue 1-3 until all layers are sampled. 5. Perform global to local mappings using mapper and create (row, col) based on a `sampled_nodes_with_duplicates` and `sampled_nbrs_per_node`. Step 3. was implemented in pytorch_geometric. ## Added - new argument `distributed` to the `neighbor_sample` function to enable the algorithm described above. - new argument `batch` to the `neighbor_sample` function that allows to specify the initial subgraph indices for seed nodes (used with disjoint). - new return value `cumm_sum_sampled_nbrs_per_node` to the `neighbor_sample` function to return cumulative sum of the sampled neighbors per each node. - new function `relabel_neighborhood` that is used after sampling all layers and its purpose is to relabel global indices of the sampled nodes to the local subtree/subgraph indices (row, col). - new function `hetero_relabel_neighborhood` (same as `relabel_neighborhood` but for heterogeneous graphs). Returns (row_dict and col_dict). - unit tests --------- Co-authored-by: rusty1s <[email protected]>
- Loading branch information
Showing
6 changed files
with
358 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.