From 3ba050e6f03d41bd4c49a96b219845af426c1de1 Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Fri, 7 Aug 2020 17:04:36 +0800 Subject: [PATCH] remove the last softmax activation --- graphgallery/nn/layers/__init__.py | 2 +- graphgallery/nn/layers/misc.py | 21 +++++++++ graphgallery/nn/models/__init__.py | 12 ++--- .../nn/models/semisupervised/chebynet.py | 17 +++---- .../nn/models/semisupervised/clustergcn.py | 17 ++++--- .../nn/models/semisupervised/densegcn.py | 15 ++++--- .../{ => experimental}/edgeconv.py | 19 ++++---- .../{ => experimental}/gcn_mix.py | 4 +- .../semisupervised/{ => experimental}/gcnf.py | 45 ++++++++++--------- .../{ => experimental}/mediansage.py | 13 +++--- .../{ => experimental}/s_obvat.py | 16 ++++--- .../nn/models/semisupervised/fastgcn.py | 8 ++-- graphgallery/nn/models/semisupervised/gat.py | 15 ++++--- graphgallery/nn/models/semisupervised/gcn.py | 21 ++++----- graphgallery/nn/models/semisupervised/gmnn.py | 27 ++++++----- .../nn/models/semisupervised/graphsage.py | 13 +++--- graphgallery/nn/models/semisupervised/gwnn.py | 15 ++++--- graphgallery/nn/models/semisupervised/lgcn.py | 24 +++++----- .../nn/models/semisupervised/obvat.py | 16 ++++--- .../nn/models/semisupervised/robustgcn.py | 17 +++---- .../nn/models/semisupervised/sbvat.py | 11 +++-- .../semisupervised/semi_supervised_model.py | 43 ++++++++++-------- graphgallery/nn/models/semisupervised/sgc.py | 12 ++--- graphgallery/utils/context_manager.py | 0 graphgallery/utils/data_utils.py | 1 + graphgallery/utils/graph_utils.py | 3 +- 26 files changed, 230 insertions(+), 177 deletions(-) rename graphgallery/nn/models/semisupervised/{ => experimental}/edgeconv.py (92%) rename graphgallery/nn/models/semisupervised/{ => experimental}/gcn_mix.py (97%) rename graphgallery/nn/models/semisupervised/{ => experimental}/gcnf.py (82%) rename graphgallery/nn/models/semisupervised/{ => experimental}/mediansage.py (95%) rename graphgallery/nn/models/semisupervised/{ => experimental}/s_obvat.py (93%) mode change 100644 => 100755 graphgallery/utils/context_manager.py diff --git a/graphgallery/nn/layers/__init__.py b/graphgallery/nn/layers/__init__.py index 117dd39a..c640d113 100755 --- a/graphgallery/nn/layers/__init__.py +++ b/graphgallery/nn/layers/__init__.py @@ -11,4 +11,4 @@ from graphgallery.nn.layers.edgeconv import GraphEdgeConvolution from graphgallery.nn.layers.mediansage import MedianAggregator, MedianGCNAggregator from graphgallery.nn.layers.gcnf import GraphConvFeature -from graphgallery.nn.layers.misc import SparseConversion, Scale, Sample +from graphgallery.nn.layers.misc import SparseConversion, Scale, Sample, Gather diff --git a/graphgallery/nn/layers/misc.py b/graphgallery/nn/layers/misc.py index 11f21721..982b4b72 100755 --- a/graphgallery/nn/layers/misc.py +++ b/graphgallery/nn/layers/misc.py @@ -57,3 +57,24 @@ def get_config(self): def compute_output_shape(self, input_shapes): return tf.TensorShape(input_shapes[0]) + + +class Gather(Layer): + def __init__(self, axis=0, *args, **kwargs): + super().__init__(*args, **kwargs) + self.axis = axis + + def call(self, inputs): + params, indices = inputs + output = tf.gather(params, indices, axis=self.axis) + return output + + def get_config(self): + base_config = super().get_config() + return base_config + + def compute_output_shape(self, input_shapes): + axis = self.axis + params_shape, indices_shape = input_shapes + output_shape = params_shape[:axis] + indices_shape + params_shape[axis + 1:] + return tf.TensorShape(output_shape) diff --git a/graphgallery/nn/models/__init__.py b/graphgallery/nn/models/__init__.py index f08638ac..5313d1b5 100755 --- a/graphgallery/nn/models/__init__.py +++ b/graphgallery/nn/models/__init__.py @@ -4,7 +4,6 @@ # (semi-)supervised model from graphgallery.nn.models.semisupervised.semi_supervised_model import SemiSupervisedModel from graphgallery.nn.models.semisupervised.gcn import GCN -from graphgallery.nn.models.semisupervised.gcn_mix import GCN_MIX from graphgallery.nn.models.semisupervised.sgc import SGC from graphgallery.nn.models.semisupervised.gat import GAT from graphgallery.nn.models.semisupervised.clustergcn import ClusterGCN @@ -17,11 +16,14 @@ from graphgallery.nn.models.semisupervised.lgcn import LGCN from graphgallery.nn.models.semisupervised.obvat import OBVAT from graphgallery.nn.models.semisupervised.sbvat import SBVAT -from graphgallery.nn.models.semisupervised.s_obvat import SimplifiedOBVAT from graphgallery.nn.models.semisupervised.gmnn import GMNN -from graphgallery.nn.models.semisupervised.edgeconv import EdgeGCN -from graphgallery.nn.models.semisupervised.mediansage import MedianSAGE -from graphgallery.nn.models.semisupervised.gcnf import GCNF + +# experimental +from graphgallery.nn.models.semisupervised.experimental.gcnf import GCNF +from graphgallery.nn.models.semisupervised.experimental.gcn_mix import GCN_MIX +from graphgallery.nn.models.semisupervised.experimental.s_obvat import SimplifiedOBVAT +from graphgallery.nn.models.semisupervised.experimental.edgeconv import EdgeGCN +from graphgallery.nn.models.semisupervised.experimental.mediansage import MedianSAGE # unsupervised model diff --git a/graphgallery/nn/models/semisupervised/chebynet.py b/graphgallery/nn/models/semisupervised/chebynet.py index 9aa13fa2..cfcdaaa6 100755 --- a/graphgallery/nn/models/semisupervised/chebynet.py +++ b/graphgallery/nn/models/semisupervised/chebynet.py @@ -1,10 +1,11 @@ import tensorflow as tf from tensorflow.keras import Model, Input -from tensorflow.keras.layers import Dropout, Softmax +from tensorflow.keras.layers import Dropout from tensorflow.keras.optimizers import Adam from tensorflow.keras import regularizers +from tensorflow.keras.losses import SparseCategoricalCrossentropy -from graphgallery.nn.layers import ChebyConvolution +from graphgallery.nn.layers import ChebyConvolution, Gather from graphgallery.sequence import FullBatchNodeSequence from graphgallery.nn.models import SemiSupervisedModel from graphgallery.utils.misc import chebyshev_polynomials @@ -37,7 +38,7 @@ class ChebyNet(SemiSupervisedModel): i.e., math:: \hat{A} = D^{-\frac{1}{2}} A D^{-\frac{1}{2}}) norm_x (String, optional): How to normalize the node feature matrix. See `graphgallery.normalize_x` - (default :str: `l1`) + (default :obj: `None`) device (String, optional): The device where the model is running on. You can specified `CPU` or `GPU` for the model. (default: :str: `CPU:0`, i.e., running on the 0-th `CPU`) @@ -51,7 +52,7 @@ class ChebyNet(SemiSupervisedModel): """ def __init__(self, adj, x, labels, order=2, norm_adj=-0.5, - norm_x='l1', device='CPU:0', seed=None, name=None, **kwargs): + norm_x=None, device='CPU:0', seed=None, name=None, **kwargs): super().__init__(adj, x, labels, device=device, seed=seed, name=name, **kwargs) @@ -103,11 +104,11 @@ def build(self, hiddens=[16], activations=['relu'], dropouts=[0.5], l2_norms=[5e h = Dropout(rate=dropout)(h) h = ChebyConvolution(self.n_classes, order=self.order, use_bias=use_bias)([h, adj]) - h = tf.gather(h, index) - output = Softmax()(h) + h = Gather()([h, index]) - model = Model(inputs=[x, *adj, index], outputs=output) - model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(lr=lr), metrics=['accuracy']) + model = Model(inputs=[x, *adj, index], outputs=h) + model.compile(loss=SparseCategoricalCrossentropy(from_logits=True), + optimizer=Adam(lr=lr), metrics=['accuracy']) self.model = model def train_sequence(self, index): diff --git a/graphgallery/nn/models/semisupervised/clustergcn.py b/graphgallery/nn/models/semisupervised/clustergcn.py index 37a0c8b8..d736fa0e 100755 --- a/graphgallery/nn/models/semisupervised/clustergcn.py +++ b/graphgallery/nn/models/semisupervised/clustergcn.py @@ -2,9 +2,10 @@ import networkx as nx import tensorflow as tf from tensorflow.keras import Model, Input -from tensorflow.keras.layers import Dropout, Softmax +from tensorflow.keras.layers import Dropout from tensorflow.keras.optimizers import Adam from tensorflow.keras import regularizers +from tensorflow.keras.losses import SparseCategoricalCrossentropy from graphgallery.nn.layers import GraphConvolution, SparseConversion from graphgallery.nn.models import SemiSupervisedModel @@ -46,7 +47,7 @@ class ClusterGCN(SemiSupervisedModel): i.e., math:: \hat{A} = D^{-\frac{1}{2}} A D^{-\frac{1}{2}}) norm_x (String, optional): How to normalize the node feature matrix. See `graphgallery.normalize_x` - (default :str: `l1`) + (default :obj: `None`) device (String, optional): The device where the model is running on. You can specified `CPU` or `GPU` for the model. (default: :str: `CPU:0`, i.e., running on the 0-th `CPU`) @@ -60,7 +61,7 @@ class ClusterGCN(SemiSupervisedModel): """ def __init__(self, adj, x, labels, graph=None, n_clusters=None, - norm_adj=-0.5, norm_x='l1', device='CPU:0', seed=None, name=None, **kwargs): + norm_adj=-0.5, norm_x=None, device='CPU:0', seed=None, name=None, **kwargs): super().__init__(adj, x, labels=labels, device=device, seed=seed, name=name, **kwargs) @@ -118,7 +119,7 @@ def build(self, hiddens=[32], activations=['relu'], dropouts=[0.5], l2_norms=[1e mask = Input(batch_shape=[None], dtype=tf.bool, name='mask') adj = SparseConversion()([edge_index, edge_weight]) - + h = x for hid, activation, dropout, l2_norm in zip(hiddens, activations, dropouts, l2_norms): h = Dropout(rate=dropout)(h) @@ -128,12 +129,10 @@ def build(self, hiddens=[32], activations=['relu'], dropouts=[0.5], l2_norms=[1e h = Dropout(rate=dropout)(h) h = GraphConvolution(self.n_classes, use_bias=use_bias)([h, adj]) h = tf.boolean_mask(h, mask) - output = Softmax()(h) - - model = Model(inputs=[x, edge_index, edge_weight, mask], outputs=output) - model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(lr=lr), - metrics=['accuracy']) + model = Model(inputs=[x, edge_index, edge_weight, mask], outputs=h) + model.compile(loss=SparseCategoricalCrossentropy(from_logits=True), + optimizer=Adam(lr=lr), metrics=['accuracy']) self.model = model diff --git a/graphgallery/nn/models/semisupervised/densegcn.py b/graphgallery/nn/models/semisupervised/densegcn.py index ab5f5bdf..d58da56f 100755 --- a/graphgallery/nn/models/semisupervised/densegcn.py +++ b/graphgallery/nn/models/semisupervised/densegcn.py @@ -5,8 +5,9 @@ from tensorflow.keras.layers import Dropout, Softmax from tensorflow.keras.optimizers import Adam from tensorflow.keras import regularizers +from tensorflow.keras.losses import SparseCategoricalCrossentropy -from graphgallery.nn.layers import DenseGraphConv +from graphgallery.nn.layers import DenseGraphConv, Gather from graphgallery.nn.models import SemiSupervisedModel from graphgallery.sequence import FullBatchNodeSequence from graphgallery.utils.shape_utils import set_equal_in_length @@ -42,7 +43,7 @@ class DenseGCN(SemiSupervisedModel): i.e., math:: \hat{A} = D^{-\frac{1}{2}} A D^{-\frac{1}{2}}) norm_x (String, optional): How to normalize the node feature matrix. See `graphgallery.normalize_x` - (default :str: `l1`) + (default :obj: `None`) device (String, optional): The device where the model is running on. You can specified `CPU` or `GPU` for the model. (default: :str: `CPU:0`, i.e., running on the 0-th `CPU`) @@ -56,7 +57,7 @@ class DenseGCN(SemiSupervisedModel): """ - def __init__(self, adj, x, labels, norm_adj=-0.5, norm_x='l1', + def __init__(self, adj, x, labels, norm_adj=-0.5, norm_x=None, device='CPU:0', seed=None, name=None, **kwargs): super().__init__(adj, x, labels, device=device, seed=seed, name=name, **kwargs) @@ -112,11 +113,11 @@ def build(self, hiddens=[16], activations=['relu'], dropouts=[0.5], l2_norms=[5e h = Dropout(rate=dropout)(h) h = DenseGraphConv(self.n_classes, use_bias=use_bias)([h, adj]) - h = tf.gather(h, index) - output = Softmax()(h) + h = Gather()([h, index]) - model = Model(inputs=[x, adj, index], outputs=output) - model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(lr=lr), metrics=['accuracy']) + model = Model(inputs=[x, adj, index], outputs=h) + model.compile(loss=SparseCategoricalCrossentropy(from_logits=True), + optimizer=Adam(lr=lr), metrics=['accuracy']) self.model = model def train_sequence(self, index): diff --git a/graphgallery/nn/models/semisupervised/edgeconv.py b/graphgallery/nn/models/semisupervised/experimental/edgeconv.py similarity index 92% rename from graphgallery/nn/models/semisupervised/edgeconv.py rename to graphgallery/nn/models/semisupervised/experimental/edgeconv.py index 46057f62..1900be8c 100755 --- a/graphgallery/nn/models/semisupervised/edgeconv.py +++ b/graphgallery/nn/models/semisupervised/experimental/edgeconv.py @@ -1,11 +1,12 @@ import numpy as np import tensorflow as tf from tensorflow.keras import Model, Input -from tensorflow.keras.layers import Dropout, Softmax +from tensorflow.keras.layers import Dropout from tensorflow.keras.optimizers import Adam from tensorflow.keras import regularizers +from tensorflow.keras.losses import SparseCategoricalCrossentropy -from graphgallery.nn.layers import GraphEdgeConvolution +from graphgallery.nn.layers import GraphEdgeConvolution, Gather from graphgallery.nn.models import SemiSupervisedModel from graphgallery.sequence import FullBatchNodeSequence from graphgallery.utils.shape_utils import set_equal_in_length @@ -44,7 +45,7 @@ class EdgeGCN(SemiSupervisedModel): i.e., math:: \hat{A} = D^{-\frac{1}{2}} A D^{-\frac{1}{2}}) norm_x (String, optional): How to normalize the node feature matrix. See `graphgallery.normalize_x` - (default :str: `l1`) + (default :obj: `None`) device (String, optional): The device where the model is running on. You can specified `CPU` or `GPU` for the model. (default: :str: `CPU:0`, i.e., running on the 0-th `CPU`) @@ -57,7 +58,7 @@ class EdgeGCN(SemiSupervisedModel): """ - def __init__(self, adj, x, labels, norm_adj=-0.5, norm_x='l1', + def __init__(self, adj, x, labels, norm_adj=-0.5, norm_x=None, device='CPU:0', seed=None, name=None, **kwargs): super().__init__(adj, x, labels, device=device, seed=seed, name=name, **kwargs) @@ -109,11 +110,11 @@ def build(self, hiddens=[16], activations=['relu'], dropouts=[0.5], l2_norms=[5e h = Dropout(rate=dropout)(h) h = GraphEdgeConvolution(self.n_classes, use_bias=use_bias)([h, edge_index, edge_weight]) - h = tf.gather(h, index) - output = Softmax()(h) - - model = Model(inputs=[x, edge_index, edge_weight, index], outputs=output) - model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(lr=lr), metrics=['accuracy']) + h = Gather()([h, index]) + + model = Model(inputs=[x, edge_index, edge_weight, index], outputs=h) + model.compile(loss=SparseCategoricalCrossentropy(from_logits=True), + optimizer=Adam(lr=lr), metrics=['accuracy']) self.model = model def train_sequence(self, index): diff --git a/graphgallery/nn/models/semisupervised/gcn_mix.py b/graphgallery/nn/models/semisupervised/experimental/gcn_mix.py similarity index 97% rename from graphgallery/nn/models/semisupervised/gcn_mix.py rename to graphgallery/nn/models/semisupervised/experimental/gcn_mix.py index 44bbd76a..8407ac6e 100755 --- a/graphgallery/nn/models/semisupervised/gcn_mix.py +++ b/graphgallery/nn/models/semisupervised/experimental/gcn_mix.py @@ -36,7 +36,7 @@ class GCN_MIX(SemiSupervisedModel): i.e., math:: \hat{A} = D^{-\frac{1}{2}} A D^{-\frac{1}{2}}) norm_x (String, optional): How to normalize the node feature matrix. See `graphgallery.normalize_x` - (default :str: `l1`) + (default :obj: `None`) device (String, optional): The device where the model is running on. You can specified `CPU` or `GPU` for the model. (default: :str: `CPU:0`, i.e., running on the 0-th `CPU`) @@ -49,7 +49,7 @@ class GCN_MIX(SemiSupervisedModel): """ - def __init__(self, adj, x, labels, norm_adj=-0.5, norm_x='l1', + def __init__(self, adj, x, labels, norm_adj=-0.5, norm_x=None, device='CPU:0', seed=None, name=None, **kwargs): super().__init__(adj, x, labels, device=device, seed=seed, name=name, **kwargs) diff --git a/graphgallery/nn/models/semisupervised/gcnf.py b/graphgallery/nn/models/semisupervised/experimental/gcnf.py similarity index 82% rename from graphgallery/nn/models/semisupervised/gcnf.py rename to graphgallery/nn/models/semisupervised/experimental/gcnf.py index 8a9182fc..d2cdd0d4 100755 --- a/graphgallery/nn/models/semisupervised/gcnf.py +++ b/graphgallery/nn/models/semisupervised/experimental/gcnf.py @@ -1,10 +1,11 @@ import tensorflow as tf from tensorflow.keras import Model, Input -from tensorflow.keras.layers import Dropout, Softmax +from tensorflow.keras.layers import Dropout from tensorflow.keras.optimizers import Adam from tensorflow.keras import regularizers +from tensorflow.keras.losses import SparseCategoricalCrossentropy -from graphgallery.nn.layers import GraphConvFeature +from graphgallery.nn.layers import GraphConvFeature, Gather from graphgallery.nn.models import SemiSupervisedModel from graphgallery.sequence import FullBatchNodeSequence from graphgallery.utils.shape_utils import set_equal_in_length @@ -17,34 +18,34 @@ class GCNF(SemiSupervisedModel): Arguments: ---------- - adj: shape (N, N), Scipy sparse matrix if `is_adj_sparse=True`, + adj: shape (N, N), Scipy sparse matrix if `is_adj_sparse=True`, Numpy array-like (or matrix) if `is_adj_sparse=False`. - The input `symmetric` adjacency matrix, where `N` is the number + The input `symmetric` adjacency matrix, where `N` is the number of nodes in graph. - x: shape (N, F), Scipy sparse matrix if `is_x_sparse=True`, + x: shape (N, F), Scipy sparse matrix if `is_x_sparse=True`, Numpy array-like (or matrix) if `is_x_sparse=False`. The input node feature matrix, where `F` is the dimension of features. labels: Numpy array-like with shape (N,) The ground-truth labels for all nodes in graph. - norm_adj (Float scalar, optional): - The normalize rate for adjacency matrix `adj`. (default: :obj:`-0.5`, - i.e., math:: \hat{A} = D^{-\frac{1}{2}} A D^{-\frac{1}{2}}) - norm_x (Boolean, optional): - Whether to use row-wise normalization for node feature matrix. - (default :bool: `True`) - device (String, optional): - The device where the model is running on. You can specified `CPU` or `GPU` + norm_adj (Float scalar, optional): + The normalize rate for adjacency matrix `adj`. (default: :obj:`-0.5`, + i.e., math:: \hat{A} = D^{-\frac{1}{2}} A D^{-\frac{1}{2}}) + norm_x (String, optional): + How to normalize the node feature matrix. See `graphgallery.normalize_x` + (default :obj: `None`) + device (String, optional): + The device where the model is running on. You can specified `CPU` or `GPU` for the model. (default: :str: `CPU:0`, i.e., running on the 0-th `CPU`) - seed (Positive integer, optional): - Used in combination with `tf.random.set_seed` & `np.random.seed` & `random.seed` - to create a reproducible sequence of tensors across multiple calls. + seed (Positive integer, optional): + Used in combination with `tf.random.set_seed` & `np.random.seed` & `random.seed` + to create a reproducible sequence of tensors across multiple calls. (default :obj: `None`, i.e., using random seed) - name (String, optional): + name (String, optional): Specified name for the model. (default: :str: `class.__name__`) """ - def __init__(self, adj, x, labels, norm_adj=-0.5, norm_x='l1', + def __init__(self, adj, x, labels, norm_adj=-0.5, norm_x=None, device='CPU:0', seed=None, name=None, **kwargs): super().__init__(adj, x, labels, device=device, seed=seed, name=name, **kwargs) @@ -100,11 +101,11 @@ def build(self, hiddens=[16], activations=['relu'], dropouts=[0.5], l2_norms=[5e # of the input data to remain the same # if ensure_shape: # h = tf.ensure_shape(h, [self.n_nodes, self.n_classes]) - h = tf.gather(h, index) - output = Softmax()(h) + h = Gather()([h, index]) - model = Model(inputs=[x, adj, index], outputs=output) - model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(lr=lr), metrics=['accuracy']) + model = Model(inputs=[x, adj, index], outputs=h) + model.compile(loss=SparseCategoricalCrossentropy(from_logits=True), + optimizer=Adam(lr=lr), metrics=['accuracy']) self.model = model diff --git a/graphgallery/nn/models/semisupervised/mediansage.py b/graphgallery/nn/models/semisupervised/experimental/mediansage.py similarity index 95% rename from graphgallery/nn/models/semisupervised/mediansage.py rename to graphgallery/nn/models/semisupervised/experimental/mediansage.py index 5a3fcd2d..3b3aacd9 100755 --- a/graphgallery/nn/models/semisupervised/mediansage.py +++ b/graphgallery/nn/models/semisupervised/experimental/mediansage.py @@ -1,9 +1,10 @@ import numpy as np import tensorflow as tf from tensorflow.keras import Model, Input -from tensorflow.keras.layers import Dropout, Softmax +from tensorflow.keras.layers import Dropout from tensorflow.keras.optimizers import Adam from tensorflow.keras import regularizers +from tensorflow.keras.losses import SparseCategoricalCrossentropy from graphgallery.nn.layers import MedianAggregator, MedianGCNAggregator from graphgallery.nn.models import SemiSupervisedModel @@ -34,7 +35,7 @@ class MedianSAGE(SemiSupervisedModel): `5` sencond-order neighbors, and the radius for `GraphSAGE` is `2`) norm_x (String, optional): How to normalize the node feature matrix. See `graphgallery.normalize_x` - (default :str: `l1`) + (default :obj: `None`) device (String, optional): The device where the model is running on. You can specified `CPU` or `GPU` for the model. (default: :str: `CPU:0`, i.e., running on the 0-th `CPU`) @@ -48,7 +49,7 @@ class MedianSAGE(SemiSupervisedModel): """ - def __init__(self, adj, x, labels, n_samples=[15, 3], norm_x='l1', + def __init__(self, adj, x, labels, n_samples=[15, 3], norm_x=None, device='CPU:0', seed=None, name=None, **kwargs): super().__init__(adj, x, labels, device=device, seed=seed, name=name, **kwargs) @@ -124,10 +125,10 @@ def build(self, hiddens=[32], activations=['relu'], dropouts=[0.5], l2_norms=[5e h = h[0] if output_normalize: h = tf.nn.l2_normalize(h, axis=1) - output = Softmax()(h) - model = Model(inputs=[x, nodes, *neighbors], outputs=output) - model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(lr=lr), metrics=['accuracy']) + model = Model(inputs=[x, nodes, *neighbors], outputs=h) + model.compile(loss=SparseCategoricalCrossentropy(from_logits=True), + optimizer=Adam(lr=lr), metrics=['accuracy']) self.model = model diff --git a/graphgallery/nn/models/semisupervised/s_obvat.py b/graphgallery/nn/models/semisupervised/experimental/s_obvat.py similarity index 93% rename from graphgallery/nn/models/semisupervised/s_obvat.py rename to graphgallery/nn/models/semisupervised/experimental/s_obvat.py index 5811926a..87e9db01 100755 --- a/graphgallery/nn/models/semisupervised/s_obvat.py +++ b/graphgallery/nn/models/semisupervised/experimental/s_obvat.py @@ -1,11 +1,12 @@ import tensorflow as tf from tensorflow.keras import Model, Input -from tensorflow.keras.layers import Dropout, Softmax +from tensorflow.keras.layers import Dropout from tensorflow.keras.optimizers import Adam from tensorflow.keras import regularizers from tensorflow.keras.initializers import TruncatedNormal +from tensorflow.keras.losses import SparseCategoricalCrossentropy -from graphgallery.nn.layers import GraphConvolution +from graphgallery.nn.layers import GraphConvolution, Gather from graphgallery.nn.models import SemiSupervisedModel from graphgallery.sequence import FullBatchNodeSequence from graphgallery.utils.bvat_utils import kl_divergence_with_logit, entropy_y_x, get_normalized_vector @@ -40,7 +41,7 @@ class SimplifiedOBVAT(SemiSupervisedModel): i.e., math:: \hat{A} = D^{-\frac{1}{2}} A D^{-\frac{1}{2}}) norm_x (String, optional): How to normalize the node feature matrix. See `graphgallery.normalize_x` - (default :str: `l1`) + (default :obj: `None`) device (String, optional): The device where the model is running on. You can specified `CPU` or `GPU` for the model. (default: :str: `CPU:0`, i.e., running on the 0-th `CPU`) @@ -54,7 +55,7 @@ class SimplifiedOBVAT(SemiSupervisedModel): """ - def __init__(self, adj, x, labels, norm_adj=-0.5, norm_x='l1', + def __init__(self, adj, x, labels, norm_adj=-0.5, norm_x=None, device='CPU:0', seed=None, name=None, **kwargs): super().__init__(adj, x, labels, device=device, seed=seed, name=name, **kwargs) @@ -109,10 +110,11 @@ def build(self, hiddens=[16], activations=['relu'], dropouts=[0.5], self.dropout_layers = dropout_layers logit = self.propagation(x, adj) - output = tf.gather(logit, index) - output = Softmax()(output) + output = Gather()([logit, index]) + model = Model(inputs=[x, adj, index], outputs=output) - model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(lr=lr), metrics=['accuracy']) + model.compile(loss=SparseCategoricalCrossentropy(from_logits=True), + optimizer=Adam(lr=lr), metrics=['accuracy']) entropy_loss = entropy_y_x(logit) vat_loss = self.virtual_adversarial_loss(x, adj, logit, epsilon) diff --git a/graphgallery/nn/models/semisupervised/fastgcn.py b/graphgallery/nn/models/semisupervised/fastgcn.py index 6fbb4f29..e68591fe 100755 --- a/graphgallery/nn/models/semisupervised/fastgcn.py +++ b/graphgallery/nn/models/semisupervised/fastgcn.py @@ -3,6 +3,7 @@ from tensorflow.keras.layers import Dense, Dropout from tensorflow.keras.optimizers import Adam from tensorflow.keras import regularizers +from tensorflow.keras.losses import SparseCategoricalCrossentropy from graphgallery.nn.layers import GraphConvolution from graphgallery.nn.models import SemiSupervisedModel @@ -103,10 +104,11 @@ def build(self, hiddens=[32], activations=['relu'], dropouts=[0.5], l2_norms=[5e kernel_regularizer=regularizers.l2(l2_norm))(h) h = Dropout(rate=dropout)(h) - output = GraphConvolution(self.n_classes, use_bias=use_bias, activation='softmax')([h, adj]) + h = GraphConvolution(self.n_classes, use_bias=use_bias)([h, adj]) - model = Model(inputs=[x, adj], outputs=output) - model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(lr=lr), metrics=['accuracy']) + model = Model(inputs=[x, adj], outputs=h) + model.compile(loss=SparseCategoricalCrossentropy(from_logits=True), + optimizer=Adam(lr=lr), metrics=['accuracy']) self.model = model def predict(self, index): diff --git a/graphgallery/nn/models/semisupervised/gat.py b/graphgallery/nn/models/semisupervised/gat.py index b5eb62df..9031ec66 100755 --- a/graphgallery/nn/models/semisupervised/gat.py +++ b/graphgallery/nn/models/semisupervised/gat.py @@ -4,8 +4,9 @@ from tensorflow.keras.layers import Dropout, Softmax from tensorflow.keras.optimizers import Adam from tensorflow.keras import regularizers +from tensorflow.keras.losses import SparseCategoricalCrossentropy -from graphgallery.nn.layers import GraphAttention +from graphgallery.nn.layers import GraphAttention, Gather from graphgallery.nn.models import SemiSupervisedModel from graphgallery.sequence import FullBatchNodeSequence from graphgallery.utils.shape_utils import set_equal_in_length @@ -35,7 +36,7 @@ class GAT(SemiSupervisedModel): The normalize rate for adjacency matrix `adj`. (default: :obj: `None`) norm_x (String, optional): How to normalize the node feature matrix. See `graphgallery.normalize_x` - (default :str: `l1`) + (default :obj: `None`) device (String, optional): The device where the model is running on. You can specified `CPU` or `GPU` for the model. (default: :str: `CPU:0`, i.e., running on the 0-th `CPU`) @@ -49,7 +50,7 @@ class GAT(SemiSupervisedModel): """ - def __init__(self, adj, x, labels, norm_adj=None, norm_x='l1', + def __init__(self, adj, x, labels, norm_adj=None, norm_x=None, device='CPU:0', seed=None, name=None, **kwargs): super().__init__(adj, x, labels, device=device, seed=seed, name=name, **kwargs) @@ -109,11 +110,11 @@ def build(self, hiddens=[8], n_heads=[8], activations=['elu'], dropouts=[0.6], l h = GraphAttention(self.n_classes, use_bias=use_bias, attn_heads=1, attn_heads_reduction='average')([h, adj]) - h = tf.gather(h, index) - output = Softmax()(h) + h = Gather()([h, index]) - model = Model(inputs=[x, adj, index], outputs=output) - model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(lr=lr), metrics=['accuracy']) + model = Model(inputs=[x, adj, index], outputs=h) + model.compile(loss=SparseCategoricalCrossentropy(from_logits=True), + optimizer=Adam(lr=lr), metrics=['accuracy']) self.model = model diff --git a/graphgallery/nn/models/semisupervised/gcn.py b/graphgallery/nn/models/semisupervised/gcn.py index 5c9e50b4..90d104d7 100755 --- a/graphgallery/nn/models/semisupervised/gcn.py +++ b/graphgallery/nn/models/semisupervised/gcn.py @@ -1,10 +1,11 @@ import tensorflow as tf from tensorflow.keras import Model, Input -from tensorflow.keras.layers import Dropout, Softmax +from tensorflow.keras.layers import Dropout from tensorflow.keras.optimizers import Adam from tensorflow.keras import regularizers +from tensorflow.keras.losses import SparseCategoricalCrossentropy -from graphgallery.nn.layers import GraphConvolution +from graphgallery.nn.layers import GraphConvolution, Gather from graphgallery.nn.models import SemiSupervisedModel from graphgallery.sequence import FullBatchNodeSequence from graphgallery.utils.shape_utils import set_equal_in_length @@ -32,9 +33,9 @@ class GCN(SemiSupervisedModel): norm_adj (Float scalar, optional): The normalize rate for adjacency matrix `adj`. (default: :obj:`-0.5`, i.e., math:: \hat{A} = D^{-\frac{1}{2}} A D^{-\frac{1}{2}}) - norm_x (Boolean, optional): - Whether to use row-wise normalization for node feature matrix. - (default :bool: `True`) + norm_x (String, optional): + How to normalize the node feature matrix. See `graphgallery.normalize_x` + (default :obj: `None`) device (String, optional): The device where the model is running on. You can specified `CPU` or `GPU` for the model. (default: :str: `CPU:0`, i.e., running on the 0-th `CPU`) @@ -47,7 +48,7 @@ class GCN(SemiSupervisedModel): """ - def __init__(self, adj, x, labels, norm_adj=-0.5, norm_x='l1', + def __init__(self, adj, x, labels, norm_adj=-0.5, norm_x=None, device='CPU:0', seed=None, name=None, **kwargs): super().__init__(adj, x, labels, device=device, seed=seed, name=name, **kwargs) @@ -98,11 +99,11 @@ def build(self, hiddens=[16], activations=['relu'], dropouts=[0.5], l2_norms=[5e h = Dropout(rate=dropout)(h) h = GraphConvolution(self.n_classes, use_bias=use_bias)([h, adj]) - h = tf.gather(h, index) - output = Softmax()(h) + h = Gather()([h, index]) - model = Model(inputs=[x, adj, index], outputs=output) - model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(lr=lr), metrics=['accuracy']) + model = Model(inputs=[x, adj, index], outputs=h) + model.compile(loss=SparseCategoricalCrossentropy(from_logits=True), + optimizer=Adam(lr=lr), metrics=['accuracy']) self.model = model diff --git a/graphgallery/nn/models/semisupervised/gmnn.py b/graphgallery/nn/models/semisupervised/gmnn.py index 885747c2..e96de83b 100755 --- a/graphgallery/nn/models/semisupervised/gmnn.py +++ b/graphgallery/nn/models/semisupervised/gmnn.py @@ -1,11 +1,13 @@ import numpy as np import tensorflow as tf from tensorflow.keras import Model, Input -from tensorflow.keras.layers import Dropout, Softmax +from tensorflow.keras.layers import Dropout from tensorflow.keras.optimizers import RMSprop, Adam from tensorflow.keras import regularizers +from tensorflow.keras.losses import CategoricalCrossentropy +from tensorflow.keras.activations import softmax -from graphgallery.nn.layers import GraphConvolution +from graphgallery.nn.layers import GraphConvolution, Gather from graphgallery.sequence import FullBatchNodeSequence from graphgallery.nn.models import SemiSupervisedModel from graphgallery.utils.shape_utils import set_equal_in_length @@ -34,7 +36,7 @@ class GMNN(SemiSupervisedModel): i.e., math:: \hat{A} = D^{-\frac{1}{2}} A D^{-\frac{1}{2}}) norm_x (String, optional): How to normalize the node feature matrix. See `graphgallery.normalize_x` - (default :str: `l1`) + (default :obj: `None`) device (String, optional): The device where the model is running on. You can specified `CPU` or `GPU` for the model. (default: :str: `CPU:0`, i.e., running on the 0-th `CPU`) @@ -47,7 +49,7 @@ class GMNN(SemiSupervisedModel): """ - def __init__(self, adj, x, labels, norm_adj=-0.5, norm_x='l1', + def __init__(self, adj, x, labels, norm_adj=-0.5, norm_x=None, device='CPU:0', seed=None, name=None, **kwargs): super().__init__(adj, x, labels, device=device, seed=seed, name=name, **kwargs) @@ -56,8 +58,8 @@ def __init__(self, adj, x, labels, norm_adj=-0.5, norm_x='l1', self.norm_x = norm_x self.preprocess(adj, x) self.labels_onehot = np.eye(self.n_classes)[labels] - - self.custom_objects = {'GraphConvolution': GraphConvolution} + + self.custom_objects = {'GraphConvolution': GraphConvolution, 'Gather': Gather} def preprocess(self, adj, x): super().preprocess(adj, x) @@ -102,11 +104,11 @@ def build_GCN(x): h = Dropout(rate=dropout)(h) h = GraphConvolution(self.n_classes, use_bias=use_bias)([h, adj]) - h = tf.gather(h, index) - output = Softmax()(h) + h = Gather()([h, index]) - model = Model(inputs=[x, adj, index], outputs=output) - model.compile(loss='categorical_crossentropy', optimizer=RMSprop(lr=lr), metrics=['accuracy']) + model = Model(inputs=[x, adj, index], outputs=h) + model.compile(loss=CategoricalCrossentropy(from_logits=True), + optimizer=RMSprop(lr=lr), metrics=['accuracy']) return model # model_p @@ -155,6 +157,7 @@ def train(self, idx_train, idx_val=None, pre_train_epochs=100, # then train model_q again label_predict = self.model.predict_on_batch(astensors([label_predict, self.adj_norm, index_all])) + label_predict = softmax(label_predict) if tf.is_tensor(label_predict): label_predict = label_predict.numpy() @@ -170,9 +173,11 @@ def train(self, idx_train, idx_val=None, pre_train_epochs=100, monitor=monitor, early_stop_metric=early_stop_metric) histories.append(history) - # update training paras and all paras + + ############# Record paras ########### self.train_paras.update(Bunch(pre_train_epochs=pre_train_epochs)) self.paras.update(Bunch(pre_train_epochs=pre_train_epochs)) + ###################################### return histories diff --git a/graphgallery/nn/models/semisupervised/graphsage.py b/graphgallery/nn/models/semisupervised/graphsage.py index 6e5fd535..0244180b 100755 --- a/graphgallery/nn/models/semisupervised/graphsage.py +++ b/graphgallery/nn/models/semisupervised/graphsage.py @@ -1,9 +1,10 @@ import numpy as np import tensorflow as tf from tensorflow.keras import Model, Input -from tensorflow.keras.layers import Dropout, Softmax +from tensorflow.keras.layers import Dropout from tensorflow.keras.optimizers import Adam from tensorflow.keras import regularizers +from tensorflow.keras.losses import SparseCategoricalCrossentropy from graphgallery.nn.layers import MeanAggregator, GCNAggregator from graphgallery.nn.models import SemiSupervisedModel @@ -38,7 +39,7 @@ class GraphSAGE(SemiSupervisedModel): `5` sencond-order neighbors, and the radius for `GraphSAGE` is `2`) norm_x (String, optional): How to normalize the node feature matrix. See `graphgallery.normalize_x` - (default :str: `l1`) + (default :obj: `None`) device (String, optional): The device where the model is running on. You can specified `CPU` or `GPU` for the model. (default: :str: `CPU:0`, i.e., running on the 0-th `CPU`) @@ -52,7 +53,7 @@ class GraphSAGE(SemiSupervisedModel): """ - def __init__(self, adj, x, labels, n_samples=[15, 5], norm_x='l1', + def __init__(self, adj, x, labels, n_samples=[15, 5], norm_x=None, device='CPU:0', seed=None, name=None, **kwargs): super().__init__(adj, x, labels, device=device, seed=seed, name=name, **kwargs) @@ -127,10 +128,10 @@ def build(self, hiddens=[32], activations=['relu'], dropouts=[0.5], l2_norms=[5e h = h[0] if output_normalize: h = tf.nn.l2_normalize(h, axis=1) - output = Softmax()(h) - model = Model(inputs=[x, nodes, *neighbors], outputs=output) - model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(lr=lr), metrics=['accuracy']) + model = Model(inputs=[x, nodes, *neighbors], outputs=h) + model.compile(loss=SparseCategoricalCrossentropy(from_logits=True), + optimizer=Adam(lr=lr), metrics=['accuracy']) self.model = model diff --git a/graphgallery/nn/models/semisupervised/gwnn.py b/graphgallery/nn/models/semisupervised/gwnn.py index b06d6876..d713f8ea 100755 --- a/graphgallery/nn/models/semisupervised/gwnn.py +++ b/graphgallery/nn/models/semisupervised/gwnn.py @@ -3,8 +3,9 @@ from tensorflow.keras.layers import Dropout, Softmax from tensorflow.keras.optimizers import Adam from tensorflow.keras import regularizers +from tensorflow.keras.losses import SparseCategoricalCrossentropy -from graphgallery.nn.layers import WaveletConvolution +from graphgallery.nn.layers import WaveletConvolution, Gather from graphgallery.nn.models import SemiSupervisedModel from graphgallery.sequence import FullBatchNodeSequence from graphgallery.utils.wavelet_utils import wavelet_basis @@ -44,7 +45,7 @@ class GWNN(SemiSupervisedModel): Whether to use row-normalize for wavelet matrix. (default :bool: `True`) norm_x (String, optional): How to normalize the node feature matrix. See `graphgallery.normalize_x` - (default :str: `l1`) + (default :obj: `None`) device (String, optional): The device where the model is running on. You can specified `CPU` or `GPU` for the model. (default: :str: `CPU:0`, i.e., running on the 0-th `CPU`) @@ -57,7 +58,7 @@ class GWNN(SemiSupervisedModel): """ def __init__(self, adj, x, labels, order=3, wavelet_s=1.2, - threshold=1e-4, wavelet_normalize=True, norm_x='l1', + threshold=1e-4, wavelet_normalize=True, norm_x=None, device='CPU:0', seed=None, name=None, **kwargs): super().__init__(adj, x, labels, device=device, seed=seed, name=name, **kwargs) @@ -111,11 +112,11 @@ def build(self, hiddens=[16], activations=['relu'], dropouts=[0.5], l2_norms=[5e h = Dropout(rate=dropout)(h) h = WaveletConvolution(self.n_classes, use_bias=use_bias)([h, wavelet, inverse_wavelet]) - h = tf.gather(h, index) - output = Softmax()(h) + h = Gather()([h, index]) - model = Model(inputs=[x, wavelet, inverse_wavelet, index], outputs=output) - model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(lr=lr), metrics=['accuracy']) + model = Model(inputs=[x, wavelet, inverse_wavelet, index], outputs=h) + model.compile(loss=SparseCategoricalCrossentropy(from_logits=True), + optimizer=Adam(lr=lr), metrics=['accuracy']) self.model = model diff --git a/graphgallery/nn/models/semisupervised/lgcn.py b/graphgallery/nn/models/semisupervised/lgcn.py index e1f8822d..aa13ac64 100755 --- a/graphgallery/nn/models/semisupervised/lgcn.py +++ b/graphgallery/nn/models/semisupervised/lgcn.py @@ -2,9 +2,10 @@ import scipy.sparse as sp import tensorflow as tf from tensorflow.keras import Model, Input -from tensorflow.keras.layers import Dropout, Softmax, Concatenate, BatchNormalization +from tensorflow.keras.layers import Dropout, Concatenate, BatchNormalization from tensorflow.keras.optimizers import Nadam from tensorflow.keras import regularizers +from tensorflow.keras.losses import SparseCategoricalCrossentropy from graphgallery.nn.layers import Top_k_features, LGConvolution, DenseGraphConv from graphgallery.nn.models import SemiSupervisedModel @@ -36,7 +37,7 @@ class LGCN(SemiSupervisedModel): i.e., math:: \hat{A} = D^{-\frac{1}{2}} A D^{-\frac{1}{2}}) norm_x (String, optional): How to normalize the node feature matrix. See `graphgallery.normalize_x` - (default :str: `l1`) + (default :obj: `None`) device (String, optional): The device where the model is running on. You can specified `CPU` or `GPU` for the model. (default: :str: `CPU:0`, i.e., running on the 0-th `CPU`) @@ -50,7 +51,7 @@ class LGCN(SemiSupervisedModel): """ - def __init__(self, adj, x, labels, norm_adj=-0.5, norm_x='l1', + def __init__(self, adj, x, labels, norm_adj=-0.5, norm_x=None, device='CPU:0', seed=None, name=None, **kwargs): super().__init__(adj, x, labels, device=device, seed=seed, name=name, **kwargs) @@ -68,9 +69,8 @@ def preprocess(self, adj, x): if self.norm_adj: adj = normalize_adj(adj, self.norm_adj) - # TODO: the input adj can be dense matrix -# if sp.isspmatrix(adj): -# adj = adj.toarray() + if sp.isspmatrix(adj): + adj = adj.toarray() if self.norm_x: x = normalize_x(x, norm=self.norm_x) @@ -122,10 +122,10 @@ def build(self, hiddens=[32], n_filters=[8, 8], activations=[None], dropouts=[0. kernel_regularizer=regularizers.l2(l2_norms[-1]))([h, adj]) h = tf.boolean_mask(h, mask) - output = Softmax()(h) - model = Model(inputs=[x, adj, mask], outputs=output) - model.compile(loss='sparse_categorical_crossentropy', optimizer=Nadam(lr=lr), metrics=['accuracy']) + model = Model(inputs=[x, adj, mask], outputs=h) + model.compile(loss=SparseCategoricalCrossentropy(from_logits=True), + optimizer=Nadam(lr=lr), metrics=['accuracy']) self.k = k self.model = model @@ -136,7 +136,8 @@ def train_sequence(self, index, batch_size=np.inf): index = get_indice_graph(self.adj_norm, index, batch_size) while index.size < self.k: index = get_indice_graph(self.adj_norm, index) - adj = self.adj_norm[index][:, index].toarray() + + adj = self.adj_norm[index][:, index] x = self.x_norm[index] mask = mask[index] labels = self.labels[index[mask]] @@ -150,9 +151,10 @@ def predict(self, index): index = asintarr(index) mask = sample_mask(index, self.n_nodes) index = get_indice_graph(self.adj_norm, index) + while index.size < self.k: index = get_indice_graph(self.adj_norm, index) - adj = self.adj_norm[index][:, index].toarray() + adj = self.adj_norm[index][:, index] x = self.x_norm[index] mask = mask[index] diff --git a/graphgallery/nn/models/semisupervised/obvat.py b/graphgallery/nn/models/semisupervised/obvat.py index 92d60048..879afd6f 100755 --- a/graphgallery/nn/models/semisupervised/obvat.py +++ b/graphgallery/nn/models/semisupervised/obvat.py @@ -1,11 +1,12 @@ import tensorflow as tf from tensorflow.keras import Model, Input -from tensorflow.keras.layers import Dropout, Softmax +from tensorflow.keras.layers import Dropout from tensorflow.keras.optimizers import Adam from tensorflow.keras import regularizers from tensorflow.keras.initializers import TruncatedNormal +from tensorflow.keras.losses import SparseCategoricalCrossentropy -from graphgallery.nn.layers import GraphConvolution +from graphgallery.nn.layers import GraphConvolution, Gather from graphgallery.nn.models import SemiSupervisedModel from graphgallery.sequence import FullBatchNodeSequence from graphgallery.utils.bvat_utils import kl_divergence_with_logit, entropy_y_x @@ -35,7 +36,7 @@ class OBVAT(SemiSupervisedModel): i.e., math:: \hat{A} = D^{-\frac{1}{2}} A D^{-\frac{1}{2}}) norm_x (String, optional): How to normalize the node feature matrix. See `graphgallery.normalize_x` - (default :str: `l1`) + (default :obj: `None`) device (String, optional): The device where the model is running on. You can specified `CPU` or `GPU` for the model. (default: :str: `CPU:0`, i.e., running on the 0-th `CPU`) @@ -49,7 +50,7 @@ class OBVAT(SemiSupervisedModel): """ - def __init__(self, adj, x, labels, norm_adj=-0.5, norm_x='l1', + def __init__(self, adj, x, labels, norm_adj=-0.5, norm_x=None, device='CPU:0', seed=None, name=None, **kwargs): super().__init__(adj, x, labels, device=device, seed=seed, name=name, **kwargs) @@ -103,10 +104,11 @@ def build(self, hiddens=[16], activations=['relu'], dropouts=[0.5], l2_norms=[5e self.dropout_layers = dropout_layers logit = self.propagation(x, adj) - output = tf.gather(logit, index) - output = Softmax()(output) + output = Gather()([logit, index]) + model = Model(inputs=[x, adj, index], outputs=output) - model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(lr=lr), metrics=['accuracy']) + model.compile(loss=SparseCategoricalCrossentropy(from_logits=True), + optimizer=Adam(lr=lr), metrics=['accuracy']) self.r_vadv = tf.Variable(TruncatedNormal(stddev=0.01)(shape=[self.n_nodes, self.n_features]), name="r_vadv") entropy_loss = entropy_y_x(logit) diff --git a/graphgallery/nn/models/semisupervised/robustgcn.py b/graphgallery/nn/models/semisupervised/robustgcn.py index 7580cda5..e66d98b5 100755 --- a/graphgallery/nn/models/semisupervised/robustgcn.py +++ b/graphgallery/nn/models/semisupervised/robustgcn.py @@ -1,10 +1,11 @@ import tensorflow as tf from tensorflow.keras import Model, Input -from tensorflow.keras.layers import Dropout, Softmax +from tensorflow.keras.layers import Dropout from tensorflow.keras.optimizers import Adam from tensorflow.keras import regularizers +from tensorflow.keras.losses import SparseCategoricalCrossentropy -from graphgallery.nn.layers import GaussionConvolution_F, GaussionConvolution_D, Sample +from graphgallery.nn.layers import GaussionConvolution_F, GaussionConvolution_D, Sample, Gather from graphgallery.nn.models import SemiSupervisedModel from graphgallery.sequence import FullBatchNodeSequence from graphgallery.utils.shape_utils import set_equal_in_length, repeat @@ -35,7 +36,7 @@ class RobustGCN(SemiSupervisedModel): and `-1.0`, respectively) norm_x (String, optional): How to normalize the node feature matrix. See `graphgallery.normalize_x` - (default :str: `l1`) + (default :obj: `None`) device (String, optional): The device where the model is running on. You can specified `CPU` or `GPU` for the model. (default: :str: `CPU:0`, i.e., running on the 0-th `CPU`) @@ -48,7 +49,7 @@ class RobustGCN(SemiSupervisedModel): """ - def __init__(self, adj, x, labels, norm_adj=[-0.5, -1], norm_x='l1', + def __init__(self, adj, x, labels, norm_adj=[-0.5, -1], norm_x=None, device='CPU:0', seed=None, name=None, **kwargs): super().__init__(adj, x, labels, device=device, seed=seed, name=name, **kwargs) @@ -108,11 +109,11 @@ def build(self, hiddens=[64], activations=['relu'], use_bias=False, dropouts=[0. var = Dropout(rate=dropout)(var) mean, var = GaussionConvolution_D(self.n_classes, gamma=gamma, use_bias=use_bias)([mean, var, *adj]) h = Sample(seed=self.seed)(mean, var) - h = tf.gather(h, index) - output = Softmax()(h) + h = Gather()([h, index]) - model = Model(inputs=[x, *adj, index], outputs=output) - model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(lr=lr), metrics=['accuracy']) + model = Model(inputs=[x, *adj, index], outputs=h) + model.compile(loss=SparseCategoricalCrossentropy(from_logits=True), + optimizer=Adam(lr=lr), metrics=['accuracy']) self.model = model diff --git a/graphgallery/nn/models/semisupervised/sbvat.py b/graphgallery/nn/models/semisupervised/sbvat.py index 075e94b3..a76aa573 100755 --- a/graphgallery/nn/models/semisupervised/sbvat.py +++ b/graphgallery/nn/models/semisupervised/sbvat.py @@ -2,14 +2,14 @@ import tensorflow as tf from tensorflow.keras import Model, Input -from tensorflow.keras.layers import Dropout, Softmax +from tensorflow.keras.layers import Dropout from tensorflow.keras.optimizers import Adam from tensorflow.keras import regularizers from tensorflow.keras.losses import sparse_categorical_crossentropy from tensorflow.keras.activations import softmax from tensorflow.keras.metrics import SparseCategoricalAccuracy -from graphgallery.nn.layers import GraphConvolution +from graphgallery.nn.layers import GraphConvolution, Gather from graphgallery.nn.models import SemiSupervisedModel from graphgallery.sequence import NodeSampleSequence from graphgallery.utils.sample_utils import find_4o_nbrs @@ -43,7 +43,7 @@ class SBVAT(SemiSupervisedModel): i.e., math:: \hat{A} = D^{-\frac{1}{2}} A D^{-\frac{1}{2}}) norm_x (String, optional): How to normalize the node feature matrix. See `graphgallery.normalize_x` - (default :str: `l1`) + (default :obj: `None`) device (String, optional): The device where the model is running on. You can specified `CPU` or `GPU` for the model. (default: :str: `CPU:0`, i.e., running on the 0-th `CPU`) @@ -57,7 +57,7 @@ class SBVAT(SemiSupervisedModel): """ def __init__(self, adj, x, labels, n_samples=100, - norm_adj=-0.5, norm_x='l1', device='CPU:0', seed=None, name=None, **kwargs): + norm_adj=-0.5, norm_x=None, device='CPU:0', seed=None, name=None, **kwargs): super().__init__(adj, x, labels, device=device, seed=seed, name=name, **kwargs) @@ -114,8 +114,7 @@ def build(self, hiddens=[16], activations=['relu'], dropouts=[0.5], self.dropout_layers = dropout_layers logit = self.propagation(x, adj) - output = tf.gather(logit, index) - output = Softmax()(output) + output = Gather()([logit, index]) model = Model(inputs=[x, adj, index], outputs=output) self.model = model diff --git a/graphgallery/nn/models/semisupervised/semi_supervised_model.py b/graphgallery/nn/models/semisupervised/semi_supervised_model.py index 269ebd02..54f3ff06 100755 --- a/graphgallery/nn/models/semisupervised/semi_supervised_model.py +++ b/graphgallery/nn/models/semisupervised/semi_supervised_model.py @@ -12,7 +12,7 @@ from graphgallery.nn.models import BaseModel from graphgallery.utils.history import History from graphgallery.utils.tqdm import tqdm -from graphgallery import asintarr, astensors, Bunch +from graphgallery import asintarr, Bunch # Ignora warnings: # UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory. @@ -20,8 +20,8 @@ warnings.filterwarnings( 'ignore', '.*Converting sparse IndexedSlices to a dense Tensor of unknown shape.*') - - + + class SemiSupervisedModel(BaseModel): """ Base model for semi-supervised learning. @@ -172,8 +172,9 @@ def train(self, idx_train, idx_val=None, if isinstance(idx_train, Sequence): train_data = idx_train else: + idx_train = asintarr(idx_train) train_data = self.train_sequence(idx_train) - self.idx_train = asintarr(idx_train) + self.idx_train = idx_train validation = idx_val @@ -181,8 +182,9 @@ def train(self, idx_train, idx_val=None, if isinstance(idx_val, Sequence): val_data = idx_val else: + idx_val = asintarr(idx_val) val_data = self.test_sequence(idx_val) - self.idx_val = asintarr(idx_val) + self.idx_val = idx_val history = History(monitor_metric=monitor, early_stop_metric=early_stop_metric) @@ -324,8 +326,9 @@ def train_v2(self, idx_train, idx_val=None, if isinstance(idx_train, Sequence): train_data = idx_train else: + idx_train = asintarr(idx_train) train_data = self.train_sequence(idx_train) - self.idx_train = asintarr(idx_train) + self.idx_train = idx_train validation = idx_val @@ -333,8 +336,9 @@ def train_v2(self, idx_train, idx_val=None, if isinstance(idx_val, Sequence): val_data = idx_val else: + idx_val = asintarr(idx_val) val_data = self.test_sequence(idx_val) - self.idx_val = asintarr(idx_val) + self.idx_val = idx_val model = self.model if not isinstance(callbacks, callbacks_module.CallbackList): @@ -447,8 +451,9 @@ def test(self, index, **kwargs): if isinstance(index, Sequence): test_data = index else: + index = asintarr(index) test_data = self.test_sequence(index) - self.idx_test = asintarr(index) + self.idx_test = index if self.do_before_test: self.do_before_test(**kwargs) @@ -591,28 +596,30 @@ def get_weights(self): def trainable_variables(self): """Return the trainable weights of model, type `tf.Tensor`.""" return self.model.trainable_variables - + def reset_weights(self): """reset the model to the first time. - """ model = self.model - assert self.backup + if self.backup is None: + raise RuntimeError("You must store the `backup` before `reset_weights`." + "`backup` will be automatically stored when the model is built.") for w, wb in zip(model.weights, self.backup): w.assign(wb) - + def reset_optimizer(self): - + model = self.model if hasattr(model, 'optimizer'): for var in model.optimizer.variables(): - var.assign(tf.zeros_like(var)) - + var.assign(tf.zeros_like(var)) + def reset_lr(self, value): model = self.model - assert hasattr(model, 'optimizer') - model.optimizer.learning_rate.assign(value) - + if not hasattr(model, 'optimizer'): + raise RuntimeError("The model has not attribute `optimizer`!") + model.optimizer.learning_rate.assign(value) + @property def close(self): """Close the session of model and set `built` to False.""" diff --git a/graphgallery/nn/models/semisupervised/sgc.py b/graphgallery/nn/models/semisupervised/sgc.py index 5e12d729..63e155f7 100755 --- a/graphgallery/nn/models/semisupervised/sgc.py +++ b/graphgallery/nn/models/semisupervised/sgc.py @@ -3,6 +3,7 @@ from tensorflow.keras.layers import Dense from tensorflow.keras.optimizers import Adam from tensorflow.keras import regularizers +from tensorflow.keras.losses import SparseCategoricalCrossentropy from graphgallery.nn.layers import SGConvolution from graphgallery.nn.models import SemiSupervisedModel @@ -35,7 +36,7 @@ class SGC(SemiSupervisedModel): i.e., math:: \hat{A} = D^{-\frac{1}{2}} A D^{-\frac{1}{2}}) norm_x (String, optional): How to normalize the node feature matrix. See `graphgallery.normalize_x` - (default :str: `l1`) + (default :obj: `None`) device (String, optional): The device where the model is running on. You can specified `CPU` or `GPU` for the model. (default: :str: `CPU:0`, i.e., running on the 0-th `CPU`) @@ -49,7 +50,7 @@ class SGC(SemiSupervisedModel): """ def __init__(self, adj, x, labels, order=2, - norm_adj=-0.5, norm_x='l1', + norm_adj=-0.5, norm_x=None, device='CPU:0', seed=None, name=None, **kwargs): super().__init__(adj, x, labels, device=device, seed=seed, name=name, **kwargs) @@ -74,7 +75,7 @@ def preprocess(self, adj, x): x = SGConvolution(order=self.order)([x, adj]) self.x_norm, self.adj_norm = x, adj - def build(self, lr=0.2, l2_norms=5e-5, use_bias=True): + def build(self, lr=0.2, l2_norms=[5e-5], use_bias=True): ############# Record paras ########### l2_norms = repeat(l2_norms, 1) local_paras = locals() @@ -89,10 +90,11 @@ def build(self, lr=0.2, l2_norms=5e-5, use_bias=True): x = Input(batch_shape=[None, self.n_features], dtype=self.floatx, name='features') - output = Dense(self.n_classes, activation='softmax', use_bias=use_bias, kernel_regularizer=regularizers.l2(l2_norms[0]))(x) + output = Dense(self.n_classes, activation=None, use_bias=use_bias, kernel_regularizer=regularizers.l2(l2_norms[0]))(x) model = Model(inputs=x, outputs=output) - model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(lr=lr), metrics=['accuracy']) + model.compile(loss=SparseCategoricalCrossentropy(from_logits=True), + optimizer=Adam(lr=lr), metrics=['accuracy']) self.model = model diff --git a/graphgallery/utils/context_manager.py b/graphgallery/utils/context_manager.py old mode 100644 new mode 100755 diff --git a/graphgallery/utils/data_utils.py b/graphgallery/utils/data_utils.py index 8cc69047..e47cb95b 100755 --- a/graphgallery/utils/data_utils.py +++ b/graphgallery/utils/data_utils.py @@ -93,6 +93,7 @@ def normalize(adj, alpha): adj = adj.tocsr(copy=False) else: adj = sp.diags(d_inv_sqrt) @ adj @ sp.diags(d_inv_sqrt) + adj = adj.A return adj.astype(config.floatx(), copy=False) diff --git a/graphgallery/utils/graph_utils.py b/graphgallery/utils/graph_utils.py index 508360c3..d444af30 100755 --- a/graphgallery/utils/graph_utils.py +++ b/graphgallery/utils/graph_utils.py @@ -86,7 +86,7 @@ def sample_neighbors(adj, nodes, n_neighbors): def get_indice_graph(adj, indices, size=np.inf, dropout=0.): if dropout > 0.: indices = np.random.choice(indices, int(indices.size*(1-dropout)), False) - neighbors = adj[indices].sum(axis=0).nonzero()[1] + neighbors = adj[indices].sum(axis=0).nonzero()[0] if neighbors.size > size - indices.size: neighbors = np.random.choice(list(neighbors), size-len(indices), False) indices = np.union1d(indices, neighbors) @@ -115,5 +115,4 @@ def largest_connected_components(adj, n_components=1): nodes_to_keep = [ idx for (idx, component) in enumerate(component_indices) if component in components_to_keep ] - # print("Selecting {0} largest connected components".format(n_components)) return nodes_to_keep \ No newline at end of file