From 6105e441426f97f31d96c54d6f35830028c2b3f6 Mon Sep 17 00:00:00 2001 From: Minjie Wang Date: Wed, 22 Aug 2018 12:25:05 -0400 Subject: [PATCH] Many fix and updates (#47) * subgraph copy from * WIP * cached members * Change all usage of id tensor to the new Index object; remove set device in DGLGraph; * subgraph merge API tested * add dict type reduced msg test --- examples/pytorch/gat/gat_batch.py | 2 - examples/pytorch/gcn/gcn_batch.py | 2 - examples/pytorch/gcn/gcn_spmv.py | 2 - examples/pytorch/generative_graph/model.py | 1 - python/dgl/backend/pytorch.py | 7 + python/dgl/batch.py | 4 +- python/dgl/builtin.py | 5 +- python/dgl/cached_graph.py | 63 +++--- python/dgl/context.py | 6 + python/dgl/frame.py | 126 +++++++++--- python/dgl/graph.py | 223 +++++++++++---------- python/dgl/scheduler.py | 26 ++- python/dgl/subgraph.py | 66 +++--- python/dgl/utils.py | 172 ++++++++++++++-- tests/pytorch/test_batching.py | 31 ++- tests/pytorch/test_cached_graph.py | 23 +-- tests/pytorch/test_frame.py | 19 +- tests/{ => pytorch}/test_graph_batch.py | 0 tests/pytorch/test_subgraph.py | 70 +++++-- tests/test_basics.py | 54 +++-- 20 files changed, 610 insertions(+), 292 deletions(-) rename tests/{ => pytorch}/test_graph_batch.py (100%) diff --git a/examples/pytorch/gat/gat_batch.py b/examples/pytorch/gat/gat_batch.py index fc5d7e3637d4..e6bc69010e5c 100644 --- a/examples/pytorch/gat/gat_batch.py +++ b/examples/pytorch/gat/gat_batch.py @@ -154,8 +154,6 @@ def main(args): # create GCN model g = DGLGraph(data.graph) - if cuda: - g.set_device(dgl.gpu(args.gpu)) # create model model = GAT(g, diff --git a/examples/pytorch/gcn/gcn_batch.py b/examples/pytorch/gcn/gcn_batch.py index e1297cdcb7a5..166bfa223b10 100644 --- a/examples/pytorch/gcn/gcn_batch.py +++ b/examples/pytorch/gcn/gcn_batch.py @@ -85,8 +85,6 @@ def main(args): # create GCN model g = DGLGraph(data.graph) - if cuda: - g.set_device(dgl.gpu(args.gpu)) model = GCN(g, in_feats, args.n_hidden, diff --git a/examples/pytorch/gcn/gcn_spmv.py b/examples/pytorch/gcn/gcn_spmv.py index 68e335a12da0..fa2c8e985165 100644 --- a/examples/pytorch/gcn/gcn_spmv.py +++ b/examples/pytorch/gcn/gcn_spmv.py @@ -79,8 +79,6 @@ def main(args): # create GCN model g = DGLGraph(data.graph) - if cuda: - g.set_device(dgl.gpu(args.gpu)) model = GCN(g, in_feats, args.n_hidden, diff --git a/examples/pytorch/generative_graph/model.py b/examples/pytorch/generative_graph/model.py index c02674142d4d..a8216575ac44 100644 --- a/examples/pytorch/generative_graph/model.py +++ b/examples/pytorch/generative_graph/model.py @@ -213,7 +213,6 @@ def masked_cross_entropy(x, label, mask=None): count, label, node_list, mask, active, label1, label1_tensor = ground_truth[0] label, node_list, mask, label1_tensor = move2cuda((label, node_list, mask, label1_tensor)) ground_truth[0] = (count, label, node_list, mask, active, label1, label1_tensor) - ground_truth[1][0].set_device(dgl.gpu(args.gpu)) optimizer.zero_grad() # create new empty graphs diff --git a/python/dgl/backend/pytorch.py b/python/dgl/backend/pytorch.py index caca05c73667..91303ed1ef2c 100644 --- a/python/dgl/backend/pytorch.py +++ b/python/dgl/backend/pytorch.py @@ -2,6 +2,7 @@ import torch as th import scipy.sparse +import dgl.context as context # Tensor types Tensor = th.Tensor @@ -73,3 +74,9 @@ def to_context(x, ctx): return x.cpu() else: raise RuntimeError('Invalid context', ctx) + +def get_context(x): + if x.device.type == 'cpu': + return context.cpu() + else: + return context.gpu(x.device.index) diff --git a/python/dgl/batch.py b/python/dgl/batch.py index 522f1a6c1744..5197a0ac3ea8 100644 --- a/python/dgl/batch.py +++ b/python/dgl/batch.py @@ -87,7 +87,7 @@ def unbatch(graph_batch): num_graphs = len(graph_list) # split and set node attrs attrs = [{} for _ in range(num_graphs)] # node attr dict for each graph - for key in graph_batch.get_n_attr_list(): + for key in graph_batch.node_attr_schemes(): vals = F.unpack(graph_batch.pop_n_repr(key), graph_batch.num_nodes) for attr, val in zip(attrs, vals): attr[key] = val @@ -96,7 +96,7 @@ def unbatch(graph_batch): # split and set edge attrs attrs = [{} for _ in range(num_graphs)] # edge attr dict for each graph - for key in graph_batch.get_e_attr_list(): + for key in graph_batch.edge_attr_schemes(): vals = F.unpack(graph_batch.pop_e_repr(key), graph_batch.num_edges) for attr, val in zip(attrs, vals): attr[key] = val diff --git a/python/dgl/builtin.py b/python/dgl/builtin.py index f838b2e97408..4201d7e1b6c0 100644 --- a/python/dgl/builtin.py +++ b/python/dgl/builtin.py @@ -8,7 +8,10 @@ def message_from_src(src, edge): def reduce_sum(node, msgs): if isinstance(msgs, list): - return sum(msgs) + if isinstance(msgs[0], dict): + return {k : sum(m[k] for m in msgs) for k in msgs[0].keys()} + else: + return sum(msgs) else: return F.sum(msgs, 1) diff --git a/python/dgl/cached_graph.py b/python/dgl/cached_graph.py index 2db53ceb05b5..5248de915e20 100644 --- a/python/dgl/cached_graph.py +++ b/python/dgl/cached_graph.py @@ -14,15 +14,21 @@ class CachedGraph: def __init__(self): self._graph = igraph.Graph(directed=True) - self._adjmat = None # cached adjacency matrix + self._freeze = False def add_nodes(self, num_nodes): + if self._freeze: + raise RuntimeError('Freezed cached graph cannot be mutated.') self._graph.add_vertices(num_nodes) def add_edge(self, u, v): + if self._freeze: + raise RuntimeError('Freezed cached graph cannot be mutated.') self._graph.add_edge(u, v) def add_edges(self, u, v): + if self._freeze: + raise RuntimeError('Freezed cached graph cannot be mutated.') # The edge will be assigned ids equal to the order. uvs = list(utils.edge_iter(u, v)) self._graph.add_edges(uvs) @@ -30,7 +36,7 @@ def add_edges(self, u, v): def get_edge_id(self, u, v): uvs = list(utils.edge_iter(u, v)) eids = self._graph.get_eids(uvs) - return utils.convert_to_id_tensor(eids) + return utils.toindex(eids) def in_edges(self, v): src = [] @@ -39,8 +45,8 @@ def in_edges(self, v): uu = self._graph.predecessors(vv) src += uu dst += [vv] * len(uu) - src = utils.convert_to_id_tensor(src) - dst = utils.convert_to_id_tensor(dst) + src = utils.toindex(src) + dst = utils.toindex(dst) return src, dst def out_edges(self, u): @@ -50,44 +56,51 @@ def out_edges(self, u): vv = self._graph.successors(uu) src += [uu] * len(vv) dst += vv - src = utils.convert_to_id_tensor(src) - dst = utils.convert_to_id_tensor(dst) + src = utils.toindex(src) + dst = utils.toindex(dst) return src, dst + def in_degrees(self, v): + degs = self._graph.indegree(list(v)) + return utils.toindex(degs) + + def num_edges(self): + return self._graph.ecount() + + @utils.cached_member def edges(self): elist = self._graph.get_edgelist() src = [u for u, _ in elist] dst = [v for _, v in elist] - src = utils.convert_to_id_tensor(src) - dst = utils.convert_to_id_tensor(dst) + src = utils.toindex(src) + dst = utils.toindex(dst) return src, dst - def in_degrees(self, v): - degs = self._graph.indegree(list(v)) - return utils.convert_to_id_tensor(degs) - + @utils.ctx_cached_member def adjmat(self, ctx): """Return a sparse adjacency matrix. The row dimension represents the dst nodes; the column dimension represents the src nodes. """ - if self._adjmat is None: - elist = self._graph.get_edgelist() - src = [u for u, _ in elist] - dst = [v for _, v in elist] - src = F.unsqueeze(utils.convert_to_id_tensor(src), 0) - dst = F.unsqueeze(utils.convert_to_id_tensor(dst), 0) - idx = F.pack([dst, src]) - n = self._graph.vcount() - dat = F.ones((len(elist),)) - self._adjmat = F.sparse_tensor(idx, dat, [n, n]) - # TODO(minjie): manually convert adjmat to context - self._adjmat = F.to_context(self._adjmat, ctx) - return self._adjmat + elist = self._graph.get_edgelist() + src = F.tensor([u for u, _ in elist], dtype=F.int64) + dst = F.tensor([v for _, v in elist], dtype=F.int64) + src = F.unsqueeze(src, 0) + dst = F.unsqueeze(dst, 0) + idx = F.pack([dst, src]) + n = self._graph.vcount() + dat = F.ones((len(elist),)) + mat = F.sparse_tensor(idx, dat, [n, n]) + mat = F.to_context(mat, ctx) + return mat + + def freeze(self): + self._freeze = True def create_cached_graph(dglgraph): cg = CachedGraph() cg.add_nodes(dglgraph.number_of_nodes()) cg._graph.add_edges(dglgraph.edge_list) + cg.freeze() return cg diff --git a/python/dgl/context.py b/python/dgl/context.py index 5ebc8d1ed84a..71103dc41f0e 100644 --- a/python/dgl/context.py +++ b/python/dgl/context.py @@ -8,6 +8,12 @@ def __init__(self, dev, devid=-1): def __str__(self): return '{}:{}'.format(self.device, self.device_id) + def __eq__(self, other): + return self.device == other.device and self.device_id == other.device_id + + def __hash__(self): + return hash((self.device, self.device_id)) + def gpu(gpuid): return Context('gpu', gpuid) diff --git a/python/dgl/frame.py b/python/dgl/frame.py index 252376301954..fedc25302f42 100644 --- a/python/dgl/frame.py +++ b/python/dgl/frame.py @@ -6,7 +6,7 @@ import dgl.backend as F from dgl.backend import Tensor -from dgl.utils import LazyDict +import dgl.utils as utils class Frame(MutableMapping): def __init__(self, data=None): @@ -77,15 +77,24 @@ def __len__(self): return self.num_columns class FrameRef(MutableMapping): + """Frame reference + + Parameters + ---------- + frame : dgl.frame.Frame + The underlying frame. + index : iterable of int + The rows that are referenced in the underlying frame. + """ def __init__(self, frame=None, index=None): self._frame = frame if frame is not None else Frame() if index is None: - self._index = slice(0, self._frame.num_rows) + self._index_data = slice(0, self._frame.num_rows) else: - # check no duplicate index + # check no duplication assert len(index) == len(np.unique(index)) - self._index = index - self._index_tensor = None + self._index_data = index + self._index = None @property def schemes(self): @@ -97,10 +106,10 @@ def num_columns(self): @property def num_rows(self): - if isinstance(self._index, slice): - return self._index.stop + if isinstance(self._index_data, slice): + return self._index_data.stop else: - return len(self._index) + return len(self._index_data) def __contains__(self, key): return key in self._frame @@ -114,15 +123,17 @@ def __getitem__(self, key): def select_rows(self, query): rowids = self._getrowid(query) def _lazy_select(key): - return F.gather_row(self._frame[key], rowids) - return LazyDict(_lazy_select, keys=self.schemes) + idx = rowids.totensor(F.get_context(self._frame[key])) + return F.gather_row(self._frame[key], idx) + return utils.LazyDict(_lazy_select, keys=self.schemes) def get_column(self, name): col = self._frame[name] if self.is_span_whole_column(): return col else: - return F.gather_row(col, self.index_tensor()) + idx = self.index().totensor(F.get_context(col)) + return F.gather_row(col, idx) def __setitem__(self, key, val): if isinstance(key, str): @@ -134,22 +145,26 @@ def add_column(self, name, col): shp = F.shape(col) if self.is_span_whole_column(): if self.num_columns == 0: - self._index = slice(0, shp[0]) + self._index_data = slice(0, shp[0]) self._clear_cache() assert shp[0] == self.num_rows self._frame[name] = col else: + colctx = F.get_context(col) if name in self._frame: fcol = self._frame[name] else: fcol = F.zeros((self._frame.num_rows,) + shp[1:]) - newfcol = F.scatter_row(fcol, self.index_tensor(), col) + fcol = F.to_context(fcol, colctx) + idx = self.index().totensor(colctx) + newfcol = F.scatter_row(fcol, idx, col) self._frame[name] = newfcol def update_rows(self, query, other): rowids = self._getrowid(query) for key, col in other.items(): - self._frame[key] = F.scatter_row(self._frame[key], rowids, col) + idx = rowids.totensor(F.get_context(self._frame[key])) + self._frame[key] = F.scatter_row(self._frame[key], idx, col) def __delitem__(self, key): if isinstance(key, str): @@ -161,10 +176,10 @@ def __delitem__(self, key): def delete_rows(self, query): query = F.asnumpy(query) - if isinstance(self._index, slice): - self._index = list(range(self._index.start, self._index.stop)) - arr = np.array(self._index, dtype=np.int32) - self._index = list(np.delete(arr, query)) + if isinstance(self._index_data, slice): + self._index_data = list(range(self._index_data.start, self._index_data.stop)) + arr = np.array(self._index_data, dtype=np.int32) + self._index_data = list(np.delete(arr, query)) self._clear_cache() def append(self, other): @@ -174,16 +189,16 @@ def append(self, other): self._frame.append(other) # update index if span_whole: - self._index = slice(0, self._frame.num_rows) - else: - new_idx = list(range(self._index.start, self._index.stop)) + self._index_data = slice(0, self._frame.num_rows) + elif contiguous: + new_idx = list(range(self._index_data.start, self._index_data.stop)) new_idx += list(range(old_nrows, self._frame.num_rows)) - self._index = new_idx + self._index_data = new_idx self._clear_cache() def clear(self): self._frame.clear() - self._index = slice(0, 0) + self._index_data = slice(0, 0) self._clear_cache() def __iter__(self): @@ -194,26 +209,73 @@ def __len__(self): def is_contiguous(self): # NOTE: this check could have false negative - return isinstance(self._index, slice) + return isinstance(self._index_data, slice) def is_span_whole_column(self): return self.is_contiguous() and self.num_rows == self._frame.num_rows def _getrowid(self, query): - if isinstance(self._index, slice): + if self.is_contiguous(): # shortcut for identical mapping return query else: - return F.gather_row(self.index_tensor(), query) + idxtensor = self.index().totensor() + return utils.toindex(F.gather_row(idxtensor, query.totensor())) - def index_tensor(self): - # TODO(minjie): context - if self._index_tensor is None: + def index(self): + if self._index is None: if self.is_contiguous(): - self._index_tensor = F.arange(self._index.stop, dtype=F.int64) + self._index = utils.toindex( + F.arange(self._index_data.stop, dtype=F.int64)) else: - self._index_tensor = F.tensor(self._index, dtype=F.int64) - return self._index_tensor + self._index = utils.toindex(self._index_data) + return self._index def _clear_cache(self): self._index_tensor = None + +def merge_frames(frames, indices, max_index, reduce_func): + """Merge a list of frames. + + The result frame contains `max_index` number of rows. For each frame in + the given list, its row is merged as follows: + + merged[indices[i][row]] += frames[i][row] + + Parameters + ---------- + frames : iterator of dgl.frame.FrameRef + A list of frames to be merged. + indices : iterator of dgl.utils.Index + The indices of the frame rows. + reduce_func : str + The reduce function (only 'sum' is supported currently) + + Returns + ------- + merged : FrameRef + The merged frame. + """ + assert reduce_func == 'sum' + assert len(frames) > 0 + schemes = frames[0].schemes + # create an adj to merge + # row index is equal to the concatenation of all the indices. + row = sum([idx.tolist() for idx in indices], []) + col = list(range(len(row))) + n = max_index + m = len(row) + row = F.unsqueeze(F.tensor(row, dtype=F.int64), 0) + col = F.unsqueeze(F.tensor(col, dtype=F.int64), 0) + idx = F.pack([row, col]) + dat = F.ones((m,)) + adjmat = F.sparse_tensor(idx, dat, [n, m]) + ctx_adjmat = utils.CtxCachedObject(lambda ctx: F.to_context(adjmat, ctx)) + merged = {} + for key in schemes: + # the rhs of the spmv is the concatenation of all the frame columns + feats = F.pack([fr[key] for fr in frames]) + merged_feats = F.spmm(ctx_adjmat.get(F.get_context(feats)), feats) + merged[key] = merged_feats + merged = FrameRef(Frame(merged)) + return merged diff --git a/python/dgl/graph.py b/python/dgl/graph.py index c714f0519ea8..a2349527a93f 100644 --- a/python/dgl/graph.py +++ b/python/dgl/graph.py @@ -12,7 +12,7 @@ import dgl.builtin as builtin from dgl.cached_graph import CachedGraph, create_cached_graph import dgl.context as context -from dgl.frame import FrameRef +from dgl.frame import FrameRef, merge_frames from dgl.nx_adapt import nx_init import dgl.scheduler as scheduler import dgl.utils as utils @@ -62,12 +62,11 @@ def __init__(self, self._reduce_func = None self._update_func = None self._edge_func = None - self._context = context.cpu() - def get_n_attr_list(self): + def node_attr_schemes(self): return self._node_frame.schemes - def get_e_attr_list(self): + def edge_attr_schemes(self): return self._edge_frame.schemes def set_n_repr(self, hu, u=ALL): @@ -92,7 +91,7 @@ def set_n_repr(self, hu, u=ALL): if is_all(u): num_nodes = self.number_of_nodes() else: - u = utils.convert_to_id_tensor(u, self.context) + u = utils.toindex(u) num_nodes = len(u) if isinstance(hu, dict): for key, val in hu.items(): @@ -108,10 +107,9 @@ def set_n_repr(self, hu, u=ALL): self._node_frame[__REPR__] = hu else: if isinstance(hu, dict): - for key, val in hu.items(): - self._node_frame[key] = F.scatter_row(self._node_frame[key], u, val) + self._node_frame[u] = hu else: - self._node_frame[__REPR__] = F.scatter_row(self._node_frame[__REPR__], u, hu) + self._node_frame[u] = {__REPR__ : hu} def get_n_repr(self, u=ALL): """Get node(s) representation. @@ -127,9 +125,9 @@ def get_n_repr(self, u=ALL): else: return dict(self._node_frame) else: - u = utils.convert_to_id_tensor(u, self.context) + u = utils.toindex(u) if len(self._node_frame) == 1 and __REPR__ in self._node_frame: - return self._node_frame[__REPR__][u] + return self._node_frame.select_rows(u)[__REPR__] else: return self._node_frame.select_rows(u) @@ -168,10 +166,10 @@ def set_e_repr(self, h_uv, u=ALL, v=ALL): v_is_all = is_all(v) assert u_is_all == v_is_all if u_is_all: - num_edges = self.number_of_edges() + num_edges = self.cached_graph.num_edges() else: - u = utils.convert_to_id_tensor(u, self.context) - v = utils.convert_to_id_tensor(v, self.context) + u = utils.toindex(u) + v = utils.toindex(v) num_edges = max(len(u), len(v)) if isinstance(h_uv, dict): for key, val in h_uv.items(): @@ -188,10 +186,9 @@ def set_e_repr(self, h_uv, u=ALL, v=ALL): else: eid = self.cached_graph.get_edge_id(u, v) if isinstance(h_uv, dict): - for key, val in h_uv.items(): - self._edge_frame[key] = F.scatter_row(self._edge_frame[key], eid, val) + self._edge_frame[eid] = h_uv else: - self._edge_frame[__REPR__] = F.scatter_row(self._edge_frame[__REPR__], eid, h_uv) + self._edge_frame[eid] = {__REPR__ : h_uv} def set_e_repr_by_id(self, h_uv, eid=ALL): """Set edge(s) representation by edge id. @@ -205,9 +202,9 @@ def set_e_repr_by_id(self, h_uv, eid=ALL): """ # sanity check if is_all(eid): - num_edges = self.number_of_edges() + num_edges = self.cached_graph.num_edges() else: - eid = utils.convert_to_id_tensor(eid, self.context) + eid = utils.toindex(eid) num_edges = len(eid) if isinstance(h_uv, dict): for key, val in h_uv.items(): @@ -223,10 +220,9 @@ def set_e_repr_by_id(self, h_uv, eid=ALL): self._edge_frame[__REPR__] = h_uv else: if isinstance(h_uv, dict): - for key, val in h_uv.items(): - self._edge_frame[key] = F.scatter_row(self._edge_frame[key], eid, val) + self._edge_frame[eid] = h_uv else: - self._edge_frame[__REPR__] = F.scatter_row(self._edge_frame[__REPR__], eid, h_uv) + self._edge_frame[eid] = {__REPR__ : h_uv} def get_e_repr(self, u=ALL, v=ALL): """Get node(s) representation. @@ -247,11 +243,11 @@ def get_e_repr(self, u=ALL, v=ALL): else: return dict(self._edge_frame) else: - u = utils.convert_to_id_tensor(u, self.context) - v = utils.convert_to_id_tensor(v, self.context) + u = utils.toindex(u) + v = utils.toindex(v) eid = self.cached_graph.get_edge_id(u, v) if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame: - return self._edge_frame[__REPR__][eid] + return self._edge_frame.select_rows(eid)[__REPR__] else: return self._edge_frame.select_rows(eid) @@ -279,27 +275,12 @@ def get_e_repr_by_id(self, eid=ALL): else: return dict(self._edge_frame) else: - eid = utils.convert_to_id_tensor(eid, self.context) + eid = utils.toindex(eid) if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame: - return self._edge_frame[__REPR__][eid] + return self._edge_frame.select_rows(eid)[__REPR__] else: return self._edge_frame.select_rows(eid) - def set_device(self, ctx): - """Set device context for this graph. - - Parameters - ---------- - ctx : dgl.context.Context - The device context. - """ - self._context = ctx - - @property - def context(self): - """Get the device context of this graph.""" - return self._context - def register_message_func(self, message_func, batchable=False): @@ -356,27 +337,6 @@ def register_update_func(self, """ self._update_func = (update_func, batchable) - def readout(self, - readout_func, - nodes=ALL, - edges=ALL): - """Trigger the readout function on the specified nodes/edges. - - Parameters - ---------- - readout_func : callable - Readout function. - nodes : str, node, container or tensor - The nodes to get reprs from. - edges : str, pair of nodes, pair of containers or pair of tensors - The edges to get reprs from. - """ - nodes = self._nodes_or_all(nodes) - edges = self._edges_or_all(edges) - nstates = [self.nodes[n] for n in nodes] - estates = [self.edges[e] for e in edges] - return readout_func(nstates, estates) - def sendto(self, u, v, message_func=None, batchable=False): """Trigger the message function on edge u->v @@ -413,6 +373,9 @@ def _nonbatch_sendto(self, u, v, message_func): f_msg = _get_message_func(message_func) if is_all(u) and is_all(v): u, v = self.cached_graph.edges() + else: + u = utils.toindex(u) + v = utils.toindex(v) for uu, vv in utils.edge_iter(u, v): ret = f_msg(_get_repr(self.nodes[uu]), _get_repr(self.edges[uu, vv])) @@ -428,8 +391,8 @@ def _batch_sendto(self, u, v, message_func): edge_reprs = self.get_e_repr() msgs = message_func(src_reprs, edge_reprs) else: - u = utils.convert_to_id_tensor(u) - v = utils.convert_to_id_tensor(v) + u = utils.toindex(u) + v = utils.toindex(v) u, v = utils.edge_broadcasting(u, v) eid = self.cached_graph.get_edge_id(u, v) self.msg_graph.add_edges(u, v) @@ -475,6 +438,9 @@ def update_edge(self, u, v, edge_func=None, batchable=False): def _nonbatch_update_edge(self, u, v, edge_func): if is_all(u) and is_all(v): u, v = self.cached_graph.edges() + else: + u = utils.toindex(u) + v = utils.toindex(v) for uu, vv in utils.edge_iter(u, v): ret = edge_func(_get_repr(self.nodes[uu]), _get_repr(self.nodes[vv]), @@ -491,8 +457,8 @@ def _batch_update_edge(self, u, v, edge_func): new_edge_reprs = edge_func(src_reprs, dst_reprs, edge_reprs) self.set_e_repr(new_edge_reprs) else: - u = utils.convert_to_id_tensor(u) - v = utils.convert_to_id_tensor(v) + u = utils.toindex(u) + v = utils.toindex(v) u, v = utils.edge_broadcasting(u, v) eid = self.cached_graph.get_edge_id(u, v) # call the UDF @@ -559,6 +525,8 @@ def _nonbatch_recv(self, u, reduce_func, update_func): f_update = update_func if is_all(u): u = list(range(0, self.number_of_nodes())) + else: + u = utils.toindex(u) for i, uu in enumerate(utils.node_iter(u)): # reduce phase msgs_batch = [self.edges[vv, uu].pop(__MSG__) @@ -586,14 +554,16 @@ def _batch_recv(self, v, reduce_func, update_func): new_ns = f_update(reordered_ns, all_reduced_msgs) if is_all(v): # First do reorder and then replace the whole column. - _, indices = F.sort(reordered_v) - # TODO(minjie): manually convert ids to context. - indices = F.to_context(indices, self.context) + _, indices = F.sort(reordered_v.totensor()) + indices = utils.toindex(indices) + # TODO(minjie): following code should be included in Frame somehow. if isinstance(new_ns, dict): for key, val in new_ns.items(): - self._node_frame[key] = F.gather_row(val, indices) + idx = indices.totensor(F.get_context(val)) + self._node_frame[key] = F.gather_row(val, idx) else: - self._node_frame[__REPR__] = F.gather_row(new_ns, indices) + idx = indices.totensor(F.get_context(new_ns)) + self._node_frame[__REPR__] = F.gather_row(new_ns, idx) else: # Use setter to do reorder. self.set_n_repr(new_ns, reordered_v) @@ -605,9 +575,14 @@ def _batch_reduce(self, v, reduce_func): if is_all(v): v = list(range(self.number_of_nodes())) + + # freeze message graph + self.msg_graph.freeze() + # sanity checks - v = utils.convert_to_id_tensor(v) + v = utils.toindex(v) f_reduce = _get_reduce_func(reduce_func) + # degree bucketing degrees, v_buckets = scheduler.degree_bucketing(self.msg_graph, v) reduced_msgs = [] @@ -617,8 +592,6 @@ def _batch_reduce(self, v, reduce_func): bkt_len = len(v_bkt) uu, vv = self.msg_graph.in_edges(v_bkt) in_msg_ids = self.msg_graph.get_edge_id(uu, vv) - # TODO(minjie): manually convert ids to context. - in_msg_ids = F.to_context(in_msg_ids, self.context) in_msgs = self._msg_frame.select_rows(in_msg_ids) # Reshape the column tensor to (B, Deg, ...). def _reshape_fn(msg): @@ -641,10 +614,14 @@ def _reshape_fn(msg): self.clear_messages() # Read the node states in the degree-bucketing order. - reordered_v = F.pack(v_buckets) + reordered_v = utils.toindex(F.pack( + [v_bkt.totensor() for v_bkt in v_buckets])) # Pack all reduced msgs together if isinstance(reduced_msgs[0], dict): - all_reduced_msgs = {key : F.pack(val) for key, val in reduced_msgs.items()} + keys = reduced_msgs[0].keys() + all_reduced_msgs = { + key : F.pack([msg[key] for msg in reduced_msgs]) + for key in keys} else: all_reduced_msgs = F.pack(reduced_msgs) @@ -697,6 +674,9 @@ def _nonbatch_update_by_edge( update_func): if is_all(u) and is_all(v): u, v = self.cached_graph.edges() + else: + u = utils.toindex(u) + v = utils.toindex(v) self._nonbatch_sendto(u, v, message_func) dst = set() for uu, vv in utils.edge_iter(u, v): @@ -713,12 +693,14 @@ def _batch_update_by_edge( self.update_all(message_func, reduce_func, update_func, True) elif message_func == 'from_src' and reduce_func == 'sum': # TODO(minjie): check the validity of edges u->v - u = utils.convert_to_id_tensor(u) - v = utils.convert_to_id_tensor(v) + u = utils.toindex(u) + v = utils.toindex(v) # TODO(minjie): broadcasting is optional for many-one input. u, v = utils.edge_broadcasting(u, v) # relabel destination nodes. new2old, old2new = utils.build_relabel_map(v) + u = u.totensor() + v = v.totensor() # TODO(minjie): should not directly use [] new_v = old2new[v] # create adj mat @@ -726,8 +708,8 @@ def _batch_update_by_edge( dat = F.ones((len(u),)) n = self.number_of_nodes() m = len(new2old) + # TODO(minjie): context adjmat = F.sparse_tensor(idx, dat, [m, n]) - adjmat = F.to_context(adjmat, self.context) # TODO(minjie): use lazy dict for reduced_msgs reduced_msgs = {} for key in self._node_frame.schemes: @@ -739,10 +721,10 @@ def _batch_update_by_edge( new_node_repr = update_func(node_repr, reduced_msgs) self.set_n_repr(new_node_repr, new2old) else: - u = utils.convert_to_id_tensor(u, self.context) - v = utils.convert_to_id_tensor(v, self.context) + u = utils.toindex(u) + v = utils.toindex(v) self._batch_sendto(u, v, message_func) - unique_v = F.unique(v) + unique_v = F.unique(v.totensor()) self._batch_recv(unique_v, reduce_func, update_func) def update_to(self, @@ -776,15 +758,17 @@ def update_to(self, assert reduce_func is not None assert update_func is not None if batchable: + v = utils.toindex(v) uu, vv = self.cached_graph.in_edges(v) - self.update_by_edge(uu, vv, message_func, - reduce_func, update_func, batchable) + self._batch_update_by_edge(uu, vv, message_func, + reduce_func, update_func) else: + v = utils.toindex(v) for vv in utils.node_iter(v): assert vv in self.nodes uu = list(self.pred[vv]) - self.sendto(uu, vv, message_func, batchable) - self.recv(vv, reduce_func, update_func, batchable) + self._nonbatch_sendto(uu, vv, message_func) + self._nonbatch_recv(vv, reduce_func, update_func) def update_from(self, u, @@ -817,15 +801,17 @@ def update_from(self, assert reduce_func is not None assert update_func is not None if batchable: + u = utils.toindex(u) uu, vv = self.cached_graph.out_edges(u) - self.update_by_edge(uu, vv, message_func, - reduce_func, update_func, batchable) + self._batch_update_by_edge(uu, vv, message_func, + reduce_func, update_func) else: + u = utils.toindex(u) for uu in utils.node_iter(u): assert uu in self.nodes for v in self.succ[uu]: - self.update_by_edge(uu, v, - message_func, reduce_func, update_func, batchable) + self._nonbatch_update_by_edge(uu, v, + message_func, reduce_func, update_func) def update_all(self, message_func=None, @@ -857,10 +843,10 @@ def update_all(self, if batchable: if message_func == 'from_src' and reduce_func == 'sum': # TODO(minjie): use lazy dict for reduced_msgs - adjmat = self.cached_graph.adjmat(self.context) reduced_msgs = {} for key in self._node_frame.schemes: col = self._node_frame[key] + adjmat = self.cached_graph.adjmat(F.get_context(col)) reduced_msgs[key] = F.spmm(adjmat, col) if len(reduced_msgs) == 1 and __REPR__ in reduced_msgs: reduced_msgs = reduced_msgs[__REPR__] @@ -930,23 +916,52 @@ def subgraph(self, nodes): Returns ------- - G : DGLGraph + G : DGLSubGraph The subgraph. """ return dgl.DGLSubGraph(self, nodes) - def copy_from(self, graph): - """Copy node/edge features from the given graph. - - All old features will be removed. + def merge(self, subgraphs, reduce_func='sum'): + """Merge subgraph features back to this parent graph. Parameters ---------- - graph : DGLGraph - The graph to copy from. + subgraphs : iterator of DGLSubGraph + The subgraphs to be merged. + reduce_func : str + The reduce function (only 'sum' is supported currently) """ - # TODO - pass + # sanity check: all the subgraphs and the parent graph + # should have the same node/edge feature schemes. + # merge node features + to_merge = [] + for sg in subgraphs: + if len(sg.node_attr_schemes()) == 0: + continue + if sg.node_attr_schemes() != self.node_attr_schemes(): + raise RuntimeError('Subgraph and parent graph do not ' + 'have the same node attribute schemes.') + to_merge.append(sg) + self._node_frame = merge_frames( + [sg._node_frame for sg in to_merge], + [sg._parent_nid for sg in to_merge], + self._node_frame.num_rows, + reduce_func) + + # merge edge features + to_merge.clear() + for sg in subgraphs: + if len(sg.edge_attr_schemes()) == 0: + continue + if sg.edge_attr_schemes() != self.edge_attr_schemes(): + raise RuntimeError('Subgraph and parent graph do not ' + 'have the same edge attribute schemes.') + to_merge.append(sg) + self._edge_frame = merge_frames( + [sg._edge_frame for sg in to_merge], + [sg._parent_eid for sg in to_merge], + self._edge_frame.num_rows, + reduce_func) def draw(self): """Plot the graph using dot.""" @@ -996,14 +1011,10 @@ def get_edge_id(self, u, v): eid : tensor The tensor contains edge id(s). """ + u = utils.toindex(u) + v = utils.toindex(v) return self.cached_graph.get_edge_id(u, v) - def _nodes_or_all(self, nodes): - return self.nodes() if nodes == ALL else nodes - - def _edges_or_all(self, edges): - return self.edges() if edges == ALL else edges - def _add_node_callback(self, node): #print('New node:', node) self._cached_graph = None diff --git a/python/dgl/scheduler.py b/python/dgl/scheduler.py index 94c6f0b05a14..bd15eb208508 100644 --- a/python/dgl/scheduler.py +++ b/python/dgl/scheduler.py @@ -1,14 +1,34 @@ """Schedule policies for graph computation.""" from __future__ import absolute_import -import dgl.backend as F import numpy as np +import dgl.backend as F +import dgl.utils as utils + def degree_bucketing(cached_graph, v): - degrees = F.asnumpy(cached_graph.in_degrees(v)) + """Create degree bucketing scheduling policy. + + Parameters + ---------- + cached_graph : dgl.cached_graph.CachedGraph + the graph + v : dgl.utils.Index + the nodes to gather messages + + Returns + ------- + unique_degrees : list of int + list of unique degrees + v_bkt : list of dgl.utils.Index + list of node id buckets; nodes belong to the same bucket have + the same degree + """ + degrees = F.asnumpy(cached_graph.in_degrees(v).totensor()) unique_degrees = list(np.unique(degrees)) + v_np = np.array(v.tolist()) v_bkt = [] for deg in unique_degrees: idx = np.where(degrees == deg) - v_bkt.append(v[idx]) + v_bkt.append(utils.Index(v_np[idx])) return unique_degrees, v_bkt diff --git a/python/dgl/subgraph.py b/python/dgl/subgraph.py index 5266a781a183..bc12a482551e 100644 --- a/python/dgl/subgraph.py +++ b/python/dgl/subgraph.py @@ -13,46 +13,30 @@ class DGLSubGraph(DGLGraph): def __init__(self, parent, nodes): - # create subgraph and relabel - nx_sg = nx.DiGraph.subgraph(parent, nodes) - # node id - # TODO(minjie): context - nid = F.tensor(nodes, dtype=F.int64) - # edge id - # TODO(minjie): slow, context - u, v = zip(*nx_sg.edges) - u = list(u) - v = list(v) - eid = parent.cached_graph.get_edge_id(u, v) - - # relabel + super(DGLSubGraph, self).__init__() + # relabel nodes self._node_mapping = utils.build_relabel_dict(nodes) - nx_sg = nx.relabel.relabel_nodes(nx_sg, self._node_mapping) + self._parent_nid = utils.toindex(nodes) + eids = [] + # create subgraph + for eid, (u, v) in enumerate(parent.edge_list): + if u in self._node_mapping and v in self._node_mapping: + self.add_edge(self._node_mapping[u], + self._node_mapping[v]) + eids.append(eid) + self._parent_eid = utils.toindex(eids) + + def copy_from(self, parent): + """Copy node/edge features from the parent graph. + + All old features will be removed. - # init - self._edge_list = [] - nx_init(self, - self._add_node_callback, - self._add_edge_callback, - self._del_node_callback, - self._del_edge_callback, - nx_sg, - **parent.graph) - # cached graph and storage - self._cached_graph = None - if parent._node_frame.num_rows == 0: - self._node_frame = FrameRef() - else: - self._node_frame = FrameRef(Frame(parent._node_frame[nid])) - if parent._edge_frame.num_rows == 0: - self._edge_frame = FrameRef() - else: - self._edge_frame = FrameRef(Frame(parent._edge_frame[eid])) - # other class members - self._msg_graph = None - self._msg_frame = FrameRef() - self._message_func = parent._message_func - self._reduce_func = parent._reduce_func - self._update_func = parent._update_func - self._edge_func = parent._edge_func - self._context = parent._context + Parameters + ---------- + parent : DGLGraph + The parent graph to copy from. + """ + if parent._node_frame.num_rows != 0: + self._node_frame = FrameRef(Frame(parent._node_frame[self._parent_nid])) + if parent._edge_frame.num_rows != 0: + self._edge_frame = FrameRef(Frame(parent._edge_frame[self._parent_eid])) diff --git a/python/dgl/utils.py b/python/dgl/utils.py index 558984e654d3..c875e326662c 100644 --- a/python/dgl/utils.py +++ b/python/dgl/utils.py @@ -2,6 +2,9 @@ from __future__ import absolute_import from collections import Mapping +from functools import wraps +import numpy as np + import dgl.backend as F from dgl.backend import Tensor, SparseTensor @@ -11,18 +14,77 @@ def is_id_tensor(u): def is_id_container(u): """Return whether the input is a supported id container.""" - return isinstance(u, list) + return (getattr(u, '__iter__', None) is not None + and getattr(u, '__len__', None) is not None) + +class Index(object): + """Index class that can be easily converted to list/tensor.""" + def __init__(self, data): + self._list_data = None + self._tensor_data = None + self._ctx_data = dict() + self._dispatch(data) + + def _dispatch(self, data): + if is_id_tensor(data): + self._tensor_data = data + elif is_id_container(data): + self._list_data = data + else: + try: + self._list_data = [int(data)] + except: + raise TypeError('Error index data: %s' % str(x)) + + def tolist(self): + if self._list_data is None: + self._list_data = list(F.asnumpy(self._tensor_data)) + return self._list_data + + def totensor(self, ctx=None): + if self._tensor_data is None: + self._tensor_data = F.tensor(self._list_data, dtype=F.int64) + if ctx is None: + return self._tensor_data + if ctx not in self._ctx_data: + self._ctx_data[ctx] = F.to_context(self._tensor_data, ctx) + return self._ctx_data[ctx] + + def __iter__(self): + return iter(self.tolist()) + + def __len__(self): + if self._list_data is not None: + return len(self._list_data) + else: + return len(self._tensor_data) + + def __getitem__(self, i): + return self.tolist()[i] + +def toindex(x): + return x if isinstance(x, Index) else Index(x) def node_iter(n): - """Return an iterator that loops over the given nodes.""" - n = convert_to_id_container(n) - for nn in n: - yield nn + """Return an iterator that loops over the given nodes. + + Parameters + ---------- + n : iterable + The node ids. + """ + return iter(n) def edge_iter(u, v): - """Return an iterator that loops over the given edges.""" - u = convert_to_id_container(u) - v = convert_to_id_container(v) + """Return an iterator that loops over the given edges. + + Parameters + ---------- + u : iterable + The src ids. + v : iterable + The dst ids. + """ if len(u) == len(v): # many-many for uu, vv in zip(u, v): @@ -38,8 +100,33 @@ def edge_iter(u, v): else: raise ValueError('Error edges:', u, v) +def edge_broadcasting(u, v): + """Convert one-many and many-one edges to many-many. + + Parameters + ---------- + u : Index + The src id(s) + v : Index + The dst id(s) + + Returns + ------- + uu : Index + The src id(s) after broadcasting + vv : Index + The dst id(s) after broadcasting + """ + if len(u) != len(v) and len(u) == 1: + u = toindex(F.broadcast_to(u.totensor(), v.totensor())) + elif len(u) != len(v) and len(v) == 1: + v = toindex(F.broadcast_to(v.totensor(), u.totensor())) + else: + assert len(u) == len(v) + return u, v + +''' def convert_to_id_container(x): - """Convert the input to id container.""" if is_id_container(x): return x elif is_id_tensor(x): @@ -52,7 +139,6 @@ def convert_to_id_container(x): return None def convert_to_id_tensor(x, ctx=None): - """Convert the input to id tensor.""" if is_id_container(x): ret = F.tensor(x, dtype=F.int64) elif is_id_tensor(x): @@ -64,6 +150,7 @@ def convert_to_id_tensor(x, ctx=None): raise TypeError('Error node: %s' % str(x)) ret = F.to_context(ret, ctx) return ret +''' class LazyDict(Mapping): """A readonly dictionary that does not materialize the storage.""" @@ -110,7 +197,7 @@ def build_relabel_map(x): Parameters ---------- - x : int, tensor or container + x : Index The input ids. Returns @@ -122,7 +209,7 @@ def build_relabel_map(x): One can use advanced indexing to convert an old id tensor to a new id tensor: new_id = old_to_new[old_id] """ - x = convert_to_id_tensor(x) + x = x.totensor() unique_x, _ = F.sort(F.unique(x)) map_len = int(F.max(unique_x)) + 1 old_to_new = F.zeros(map_len, dtype=F.int64) @@ -150,12 +237,55 @@ def build_relabel_dict(x): relabel_dict[v] = i return relabel_dict -def edge_broadcasting(u, v): - """Convert one-many and many-one edges to many-many.""" - if len(u) != len(v) and len(u) == 1: - u = F.broadcast_to(u, v) - elif len(u) != len(v) and len(v) == 1: - v = F.broadcast_to(v, u) - else: - assert len(u) == len(v) - return u, v +class CtxCachedObject(object): + """A wrapper to cache object generated by different context. + + Note: such wrapper may incur significant overhead if the wrapped object is very light. + + Parameters + ---------- + generator : callable + A callable function that can create the object given ctx as the only argument. + """ + def __init__(self, generator): + self._generator = generator + self._ctx_dict = {} + + def get(self, ctx): + if not ctx in self._ctx_dict: + self._ctx_dict[ctx] = self._generator(ctx) + return self._ctx_dict[ctx] + +def ctx_cached_member(func): + """Convenient class member function wrapper to cache the function result. + + The wrapped function must only have two arguments: `self` and `ctx`. The former is the + class object and the later is the context. It will check whether the class object is + freezed (by checking the `_freeze` member). If yes, it caches the function result in + the field prefixed by '_CACHED_' before the function name. + """ + cache_name = '_CACHED_' + func.__name__ + @wraps(func) + def wrapper(self, ctx): + if self._freeze: + # cache + if getattr(self, cache_name, None) is None: + bind_func = lambda _ctx : func(self, _ctx) + setattr(self, cache_name, CtxCachedObject(bind_func)) + return getattr(self, cache_name).get(ctx) + else: + return func(self, ctx) + return wrapper + +def cached_member(func): + cache_name = '_CACHED_' + func.__name__ + @wraps(func) + def wrapper(self): + if self._freeze: + # cache + if getattr(self, cache_name, None) is None: + setattr(self, cache_name, func(self)) + return getattr(self, cache_name) + else: + return func(self) + return wrapper diff --git a/tests/pytorch/test_batching.py b/tests/pytorch/test_batching.py index 230a44f3249a..96694adeeb8f 100644 --- a/tests/pytorch/test_batching.py +++ b/tests/pytorch/test_batching.py @@ -26,6 +26,17 @@ def update_func(node, accum): assert node['h'].shape == accum.shape return {'h' : node['h'] + accum} +def reduce_dict_func(node, msgs): + msgs = msgs['m'] + reduce_msg_shapes.add(tuple(msgs.shape)) + assert len(msgs.shape) == 3 + assert msgs.shape[2] == D + return {'m' : th.sum(msgs, 1)} + +def update_dict_func(node, accum): + assert node['h'].shape == accum['m'].shape + return {'h' : node['h'] + accum['m']} + def generate_graph(grad=False): g = DGLGraph() for i in range(10): @@ -149,7 +160,8 @@ def _fmsg(src, edge): v = th.tensor([9]) g.sendto(u, v) -def test_batch_recv(): +def test_batch_recv1(): + # basic recv test g = generate_graph() g.register_message_func(message_func, batchable=True) g.register_reduce_func(reduce_func, batchable=True) @@ -162,6 +174,20 @@ def test_batch_recv(): assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)}) reduce_msg_shapes.clear() +def test_batch_recv2(): + # recv test with dict type reduce message + g = generate_graph() + g.register_message_func(message_func, batchable=True) + g.register_reduce_func(reduce_dict_func, batchable=True) + g.register_update_func(update_dict_func, batchable=True) + u = th.tensor([0, 0, 0, 4, 5, 6]) + v = th.tensor([1, 2, 3, 9, 9, 9]) + reduce_msg_shapes.clear() + g.sendto(u, v) + g.recv(th.unique(v)) + assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)}) + reduce_msg_shapes.clear() + def test_update_routines(): g = generate_graph() g.register_message_func(message_func, batchable=True) @@ -210,6 +236,7 @@ def _test_delete(): test_batch_setter_getter() test_batch_setter_autograd() test_batch_send() - test_batch_recv() + test_batch_recv1() + test_batch_recv2() test_update_routines() #test_delete() diff --git a/tests/pytorch/test_cached_graph.py b/tests/pytorch/test_cached_graph.py index 3bd96d12af8d..4030b6ea415a 100644 --- a/tests/pytorch/test_cached_graph.py +++ b/tests/pytorch/test_cached_graph.py @@ -3,6 +3,7 @@ import networkx as nx from dgl import DGLGraph from dgl.cached_graph import * +from dgl.utils import Index def check_eq(a, b): assert a.shape == b.shape @@ -15,22 +16,18 @@ def test_basics(): g.add_edge(1, 3) g.add_edge(2, 4) g.add_edge(2, 5) + g.add_edge(0, 2) cg = create_cached_graph(g) - u = th.tensor([0, 1, 1, 2, 2]) - v = th.tensor([1, 2, 3, 4, 5]) - check_eq(cg.get_edge_id(u, v), th.tensor([0, 1, 2, 3, 4])) - cg.add_edges(0, 2) - assert cg.get_edge_id(0, 2) == 5 - query = th.tensor([1, 2]) + u = Index(th.tensor([0, 0, 1, 1, 2, 2])) + v = Index(th.tensor([1, 2, 2, 3, 4, 5])) + check_eq(cg.get_edge_id(u, v).totensor(), th.tensor([0, 5, 1, 2, 3, 4])) + query = Index(th.tensor([1, 2])) s, d = cg.in_edges(query) - check_eq(s, th.tensor([0, 0, 1])) - check_eq(d, th.tensor([1, 2, 2])) + check_eq(s.totensor(), th.tensor([0, 0, 1])) + check_eq(d.totensor(), th.tensor([1, 2, 2])) s, d = cg.out_edges(query) - check_eq(s, th.tensor([1, 1, 2, 2])) - check_eq(d, th.tensor([2, 3, 4, 5])) - - print(cg._graph.get_adjacency()) - print(cg._graph.get_adjacency(eids=True)) + check_eq(s.totensor(), th.tensor([1, 1, 2, 2])) + check_eq(d.totensor(), th.tensor([2, 3, 4, 5])) if __name__ == '__main__': test_basics() diff --git a/tests/pytorch/test_frame.py b/tests/pytorch/test_frame.py index 72ed5a6b6740..9de6b54026b1 100644 --- a/tests/pytorch/test_frame.py +++ b/tests/pytorch/test_frame.py @@ -2,6 +2,7 @@ from torch.autograd import Variable import numpy as np from dgl.frame import Frame, FrameRef +from dgl.utils import Index N = 10 D = 5 @@ -112,7 +113,7 @@ def test_append2(): assert not f.is_span_whole_column() assert f.num_rows == 3 * N new_idx = list(range(N)) + list(range(2*N, 4*N)) - assert check_eq(f.index_tensor(), th.tensor(new_idx)) + assert check_eq(f.index().totensor(), th.tensor(new_idx)) assert data.num_rows == 4 * N def test_row1(): @@ -122,20 +123,20 @@ def test_row1(): # getter # test non-duplicate keys - rowid = th.tensor([0, 2]) + rowid = Index(th.tensor([0, 2])) rows = f[rowid] for k, v in rows.items(): assert v.shape == (len(rowid), D) assert check_eq(v, data[k][rowid]) # test duplicate keys - rowid = th.tensor([8, 2, 2, 1]) + rowid = Index(th.tensor([8, 2, 2, 1])) rows = f[rowid] for k, v in rows.items(): assert v.shape == (len(rowid), D) assert check_eq(v, data[k][rowid]) # setter - rowid = th.tensor([0, 2, 4]) + rowid = Index(th.tensor([0, 2, 4])) vals = {'a1' : th.zeros((len(rowid), D)), 'a2' : th.zeros((len(rowid), D)), 'a3' : th.zeros((len(rowid), D)), @@ -152,13 +153,13 @@ def test_row2(): # getter c1 = f['a1'] # test non-duplicate keys - rowid = th.tensor([0, 2]) + rowid = Index(th.tensor([0, 2])) rows = f[rowid] rows['a1'].backward(th.ones((len(rowid), D))) assert check_eq(c1.grad[:,0], th.tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 0.])) c1.grad.data.zero_() # test duplicate keys - rowid = th.tensor([8, 2, 2, 1]) + rowid = Index(th.tensor([8, 2, 2, 1])) rows = f[rowid] rows['a1'].backward(th.ones((len(rowid), D))) assert check_eq(c1.grad[:,0], th.tensor([0., 1., 2., 0., 0., 0., 0., 0., 1., 0.])) @@ -166,7 +167,7 @@ def test_row2(): # setter c1 = f['a1'] - rowid = th.tensor([0, 2, 4]) + rowid = Index(th.tensor([0, 2, 4])) vals = {'a1' : Variable(th.zeros((len(rowid), D)), requires_grad=True), 'a2' : Variable(th.zeros((len(rowid), D)), requires_grad=True), 'a3' : Variable(th.zeros((len(rowid), D)), requires_grad=True), @@ -210,14 +211,14 @@ def test_sharing(): f2_a1 = f2['a1'] # test write # update own ref should not been seen by the other. - f1[th.tensor([0, 1])] = { + f1[Index(th.tensor([0, 1]))] = { 'a1' : th.zeros([2, D]), 'a2' : th.zeros([2, D]), 'a3' : th.zeros([2, D]), } assert check_eq(f2['a1'], f2_a1) # update shared space should been seen by the other. - f1[th.tensor([2, 3])] = { + f1[Index(th.tensor([2, 3]))] = { 'a1' : th.ones([2, D]), 'a2' : th.ones([2, D]), 'a3' : th.ones([2, D]), diff --git a/tests/test_graph_batch.py b/tests/pytorch/test_graph_batch.py similarity index 100% rename from tests/test_graph_batch.py rename to tests/pytorch/test_graph_batch.py diff --git a/tests/pytorch/test_subgraph.py b/tests/pytorch/test_subgraph.py index 4bf2cf4da44a..eef0bf9f1ec5 100644 --- a/tests/pytorch/test_subgraph.py +++ b/tests/pytorch/test_subgraph.py @@ -24,35 +24,71 @@ def generate_graph(grad=False): g.set_e_repr({'l' : ecol}) return g -def test_subgraph(): +def test_basics(): g = generate_graph() h = g.get_n_repr()['h'] l = g.get_e_repr()['l'] - sg = g.subgraph([0, 2, 3, 6, 7, 9]) + nid = [0, 2, 3, 6, 7, 9] + eid = [2, 3, 4, 5, 10, 11, 12, 13, 16] + sg = g.subgraph(nid) + # the subgraph is empty initially + assert len(sg.get_n_repr()) == 0 + assert len(sg.get_e_repr()) == 0 + # the data is copied after explict copy from + sg.copy_from(g) + assert len(sg.get_n_repr()) == 1 + assert len(sg.get_e_repr()) == 1 sh = sg.get_n_repr()['h'] - check_eq(h[th.tensor([0, 2, 3, 6, 7, 9])], sh) + assert check_eq(h[nid], sh) ''' s, d, eid 0, 1, 0 1, 9, 1 - 0, 2, 2 - 2, 9, 3 - 0, 3, 4 - 3, 9, 5 + 0, 2, 2 1 + 2, 9, 3 1 + 0, 3, 4 1 + 3, 9, 5 1 0, 4, 6 4, 9, 7 0, 5, 8 - 5, 9, 9 - 0, 6, 10 - 6, 9, 11 - 0, 7, 12 - 7, 9, 13 + 5, 9, 9 3 + 0, 6, 10 1 + 6, 9, 11 1 3 + 0, 7, 12 1 + 7, 9, 13 1 3 0, 8, 14 - 8, 9, 15 - 9, 0, 16 + 8, 9, 15 3 + 9, 0, 16 1 ''' - eid = th.tensor([2, 3, 4, 5, 10, 11, 12, 13, 16]) - check_eq(l[eid], sg.get_e_repr()['l']) + assert check_eq(l[eid], sg.get_e_repr()['l']) + # update the node/edge features on the subgraph should NOT + # reflect to the parent graph. + sg.set_n_repr({'h' : th.zeros((6, D))}) + assert check_eq(h, g.get_n_repr()['h']) + +def test_merge(): + g = generate_graph() + g.set_n_repr({'h' : th.zeros((10, D))}) + g.set_e_repr({'l' : th.zeros((17, D))}) + # subgraphs + sg1 = g.subgraph([0, 2, 3, 6, 7, 9]) + sg1.set_n_repr({'h' : th.ones((6, D))}) + sg1.set_e_repr({'l' : th.ones((9, D))}) + + sg2 = g.subgraph([0, 2, 3, 4]) + sg2.set_n_repr({'h' : th.ones((4, D)) * 2}) + + sg3 = g.subgraph([5, 6, 7, 8, 9]) + sg3.set_e_repr({'l' : th.ones((4, D)) * 3}) + + g.merge([sg1, sg2, sg3]) + + h = g.get_n_repr()['h'][:,0] + l = g.get_e_repr()['l'][:,0] + assert check_eq(h, th.tensor([3., 0., 3., 3., 2., 0., 1., 1., 0., 1.])) + assert check_eq(l, + th.tensor([0., 0., 1., 1., 1., 1., 0., 0., 0., 3., 1., 4., 1., 4., 0., 3., 1.])) if __name__ == '__main__': - test_subgraph() + test_basics() + test_merge() diff --git a/tests/test_basics.py b/tests/test_basics.py index 3b970398aec8..1cac632363e1 100644 --- a/tests/test_basics.py +++ b/tests/test_basics.py @@ -6,6 +6,12 @@ def message_func(src, edge): def update_func(node, accum): return {'h' : node['h'] + accum} +def message_dict_func(src, edge): + return {'m' : src['h']} + +def update_dict_func(node, accum): + return {'h' : node['h'] + accum['m']} + def generate_graph(): g = DGLGraph() for i in range(10): @@ -23,12 +29,18 @@ def check(g, h): h = [str(x) for x in h] assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h)) -def test_sendrecv(): - g = generate_graph() - check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +def register1(g): g.register_message_func(message_func) g.register_update_func(update_func) g.register_reduce_func('sum') + +def register2(g): + g.register_message_func(message_dict_func) + g.register_update_func(update_dict_func) + g.register_reduce_func('sum') + +def _test_sendrecv(g): + check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) g.sendto(0, 1) g.recv(1) check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10]) @@ -37,12 +49,8 @@ def test_sendrecv(): g.recv(9) check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 23]) -def test_multi_sendrecv(): - g = generate_graph() +def _test_multi_sendrecv(g): check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) - g.register_message_func(message_func) - g.register_update_func(update_func) - g.register_reduce_func('sum') # one-many g.sendto(0, [1, 2, 3]) g.recv([1, 2, 3]) @@ -56,12 +64,8 @@ def test_multi_sendrecv(): g.recv([4, 5, 9]) check(g, [1, 3, 4, 5, 6, 7, 7, 8, 9, 45]) -def test_update_routines(): - g = generate_graph() +def _test_update_routines(g): check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) - g.register_message_func(message_func) - g.register_update_func(update_func) - g.register_reduce_func('sum') g.update_by_edge(0, 1) check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10]) g.update_to(9) @@ -71,6 +75,30 @@ def test_update_routines(): g.update_all() check(g, [56, 5, 5, 6, 7, 8, 9, 10, 11, 108]) +def test_sendrecv(): + g = generate_graph() + register1(g) + _test_sendrecv(g) + g = generate_graph() + register2(g) + _test_sendrecv(g) + +def test_multi_sendrecv(): + g = generate_graph() + register1(g) + _test_multi_sendrecv(g) + g = generate_graph() + register2(g) + _test_multi_sendrecv(g) + +def test_update_routines(): + g = generate_graph() + register1(g) + _test_update_routines(g) + g = generate_graph() + register2(g) + _test_update_routines(g) + if __name__ == '__main__': test_sendrecv() test_multi_sendrecv()