diff --git a/docs/source/api/python/transform.rst b/docs/source/api/python/transform.rst index d65c3a0fcbbc..32d8870dc202 100644 --- a/docs/source/api/python/transform.rst +++ b/docs/source/api/python/transform.rst @@ -13,6 +13,12 @@ BaseTransform :members: __call__, __repr__ :show-inheritance: +Compose +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: Compose + :show-inheritance: + AddSelfLoop ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -55,8 +61,50 @@ AddMetaPaths .. autoclass:: AddMetaPaths :show-inheritance: -KNNGraph +GCNNorm +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: GCNNorm + :show-inheritance: + +PPR +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: PPR + :show-inheritance: + +HeatKernel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: HeatKernel + :show-inheritance: + +GDC +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: GDC + :show-inheritance: + +NodeShuffle +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: NodeShuffle + :show-inheritance: + +DropNode +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: DropNode + :show-inheritance: + +DropEdge +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: DropEdge + :show-inheritance: + +AddEdge ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: KNNGraph +.. autoclass:: AddEdge :show-inheritance: diff --git a/python/dgl/backend/backend.py b/python/dgl/backend/backend.py index c3b893bd9e4c..0ea3ace3f8f1 100644 --- a/python/dgl/backend/backend.py +++ b/python/dgl/backend/backend.py @@ -570,6 +570,21 @@ def exp(input): """ pass +def inverse(input): + """Returns the inverse matrix of a square matrix if it exists. + + Parameters + ---------- + input : Tensor + The input square matrix. + + Returns + ------- + Tensor + The output tensor. + """ + pass + def sqrt(input): """Returns a new tensor with the square root of the elements of the input tensor `input`. @@ -1057,6 +1072,21 @@ def equal(x, y): """ pass +def allclose(x, y, rtol=1e-4, atol=1e-4): + """Compares whether all elements are close. + + Parameters + ---------- + x : Tensor + First tensor + y : Tensor + Second tensor + rtol : float, optional + Relative tolerance + atol : float, optional + Absolute tolerance + """ + def logical_not(input): """Perform a logical not operation. Equivalent to np.logical_not diff --git a/python/dgl/backend/mxnet/tensor.py b/python/dgl/backend/mxnet/tensor.py index 2a3244c631f7..881208393fd9 100644 --- a/python/dgl/backend/mxnet/tensor.py +++ b/python/dgl/backend/mxnet/tensor.py @@ -191,6 +191,9 @@ def argsort(input, dim, descending): def exp(input): return nd.exp(input) +def inverse(input): + return nd.linalg_inverse(input) + def sqrt(input): return nd.sqrt(input) @@ -327,6 +330,9 @@ def boolean_mask(input, mask): def equal(x, y): return x == y +def allclose(x, y, rtol=1e-4, atol=1e-4): + return np.allclose(x.asnumpy(), y.asnumpy(), rtol=rtol, atol=atol) + def logical_not(input): return nd.logical_not(input) diff --git a/python/dgl/backend/pytorch/tensor.py b/python/dgl/backend/pytorch/tensor.py index f08d26552455..3f4ac995b71d 100644 --- a/python/dgl/backend/pytorch/tensor.py +++ b/python/dgl/backend/pytorch/tensor.py @@ -14,8 +14,8 @@ from ...function.base import TargetCode from ...base import dgl_warning -if LooseVersion(th.__version__) < LooseVersion("1.5.0"): - raise Exception("Detected an old version of PyTorch. Please update torch>=1.5.0 " +if LooseVersion(th.__version__) < LooseVersion("1.8.0"): + raise Exception("Detected an old version of PyTorch. Please update torch>=1.8.0 " "for the best experience.") def data_type_dict(): @@ -164,6 +164,9 @@ def argtopk(input, k, dim, descending=True): def exp(input): return th.exp(input) +def inverse(input): + return th.inverse(input) + def sqrt(input): return th.sqrt(input) @@ -276,6 +279,9 @@ def boolean_mask(input, mask): def equal(x, y): return x == y +def allclose(x, y, rtol=1e-4, atol=1e-4): + return th.allclose(x, y, rtol=rtol, atol=atol) + def logical_not(input): return ~input diff --git a/python/dgl/backend/tensorflow/tensor.py b/python/dgl/backend/tensorflow/tensor.py index b54343c9e13e..b919118ac375 100644 --- a/python/dgl/backend/tensorflow/tensor.py +++ b/python/dgl/backend/tensorflow/tensor.py @@ -244,6 +244,10 @@ def exp(input): return tf.exp(input) +def inverse(input): + return tf.linalg.inv(input) + + def sqrt(input): return tf.sqrt(input) @@ -396,6 +400,11 @@ def equal(x, y): return x == y +def allclose(x, y, rtol=1e-4, atol=1e-4): + return np.allclose(tf.convert_to_tensor(x).numpy(), + tf.convert_to_tensor(y).numpy(), rtol=rtol, atol=atol) + + def logical_not(input): return ~input diff --git a/python/dgl/transform/module.py b/python/dgl/transform/module.py index 7955aa1a6fec..460ee3cb7220 100644 --- a/python/dgl/transform/module.py +++ b/python/dgl/transform/module.py @@ -14,11 +14,21 @@ # limitations under the License. # """Modules for transform""" +# pylint: disable= no-member, arguments-differ, invalid-name + +from scipy.linalg import expm from .. import convert from .. import backend as F +from .. import function as fn from . import functional +try: + import torch + from torch.distributions import Bernoulli +except ImportError: + pass + __all__ = [ 'BaseTransform', 'AddSelfLoop', @@ -28,7 +38,15 @@ 'LineGraph', 'KHopGraph', 'AddMetaPaths', - 'Compose' + 'Compose', + 'GCNNorm', + 'PPR', + 'HeatKernel', + 'GDC', + 'NodeShuffle', + 'DropNode', + 'DropEdge', + 'AddEdge' ] def update_graph_structure(g, data_dict, copy_edata=True): @@ -672,3 +690,568 @@ def __call__(self, g): def __repr__(self): args = [' ' + str(transform) for transform in self.transforms] return self.__class__.__name__ + '([\n' + ',\n'.join(args) + '\n])' + +class GCNNorm(BaseTransform): + r""" + + Description + ----------- + Apply symmetric adjacency normalization to an input graph and save the result edge weights, + as described in `Semi-Supervised Classification with Graph Convolutional Networks + `__. + + For a heterogeneous graph, this only applies to symmetric canonical edge types, whose source + and destination node types are identical. + + Parameters + ---------- + eweight_name : str, optional + :attr:`edata` name to retrieve and store edge weights. The edge weights are optional. + + Example + ------- + + >>> import dgl + >>> import torch + >>> from dgl import GCNNorm + >>> transform = GCNNorm() + >>> g = dgl.graph(([0, 1, 2], [0, 0, 1])) + + Case1: Transform an unweighted graph + + >>> g = transform(g) + >>> print(g.edata['w']) + tensor([0.5000, 0.7071, 0.0000]) + + Case2: Transform a weighted graph + + >>> g.edata['w'] = torch.tensor([0.1, 0.2, 0.3]) + >>> g = transform(g) + >>> print(g.edata['w']) + tensor([0.3333, 0.6667, 0.0000]) + """ + def __init__(self, eweight_name='w'): + self.eweight_name = eweight_name + + def calc_etype(self, c_etype, g): + r""" + + Description + ----------- + Get edge weights for an edge type. + """ + ntype = c_etype[0] + with g.local_scope(): + if self.eweight_name in g.edges[c_etype].data: + g.update_all(fn.copy_e(self.eweight_name, 'm'), fn.sum('m', 'deg'), etype=c_etype) + deg_inv_sqrt = 1. / F.sqrt(g.nodes[ntype].data['deg']) + g.nodes[ntype].data['w'] = F.replace_inf_with_zero(deg_inv_sqrt) + g.apply_edges(lambda edge: {'w': edge.src['w'] * edge.data[self.eweight_name] * + edge.dst['w']}, + etype=c_etype) + else: + deg = g.in_degrees(etype=c_etype) + deg_inv_sqrt = 1. / F.sqrt(F.astype(deg, F.float32)) + g.nodes[ntype].data['w'] = F.replace_inf_with_zero(deg_inv_sqrt) + g.apply_edges(lambda edges: {'w': edges.src['w'] * edges.dst['w']}, etype=c_etype) + return g.edges[c_etype].data['w'] + + def __call__(self, g): + result = dict() + for c_etype in g.canonical_etypes: + utype, _, vtype = c_etype + if utype == vtype: + result[c_etype] = self.calc_etype(c_etype, g) + + for c_etype, eweight in result.items(): + g.edges[c_etype].data[self.eweight_name] = eweight + return g + +class PPR(BaseTransform): + r""" + + Description + ----------- + Apply personalized PageRank (PPR) to an input graph for diffusion, as introduced in + `The pagerank citation ranking: Bringing order to the web + `__. A sparsification will be applied to the + weighted adjacency matrix after diffusion. Specifically, edges whose weight is below + a threshold will be dropped. + + This module only works for homogeneous graphs. + + Parameters + ---------- + alpha : float, optional + Restart probability, which commonly lies in :math:`[0.05, 0.2]`. + eweight_name : str, optional + :attr:`edata` name to retrieve and store edge weights. If it does + not exist in an input graph, this module initializes a weight of 1 + for all edges. The edge weights should be a tensor of shape :math:`(E)`, + where E is the number of edges. + eps : float, optional + The threshold to preserve edges in sparsification after diffusion. Edges of a + weight smaller than eps will be dropped. + avg_degree : int, optional + The desired average node degree of the result graph. This is the other way to + control the sparsity of the result graph and will only be effective if + :attr:`eps` is not given. + + Example + ------- + + >>> import dgl + >>> import torch + >>> from dgl import PPR + + >>> transform = PPR(avg_degree=2) + >>> g = dgl.graph(([0, 1, 2, 3, 4], [2, 3, 4, 5, 3])) + >>> g.edata['w'] = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]) + >>> new_g = transform(g) + >>> print(new_g.edata['w']) + tensor([0.1500, 0.1500, 0.1500, 0.0255, 0.0163, 0.1500, 0.0638, 0.0383, 0.1500, + 0.0510, 0.0217, 0.1500]) + """ + def __init__(self, alpha=0.15, eweight_name='w', eps=None, avg_degree=5): + self.alpha = alpha + self.eweight_name = eweight_name + self.eps = eps + self.avg_degree = avg_degree + + def get_eps(self, num_nodes, mat): + r""" + + Description + ----------- + Get the threshold for graph sparsification. + """ + if self.eps is None: + # Infer from self.avg_degree + if self.avg_degree > num_nodes: + return float('-inf') + sorted_weights = torch.sort(mat.flatten(), descending=True).values + return sorted_weights[self.avg_degree * num_nodes - 1] + else: + return self.eps + + def __call__(self, g): + # Step1: PPR diffusion + # (α - 1) A + device = g.device + eweight = (self.alpha - 1) * g.edata.get(self.eweight_name, F.ones( + (g.num_edges(),), F.float32, device)) + num_nodes = g.num_nodes() + mat = F.zeros((num_nodes, num_nodes), F.float32, device) + src, dst = g.edges() + src, dst = F.astype(src, F.int64), F.astype(dst, F.int64) + mat[dst, src] = eweight + # I_n + (α - 1) A + nids = F.astype(g.nodes(), F.int64) + mat[nids, nids] = mat[nids, nids] + 1 + # α (I_n + (α - 1) A)^-1 + diff_mat = self.alpha * F.inverse(mat) + + # Step2: sparsification + num_nodes = g.num_nodes() + eps = self.get_eps(num_nodes, diff_mat) + dst, src = (diff_mat >= eps).nonzero(as_tuple=False).t() + data_dict = {g.canonical_etypes[0]: (src, dst)} + new_g = update_graph_structure(g, data_dict, copy_edata=False) + new_g.edata[self.eweight_name] = diff_mat[dst, src] + + return new_g + +def is_bidirected(g): + """Return whether the graph is a bidirected graph. + + A graph is bidirected if for any edge :math:`(u, v)` in :math:`G` with weight :math:`w`, + there exists an edge :math:`(v, u)` in :math:`G` with the same weight. + """ + src, dst = g.edges() + num_nodes = g.num_nodes() + + # Sort first by src then dst + idx_src_dst = src * num_nodes + dst + perm_src_dst = F.argsort(idx_src_dst, dim=0, descending=False) + src1, dst1 = src[perm_src_dst], dst[perm_src_dst] + + # Sort first by dst then src + idx_dst_src = dst * num_nodes + src + perm_dst_src = F.argsort(idx_dst_src, dim=0, descending=False) + src2, dst2 = src[perm_dst_src], dst[perm_dst_src] + + return F.allclose(src1, dst2) and F.allclose(src2, dst1) + +# pylint: disable=C0103 +class HeatKernel(BaseTransform): + r""" + + Description + ----------- + Apply heat kernel to an input graph for diffusion, as introduced in + `Diffusion kernels on graphs and other discrete structures + `__. + A sparsification will be applied to the weighted adjacency matrix after diffusion. + Specifically, edges whose weight is below a threshold will be dropped. + + This module only works for homogeneous graphs. + + Parameters + ---------- + t : float, optional + Diffusion time, which commonly lies in :math:`[2, 10]`. + eweight_name : str, optional + :attr:`edata` name to retrieve and store edge weights. If it does + not exist in an input graph, this module initializes a weight of 1 + for all edges. The edge weights should be a tensor of shape :math:`(E)`, + where E is the number of edges. + eps : float, optional + The threshold to preserve edges in sparsification after diffusion. Edges of a + weight smaller than eps will be dropped. + avg_degree : int, optional + The desired average node degree of the result graph. This is the other way to + control the sparsity of the result graph and will only be effective if + :attr:`eps` is not given. + + Example + ------- + + >>> import dgl + >>> import torch + >>> from dgl import HeatKernel + + >>> transform = HeatKernel(avg_degree=2) + >>> g = dgl.graph(([0, 1, 2, 3, 4], [2, 3, 4, 5, 3])) + >>> g.edata['w'] = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]) + >>> new_g = transform(g) + >>> print(new_g.edata['w']) + tensor([0.1353, 0.1353, 0.1353, 0.0541, 0.0406, 0.1353, 0.1353, 0.0812, 0.1353, + 0.1083, 0.0541, 0.1353]) + """ + def __init__(self, t=2., eweight_name='w', eps=None, avg_degree=5): + self.t = t + self.eweight_name = eweight_name + self.eps = eps + self.avg_degree = avg_degree + + def get_eps(self, num_nodes, mat): + r""" + + Description + ----------- + Get the threshold for graph sparsification. + """ + if self.eps is None: + # Infer from self.avg_degree + if self.avg_degree > num_nodes: + return float('-inf') + sorted_weights = torch.sort(mat.flatten(), descending=True).values + return sorted_weights[self.avg_degree * num_nodes - 1] + else: + return self.eps + + def __call__(self, g): + # Step1: heat kernel diffusion + # t A + device = g.device + eweight = self.t * g.edata.get(self.eweight_name, F.ones( + (g.num_edges(),), F.float32, device)) + num_nodes = g.num_nodes() + mat = F.zeros((num_nodes, num_nodes), F.float32, device) + src, dst = g.edges() + src, dst = F.astype(src, F.int64), F.astype(dst, F.int64) + mat[dst, src] = eweight + # t (A - I_n) + nids = F.astype(g.nodes(), F.int64) + mat[nids, nids] = mat[nids, nids] - self.t + + if is_bidirected(g): + e, V = torch.linalg.eigh(mat, UPLO='U') + diff_mat = V @ torch.diag(e.exp()) @ V.t() + else: + diff_mat_np = expm(mat.cpu().numpy()) + diff_mat = torch.Tensor(diff_mat_np).to(device) + + # Step2: sparsification + num_nodes = g.num_nodes() + eps = self.get_eps(num_nodes, diff_mat) + dst, src = (diff_mat >= eps).nonzero(as_tuple=False).t() + data_dict = {g.canonical_etypes[0]: (src, dst)} + new_g = update_graph_structure(g, data_dict, copy_edata=False) + new_g.edata[self.eweight_name] = diff_mat[dst, src] + + return new_g + +class GDC(BaseTransform): + r""" + + Description + ----------- + Apply graph diffusion convolution (GDC) to an input graph, as introduced in + `Diffusion Improves Graph Learning `__. A sparsification + will be applied to the weighted adjacency matrix after diffusion. Specifically, edges whose + weight is below a threshold will be dropped. + + This module only works for homogeneous graphs. + + Parameters + ---------- + coefs : list[float], optional + List of coefficients. :math:`\theta_k` for each power of the adjacency matrix. + eweight_name : str, optional + :attr:`edata` name to retrieve and store edge weights. If it does + not exist in an input graph, this module initializes a weight of 1 + for all edges. The edge weights should be a tensor of shape :math:`(E)`, + where E is the number of edges. + eps : float, optional + The threshold to preserve edges in sparsification after diffusion. Edges of a + weight smaller than eps will be dropped. + avg_degree : int, optional + The desired average node degree of the result graph. This is the other way to + control the sparsity of the result graph and will only be effective if + :attr:`eps` is not given. + + Example + ------- + + >>> import dgl + >>> import torch + >>> from dgl import GDC + + >>> transform = GDC([0.3, 0.2, 0.1], avg_degree=2) + >>> g = dgl.graph(([0, 1, 2, 3, 4], [2, 3, 4, 5, 3])) + >>> g.edata['w'] = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]) + >>> new_g = transform(g) + >>> print(new_g.edata['w']) + tensor([0.3000, 0.3000, 0.0200, 0.3000, 0.0400, 0.3000, 0.1000, 0.0600, 0.3000, + 0.0800, 0.0200, 0.3000]) + """ + def __init__(self, coefs, eweight_name='w', eps=None, avg_degree=5): + self.coefs = coefs + self.eweight_name = eweight_name + self.eps = eps + self.avg_degree = avg_degree + + def get_eps(self, num_nodes, mat): + r""" + + Description + ----------- + Get the threshold for graph sparsification. + """ + if self.eps is None: + # Infer from self.avg_degree + if self.avg_degree > num_nodes: + return float('-inf') + sorted_weights = torch.sort(mat.flatten(), descending=True).values + return sorted_weights[self.avg_degree * num_nodes - 1] + else: + return self.eps + + def __call__(self, g): + # Step1: diffusion + # A + device = g.device + eweight = g.edata.get(self.eweight_name, F.ones( + (g.num_edges(),), F.float32, device)) + num_nodes = g.num_nodes() + adj = F.zeros((num_nodes, num_nodes), F.float32, device) + src, dst = g.edges() + src, dst = F.astype(src, F.int64), F.astype(dst, F.int64) + adj[dst, src] = eweight + + # theta_0 I_n + mat = torch.eye(num_nodes, device=device) + diff_mat = self.coefs[0] * mat + # add theta_k A^k + for coef in self.coefs[1:]: + mat = mat @ adj + diff_mat += coef * mat + + # Step2: sparsification + num_nodes = g.num_nodes() + eps = self.get_eps(num_nodes, diff_mat) + dst, src = (diff_mat >= eps).nonzero(as_tuple=False).t() + data_dict = {g.canonical_etypes[0]: (src, dst)} + new_g = update_graph_structure(g, data_dict, copy_edata=False) + new_g.edata[self.eweight_name] = diff_mat[dst, src] + + return new_g + +class NodeShuffle(BaseTransform): + r""" + + Description + ----------- + Randomly shuffle the nodes. + + Example + ------- + + >>> import dgl + >>> import torch + >>> from dgl import NodeShuffle + + >>> transform = NodeShuffle() + >>> g = dgl.graph(([0, 1], [1, 2])) + >>> g.ndata['h1'] = torch.tensor([[1., 2.], [3., 4.], [5., 6.]]) + >>> g.ndata['h2'] = torch.tensor([[7., 8.], [9., 10.], [11., 12.]]) + >>> g = transform(g) + >>> print(g.ndata['h1']) + tensor([[5., 6.], + [3., 4.], + [1., 2.]]) + >>> print(g.ndata['h2']) + tensor([[11., 12.], + [ 9., 10.], + [ 7., 8.]]) + """ + def __call__(self, g): + for ntype in g.ntypes: + nids = F.astype(g.nodes(ntype), F.int64) + perm = F.rand_shuffle(nids) + for key, feat in g.nodes[ntype].data.items(): + g.nodes[ntype].data[key] = feat[perm] + return g + +# pylint: disable=C0103 +class DropNode(BaseTransform): + r""" + + Description + ----------- + Randomly drop nodes, as described in + `Graph Contrastive Learning with Augmentations `__. + + Parameters + ---------- + p : float, optional + Probability of a node to be dropped. + + Example + ------- + + >>> import dgl + >>> import torch + >>> from dgl import DropNode + + >>> transform = DropNode() + >>> g = dgl.rand_graph(5, 20) + >>> g.ndata['h'] = torch.arange(g.num_nodes()) + >>> g.edata['h'] = torch.arange(g.num_edges()) + >>> new_g = transform(g) + >>> print(new_g) + Graph(num_nodes=3, num_edges=7, + ndata_schemes={'h': Scheme(shape=(), dtype=torch.int64)} + edata_schemes={'h': Scheme(shape=(), dtype=torch.int64)}) + >>> print(new_g.ndata['h']) + tensor([0, 1, 2]) + >>> print(new_g.edata['h']) + tensor([0, 6, 14, 5, 17, 3, 11]) + """ + def __init__(self, p=0.5): + self.p = p + self.dist = Bernoulli(p) + + def __call__(self, g): + # Fast path + if self.p == 0: + return g + + for ntype in g.ntypes: + samples = self.dist.sample(torch.Size([g.num_nodes(ntype)])) + nids_to_remove = g.nodes(ntype)[samples.bool().to(g.device)] + g.remove_nodes(nids_to_remove, ntype=ntype) + return g + +# pylint: disable=C0103 +class DropEdge(BaseTransform): + r""" + + Description + ----------- + Randomly drop edges, as described in + `DropEdge: Towards Deep Graph Convolutional Networks on Node Classification + `__ and `Graph Contrastive Learning with Augmentations + `__. + + Parameters + ---------- + p : float, optional + Probability of an edge to be dropped. + + Example + ------- + + >>> import dgl + >>> import torch + >>> from dgl import DropEdge + + >>> transform = DropEdge() + >>> g = dgl.rand_graph(5, 20) + >>> g.edata['h'] = torch.arange(g.num_edges()) + >>> new_g = transform(g) + >>> print(new_g) + Graph(num_nodes=5, num_edges=12, + ndata_schemes={} + edata_schemes={'h': Scheme(shape=(), dtype=torch.int64)}) + >>> print(new_g.edata['h']) + tensor([0, 1, 3, 7, 8, 10, 11, 12, 13, 15, 18, 19]) + """ + def __init__(self, p=0.5): + self.p = p + self.dist = Bernoulli(p) + + def __call__(self, g): + # Fast path + if self.p == 0: + return g + + for c_etype in g.canonical_etypes: + samples = self.dist.sample(torch.Size([g.num_edges(c_etype)])) + eids_to_remove = g.edges(form='eid', etype=c_etype)[samples.bool().to(g.device)] + g.remove_edges(eids_to_remove, etype=c_etype) + return g + +class AddEdge(BaseTransform): + r""" + + Description + ----------- + Randomly add edges, as described in `Graph Contrastive Learning with Augmentations + `__. + + Parameters + ---------- + ratio : float, optional + Number of edges to add divided by the number of existing edges. + + Example + ------- + + >>> import dgl + >>> from dgl import AddEdge + + >>> transform = AddEdge() + >>> g = dgl.rand_graph(5, 20) + >>> new_g = transform(g) + >>> print(new_g.num_edges()) + 24 + """ + def __init__(self, ratio=0.2): + self.ratio = ratio + + def __call__(self, g): + # Fast path + if self.ratio == 0.: + return g + + device = g.device + idtype = g.idtype + for c_etype in g.canonical_etypes: + utype, _, vtype = c_etype + num_edges_to_add = int(g.num_edges(c_etype) * self.ratio) + src = F.randint([num_edges_to_add], idtype, device, low=0, high=g.num_nodes(utype)) + dst = F.randint([num_edges_to_add], idtype, device, low=0, high=g.num_nodes(vtype)) + g.add_edges(src, dst, etype=c_etype) + return g diff --git a/tests/compute/test_graph.py b/tests/compute/test_graph.py index ed8e24f8483c..45bfe41e085d 100644 --- a/tests/compute/test_graph.py +++ b/tests/compute/test_graph.py @@ -345,8 +345,8 @@ def test_empty_data_initialized(): assert len(g.ndata["ha"]) == 1 def test_is_sorted(): - u_src, u_dst = edge_pair_input(False) - s_src, s_dst = edge_pair_input(True) + u_src, u_dst = edge_pair_input(False) + s_src, s_dst = edge_pair_input(True) u_src = F.tensor(u_src, dtype=F.int32) u_dst = F.tensor(u_dst, dtype=F.int32) @@ -409,7 +409,7 @@ def test_formats(): fail = False finally: assert not fail - + if __name__ == '__main__': test_query() test_mutation() diff --git a/tests/compute/test_transform.py b/tests/compute/test_transform.py index d882304298fc..8b196a8289bd 100644 --- a/tests/compute/test_transform.py +++ b/tests/compute/test_transform.py @@ -23,6 +23,7 @@ import dgl.partition import backend as F import unittest +import math from utils import parametrize_dtype from test_heterograph import create_test_heterograph3, create_test_heterograph4, create_test_heterograph5 @@ -2156,5 +2157,144 @@ def test_module_compose(idtype): eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) assert eset == {(0, 1), (1, 2), (1, 0), (2, 1), (0, 0), (1, 1), (2, 2)} +@parametrize_dtype +def test_module_gcnnorm(idtype): + g = dgl.heterograph({ + ('A', 'r1', 'A'): ([0, 1, 2], [0, 0, 1]), + ('A', 'r2', 'B'): ([0, 0], [1, 1]), + ('B', 'r3', 'B'): ([0, 1, 2], [0, 0, 1]) + }, idtype=idtype, device=F.ctx()) + g.edges['r3'].data['w'] = F.tensor([0.1, 0.2, 0.3]) + transform = dgl.GCNNorm() + new_g = transform(g) + assert 'w' not in new_g.edges[('A', 'r2', 'B')].data + assert F.allclose(new_g.edges[('A', 'r1', 'A')].data['w'], + F.tensor([1./2, 1./math.sqrt(2), 0.])) + assert F.allclose(new_g.edges[('B', 'r3', 'B')].data['w'], F.tensor([1./3, 2./3, 0.])) + +@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now') +@parametrize_dtype +def test_module_ppr(idtype): + g = dgl.graph(([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]), idtype=idtype, device=F.ctx()) + g.ndata['h'] = F.randn((6, 2)) + transform = dgl.PPR(avg_degree=2) + new_g = transform(g) + assert new_g.idtype == g.idtype + assert new_g.device == g.device + assert new_g.num_nodes() == g.num_nodes() + src, dst = new_g.edges() + eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) + assert eset == {(0, 0), (0, 2), (0, 4), (1, 1), (1, 3), (1, 5), (2, 2), + (2, 3), (2, 4), (3, 3), (3, 5), (4, 3), (4, 4), (4, 5), (5, 5)} + assert F.allclose(g.ndata['h'], new_g.ndata['h']) + assert 'w' in new_g.edata + + # Prior edge weights + g.edata['w'] = F.tensor([0.1, 0.2, 0.3, 0.4, 0.5]) + new_g = transform(g) + src, dst = new_g.edges() + eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) + assert eset == {(0, 0), (1, 1), (1, 3), (2, 2), (2, 3), (2, 4), + (3, 3), (3, 5), (4, 3), (4, 4), (4, 5), (5, 5)} + +@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now') +@parametrize_dtype +def test_module_heat_kernel(idtype): + # Case1: directed graph + g = dgl.graph(([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]), idtype=idtype, device=F.ctx()) + g.ndata['h'] = F.randn((6, 2)) + transform = dgl.HeatKernel(avg_degree=1) + new_g = transform(g) + assert new_g.idtype == g.idtype + assert new_g.device == g.device + assert new_g.num_nodes() == g.num_nodes() + src, dst = new_g.edges() + eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) + assert eset == {(0, 2), (0, 4), (1, 3), (1, 5), (2, 3), (2, 4), (3, 5), (4, 5)} + assert F.allclose(g.ndata['h'], new_g.ndata['h']) + assert 'w' in new_g.edata + + # Case2: weighted undirected graph + g = dgl.graph(([0, 1, 2, 3], [1, 0, 3, 2]), idtype=idtype, device=F.ctx()) + g.edata['w'] = F.tensor([0.1, 0.2, 0.3, 0.4]) + new_g = transform(g) + src, dst = new_g.edges() + eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) + assert eset == {(0, 0), (1, 1), (2, 2), (3, 3)} + +@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now') +@parametrize_dtype +def test_module_gdc(idtype): + transform = dgl.GDC([0.1, 0.2, 0.1], avg_degree=1) + g = dgl.graph(([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]), idtype=idtype, device=F.ctx()) + g.ndata['h'] = F.randn((6, 2)) + new_g = transform(g) + assert new_g.idtype == g.idtype + assert new_g.device == g.device + assert new_g.num_nodes() == g.num_nodes() + src, dst = new_g.edges() + eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) + assert eset == {(0, 0), (0, 2), (0, 4), (1, 1), (1, 3), (1, 5), (2, 2), (2, 3), + (2, 4), (3, 3), (3, 5), (4, 3), (4, 4), (4, 5), (5, 5)} + assert F.allclose(g.ndata['h'], new_g.ndata['h']) + assert 'w' in new_g.edata + + # Prior edge weights + g.edata['w'] = F.tensor([0.1, 0.2, 0.3, 0.4, 0.5]) + new_g = transform(g) + src, dst = new_g.edges() + eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) + assert eset == {(0, 0), (1, 1), (2, 2), (3, 3), (4, 3), (4, 4), (5, 5)} + +@parametrize_dtype +def test_module_node_shuffle(idtype): + transform = dgl.NodeShuffle() + g = dgl.heterograph({ + ('A', 'r', 'B'): ([0, 1], [1, 2]), + }, idtype=idtype, device=F.ctx()) + new_g = transform(g) + +@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now') +@parametrize_dtype +def test_module_drop_node(idtype): + transform = dgl.DropNode() + g = dgl.heterograph({ + ('A', 'r', 'B'): ([0, 1], [1, 2]), + }, idtype=idtype, device=F.ctx()) + new_g = transform(g) + assert new_g.idtype == g.idtype + assert new_g.device == g.device + assert new_g.ntypes == g.ntypes + assert new_g.canonical_etypes == g.canonical_etypes + +@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now') +@parametrize_dtype +def test_module_drop_edge(idtype): + transform = dgl.DropEdge() + g = dgl.heterograph({ + ('A', 'r1', 'B'): ([0, 1], [1, 2]), + ('C', 'r2', 'C'): ([3, 4, 5], [6, 7, 8]) + }, idtype=idtype, device=F.ctx()) + new_g = transform(g) + assert new_g.idtype == g.idtype + assert new_g.device == g.device + assert new_g.ntypes == g.ntypes + assert new_g.canonical_etypes == g.canonical_etypes + +@parametrize_dtype +def test_module_add_edge(idtype): + transform = dgl.AddEdge() + g = dgl.heterograph({ + ('A', 'r1', 'B'): ([0, 1, 2, 3, 4], [1, 2, 3, 4, 5]), + ('C', 'r2', 'C'): ([0, 1, 2, 3, 4], [1, 2, 3, 4, 5]) + }, idtype=idtype, device=F.ctx()) + new_g = transform(g) + assert new_g.num_edges(('A', 'r1', 'B')) == 6 + assert new_g.num_edges(('C', 'r2', 'C')) == 6 + assert new_g.idtype == g.idtype + assert new_g.device == g.device + assert new_g.ntypes == g.ntypes + assert new_g.canonical_etypes == g.canonical_etypes + if __name__ == '__main__': test_partition_with_halo()