Skip to content

Commit

Permalink
modified to work with dictDatasets
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoTrizio committed Nov 13, 2024
1 parent 7ff4760 commit d405e5c
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions mlcolvar/cvs/supervised/deeptda_merged.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,15 @@ def training_step(self, train_batch, *args, **kwargs) -> torch.Tensor:
return_loss_terms=True
)
elif self.gnn_model._model_type is 'gnn':
data = train_batch.to_dict()
# data = train_batch.to_dict()
data = train_batch['data_list']
data['positions'].requires_grad_(True)
data['node_attrs'].requires_grad_(True)

output = self.forward(data)

loss, loss_centers, loss_sigmas = self.loss_fn(output,
train_batch.graph_labels.squeeze(),
data["graph_labels"].squeeze(),
return_loss_terms=True
)

Expand Down

0 comments on commit d405e5c

Please sign in to comment.