-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GNN cvs module merge #160
GNN cvs module merge #160
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CodeQL found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.
@@ -0,0 +1,135 @@ | |||
import torch | |||
import lightning |
Check notice
Code scanning / CodeQL
Unused import Note
import torch | ||
import lightning | ||
from torch import nn | ||
import numpy as np |
Check notice
Code scanning / CodeQL
Unused import Note
import lightning | ||
from torch import nn | ||
import numpy as np | ||
import torch_geometric as tg |
Check notice
Code scanning / CodeQL
Unused import Note
from torch import nn | ||
import numpy as np | ||
import torch_geometric as tg | ||
from typing import List, Dict, Tuple, Any |
Check notice
Code scanning / CodeQL
Unused import Note
|
||
from mlcolvar.core.nn.graph.gnn import BaseGNN | ||
|
||
from typing import List, Dict, Tuple |
Check notice
Code scanning / CodeQL
Unused import Note
# def __repr__(self) -> str: | ||
# result = '' | ||
# n_digits = len(str(self._n_total)) | ||
# data_string_1 = '[ \033[32m{{:{:d}d}}\033[0m\033[36m \033[0m' | ||
# data_string_2 = '| \033[32m{{:{:d}d}}\033[0m\033[36m \033[0m' | ||
# shuffle_string_1 = '|\033[36m \033[0m ]' | ||
# shuffle_string_2 = '|\033[36m \033[0m ]' | ||
|
||
# prefix = '\033[1m\033[34m BASEDATA \033[0m: ' | ||
# result += ( | ||
# prefix + self._dataset.__repr__().split('GRAPHDATASET ')[1] + '\n' |
Check notice
Code scanning / CodeQL
Commented-out code Note
# if self.shuffle[0]: | ||
# result += shuffle_string_1 | ||
# else: | ||
# result += shuffle_string_2 | ||
|
||
# if self._n_validation > 0: | ||
# result += '\n' | ||
# prefix = '\033[1m\033[34m VALIDATION \033[0m: ' | ||
# string = prefix + data_string_1.format(n_digits) | ||
# result += string.format( | ||
# self._n_validation, self._n_validation / self._n_total * 100 |
Check notice
Code scanning / CodeQL
Commented-out code Note
# if self.shuffle[1]: | ||
# result += shuffle_string_1 | ||
# else: | ||
# result += shuffle_string_2 | ||
|
||
# if self._n_test > 0: | ||
# result += '\n' | ||
# prefix = '\033[1m\033[34m TEST \033[0m: ' | ||
# string = prefix + data_string_1.format(n_digits) | ||
# result += string.format( | ||
# self._n_test, self._n_test / self._n_total * 100 |
Check notice
Code scanning / CodeQL
Commented-out code Note
# if self.shuffle[2]: | ||
# result += shuffle_string_1 | ||
# else: | ||
# result += shuffle_string_2 |
Check notice
Code scanning / CodeQL
Commented-out code Note
] | ||
|
||
|
||
class GraphDataSet(list): |
Check warning
Code scanning / CodeQL
`__eq__` not overridden when adding attributes Warning
'__eq__'
__atomic_numbers
The class 'GraphDataSet' does not override
General description
Add the code for CVs based on GNN in the most (possible) organic way.
This largely inherits from Jintu's work (kudos!), where all the code was based on a "library within the library".
Some functions were much different for the rest of the code (e.g., all the code for GNN and the GraphDatamodule), others were mostly redundant (e.g., GraphDataset, CV base and specific classes).
It would be wise to reduce the code duplicates and redundancies and make the whole library more organic, still including all the new functionalities.
SPOILER: this requires some thinking and some modifications here and there
(We could also split the story in more PR in case)
Point-by-point description
Data handling
Affecting -->
mlcolvar.data
,mlcolvar.utils.io
Overview
So far, we have a
DictDataset
(based ontorch.Dataset
) and the correspondingDictModule
(based onlightning.lightningDataModule
).For GNNs, there was a
GraphDataset
(based on lists) and the correspondingDictModule
(based onlightning.lightningDataModule
).Here, the data are handled using the
PyTorchGeometric
for convenience.There are also a bunch of auxiliary functions for neighborhoods and handling of atom types, plus some utils to initialize the dataset easily from files.
Possible solution
Merge the two dataset classes into
DictDataset
as they basically do the same thing.Add a
metadata
attribute to the dataset to store quantities that are not data (e.g., cutoff and atom_types).Add some exceptions for the torch-geometric-like data (still easy thanks to the dictionary structure, i.e., they will always be found with the same key).
All the other classes/functions that are different will go in a
mlcolvar.data.graph
module (almost) as they are.Questions
GNN models
Affecting -->
mlcolvar.core.nn
Overview
Of course, they need to be implemented 😄 but we can inherit most of the code from Jintu.
As an overview, there is a
BaseGNN
parent class that implements the common features, and then each model (e.g.,SchNet
or GVP) is implemented on top of that.There is also a
radial.py
that implements a bunch of tools for radial embeddings.Possible solution
Include the GNN codes as they are, eventually with some revision, into a
mlcolvar.core.nn.gnn
moduleQuestions
CV model
Affecting -->
mlcolvar.cvs
Overview
In Jintu's implementation, all the CV classes we tested were re-implemented, still using the original loss function code.
The point, there, is that the initialization of the underlying ML model (also in the current version of the library) is performed within the CV class.
We did it to make it simple, and indeed, it is for feed-forward networks, as they have very few things to set (i.e., layers, node, activations) and also because there were no alternatives at the time.
For GNNs, however, the initialization can vary a lot (i.e., different architectures and many parameters one could set).
I am afraid we can't cut corners here if we want to include everything and somewhere we need to add an extra layer of complexity to wither the workflow or the CV models,
Possible solution(s)
gnn_model=None
keyword to the init of the CV classes that can use GNNs (plus some error messages and checks here and there) so that you can pass GNN model that you had initialized outside the CV classgnn_model is None
: we use the feedforward implementation as it used to be (read: if you don't know you can use GNN, you don't mess up anything)gnn_model is BaseGNN
: we override under the hood the code that needs to be adapted (i.e., theblocks
and thetraining_step
), all the rest can be the same for what I've seen so far (only TDA)PRO: Only one Base CV class (maybe with a few modifications, only one specific CV class (with a decent amount of modifications), the user experience will not change much.
CONS: the whole mechanism may not be super clear and clean
model.model_type
that can beff
orgnn
) so that we can use the right code thereafter in the CV.PRO: similar to 1 but more general and maybe less confusing than 1
CONS: always adds one more step to the workflow it may sound more complicated than before (even if it's just one line)
PRO: no (eventually breaking) changes anywhere and for sure much lower activation energy for this PR.
CONS: quite redundant code, still the GNN require a lot of parameters to be set
Questions
General todos
General questions
Status