Skip to content

Commit

Permalink
Minor fix in labels tensor init
Browse files Browse the repository at this point in the history
  • Loading branch information
FilippoOlivo committed Nov 1, 2024
1 parent cbbf15d commit a0333cb
Showing 1 changed file with 55 additions and 34 deletions.
89 changes: 55 additions & 34 deletions pina/label_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ class LabelTensor(torch.Tensor):

@staticmethod
def __new__(cls, x, labels, *args, **kwargs):
full = kwargs.pop("full", False)
if isinstance(x, LabelTensor):
x.full = full
return x
else:
return super().__new__(cls, x, *args, **kwargs)
Expand Down Expand Up @@ -124,13 +126,12 @@ def _init_labels_from_dict(self, labels):
does not match with tensor shape
"""
tensor_shape = self.shape

if hasattr(self, 'full') and self.full:
labels = {
i: labels[i] if i in labels else {
'name': i
'name': i, 'dof': range(tensor_shape[i])
}
for i in labels.keys()
for i in range(len(tensor_shape))
}
for k, v in labels.items():
# Init labels from str
Expand Down Expand Up @@ -274,31 +275,34 @@ def cat(tensors, dim=0):
return tensors[0]
# Perform cat on tensors
new_tensor = torch.cat(tensors, dim=dim)

new_tensor_shape = new_tensor.shape
# Update labels
labels = LabelTensor.__create_labels_cat(tensors, dim)
labels = LabelTensor.__create_labels_cat(tensors, dim, new_tensor_shape)

return LabelTensor.__internal_init__(new_tensor, labels,
tensors[0].dim_names)

@staticmethod
def __create_labels_cat(tensors, dim):
def __create_labels_cat(tensors, dim, tensor_shape):
# Check if names and dof of the labels are the same in all dimensions
# except in dim
stored_labels = [tensor.stored_labels for tensor in tensors]

# check if:
# - labels dict have same keys
# - all labels are the same expect for dimension dim
if not all(
all(stored_labels[i][k] == stored_labels[0][k]
for i in range(len(stored_labels)))
for k in stored_labels[0].keys() if k != dim):
if not all(all(stored_labels[i][k] == stored_labels[0][k]
for i in range(len(stored_labels)))
for k in stored_labels[0].keys() if k != dim):
raise RuntimeError('tensors must have the same shape and dof')

labels = {k: copy(v) for k, v in tensors[0].stored_labels.items()}
if dim in labels.keys():
last_dim_dof = [i for j in stored_labels for i in j[dim]['dof']]
labels_list = [j[dim]['dof'] for j in stored_labels]
if all(isinstance(j, range) for j in labels_list):
last_dim_dof = range(tensor_shape[dim])
else:
last_dim_dof = [i for j in labels_list for i in j]
labels[dim]['dof'] = last_dim_dof
return labels

Expand Down Expand Up @@ -329,7 +333,8 @@ def clone(self, *args, **kwargs):
:return: A copy of the tensor.
:rtype: LabelTensor
"""
labels = {k: copy(v) for k, v in self._labels.items()}
labels = {k: {sub_k: copy(sub_v) for sub_k, sub_v in v.items()} for k, v
in self.stored_labels.items()}
out = LabelTensor(super().clone(*args, **kwargs), labels)
return out

Expand Down Expand Up @@ -396,37 +401,47 @@ def vstack(label_tensors):
"""
return LabelTensor.cat(label_tensors, dim=0)

@profile
def __getitem__(self, index):
"""
TODO: Complete docstring
:param index:
:return:
"""
if isinstance(index,
str) or (isinstance(index, (tuple, list))
and all(isinstance(a, str) for a in index)):
if isinstance(index, str) or (isinstance(index, (tuple, list))
and all(
isinstance(a, str) for a in index)):
return self.extract(index)

if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
index = [index.nonzero().squeeze()]
selected_lt = super().__getitem__(index)

if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
index = [index.nonzero().squeeze().tolist()]


if isinstance(index, (int, slice)):
index = [index]

if index[0] == Ellipsis:
index = [slice(None)] * (self.ndim - 1) + [index[1]]

if hasattr(self, "labels"):
labels = {k: copy(v) for k, v in self.stored_labels.items()}
try:
stored_labels = self.stored_labels
labels = {}
for j, idx in enumerate(index):

if isinstance(idx, int):
selected_lt = selected_lt.unsqueeze(j)
if j in labels.keys() and idx != slice(None):
self._update_single_label(labels, labels, idx, j)
if j in self.stored_labels.keys() and idx != slice(None):
self._update_single_label(stored_labels, labels, idx, j)
labels.update(
{k: {sub_k: copy(sub_v) for sub_k, sub_v in v.items()} for k, v
in stored_labels.items() if k not in labels})
selected_lt = LabelTensor.__internal_init__(selected_lt, labels,
self.dim_names)
except AttributeError:
import warnings
warnings.warn('No attribute labels in LabelTensor')
return selected_lt

@staticmethod
Expand All @@ -443,28 +458,34 @@ def _update_single_label(old_labels, to_update_labels, index, dim):
old_dof = old_labels[dim]['dof']
if isinstance(index, torch.Tensor) and index.ndim == 0:
index = int(index)
if not isinstance(
index,
(int, slice)) and len(index) == len(old_dof) and isinstance(
old_dof, range):
if (not isinstance(
index, (int, slice)) and len(index) == len(old_dof) and
isinstance(old_dof, range)):
return

if isinstance(index, torch.Tensor):
index = index.nonzero(
as_tuple=True
)[0] if index.dtype == torch.bool else index.tolist()
if isinstance(old_dof, range):
to_update_labels.update({
dim: {
'dof': index.tolist(),
'name': old_labels[dim]['name']
}
})
return
index = index.tolist()
if isinstance(index, list):
to_update_labels.update({
dim: {
'dof': [old_dof[i] for i in index],
'name': old_labels[dim]['name']
}
})
else:
to_update_labels.update(
{dim: {
'dof': old_dof[index],
'name': old_labels[dim]['name']
}})
return
to_update_labels.update(
{dim: {
'dof': old_dof[index],
'name': old_labels[dim]['name']
}})

def sort_labels(self, dim=None):

Expand Down

0 comments on commit a0333cb

Please sign in to comment.