-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added support for ensemble models #308
Conversation
Depending on your final use of this you could return a Dict[str, Tensor] or a Tuple[Tuple[Tensor, Optional[Tensor]], Tuple[Tensor, Optional[Tensor]]] or a List[Dict[str, Optional[Tensor]]] storing mean and std for each output |
torchmdnet/models/model.py
Outdated
z: Tensor, | ||
pos: Tensor, | ||
batch: Optional[Tensor] = None, | ||
box: Optional[Tensor] = None, | ||
q: Optional[Tensor] = None, | ||
s: Optional[Tensor] = None, | ||
extra_args: Optional[Dict[str, Tensor]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be interesting to make this class take *args,**kwargs for future proofing. I do not think it is compatible with TorchScript, tough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I definitely considered that. But if you say it's not compatible we can leave it for now I guess
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that we are not aiming for TorchScript anymore, I would make this function variadic so that it is immune to changes in the TorchMD_Net interface.
This is particularly useful for when/if #306 is merged, which changes how we deal with charges, spins and such.
I believe your implementation is clean enough for its purposes. Stacking the model outputs is wasteful, but the code is much cleaner and adaptable than a Welford-like an online computation. Lets settle on the output and we can write a documentation entry for it and some tests. Maybe below this entry? https://torchmd-net.readthedocs.io/en/latest/models.html#loading-a-model-for-inference |
I added a sanity check: def test_ensemble():
ckpts = [join(dirname(dirname(__file__)), "tests", "example.ckpt")] * 3
model = load_model(ckpts[0])
ensemble_model = load_model(ckpts)
z, pos, batch = create_example_batch(n_atoms=5)
pred, deriv = model(z, pos, batch)
pred_ensemble, deriv_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch)
torch.testing.assert_close(pred, pred_ensemble, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(deriv, deriv_ensemble, atol=1e-5, rtol=1e-5)
assert y_std.shape == pred.shape
assert neg_dy_std.shape == deriv.shape
assert (y_std == 0).all()
assert (neg_dy_std == 0).all() |
Nice sanity test, thanks! Okay so we can merge it for now and we see in the future if it needs adaptation for different model types |
I added a section on this to the docs and made Ensemble variadic, just relaying the arguments to TorchMD_Net.forward. This LGTM, feel free to merge if you agree. |
Looks lovely. Many thanks Raul! |
Works fine but I don't like so much that the same function can return a model with 4 outputs (energy/force mean/std) or with 2 outputs (just energy, force) depending on if you pass a single file or multiple.
Maybe it could become an argument to load_model to disable the returning of the std?