Skip to content

Commit

Permalink
Small change for kvstore api (dmlc#981)
Browse files Browse the repository at this point in the history
* Small change for kvstore api

* fix ci

* fix ci
  • Loading branch information
aksnzhy authored Nov 8, 2019
1 parent 0b4935d commit a0193fd
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 48 deletions.
30 changes: 15 additions & 15 deletions examples/mxnet/dis_kvstore/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,31 +37,31 @@ def start_client(args):

if client.get_id() == 0:
client.pull(name='embed_0', server_id=0, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64'))
server_id, new_tensor_0 = client.pull_wait()
assert server_id == 0
msg_0 = client.pull_wait()
assert msg_0.rank == 0
client.pull(name='embed_0', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4, 5], dtype='int64'))
server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1
msg_1 = client.pull_wait()
assert msg_1.rank == 1
print("embed_0:")
print(mx.nd.concat(new_tensor_0, new_tensor_1, dim=0))
print(mx.nd.concat(msg_0.data, msg_1.data, dim=0))

client.pull(name='embed_1', server_id=0, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64'))
server_id, new_tensor_0 = client.pull_wait()
assert server_id == 0
msg_0 = client.pull_wait()
assert msg_0.rank == 0
client.pull(name='embed_1', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4, 5], dtype='int64'))
server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1
msg_1 = client.pull_wait()
assert msg_1.rank == 1
print("embed_1:")
print(mx.nd.concat(new_tensor_0, new_tensor_1, dim=0))
print(mx.nd.concat(msg_0.data, msg_1.data, dim=0))

client.pull(name='server_embed', server_id=0, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64'))
server_id, new_tensor_0 = client.pull_wait()
assert server_id == 0
msg_0 = client.pull_wait()
assert msg_0.rank == 0
client.pull(name='server_embed', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64'))
server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1
msg_1 = client.pull_wait()
assert msg_1.rank == 1
print("server_embed:")
print(mx.nd.concat(new_tensor_0, new_tensor_1, dim=0))
print(mx.nd.concat(msg_0.data, msg_1.data, dim=0))

# Shut-down all the servers
if client.get_id() == 0:
Expand Down
30 changes: 15 additions & 15 deletions examples/pytorch/dis_kvstore/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,31 +37,31 @@ def start_client(args):

if client.get_id() == 0:
client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor_0 = client.pull_wait()
assert server_id == 0
msg_0 = client.pull_wait()
assert msg_0.rank == 0
client.pull(name='embed_0', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4, 5]))
server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1
msg_1 = client.pull_wait()
assert msg_1.rank == 1
print("embed_0:")
print(th.cat([new_tensor_0, new_tensor_1]))
print(th.cat([msg_0.data, msg_1.data]))

client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor_0 = client.pull_wait()
assert server_id == 0
msg_0 = client.pull_wait()
assert msg_0.rank == 0
client.pull(name='embed_1', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4, 5]))
server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1
msg_1 = client.pull_wait()
assert msg_1.rank == 1
print("embed_1:")
print(th.cat([new_tensor_0, new_tensor_1]))
print(th.cat([msg_0.data, msg_1.data]))

client.pull(name='server_embed', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor_0 = client.pull_wait()
assert server_id == 0
msg_0 = client.pull_wait()
assert msg_0.rank == 0
client.pull(name='server_embed', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1
msg_1 = client.pull_wait()
assert msg_1.rank == 1
print("server_embed:")
print(th.cat([new_tensor_0, new_tensor_1]))
print(th.cat([msg_0.data, msg_1.data]))

# Shut-down all the servers
if client.get_id() == 0:
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/contrib/dis_kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def pull_wait(self):
"""
msg = _recv_kv_msg(self._receiver)
assert msg.type == KVMsgType.PULL_BACK, 'Recv kv msg error.'
return msg.rank, msg.data
return msg

def barrier(self):
"""Barrier for all client nodes
Expand Down
35 changes: 18 additions & 17 deletions tests/compute/test_kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def start_client(server_embed):
client.barrier()

client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor = client.pull_wait()
assert server_id == 0
msg = client.pull_wait()
assert msg.rank == 0

target_tensor_0 = th.tensor(
[[ 0., 0., 0.],
Expand All @@ -55,46 +55,47 @@ def start_client(server_embed):
[ 0., 0., 0.],
[10., 10., 10.]])

assert th.equal(new_tensor, target_tensor_0) == True
assert th.equal(msg.data, target_tensor_0) == True

client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor = client.pull_wait()
msg = client.pull_wait()

target_tensor_1 = th.tensor([ 0., 0., 5., 0., 10.])

assert th.equal(new_tensor, target_tensor_1) == True
assert th.equal(msg.data, target_tensor_1) == True

client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
client.pull(name='server_embed', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))

_, tensor_0 = client.pull_wait()
_, tensor_1 = client.pull_wait()
_, tensor_2 = client.pull_wait()
_, tensor_3 = client.pull_wait()
_, tensor_4 = client.pull_wait()
msg_0 = client.pull_wait()
msg_1 = client.pull_wait()
msg_2 = client.pull_wait()
msg_3 = client.pull_wait()
msg_4 = client.pull_wait()

target_tensor_2 = th.tensor([ 2., 2., 7., 2., 12.])

assert th.equal(tensor_0, target_tensor_0) == True
assert th.equal(tensor_1, target_tensor_1) == True
assert th.equal(tensor_2, target_tensor_0) == True
assert th.equal(tensor_3, target_tensor_1) == True
assert th.equal(tensor_4, target_tensor_2) == True
assert th.equal(msg_0.data, target_tensor_0) == True
assert th.equal(msg_1.data, target_tensor_1) == True
assert th.equal(msg_2.data, target_tensor_0) == True
assert th.equal(msg_3.data, target_tensor_1) == True
assert th.equal(msg_4.data, target_tensor_2) == True

server_embed += target_tensor_2

client.pull(name='server_embed', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
_, tensor_5 = client.pull_wait()
msg_5 = client.pull_wait()

assert th.equal(tensor_5, target_tensor_2 * 2) == True
assert th.equal(msg_5.data, target_tensor_2 * 2) == True

client.shut_down()

if __name__ == '__main__':
server_embed = th.tensor([2., 2., 2., 2., 2.])
# use pytorch shared memory
server_embed.share_memory_()

pid = os.fork()
Expand Down

0 comments on commit a0193fd

Please sign in to comment.