Skip to content

Commit

Permalink
fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Dario Coscia committed Nov 7, 2023
1 parent 05050f4 commit 4023eb0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
3 changes: 2 additions & 1 deletion pina/callbacks/adaptive_refinment_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _compute_residual(self, trainer):
pts.retain_grad()
# PINN loss: equation evaluated only on locations where sampling is needed
target = condition.equation.residual(pts, solver.forward(pts))
res_loss[location] = torch.abs(target)
res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
tot_loss.append(torch.abs(target))

return torch.vstack(tot_loss), res_loss
Expand All @@ -74,6 +74,7 @@ def _r3_routine(self, trainer):
"""
# compute residual (all device possible)
tot_loss, res_loss = self._compute_residual(trainer)
tot_loss = tot_loss.as_subclass(torch.Tensor)

# !!!!!! From now everything is performed on CPU !!!!!!

Expand Down
8 changes: 6 additions & 2 deletions pina/label_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def clone(self, *args, **kwargs):
:return: a copy of the tensor
:rtype: LabelTensor
"""
# used before merging
# # used before merging
# try:
# out = LabelTensor(super().clone(*args, **kwargs), self.labels)
# except:
Expand Down Expand Up @@ -229,8 +229,12 @@ def __getitem__(self, index):
"""
Return a copy of the selected tensor.
"""
selected_lt = super(Tensor, self).__getitem__(index)

if isinstance(index, str) or (isinstance(index, (tuple, list))and all(isinstance(a, str) for a in index)):
return self.extract(index)

selected_lt = super(Tensor, self).__getitem__(index)

try:
len_index = len(index)
except TypeError:
Expand Down

0 comments on commit 4023eb0

Please sign in to comment.