Skip to content

Commit

Permalink
fixing cuda/cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
dario-coscia authored Oct 18, 2023
1 parent 1f0e8c7 commit c8760d2
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions pina/label_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,20 +132,22 @@ def select(self, *args, **kwargs):
return tmp

def cuda(self, *args, **kwargs):
"""
Send Tensor to cuda. For more details, see :meth:`torch.Tensor.cuda`.
"""
tmp = super().cuda(*args, **kwargs)
tmp._labels = self._labels
return tmp
"""
Send Tensor to cuda. For more details, see :meth:`torch.Tensor.cuda`.
"""
tmp = super().cuda(*args, **kwargs)
new = self.__class__.clone(self)
new.data = tmp.data
return tmp

def cpu(self, *args, **kwargs):
"""
Send Tensor to cpu. For more details, see :meth:`torch.Tensor.cpu`.
"""
tmp = super().cpu(*args, **kwargs)
tmp._labels = self._labels
return tmp
"""
Send Tensor to cpu. For more details, see :meth:`torch.Tensor.cpu`.
"""
tmp = super().cpu(*args, **kwargs)
new = self.__class__.clone(self)
new.data = tmp.data
return tmp

def extract(self, label_to_extract):
"""
Expand Down

0 comments on commit c8760d2

Please sign in to comment.