From 8bff3e9cc6a257862a7f18ffbc69308d60b4ac94 Mon Sep 17 00:00:00 2001 From: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> Date: Wed, 18 Oct 2023 19:00:56 +0200 Subject: [PATCH] fixing cuda/cpu --- pina/label_tensor.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/pina/label_tensor.py b/pina/label_tensor.py index a5fb9ba4..5645381f 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -132,20 +132,22 @@ def select(self, *args, **kwargs): return tmp def cuda(self, *args, **kwargs): - """ - Send Tensor to cuda. For more details, see :meth:`torch.Tensor.cuda`. - """ - tmp = super().cuda(*args, **kwargs) - tmp._labels = self._labels - return tmp + """ + Send Tensor to cuda. For more details, see :meth:`torch.Tensor.cuda`. + """ + tmp = super().cuda(*args, **kwargs) + new = self.__class__.clone(self) + new.data = tmp.data + return tmp def cpu(self, *args, **kwargs): - """ - Send Tensor to cpu. For more details, see :meth:`torch.Tensor.cpu`. - """ - tmp = super().cpu(*args, **kwargs) - tmp._labels = self._labels - return tmp + """ + Send Tensor to cpu. For more details, see :meth:`torch.Tensor.cpu`. + """ + tmp = super().cpu(*args, **kwargs) + new = self.__class__.clone(self) + new.data = tmp.data + return tmp def extract(self, label_to_extract): """