Skip to content

cpittet/DGM_pytorch

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

24 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DGM_pytorch

Code for the paper "Differentiable Graph Module (DGM) for Graph Convolutional Networks" by Anees Kazi*, Luca Cosmo*, Seyed-Ahmad Ahmadi, Nassir Navab, and Michael Bronstein

Installation

Create a Conda virtual environment and install all the necessary packages

conda create -n DGMenv python=3.8
conda activate DGMenv
conda install -c anaconda cmake=3.19
conda install pytorch==1.8.1 torchvision==0.9.1 torchaudio==0.8.1 cudatoolkit=10.1 -c pytorch
pip install pytorch_lightning==1.3.8

pip install torch-scatter==2.0.8 -f https://data.pyg.org/whl/torch-1.8.1+cu101.html
pip install torch-sparse==0.6.12 -f https://data.pyg.org/whl/torch-1.8.1+cu101.html
pip install torch-geometric

Training

To train a model with the default options run the following command:

python train.py

Notes

The graph sampling code is based on a modified version of the KeOps libray (www.kernel-operations.io) to speed-up the computation. In particular, the argKmin function of the original libray has been modified to handle the stochasticity of the sampling strategy, adding samples drawn from a Gumbel distribution to the input before performing the reduction.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • C++ 63.8%
  • Python 20.8%
  • Cuda 4.2%
  • CMake 3.7%
  • Makefile 3.4%
  • R 1.9%
  • Other 2.2%