diff --git a/python/dgl/graph.py b/python/dgl/graph.py index 58212eb166c8..523c98325652 100644 --- a/python/dgl/graph.py +++ b/python/dgl/graph.py @@ -50,11 +50,11 @@ def __init__(self, self._msg_frame = FrameRef() self.reset_messages() # registered functions - self._message_func = (None, None) - self._reduce_func = (None, None) - self._edge_func = (None, None) - self._apply_node_func = (None, None) - self._apply_edge_func = (None, None) + self._message_func = None + self._reduce_func = None + self._edge_func = None + self._apply_node_func = None + self._apply_edge_func = None def add_nodes(self, num, reprs=None): """Add nodes. @@ -710,77 +710,57 @@ def get_e_repr_by_id(self, eid=ALL): else: return self._edge_frame.select_rows(eid) - def register_edge_func(self, - edge_func, - batchable=False): + def register_edge_func(self, edge_func): """Register global edge update function. Parameters ---------- edge_func : callable Message function on the edge. - batchable : bool - Whether the provided message function allows batch computing. """ - self._edge_func = (edge_func, batchable) + self._edge_func = edge_func - def register_message_func(self, - message_func, - batchable=False): + def register_message_func(self, message_func): """Register global message function. Parameters ---------- message_func : callable Message function on the edge. - batchable : bool - Whether the provided message function allows batch computing. """ - self._message_func = (message_func, batchable) + self._message_func = message_func - def register_reduce_func(self, - reduce_func, - batchable=False): + def register_reduce_func(self, reduce_func): """Register global message reduce function. Parameters ---------- reduce_func : str or callable Reduce function on incoming edges. - batchable : bool - Whether the provided reduce function allows batch computing. """ - self._reduce_func = (reduce_func, batchable) + self._reduce_func = reduce_func - def register_apply_node_func(self, - apply_node_func, - batchable=False): + def register_apply_node_func(self, apply_node_func): """Register global node apply function. Parameters ---------- apply_node_func : callable Apply function on the node. - batchable : bool - Whether the provided function allows batch computing. """ - self._apply_node_func = (apply_node_func, batchable) + self._apply_node_func = apply_node_func - def register_apply_edge_func(self, - apply_edge_func, - batchable=False): + def register_apply_edge_func(self, apply_edge_func): """Register global edge apply function. Parameters ---------- apply_edge_func : callable Apply function on the edge. - batchable : bool - Whether the provided function allows batch computing. """ - self._apply_edge_func = (apply_edge_func, batchable) + self._apply_edge_func = apply_edge_func - def apply_nodes(self, v, apply_node_func="default", batchable=False): + def apply_nodes(self, v, apply_node_func="default"): """Apply the function on node representations. Parameters @@ -789,27 +769,16 @@ def apply_nodes(self, v, apply_node_func="default", batchable=False): The node id(s). apply_node_func : callable The apply node function. - batchable : bool - Whether the provided function allows batch computing. """ if apply_node_func == "default": - apply_node_func, batchable = self._apply_node_func + apply_node_func = self._apply_node_func if not apply_node_func: # Skip none function call. return - if batchable: - new_repr = apply_node_func(self.get_n_repr(v)) - self.set_n_repr(new_repr, v) - else: - raise RuntimeError('Disabled') - if is_all(v): - v = self.nodes() - v = utils.toindex(v) - for vv in utils.node_iter(v): - ret = apply_node_func(_get_repr(self.nodes[vv])) - _set_repr(self.nodes[vv], ret) + new_repr = apply_node_func(self.get_n_repr(v)) + self.set_n_repr(new_repr, v) - def apply_edges(self, u, v, apply_edge_func="default", batchable=False): + def apply_edges(self, u, v, apply_edge_func="default"): """Apply the function on edge representations. Parameters @@ -820,27 +789,16 @@ def apply_edges(self, u, v, apply_edge_func="default", batchable=False): The dst node id(s). apply_edge_func : callable The apply edge function. - batchable : bool - Whether the provided function allows batch computing. """ if apply_edge_func == "default": - apply_edge_func, batchable = self._apply_edge_func + apply_edge_func = self._apply_edge_func if not apply_edge_func: # Skip none function call. return - if batchable: - new_repr = apply_edge_func(self.get_e_repr(u, v)) - self.set_e_repr(new_repr, u, v) - else: - if is_all(u) == is_all(v): - u, v = zip(*self.edges) - u = utils.toindex(u) - v = utils.toindex(v) - for uu, vv in utils.edge_iter(u, v): - ret = apply_edge_func(_get_repr(self.edges[uu, vv])) - _set_repr(self.edges[uu, vv], ret) + new_repr = apply_edge_func(self.get_e_repr(u, v)) + self.set_e_repr(new_repr, u, v) - def send(self, u, v, message_func="default", batchable=False): + def send(self, u, v, message_func="default"): """Trigger the message function on edge u->v The message function should be compatible with following signature: @@ -861,30 +819,13 @@ def send(self, u, v, message_func="default", batchable=False): The destination node(s). message_func : callable The message function. - batchable : bool - Whether the function allows batched computation. """ if message_func == "default": - message_func, batchable = self._message_func + message_func = self._message_func assert message_func is not None if isinstance(message_func, (tuple, list)): message_func = BundledMessageFunction(message_func) - if batchable: - self._batch_send(u, v, message_func) - else: - self._nonbatch_send(u, v, message_func) - - def _nonbatch_send(self, u, v, message_func): - raise RuntimeError('Disabled') - if is_all(u) and is_all(v): - u, v = self.cached_graph.edges() - else: - u = utils.toindex(u) - v = utils.toindex(v) - for uu, vv in utils.edge_iter(u, v): - ret = message_func(_get_repr(self.nodes[uu]), - _get_repr(self.edges[uu, vv])) - self.edges[uu, vv][__MSG__] = ret + self._batch_send(u, v, message_func) def _batch_send(self, u, v, message_func): if is_all(u) and is_all(v): @@ -908,7 +849,7 @@ def _batch_send(self, u, v, message_func): else: self._msg_frame.append({__MSG__ : msgs}) - def update_edge(self, u=ALL, v=ALL, edge_func="default", batchable=False): + def update_edge(self, u=ALL, v=ALL, edge_func="default"): """Update representation on edge u->v The edge function should be compatible with following signature: @@ -927,29 +868,11 @@ def update_edge(self, u=ALL, v=ALL, edge_func="default", batchable=False): The destination node(s). edge_func : callable The update function. - batchable : bool - Whether the function allows batched computation. """ if edge_func == "default": - edge_func, batchable = self._edge_func + edge_func = self._edge_func assert edge_func is not None - if batchable: - self._batch_update_edge(u, v, edge_func) - else: - self._nonbatch_update_edge(u, v, edge_func) - - def _nonbatch_update_edge(self, u, v, edge_func): - raise RuntimeError('Disabled') - if is_all(u) and is_all(v): - u, v = self.cached_graph.edges() - else: - u = utils.toindex(u) - v = utils.toindex(v) - for uu, vv in utils.edge_iter(u, v): - ret = edge_func(_get_repr(self.nodes[uu]), - _get_repr(self.nodes[vv]), - _get_repr(self.edges[uu, vv])) - _set_repr(self.edges[uu, vv], ret) + self._batch_update_edge(u, v, edge_func) def _batch_update_edge(self, u, v, edge_func): if is_all(u) and is_all(v): @@ -975,8 +898,7 @@ def _batch_update_edge(self, u, v, edge_func): def recv(self, u, reduce_func="default", - apply_node_func="default", - batchable=False): + apply_node_func="default"): """Receive and reduce in-coming messages and update representation on node u. It computes the new node state using the messages sent from the predecessors @@ -1006,34 +928,15 @@ def recv(self, The reduce function. apply_node_func : callable, optional The update function. - batchable : bool, optional - Whether the reduce and update function allows batched computation. """ if reduce_func == "default": - reduce_func, batchable = self._reduce_func + reduce_func = self._reduce_func assert reduce_func is not None if isinstance(reduce_func, (list, tuple)): reduce_func = BundledReduceFunction(reduce_func) - if batchable: - self._batch_recv(u, reduce_func) - else: - self._nonbatch_recv(u, reduce_func) + self._batch_recv(u, reduce_func) # optional apply nodes - self.apply_nodes(u, apply_node_func, batchable) - - def _nonbatch_recv(self, u, reduce_func): - raise RuntimeError('Disabled') - if is_all(u): - u = list(range(0, self.number_of_nodes())) - else: - u = utils.toindex(u) - for i, uu in enumerate(utils.node_iter(u)): - # reduce phase - msgs_batch = [self.edges[vv, uu].pop(__MSG__) - for vv in self.pred[uu] if __MSG__ in self.edges[vv, uu]] - if len(msgs_batch) != 0: - new_repr = reduce_func(_get_repr(self.nodes[uu]), msgs_batch) - _set_repr(self.nodes[uu], new_repr) + self.apply_nodes(u, apply_node_func) def _batch_recv(self, v, reduce_func): if self._msg_frame.num_rows == 0: @@ -1105,8 +1008,7 @@ def send_and_recv(self, u, v, message_func="default", reduce_func="default", - apply_node_func="default", - batchable=False): + apply_node_func="default"): """Trigger the message function on u->v and update v. Parameters @@ -1121,8 +1023,6 @@ def send_and_recv(self, The reduce function. apply_node_func : callable, optional The update function. - batchable : bool - Whether the reduce and update function allows batched computation. """ u = utils.toindex(u) v = utils.toindex(v) @@ -1132,34 +1032,28 @@ def send_and_recv(self, return unique_v = utils.toindex(F.unique(v.tousertensor())) - # TODO(minjie): better way to figure out `batchable` flag if message_func == "default": - message_func, batchable = self._message_func + message_func = self._message_func if reduce_func == "default": - reduce_func, _ = self._reduce_func + reduce_func = self._reduce_func assert message_func is not None assert reduce_func is not None - if batchable: - executor = scheduler.get_executor( - 'send_and_recv', self, src=u, dst=v, - message_func=message_func, reduce_func=reduce_func) - else: - executor = None - + executor = scheduler.get_executor( + 'send_and_recv', self, src=u, dst=v, + message_func=message_func, reduce_func=reduce_func) if executor: executor.run() else: - self.send(u, v, message_func, batchable=batchable) - self.recv(unique_v, reduce_func, None, batchable=batchable) - self.apply_nodes(unique_v, apply_node_func, batchable=batchable) + self.send(u, v, message_func) + self.recv(unique_v, reduce_func, None) + self.apply_nodes(unique_v, apply_node_func) def pull(self, v, message_func="default", reduce_func="default", - apply_node_func="default", - batchable=False): + apply_node_func="default"): """Pull messages from the node's predecessors and then update it. Parameters @@ -1172,24 +1066,20 @@ def pull(self, The reduce function. apply_node_func : callable, optional The update function. - batchable : bool - Whether the reduce and update function allows batched computation. """ v = utils.toindex(v) if len(v) == 0: return uu, vv, _ = self._graph.in_edges(v) - self.send_and_recv(uu, vv, message_func, reduce_func, - apply_node_func=None, batchable=batchable) + self.send_and_recv(uu, vv, message_func, reduce_func, apply_node_func=None) unique_v = F.unique(v.tousertensor()) - self.apply_nodes(unique_v, apply_node_func, batchable=batchable) + self.apply_nodes(unique_v, apply_node_func) def push(self, u, message_func="default", reduce_func="default", - apply_node_func="default", - batchable=False): + apply_node_func="default"): """Send message from the node to its successors and update them. Parameters @@ -1202,21 +1092,18 @@ def push(self, The reduce function. apply_node_func : callable The update function. - batchable : bool - Whether the reduce and update function allows batched computation. """ u = utils.toindex(u) if len(u) == 0: return uu, vv, _ = self._graph.out_edges(u) self.send_and_recv(uu, vv, message_func, - reduce_func, apply_node_func, batchable=batchable) + reduce_func, apply_node_func) def update_all(self, message_func="default", reduce_func="default", - apply_node_func="default", - batchable=False): + apply_node_func="default"): """Send messages through all the edges and update all nodes. Parameters @@ -1227,35 +1114,28 @@ def update_all(self, The reduce function. apply_node_func : callable, optional The update function. - batchable : bool - Whether the reduce and update function allows batched computation. """ if message_func == "default": - message_func, batchable = self._message_func + message_func = self._message_func if reduce_func == "default": - reduce_func, _ = self._reduce_func + reduce_func = self._reduce_func assert message_func is not None assert reduce_func is not None - if batchable: - executor = scheduler.get_executor( - "update_all", self, message_func=message_func, reduce_func=reduce_func) - else: - executor = None - + executor = scheduler.get_executor( + "update_all", self, message_func=message_func, reduce_func=reduce_func) if executor: executor.run() else: - self.send(ALL, ALL, message_func, batchable=batchable) - self.recv(ALL, reduce_func, None, batchable=batchable) - self.apply_nodes(ALL, apply_node_func, batchable=batchable) + self.send(ALL, ALL, message_func) + self.recv(ALL, reduce_func, None) + self.apply_nodes(ALL, apply_node_func) def propagate(self, iterator='bfs', message_func="default", reduce_func="default", apply_node_func="default", - batchable=False, **kwargs): """Propagate messages and update nodes using iterator. @@ -1274,8 +1154,6 @@ def propagate(self, The reduce function. apply_node_func : str or callable The update function. - batchable : bool - Whether the reduce and update function allows batched computation. iterator : str or generator of steps. The iterator of the graph. kwargs : keyword arguments, optional @@ -1288,7 +1166,7 @@ def propagate(self, # NOTE: the iteration can return multiple edges at each step. for u, v in iterator: self.send_and_recv(u, v, - message_func, reduce_func, apply_node_func, batchable) + message_func, reduce_func, apply_node_func) def subgraph(self, nodes): """Generate the subgraph among the given nodes. @@ -1350,15 +1228,3 @@ def merge(self, subgraphs, reduce_func='sum'): [sg._parent_eid for sg in to_merge], self._edge_frame.num_rows, reduce_func) - -def _get_repr(attr_dict): - if len(attr_dict) == 1 and __REPR__ in attr_dict: - return attr_dict[__REPR__] - else: - return attr_dict - -def _set_repr(attr_dict, attr): - if utils.is_dict_like(attr): - attr_dict.update(attr) - else: - attr_dict[__REPR__] = attr diff --git a/tests/pytorch/test_batching.py b/tests/pytorch/test_batching.py index f9b4e2435f2c..6711adc0b8f5 100644 --- a/tests/pytorch/test_batching.py +++ b/tests/pytorch/test_batching.py @@ -133,7 +133,7 @@ def test_batch_send(): def _fmsg(src, edge): assert src['h'].shape == (5, D) return {'m' : src['h']} - g.register_message_func(_fmsg, batchable=True) + g.register_message_func(_fmsg) # many-many send u = th.tensor([0, 0, 0, 0, 0]) v = th.tensor([1, 2, 3, 4, 5]) @@ -150,9 +150,9 @@ def _fmsg(src, edge): def test_batch_recv(): # basic recv test g = generate_graph() - g.register_message_func(message_func, batchable=True) - g.register_reduce_func(reduce_func, batchable=True) - g.register_apply_node_func(apply_node_func, batchable=True) + g.register_message_func(message_func) + g.register_reduce_func(reduce_func) + g.register_apply_node_func(apply_node_func) u = th.tensor([0, 0, 0, 4, 5, 6]) v = th.tensor([1, 2, 3, 9, 9, 9]) reduce_msg_shapes.clear() @@ -163,9 +163,9 @@ def test_batch_recv(): def test_update_routines(): g = generate_graph() - g.register_message_func(message_func, batchable=True) - g.register_reduce_func(reduce_func, batchable=True) - g.register_apply_node_func(apply_node_func, batchable=True) + g.register_message_func(message_func) + g.register_reduce_func(reduce_func) + g.register_apply_node_func(apply_node_func) # send_and_recv reduce_msg_shapes.clear() @@ -209,7 +209,7 @@ def _reduce(node, msgs): return node + msgs.sum(1) old_repr = th.randn(5, 5) g.set_n_repr(old_repr) - g.update_all(_message, _reduce, batchable=True) + g.update_all(_message, _reduce) new_repr = g.get_n_repr() assert th.allclose(new_repr[1:], old_repr[1:]) @@ -227,17 +227,17 @@ def _reduce(node, msgs): old_repr = th.randn(2, 5) g.set_n_repr(old_repr) - g.pull(0, _message, _reduce, batchable=True) + g.pull(0, _message, _reduce) new_repr = g.get_n_repr() assert th.allclose(new_repr[0], old_repr[0]) assert th.allclose(new_repr[1], old_repr[1]) - g.pull(1, _message, _reduce, batchable=True) + g.pull(1, _message, _reduce) new_repr = g.get_n_repr() assert th.allclose(new_repr[1], old_repr[0]) old_repr = th.randn(2, 5) g.set_n_repr(old_repr) - g.pull([0, 1], _message, _reduce, batchable=True) + g.pull([0, 1], _message, _reduce) new_repr = g.get_n_repr() assert th.allclose(new_repr[0], old_repr[0]) assert th.allclose(new_repr[1], old_repr[0]) diff --git a/tests/pytorch/test_batching_anonymous.py b/tests/pytorch/test_batching_anonymous.py index b35abe1c242c..ad431b05b6e3 100644 --- a/tests/pytorch/test_batching_anonymous.py +++ b/tests/pytorch/test_batching_anonymous.py @@ -129,7 +129,7 @@ def test_batch_send(): def _fmsg(hu, edge): assert hu.shape == (5, D) return hu - g.register_message_func(_fmsg, batchable=True) + g.register_message_func(_fmsg) # many-many send u = th.tensor([0, 0, 0, 0, 0]) v = th.tensor([1, 2, 3, 4, 5]) @@ -145,8 +145,8 @@ def _fmsg(hu, edge): def test_batch_recv(): g = generate_graph() - g.register_message_func(message_func, batchable=True) - g.register_reduce_func(reduce_func, batchable=True) + g.register_message_func(message_func) + g.register_reduce_func(reduce_func) u = th.tensor([0, 0, 0, 4, 5, 6]) v = th.tensor([1, 2, 3, 9, 9, 9]) reduce_msg_shapes.clear() @@ -157,8 +157,8 @@ def test_batch_recv(): def test_update_routines(): g = generate_graph() - g.register_message_func(message_func, batchable=True) - g.register_reduce_func(reduce_func, batchable=True) + g.register_message_func(message_func) + g.register_reduce_func(reduce_func) # send_and_recv reduce_msg_shapes.clear() diff --git a/tests/pytorch/test_function.py b/tests/pytorch/test_function.py index ca6702fc9197..2aef6975ee69 100644 --- a/tests/pytorch/test_function.py +++ b/tests/pytorch/test_function.py @@ -51,32 +51,32 @@ def reducer_none(node, msgs): def test_copy_src(): # copy_src with both fields g = generate_graph() - g.register_message_func(fn.copy_src(src='h', out='m'), batchable=True) - g.register_reduce_func(reducer_both, batchable=True) + g.register_message_func(fn.copy_src(src='h', out='m')) + g.register_reduce_func(reducer_both) g.update_all() assert th.allclose(g.get_n_repr()['h'], th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.])) # copy_src with only src field; the out field should use anonymous repr g = generate_graph() - g.register_message_func(fn.copy_src(src='h'), batchable=True) - g.register_reduce_func(reducer_out, batchable=True) + g.register_message_func(fn.copy_src(src='h')) + g.register_reduce_func(reducer_out) g.update_all() assert th.allclose(g.get_n_repr()['h'], th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.])) # copy_src with no src field; should use anonymous repr g = generate_graph1() - g.register_message_func(fn.copy_src(out='m'), batchable=True) - g.register_reduce_func(reducer_both, batchable=True) + g.register_message_func(fn.copy_src(out='m')) + g.register_reduce_func(reducer_both) g.update_all() assert th.allclose(g.get_n_repr()['h'], th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.])) # copy src with no fields; g = generate_graph1() - g.register_message_func(fn.copy_src(), batchable=True) - g.register_reduce_func(reducer_out, batchable=True) + g.register_message_func(fn.copy_src()) + g.register_reduce_func(reducer_out) g.update_all() assert th.allclose(g.get_n_repr()['h'], th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.])) @@ -84,32 +84,32 @@ def test_copy_src(): def test_copy_edge(): # copy_edge with both fields g = generate_graph() - g.register_message_func(fn.copy_edge(edge='h', out='m'), batchable=True) - g.register_reduce_func(reducer_both, batchable=True) + g.register_message_func(fn.copy_edge(edge='h', out='m')) + g.register_reduce_func(reducer_both) g.update_all() assert th.allclose(g.get_n_repr()['h'], th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.])) # copy_edge with only edge field; the out field should use anonymous repr g = generate_graph() - g.register_message_func(fn.copy_edge(edge='h'), batchable=True) - g.register_reduce_func(reducer_out, batchable=True) + g.register_message_func(fn.copy_edge(edge='h')) + g.register_reduce_func(reducer_out) g.update_all() assert th.allclose(g.get_n_repr()['h'], th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.])) # copy_edge with no edge field; should use anonymous repr g = generate_graph1() - g.register_message_func(fn.copy_edge(out='m'), batchable=True) - g.register_reduce_func(reducer_both, batchable=True) + g.register_message_func(fn.copy_edge(out='m')) + g.register_reduce_func(reducer_both) g.update_all() assert th.allclose(g.get_n_repr()['h'], th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.])) # copy edge with no fields; g = generate_graph1() - g.register_message_func(fn.copy_edge(), batchable=True) - g.register_reduce_func(reducer_out, batchable=True) + g.register_message_func(fn.copy_edge()) + g.register_reduce_func(reducer_out) g.update_all() assert th.allclose(g.get_n_repr()['h'], th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.])) @@ -117,36 +117,36 @@ def test_copy_edge(): def test_src_mul_edge(): # src_mul_edge with all fields g = generate_graph() - g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m'), batchable=True) - g.register_reduce_func(reducer_both, batchable=True) + g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m')) + g.register_reduce_func(reducer_both) g.update_all() assert th.allclose(g.get_n_repr()['h'], th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.])) g = generate_graph() - g.register_message_func(fn.src_mul_edge(src='h', edge='h'), batchable=True) - g.register_reduce_func(reducer_out, batchable=True) + g.register_message_func(fn.src_mul_edge(src='h', edge='h')) + g.register_reduce_func(reducer_out) g.update_all() assert th.allclose(g.get_n_repr()['h'], th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.])) g = generate_graph1() - g.register_message_func(fn.src_mul_edge(out='m'), batchable=True) - g.register_reduce_func(reducer_both, batchable=True) + g.register_message_func(fn.src_mul_edge(out='m')) + g.register_reduce_func(reducer_both) g.update_all() assert th.allclose(g.get_n_repr()['h'], th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.])) g = generate_graph1() - g.register_message_func(fn.src_mul_edge(), batchable=True) - g.register_reduce_func(reducer_out, batchable=True) + g.register_message_func(fn.src_mul_edge()) + g.register_reduce_func(reducer_out) g.update_all() assert th.allclose(g.get_n_repr()['h'], th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.])) g = generate_graph1() - g.register_message_func(fn.src_mul_edge(), batchable=True) - g.register_reduce_func(reducer_none, batchable=True) + g.register_message_func(fn.src_mul_edge()) + g.register_reduce_func(reducer_none) g.update_all() assert th.allclose(g.get_n_repr(), th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.])) diff --git a/tests/pytorch/test_graph_batch.py b/tests/pytorch/test_graph_batch.py index 4501b2d6fee7..fe3765585ee5 100644 --- a/tests/pytorch/test_graph_batch.py +++ b/tests/pytorch/test_graph_batch.py @@ -71,8 +71,8 @@ def test_batch_sendrecv(): t2 = tree2() bg = dgl.batch([t1, t2]) - bg.register_message_func(lambda src, edge: src, batchable=True) - bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1), batchable=True) + bg.register_message_func(lambda src, edge: src) + bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1)) e1 = [(3, 1), (4, 1)] e2 = [(2, 4), (0, 4)] @@ -94,8 +94,8 @@ def test_batch_propagate(): t2 = tree2() bg = dgl.batch([t1, t2]) - bg.register_message_func(lambda src, edge: src, batchable=True) - bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1), batchable=True) + bg.register_message_func(lambda src, edge: src) + bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1)) # get leaves. order = [] diff --git a/tests/pytorch/test_specialization.py b/tests/pytorch/test_specialization.py index 4b1f7797e563..6327f53c4bb3 100644 --- a/tests/pytorch/test_specialization.py +++ b/tests/pytorch/test_specialization.py @@ -38,23 +38,23 @@ def apply_func(hu): g = generate_graph() # update all v1 = g.get_n_repr()[fld] - g.update_all(fn.copy_src(src=fld), fn.sum(out=fld), apply_func, batchable=True) + g.update_all(fn.copy_src(src=fld), fn.sum(out=fld), apply_func) v2 = g.get_n_repr()[fld] g.set_n_repr({fld : v1}) - g.update_all(message_func, reduce_func, apply_func, batchable=True) + g.update_all(message_func, reduce_func, apply_func) v3 = g.get_n_repr()[fld] assert th.allclose(v2, v3) # update all with edge weights v1 = g.get_n_repr()[fld] g.update_all(fn.src_mul_edge(src=fld, edge='e1'), - fn.sum(out=fld), apply_func, batchable=True) + fn.sum(out=fld), apply_func) v2 = g.get_n_repr()[fld] g.set_n_repr({fld : v1}) g.update_all(fn.src_mul_edge(src=fld, edge='e2'), - fn.sum(out=fld), apply_func, batchable=True) + fn.sum(out=fld), apply_func) v3 = g.get_n_repr()[fld] g.set_n_repr({fld : v1}) - g.update_all(message_func_edge, reduce_func, apply_func, batchable=True) + g.update_all(message_func_edge, reduce_func, apply_func) v4 = g.get_n_repr()[fld] assert th.allclose(v2, v3) assert th.allclose(v3, v4) @@ -85,25 +85,25 @@ def apply_func(hu): # send and recv v1 = g.get_n_repr()[fld] g.send_and_recv(u, v, fn.copy_src(src=fld), - fn.sum(out=fld), apply_func, batchable=True) + fn.sum(out=fld), apply_func) v2 = g.get_n_repr()[fld] g.set_n_repr({fld : v1}) g.send_and_recv(u, v, message_func, - reduce_func, apply_func, batchable=True) + reduce_func, apply_func) v3 = g.get_n_repr()[fld] assert th.allclose(v2, v3) # send and recv with edge weights v1 = g.get_n_repr()[fld] g.send_and_recv(u, v, fn.src_mul_edge(src=fld, edge='e1'), - fn.sum(out=fld), apply_func, batchable=True) + fn.sum(out=fld), apply_func) v2 = g.get_n_repr()[fld] g.set_n_repr({fld : v1}) g.send_and_recv(u, v, fn.src_mul_edge(src=fld, edge='e2'), - fn.sum(out=fld), apply_func, batchable=True) + fn.sum(out=fld), apply_func) v3 = g.get_n_repr()[fld] g.set_n_repr({fld : v1}) g.send_and_recv(u, v, message_func_edge, - reduce_func, apply_func, batchable=True) + reduce_func, apply_func) v4 = g.get_n_repr()[fld] assert th.allclose(v2, v3) assert th.allclose(v3, v4) @@ -127,18 +127,18 @@ def reduce_func(hv, msgs): # update all, mix of builtin and UDF g.update_all([fn.copy_src(src=fld, out='m1'), message_func], [fn.sum(msgs='m1', out='v1'), reduce_func], - None, batchable=True) + None) v1 = g.get_n_repr()['v1'] v2 = g.get_n_repr()['v2'] assert th.allclose(v1, v2) # run builtin with single message and reduce - g.update_all(fn.copy_src(src=fld), fn.sum(out='v1'), None, batchable=True) + g.update_all(fn.copy_src(src=fld), fn.sum(out='v1'), None) v1 = g.get_n_repr()['v1'] assert th.allclose(v1, v2) # 1 message, 2 reduces, using anonymous repr - g.update_all(fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None, batchable=True) + g.update_all(fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None) v2 = g.get_n_repr()['v2'] v3 = g.get_n_repr()['v3'] assert th.allclose(v1, v2) @@ -147,7 +147,7 @@ def reduce_func(hv, msgs): # update all with edge weights, 2 message, 3 reduces g.update_all([fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')], [fn.sum(msgs='m1', out='v1'), fn.sum(msgs='m2', out='v2'), fn.sum(msgs='m1', out='v3')], - None, batchable=True) + None) v1 = g.get_n_repr()['v1'] v2 = g.get_n_repr()['v2'] v3 = g.get_n_repr()['v3'] @@ -155,7 +155,7 @@ def reduce_func(hv, msgs): assert th.allclose(v1, v3) # run UDF with single message and reduce - g.update_all(message_func_edge, reduce_func, None, batchable=True) + g.update_all(message_func_edge, reduce_func, None) v2 = g.get_n_repr()['v2'] assert th.allclose(v1, v2) @@ -179,19 +179,19 @@ def reduce_func(hv, msgs): g.send_and_recv(u, v, [fn.copy_src(src=fld, out='m1'), message_func], [fn.sum(msgs='m1', out='v1'), reduce_func], - None, batchable=True) + None) v1 = g.get_n_repr()['v1'] v2 = g.get_n_repr()['v2'] assert th.allclose(v1, v2) # run builtin with single message and reduce g.send_and_recv(u, v, fn.copy_src(src=fld), fn.sum(out='v1'), - None, batchable=True) + None) v1 = g.get_n_repr()['v1'] assert th.allclose(v1, v2) # 1 message, 2 reduces, using anonymous repr - g.send_and_recv(u, v, fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None, batchable=True) + g.send_and_recv(u, v, fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None) v2 = g.get_n_repr()['v2'] v3 = g.get_n_repr()['v3'] assert th.allclose(v1, v2) @@ -201,7 +201,7 @@ def reduce_func(hv, msgs): g.send_and_recv(u, v, [fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')], [fn.sum(msgs='m1', out='v1'), fn.sum(msgs='m2', out='v2'), fn.sum(msgs='m1', out='v3')], - None, batchable=True) + None) v1 = g.get_n_repr()['v1'] v2 = g.get_n_repr()['v2'] v3 = g.get_n_repr()['v3'] @@ -210,7 +210,7 @@ def reduce_func(hv, msgs): # run UDF with single message and reduce g.send_and_recv(u, v, message_func_edge, - reduce_func, None, batchable=True) + reduce_func, None) v2 = g.get_n_repr()['v2'] assert th.allclose(v1, v2) diff --git a/tests/test_anonymous_repr.py b/tests/test_anonymous_repr.py deleted file mode 100644 index 90d5cf6a5695..000000000000 --- a/tests/test_anonymous_repr.py +++ /dev/null @@ -1,62 +0,0 @@ -from dgl import DGLGraph -from dgl.graph import __REPR__ - -def message_func(hu, e_uv): - return hu + e_uv - -def reduce_func(h, msgs): - return h + sum(msgs) - -def generate_graph(): - g = DGLGraph() - for i in range(10): - g.add_node(i, __REPR__=i+1) # 10 nodes. - # create a graph where 0 is the source and 9 is the sink - for i in range(1, 9): - g.add_edge(0, i, __REPR__=1) - g.add_edge(i, 9, __REPR__=1) - # add a back flow from 9 to 0 - g.add_edge(9, 0) - return g - -def check(g, h): - nh = [str(g.nodes[i][__REPR__]) for i in range(10)] - h = [str(x) for x in h] - assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h)) - -def test_sendrecv(): - g = generate_graph() - check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) - g.register_message_func(message_func) - g.register_reduce_func(reduce_func) - g.send(0, 1) - g.recv(1) - check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 10]) - g.send(5, 9) - g.send(6, 9) - g.recv(9) - check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 25]) - -def message_func_hybrid(src, edge): - return src[__REPR__] + edge - -def reduce_func_hybrid(node, msgs): - return node[__REPR__] + sum(msgs) - -def test_hybridrepr(): - g = generate_graph() - for i in range(10): - g.nodes[i]['id'] = -i - g.register_message_func(message_func_hybrid) - g.register_reduce_func(reduce_func_hybrid) - g.send(0, 1) - g.recv(1) - check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 10]) - g.send(5, 9) - g.send(6, 9) - g.recv(9) - check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 25]) - -if __name__ == '__main__': - test_sendrecv() - test_hybridrepr() diff --git a/tests/test_basics.py b/tests/test_basics.py deleted file mode 100644 index 800955f06b7f..000000000000 --- a/tests/test_basics.py +++ /dev/null @@ -1,111 +0,0 @@ -from dgl.graph import DGLGraph - -def message_func(src, edge): - return src['h'] - -def reduce_func(node, msgs): - return {'m' : sum(msgs)} - -def apply_func(node): - return {'h' : node['h'] + node['m']} - -def message_dict_func(src, edge): - return {'m' : src['h']} - -def reduce_dict_func(node, msgs): - return {'m' : sum([msg['m'] for msg in msgs])} - -def apply_dict_func(node): - return {'h' : node['h'] + node['m']} - -def generate_graph(): - g = DGLGraph() - for i in range(10): - g.add_node(i, h=i+1) # 10 nodes. - # create a graph where 0 is the source and 9 is the sink - for i in range(1, 9): - g.add_edge(0, i) - g.add_edge(i, 9) - # add a back flow from 9 to 0 - g.add_edge(9, 0) - return g - -def check(g, h): - nh = [str(g.nodes[i]['h']) for i in range(10)] - h = [str(x) for x in h] - assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h)) - -def register1(g): - g.register_message_func(message_func) - g.register_reduce_func(reduce_func) - g.register_apply_node_func(apply_func) - -def register2(g): - g.register_message_func(message_dict_func) - g.register_reduce_func(reduce_dict_func) - g.register_apply_node_func(apply_dict_func) - -def _test_sendrecv(g): - check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) - g.send(0, 1) - g.recv(1) - check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10]) - g.send(5, 9) - g.send(6, 9) - g.recv(9) - check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 23]) - -def _test_multi_sendrecv(g): - check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) - # one-many - g.send(0, [1, 2, 3]) - g.recv([1, 2, 3]) - check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 10]) - # many-one - g.send([6, 7, 8], 9) - g.recv(9) - check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 34]) - # many-many - g.send([0, 0, 4, 5], [4, 5, 9, 9]) - g.recv([4, 5, 9]) - check(g, [1, 3, 4, 5, 6, 7, 7, 8, 9, 45]) - -def _test_update_routines(g): - check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) - g.send_and_recv(0, 1) - check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10]) - g.pull(9) - check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 55]) - g.push(0) - check(g, [1, 4, 4, 5, 6, 7, 8, 9, 10, 55]) - g.update_all() - check(g, [56, 5, 5, 6, 7, 8, 9, 10, 11, 108]) - -def test_sendrecv(): - g = generate_graph() - register1(g) - _test_sendrecv(g) - g = generate_graph() - register2(g) - _test_sendrecv(g) - -def test_multi_sendrecv(): - g = generate_graph() - register1(g) - _test_multi_sendrecv(g) - g = generate_graph() - register2(g) - _test_multi_sendrecv(g) - -def test_update_routines(): - g = generate_graph() - register1(g) - _test_update_routines(g) - g = generate_graph() - register2(g) - _test_update_routines(g) - -if __name__ == '__main__': - test_sendrecv() - test_multi_sendrecv() - test_update_routines() diff --git a/tests/test_basics2.py b/tests/test_basics2.py deleted file mode 100644 index 90b039a23992..000000000000 --- a/tests/test_basics2.py +++ /dev/null @@ -1,74 +0,0 @@ -from dgl import DGLGraph -from dgl.graph import __REPR__ - -def message_func(hu, e_uv): - return hu - -def message_not_called(hu, e_uv): - assert False - return hu - -def reduce_not_called(h, msgs): - assert False - return 0 - -def reduce_func(h, msgs): - return h + sum(msgs) - -def check(g, h): - nh = [str(g.nodes[i][__REPR__]) for i in range(10)] - h = [str(x) for x in h] - assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h)) - -def generate_graph(): - g = DGLGraph() - for i in range(10): - g.add_node(i, __REPR__=i+1) # 10 nodes. - # create a graph where 0 is the source and 9 is the sink - for i in range(1, 9): - g.add_edge(0, i) - g.add_edge(i, 9) - return g - -def test_no_msg_recv(): - g = generate_graph() - check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) - g.register_message_func(message_not_called) - g.register_reduce_func(reduce_not_called) - g.register_apply_node_func(lambda h : h + 1) - for i in range(10): - g.recv(i) - check(g, [2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) - -def test_double_recv(): - g = generate_graph() - check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) - g.register_message_func(message_func) - g.register_reduce_func(reduce_func) - g.send(1, 9) - g.send(2, 9) - g.recv(9) - check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 15]) - g.register_reduce_func(reduce_not_called) - g.recv(9) - -def test_pull_0deg(): - g = DGLGraph() - g.add_node(0, h=2) - g.add_node(1, h=1) - g.add_edge(0, 1) - def _message(src, edge): - assert False - return src - def _reduce(node, msgs): - assert False - return node - def _update(node): - return {'h': node['h'] * 2} - g.pull(0, _message, _reduce, _update) - assert g.nodes[0]['h'] == 4 - -if __name__ == '__main__': - test_no_msg_recv() - test_double_recv() - test_pull_0deg() diff --git a/tests/test_function.py b/tests/test_function.py deleted file mode 100644 index 0cca9eaa35a8..000000000000 --- a/tests/test_function.py +++ /dev/null @@ -1,141 +0,0 @@ -import dgl -import dgl.function as fn -from dgl.graph import __REPR__ - -def generate_graph(): - g = dgl.DGLGraph() - for i in range(10): - g.add_node(i, h=i+1) # 10 nodes. - # create a graph where 0 is the source and 9 is the sink - for i in range(1, 9): - g.add_edge(0, i, h=1) - g.add_edge(i, 9, h=i+1) - # add a back flow from 9 to 0 - g.add_edge(9, 0, h=10) - return g - -def check(g, h, fld): - nh = [str(g.nodes[i][fld]) for i in range(10)] - h = [str(x) for x in h] - assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h)) - -def generate_graph1(): - """graph with anonymous repr""" - g = dgl.DGLGraph() - for i in range(10): - g.add_node(i, __REPR__=i+1) # 10 nodes. - # create a graph where 0 is the source and 9 is the sink - for i in range(1, 9): - g.add_edge(0, i, __REPR__=1) - g.add_edge(i, 9, __REPR__=i+1) - # add a back flow from 9 to 0 - g.add_edge(9, 0, __REPR__=10) - return g - -def test_copy_src(): - # copy_src with both fields - g = generate_graph() - g.register_message_func(fn.copy_src(src='h', out='m'), batchable=False) - g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False) - check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h') - g.update_all() - check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h') - - # copy_src with only src field; the out field should use anonymous repr - g = generate_graph() - g.register_message_func(fn.copy_src(src='h'), batchable=False) - g.register_reduce_func(fn.sum(out='h'), batchable=False) - check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h') - g.update_all() - check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h') - - # copy_src with no src field; should use anonymous repr - g = generate_graph1() - g.register_message_func(fn.copy_src(out='m'), batchable=False) - g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False) - check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__) - g.update_all() - check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h') - - # copy src with no fields; - g = generate_graph1() - g.register_message_func(fn.copy_src(), batchable=False) - g.register_reduce_func(fn.sum(out='h'), batchable=False) - check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__) - g.update_all() - check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h') - -def test_copy_edge(): - # copy_edge with both fields - g = generate_graph() - g.register_message_func(fn.copy_edge(edge='h', out='m'), batchable=False) - g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False) - check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h') - g.update_all() - check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h') - - # copy_edge with only edge field; the out field should use anonymous repr - g = generate_graph() - g.register_message_func(fn.copy_edge(edge='h'), batchable=False) - g.register_reduce_func(fn.sum(out='h'), batchable=False) - check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h') - g.update_all() - check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h') - - # copy_edge with no edge field; should use anonymous repr - g = generate_graph1() - g.register_message_func(fn.copy_edge(out='m'), batchable=False) - g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False) - check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__) - g.update_all() - check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h') - - # copy edge with no fields; - g = generate_graph1() - g.register_message_func(fn.copy_edge(), batchable=False) - g.register_reduce_func(fn.sum(out='h'), batchable=False) - check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__) - g.update_all() - check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h') - -def test_src_mul_edge(): - # src_mul_edge with all fields - g = generate_graph() - g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m'), batchable=False) - g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False) - check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h') - g.update_all() - check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], 'h') - - g = generate_graph() - g.register_message_func(fn.src_mul_edge(src='h', edge='h'), batchable=False) - g.register_reduce_func(fn.sum(out='h'), batchable=False) - check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h') - g.update_all() - check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], 'h') - - g = generate_graph1() - g.register_message_func(fn.src_mul_edge(out='m'), batchable=False) - g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False) - check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__) - g.update_all() - check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], 'h') - - g = generate_graph1() - g.register_message_func(fn.src_mul_edge(), batchable=False) - g.register_reduce_func(fn.sum(out='h'), batchable=False) - check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__) - g.update_all() - check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], 'h') - - g = generate_graph1() - g.register_message_func(fn.src_mul_edge(), batchable=False) - g.register_reduce_func(fn.sum(), batchable=False) - check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__) - g.update_all() - check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], __REPR__) - -if __name__ == '__main__': - test_copy_src() - test_copy_edge() - test_src_mul_edge()