diff --git a/pina/label_tensor.py b/pina/label_tensor.py index a5fb9ba4..5645381f 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -132,20 +132,22 @@ def select(self, *args, **kwargs): return tmp def cuda(self, *args, **kwargs): - """ - Send Tensor to cuda. For more details, see :meth:`torch.Tensor.cuda`. - """ - tmp = super().cuda(*args, **kwargs) - tmp._labels = self._labels - return tmp + """ + Send Tensor to cuda. For more details, see :meth:`torch.Tensor.cuda`. + """ + tmp = super().cuda(*args, **kwargs) + new = self.__class__.clone(self) + new.data = tmp.data + return tmp def cpu(self, *args, **kwargs): - """ - Send Tensor to cpu. For more details, see :meth:`torch.Tensor.cpu`. - """ - tmp = super().cpu(*args, **kwargs) - tmp._labels = self._labels - return tmp + """ + Send Tensor to cpu. For more details, see :meth:`torch.Tensor.cpu`. + """ + tmp = super().cpu(*args, **kwargs) + new = self.__class__.clone(self) + new.data = tmp.data + return tmp def extract(self, label_to_extract): """