Skip to content

Commit

Permalink
Geo gnn fixes (#284)
Browse files Browse the repository at this point in the history
* Fix PCQM4Mv2 data

To work with 3d GNN, we needed to add the `z` feature.

* Add support for PNA

Needed global_max_pool or any kind of pooling on the output of the model
in order to work with PCQM4Mv2.

* Adjust dataset sizes and batch sizes for better GPU util

* Fix dataclasses for py3.11
  • Loading branch information
bouthilx authored Sep 13, 2024
1 parent 8793dd8 commit 7633e6b
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 55 deletions.
16 changes: 9 additions & 7 deletions benchmarks/geo_gnn/dev.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
dimenet:
pna:
inherits: _defaults
definition: .
install-variant: cuda
install_group: torch
plan:
method: per_gpu
argv:
--model: 'DimeNet'
--num-samples: 10000
--use3d: True
--model: 'PNA'
--num-samples: 100000
--batch-size: 4096

pna:
dimenet:
inherits: _defaults
definition: .
install-variant: cuda
install_group: torch
plan:
method: per_gpu
argv:
--model: 'PNA'
--num-samples: 10000
--model: 'DimeNet'
--num-samples: 10000
--use3d: True
--batch-size: 512
54 changes: 24 additions & 30 deletions benchmarks/geo_gnn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pcqm4m_subset import PCQM4Mv2Subset
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_max_pool

from benchmate.observer import BenchObserver

Expand Down Expand Up @@ -102,26 +103,25 @@ def main():
args = parser().parse_args()

def batch_size(x):
shape = x.y.shape
return shape[0]
# assert len(x.batch.unique()) == int(x.batch[-1] - x.batch[0] + 1)
return int(x.batch[-1] - x.batch[0] + 1)

observer = BenchObserver(batch_size_fn=batch_size)

# train_dataset = PCQM4Mv2Subset(args.num_samples, args.root)
train_dataset = QM9(args.root)
train_dataset = PCQM4Mv2Subset(args.num_samples, args.root)

sample = next(iter(train_dataset))

info = models[args.model](args,
sample=sample,
degree=lambda: train_degree(train_dataset),
info = models[args.model](
args,
sample=sample,
degree=lambda: train_degree(train_dataset),
)

TRAIN_mean, TRAIN_std = (
mean(train_dataset).item(),
std(train_dataset).item(),
)
print("Train mean: {}\tTrain std: {}".format(TRAIN_mean, TRAIN_std))

DataLoaderClass = DataLoader
dataloader_kwargs = {}
Expand All @@ -131,7 +131,7 @@ def batch_size(x):
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
**dataloader_kwargs
**dataloader_kwargs,
)

device = accelerator.fetch_device(0)
Expand All @@ -148,33 +148,26 @@ def batch_size(x):
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)

num_batches = len(train_loader)
for epoch in range(1, args.epochs + 1):
model.train()
loader = observer.loader(train_loader)

for step, batch in enumerate(observer.iterate(train_loader)):
# QM9 => DataBatch(x=[290, 11], edge_index=[2, 602], edge_attr=[602, 4], y=[16, 19], pos=[290, 3], z=[290], smiles=[16], name=[16], idx=[16], batch=[290], ptr=[17])
# PCQM4Mv2Subset => DataBatch(x=[229, 9], edge_index=[2, 476], edge_attr=[476, 3], y=[16], pos=[229, 3], smiles=[16], batch=[229], ptr=[17])
model.train() # No eval ever.
for epoch in range(1, args.epochs + 1):
for step, batch in enumerate(loader):
batch = batch.to(device)

if args.use3d:

if hasattr(batch, "z"):
z = batch.z
else:
z = batch.batch

molecule_repr = model(z=z, pos=batch.pos, batch=batch.batch)
molecule_repr = model(z=batch.z, pos=batch.pos, batch=batch.batch)
else:
molecule_repr = model(x=batch.x, batch=batch.batch, edge_index=batch.edge_index, batch_size=batch_size(batch))
molecule_repr = model(
x=batch.x.type(torch.float),
batch=batch.batch,
edge_index=batch.edge_index,
batch_size=batch_size(batch),
)
molecule_repr = global_max_pool(molecule_repr, batch.batch)

pred = molecule_repr.squeeze()

# Dimenet : pred: torch.Size([ 16, 19])
# PNA : pred: torch.Size([292, 19]) <= (with x=batch.x) WTF !? 292 = batch.x.shape[0]
# batch : torch.Size([ 16, 19])
# print(molecule_repr.shape)
# print(batch.y.shape)

B = pred.size()[0]
y = batch.y.view(B, -1)
# normalize
Expand All @@ -192,7 +185,8 @@ def batch_size(x):

lr_scheduler.step()

print("Epoch: {}\nLoss: {}".format(epoch))
if loader.is_done():
break


if __name__ == "__main__":
Expand Down
7 changes: 5 additions & 2 deletions benchmarks/geo_gnn/pcqm4m_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
"smiles": str,
"pos": dict(dtype=torch.float32, size=(-1, 3)),
"y": float,
"z": dict(dtype=torch.long, size=(-1,)),
}

self.from_smiles = from_smiles or _from_smiles
Expand All @@ -49,12 +50,10 @@ def raw_file_names(self):
]

def download(self):
print(self.raw_paths)
if all(os.path.exists(path) for path in self.raw_paths):
return

# Download 2d graphs
print(self.raw_dir)
super().download()

# Download 3D coordinates
Expand All @@ -78,6 +77,9 @@ def process(self) -> None:
data.pos = torch.tensor(
extra.GetConformer().GetPositions(), dtype=torch.float
)
data.z = torch.tensor(
[atom.GetAtomicNum() for atom in extra.GetAtoms()], dtype=torch.long
)

data_list.append(data)
if (
Expand All @@ -104,4 +106,5 @@ def std(self):
def serialize(self, data: BaseData) -> Dict[str, Any]:
rval = super().serialize(data)
rval["pos"] = data.pos
rval["z"] = data.z
return rval
7 changes: 2 additions & 5 deletions benchmarks/geo_gnn/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def parser():
"--num-samples",
type=int,
help="Number of samples to process in the dataset",
default=10000,
default=100000,
)
parser.add_argument(
"--root",
Expand All @@ -26,7 +26,4 @@ def parser():
if __name__ == "__main__":
args, _ = parser().parse_known_args()

# TODO: Handle argument for the number of samples
train_dataset = QM9(args.root)
# dataset = PCQM4Mv2Subset(args.num_samples, root=args.root)

dataset = PCQM4Mv2Subset(args.num_samples, root=args.root)
23 changes: 12 additions & 11 deletions milabench/system.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import contextvars
import ipaddress
import os
import socket
from dataclasses import dataclass, field
import sys
import subprocess
import sys
from contextlib import contextmanager
import ipaddress
from dataclasses import dataclass, field

import psutil
import yaml
Expand Down Expand Up @@ -193,11 +193,11 @@ class Torchrun:

@dataclass
class Options:
sizer: SizerOptions = SizerOptions()
cpu: CPUOptions = CPUOptions()
dataset: DatasetConfig = DatasetConfig()
dirs: Dirs = Dirs()
torchrun: Torchrun = Torchrun()
sizer: SizerOptions = field(default_factory=SizerOptions)
cpu: CPUOptions = field(default_factory=CPUOptions)
dataset: DatasetConfig = field(default_factory=DatasetConfig)
dirs: Dirs = field(default_factory=Dirs)
torchrun: Torchrun = field(default_factory=Torchrun)


@dataclass
Expand Down Expand Up @@ -231,18 +231,19 @@ def default_device():
@dataclass
class SystemConfig:
"""This is meant to be an exhaustive list of all the environment overrides"""

arch: str = defaultfield("gpu.arch", str, default_device())
sshkey: str = defaultfield("ssh", str, "~/.ssh/id_rsa")
docker_image: str = None
nodes: list[Nodes] = field(default_factory=list)
gpu: GPUConfig = GPUConfig()
options: Options = Options()
gpu: GPUConfig = field(default_factory=GPUConfig)
options: Options = field(default_factory=Options)

base: str = defaultfield("base", str, None)
config: str = defaultfield("config", str, None)
dash: bool = defaultfield("dash", bool, 1)
noterm: bool = defaultfield("noterm", bool, 0)
github: Github = Github()
github: Github = field(default_factory=Github)


def check_node_config(nodes):
Expand Down

0 comments on commit 7633e6b

Please sign in to comment.