A light-weight PyTorch extension for efficient equivariant deep learning.
Torch-Gauge is a library to boost geometric learning on physical data structures with Lie-group symmetry and beyond two-body interactions. The library is designed to be specifically optimized for training and inference on GPUs, with mini-batch training utilities natively supported.
Torch-Gauge uses padded Verlet list as the core for manipulating relational data, which enables compact and highly customizable implementation of geometric learning models.
As an illustration, SchNet 's interaction module can be concisely implemented:
import torch
from torch.nn import Linear, Parameter
from torch_gauge.verlet_list import VerletList
from torch_gauge.nn import SSP
class SchNetLayer(torch.nn.Module):
def __init__(self, num_features):
_nf = num_features
self.gamma = 10.0
self.rbf_centers = Parameter(torch.linspace(0.1, 30.1, 300), requires_grad=False)
self.cfconv = torch.nn.Sequential(
Linear(300, _nf, bias=False), SSP(), Linear(_nf, _nf, bias=False), SSP()
self.pre_conv = Linear(_nf, _nf)
self.post_conv = torch.nn.Sequential(Linear(_nf, _nf), SSP(), Linear(_nf, _nf))
def forward(self, vl: VerletList, l: int):
xyz = vl.ndata["xyz"]
pre_conv = self.pre_conv(vl.ndata[f"atomic_{l}"])
d_ij = (vl.query_src(xyz) - xyz.unsqueeze(1)).norm(dim=2, keepdim=True)
filters = self.cfconv(
torch.exp(-self.gamma * (d_ij - self.rbf_centers.view(1, 1, -1)).pow(2))
conv_out = (filters * vl.query_src(pre_conv) * vl.edge_mask.unsqueeze(2)).sum(1)
vl.ndata[f"atomic_{l+1}"] = vl.ndata[f"atomic_{l}"] + self.post_conv(conv_out)
return vl.ndata[f"atomic_{l+1}"]
With the SphericalTensor
support we can feasibly extend it to SE(3)-equivariant convolutions:
import torch
from torch_gauge.nn import SSP, IELin
from torch.nn import Linear, Parameter
from torch_gauge.verlet_list import VerletList
from torch_gauge.o3.spherical import SphericalTensor
from torch_gauge.o3.rsh import RSHxyz
from torch_gauge.geometric import poly_env
from torch_gauge.o3.clebsch_gordan import LeviCivitaCoupler
class SE3Layer(torch.nn.Module):
def __init__(self, num_features):
_nf = num_features
self.rbf_freqs = Parameter(torch.arange(16), requires_grad=False)
self.rsh_mod = RSHxyz(max_l=1)
self.coupler = LeviCivitaCoupler(torch.LongTensor([_nf, _nf]))
self.filter_gen = torch.nn.Sequential(
Linear(16, _nf, bias=False), SSP(), Linear(_nf, _nf, bias=False)
self.pre_conv = IELin([_nf, _nf], [_nf, _nf])
self.post_conv = torch.nn.ModuleList(
[IELin([2*_nf, 3*_nf], [_nf, _nf]), SSP(), IELin([_nf, _nf], [_nf, _nf])]
def forward(self, vl: VerletList, l: int):
r_ij = vl.query_src(vl.ndata["xyz"]) - vl.ndata["xyz"].unsqueeze(1)
d_ij = r_ij.norm(dim=2, keepdim=True)
r_ij = r_ij / d_ij
feat_in: SphericalTensor = vl.ndata[f"atomic_{l}"]
pre_conv = vl.query_src(self.pre_conv(feat_in))
filters_radial = self.filter_gen(
torch.sin(d_ij / 5 * self.rbf_freqs.view(1, 1, -1)) / d_ij * poly_env(d_ij/5)
filters = pre_conv.self_like(self.rsh_mod(r_ij).ten.unsqueeze(-1).mul(filters_radial).flatten(2, 3))
coupling_out = self.post_conv[0](self.coupler(pre_conv, filters, overlap_out=False))
conv_out = feat_in.self_like(coupling_out.ten.mul(vl.edge_mask.unsqueeze(2)).sum(1))
conv_out = conv_out.scalar_mul(self.post_conv[1](conv_out.invariant()))
conv_out.ten = self.post_conv[2](conv_out).ten
vl.ndata[f"atomic_{l+1}"] = feat_in + conv_out
return vl.ndata[f"atomic_{l+1}"]
Torch-Gauge's Verlet list also supports generating higher-order views (triplets, quartets, etc.) to streamline the implemention of models with specialized non-2body message passing algorithms.
pip install torch-gauge
Once PyTorch is installed, running
pip install versioneer && versioneer install
pip install -e .
High-level python functionstorch_gauge/o3
O(3) group algebra functionals and data structurestorch_gauge/nn.py
Tensorial neural network building blockstorch_gauge/verlet_list.py
Verlet neighbor-list operations for representing relational datatorch_gauge/geometric.py
contains geometric and algebra operations with autogradtorch_gauge/models
contains exemplary implementations of GNN variants and descriptorstorch_gauge/tests
contains pytests
make test
Please submit a question/issue or email me.