Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for ensemble models #308

Merged
merged 12 commits into from
Mar 19, 2024
83 changes: 52 additions & 31 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,38 +83,9 @@ This is a minimal example of a custom training loop:
optimizer.step()




Loading a model for inference
-----------------------------

Once you have trained a model you should have a checkpoint that you can load for inference using :py:mod:`torchmdnet.models.model.load_model` as in the following example.

.. code:: python

import torch
from torchmdnet.models.model import load_model
checkpoint = "/path/to/checkpoint/my_checkpoint.ckpt"
model = load_model(checkpoint, derivative=True)
# An arbitrary set of inputs for the model
n_atoms = 10
zs = torch.tensor([1, 6, 7, 8, 9], dtype=torch.long)
z = zs[torch.randint(0, len(zs), (n_atoms,))]
pos = torch.randn(len(z), 3)
batch = torch.zeros(len(z), dtype=torch.long)

y, neg_dy = model(z, pos, batch)

.. note:: You can train a model using only the labels (i.e. energy) by passing :code:`derivative=False` and then set it to :code:`True` to compute its derivative (i.e. forces) only during inference.

.. note:: Some models take additional inputs such as the charge :code:`q` and the spin :code:`s` of the atoms depending on the chosen priors/outputs. Check the documentation of the model you are using to see if this is the case.

.. note:: When periodic boundary conditions are required, modules typically offer the possibility of providing the box vectors at construction and/or as an argument to the forward pass. Check the documentation of the class you are using to see if this is the case.


.. _delta-learning:
Training on relative energies
-----------------------------
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

It might be useful to train the model on relative energies but then make the model produce total energies when running inference.
TorchMD-Net supports delta training via the :code:`remove_ref_energy` option. Passing this option when training (either via the :ref:`configuration-file` or using the :ref:`torchmd-train` command line interface) will subtract the reference energy from each atom in a sample before passing it to the model.
Expand All @@ -126,7 +97,7 @@ If :code:`remove_ref_energy` is turned on, the reference energy is stored in the
.. note:: The reference energies are stored as an :py:mod:`torchmdnet.priors.Atomref` prior with :code:`enable=False`.

Example
~~~~~~~
********

First we train a model with the :code:`remove_ref_energy` option turned on:

Expand All @@ -151,6 +122,56 @@ Then we load the model for inference:
batch = torch.zeros(len(z), dtype=torch.long)

y, neg_dy = model(z, pos, batch)


Loading a model for inference
-----------------------------

Once you have trained a model you should have a checkpoint that you can load for inference using :py:mod:`torchmdnet.models.model.load_model` as in the following example.

.. code:: python

import torch
from torchmdnet.models.model import load_model
checkpoint = "/path/to/checkpoint/my_checkpoint.ckpt"
model = load_model(checkpoint, derivative=True)
# An arbitrary set of inputs for the model
n_atoms = 10
zs = torch.tensor([1, 6, 7, 8, 9], dtype=torch.long)
z = zs[torch.randint(0, len(zs), (n_atoms,))]
pos = torch.randn(len(z), 3)
batch = torch.zeros(len(z), dtype=torch.long)

y, neg_dy = model(z, pos, batch)

.. note:: You can train a model using only the labels (i.e. energy) by passing :code:`derivative=False` and then set it to :code:`True` to compute its derivative (i.e. forces) only during inference.

.. note:: Some models take additional inputs such as the charge :code:`q` and the spin :code:`s` of the atoms depending on the chosen priors/outputs. Check the documentation of the model you are using to see if this is the case.

.. note:: When periodic boundary conditions are required, modules typically offer the possibility of providing the box vectors at construction and/or as an argument to the forward pass. Check the documentation of the class you are using to see if this is the case.


Model Ensembles
---------------
It is possible to create an ensemble of models by loading multiple checkpoints and averaging their predictions. The following example shows how to do this:

.. code:: python

import torch
from torchmdnet.models.model import load_model
checkpoints = ["/path/to/checkpoint/my_checkpoint1.ckpt", "/path/to/checkpoint/my_checkpoint2.ckpt"]
model_ensemble = load_model(checkpoints, return_std=True)
y_ensemble, neg_dy_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch)


.. note:: :py:mod:`torchmdnet.models.model.load_model` will return an instance of :py:mod:`torchmdnet.models.model.Ensemble` if a list of checkpoints is passed. The :code:`return_std` option can be used to return the standard deviation of the predictions.



.. autoclass:: torchmdnet.models.model.Ensemble
:noindex:





Expand Down
57 changes: 45 additions & 12 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import lightning as pl
from torchmdnet import models
from torchmdnet.models.model import create_model
from torchmdnet.models.model import create_model, load_model
from torchmdnet.models import output_modules
from torchmdnet.models.utils import dtype_mapping

Expand All @@ -23,7 +23,9 @@
def test_forward(model_name, use_batch, explicit_q_s, precision):
z, pos, batch = create_example_batch()
pos = pos.to(dtype=dtype_mapping[precision])
model = create_model(load_example_args(model_name, prior_model=None, precision=precision))
model = create_model(
load_example_args(model_name, prior_model=None, precision=precision)
)
batch = batch if use_batch else None
if explicit_q_s:
model(z, pos, batch=batch, q=None, s=None)
Expand All @@ -33,10 +35,12 @@ def test_forward(model_name, use_batch, explicit_q_s, precision):

@mark.parametrize("model_name", models.__all_models__)
@mark.parametrize("output_model", output_modules.__all__)
@mark.parametrize("precision", [32,64])
@mark.parametrize("precision", [32, 64])
def test_forward_output_modules(model_name, output_model, precision):
z, pos, batch = create_example_batch()
args = load_example_args(model_name, remove_prior=True, output_model=output_model, precision=precision)
args = load_example_args(
model_name, remove_prior=True, output_model=output_model, precision=precision
)
model = create_model(args)
model(z, pos, batch=batch)

Expand All @@ -61,18 +65,25 @@ def test_torchscript(model_name, device):
grad_outputs=grad_outputs,
)[0]


def test_torchscript_output_modification():
model = create_model(load_example_args("tensornet", remove_prior=True, derivative=True))
model = create_model(
load_example_args("tensornet", remove_prior=True, derivative=True)
)

class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.model = model

def forward(self, z, pos, batch):
y, neg_dy = self.model(z, pos, batch=batch)
# A TorchScript bug is triggered if we modify an output of model marked as Optional[Tensor]
return y, 2*neg_dy
return y, 2 * neg_dy

torch.jit.script(MyModel())


@mark.parametrize("model_name", models.__all_models__)
@mark.parametrize("device", ["cpu", "cuda"])
def test_torchscript_dynamic_shapes(model_name, device):
Expand All @@ -84,11 +95,11 @@ def test_torchscript_dynamic_shapes(model_name, device):
model = torch.jit.script(
create_model(load_example_args(model_name, remove_prior=True, derivative=True))
).to(device=device)
#Repeat the input to make it dynamic
# Repeat the input to make it dynamic
for rep in range(0, 5):
print(rep)
zi = z.repeat_interleave(rep+1, dim=0).to(device=device)
posi = pos.repeat_interleave(rep+1, dim=0).to(device=device)
zi = z.repeat_interleave(rep + 1, dim=0).to(device=device)
posi = pos.repeat_interleave(rep + 1, dim=0).to(device=device)
batchi = torch.randint(0, 10, (zi.shape[0],)).sort()[0].to(device=device)
y, neg_dy = model(zi, posi, batch=batchi)
grad_outputs = [torch.ones_like(neg_dy)]
Expand All @@ -98,7 +109,8 @@ def test_torchscript_dynamic_shapes(model_name, device):
grad_outputs=grad_outputs,
)[0]

#Currently only tensornet is CUDA graph compatible

# Currently only tensornet is CUDA graph compatible
@mark.parametrize("model_name", ["tensornet"])
def test_cuda_graph_compatible(model_name):
if not torch.cuda.is_available():
Expand Down Expand Up @@ -142,6 +154,7 @@ def test_cuda_graph_compatible(model_name):
assert torch.allclose(y, y2)
assert torch.allclose(neg_dy, neg_dy2, atol=1e-5, rtol=1e-5)


@mark.parametrize("model_name", models.__all_models__)
def test_seed(model_name):
args = load_example_args(model_name, remove_prior=True)
Expand All @@ -153,6 +166,7 @@ def test_seed(model_name):
for p1, p2 in zip(m1.parameters(), m2.parameters()):
assert (p1 == p2).all(), "Parameters don't match although using the same seed."


@mark.parametrize("model_name", models.__all_models__)
@mark.parametrize(
"output_model",
Expand Down Expand Up @@ -199,7 +213,9 @@ def test_forward_output(model_name, output_model, overwrite_reference=False):
), f"Set new reference outputs for {model_name} with output model {output_model}."

# compare actual ouput with reference
torch.testing.assert_close(pred, expected[model_name][output_model]["pred"], atol=1e-5, rtol=1e-5)
torch.testing.assert_close(
pred, expected[model_name][output_model]["pred"], atol=1e-5, rtol=1e-5
)
if derivative:
torch.testing.assert_close(
deriv, expected[model_name][output_model]["deriv"], atol=1e-5, rtol=1e-5
Expand All @@ -218,7 +234,7 @@ def test_gradients(model_name):
remove_prior=True,
output_model=output_model,
derivative=derivative,
precision=precision
precision=precision,
)
model = create_model(args)
z, pos, batch = create_example_batch(n_atoms=5)
Expand All @@ -227,3 +243,20 @@ def test_gradients(model_name):
torch.autograd.gradcheck(
model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3
)


def test_ensemble():
ckpts = [join(dirname(dirname(__file__)), "tests", "example.ckpt")] * 3
model = load_model(ckpts[0])
ensemble_model = load_model(ckpts, return_std=True)
z, pos, batch = create_example_batch(n_atoms=5)

pred, deriv = model(z, pos, batch)
pred_ensemble, deriv_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch)

torch.testing.assert_close(pred, pred_ensemble, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(deriv, deriv_ensemble, atol=1e-5, rtol=1e-5)
assert y_std.shape == pred.shape
assert neg_dy_std.shape == deriv.shape
assert (y_std == 0).all()
assert (neg_dy_std == 0).all()
97 changes: 78 additions & 19 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,18 +139,25 @@ def create_model(args, prior_model=None, mean=None, std=None):
return model


def load_model(filepath, args=None, device="cpu", **kwargs):
def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs):
"""Load a model from a checkpoint file.

If a list of paths is given, an :py:mod:`Ensemble` model is returned.
Args:
filepath (str): Path to the checkpoint file.
filepath (str or list): Path to the checkpoint file or a list of paths.
args (dict, optional): Arguments for the model. Defaults to None.
device (str, optional): Device on which the model should be loaded. Defaults to "cpu".
return_std (bool, optional): Whether to return the standard deviation of an Ensemble model. Defaults to False.
**kwargs: Extra keyword arguments for the model.

Returns:
nn.Module: An instance of the TorchMD_Net model.
"""
if isinstance(filepath, (list, tuple)):
return Ensemble(
[load_model(f, args=args, device=device, **kwargs) for f in filepath],
return_std=return_std,
)

ckpt = torch.load(filepath, map_location="cpu")
if args is None:
Expand Down Expand Up @@ -187,29 +194,32 @@ def create_prior_models(args, dataset=None):

1. A single prior model name and its arguments as a dictionary:

```python
args = {
"prior_model": "Atomref",
"prior_args": {"max_z": 100}
}
```
.. code:: python

args = {
"prior_model": "Atomref",
"prior_args": {"max_z": 100}
}


2. A list of prior model names and their arguments as a list of dictionaries:

```python
.. code:: python

args = {
"prior_model": ["Atomref", "D2"],
"prior_args": [{"max_z": 100}, {"max_z": 100}]
}

args = {
"prior_model": ["Atomref", "D2"],
"prior_args": [{"max_z": 100}, {"max_z": 100}]
}
```

3. A list of prior model names and their arguments as a dictionary:

```python
args = {
"prior_model": [{"Atomref": {"max_z": 100}}, {"D2": {"max_z": 100}}]
}
```
.. code:: python

args = {
"prior_model": [{"Atomref": {"max_z": 100}}, {"D2": {"max_z": 100}}]
}


Args:
args (dict): Arguments for the model.
Expand Down Expand Up @@ -426,3 +436,52 @@ def forward(
# Returning an empty tensor allows to decorate this method as always returning two tensors.
# This is required to overcome a TorchScript limitation, xref https://github.com/openmm/openmm-torch/issues/135
return y, torch.empty(0)


class Ensemble(torch.nn.ModuleList):
"""Average predictions over an ensemble of TorchMD-Net models.

This module behaves like a single TorchMD-Net model, but its forward method returns the average and standard deviation of the predictions over all models it was initialized with.

Args:
modules (List[nn.Module]): List of :py:mod:`TorchMD_Net` models to average predictions over.
return_std (bool, optional): Whether to return the standard deviation of the predictions. Defaults to False. If set to True, the model returns 4 arguments (mean_y, mean_neg_dy, std_y, std_neg_dy) instead of 2 (mean_y, mean_neg_dy).
"""

def __init__(self, modules: List[nn.Module], return_std: bool = False):
for module in modules:
assert isinstance(module, TorchMD_Net)
super().__init__(modules)
self.return_std = return_std

def forward(
self,
*args,
**kwargs,
):
"""Average predictions over all models in the ensemble.
The arguments to this function are simply relayed to the forward method of each :py:mod:`TorchMD_Net` model in the ensemble.
Args:
*args: Positional arguments to forward to the models.
**kwargs: Keyword arguments to forward to the models.
Returns:
Tuple[Tensor, Optional[Tensor]] or Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]: The average and standard deviation of the predictions over all models in the ensemble. If return_std is False, the output is a tuple (mean_y, mean_neg_dy). If return_std is True, the output is a tuple (mean_y, mean_neg_dy, std_y, std_neg_dy).

"""
y = []
neg_dy = []
for model in self:
res = model(*args, **kwargs)
y.append(res[0])
neg_dy.append(res[1])
y = torch.stack(y)
neg_dy = torch.stack(neg_dy)
y_mean = torch.mean(y, axis=0)
neg_dy_mean = torch.mean(neg_dy, axis=0)
y_std = torch.std(y, axis=0)
neg_dy_std = torch.std(neg_dy, axis=0)

if self.return_std:
return y_mean, neg_dy_mean, y_std, neg_dy_std
else:
return y_mean, neg_dy_mean
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved
Loading