-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Graphs_integration #161
base: main
Are you sure you want to change the base?
Graphs_integration #161
Conversation
# inherit right device | ||
device = x.device | ||
|
||
mass = mass.to(device) | ||
dtype = x.dtype |
Check notice
Code scanning / CodeQL
Unused local variable Note
|
||
# create and edit dataset | ||
dataset = DictDataset({"data": X, "labels": y, "weights": w}) | ||
dataset = compute_committor_weights(dataset=dataset, bias=bias, data_groups=[0,1,2], beta=1.0) |
Check notice
Code scanning / CodeQL
Unused local variable Note
# if isinstance(model, FeedForward): | ||
# # self.nn = model | ||
# elif isinstance(model, BaseGNN): | ||
# # GNN models need to be scripted! | ||
# # self.nn = torch.jit.script_if_tracing(model) | ||
# # self.nn = model | ||
# self.in_features = None | ||
# self.out_features = model.out_features |
Check notice
Code scanning / CodeQL
Commented-out code Note
# eval | ||
model.eval() | ||
with torch.no_grad(): | ||
s = model(X).numpy() |
Check warning
Code scanning / CodeQL
Variable defined multiple times Warning
redefined
import torch | ||
import numpy as np | ||
|
||
from typing import Tuple, Dict, Optional, List |
Check notice
Code scanning / CodeQL
Unused import Note
Import of 'Dict' is not used.
# graph data | ||
from mlcolvar.data.graph.utils import create_test_graph_input | ||
dataset = create_test_graph_input('dataset') | ||
lagged_dataset = create_timelagged_dataset(dataset, logweights=torch.randn(len(dataset))) |
Check notice
Code scanning / CodeQL
Unused local variable Note
|
||
torch.set_default_dtype(dtype) | ||
|
||
rbf |
Check notice
Code scanning / CodeQL
Statement has no effect Note
|
||
torch.set_default_dtype(dtype) | ||
|
||
rbf |
Check notice
Code scanning / CodeQL
Statement has no effect Note
|
||
torch.set_default_dtype(dtype) | ||
|
||
cutoff_function |
Check notice
Code scanning / CodeQL
Statement has no effect Note
# ===================loss===================== | ||
if self.training: | ||
loss, loss_var, loss_bound_A, loss_bound_B = self.loss_fn( | ||
x, q, labels, weights | ||
x, z, q, labels, weights |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
# ===================loss===================== | ||
if self.training: | ||
loss, loss_var, loss_bound_A, loss_bound_B = self.loss_fn( | ||
x, q, labels, weights | ||
x, z, q, labels, weights |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
) | ||
else: | ||
loss, loss_var, loss_bound_A, loss_bound_B = self.loss_fn( | ||
x, q, labels, weights | ||
x, z, q, labels, weights |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
) | ||
else: | ||
loss, loss_var, loss_bound_A, loss_bound_B = self.loss_fn( | ||
x, q, labels, weights | ||
x, z, q, labels, weights |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
General description
Add the code for CVs based on GNN in the most (possible) organic way.
This largely inherits from Jintu's work (kudos! @jintuzhang), where all the code was based on a "library within the library".
Some functions were much different for the rest of the code (e.g., all the code for GNN and the GraphDatamodule), others were mostly redundant (e.g., GraphDataset, CV base and specific classes).
It would be wise to reduce the code duplicates and redundancies and make the whole library more organic, still including all the new functionalities.
SPOILER: this requires some thinking and some modifications here and there
(We could also split the story in more PR in case)
Point-by-point description
Data handling
Affecting -->
mlcolvar.data
,mlcolvar.utils.io
Overview
So far, we have a
DictDataset
(based ontorch.Dataset
) and the correspondingDictModule
(based onlightning.lightningDataModule
).For GNNs, there was a
GraphDataset
(based on lists) and the correspondingDictModule
(based onlightning.lightningDataModule
).Here, the data are handled using the
PyTorchGeometric
for convenience.There are also a bunch of auxiliary functions for neighborhoods and handling of atom types, plus some utils to initialize the dataset easily from files.
Implemented solution
The two things are merged:
DictDataset
that can handle both types of data.metadata
attribute that stores in adict
general properties (e.g., cutoff and atom_types).__init__
, the user can specify thedata_type
(eitherdescriptors
(default) orgraphs
. This is then stored inmetadata
and is used in theDictLoader
to handle the data the right way (see below)mlcolvar.data.utils
:save_dataset
,load_dataset
andsave_dataset_configurations_as_extyz
DictModule
that can handle both types of data. Depending on themetadata['data_type']
of the incoming dataset it either uses ourDictLoader
or thetorch_geometric.DataLoader
.data.graph
containing:atomic.py
for the handling of atomic quantities based on the data classConfiguration
neighborhood.py
for building neighbor lists usingmatscipy
utils.py
to frameConfigurations
into dataset and one-hot embeddings. It also containscreate_test_graph_input
as creating inputs for testing here requires several lines of code.create_dataset_from_trajectories
utils inmlcolvar.utils.io
that allows creating a dataset directly from some trajectory files, providing topology files and usingmdtraj
create_timelagged_dataset
that can also create the time-lagged dataset starting fromDictDataset
withdata_type=='graphs'
NB For the graph datasets, the keys are the original ones:
data_list
: all the graph data, e.g., edge src and dst, batch index... (this goes inDictDataset
)z_table
: atomic numbers map (this goes inDictDataset.metadata
)cutoff
: cutoff used in the graph (this goes inDictDataset.metadata
)Questions
GNN models
Affecting -->
mlcolvar.core.nn
Overview
Of course, they need to be implemented 😄 but we can inherit most of the code from Jintu.
As an overview, there is a
BaseGNN
parent class that implements the common features, and then each model (e.g.,SchNet
or GVP) is implemented on top of that.There is also a
radial.py
that implements a bunch of tools for radial embeddings.Implemented solution
The GNN code is now implemented in
mlcolvar.core.nn.graph
.BaseGNN
class that is a template for the architecture-specific code. This, for example, already has the methods for embedding edges and setting some common properties.Radial
module implements the tools for radial embeddingsSchNetModel
andGVPModel
are implemented based onBaseGNN
utils.py
, there is a function that creates data for the tests for this module. This should be replaced using the very similar functionmlcolvar.data.graph.utils.create_test_graph_input
that is more general and used also for other thingsCV models
Affecting -->
mlcolvar.cvs
,mlcolvar.core.loss
Overview
In Jintu's implementation, all the CV classes we tested were re-implemented, still using the original loss function code.
The point, there, is that the initialization of the underlying ML model (also in the current version of the library) is performed within the CV class.
We did it to make it simple, and indeed, it is for feed-forward networks, as they have very few things to set (i.e., layers, node, activations) and also because there were no alternatives at the time.
For GNNs, however, the initialization can vary a lot (i.e., different architectures and many parameters one could set).
I am afraid we can't cut corners here if we want to include everything and somewhere we need to add an extra layer of complexity to wither the workflow or the CV models,
Implemented solution
We keep everything similar to what it used to be in the library, except for:
layers
keyword to the more generalmodel
in the init of the CV classes that can acceptlayers
keyword and initializes a FeedForward with that and all theDEFAULT_BLOCKS' (see point 2), e.g. for DeepLDA:
['norm_in', 'nn', 'lda']`.mlcolvar.core.nn.FeedForward
ormlcolvar.core.nn.graph.BaseGNN
model that you had initialized outside the CV class. This way, one overrides the old default and provides an external model and uses theMODEL_BLOCKS
, e.g. for DeepLDA:['nn', 'lda']
. For example, the initialization can be something like thisBLOCKS
of each CV model are duplicated inDEFAULT_BLOCKS' and
MODEL_BLOCKS` to account for the different behaviors. This was a simple way to initialize everything in all the cases (maybe not best one, see questions)###Things to note
CommittorLoss
as it does not depend only on the output space but also on the derivatives wrt the input/positions.NotImplementedError
, as we do not have for now a stable GNN-AE. As a consequence, also theMultiTaskCV
is does not support GNN models as, for the way we intend it, it wouldn't have much sense without a GNN-based AE.Questions
BLOCKS
? Is it worth it to keep this thing?TODOs
Explain module
Affecting -->
mlcolvar.explain
Overview
There are some new explain functions for GNN model that we should add to the explain module
Possible solution
Include the GNN codes as they are, eventually with some revision, into a
mlcolvar.explain.graph
module or also without the submodule as there are no overlaps here, I thinkQuestions
General todos
General questions
Status