From 7633e6b1efaadbaf86fa7eb46ed6d6c8084e6000 Mon Sep 17 00:00:00 2001 From: Xavier Bouthillier Date: Fri, 13 Sep 2024 16:13:23 -0400 Subject: [PATCH] Geo gnn fixes (#284) * Fix PCQM4Mv2 data To work with 3d GNN, we needed to add the `z` feature. * Add support for PNA Needed global_max_pool or any kind of pooling on the output of the model in order to work with PCQM4Mv2. * Adjust dataset sizes and batch sizes for better GPU util * Fix dataclasses for py3.11 --- benchmarks/geo_gnn/dev.yaml | 16 +++++---- benchmarks/geo_gnn/main.py | 54 +++++++++++++---------------- benchmarks/geo_gnn/pcqm4m_subset.py | 7 ++-- benchmarks/geo_gnn/prepare.py | 7 ++-- milabench/system.py | 23 ++++++------ 5 files changed, 52 insertions(+), 55 deletions(-) diff --git a/benchmarks/geo_gnn/dev.yaml b/benchmarks/geo_gnn/dev.yaml index 7fadaea5f..6f261c895 100644 --- a/benchmarks/geo_gnn/dev.yaml +++ b/benchmarks/geo_gnn/dev.yaml @@ -1,4 +1,4 @@ -dimenet: +pna: inherits: _defaults definition: . install-variant: cuda @@ -6,11 +6,11 @@ dimenet: plan: method: per_gpu argv: - --model: 'DimeNet' - --num-samples: 10000 - --use3d: True + --model: 'PNA' + --num-samples: 100000 + --batch-size: 4096 -pna: +dimenet: inherits: _defaults definition: . install-variant: cuda @@ -18,5 +18,7 @@ pna: plan: method: per_gpu argv: - --model: 'PNA' - --num-samples: 10000 \ No newline at end of file + --model: 'DimeNet' + --num-samples: 10000 + --use3d: True + --batch-size: 512 \ No newline at end of file diff --git a/benchmarks/geo_gnn/main.py b/benchmarks/geo_gnn/main.py index 714707f65..71e1c8827 100644 --- a/benchmarks/geo_gnn/main.py +++ b/benchmarks/geo_gnn/main.py @@ -9,6 +9,7 @@ from pcqm4m_subset import PCQM4Mv2Subset from torch_geometric.datasets import QM9 from torch_geometric.loader import DataLoader +from torch_geometric.nn import global_max_pool from benchmate.observer import BenchObserver @@ -102,26 +103,25 @@ def main(): args = parser().parse_args() def batch_size(x): - shape = x.y.shape - return shape[0] + # assert len(x.batch.unique()) == int(x.batch[-1] - x.batch[0] + 1) + return int(x.batch[-1] - x.batch[0] + 1) observer = BenchObserver(batch_size_fn=batch_size) - # train_dataset = PCQM4Mv2Subset(args.num_samples, args.root) - train_dataset = QM9(args.root) + train_dataset = PCQM4Mv2Subset(args.num_samples, args.root) sample = next(iter(train_dataset)) - info = models[args.model](args, - sample=sample, - degree=lambda: train_degree(train_dataset), + info = models[args.model]( + args, + sample=sample, + degree=lambda: train_degree(train_dataset), ) TRAIN_mean, TRAIN_std = ( mean(train_dataset).item(), std(train_dataset).item(), ) - print("Train mean: {}\tTrain std: {}".format(TRAIN_mean, TRAIN_std)) DataLoaderClass = DataLoader dataloader_kwargs = {} @@ -131,7 +131,7 @@ def batch_size(x): batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, - **dataloader_kwargs + **dataloader_kwargs, ) device = accelerator.fetch_device(0) @@ -148,33 +148,26 @@ def batch_size(x): lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs) num_batches = len(train_loader) - for epoch in range(1, args.epochs + 1): - model.train() + loader = observer.loader(train_loader) - for step, batch in enumerate(observer.iterate(train_loader)): - # QM9 => DataBatch(x=[290, 11], edge_index=[2, 602], edge_attr=[602, 4], y=[16, 19], pos=[290, 3], z=[290], smiles=[16], name=[16], idx=[16], batch=[290], ptr=[17]) - # PCQM4Mv2Subset => DataBatch(x=[229, 9], edge_index=[2, 476], edge_attr=[476, 3], y=[16], pos=[229, 3], smiles=[16], batch=[229], ptr=[17]) + model.train() # No eval ever. + for epoch in range(1, args.epochs + 1): + for step, batch in enumerate(loader): batch = batch.to(device) - + if args.use3d: - - if hasattr(batch, "z"): - z = batch.z - else: - z = batch.batch - - molecule_repr = model(z=z, pos=batch.pos, batch=batch.batch) + molecule_repr = model(z=batch.z, pos=batch.pos, batch=batch.batch) else: - molecule_repr = model(x=batch.x, batch=batch.batch, edge_index=batch.edge_index, batch_size=batch_size(batch)) + molecule_repr = model( + x=batch.x.type(torch.float), + batch=batch.batch, + edge_index=batch.edge_index, + batch_size=batch_size(batch), + ) + molecule_repr = global_max_pool(molecule_repr, batch.batch) pred = molecule_repr.squeeze() - # Dimenet : pred: torch.Size([ 16, 19]) - # PNA : pred: torch.Size([292, 19]) <= (with x=batch.x) WTF !? 292 = batch.x.shape[0] - # batch : torch.Size([ 16, 19]) - # print(molecule_repr.shape) - # print(batch.y.shape) - B = pred.size()[0] y = batch.y.view(B, -1) # normalize @@ -192,7 +185,8 @@ def batch_size(x): lr_scheduler.step() - print("Epoch: {}\nLoss: {}".format(epoch)) + if loader.is_done(): + break if __name__ == "__main__": diff --git a/benchmarks/geo_gnn/pcqm4m_subset.py b/benchmarks/geo_gnn/pcqm4m_subset.py index 615aea2bb..2d6e0e2bd 100644 --- a/benchmarks/geo_gnn/pcqm4m_subset.py +++ b/benchmarks/geo_gnn/pcqm4m_subset.py @@ -35,6 +35,7 @@ def __init__( "smiles": str, "pos": dict(dtype=torch.float32, size=(-1, 3)), "y": float, + "z": dict(dtype=torch.long, size=(-1,)), } self.from_smiles = from_smiles or _from_smiles @@ -49,12 +50,10 @@ def raw_file_names(self): ] def download(self): - print(self.raw_paths) if all(os.path.exists(path) for path in self.raw_paths): return # Download 2d graphs - print(self.raw_dir) super().download() # Download 3D coordinates @@ -78,6 +77,9 @@ def process(self) -> None: data.pos = torch.tensor( extra.GetConformer().GetPositions(), dtype=torch.float ) + data.z = torch.tensor( + [atom.GetAtomicNum() for atom in extra.GetAtoms()], dtype=torch.long + ) data_list.append(data) if ( @@ -104,4 +106,5 @@ def std(self): def serialize(self, data: BaseData) -> Dict[str, Any]: rval = super().serialize(data) rval["pos"] = data.pos + rval["z"] = data.z return rval diff --git a/benchmarks/geo_gnn/prepare.py b/benchmarks/geo_gnn/prepare.py index 2b352f8ce..b3ac374b0 100755 --- a/benchmarks/geo_gnn/prepare.py +++ b/benchmarks/geo_gnn/prepare.py @@ -12,7 +12,7 @@ def parser(): "--num-samples", type=int, help="Number of samples to process in the dataset", - default=10000, + default=100000, ) parser.add_argument( "--root", @@ -26,7 +26,4 @@ def parser(): if __name__ == "__main__": args, _ = parser().parse_known_args() - # TODO: Handle argument for the number of samples - train_dataset = QM9(args.root) - # dataset = PCQM4Mv2Subset(args.num_samples, root=args.root) - + dataset = PCQM4Mv2Subset(args.num_samples, root=args.root) diff --git a/milabench/system.py b/milabench/system.py index d29f4cd27..c237baf2c 100644 --- a/milabench/system.py +++ b/milabench/system.py @@ -1,11 +1,11 @@ import contextvars +import ipaddress import os import socket -from dataclasses import dataclass, field -import sys import subprocess +import sys from contextlib import contextmanager -import ipaddress +from dataclasses import dataclass, field import psutil import yaml @@ -193,11 +193,11 @@ class Torchrun: @dataclass class Options: - sizer: SizerOptions = SizerOptions() - cpu: CPUOptions = CPUOptions() - dataset: DatasetConfig = DatasetConfig() - dirs: Dirs = Dirs() - torchrun: Torchrun = Torchrun() + sizer: SizerOptions = field(default_factory=SizerOptions) + cpu: CPUOptions = field(default_factory=CPUOptions) + dataset: DatasetConfig = field(default_factory=DatasetConfig) + dirs: Dirs = field(default_factory=Dirs) + torchrun: Torchrun = field(default_factory=Torchrun) @dataclass @@ -231,18 +231,19 @@ def default_device(): @dataclass class SystemConfig: """This is meant to be an exhaustive list of all the environment overrides""" + arch: str = defaultfield("gpu.arch", str, default_device()) sshkey: str = defaultfield("ssh", str, "~/.ssh/id_rsa") docker_image: str = None nodes: list[Nodes] = field(default_factory=list) - gpu: GPUConfig = GPUConfig() - options: Options = Options() + gpu: GPUConfig = field(default_factory=GPUConfig) + options: Options = field(default_factory=Options) base: str = defaultfield("base", str, None) config: str = defaultfield("config", str, None) dash: bool = defaultfield("dash", bool, 1) noterm: bool = defaultfield("noterm", bool, 0) - github: Github = Github() + github: Github = field(default_factory=Github) def check_node_config(nodes):