This is the official code for Discovering the Representation Bottleneck of Graph Neural Networks from Multi-order Interactions.
[arXiv].
Unlike social networks or knowledge graphs, no edges are explicitly defined for molecular
graphs in the 3D Euclidean space, and researchers usually employ KNN-graphs and
fully-connected graphs to construction the connectivity between
nodes or entities (e.g., atoms or residues). Our work reveals that these two standard graph construction methods can bring improper inductive bias,
which prevents GNNs from learning interactions of the optimal order [1] (i.e., complexity) and therefore reduce their performance. To overcome this limitation, we design a new graph rewiring approach to dynamically adjust the receptive fields of each node or entity.
Some necessary packages before running the code.
pip install torch
pip install sklearn
pip install einops
pip install matplotlib
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-{your!torch!version}+cu{your!cuda!version}.html
We examine the characteristics of GNNs on four different tasks. Among them, Newtonian Dynamics and molecular dynamics (MD) simulations are node-level regression tasks, while Hamiltonian dynamics and molecular property prediction are graph-level regression task. Please follow the following guidance to generate and preprocess the data.
There we follow Cranmer et al. (2020) [2] and adjust their code to generate the data.
# install necessary packages
pip install celluloid
pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.htmlpip install celluloid
pip install jaxlib
Note that jax.ops.index_update
is deprecated at jax
0.2.22, and we modify the profile via x0.at[].set()
. Moreover, it might
cause a problem with loading jax due to Couldn't invoke ptxas
. This is because the path of ptxas
is not available to the system.
A possible solution is to install cuda manually using the install_cuda_11_1.sh
file.
Then run the following sh
command to produce the raw data.
sudo bash install_cuda_11_1.sh
Finally, preprocess the raw data and save as pt
file for future usage.
python data/dataset_nbody.py
The MD dataset, ISO 17, is provided by the Quantum Machine organization, which is available in its official website. After downloading the source data, run the following script to preprocess it.
python data/dataset_iso17.py
QM7 and QM8 datasets are also accessible in the same link of the Quantum Machine organization as the MD datatset.
python data/dataset_qm.py
Implement the following command to pre-train a 3D GNN model.
python train.py --data=qm7 --method=egnn --gpu=0,1
# load a pretrained model
python test.py --data=qm7 --method=egnn --pretrain=1
# randomly initializing a model
python test.py --data=qm7 --method=egnn --pretrain=0
In Section Revisiting Representation Bottlenecks of DNNs, we additionally investigate the representation bottleneck of another commonly used type of DNNs, i.e., CNNs. The cnn folder documents the corresponding implementation, which evaluates multi-order interaction strengths of timm backbones for visual representation learning. It is based on Deng et al. [1] and their official code.
This cnn repository works with PyTorch 1.8 or higher and timm. There are installation steps with the latest PyTorch:
conda create -n bottleneck python=3.8 pytorch=1.12 cudatoolkit=11.3 torchvision -c pytorch -y
conda activate bottleneck
pip install -r cnn/requirements.txt
Then, please download datasets and place them under ./cnn/datasets
. CIFAR-10 will be automatically
downloaded, while ImageNet should be downloaded and unziped manually.
We only support the evaluation of pre-trained models. Please download released pre-trained models from timm
and place them in ./cnn/timm_hub
. Then run the following example on ImageNet in ./cnn/interaction_in1k.sh
:
cd cnn
bash interaction.sh
You can uncomment the setting, including the model name and ckeckpoints, that you want to run on top of the script.
The results will be saved in the results
directory by default.
If you have any questions, please do not hesitate to contact Fang WU.
Please consider citing our paper if you find it helpful. Thank you! 😜
@article{wu2022discovering,
title={Discovering the Representation Bottleneck of Graph Neural Networks from Multi-order Interactions},
author={Wu, Fang and Li, Siyuan and Wu, Lirong and Li, Stan Z and Radev, Dragomir and Zhang, Qiang},
journal={arXiv preprint arXiv:2205.07266},
year={2022}
}
[1]
Deng, H., Ren, Q., Chen, X., Zhang, H., Ren, J., & Zhang, Q. (2021). Discovering and explaining the representation bottleneck of dnns. arXiv preprint arXiv:2111.06236.
[2]
Cranmer, Miles, et al. "Discovering symbolic models from deep learning with inductive biases." NIPS 33 (2020): 17429-17442.