From 4023eb0bd7cb8c16f5f8879b1b1b6050b5702fd8 Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Tue, 7 Nov 2023 12:22:57 +0100 Subject: [PATCH] fixing tests --- pina/callbacks/adaptive_refinment_callbacks.py | 3 ++- pina/label_tensor.py | 8 ++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/pina/callbacks/adaptive_refinment_callbacks.py b/pina/callbacks/adaptive_refinment_callbacks.py index 664c8be6..03ce5838 100644 --- a/pina/callbacks/adaptive_refinment_callbacks.py +++ b/pina/callbacks/adaptive_refinment_callbacks.py @@ -61,7 +61,7 @@ def _compute_residual(self, trainer): pts.retain_grad() # PINN loss: equation evaluated only on locations where sampling is needed target = condition.equation.residual(pts, solver.forward(pts)) - res_loss[location] = torch.abs(target) + res_loss[location] = torch.abs(target).as_subclass(torch.Tensor) tot_loss.append(torch.abs(target)) return torch.vstack(tot_loss), res_loss @@ -74,6 +74,7 @@ def _r3_routine(self, trainer): """ # compute residual (all device possible) tot_loss, res_loss = self._compute_residual(trainer) + tot_loss = tot_loss.as_subclass(torch.Tensor) # !!!!!! From now everything is performed on CPU !!!!!! diff --git a/pina/label_tensor.py b/pina/label_tensor.py index d612b47b..fe9c51b2 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -118,7 +118,7 @@ def clone(self, *args, **kwargs): :return: a copy of the tensor :rtype: LabelTensor """ - # used before merging + # # used before merging # try: # out = LabelTensor(super().clone(*args, **kwargs), self.labels) # except: @@ -229,8 +229,12 @@ def __getitem__(self, index): """ Return a copy of the selected tensor. """ - selected_lt = super(Tensor, self).__getitem__(index) + if isinstance(index, str) or (isinstance(index, (tuple, list))and all(isinstance(a, str) for a in index)): + return self.extract(index) + + selected_lt = super(Tensor, self).__getitem__(index) + try: len_index = len(index) except TypeError: