Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for ensemble models #308

Merged
merged 12 commits into from
Mar 19, 2024
Merged

Added support for ensemble models #308

merged 12 commits into from
Mar 19, 2024

Conversation

stefdoerr
Copy link
Collaborator

Works fine but I don't like so much that the same function can return a model with 4 outputs (energy/force mean/std) or with 2 outputs (just energy, force) depending on if you pass a single file or multiple.
Maybe it could become an argument to load_model to disable the returning of the std?

@stefdoerr stefdoerr requested a review from RaulPPelaez March 18, 2024 12:14
@stefdoerr stefdoerr self-assigned this Mar 18, 2024
@RaulPPelaez
Copy link
Collaborator

Depending on your final use of this you could return a Dict[str, Tensor] or a Tuple[Tuple[Tensor, Optional[Tensor]], Tuple[Tensor, Optional[Tensor]]] or a List[Dict[str, Optional[Tensor]]] storing mean and std for each output

Comment on lines 443 to 449
z: Tensor,
pos: Tensor,
batch: Optional[Tensor] = None,
box: Optional[Tensor] = None,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_args: Optional[Dict[str, Tensor]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be interesting to make this class take *args,**kwargs for future proofing. I do not think it is compatible with TorchScript, tough.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I definitely considered that. But if you say it's not compatible we can leave it for now I guess

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that we are not aiming for TorchScript anymore, I would make this function variadic so that it is immune to changes in the TorchMD_Net interface.
This is particularly useful for when/if #306 is merged, which changes how we deal with charges, spins and such.

@RaulPPelaez
Copy link
Collaborator

I believe your implementation is clean enough for its purposes. Stacking the model outputs is wasteful, but the code is much cleaner and adaptable than a Welford-like an online computation.

Lets settle on the output and we can write a documentation entry for it and some tests. Maybe below this entry? https://torchmd-net.readthedocs.io/en/latest/models.html#loading-a-model-for-inference

@RaulPPelaez
Copy link
Collaborator

I added a sanity check:

def test_ensemble():
    ckpts = [join(dirname(dirname(__file__)), "tests", "example.ckpt")] * 3
    model = load_model(ckpts[0])
    ensemble_model = load_model(ckpts)
    z, pos, batch = create_example_batch(n_atoms=5)

    pred, deriv = model(z, pos, batch)
    pred_ensemble, deriv_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch)

    torch.testing.assert_close(pred, pred_ensemble, atol=1e-5, rtol=1e-5)
    torch.testing.assert_close(deriv, deriv_ensemble, atol=1e-5, rtol=1e-5)
    assert y_std.shape == pred.shape
    assert neg_dy_std.shape == deriv.shape
    assert (y_std == 0).all()
    assert (neg_dy_std == 0).all()

@stefdoerr
Copy link
Collaborator Author

Nice sanity test, thanks! Okay so we can merge it for now and we see in the future if it needs adaptation for different model types

@RaulPPelaez
Copy link
Collaborator

I added a section on this to the docs and made Ensemble variadic, just relaying the arguments to TorchMD_Net.forward.

This LGTM, feel free to merge if you agree.

@stefdoerr
Copy link
Collaborator Author

Looks lovely. Many thanks Raul!

@stefdoerr stefdoerr merged commit 8a1be71 into main Mar 19, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants