diff --git a/examples/pytorch/gcn/gcn.py b/examples/pytorch/gcn/gcn.py index 1469f284f9c0..ac24b73c79e5 100644 --- a/examples/pytorch/gcn/gcn.py +++ b/examples/pytorch/gcn/gcn.py @@ -16,10 +16,10 @@ from dgl.data import register_data_args, load_data def gcn_msg(src, edge): - return src + return {'m' : src['h']} def gcn_reduce(node, msgs): - return torch.sum(msgs, 1) + return {'h' : torch.sum(msgs['m'], 1)} class NodeApplyModule(nn.Module): def __init__(self, in_feats, out_feats, activation=None): @@ -28,10 +28,10 @@ def __init__(self, in_feats, out_feats, activation=None): self.activation = activation def forward(self, node): - h = self.linear(node) + h = self.linear(node['h']) if self.activation: h = self.activation(h) - return h + return {'h' : h} class GCN(nn.Module): def __init__(self, @@ -54,14 +54,14 @@ def __init__(self, self.layers.append(NodeApplyModule(n_hidden, n_classes)) def forward(self, features): - self.g.set_n_repr(features) + self.g.set_n_repr({'h' : features}) for layer in self.layers: # apply dropout if self.dropout: - val = F.dropout(self.g.get_n_repr(), p=self.dropout) - self.g.set_n_repr(val) + g.apply_nodes(apply_node_func= + lambda node: F.dropout(node['h'], p=self.dropout)) self.g.update_all(gcn_msg, gcn_reduce, layer) - return self.g.pop_n_repr() + return self.g.pop_n_repr('h') def main(args): # load and preprocess dataset diff --git a/examples/pytorch/gcn/gcn_spmv.py b/examples/pytorch/gcn/gcn_spmv.py index c49c0f6b9451..ca724d9001bf 100644 --- a/examples/pytorch/gcn/gcn_spmv.py +++ b/examples/pytorch/gcn/gcn_spmv.py @@ -23,10 +23,10 @@ def __init__(self, in_feats, out_feats, activation=None): self.activation = activation def forward(self, node): - h = self.linear(node) + h = self.linear(node['h']) if self.activation: h = self.activation(h) - return h + return {'h' : h} class GCN(nn.Module): def __init__(self, @@ -49,14 +49,16 @@ def __init__(self, self.layers.append(NodeApplyModule(n_hidden, n_classes)) def forward(self, features): - self.g.set_n_repr(features) + self.g.set_n_repr({'h' : features}) for layer in self.layers: # apply dropout if self.dropout: - val = F.dropout(self.g.get_n_repr(), p=self.dropout) - self.g.set_n_repr(val) - self.g.update_all(fn.copy_src(), fn.sum(), layer) - return self.g.pop_n_repr() + g.apply_nodes(apply_node_func= + lambda node: F.dropout(node['h'], p=self.dropout)) + self.g.update_all(fn.copy_src(src='h', out='m'), + fn.sum(msgs='m', out='h'), + layer) + return self.g.pop_n_repr('h') def main(args): # load and preprocess dataset diff --git a/python/dgl/backend/pytorch.py b/python/dgl/backend/pytorch.py index 097e80957be4..809a71589638 100644 --- a/python/dgl/backend/pytorch.py +++ b/python/dgl/backend/pytorch.py @@ -93,23 +93,24 @@ def get_context(arr): return TVMContext( TVMContext.STR2MASK[arr.device.type], arr.device.index) -def _typestr(arr_dtype): +def get_tvmtype(arr): + arr_dtype = arr.dtype if arr_dtype in (th.float16, th.half): - return 'float16' + return TVMType('float16') elif arr_dtype in (th.float32, th.float): - return 'float32' + return TVMType('float32') elif arr_dtype in (th.float64, th.double): - return 'float64' + return TVMType('float64') elif arr_dtype in (th.int16, th.short): - return 'int16' + return TVMType('int16') elif arr_dtype in (th.int32, th.int): - return 'int32' + return TVMType('int32') elif arr_dtype in (th.int64, th.long): - return 'int64' + return TVMType('int64') elif arr_dtype == th.int8: - return 'int8' + return TVMType('int8') elif arr_dtype == th.uint8: - return 'uint8' + return TVMType('uint8') else: raise RuntimeError('Unsupported data type:', arr_dtype) @@ -130,20 +131,6 @@ def zerocopy_from_numpy(np_data): """Return a tensor that shares the numpy data.""" return th.from_numpy(np_data) - ''' - data = arr_data - assert data.is_contiguous() - arr = TVMArray() - shape = c_array(tvm_shape_index_t, tuple(data.shape)) - arr.data = ctypes.cast(data.data_ptr(), ctypes.c_void_p) - arr.shape = shape - arr.strides = None - arr.dtype = TVMType(_typestr(data.dtype)) - arr.ndim = len(shape) - arr.ctx = get_context(data) - return arr - ''' - def nonzero_1d(arr): """Return a 1D tensor with nonzero element indices in a 1D vector""" assert arr.dim() == 1 diff --git a/python/dgl/base.py b/python/dgl/base.py index 8658ee539b22..4693334d0f80 100644 --- a/python/dgl/base.py +++ b/python/dgl/base.py @@ -1,4 +1,9 @@ """Module for base types and utilities.""" +from __future__ import absolute_import + +import warnings + +from ._ffi.base import DGLError # A special argument for selecting all nodes/edges. ALL = "__ALL__" @@ -8,3 +13,5 @@ def is_all(arg): __MSG__ = "__MSG__" __REPR__ = "__REPR__" + +dgl_warning = warnings.warn diff --git a/python/dgl/frame.py b/python/dgl/frame.py index a1c34c096076..3cc00a1aaf53 100644 --- a/python/dgl/frame.py +++ b/python/dgl/frame.py @@ -1,4 +1,4 @@ -"""Columnar storage for graph attributes.""" +"""Columnar storage for DGLGraph.""" from __future__ import absolute_import from collections import MutableMapping @@ -6,178 +6,598 @@ from . import backend as F from .backend import Tensor +from .base import DGLError, dgl_warning from . import utils +class Scheme(object): + """The column scheme. + + Parameters + ---------- + shape : tuple of int + The feature shape. + dtype : TVMType + The feature data type. + """ + def __init__(self, shape, dtype): + self.shape = shape + self.dtype = dtype + + def __repr__(self): + return '{shape=%s, dtype=%s}' % (repr(self.shape), repr(self.dtype)) + + def __eq__(self, other): + return self.shape == other.shape and self.dtype == other.dtype + + def __ne__(self, other): + return not self.__eq__(other) + + @staticmethod + def infer_scheme(tensor): + """Infer the scheme of the given tensor.""" + return Scheme(tuple(F.shape(tensor)[1:]), F.get_tvmtype(tensor)) + +class Column(object): + """A column is a compact store of features of multiple nodes/edges. + + Currently, we use one dense tensor to batch all the feature tensors + together (along the first dimension). + + Parameters + ---------- + data : Tensor + The initial data of the column. + scheme : Scheme, optional + The scheme of the column. Will be inferred if not provided. + """ + def __init__(self, data, scheme=None): + self.data = data + self.scheme = scheme if scheme else Scheme.infer_scheme(data) + + def __len__(self): + """The column length.""" + return F.shape(self.data)[0] + + def __getitem__(self, idx): + """Return the feature data given the index. + + Parameters + ---------- + idx : utils.Index + The index. + + Returns + ------- + Tensor + The feature data + """ + user_idx = idx.tousertensor(F.get_context(self.data)) + return F.gather_row(self.data, user_idx) + + def __setitem__(self, idx, feats): + """Update the feature data given the index. + + The update is performed out-placely so it can be used in autograd mode. + For inplace write, please use ``update``. + + Parameters + ---------- + idx : utils.Index + The index. + feats : Tensor + The new features. + """ + self.update(idx, feats, inplace=False) + + def update(self, idx, feats, inplace): + """Update the feature data given the index. + + Parameters + ---------- + idx : utils.Index + The index. + feats : Tensor + The new features. + inplace : bool + If true, use inplace write. + """ + feat_scheme = Scheme.infer_scheme(feats) + if feat_scheme != self.scheme: + raise DGLError("Cannot update column of scheme %s using feature of scheme %s." + % (feat_scheme, self.scheme)) + user_idx = idx.tousertensor(F.get_context(self.data)) + if inplace: + # TODO(minjie): do not use [] operator directly + self.data[user_idx] = feats + else: + self.data = F.scatter_row(self.data, user_idx, feats) + + @staticmethod + def create(data): + """Create a new column using the given data.""" + if isinstance(data, Column): + return Column(data.data) + else: + return Column(data) + class Frame(MutableMapping): + """The columnar storage for node/edge features. + + The frame is a dictionary from feature fields to feature columns. + All columns should have the same number of rows (i.e. the same first dimension). + + Parameters + ---------- + data : dict-like, optional + The frame data in dictionary. If the provided data is another frame, + this frame will NOT share columns with the given frame. So any out-place + update on one will not reflect to the other. The inplace update will + be seen by both. This follows the semantic of python's container. + """ def __init__(self, data=None): if data is None: self._columns = dict() self._num_rows = 0 else: - self._columns = dict(data) - self._num_rows = F.shape(list(data.values())[0])[0] - for k, v in data.items(): - assert F.shape(v)[0] == self._num_rows + # Note that we always create a new column for the given data. + # This avoids two frames accidentally sharing the same column. + self._columns = {k : Column.create(v) for k, v in data.items()} + if len(self._columns) != 0: + self._num_rows = len(next(iter(self._columns.values()))) + else: + self._num_rows = 0 + # sanity check + for name, col in self._columns.items(): + if len(col) != self._num_rows: + raise DGLError('Expected all columns to have same # rows (%d), ' + 'got %d on %r.' % (self._num_rows, len(col), name)) + # Initializer for empty values. Initializer is a callable. + # If is none, then a warning will be raised + # in the first call and zero initializer will be used later. + self._initializer = None + + def set_initializer(self, initializer): + """Set the initializer for empty values. + + Initializer is a callable that returns a tensor given the shape and data type. + + Parameters + ---------- + initializer : callable + The initializer. + """ + self._initializer = initializer + + @property + def initializer(self): + """Return the initializer of this frame.""" + return self._initializer @property def schemes(self): - return set(self._columns.keys()) + """Return a dictionary of column name to column schemes.""" + return {k : col.scheme for k, col in self._columns.items()} @property def num_columns(self): + """Return the number of columns in this frame.""" return len(self._columns) @property def num_rows(self): + """Return the number of rows in this frame.""" return self._num_rows - def __contains__(self, key): - return key in self._columns - - def __getitem__(self, key): - # get column - return self._columns[key] - - def __setitem__(self, key, val): - # set column - self.add_column(key, val) - - def __delitem__(self, key): - # delete column - del self._columns[key] + def __contains__(self, name): + """Return true if the given column name exists.""" + return name in self._columns + + def __getitem__(self, name): + """Return the column of the given name. + + Parameters + ---------- + name : str + The column name. + + Returns + ------- + Column + The column. + """ + return self._columns[name] + + def __setitem__(self, name, data): + """Update the whole column. + + Parameters + ---------- + name : str + The column name. + col : Column or data convertible to Column + The column data. + """ + self.update_column(name, data) + + def __delitem__(self, name): + """Delete the whole column. + + Parameters + ---------- + name : str + The column name. + """ + del self._columns[name] if len(self._columns) == 0: self._num_rows = 0 - def add_column(self, name, col): + def add_column(self, name, scheme, ctx): + """Add a new column to the frame. + + The frame will be initialized by the initializer. + + Parameters + ---------- + name : str + The column name. + scheme : Scheme + The column scheme. + ctx : TVMContext + The column context. + """ + if name in self: + dgl_warning('Column "%s" already exists. Ignore adding this column again.' % name) + return + if self.num_rows == 0: + raise DGLError('Cannot add column "%s" using column schemes because' + ' number of rows is unknown. Make sure there is at least' + ' one column in the frame so number of rows can be inferred.') + if self.initializer is None: + dgl_warning('Initializer is not set. Use zero initializer instead.' + ' To suppress this warning, use `set_initializer` to' + ' explicitly specify which initializer to use.') + # TODO(minjie): handle data type + self.set_initializer(lambda shape, dtype : F.zeros(shape)) + # TODO(minjie): directly init data on the targer device. + init_data = self.initializer((self.num_rows,) + scheme.shape, scheme.dtype) + init_data = F.to_context(init_data, ctx) + self._columns[name] = Column(init_data, scheme) + + def update_column(self, name, data): + """Add or replace the column with the given name and data. + + Parameters + ---------- + name : str + The column name. + data : Column or data convertible to Column + The column data. + """ + col = Column.create(data) if self.num_columns == 0: - self._num_rows = F.shape(col)[0] - else: - assert F.shape(col)[0] == self._num_rows + self._num_rows = len(col) + elif len(col) != self._num_rows: + raise DGLError('Expected data to have %d rows, got %d.' % + (self._num_rows, len(col))) self._columns[name] = col def append(self, other): + """Append another frame's data into this frame. + + If the current frame is empty, it will just use the columns of the + given frame. Otherwise, the given data should contain all the + column keys of this frame. + + Parameters + ---------- + other : Frame or dict-like + The frame data to be appended. + """ + if not isinstance(other, Frame): + other = Frame(other) if len(self._columns) == 0: for key, col in other.items(): self._columns[key] = col + self._num_rows = other.num_rows else: for key, col in other.items(): - self._columns[key] = F.pack([self[key], col]) - # TODO(minjie): sanity check for num_rows - if len(self._columns) != 0: - self._num_rows = F.shape(list(self._columns.values())[0])[0] + sch = self._columns[key].scheme + other_sch = col.scheme + if sch != other_sch: + raise DGLError("Cannot append column of scheme %s to column of scheme %s." + % (other_scheme, sch)) + self._columns[key].data = F.pack( + [self._columns[key].data, col.data]) + self._num_rows += other.num_rows def clear(self): + """Clear this frame. Remove all the columns.""" self._columns = {} self._num_rows = 0 def __iter__(self): + """Return an iterator of columns.""" return iter(self._columns) def __len__(self): + """Return the number of columns.""" return self.num_columns + def keys(self): + """Return the keys.""" + return self._columns.keys() + class FrameRef(MutableMapping): - """Frame reference + """Reference object to a frame on a subset of rows. Parameters ---------- - frame : dgl.frame.Frame - The underlying frame. - index : iterable of int - The rows that are referenced in the underlying frame. + frame : Frame, optional + The underlying frame. If not given, the reference will point to a + new empty frame. + index : iterable of int, optional + The rows that are referenced in the underlying frame. If not given, + the whole frame is referenced. The index should be distinct (no + duplication is allowed). """ def __init__(self, frame=None, index=None): self._frame = frame if frame is not None else Frame() if index is None: self._index_data = slice(0, self._frame.num_rows) else: - # check no duplication - assert len(index) == len(np.unique(index)) + # TODO(minjie): check no duplication self._index_data = index self._index = None @property def schemes(self): + """Return the frame schemes. + + Returns + ------- + dict of str to Scheme + The frame schemes. + """ return self._frame.schemes @property def num_columns(self): + """Return the number of columns in the referred frame.""" return self._frame.num_columns @property def num_rows(self): + """Return the number of rows referred.""" if isinstance(self._index_data, slice): + # NOTE: we are assuming that the index is a slice ONLY IF + # index=None during construction. + # As such, start is always 0, and step is always 1. return self._index_data.stop else: return len(self._index_data) - def __contains__(self, key): - return key in self._frame + def set_initializer(self, initializer): + """Set the initializer for empty values. + + Initializer is a callable that returns a tensor given the shape and data type. + + Parameters + ---------- + initializer : callable + The initializer. + """ + self._frame.set_initializer(initializer) + + def index(self): + """Return the index object. + + Returns + ------- + utils.Index + The index. + """ + if self._index is None: + if self.is_contiguous(): + self._index = utils.toindex( + F.arange(self._index_data.stop, dtype=F.int64)) + else: + self._index = utils.toindex(self._index_data) + return self._index + + def __contains__(self, name): + """Return whether the column name exists.""" + return name in self._frame + + def __iter__(self): + """Return the iterator of the columns.""" + return iter(self._frame) + + def __len__(self): + """Return the number of columns.""" + return self.num_columns + + def keys(self): + """Return the keys.""" + return self._frame.keys() def __getitem__(self, key): + """Get data from the frame. + + If the provided key is string, the corresponding column data will be returned. + If the provided key is an index, the corresponding rows will be selected. The + returned rows are saved in a lazy dictionary so only the real selection happens + when the explicit column name is provided. + + Examples (using pytorch) + ------------------------ + >>> # create a frame of two columns and five rows + >>> f = Frame({'c1' : torch.zeros([5, 2]), 'c2' : torch.ones([5, 2])}) + >>> fr = FrameRef(f) + >>> # select the row 1 and 2, the returned `rows` is a lazy dictionary. + >>> rows = fr[Index([1, 2])] + >>> rows['c1'] # only select rows for 'c1' column; 'c2' column is not sliced. + + Parameters + ---------- + key : str or utils.Index + The key. + + Returns + ------- + Tensor or lazy dict or tensors + Depends on whether it is a column selection or row selection. + """ if isinstance(key, str): - return self.get_column(key) + return self.select_column(key) else: return self.select_rows(key) - def select_rows(self, query): - rowids = self._getrowid(query) - def _lazy_select(key): - idx = rowids.tousertensor(F.get_context(self._frame[key])) - return F.gather_row(self._frame[key], idx) - return utils.LazyDict(_lazy_select, keys=self.schemes) + def select_column(self, name): + """Return the column of the given name. + + If only part of the rows are referenced, the fetching the whole column will + also slice out the referenced rows. + + Parameters + ---------- + name : str + The column name. - def get_column(self, name): + Returns + ------- + Tensor + The column data. + """ col = self._frame[name] if self.is_span_whole_column(): - return col + return col.data else: - idx = self.index().tousertensor(F.get_context(col)) - return F.gather_row(col, idx) + return col[self.index()] + + def select_rows(self, query): + """Return the rows given the query. + + Parameters + ---------- + query : utils.Index + The rows to be selected. + + Returns + ------- + utils.LazyDict + The lazy dictionary from str to the selected data. + """ + rowids = self._getrowid(query) + return utils.LazyDict(lambda key: self._frame[key][rowids], keys=self.keys()) def __setitem__(self, key, val): + """Update the data in the frame. + + If the provided key is string, the corresponding column data will be updated. + The provided value should be one tensor that have the same scheme and length + as the column. + + If the provided key is an index, the corresponding rows will be updated. The + value provided should be a dictionary of string to the data of each column. + + All updates are performed out-placely to be work with autograd. For inplace + update, use ``update_column`` or ``update_rows``. + + Parameters + ---------- + key : str or utils.Index + The key. + val : Tensor or dict of tensors + The value. + """ if isinstance(key, str): - self.add_column(key, val) + self.update_column(key, val, inplace=False) else: - self.update_rows(key, val) - - def add_column(self, name, col, inplace=False): - shp = F.shape(col) + self.update_rows(key, val, inplace=False) + + def update_column(self, name, data, inplace): + """Update the column. + + If this frameref spans the whole column of the underlying frame, this is + equivalent to update the column of the frame. + + If this frameref only points to part of the rows, then update the column + here will correspond to update part of the column in the frame. Raise error + if the given column name does not exist. + + Parameters + ---------- + name : str + The column name. + data : Tensor + The update data. + inplace : bool + True if the update is performed inplacely. + """ if self.is_span_whole_column(): + col = Column.create(data) if self.num_columns == 0: - self._index_data = slice(0, shp[0]) + # the frame is empty + self._index_data = slice(0, len(col)) self._clear_cache() - assert shp[0] == self.num_rows self._frame[name] = col else: - colctx = F.get_context(col) - if name in self._frame: - fcol = self._frame[name] - else: - fcol = F.zeros((self._frame.num_rows,) + shp[1:]) - fcol = F.to_context(fcol, colctx) - idx = self.index().tousertensor(colctx) - if inplace: - self._frame[name] = fcol - self._frame[name][idx] = col - else: - newfcol = F.scatter_row(fcol, idx, col) - self._frame[name] = newfcol - - def update_rows(self, query, other, inplace=False): + if name not in self._frame: + feat_shape = F.shape(data)[1:] + feat_dtype = F.get_tvmtype(data) + ctx = F.get_context(data) + self._frame.add_column(name, Scheme(feat_shape, feat_dtype), ctx) + #raise DGLError('Cannot update column. Column "%s" does not exist.' + # ' Did you forget to init the column using `set_n_repr`' + # ' or `set_e_repr`?' % name) + fcol = self._frame[name] + fcol.update(self.index(), data, inplace) + + def update_rows(self, query, data, inplace): + """Update the rows. + + If the provided data has new column, it will be added to the frame. + + See Also + -------- + ``update_column`` + + Parameters + ---------- + query : utils.Index + The rows to be updated. + data : dict-like + The row data. + inplace : bool + True if the update is performed inplacely. + """ rowids = self._getrowid(query) - for key, col in other.items(): + for key, col in data.items(): if key not in self: # add new column tmpref = FrameRef(self._frame, rowids) - tmpref.add_column(key, col, inplace) - idx = rowids.tousertensor(F.get_context(self._frame[key])) - if inplace: - self._frame[key][idx] = col + tmpref.update_column(key, col, inplace) + #raise DGLError('Cannot update rows. Column "%s" does not exist.' + # ' Did you forget to init the column using `set_n_repr`' + # ' or `set_e_repr`?' % key) else: - self._frame[key] = F.scatter_row(self._frame[key], idx, col) + self._frame[key].update(rowids, col, inplace) def __delitem__(self, key): + """Delete data in the frame. + + If the provided key is a string, the corresponding column will be deleted. + If the provided key is an index object, the corresponding rows will be deleted. + + Please note that "deleted" rows are not really deleted, but simply removed + in the reference. As a result, if two FrameRefs point to the same Frame, deleting + from one ref will not relect on the other. By contrast, deleting columns is real. + + Parameters + ---------- + key : str or utils.Index + The key. + """ if isinstance(key, str): del self._frame[key] if len(self._frame) == 0: @@ -186,7 +606,18 @@ def __delitem__(self, key): self.delete_rows(key) def delete_rows(self, query): - query = F.asnumpy(query) + """Delete rows. + + Please note that "deleted" rows are not really deleted, but simply removed + in the reference. As a result, if two FrameRefs point to the same Frame, deleting + from one ref will not relect on the other. By contrast, deleting columns is real. + + Parameters + ---------- + query : utils.Index + The rows to be deleted. + """ + query = query.tolist() if isinstance(self._index_data, slice): self._index_data = list(range(self._index_data.start, self._index_data.stop)) arr = np.array(self._index_data, dtype=np.int32) @@ -194,6 +625,13 @@ def delete_rows(self, query): self._clear_cache() def append(self, other): + """Append another frame into this one. + + Parameters + ---------- + other : dict of str to tensor + The data to be appended. + """ span_whole = self.is_span_whole_column() contiguous = self.is_contiguous() old_nrows = self._frame.num_rows @@ -208,24 +646,23 @@ def append(self, other): self._clear_cache() def clear(self): + """Clear the frame.""" self._frame.clear() self._index_data = slice(0, 0) self._clear_cache() - def __iter__(self): - return iter(self._frame) - - def __len__(self): - return self.num_columns - def is_contiguous(self): - # NOTE: this check could have false negative + """Return whether this refers to a contiguous range of rows.""" + # NOTE: this check could have false negatives and false positives + # (step other than 1) return isinstance(self._index_data, slice) def is_span_whole_column(self): + """Return whether this refers to all the rows.""" return self.is_contiguous() and self.num_rows == self._frame.num_rows def _getrowid(self, query): + """Internal function to convert from the local row ids to the row ids of the frame.""" if self.is_contiguous(): # shortcut for identical mapping return query @@ -233,16 +670,8 @@ def _getrowid(self, query): idxtensor = self.index().tousertensor() return utils.toindex(F.gather_row(idxtensor, query.tousertensor())) - def index(self): - if self._index is None: - if self.is_contiguous(): - self._index = utils.toindex( - F.arange(self._index_data.stop, dtype=F.int64)) - else: - self._index = utils.toindex(self._index_data) - return self._index - def _clear_cache(self): + """Internal function to clear the cached object.""" self._index_tensor = None def merge_frames(frames, indices, max_index, reduce_func): @@ -267,6 +696,8 @@ def merge_frames(frames, indices, max_index, reduce_func): merged : FrameRef The merged frame. """ + # TODO(minjie) + assert False, 'Buggy code, disabled for now.' assert reduce_func == 'sum' assert len(frames) > 0 schemes = frames[0].schemes diff --git a/python/dgl/graph.py b/python/dgl/graph.py index 38cd8d28e63b..0ef4e273433f 100644 --- a/python/dgl/graph.py +++ b/python/dgl/graph.py @@ -504,25 +504,49 @@ def from_scipy_sparse_matrix(self, a): self._msg_graph.add_nodes(self._graph.number_of_nodes()) def node_attr_schemes(self): - """Return the node attribute schemes. + """Return the node feature schemes. Returns ------- - iterable - The set of attribute names + dict of str to schemes + The schemes of node feature columns. """ return self._node_frame.schemes def edge_attr_schemes(self): - """Return the edge attribute schemes. + """Return the edge feature schemes. Returns ------- - iterable - The set of attribute names + dict of str to schemes + The schemes of edge feature columns. """ return self._edge_frame.schemes + def set_n_initializer(self, initializer): + """Set the initializer for empty node features. + + Initializer is a callable that returns a tensor given the shape and data type. + + Parameters + ---------- + initializer : callable + The initializer. + """ + self._node_frame.set_initializer(initializer) + + def set_e_initializer(self, initializer): + """Set the initializer for empty edge features. + + Initializer is a callable that returns a tensor given the shape and data type. + + Parameters + ---------- + initializer : callable + The initializer. + """ + self._edge_frame.set_initializer(initializer) + def set_n_repr(self, hu, u=ALL, inplace=False): """Set node(s) representation. @@ -534,12 +558,17 @@ def set_n_repr(self, hu, u=ALL, inplace=False): Dictionary type is also supported for `hu`. In this case, each item will be treated as separate attribute of the nodes. + All update will be done out-placely to work with autograd unless the inplace + flag is true. + Parameters ---------- hu : tensor or dict of tensor - Node representation. + Node representation. u : node, container or tensor - The node(s). + The node(s). + inplace : bool + True if the update is done inplacely """ # sanity check if is_all(u): @@ -607,7 +636,7 @@ def pop_n_repr(self, key=__REPR__): """ return self._node_frame.pop(key) - def set_e_repr(self, h_uv, u=ALL, v=ALL): + def set_e_repr(self, h_uv, u=ALL, v=ALL, inplace=False): """Set edge(s) representation. To set multiple edge representations at once, pass `u` and `v` with tensors or @@ -618,6 +647,9 @@ def set_e_repr(self, h_uv, u=ALL, v=ALL): Dictionary type is also supported for `h_uv`. In this case, each item will be treated as separate attribute of the edges. + All update will be done out-placely to work with autograd unless the inplace + flag is true. + Parameters ---------- h_uv : tensor or dict of tensor @@ -626,28 +658,35 @@ def set_e_repr(self, h_uv, u=ALL, v=ALL): The source node(s). v : node, container or tensor The destination node(s). + inplace : bool + True if the update is done inplacely """ # sanity check u_is_all = is_all(u) v_is_all = is_all(v) assert u_is_all == v_is_all if u_is_all: - self.set_e_repr_by_id(h_uv, eid=ALL) + self.set_e_repr_by_id(h_uv, eid=ALL, inplace=inplace) else: u = utils.toindex(u) v = utils.toindex(v) _, _, eid = self._graph.edge_ids(u, v) - self.set_e_repr_by_id(h_uv, eid=eid) + self.set_e_repr_by_id(h_uv, eid=eid, inplace=inplace) - def set_e_repr_by_id(self, h_uv, eid=ALL): + def set_e_repr_by_id(self, h_uv, eid=ALL, inplace=False): """Set edge(s) representation by edge id. + All update will be done out-placely to work with autograd unless the inplace + flag is true. + Parameters ---------- h_uv : tensor or dict of tensor Edge representation. eid : int, container or tensor The edge id(s). + inplace : bool + True if the update is done inplacely """ # sanity check if is_all(eid): @@ -662,16 +701,18 @@ def set_e_repr_by_id(self, h_uv, eid=ALL): assert F.shape(h_uv)[0] == num_edges # set if is_all(eid): + # update column if utils.is_dict_like(h_uv): for key, val in h_uv.items(): self._edge_frame[key] = val else: self._edge_frame[__REPR__] = h_uv else: + # update row if utils.is_dict_like(h_uv): - self._edge_frame[eid] = h_uv + self._edge_frame.update_rows(eid, h_uv, inplace=inplace) else: - self._edge_frame[eid] = {__REPR__ : h_uv} + self._edge_frame.update_rows(eid, {__REPR__ : h_uv}, inplace=inplace) def get_e_repr(self, u=ALL, v=ALL): """Get node(s) representation. @@ -793,12 +834,12 @@ def register_apply_edge_func(self, apply_edge_func): """ self._apply_edge_func = apply_edge_func - def apply_nodes(self, v, apply_node_func="default"): + def apply_nodes(self, v=ALL, apply_node_func="default"): """Apply the function on node representations. Parameters ---------- - v : int, iterable of int, tensor + v : int, iterable of int, tensor, optional The node id(s). apply_node_func : callable The apply node function. @@ -952,8 +993,8 @@ def _batch_send(self, u, v, eid, message_func): self._msg_frame.update_rows( msg_target_rows, {k: F.gather_row(msgs[k], msg_update_rows.tousertensor()) - for k in msgs} - ) + for k in msgs}, + inplace=False) if len(msg_append_rows) > 0: new_u, new_v = zip(*new_uv) new_u = utils.toindex(new_u) @@ -961,14 +1002,13 @@ def _batch_send(self, u, v, eid, message_func): self._msg_graph.add_edges(new_u, new_v) self._msg_frame.append( {k: F.gather_row(msgs[k], msg_append_rows.tousertensor()) - for k in msgs} - ) + for k in msgs}) else: if len(msg_target_rows) > 0: self._msg_frame.update_rows( msg_target_rows, - {__MSG__: F.gather_row(msgs, msg_update_rows.tousertensor())} - ) + {__MSG__: F.gather_row(msgs, msg_update_rows.tousertensor())}, + inplace=False) if len(msg_append_rows) > 0: new_u, new_v = zip(*new_uv) new_u = utils.toindex(new_u) diff --git a/tests/pytorch/test_basics.py b/tests/pytorch/test_basics.py index 0b739f5cd673..10a86ecaba9f 100644 --- a/tests/pytorch/test_basics.py +++ b/tests/pytorch/test_basics.py @@ -20,22 +20,26 @@ def reduce_func(node, msgs): reduce_msg_shapes.add(tuple(msgs.shape)) assert len(msgs.shape) == 3 assert msgs.shape[2] == D - return {'m' : th.sum(msgs, 1)} + return {'accum' : th.sum(msgs, 1)} def apply_node_func(node): - return {'h' : node['h'] + node['m']} + return {'h' : node['h'] + node['accum']} def generate_graph(grad=False): g = DGLGraph() g.add_nodes(10) # 10 nodes. # create a graph where 0 is the source and 9 is the sink + # 17 edges for i in range(1, 9): g.add_edge(0, i) g.add_edge(i, 9) # add a back flow from 9 to 0 g.add_edge(9, 0) ncol = Variable(th.randn(10, D), requires_grad=grad) + accumcol = Variable(th.randn(10, D), requires_grad=grad) + ecol = Variable(th.randn(17, D), requires_grad=grad) g.set_n_repr({'h' : ncol}) + g.set_n_initializer(lambda shape, dtype : th.zeros(shape)) return g def test_batch_setter_getter(): @@ -46,8 +50,9 @@ def _pfc(x): g.set_n_repr({'h' : th.zeros((10, D))}) assert _pfc(g.get_n_repr()['h']) == [0.] * 10 # pop nodes + old_len = len(g.get_n_repr()) assert _pfc(g.pop_n_repr('h')) == [0.] * 10 - assert len(g.get_n_repr()) == 0 + assert len(g.get_n_repr()) == old_len - 1 g.set_n_repr({'h' : th.zeros((10, D))}) # set partial nodes u = th.tensor([1, 3, 5]) @@ -81,8 +86,9 @@ def _pfc(x): g.set_e_repr({'l' : th.zeros((17, D))}) assert _pfc(g.get_e_repr()['l']) == [0.] * 17 # pop edges + old_len = len(g.get_e_repr()) assert _pfc(g.pop_e_repr('l')) == [0.] * 17 - assert len(g.get_e_repr()) == 0 + assert len(g.get_e_repr()) == old_len - 1 g.set_e_repr({'l' : th.zeros((17, D))}) # set partial edges (many-many) u = th.tensor([0, 0, 2, 5, 9]) diff --git a/tests/pytorch/test_basics_anonymous.py b/tests/pytorch/test_basics_anonymous.py index ad431b05b6e3..81dbe653e7c9 100644 --- a/tests/pytorch/test_basics_anonymous.py +++ b/tests/pytorch/test_basics_anonymous.py @@ -30,8 +30,10 @@ def generate_graph(grad=False): g.add_edge(i, 9) # add a back flow from 9 to 0 g.add_edge(9, 0) - col = Variable(th.randn(10, D), requires_grad=grad) - g.set_n_repr(col) + ncol = Variable(th.randn(10, D), requires_grad=grad) + ecol = Variable(th.randn(17, D), requires_grad=grad) + g.set_n_repr(ncol) + g.set_e_repr(ecol) return g def test_batch_setter_getter(): diff --git a/tests/pytorch/test_frame.py b/tests/pytorch/test_frame.py index aaec133cd56e..927277773ecf 100644 --- a/tests/pytorch/test_frame.py +++ b/tests/pytorch/test_frame.py @@ -2,14 +2,11 @@ from torch.autograd import Variable import numpy as np from dgl.frame import Frame, FrameRef -from dgl.utils import Index +from dgl.utils import Index, toindex N = 10 D = 5 -def check_eq(a, b): - return a.shape == b.shape and np.allclose(a.numpy(), b.numpy()) - def check_fail(fn): try: fn() @@ -27,12 +24,13 @@ def test_create(): data = create_test_data() f1 = Frame() for k, v in data.items(): - f1.add_column(k, v) - assert f1.schemes == set(data.keys()) + f1.update_column(k, v) + print(f1.schemes) + assert f1.keys() == set(data.keys()) assert f1.num_columns == 3 assert f1.num_rows == N f2 = Frame(data) - assert f2.schemes == set(data.keys()) + assert f2.keys() == set(data.keys()) assert f2.num_columns == 3 assert f2.num_rows == N f1.clear() @@ -45,9 +43,9 @@ def test_column1(): f = Frame(data) assert f.num_rows == N assert len(f) == 3 - assert check_eq(f['a1'], data['a1']) + assert th.allclose(f['a1'].data, data['a1'].data) f['a1'] = data['a2'] - assert check_eq(f['a2'], data['a2']) + assert th.allclose(f['a2'].data, data['a2'].data) # add a different length column should fail def failed_add_col(): f['a4'] = th.zeros([N+1, D]) @@ -70,16 +68,15 @@ def test_column2(): f = FrameRef(data, [3, 4, 5, 6, 7]) assert f.num_rows == 5 assert len(f) == 3 - assert check_eq(f['a1'], data['a1'][3:8]) + assert th.allclose(f['a1'], data['a1'].data[3:8]) # set column should reflect on the referenced data f['a1'] = th.zeros([5, D]) - assert check_eq(data['a1'][3:8], th.zeros([5, D])) - # add new column should be padded with zero - f['a4'] = th.ones([5, D]) - assert len(data) == 4 - assert check_eq(data['a4'][0:3], th.zeros([3, D])) - assert check_eq(data['a4'][3:8], th.ones([5, D])) - assert check_eq(data['a4'][8:10], th.zeros([2, D])) + assert th.allclose(data['a1'].data[3:8], th.zeros([5, D])) + # add new partial column should fail with error initializer + f.set_initializer(lambda shape, dtype : assert_(False)) + def failed_add_col(): + f['a4'] = th.ones([5, D]) + assert check_fail(failed_add_col) def test_append1(): # test append API on Frame @@ -91,9 +88,14 @@ def test_append1(): f1.append(f2) assert f1.num_rows == 2 * N c1 = f1['a1'] - assert c1.shape == (2 * N, D) + assert c1.data.shape == (2 * N, D) truth = th.cat([data['a1'], data['a1']]) - assert check_eq(truth, c1) + assert th.allclose(truth, c1.data) + # append dict of different length columns should fail + f3 = {'a1' : th.zeros((3, D)), 'a2' : th.zeros((3, D)), 'a3' : th.zeros((2, D))} + def failed_append(): + f1.append(f3) + assert check_fail(failed_append) def test_append2(): # test append on FrameRef @@ -113,7 +115,7 @@ def test_append2(): assert not f.is_span_whole_column() assert f.num_rows == 3 * N new_idx = list(range(N)) + list(range(2*N, 4*N)) - assert check_eq(f.index().tousertensor(), th.tensor(new_idx)) + assert th.all(f.index().tousertensor() == th.tensor(new_idx, dtype=th.int64)) assert data.num_rows == 4 * N def test_row1(): @@ -127,13 +129,13 @@ def test_row1(): rows = f[rowid] for k, v in rows.items(): assert v.shape == (len(rowid), D) - assert check_eq(v, data[k][rowid]) + assert th.allclose(v, data[k][rowid]) # test duplicate keys rowid = Index(th.tensor([8, 2, 2, 1])) rows = f[rowid] for k, v in rows.items(): assert v.shape == (len(rowid), D) - assert check_eq(v, data[k][rowid]) + assert th.allclose(v, data[k][rowid]) # setter rowid = Index(th.tensor([0, 2, 4])) @@ -143,12 +145,14 @@ def test_row1(): } f[rowid] = vals for k, v in f[rowid].items(): - assert check_eq(v, th.zeros((len(rowid), D))) + assert th.allclose(v, th.zeros((len(rowid), D))) - # setting rows with new column should automatically add a new column - vals['a4'] = th.ones((len(rowid), D)) - f[rowid] = vals - assert len(f) == 4 + # setting rows with new column should raise error with error initializer + f.set_initializer(lambda shape, dtype : assert_(False)) + def failed_update_rows(): + vals['a4'] = th.ones((len(rowid), D)) + f[rowid] = vals + assert check_fail(failed_update_rows) def test_row2(): # test row getter/setter autograd compatibility @@ -161,13 +165,13 @@ def test_row2(): rowid = Index(th.tensor([0, 2])) rows = f[rowid] rows['a1'].backward(th.ones((len(rowid), D))) - assert check_eq(c1.grad[:,0], th.tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 0.])) + assert th.allclose(c1.grad[:,0], th.tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 0.])) c1.grad.data.zero_() # test duplicate keys rowid = Index(th.tensor([8, 2, 2, 1])) rows = f[rowid] rows['a1'].backward(th.ones((len(rowid), D))) - assert check_eq(c1.grad[:,0], th.tensor([0., 1., 2., 0., 0., 0., 0., 0., 1., 0.])) + assert th.allclose(c1.grad[:,0], th.tensor([0., 1., 2., 0., 0., 0., 0., 0., 1., 0.])) c1.grad.data.zero_() # setter @@ -180,8 +184,8 @@ def test_row2(): f[rowid] = vals c11 = f['a1'] c11.backward(th.ones((N, D))) - assert check_eq(c1.grad[:,0], th.tensor([0., 1., 0., 1., 0., 1., 1., 1., 1., 1.])) - assert check_eq(vals['a1'].grad, th.ones((len(rowid), D))) + assert th.allclose(c1.grad[:,0], th.tensor([0., 1., 0., 1., 0., 1., 1., 1., 1., 1.])) + assert th.allclose(vals['a1'].grad, th.ones((len(rowid), D))) assert vals['a2'].grad is None def test_row3(): @@ -201,8 +205,9 @@ def test_row3(): newidx = list(range(N)) newidx.pop(2) newidx.pop(2) + newidx = toindex(newidx) for k, v in f.items(): - assert check_eq(v, data[k][th.tensor(newidx)]) + assert th.allclose(v, data[k][newidx]) def test_sharing(): data = Frame(create_test_data()) @@ -210,10 +215,10 @@ def test_sharing(): f2 = FrameRef(data, index=[2, 3, 4, 5, 6]) # test read for k, v in f1.items(): - assert check_eq(data[k][0:4], v) + assert th.allclose(data[k].data[0:4], v) for k, v in f2.items(): - assert check_eq(data[k][2:7], v) - f2_a1 = f2['a1'] + assert th.allclose(data[k].data[2:7], v) + f2_a1 = f2['a1'].data # test write # update own ref should not been seen by the other. f1[Index(th.tensor([0, 1]))] = { @@ -221,7 +226,7 @@ def test_sharing(): 'a2' : th.zeros([2, D]), 'a3' : th.zeros([2, D]), } - assert check_eq(f2['a1'], f2_a1) + assert th.allclose(f2['a1'], f2_a1) # update shared space should been seen by the other. f1[Index(th.tensor([2, 3]))] = { 'a1' : th.ones([2, D]), @@ -229,7 +234,7 @@ def test_sharing(): 'a3' : th.ones([2, D]), } f2_a1[0:2] = th.ones([2, D]) - assert check_eq(f2['a1'], f2_a1) + assert th.allclose(f2['a1'], f2_a1) if __name__ == '__main__': test_create() diff --git a/tests/pytorch/test_specialization.py b/tests/pytorch/test_specialization.py index 6327f53c4bb3..fbd648a8e404 100644 --- a/tests/pytorch/test_specialization.py +++ b/tests/pytorch/test_specialization.py @@ -123,6 +123,7 @@ def reduce_func(hv, msgs): return {'v2': th.sum(msgs['m2'], 1)} g = generate_graph() + g.set_n_repr({'v1' : th.zeros((10,)), 'v2' : th.zeros((10,))}) fld = 'f2' # update all, mix of builtin and UDF g.update_all([fn.copy_src(src=fld, out='m1'), message_func], @@ -173,6 +174,8 @@ def reduce_func(hv, msgs): return {'v2' : th.sum(msgs['m2'], 1)} g = generate_graph() + g.set_n_repr({'v1' : th.zeros((10, D)), 'v2' : th.zeros((10, D)), + 'v3' : th.zeros((10, D))}) fld = 'f2' # send and recv, mix of builtin and UDF