The sslgraph package is a collection of benchmark datasets, data interfaces, evaluation tasks, and state-of-the-art algorithms for graph self-supervised learning. We aims to provide a unified and highly customizable framework for implementing graph self-supervised learning methods, standardized datasets, and unified performance evaluation for academic researchers interested in graph self-supervised learning. We cover the following three tasks:
- Unsupervised graph-level representation learning, to learn graph-level representations on unlabeled graph dataset, evaluated by 10-fold linear classification with SVC or logistic regression;
- Semi-supervised graph classification (and any other two-stage training tasks including pre-training and finetuning on different datasets), that pretrain graph encoders with a large amount of unlabeled graphs, fintune and evaluate the encoder on a smaller amount of labeled graphs;
- Unsupervised node-level representation learning, to learn node-level representation on unlabeled graph dataset evaluated by node-level linear classification with logistic regression.
The sslraph
package implements a unified and highly customizable framework for contrastive learning methods as a parent class. Particular contrastive methods can be easily implemented given the framework by specifying its encoder, functions for view generation, and contrastive objectives.
Current version includes the following components and we will keep updating them as new methods come out.
- Encoders: GCN, GIN, and ResGCN (semi-supervised benchmark).
- View functions: unified node-induced sampling, random walk sampling (sample-based), node attribute masking (feature transformation), edge pertubation, graph diffusion (structure transformation) and their combinations or random choice.
- Objectives: InfoNCE (NT-XENT), Jenson-Shannon Estimator for all graph-level, node-level or combined contrast, and any number of views.
Based on the framework and components, four state-of-the-art graph generation algorithms are implemented, with detailed examples for running evaluations with the algorithms. The implemented algorithms include
InfoGraph
(graph-level): InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information MaximizationGRACE
(node-level): Deep Graph Contrastive Representation LearningMVGRL
(graph & node): Contrastive Multi-View Representation Learning on GraphsGraphCL
(graph-level): Graph Contrastive Learning with Augmentations
Alghouth only contrastive learning framework and methods are implemented, the evaluation tools are also compatible with predictive methods for self-supervised learning, such as graph auto-encoders.
We have provided examples for running the four pre-implemented algorithms with data interfaces and evaluation tasks. Please refer to the jupyter notebooks for instructions.
Below are instructions to implement your own contrastive learning algorithms or perform evaluation with your own datasets.
All Dataset
objects from torch_geometric
are supported, such as TuDataset
and MoleculeNet
(see the full list here). For new datasets that are not included, please refer to their instructions. For different evaluation tasks, the evaluator requires different dataset inputs.
-
Unsupervised graph-level representation learning (
sslgraph.utils.eval_graph.EvalUnsupevised
) requires a single (unlabeled) dataset. -
Semi-supervised graph classification and transfer learning (
sslgraph.utils.eval_graph.EvalSemisupevised
) requires two datasets for model pretraining and finetuning respectively. The finetuning dataset must include graph labels. -
Unsupervised node-level representation learning (
sslgraph.utils.eval_node.EvalUnsupevised
) requires one dataset that includes all nodes in one graph, together with the training mask and test mask that indicates which ndoes are for training/test.
Examples can be found in the jupyter notebooks.
You can customize contrastive model using the base class sslgraph.contrastive.model.Contrastive
. You will need to specify the following arguments.
Contrastive(objective, views_fn, graph_level=True, node_level=False, z_dim=None, z_n_dim=None,
proj=None, proj_n=None, neg_by_crpt=False, tau=0.5, device=None, choice_model='last',
model_path='models')
objective
: String. Either'NCE'
or'JSE'
. The objective will be automatrically adjusted for different representation level (node, graph, or both).views_fn
: List of functions. Each function corresponds to a view generation that takes graphs as input and outputs transformed graphs.graph_level
,node_level
: Boolean. At least one should beTrue
. Whether to perform contrast among graph-level representations (GraphCL), node-level representations (GEACE), or both (DGI, InfoGraph, MVGRL). Default:graph_level=True
,node_level=False
.z_dim
,z_n_dim
: Integer orNone
. The dimension of graph-level and node-level representations. Required ifgraph_level
ornode_level
is set toTrue
, respectively. When jumping knowledge is applied in your graph encoder,z_dim = z_n_dim * n_layers
.proj
,proj_n
: String, callable model, orNone
. The projection heads for graph/node-level representations. If string, should be either'linear'
or'MLP'
.neg_by_crpt
: Boolean. Only required when usingJSE
objective. IfTrue
, model will generate corrupted graphs as negative samples. IfFalse
, model will consider any pairs of different samples in a batch as negative pairs.tau
. Float in(0,1]
. Only required when usingNCE
objective.
Methods:
train(encoder, data_loader, optimizer, epochs, per_epoch_out=False)
-
encoder
: Pytorchnn.Module
object. Callable with input graphs, and returns graph-level, node-level, or both representations. -
data_loader
: PytorchDataloader
object. -
optimizer
: Pytorch optimizer. Should be initialized with the parameters inencoder
. Example:optimizer=Adam(encoder.parameters(), lr=0.001)
-
epochs
: Integer. Number of pretraining epochs. -
per_epoch_out
: Boolean. IfTrue
, yield encoder per epoch. Otherwise, only yield the final encoder at the last epoch.Return: A generator that yields tuples of
(trained encoder, trained projection head)
at each epoch or after the last epoch. When only one ofgraph_level
andnode_level
isTrue
, the trained projection head only contains the corresponding head. When both areTrue
, the trained projection head is a tuple of(graph proj head, node proj head)
.
You may also define a class using Contrastive
and override the class methods (such as train()
) as needed, so that the customized model can be used with the evaluation tools. Follow the examples for implementing GRACE
, InfoGraph
, GraphCL
, and MVGRL
.
If you find our library useful, please consider cite our work below.
@article{xie2021self,
title={Self-Supervised Learning of Graph Neural Networks: A Unified Review},
author={Xie, Yaochen and Xu, Zhao and Wang, Zhengyang and Ji, Shuiwang},
journal={arXiv preprint arXiv:2102.10757},
year={2021}
}
If you have any questions, please submit an issue or contact us at [email protected] and [email protected].