Skip to content

Commit

Permalink
remove nonbatchable mode
Browse files Browse the repository at this point in the history
  • Loading branch information
jermainewang committed Oct 3, 2018
1 parent 3a3e5d4 commit 7d04c8c
Show file tree
Hide file tree
Showing 10 changed files with 122 additions and 644 deletions.
246 changes: 56 additions & 190 deletions python/dgl/graph.py

Large diffs are not rendered by default.

22 changes: 11 additions & 11 deletions tests/pytorch/test_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_batch_send():
def _fmsg(src, edge):
assert src['h'].shape == (5, D)
return {'m' : src['h']}
g.register_message_func(_fmsg, batchable=True)
g.register_message_func(_fmsg)
# many-many send
u = th.tensor([0, 0, 0, 0, 0])
v = th.tensor([1, 2, 3, 4, 5])
Expand All @@ -150,9 +150,9 @@ def _fmsg(src, edge):
def test_batch_recv():
# basic recv test
g = generate_graph()
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True)
g.register_apply_node_func(apply_node_func, batchable=True)
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
g.register_apply_node_func(apply_node_func)
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
reduce_msg_shapes.clear()
Expand All @@ -163,9 +163,9 @@ def test_batch_recv():

def test_update_routines():
g = generate_graph()
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True)
g.register_apply_node_func(apply_node_func, batchable=True)
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
g.register_apply_node_func(apply_node_func)

# send_and_recv
reduce_msg_shapes.clear()
Expand Down Expand Up @@ -209,7 +209,7 @@ def _reduce(node, msgs):
return node + msgs.sum(1)
old_repr = th.randn(5, 5)
g.set_n_repr(old_repr)
g.update_all(_message, _reduce, batchable=True)
g.update_all(_message, _reduce)
new_repr = g.get_n_repr()

assert th.allclose(new_repr[1:], old_repr[1:])
Expand All @@ -227,17 +227,17 @@ def _reduce(node, msgs):

old_repr = th.randn(2, 5)
g.set_n_repr(old_repr)
g.pull(0, _message, _reduce, batchable=True)
g.pull(0, _message, _reduce)
new_repr = g.get_n_repr()
assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[1])
g.pull(1, _message, _reduce, batchable=True)
g.pull(1, _message, _reduce)
new_repr = g.get_n_repr()
assert th.allclose(new_repr[1], old_repr[0])

old_repr = th.randn(2, 5)
g.set_n_repr(old_repr)
g.pull([0, 1], _message, _reduce, batchable=True)
g.pull([0, 1], _message, _reduce)
new_repr = g.get_n_repr()
assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[0])
Expand Down
10 changes: 5 additions & 5 deletions tests/pytorch/test_batching_anonymous.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_batch_send():
def _fmsg(hu, edge):
assert hu.shape == (5, D)
return hu
g.register_message_func(_fmsg, batchable=True)
g.register_message_func(_fmsg)
# many-many send
u = th.tensor([0, 0, 0, 0, 0])
v = th.tensor([1, 2, 3, 4, 5])
Expand All @@ -145,8 +145,8 @@ def _fmsg(hu, edge):

def test_batch_recv():
g = generate_graph()
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True)
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
reduce_msg_shapes.clear()
Expand All @@ -157,8 +157,8 @@ def test_batch_recv():

def test_update_routines():
g = generate_graph()
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True)
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)

# send_and_recv
reduce_msg_shapes.clear()
Expand Down
52 changes: 26 additions & 26 deletions tests/pytorch/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,102 +51,102 @@ def reducer_none(node, msgs):
def test_copy_src():
# copy_src with both fields
g = generate_graph()
g.register_message_func(fn.copy_src(src='h', out='m'), batchable=True)
g.register_reduce_func(reducer_both, batchable=True)
g.register_message_func(fn.copy_src(src='h', out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))

# copy_src with only src field; the out field should use anonymous repr
g = generate_graph()
g.register_message_func(fn.copy_src(src='h'), batchable=True)
g.register_reduce_func(reducer_out, batchable=True)
g.register_message_func(fn.copy_src(src='h'))
g.register_reduce_func(reducer_out)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))

# copy_src with no src field; should use anonymous repr
g = generate_graph1()
g.register_message_func(fn.copy_src(out='m'), batchable=True)
g.register_reduce_func(reducer_both, batchable=True)
g.register_message_func(fn.copy_src(out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))

# copy src with no fields;
g = generate_graph1()
g.register_message_func(fn.copy_src(), batchable=True)
g.register_reduce_func(reducer_out, batchable=True)
g.register_message_func(fn.copy_src())
g.register_reduce_func(reducer_out)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))

def test_copy_edge():
# copy_edge with both fields
g = generate_graph()
g.register_message_func(fn.copy_edge(edge='h', out='m'), batchable=True)
g.register_reduce_func(reducer_both, batchable=True)
g.register_message_func(fn.copy_edge(edge='h', out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))

# copy_edge with only edge field; the out field should use anonymous repr
g = generate_graph()
g.register_message_func(fn.copy_edge(edge='h'), batchable=True)
g.register_reduce_func(reducer_out, batchable=True)
g.register_message_func(fn.copy_edge(edge='h'))
g.register_reduce_func(reducer_out)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))

# copy_edge with no edge field; should use anonymous repr
g = generate_graph1()
g.register_message_func(fn.copy_edge(out='m'), batchable=True)
g.register_reduce_func(reducer_both, batchable=True)
g.register_message_func(fn.copy_edge(out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))

# copy edge with no fields;
g = generate_graph1()
g.register_message_func(fn.copy_edge(), batchable=True)
g.register_reduce_func(reducer_out, batchable=True)
g.register_message_func(fn.copy_edge())
g.register_reduce_func(reducer_out)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))

def test_src_mul_edge():
# src_mul_edge with all fields
g = generate_graph()
g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m'), batchable=True)
g.register_reduce_func(reducer_both, batchable=True)
g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))

g = generate_graph()
g.register_message_func(fn.src_mul_edge(src='h', edge='h'), batchable=True)
g.register_reduce_func(reducer_out, batchable=True)
g.register_message_func(fn.src_mul_edge(src='h', edge='h'))
g.register_reduce_func(reducer_out)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))

g = generate_graph1()
g.register_message_func(fn.src_mul_edge(out='m'), batchable=True)
g.register_reduce_func(reducer_both, batchable=True)
g.register_message_func(fn.src_mul_edge(out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))

g = generate_graph1()
g.register_message_func(fn.src_mul_edge(), batchable=True)
g.register_reduce_func(reducer_out, batchable=True)
g.register_message_func(fn.src_mul_edge())
g.register_reduce_func(reducer_out)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))

g = generate_graph1()
g.register_message_func(fn.src_mul_edge(), batchable=True)
g.register_reduce_func(reducer_none, batchable=True)
g.register_message_func(fn.src_mul_edge())
g.register_reduce_func(reducer_none)
g.update_all()
assert th.allclose(g.get_n_repr(),
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
Expand Down
8 changes: 4 additions & 4 deletions tests/pytorch/test_graph_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def test_batch_sendrecv():
t2 = tree2()

bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: src, batchable=True)
bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1), batchable=True)
bg.register_message_func(lambda src, edge: src)
bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1))
e1 = [(3, 1), (4, 1)]
e2 = [(2, 4), (0, 4)]

Expand All @@ -94,8 +94,8 @@ def test_batch_propagate():
t2 = tree2()

bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: src, batchable=True)
bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1), batchable=True)
bg.register_message_func(lambda src, edge: src)
bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1))
# get leaves.

order = []
Expand Down
40 changes: 20 additions & 20 deletions tests/pytorch/test_specialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,23 @@ def apply_func(hu):
g = generate_graph()
# update all
v1 = g.get_n_repr()[fld]
g.update_all(fn.copy_src(src=fld), fn.sum(out=fld), apply_func, batchable=True)
g.update_all(fn.copy_src(src=fld), fn.sum(out=fld), apply_func)
v2 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1})
g.update_all(message_func, reduce_func, apply_func, batchable=True)
g.update_all(message_func, reduce_func, apply_func)
v3 = g.get_n_repr()[fld]
assert th.allclose(v2, v3)
# update all with edge weights
v1 = g.get_n_repr()[fld]
g.update_all(fn.src_mul_edge(src=fld, edge='e1'),
fn.sum(out=fld), apply_func, batchable=True)
fn.sum(out=fld), apply_func)
v2 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1})
g.update_all(fn.src_mul_edge(src=fld, edge='e2'),
fn.sum(out=fld), apply_func, batchable=True)
fn.sum(out=fld), apply_func)
v3 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1})
g.update_all(message_func_edge, reduce_func, apply_func, batchable=True)
g.update_all(message_func_edge, reduce_func, apply_func)
v4 = g.get_n_repr()[fld]
assert th.allclose(v2, v3)
assert th.allclose(v3, v4)
Expand Down Expand Up @@ -85,25 +85,25 @@ def apply_func(hu):
# send and recv
v1 = g.get_n_repr()[fld]
g.send_and_recv(u, v, fn.copy_src(src=fld),
fn.sum(out=fld), apply_func, batchable=True)
fn.sum(out=fld), apply_func)
v2 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1})
g.send_and_recv(u, v, message_func,
reduce_func, apply_func, batchable=True)
reduce_func, apply_func)
v3 = g.get_n_repr()[fld]
assert th.allclose(v2, v3)
# send and recv with edge weights
v1 = g.get_n_repr()[fld]
g.send_and_recv(u, v, fn.src_mul_edge(src=fld, edge='e1'),
fn.sum(out=fld), apply_func, batchable=True)
fn.sum(out=fld), apply_func)
v2 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1})
g.send_and_recv(u, v, fn.src_mul_edge(src=fld, edge='e2'),
fn.sum(out=fld), apply_func, batchable=True)
fn.sum(out=fld), apply_func)
v3 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1})
g.send_and_recv(u, v, message_func_edge,
reduce_func, apply_func, batchable=True)
reduce_func, apply_func)
v4 = g.get_n_repr()[fld]
assert th.allclose(v2, v3)
assert th.allclose(v3, v4)
Expand All @@ -127,18 +127,18 @@ def reduce_func(hv, msgs):
# update all, mix of builtin and UDF
g.update_all([fn.copy_src(src=fld, out='m1'), message_func],
[fn.sum(msgs='m1', out='v1'), reduce_func],
None, batchable=True)
None)
v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2)

# run builtin with single message and reduce
g.update_all(fn.copy_src(src=fld), fn.sum(out='v1'), None, batchable=True)
g.update_all(fn.copy_src(src=fld), fn.sum(out='v1'), None)
v1 = g.get_n_repr()['v1']
assert th.allclose(v1, v2)

# 1 message, 2 reduces, using anonymous repr
g.update_all(fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None, batchable=True)
g.update_all(fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None)
v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3']
assert th.allclose(v1, v2)
Expand All @@ -147,15 +147,15 @@ def reduce_func(hv, msgs):
# update all with edge weights, 2 message, 3 reduces
g.update_all([fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')],
[fn.sum(msgs='m1', out='v1'), fn.sum(msgs='m2', out='v2'), fn.sum(msgs='m1', out='v3')],
None, batchable=True)
None)
v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3']
assert th.allclose(v1, v2)
assert th.allclose(v1, v3)

# run UDF with single message and reduce
g.update_all(message_func_edge, reduce_func, None, batchable=True)
g.update_all(message_func_edge, reduce_func, None)
v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2)

Expand All @@ -179,19 +179,19 @@ def reduce_func(hv, msgs):
g.send_and_recv(u, v,
[fn.copy_src(src=fld, out='m1'), message_func],
[fn.sum(msgs='m1', out='v1'), reduce_func],
None, batchable=True)
None)
v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2)

# run builtin with single message and reduce
g.send_and_recv(u, v, fn.copy_src(src=fld), fn.sum(out='v1'),
None, batchable=True)
None)
v1 = g.get_n_repr()['v1']
assert th.allclose(v1, v2)

# 1 message, 2 reduces, using anonymous repr
g.send_and_recv(u, v, fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None, batchable=True)
g.send_and_recv(u, v, fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None)
v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3']
assert th.allclose(v1, v2)
Expand All @@ -201,7 +201,7 @@ def reduce_func(hv, msgs):
g.send_and_recv(u, v,
[fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')],
[fn.sum(msgs='m1', out='v1'), fn.sum(msgs='m2', out='v2'), fn.sum(msgs='m1', out='v3')],
None, batchable=True)
None)
v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3']
Expand All @@ -210,7 +210,7 @@ def reduce_func(hv, msgs):

# run UDF with single message and reduce
g.send_and_recv(u, v, message_func_edge,
reduce_func, None, batchable=True)
reduce_func, None)
v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2)

Expand Down
Loading

0 comments on commit 7d04c8c

Please sign in to comment.