From a0333cb4d6bdbd89be6889a86ddb06750738477d Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Thu, 31 Oct 2024 17:21:14 +0100 Subject: [PATCH] Minor fix in labels tensor init --- pina/label_tensor.py | 89 +++++++++++++++++++++++++++----------------- 1 file changed, 55 insertions(+), 34 deletions(-) diff --git a/pina/label_tensor.py b/pina/label_tensor.py index f08f43fb..839c23e8 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -20,7 +20,9 @@ class LabelTensor(torch.Tensor): @staticmethod def __new__(cls, x, labels, *args, **kwargs): + full = kwargs.pop("full", False) if isinstance(x, LabelTensor): + x.full = full return x else: return super().__new__(cls, x, *args, **kwargs) @@ -124,13 +126,12 @@ def _init_labels_from_dict(self, labels): does not match with tensor shape """ tensor_shape = self.shape - if hasattr(self, 'full') and self.full: labels = { i: labels[i] if i in labels else { - 'name': i + 'name': i, 'dof': range(tensor_shape[i]) } - for i in labels.keys() + for i in range(len(tensor_shape)) } for k, v in labels.items(): # Init labels from str @@ -274,15 +275,15 @@ def cat(tensors, dim=0): return tensors[0] # Perform cat on tensors new_tensor = torch.cat(tensors, dim=dim) - + new_tensor_shape = new_tensor.shape # Update labels - labels = LabelTensor.__create_labels_cat(tensors, dim) + labels = LabelTensor.__create_labels_cat(tensors, dim, new_tensor_shape) return LabelTensor.__internal_init__(new_tensor, labels, tensors[0].dim_names) @staticmethod - def __create_labels_cat(tensors, dim): + def __create_labels_cat(tensors, dim, tensor_shape): # Check if names and dof of the labels are the same in all dimensions # except in dim stored_labels = [tensor.stored_labels for tensor in tensors] @@ -290,15 +291,18 @@ def __create_labels_cat(tensors, dim): # check if: # - labels dict have same keys # - all labels are the same expect for dimension dim - if not all( - all(stored_labels[i][k] == stored_labels[0][k] - for i in range(len(stored_labels))) - for k in stored_labels[0].keys() if k != dim): + if not all(all(stored_labels[i][k] == stored_labels[0][k] + for i in range(len(stored_labels))) + for k in stored_labels[0].keys() if k != dim): raise RuntimeError('tensors must have the same shape and dof') labels = {k: copy(v) for k, v in tensors[0].stored_labels.items()} if dim in labels.keys(): - last_dim_dof = [i for j in stored_labels for i in j[dim]['dof']] + labels_list = [j[dim]['dof'] for j in stored_labels] + if all(isinstance(j, range) for j in labels_list): + last_dim_dof = range(tensor_shape[dim]) + else: + last_dim_dof = [i for j in labels_list for i in j] labels[dim]['dof'] = last_dim_dof return labels @@ -329,7 +333,8 @@ def clone(self, *args, **kwargs): :return: A copy of the tensor. :rtype: LabelTensor """ - labels = {k: copy(v) for k, v in self._labels.items()} + labels = {k: {sub_k: copy(sub_v) for sub_k, sub_v in v.items()} for k, v + in self.stored_labels.items()} out = LabelTensor(super().clone(*args, **kwargs), labels) return out @@ -396,21 +401,23 @@ def vstack(label_tensors): """ return LabelTensor.cat(label_tensors, dim=0) + @profile def __getitem__(self, index): """ TODO: Complete docstring :param index: :return: """ - if isinstance(index, - str) or (isinstance(index, (tuple, list)) - and all(isinstance(a, str) for a in index)): + if isinstance(index, str) or (isinstance(index, (tuple, list)) + and all( + isinstance(a, str) for a in index)): return self.extract(index) + if isinstance(index, torch.Tensor) and index.dtype == torch.bool: + index = [index.nonzero().squeeze()] selected_lt = super().__getitem__(index) - if isinstance(index, torch.Tensor) and index.dtype == torch.bool: - index = [index.nonzero().squeeze().tolist()] + if isinstance(index, (int, slice)): index = [index] @@ -418,15 +425,23 @@ def __getitem__(self, index): if index[0] == Ellipsis: index = [slice(None)] * (self.ndim - 1) + [index[1]] - if hasattr(self, "labels"): - labels = {k: copy(v) for k, v in self.stored_labels.items()} + try: + stored_labels = self.stored_labels + labels = {} for j, idx in enumerate(index): + if isinstance(idx, int): selected_lt = selected_lt.unsqueeze(j) - if j in labels.keys() and idx != slice(None): - self._update_single_label(labels, labels, idx, j) + if j in self.stored_labels.keys() and idx != slice(None): + self._update_single_label(stored_labels, labels, idx, j) + labels.update( + {k: {sub_k: copy(sub_v) for sub_k, sub_v in v.items()} for k, v + in stored_labels.items() if k not in labels}) selected_lt = LabelTensor.__internal_init__(selected_lt, labels, self.dim_names) + except AttributeError: + import warnings + warnings.warn('No attribute labels in LabelTensor') return selected_lt @staticmethod @@ -443,15 +458,21 @@ def _update_single_label(old_labels, to_update_labels, index, dim): old_dof = old_labels[dim]['dof'] if isinstance(index, torch.Tensor) and index.ndim == 0: index = int(index) - if not isinstance( - index, - (int, slice)) and len(index) == len(old_dof) and isinstance( - old_dof, range): + if (not isinstance( + index, (int, slice)) and len(index) == len(old_dof) and + isinstance(old_dof, range)): return + if isinstance(index, torch.Tensor): - index = index.nonzero( - as_tuple=True - )[0] if index.dtype == torch.bool else index.tolist() + if isinstance(old_dof, range): + to_update_labels.update({ + dim: { + 'dof': index.tolist(), + 'name': old_labels[dim]['name'] + } + }) + return + index = index.tolist() if isinstance(index, list): to_update_labels.update({ dim: { @@ -459,12 +480,12 @@ def _update_single_label(old_labels, to_update_labels, index, dim): 'name': old_labels[dim]['name'] } }) - else: - to_update_labels.update( - {dim: { - 'dof': old_dof[index], - 'name': old_labels[dim]['name'] - }}) + return + to_update_labels.update( + {dim: { + 'dof': old_dof[index], + 'name': old_labels[dim]['name'] + }}) def sort_labels(self, dim=None):