Skip to content

Commit

Permalink
Update test_label_tensor.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dario-coscia authored Oct 18, 2023
1 parent 7a86242 commit 56be4ce
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions tests/test_label_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def test_labels():
def test_extract():
label_to_extract = ['a', 'c']
tensor = LabelTensor(data, labels)
print(tensor)
new = tensor.extract(label_to_extract)
assert new.labels == label_to_extract
assert new.shape[1] == len(label_to_extract)
Expand Down Expand Up @@ -58,7 +57,6 @@ def test_extract_order():
expected = torch.cat(
(data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)),
dim=1)
print(expected)
assert new.labels == label_to_extract
assert new.shape[1] == len(label_to_extract)
assert torch.all(torch.isclose(expected, new))
Expand All @@ -83,6 +81,13 @@ def test_merge2():


def test_getitem():
tensor = LabelTensor(data, labels)
tensor_view = tensor['a']

assert tensor_view.labels == ['a']
assert torch.allclose(tensor_view.flatten(), data[:, 0])

def test_getitem2():
tensor = LabelTensor(data, labels)
tensor_view = tensor[:5]

Expand All @@ -101,4 +106,4 @@ def test_slice():

tensor_view3 = tensor[:, 2]
assert tensor_view3.labels == labels[2]
assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))
assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))

0 comments on commit 56be4ce

Please sign in to comment.