From 08903e310f197e8b0f64914ace19eaad014adc12 Mon Sep 17 00:00:00 2001 From: kaminow Date: Tue, 17 Oct 2023 16:26:22 -0400 Subject: [PATCH] Homogenize Model forward pass return signature with GroupedModel. --- mtenn/model.py | 6 +++--- mtenn/tests/test_combination.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mtenn/model.py b/mtenn/model.py index e42b5c8..4817091 100644 --- a/mtenn/model.py +++ b/mtenn/model.py @@ -56,9 +56,9 @@ def forward(self, comp, *parts): energy_val = self.strategy(complex_rep, *parts_rep) if self.readout: - return self.readout(energy_val) + return self.readout(energy_val), [energy_val] else: - return energy_val + return energy_val, [energy_val] def _fix_device(self, data): ## We'll call this on everything for uniformity, but if we fix_deivec is @@ -194,7 +194,7 @@ def forward(self, input_list): flush=True, ) # First get prediction - pred = super().forward(inp) + pred, _ = super().forward(inp) pred_list.append(pred.detach()) # Get gradient per sample diff --git a/mtenn/tests/test_combination.py b/mtenn/tests/test_combination.py index 031192c..4798588 100644 --- a/mtenn/tests/test_combination.py +++ b/mtenn/tests/test_combination.py @@ -37,7 +37,7 @@ def test_mean_combination(models_and_inputs): model_test, model_ref, inp_list, target, loss_func = models_and_inputs # Ref calc - pred_list = [model_ref(X) for X in inp_list] + pred_list = [model_ref(X)[0] for X in inp_list] pred_ref = torch.stack(pred_list).mean(axis=0) loss = loss_func(pred_ref, target) loss.backward() @@ -66,7 +66,7 @@ def test_max_combination(models_and_inputs): model_test, model_ref, inp_list, target, loss_func = models_and_inputs # Ref calc - pred_list = [model_ref(X) for X in inp_list] + pred_list = [model_ref(X)[0] for X in inp_list] pred = torch.logsumexp(torch.stack(pred_list), axis=0) loss = loss_func(pred, target) loss.backward() @@ -98,7 +98,7 @@ def test_boltzmann_combination(models_and_inputs): model_test, model_ref, inp_list, target, loss_func = models_and_inputs # Ref calc - pred_list = torch.stack([model_ref(X) for X in inp_list]) + pred_list = torch.stack([model_ref(X)[0] for X in inp_list]) w = torch.exp(-pred_list - torch.logsumexp(-pred_list, axis=0)) pred_ref = torch.dot(w.flatten(), pred_list.flatten()) loss = loss_func(pred_ref, target)