Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GNN cvs module merge #160

Closed
wants to merge 13 commits into from
Closed

GNN cvs module merge #160

wants to merge 13 commits into from

Conversation

EnricoTrizio
Copy link
Collaborator

@EnricoTrizio EnricoTrizio commented Nov 13, 2024

General description

Add the code for CVs based on GNN in the most (possible) organic way.
This largely inherits from Jintu's work (kudos!), where all the code was based on a "library within the library".
Some functions were much different for the rest of the code (e.g., all the code for GNN and the GraphDatamodule), others were mostly redundant (e.g., GraphDataset, CV base and specific classes).

It would be wise to reduce the code duplicates and redundancies and make the whole library more organic, still including all the new functionalities.
SPOILER: this requires some thinking and some modifications here and there

(We could also split the story in more PR in case)

Point-by-point description

Data handling

Affecting --> mlcolvar.data, mlcolvar.utils.io

Overview

So far, we have a DictDataset (based on torch.Dataset) and the corresponding DictModule (based on lightning.lightningDataModule).

For GNNs, there was a GraphDataset (based on lists) and the corresponding DictModule (based on lightning.lightningDataModule).
Here, the data are handled using the PyTorchGeometric for convenience.
There are also a bunch of auxiliary functions for neighborhoods and handling of atom types, plus some utils to initialize the dataset easily from files.

Possible solution

Merge the two dataset classes into DictDataset as they basically do the same thing.
Add a metadata attribute to the dataset to store quantities that are not data (e.g., cutoff and atom_types).
Add some exceptions for the torch-geometric-like data (still easy thanks to the dictionary structure, i.e., they will always be found with the same key).
All the other classes/functions that are different will go in a mlcolvar.data.graph module (almost) as they are.

Questions

  • Does it make sense?

GNN models

Affecting --> mlcolvar.core.nn

Overview

Of course, they need to be implemented 😄 but we can inherit most of the code from Jintu.
As an overview, there is a BaseGNN parent class that implements the common features, and then each model (e.g., SchNet or GVP) is implemented on top of that.
There is also a radial.py that implements a bunch of tools for radial embeddings.

Possible solution

Include the GNN codes as they are, eventually with some revision, into a mlcolvar.core.nn.gnn module

Questions

  • Maybe we can try to rely on PyTorch_geometric also here?
  • Scripting tracing and nasty things?

CV model

Affecting --> mlcolvar.cvs

Overview

In Jintu's implementation, all the CV classes we tested were re-implemented, still using the original loss function code.
The point, there, is that the initialization of the underlying ML model (also in the current version of the library) is performed within the CV class.
We did it to make it simple, and indeed, it is for feed-forward networks, as they have very few things to set (i.e., layers, node, activations) and also because there were no alternatives at the time.
For GNNs, however, the initialization can vary a lot (i.e., different architectures and many parameters one could set).

I am afraid we can't cut corners here if we want to include everything and somewhere we need to add an extra layer of complexity to wither the workflow or the CV models,

Possible solution(s)

  1. (my first implementation, now in the PR) We keep everything similar to what it is now. We add a gnn_model=None keyword to the init of the CV classes that can use GNNs (plus some error messages and checks here and there) so that you can pass GNN model that you had initialized outside the CV class
# for GNN-based
gnn_model = SchNet(...)
model = DeepTDA(..., gnn_model=gnn_model, ...)

# for FFNN-based
model = DeepTDA(..., gnn_model=None (default), ...)
  • if gnn_model is None: we use the feedforward implementation as it used to be (read: if you don't know you can use GNN, you don't mess up anything)
  • if gnn_model is BaseGNN: we override under the hood the code that needs to be adapted (i.e., the blocks and the training_step), all the rest can be the same for what I've seen so far (only TDA)

PRO: Only one Base CV class (maybe with a few modifications, only one specific CV class (with a decent amount of modifications), the user experience will not change much.
CONS: the whole mechanism may not be super clear and clean

  1. The same as 1 but more general: we also take the feedforward model initialization out of the CV models and we add some signature to the different models (i.e., model.model_type that can be ff or gnn) so that we can use the right code thereafter in the CV.
# for GNN-based
gnn_model = SchNet(...)
model = DeepTDA(..., model=gnn_model, ...)

# for FFNN-based
ff_model = FeedForward(...)
model = DeepTDA(..., model=ff_model, ...)

PRO: similar to 1 but more general and maybe less confusing than 1
CONS: always adds one more step to the workflow it may sound more complicated than before (even if it's just one line)

  1. (The Jintu's way) Keep two separate classes for graph and feed-forward based CVs.
# for GNN-based
model = GraphDeepTDA(TDA_params... , GNN_params...)

# for FFNN-based
model = DeepTDA(TDA_params..., FFNN_params)

PRO: no (eventually breaking) changes anywhere and for sure much lower activation energy for this PR.
CONS: quite redundant code, still the GNN require a lot of parameters to be set

  1. Strange things with classes of classes that @andrrizzi may suggest

Questions

  • What do we prefer? User experience? Code conciseness? Less changes?

General todos

  • Check everything 😄

General questions

  • Question1

Status

  • Ready to go

Copy link

@github-advanced-security github-advanced-security bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CodeQL found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.

@@ -0,0 +1,135 @@
import torch
import lightning

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'lightning' is not used.
import torch
import lightning
from torch import nn
import numpy as np

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'np' is not used.
import lightning
from torch import nn
import numpy as np
import torch_geometric as tg

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'tg' is not used.
from torch import nn
import numpy as np
import torch_geometric as tg
from typing import List, Dict, Tuple, Any

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'Any' is not used.

from mlcolvar.core.nn.graph.gnn import BaseGNN

from typing import List, Dict, Tuple

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'Tuple' is not used.
Comment on lines +186 to +196
# def __repr__(self) -> str:
# result = ''
# n_digits = len(str(self._n_total))
# data_string_1 = '[ \033[32m{{:{:d}d}}\033[0m\033[36m 󰡷 \033[0m'
# data_string_2 = '| \033[32m{{:{:d}d}}\033[0m\033[36m  \033[0m'
# shuffle_string_1 = '|\033[36m  \033[0m ]'
# shuffle_string_2 = '|\033[36m  \033[0m ]'

# prefix = '\033[1m\033[34m BASEDATA \033[0m: '
# result += (
# prefix + self._dataset.__repr__().split('GRAPHDATASET ')[1] + '\n'

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
Comment on lines +205 to +215
# if self.shuffle[0]:
# result += shuffle_string_1
# else:
# result += shuffle_string_2

# if self._n_validation > 0:
# result += '\n'
# prefix = '\033[1m\033[34m VALIDATION \033[0m: '
# string = prefix + data_string_1.format(n_digits)
# result += string.format(
# self._n_validation, self._n_validation / self._n_total * 100

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
Comment on lines +219 to +229
# if self.shuffle[1]:
# result += shuffle_string_1
# else:
# result += shuffle_string_2

# if self._n_test > 0:
# result += '\n'
# prefix = '\033[1m\033[34m TEST \033[0m: '
# string = prefix + data_string_1.format(n_digits)
# result += string.format(
# self._n_test, self._n_test / self._n_total * 100

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
Comment on lines +233 to +236
# if self.shuffle[2]:
# result += shuffle_string_1
# else:
# result += shuffle_string_2

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
]


class GraphDataSet(list):

Check warning

Code scanning / CodeQL

`__eq__` not overridden when adding attributes Warning

The class 'GraphDataSet' does not override
'__eq__'
, but adds the new attribute
__atomic_numbers
.
The class 'GraphDataSet' does not override
'__eq__'
, but adds the new attribute
__cutoff
.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant