Skip to content

Commit

Permalink
Label Tensor update (#188)
Browse files Browse the repository at this point in the history
* Update test_label_tensor.py
* adding test

---------

Co-authored-by: Dario Coscia <[email protected]>
  • Loading branch information
dario-coscia and Dario Coscia authored Nov 2, 2023
1 parent b2e1b3f commit f9441c0
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 7 deletions.
38 changes: 34 additions & 4 deletions pina/label_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ def __init__(self, x, labels):
[1.0246e-01, 9.5179e-01, 3.7043e-02],
[9.6150e-01, 8.0656e-01, 8.3824e-01]])
>>> tensor.extract('a')
tensor([[0.0671],
[0.9239],
[0.8927],
...,
[0.5819],
[0.1025],
[0.9615]])
>>> tensor['a']
tensor([[0.0671],
[0.9239],
[0.8927],
Expand Down Expand Up @@ -69,7 +77,7 @@ def __init__(self, x, labels):
'the passed labels.'
)
self._labels = labels

@property
def labels(self):
"""Property decorator for labels
Expand Down Expand Up @@ -100,7 +108,7 @@ def clone(self, *args, **kwargs):
"""
try:
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
except:
except: # this is used when the tensor loose the labels, notice it will create a bug! Kept for compatibility with Lightining
out = super().clone(*args, **kwargs)

return out
Expand All @@ -123,6 +131,24 @@ def select(self, *args, **kwargs):
tmp._labels = self._labels
return tmp

def cuda(self, *args, **kwargs):
"""
Send Tensor to cuda. For more details, see :meth:`torch.Tensor.cuda`.
"""
tmp = super().cuda(*args, **kwargs)
new = self.__class__.clone(self)
new.data = tmp.data
return tmp

def cpu(self, *args, **kwargs):
"""
Send Tensor to cpu. For more details, see :meth:`torch.Tensor.cpu`.
"""
tmp = super().cpu(*args, **kwargs)
new = self.__class__.clone(self)
new.data = tmp.data
return tmp

def extract(self, label_to_extract):
"""
Extract the subset of the original tensor by returning all the columns
Expand All @@ -149,7 +175,7 @@ def extract(self, label_to_extract):
except ValueError:
raise ValueError(f'`{f}` not in the labels list')

new_data = super(Tensor, self.T).__getitem__(indeces).float().T
new_data = super(Tensor, self.T).__getitem__(indeces).T
new_labels = [self.labels[idx] for idx in indeces]

extracted_tensor = new_data.as_subclass(LabelTensor)
Expand Down Expand Up @@ -196,8 +222,12 @@ def __getitem__(self, index):
"""
Return a copy of the selected tensor.
"""
selected_lt = super(Tensor, self).__getitem__(index)

if isinstance(index, str) or (isinstance(index, (tuple, list))and all(isinstance(a, str) for a in index)):
return self.extract(index)

selected_lt = super(Tensor, self).__getitem__(index)

try:
len_index = len(index)
except TypeError:
Expand Down
16 changes: 13 additions & 3 deletions tests/test_label_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def test_labels():
def test_extract():
label_to_extract = ['a', 'c']
tensor = LabelTensor(data, labels)
print(tensor)
new = tensor.extract(label_to_extract)
assert new.labels == label_to_extract
assert new.shape[1] == len(label_to_extract)
Expand Down Expand Up @@ -58,7 +57,6 @@ def test_extract_order():
expected = torch.cat(
(data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)),
dim=1)
print(expected)
assert new.labels == label_to_extract
assert new.shape[1] == len(label_to_extract)
assert torch.all(torch.isclose(expected, new))
Expand All @@ -83,6 +81,18 @@ def test_merge2():


def test_getitem():
tensor = LabelTensor(data, labels)
tensor_view = tensor['a']

assert tensor_view.labels == ['a']
assert torch.allclose(tensor_view.flatten(), data[:, 0])

tensor_view = tensor['a', 'c']

assert tensor_view.labels == ['a', 'c']
assert torch.allclose(tensor_view, data[:, 0::2])

def test_getitem2():
tensor = LabelTensor(data, labels)
tensor_view = tensor[:5]

Expand All @@ -101,4 +111,4 @@ def test_slice():

tensor_view3 = tensor[:, 2]
assert tensor_view3.labels == labels[2]
assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))
assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))

0 comments on commit f9441c0

Please sign in to comment.