Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for ensemble models #308

Merged
merged 12 commits into from
Mar 19, 2024
99 changes: 67 additions & 32 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import lightning as pl
from torchmdnet import models
from torchmdnet.models.model import create_model
from torchmdnet.models.model import create_model, load_model
from torchmdnet.models import output_modules
from torchmdnet.models.utils import dtype_mapping

Expand All @@ -23,7 +23,9 @@
def test_forward(model_name, use_batch, explicit_q_s, precision):
z, pos, batch = create_example_batch()
pos = pos.to(dtype=dtype_mapping[precision])
model = create_model(load_example_args(model_name, prior_model=None, precision=precision))
model = create_model(
load_example_args(model_name, prior_model=None, precision=precision)
)
batch = batch if use_batch else None
if explicit_q_s:
model(z, pos, batch=batch, q=None, s=None)
Expand All @@ -33,10 +35,12 @@ def test_forward(model_name, use_batch, explicit_q_s, precision):

@mark.parametrize("model_name", models.__all_models__)
@mark.parametrize("output_model", output_modules.__all__)
@mark.parametrize("precision", [32,64])
@mark.parametrize("precision", [32, 64])
def test_forward_output_modules(model_name, output_model, precision):
z, pos, batch = create_example_batch()
args = load_example_args(model_name, remove_prior=True, output_model=output_model, precision=precision)
args = load_example_args(
model_name, remove_prior=True, output_model=output_model, precision=precision
)
model = create_model(args)
model(z, pos, batch=batch)

Expand All @@ -61,18 +65,25 @@ def test_torchscript(model_name, device):
grad_outputs=grad_outputs,
)[0]


def test_torchscript_output_modification():
model = create_model(load_example_args("tensornet", remove_prior=True, derivative=True))
model = create_model(
load_example_args("tensornet", remove_prior=True, derivative=True)
)

class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.model = model

def forward(self, z, pos, batch):
y, neg_dy = self.model(z, pos, batch=batch)
# A TorchScript bug is triggered if we modify an output of model marked as Optional[Tensor]
return y, 2*neg_dy
return y, 2 * neg_dy

torch.jit.script(MyModel())


@mark.parametrize("model_name", models.__all_models__)
@mark.parametrize("device", ["cpu", "cuda"])
def test_torchscript_dynamic_shapes(model_name, device):
Expand All @@ -84,11 +95,11 @@ def test_torchscript_dynamic_shapes(model_name, device):
model = torch.jit.script(
create_model(load_example_args(model_name, remove_prior=True, derivative=True))
).to(device=device)
#Repeat the input to make it dynamic
# Repeat the input to make it dynamic
for rep in range(0, 5):
print(rep)
zi = z.repeat_interleave(rep+1, dim=0).to(device=device)
posi = pos.repeat_interleave(rep+1, dim=0).to(device=device)
zi = z.repeat_interleave(rep + 1, dim=0).to(device=device)
posi = pos.repeat_interleave(rep + 1, dim=0).to(device=device)
batchi = torch.randint(0, 10, (zi.shape[0],)).sort()[0].to(device=device)
y, neg_dy = model(zi, posi, batch=batchi)
grad_outputs = [torch.ones_like(neg_dy)]
Expand All @@ -98,32 +109,35 @@ def test_torchscript_dynamic_shapes(model_name, device):
grad_outputs=grad_outputs,
)[0]

#Currently only tensornet is CUDA graph compatible

# Currently only tensornet is CUDA graph compatible
@mark.parametrize("model_name", ["tensornet"])
def test_cuda_graph_compatible(model_name):
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
z, pos, batch = create_example_batch()
args = {"model": model_name,
"embedding_dimension": 128,
"num_layers": 2,
"num_rbf": 32,
"rbf_type": "expnorm",
"trainable_rbf": False,
"activation": "silu",
"cutoff_lower": 0.0,
"cutoff_upper": 5.0,
"max_z": 100,
"max_num_neighbors": 128,
"equivariance_invariance_group": "O(3)",
"prior_model": None,
"atom_filter": -1,
"derivative": True,
"check_error": False,
"static_shapes": True,
"output_model": "Scalar",
"reduce_op": "sum",
"precision": 32 }
args = {
"model": model_name,
"embedding_dimension": 128,
"num_layers": 2,
"num_rbf": 32,
"rbf_type": "expnorm",
"trainable_rbf": False,
"activation": "silu",
"cutoff_lower": 0.0,
"cutoff_upper": 5.0,
"max_z": 100,
"max_num_neighbors": 128,
"equivariance_invariance_group": "O(3)",
"prior_model": None,
"atom_filter": -1,
"derivative": True,
"check_error": False,
"static_shapes": True,
"output_model": "Scalar",
"reduce_op": "sum",
"precision": 32,
}
model = create_model(args).to(device="cuda")
model.eval()
z = z.to("cuda")
Expand All @@ -142,6 +156,7 @@ def test_cuda_graph_compatible(model_name):
assert torch.allclose(y, y2)
assert torch.allclose(neg_dy, neg_dy2, atol=1e-5, rtol=1e-5)


@mark.parametrize("model_name", models.__all_models__)
def test_seed(model_name):
args = load_example_args(model_name, remove_prior=True)
Expand All @@ -153,6 +168,7 @@ def test_seed(model_name):
for p1, p2 in zip(m1.parameters(), m2.parameters()):
assert (p1 == p2).all(), "Parameters don't match although using the same seed."


@mark.parametrize("model_name", models.__all_models__)
@mark.parametrize(
"output_model",
Expand Down Expand Up @@ -199,7 +215,9 @@ def test_forward_output(model_name, output_model, overwrite_reference=False):
), f"Set new reference outputs for {model_name} with output model {output_model}."

# compare actual ouput with reference
torch.testing.assert_close(pred, expected[model_name][output_model]["pred"], atol=1e-5, rtol=1e-5)
torch.testing.assert_close(
pred, expected[model_name][output_model]["pred"], atol=1e-5, rtol=1e-5
)
if derivative:
torch.testing.assert_close(
deriv, expected[model_name][output_model]["deriv"], atol=1e-5, rtol=1e-5
Expand All @@ -218,7 +236,7 @@ def test_gradients(model_name):
remove_prior=True,
output_model=output_model,
derivative=derivative,
precision=precision
precision=precision,
)
model = create_model(args)
z, pos, batch = create_example_batch(n_atoms=5)
Expand All @@ -227,3 +245,20 @@ def test_gradients(model_name):
torch.autograd.gradcheck(
model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3
)


def test_ensemble():
ckpts = [join(dirname(dirname(__file__)), "tests", "example.ckpt")] * 3
model = load_model(ckpts[0])
ensemble_model = load_model(ckpts, return_std=True)
z, pos, batch = create_example_batch(n_atoms=5)

pred, deriv = model(z, pos, batch)
pred_ensemble, deriv_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch)

torch.testing.assert_close(pred, pred_ensemble, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(deriv, deriv_ensemble, atol=1e-5, rtol=1e-5)
assert y_std.shape == pred.shape
assert neg_dy_std.shape == deriv.shape
assert (y_std == 0).all()
assert (neg_dy_std == 0).all()
59 changes: 57 additions & 2 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,18 +139,25 @@ def create_model(args, prior_model=None, mean=None, std=None):
return model


def load_model(filepath, args=None, device="cpu", **kwargs):
def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs):
"""Load a model from a checkpoint file.

If a list of paths is given, an :py:mod:`Ensemble` model is returned.
Args:
filepath (str): Path to the checkpoint file.
filepath (str or list): Path to the checkpoint file or a list of paths.
args (dict, optional): Arguments for the model. Defaults to None.
device (str, optional): Device on which the model should be loaded. Defaults to "cpu".
return_std (bool, optional): Whether to return the standard deviation of an Ensemble model. Defaults to False.
**kwargs: Extra keyword arguments for the model.

Returns:
nn.Module: An instance of the TorchMD_Net model.
"""
if isinstance(filepath, (list, tuple)):
return Ensemble(
[load_model(f, args=args, device=device, **kwargs) for f in filepath],
return_std=return_std,
)

ckpt = torch.load(filepath, map_location="cpu")
if args is None:
Expand Down Expand Up @@ -426,3 +433,51 @@ def forward(
# Returning an empty tensor allows to decorate this method as always returning two tensors.
# This is required to overcome a TorchScript limitation, xref https://github.com/openmm/openmm-torch/issues/135
return y, torch.empty(0)


class Ensemble(torch.nn.ModuleList):
"""Average predictions over an ensemble of TorchMD-Net models.

This module behaves like a single TorchMD-Net model, but its forward method returns the average and standard deviation of the predictions over all models it was initialized with.

Args:
modules (List[nn.Module]): List of :py:mod:`TorchMD_Net` models to average predictions over.
return_std (bool, optional): Whether to return the standard deviation of the predictions. Defaults to False. If set to True, the model returns 4 arguments (mean_y, mean_neg_dy, std_y, std_neg_dy) instead of 2 (mean_y, mean_neg_dy).
"""

def __init__(self, modules: List[nn.Module], return_std: bool = False):
for module in modules:
assert isinstance(module, TorchMD_Net)
super().__init__(modules)
self.return_std = return_std

def forward(
self,
z: Tensor,
pos: Tensor,
batch: Optional[Tensor] = None,
box: Optional[Tensor] = None,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_args: Optional[Dict[str, Tensor]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be interesting to make this class take *args,**kwargs for future proofing. I do not think it is compatible with TorchScript, tough.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I definitely considered that. But if you say it's not compatible we can leave it for now I guess

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that we are not aiming for TorchScript anymore, I would make this function variadic so that it is immune to changes in the TorchMD_Net interface.
This is particularly useful for when/if #306 is merged, which changes how we deal with charges, spins and such.

):
y = []
neg_dy = []
for model in self:
res = model(
z=z, pos=pos, batch=batch, box=box, q=q, s=s, extra_args=extra_args
)
y.append(res[0])
neg_dy.append(res[1])

y = torch.stack(y)
neg_dy = torch.stack(neg_dy)
y_mean = torch.mean(y, axis=0)
neg_dy_mean = torch.mean(neg_dy, axis=0)
y_std = torch.std(y, axis=0)
neg_dy_std = torch.std(neg_dy, axis=0)

if self.return_std:
return y_mean, neg_dy_mean, y_std, neg_dy_std
else:
return y_mean, neg_dy_mean
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved
Loading