Skip to content

Commit

Permalink
[Feature][GPU] Add function for setting weights of a sparse embedding…
Browse files Browse the repository at this point in the history
… on multiple GPUs. (dmlc#3047)

* add unit test

* Extend NDArrayPartition object

* Add method for setting embedding, and improve documentation

* Sync before returning

* Use name unique to sparse embedding class to avoid delete

Co-authored-by: xiang song(charlie.song) <[email protected]>
  • Loading branch information
nv-dlasalle and classicsong authored Jun 22, 2021
1 parent 70af194 commit 7359481
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 9 deletions.
44 changes: 35 additions & 9 deletions python/dgl/nn/pytorch/sparse_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,11 @@ def __init__(self, num_embeddings, embedding_dim, name,
if rank == 0:
# root process broadcasts nccl id
nccl_id = nccl.UniqueId()
self._store.set('nccl_root_id', str(nccl_id))
self._store.set('nccl_root_id_sparse_emb', str(nccl_id))
else:
nccl_id = nccl.UniqueId(self._store.get('nccl_root_id'))
nccl_id = nccl.UniqueId(self._store.get('nccl_root_id_sparse_emb'))
_COMM = nccl.Communicator(self._world_size, self._rank,
nccl_id)
if self._rank == 0:
# clear the store entry for future communicators
self._store.delete_key('nccl_root_id')
th.distributed.barrier()

self._comm = _COMM

if not self._partition:
Expand Down Expand Up @@ -335,12 +330,43 @@ def weight(self):
"""
return self._tensor

def gather_embedding(self):
"""Return a copy of the embedding stored in CPU memory. If this is a
def all_set_embedding(self, values):
""" Set the values of the embedding. This method must be called by all
processes sharing the embedding with identical tensors for
:attr:`values`.
NOTE: This method must be called by all processes sharing the
embedding, or it may result in a deadlock.
Parameters
----------
values : Tensor
The global tensor to pull values from.
"""
if self._partition:
idxs = F.copy_to(
self._partition.get_local_indices(
self._comm.rank(),
ctx=F.context(self._tensor)),
F.context(values))
self._tensor[:] = F.copy_to(F.gather_row(values, idxs),
ctx=F.context(self._tensor))[:]
else:
if self._rank == 0:
self._tensor[:] = F.copy_to(values,
ctx=F.context(self._tensor))[:]
if th.distributed.is_initialized():
th.distributed.barrier()

def all_get_embedding(self):
""" Return a copy of the embedding stored in CPU memory. If this is a
multi-processing instance, the tensor will be returned in shared
memory. If the embedding is currently stored on multiple GPUs, all
processes must call this method in the same order.
NOTE: This method must be called by all processes sharing the
embedding, or it may result in a deadlock.
Returns
-------
torch.Tensor
Expand Down
16 changes: 16 additions & 0 deletions python/dgl/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,13 +419,29 @@ def __init__(self, array_size, num_parts, mode='remainder', part_ranges=None):
array_size, num_parts)
else:
assert False, 'Unknown partition mode "{}"'.format(mode)
self._array_size = array_size
self._num_parts = num_parts

def num_parts(self):
""" Get the number of partitions.
"""
return self._num_parts

def array_size(self):
""" Get the total size of the first dimension of the partitioned array.
"""
return self._array_size

def get(self):
""" Get the C-handle for this object.
"""
return self._partition

def get_local_indices(self, part, ctx):
""" Get the set of global indices in this given partition.
"""
return self.map_to_global(F.arange(0, self.local_size(part), ctx=ctx), part)

def local_size(self, part):
""" Get the number of rows/items assigned to the given part.
"""
Expand Down
68 changes: 68 additions & 0 deletions tests/pytorch/test_sparse_emb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import multiprocessing as mp
import unittest, os
import pytest

import torch as th
import backend as F

from dgl.nn import NodeEmbedding


def initializer(emb):
th.manual_seed(0)
emb.uniform_(-1.0, 1.0)
return emb

def check_all_set_all_get_func(device, init_emb):
num_embs = init_emb.shape[0]
emb_dim = init_emb.shape[1]
dgl_emb = NodeEmbedding(num_embs, emb_dim, 'test', device=device)
dgl_emb.all_set_embedding(init_emb)

out_emb = dgl_emb.all_get_embedding()
assert F.allclose(init_emb, out_emb)

def start_sparse_worker(rank, world_size, test, args):
print('start sparse worker {}'.format(rank))
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip='127.0.0.1', master_port='12345')
backend = 'gloo'
device = F.ctx()
if device.type == 'cuda':
device = th.device(rank)
th.cuda.set_device(device)
th.distributed.init_process_group(backend=backend,
init_method=dist_init_method,
world_size=world_size,
rank=rank)

test(device, *args)
th.distributed.barrier()

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@pytest.mark.parametrize("num_workers", [1, 2, 3])
def test_multiprocess_sparse_emb_get_set(num_workers):
if F.ctx().type == 'cuda' and th.cuda.device_count() < num_workers:
pytest.skip("Not enough GPUs to run test.")

worker_list = []

init_emb = th.rand([1000, 8])

ctx = mp.get_context('spawn')
for i in range(num_workers):
p = ctx.Process(target=start_sparse_worker,
args=(i, num_workers, check_all_set_all_get_func, (init_emb,)))
p.start()
worker_list.append(p)

for p in worker_list:
p.join()
for p in worker_list:
assert p.exitcode == 0


if __name__ == '__main__':
test_sparse_emb_get_set(1)
test_sparse_emb_get_set(2)
test_sparse_emb_get_set(3)

0 comments on commit 7359481

Please sign in to comment.