Skip to content

Commit

Permalink
Make Ensemble variadic
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Mar 19, 2024
1 parent 9b0a362 commit 5a60b40
Showing 1 changed file with 6 additions and 11 deletions.
17 changes: 6 additions & 11 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,23 +453,18 @@ def __init__(self, modules: List[nn.Module], return_std: bool = False):

def forward(
self,
z: Tensor,
pos: Tensor,
batch: Optional[Tensor] = None,
box: Optional[Tensor] = None,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_args: Optional[Dict[str, Tensor]] = None,
*args,
**kwargs,
):
"""Average predictions over all models in the ensemble.
The arguments to this function are simply relayed to the forward method of each :py:mod:`TorchMD_Net` model in the ensemble.
"""
y = []
neg_dy = []
for model in self:
res = model(
z=z, pos=pos, batch=batch, box=box, q=q, s=s, extra_args=extra_args
)
res = model(*args, **kwargs)
y.append(res[0])
neg_dy.append(res[1])

y = torch.stack(y)
neg_dy = torch.stack(neg_dy)
y_mean = torch.mean(y, axis=0)
Expand Down

1 comment on commit 5a60b40

@stefdoerr
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, thanks

Please sign in to comment.