Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graphs_integration #161

Open
wants to merge 245 commits into
base: main
Choose a base branch
from
Open

Graphs_integration #161

wants to merge 245 commits into from

Conversation

EnricoTrizio
Copy link
Collaborator

@EnricoTrizio EnricoTrizio commented Nov 13, 2024

General description

Add the code for CVs based on GNN in the most (possible) organic way.
This largely inherits from Jintu's work (kudos! @jintuzhang), where all the code was based on a "library within the library".
Some functions were much different for the rest of the code (e.g., all the code for GNN and the GraphDatamodule), others were mostly redundant (e.g., GraphDataset, CV base and specific classes).

It would be wise to reduce the code duplicates and redundancies and make the whole library more organic, still including all the new functionalities.
SPOILER: this requires some thinking and some modifications here and there

(We could also split the story in more PR in case)


Point-by-point description

Data handling

Affecting --> mlcolvar.data, mlcolvar.utils.io

Overview

So far, we have a DictDataset (based on torch.Dataset) and the corresponding DictModule (based on lightning.lightningDataModule).

For GNNs, there was a GraphDataset (based on lists) and the corresponding DictModule (based on lightning.lightningDataModule).
Here, the data are handled using the PyTorchGeometric for convenience.
There are also a bunch of auxiliary functions for neighborhoods and handling of atom types, plus some utils to initialize the dataset easily from files.

Implemented solution

The two things are merged:

  1. A single DictDataset that can handle both types of data.
  • It also has a metadata attribute that stores in a dict general properties (e.g., cutoff and atom_types).
  • In the __init__, the user can specify the data_type (either descriptors (default) or graphs. This is then stored in metadata and is used in the DictLoader to handle the data the right way (see below)
  • New utils have been added in mlcolvar.data.utils: save_dataset, load_dataset and save_dataset_configurations_as_extyz
  1. A single DictModule that can handle both types of data. Depending on the metadata['data_type'] of the incoming dataset it either uses our DictLoader or the torch_geometric.DataLoader.
  2. A new submodule data.graph containing:
  • atomic.py for the handling of atomic quantities based on the data class Configuration
  • neighborhood.py for building neighbor lists using matscipy
  • utils.py to frame Configurations into dataset and one-hot embeddings. It also contains create_test_graph_input as creating inputs for testing here requires several lines of code.
  1. A new create_dataset_from_trajectories utils in mlcolvar.utils.io that allows creating a dataset directly from some trajectory files, providing topology files and using mdtraj
  2. A single create_timelagged_dataset that can also create the time-lagged dataset starting from DictDataset with data_type=='graphs'

NB For the graph datasets, the keys are the original ones:

  • data_list: all the graph data, e.g., edge src and dst, batch index... (this goes in DictDataset)
  • z_table: atomic numbers map (this goes in DictDataset.metadata)
  • cutoff: cutoff used in the graph (this goes in DictDataset.metadata)

Questions

  • Shall we keep these names for the keys?
  • Do we like the metadata thing?
  • Single DataModule?
  • Maybe make the overall structure smoother? i.e., no too many utils.py here and there and too many submodules?

GNN models

Affecting --> mlcolvar.core.nn

Overview

Of course, they need to be implemented 😄 but we can inherit most of the code from Jintu.
As an overview, there is a BaseGNN parent class that implements the common features, and then each model (e.g., SchNet or GVP) is implemented on top of that.
There is also a radial.py that implements a bunch of tools for radial embeddings.

Implemented solution

The GNN code is now implemented in mlcolvar.core.nn.graph.

  1. There is a BaseGNN class that is a template for the architecture-specific code. This, for example, already has the methods for embedding edges and setting some common properties.
  2. The Radial module implements the tools for radial embeddings
  3. The SchNetModel and GVPModel are implemented based on BaseGNN
  4. In utils.py, there is a function that creates data for the tests for this module. This should be replaced using the very similar function mlcolvar.data.graph.utils.create_test_graph_input that is more general and used also for other things

CV models

Affecting --> mlcolvar.cvs, mlcolvar.core.loss

Overview

In Jintu's implementation, all the CV classes we tested were re-implemented, still using the original loss function code.
The point, there, is that the initialization of the underlying ML model (also in the current version of the library) is performed within the CV class.
We did it to make it simple, and indeed, it is for feed-forward networks, as they have very few things to set (i.e., layers, node, activations) and also because there were no alternatives at the time.
For GNNs, however, the initialization can vary a lot (i.e., different architectures and many parameters one could set).

I am afraid we can't cut corners here if we want to include everything and somewhere we need to add an extra layer of complexity to wither the workflow or the CV models,

Implemented solution

We keep everything similar to what it used to be in the library, except for:

  1. We rename the layers keyword to the more general model in the init of the CV classes that can accept
  • A list of integers, as it was before. It works as the old layers keyword and initializes a FeedForward with that and all the DEFAULT_BLOCKS' (see point 2), e.g. for DeepLDA: ['norm_in', 'nn', 'lda']`.
  • A mlcolvar.core.nn.FeedForward or mlcolvar.core.nn.graph.BaseGNN model that you had initialized outside the CV class. This way, one overrides the old default and provides an external model and uses the MODEL_BLOCKS, e.g. for DeepLDA: ['nn', 'lda']. For example, the initialization can be something like this
# for GNN-based
gnn_model = SchNet(...)
model = DeepTDA(..., model=gnn_model, ...)

# for FFNN-based, alternative 1
model = DeepTDA(..., model=[2, 3], ...)

# for FFNN-based, alternative 2
ff_model = FeedForward(layers=[2, 3])
model = DeepTDA(..., model=ff_model, ...)
  1. The BLOCKS of each CV model are duplicated in DEFAULT_BLOCKS' and MODEL_BLOCKS` to account for the different behaviors. This was a simple way to initialize everything in all the cases (maybe not best one, see questions)
  2. In the training step, the change amounts to having a different setup of the data depending on the type of ML-model we are using, then the rest is basically the same as it was.

###Things to note

  1. All the loss functions are untouched! Except for the CommittorLoss as it does not depend only on the output space but also on the derivatives wrt the input/positions.
  2. When an external GNN model is provided, checkpoint and logs are still not working. I left these things for the very end of the PR, focusing on making the things work before.
  3. Autoencoder based CVs only raise a NotImplementedError, as we do not have for now a stable GNN-AE. As a consequence, also the MultiTaskCV is does not support GNN models as, for the way we intend it, it wouldn't have much sense without a GNN-based AE.

Questions

  • What shall we do with the BLOCKS? Is it worth it to keep this thing?

TODOs

  • Make logger and checkpointing work with graph models 🗡️
  • Add autoencoders (in the future)

Explain module

Affecting --> mlcolvar.explain

Overview

There are some new explain functions for GNN model that we should add to the explain module

Possible solution

Include the GNN codes as they are, eventually with some revision, into a mlcolvar.explain.graph module or also without the submodule as there are no overlaps here, I think

Questions

  • Do we need to create a submodule of explain?

General todos

  • Check everything 😄
  • Fix dependencies
  • Fix and clean imports
  • Fix init files
  • Remove commented lines
  • Tests !!!
  • DOCS !!!
  • Multitask tests!
  • Change CV from scratch tutorial, now it's fixed badly to make the test go

General questions

  • How many new dependencies do we want to keep? Can we make something optional?

Status

  • Ready to go

# inherit right device
device = x.device

mass = mass.to(device)
dtype = x.dtype

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable dtype is not used.
mlcolvar/core/loss/eigvals.py Dismissed Show dismissed Hide dismissed

# create and edit dataset
dataset = DictDataset({"data": X, "labels": y, "weights": w})
dataset = compute_committor_weights(dataset=dataset, bias=bias, data_groups=[0,1,2], beta=1.0)

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable dataset is not used.
Comment on lines +92 to +99
# if isinstance(model, FeedForward):
# # self.nn = model
# elif isinstance(model, BaseGNN):
# # GNN models need to be scripted!
# # self.nn = torch.jit.script_if_tracing(model)
# # self.nn = model
# self.in_features = None
# self.out_features = model.out_features

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
# eval
model.eval()
with torch.no_grad():
s = model(X).numpy()

Check warning

Code scanning / CodeQL

Variable defined multiple times Warning

This assignment to 's' is unnecessary as it is
redefined
before this value is used.
import torch
import numpy as np

from typing import Tuple, Dict, Optional, List

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'Tuple' is not used.
Import of 'Dict' is not used.
mlcolvar/graph/utils/__init__.py Fixed Show fixed Hide fixed
mlcolvar/graph/utils/__init__.py Fixed Show fixed Hide fixed
mlcolvar/graph/utils/__init__.py Fixed Show fixed Hide fixed
# graph data
from mlcolvar.data.graph.utils import create_test_graph_input
dataset = create_test_graph_input('dataset')
lagged_dataset = create_timelagged_dataset(dataset, logweights=torch.randn(len(dataset)))

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable lagged_dataset is not used.

torch.set_default_dtype(dtype)

rbf

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

torch.set_default_dtype(dtype)

rbf

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

torch.set_default_dtype(dtype)

cutoff_function

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.
# ===================loss=====================
if self.training:
loss, loss_var, loss_bound_A, loss_bound_B = self.loss_fn(
x, q, labels, weights
x, z, q, labels, weights

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'labels' may be used before it is initialized.
# ===================loss=====================
if self.training:
loss, loss_var, loss_bound_A, loss_bound_B = self.loss_fn(
x, q, labels, weights
x, z, q, labels, weights

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'weights' may be used before it is initialized.
)
else:
loss, loss_var, loss_bound_A, loss_bound_B = self.loss_fn(
x, q, labels, weights
x, z, q, labels, weights

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'labels' may be used before it is initialized.
)
else:
loss, loss_var, loss_bound_A, loss_bound_B = self.loss_fn(
x, q, labels, weights
x, z, q, labels, weights

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'weights' may be used before it is initialized.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants