diff --git a/torchmdnet/datasets/maceoff.py b/torchmdnet/datasets/maceoff.py index 0326a01d..5725e3df 100644 --- a/torchmdnet/datasets/maceoff.py +++ b/torchmdnet/datasets/maceoff.py @@ -108,6 +108,10 @@ def sample_iter(self, mol_ids=False): if self.max_gradient: if data.neg_dy.norm(dim=1).max() > float(self.max_gradient): continue + if self.pre_filter is not None and not self.pre_filter(data): + continue + if self.pre_transform is not None: + data = self.pre_transform(data) yield data def download(self):