Skip to content

GraphGallery 0.1.5

Compare
Choose a tag to compare
@EdisonLeeeee EdisonLeeeee released this 28 Jul 13:39
· 702 commits to master since this release

Changes

  • add model.show() methods to show the parameters.
  • improve the method model.build and using it comes more flexible.
  • fix the ClusterGCN retracing in tf.function bugs in tensorflow version >= 2.2.0.

Example of GCN model

from graphgallery.nn.models import GCN
# adj is scipy sparse matrix, x is numpy array matrix
model = GCN(adj, x, labels, device='GPU', seed=123)
# build your GCN model with custom hyper-parameters
model.build()
# train your model. here idx_train and idx_val are numpy arrays
his = model.train(idx_train, idx_val, verbose=1, epochs=100)
# test your model
loss, accuracy = model.test(idx_test)
print(f'Test loss {loss:.5}, Test accuracy {accuracy:.2%}')

On Cora dataset:

loss 1.02, acc 95.00%, val_loss 1.41, val_acc 77.40%: 100%|██████████| 100/100 [00:02<00:00, 37.07it/s]
Test loss 1.4123, Test accuracy 81.20%

Build your model

you can use the following statement to build your model

# one hidden layer with hidden units 32 and activation function RELU
>>> model.build(hiddens=32, activations='relu')

# two hidden layer with hidden units 32, 64 and all activation functions are RELU
>>> model.build(hiddens=[32, 64], activations='relu')

# two hidden layer with hidden units 32, 64 and activation functions RELU and ELU
>>> model.build(hiddens=[32, 64], activations=['relu', 'elu'])

# other parameters like `dropouts` and `l2_norms` (if have) are the SAME.

Train or test your model

More details can be seen in the methods model.train and model.test

Hyper-parameters

you can simply use model.show() to show all your Hyper-parameters.
Otherwise you can also use model.show('model') or model.show('train') to show your model parameters and training parameters.
NOTE: you should install texttable first.

Visualization

  • Accuracy
import matplotlib.pyplot as plt
plt.plot(his.history['acc'])
plt.plot(his.history['val_acc'])
plt.legend(['Accuracy', 'Val Accuracy'])
plt.xlabel('Epochs')

visualization

  • Loss
import matplotlib.pyplot as plt
plt.plot(his.history['loss'])
plt.plot(his.history['val_loss'])
plt.legend(['Loss', 'Val Loss'])
plt.xlabel('Epochs')

visualization