diff --git a/pina/__init__.py b/pina/__init__.py index 3bc28ae6..8d56f235 100644 --- a/pina/__init__.py +++ b/pina/__init__.py @@ -1,17 +1,17 @@ __all__ = [ - "Trainer", "LabelTensor", "Plotter", "Condition", "SamplePointDataset", - "PinaDataModule", "PinaDataLoader", 'TorchOptimizer', 'Graph' + "Trainer", "LabelTensor", "Plotter", "Condition", + "PinaDataModule", 'TorchOptimizer', 'Graph', 'LabelParameter' ] from .meta import * -from .label_tensor import LabelTensor +from .label_tensor import LabelTensor, LabelParameter from .solvers.solver import SolverInterface from .trainer import Trainer from .plotter import Plotter from .condition.condition import Condition -from .data import SamplePointDataset + from .data import PinaDataModule -from .data import PinaDataLoader + from .optim import TorchOptimizer from .optim import TorchScheduler from .graph import Graph diff --git a/pina/collector.py b/pina/collector.py index 3219b2b6..1f0fb41d 100644 --- a/pina/collector.py +++ b/pina/collector.py @@ -1,3 +1,4 @@ +from . import LabelTensor from .utils import check_consistency, merge_tensors @@ -66,9 +67,12 @@ def store_sample_domains(self, n, mode, variables, sample_locations): for loc in sample_locations: # get condition condition = self.problem.conditions[loc] + condition_domain = condition.domain + if isinstance(condition_domain, str): + condition_domain = self.problem.domains[condition_domain] keys = ["input_points", "equation"] # if the condition is not ready, we get and store the data - if (not self._is_conditions_ready[loc]): + if not self._is_conditions_ready[loc]: # if it is the first time we sample if not self.data_collections[loc]: already_sampled = [] @@ -84,10 +88,11 @@ def store_sample_domains(self, n, mode, variables, sample_locations): # get the samples samples = [ - condition.domain.sample(n=n, mode=mode, variables=variables) - ] + already_sampled + condition_domain.sample(n=n, mode=mode, + variables=variables) + ] + already_sampled pts = merge_tensors(samples) - if (set(pts.labels).issubset(sorted(self.problem.input_variables))): + if set(pts.labels).issubset(sorted(self.problem.input_variables)): pts = pts.sort_labels() if sorted(pts.labels) == sorted(self.problem.input_variables): self._is_conditions_ready[loc] = True @@ -110,5 +115,6 @@ def add_points(self, new_points_dict): if not self._is_conditions_ready[k]: raise RuntimeError( 'Cannot add points on a non sampled condition') - self.data_collections[k]['input_points'] = self.data_collections[k][ - 'input_points'].vstack(v) + self.data_collections[k]['input_points'] = LabelTensor.vstack( + [self.data_collections[k][ + 'input_points'], v]) diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index 05c543eb..2adf1c82 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -19,7 +19,7 @@ class DataConditionInterface(ConditionInterface): def __init__(self, input_points, conditional_variables=None): """ - TODO + TODO : add docstring """ super().__init__() self.input_points = input_points diff --git a/pina/condition/domain_equation_condition.py b/pina/condition/domain_equation_condition.py index 53e07621..65095b05 100644 --- a/pina/condition/domain_equation_condition.py +++ b/pina/condition/domain_equation_condition.py @@ -16,7 +16,7 @@ class DomainEquationCondition(ConditionInterface): condition_type = ['physics'] def __init__(self, domain, equation): """ - TODO + TODO : add docstring """ super().__init__() self.domain = domain @@ -24,7 +24,7 @@ def __init__(self, domain, equation): def __setattr__(self, key, value): if key == 'domain': - check_consistency(value, (DomainInterface)) + check_consistency(value, (DomainInterface, str)) DomainEquationCondition.__dict__[key].__set__(self, value) elif key == 'equation': check_consistency(value, (EquationInterface)) diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index 2a7f4647..2c376a16 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -17,7 +17,7 @@ class InputPointsEquationCondition(ConditionInterface): condition_type = ['physics'] def __init__(self, input_points, equation): """ - TODO + TODO : add docstring """ super().__init__() self.input_points = input_points diff --git a/pina/condition/input_output_condition.py b/pina/condition/input_output_condition.py index e9c34bea..bf7f8b92 100644 --- a/pina/condition/input_output_condition.py +++ b/pina/condition/input_output_condition.py @@ -1,4 +1,5 @@ import torch +import torch_geometric from .condition_interface import ConditionInterface from ..label_tensor import LabelTensor @@ -16,7 +17,7 @@ class InputOutputPointsCondition(ConditionInterface): condition_type = ['supervised'] def __init__(self, input_points, output_points): """ - TODO + TODO : add docstring """ super().__init__() self.input_points = input_points @@ -24,7 +25,7 @@ def __init__(self, input_points, output_points): def __setattr__(self, key, value): if (key == 'input_points') or (key == 'output_points'): - check_consistency(value, (LabelTensor, Graph, torch.Tensor)) + check_consistency(value, (LabelTensor, Graph, torch.Tensor, torch_geometric.data.Data)) InputOutputPointsCondition.__dict__[key].__set__(self, value) elif key in ('_problem', '_condition_type'): super().__setattr__(key, value) diff --git a/pina/data/__init__.py b/pina/data/__init__.py index 2b3a126a..292c9ed1 100644 --- a/pina/data/__init__.py +++ b/pina/data/__init__.py @@ -2,14 +2,11 @@ Import data classes """ __all__ = [ - 'PinaDataLoader', 'SupervisedDataset', 'SamplePointDataset', - 'UnsupervisedDataset', 'Batch', 'PinaDataModule', 'BaseDataset' + 'PinaDataModule', + 'PinaDataset' ] -from .pina_dataloader import PinaDataLoader -from .supervised_dataset import SupervisedDataset -from .sample_dataset import SamplePointDataset -from .unsupervised_dataset import UnsupervisedDataset -from .pina_batch import Batch + + from .data_module import PinaDataModule -from .base_dataset import BaseDataset +from .dataset import PinaDataset diff --git a/pina/data/base_dataset.py b/pina/data/base_dataset.py deleted file mode 100644 index d05784f8..00000000 --- a/pina/data/base_dataset.py +++ /dev/null @@ -1,156 +0,0 @@ -""" -Basic data module implementation -""" -import torch -import logging - -from torch.utils.data import Dataset -from ..label_tensor import LabelTensor - - -class BaseDataset(Dataset): - """ - BaseDataset class, which handle initialization and data retrieval - :var condition_indices: List of indices - :var device: torch.device - """ - - def __new__(cls, problem=None, device=torch.device('cpu')): - """ - Ensure correct definition of __slots__ before initialization - :param AbstractProblem problem: The formulation of the problem. - :param torch.device device: The device on which the - dataset will be loaded. - """ - if cls is BaseDataset: - raise TypeError( - 'BaseDataset cannot be instantiated directly. Use a subclass.') - if not hasattr(cls, '__slots__'): - raise TypeError( - 'Something is wrong, __slots__ must be defined in subclasses.') - return object.__new__(cls) - - def __init__(self, problem=None, device=torch.device('cpu')): - """" - Initialize the object based on __slots__ - :param AbstractProblem problem: The formulation of the problem. - :param torch.device device: The device on which the - dataset will be loaded. - """ - super().__init__() - self.empty = True - self.problem = problem - self.device = device - self.condition_indices = None - for slot in self.__slots__: - setattr(self, slot, []) - self.num_el_per_condition = [] - self.conditions_idx = [] - if self.problem is not None: - self._init_from_problem(self.problem.collector.data_collections) - self.initialized = False - - def _init_from_problem(self, collector_dict): - """ - TODO - """ - for name, data in collector_dict.items(): - keys = list(data.keys()) - if set(self.__slots__) == set(keys): - self._populate_init_list(data) - idx = [ - key for key, val in - self.problem.collector.conditions_name.items() - if val == name - ] - self.conditions_idx.append(idx) - self.initialize() - - def add_points(self, data_dict, condition_idx, batching_dim=0): - """ - This method filled internal lists of data points - :param data_dict: dictionary containing data points - :param condition_idx: index of the condition to which the data points - belong to - :param batching_dim: dimension of the batching - :raises: ValueError if the dataset has already been initialized - """ - if not self.initialized: - self._populate_init_list(data_dict, batching_dim) - self.conditions_idx.append(condition_idx) - self.empty = False - else: - raise ValueError('Dataset already initialized') - - def _populate_init_list(self, data_dict, batching_dim=0): - current_cond_num_el = None - for slot in data_dict.keys(): - slot_data = data_dict[slot] - if batching_dim != 0: - if isinstance(slot_data, (LabelTensor, torch.Tensor)): - dims = len(slot_data.size()) - slot_data = slot_data.permute( - [batching_dim] + - [dim for dim in range(dims) if dim != batching_dim]) - if current_cond_num_el is None: - current_cond_num_el = len(slot_data) - elif current_cond_num_el != len(slot_data): - raise ValueError('Different dimension in same condition') - current_list = getattr(self, slot) - current_list += [ - slot_data - ] if not (isinstance(slot_data, list)) else slot_data - self.num_el_per_condition.append(current_cond_num_el) - - def initialize(self): - """ - Initialize the datasets tensors/LabelTensors/lists given the lists - already filled - """ - logging.debug(f'Initialize dataset {self.__class__.__name__}') - if self.num_el_per_condition: - self.condition_indices = torch.cat([ - torch.tensor( - [self.conditions_idx[i]] * self.num_el_per_condition[i], - dtype=torch.uint8) - for i in range(len(self.num_el_per_condition)) - ], - dim=0) - for slot in self.__slots__: - current_attribute = getattr(self, slot) - if all(isinstance(a, LabelTensor) for a in current_attribute): - setattr(self, slot, LabelTensor.vstack(current_attribute)) - self.initialized = True - - def __len__(self): - """ - :return: Number of elements in the dataset - """ - return len(getattr(self, self.__slots__[0])) - - def __getitem__(self, idx): - """ - :param idx: - :return: - """ - if not isinstance(idx, (tuple, list, slice, int)): - raise IndexError("Invalid index") - tensors = [] - for attribute in self.__slots__: - tensor = getattr(self, attribute) - if isinstance(attribute, (LabelTensor, torch.Tensor)): - tensors.append(tensor.__getitem__(idx)) - elif isinstance(attribute, list): - if isinstance(idx, (list, tuple)): - tensor = [tensor[i] for i in idx] - tensors.append(tensor) - return tensors - - def apply_shuffle(self, indices): - for slot in self.__slots__: - if slot != 'equation': - attribute = getattr(self, slot) - if isinstance(attribute, (LabelTensor, torch.Tensor)): - setattr(self, 'slot', attribute[[indices]]) - if isinstance(attribute, list): - setattr(self, 'slot', [attribute[i] for i in indices]) diff --git a/pina/data/data_module.py b/pina/data/data_module.py index b09fb54a..45985615 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -1,16 +1,60 @@ -""" -This module provide basic data management functionalities -""" - +import logging +from lightning.pytorch import LightningDataModule import math import torch -import logging -from pytorch_lightning import LightningDataModule -from .sample_dataset import SamplePointDataset -from .supervised_dataset import SupervisedDataset -from .unsupervised_dataset import UnsupervisedDataset -from .pina_dataloader import PinaDataLoader -from .pina_subset import PinaSubset +from ..label_tensor import LabelTensor +from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \ + RandomSampler +from torch.utils.data.distributed import DistributedSampler +from functools import partial +from .dataset import PinaDatasetFactory + + +def collate_dummy(batch): + return batch[0] + + +def collate_fn(batch, max_conditions_lengths): + """ + Function used to collate the batch + """ + batch_dict = {} + if isinstance(batch, dict): + return batch + conditions_names = batch[0].keys() + + # Condition names + for condition_name in conditions_names: + single_cond_dict = {} + condition_args = batch[0][condition_name].keys() + for arg in condition_args: + data_list = [batch[idx][condition_name][arg] for idx in range( + min(len(batch), max_conditions_lengths[condition_name]))] + if isinstance(data_list[0], LabelTensor): + single_cond_dict[arg] = LabelTensor.stack(data_list) + elif isinstance(data_list[0], torch.Tensor): + single_cond_dict[arg] = torch.stack(data_list) + else: + raise NotImplementedError( + f"Data type {type(data_list[0])} not supported") + batch_dict[condition_name] = single_cond_dict + return batch_dict + +class PinaBatchSampler(BatchSampler): + def __init__(self, dataset, batch_size, shuffle, sampler=None): + if sampler is None: + if torch.distributed.is_available() and torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + sampler = DistributedSampler(dataset, shuffle=shuffle, + rank=rank, num_replicas=world_size) + else: + if shuffle: + sampler = RandomSampler(dataset) + else: + sampler = SequentialSampler(dataset) + super().__init__(sampler=sampler, batch_size=batch_size, + drop_last=False) class PinaDataModule(LightningDataModule): @@ -20,194 +64,202 @@ class PinaDataModule(LightningDataModule): """ def __init__(self, - problem, - device, + collector, train_size=.7, test_size=.2, val_size=.1, predict_size=0., batch_size=None, shuffle=True, - datasets=None): + repeat=False + ): """ Initialize the object, creating dataset based on input problem - :param AbstractProblem problem: PINA problem - :param device: Device used for training and testing + :param Collector collector: PINA problem :param train_size: number/percentage of elements in train split :param test_size: number/percentage of elements in test split - :param eval_size: number/percentage of elements in evaluation split + :param val_size: number/percentage of elements in evaluation split :param batch_size: batch size used for training - :param datasets: list of datasets objects """ logging.debug('Start initialization of Pina DataModule') logging.info('Start initialization of Pina DataModule') super().__init__() - self.problem = problem - self.device = device - self.dataset_classes = [ - SupervisedDataset, UnsupervisedDataset, SamplePointDataset - ] - if datasets is None: - self.datasets = None - else: - self.datasets = datasets - self.split_length = [] - self.split_names = [] - self.loader_functions = {} self.batch_size = batch_size - self.condition_names = problem.collector.conditions_name + self.shuffle = shuffle + self.repeat = repeat + # Begin Data splitting + splits_dict = {} if train_size > 0: - self.split_names.append('train') - self.split_length.append(train_size) + splits_dict['train'] = train_size + self.train_dataset = None else: self.train_dataloader = super().train_dataloader - if test_size > 0: - self.split_length.append(test_size) - self.split_names.append('test') + splits_dict['test'] = test_size + self.test_dataset = None else: self.test_dataloader = super().test_dataloader - if val_size > 0: - self.split_length.append(val_size) - self.split_names.append('val') + splits_dict['val'] = val_size + self.val_dataset = None else: self.val_dataloader = super().val_dataloader - if predict_size > 0: - self.split_length.append(predict_size) - self.split_names.append('predict') + splits_dict['predict'] = predict_size + self.predict_dataset = None else: self.predict_dataloader = super().predict_dataloader - - self.splits = {k: {} for k in self.split_names} - self.shuffle = shuffle - self.has_setup_fit = False - self.has_setup_test = False - - def prepare_data(self): - if self.datasets is None: - self._create_datasets() + self.collector_splits = self._create_splits(collector, splits_dict) def setup(self, stage=None): """ Perform the splitting of the dataset """ logging.debug('Start setup of Pina DataModule obj') - if self.datasets is None: - self._create_datasets() + if stage == 'fit' or stage is None: - for dataset in self.datasets: - if len(dataset) > 0: - splits = self.dataset_split(dataset, - self.split_length, - shuffle=self.shuffle) - for i in range(len(self.split_length)): - self.splits[self.split_names[i]][ - dataset.data_type] = splits[i] - self.has_setup_fit = True + self.train_dataset = PinaDatasetFactory( + self.collector_splits['train'], + max_conditions_lengths=self.find_max_conditions_lengths( + 'train')) + if 'val' in self.collector_splits.keys(): + self.val_dataset = PinaDatasetFactory( + self.collector_splits['val'], + max_conditions_lengths=self.find_max_conditions_lengths( + 'val')) elif stage == 'test': - if self.has_setup_fit is False: - raise NotImplementedError( - "You must call setup with stage='fit' " - "first") + self.test_dataset = PinaDatasetFactory( + self.collector_splits['test'], + max_conditions_lengths=self.find_max_conditions_lengths( + 'test')) + elif stage == 'predict': + self.predict_dataset = PinaDatasetFactory( + self.collector_splits['predict'], + max_conditions_lengths=self.find_max_conditions_lengths( + 'predict')) else: - raise ValueError("stage must be either 'fit' or 'test'") + raise ValueError( + "stage must be either 'fit' or 'test' or 'predict'.") @staticmethod - def dataset_split(dataset, lengths, seed=None, shuffle=True): - """ - Perform the splitting of the dataset - :param dataset: dataset object we wanted to split - :param lengths: lengths of elements in dataset - :param seed: random seed - :param shuffle: shuffle dataset - :return: split dataset - :rtype: PinaSubset - """ - if sum(lengths) - 1 < 1e-3: - len_dataset = len(dataset) - lengths = [ - int(math.floor(len_dataset * length)) for length in lengths - ] - remainder = len(dataset) - sum(lengths) - for i in range(remainder): - lengths[i % len(lengths)] += 1 - elif sum(lengths) - 1 >= 1e-3: - raise ValueError(f"Sum of lengths is {sum(lengths)} less than 1") - - if shuffle: - if seed is not None: - generator = torch.Generator() - generator.manual_seed(seed) - indices = torch.randperm(sum(lengths), generator=generator) - else: - indices = torch.randperm(sum(lengths)) - dataset.apply_shuffle(indices) + def _split_condition(condition_dict, splits_dict): + len_condition = len(condition_dict['input_points']) - offsets = [ - sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths)) - ] - return [ - PinaSubset(dataset, slice(offset, offset + length)) - for offset, length in zip(offsets, lengths) + lengths = [ + int(math.floor(len_condition * length)) for length in + splits_dict.values() ] - def _create_datasets(self): + remainder = len_condition - sum(lengths) + for i in range(remainder): + lengths[i % len(lengths)] += 1 + splits_dict = {k: v for k, v in zip(splits_dict.keys(), lengths) + } + to_return_dict = {} + offset = 0 + for stage, stage_len in splits_dict.items(): + to_return_dict[stage] = {k: v[offset:offset + stage_len] + for k, v in condition_dict.items() if + k != 'equation' + # Equations are NEVER dataloaded + } + offset += stage_len + return to_return_dict + + def _create_splits(self, collector, splits_dict): """ - Create the dataset objects putting data + Create the dataset objects putting data """ logging.debug('Dataset creation in PinaDataModule obj') - collector = self.problem.collector - batching_dim = self.problem.batching_dimension - datasets_slots = [i.__slots__ for i in self.dataset_classes] - self.datasets = [ - dataset(device=self.device) for dataset in self.dataset_classes - ] - logging.debug('Filling datasets in PinaDataModule obj') - for name, data in collector.data_collections.items(): - keys = list(data.keys()) - idx = [ - key for key, val in collector.conditions_name.items() - if val == name - ] - for i, slot in enumerate(datasets_slots): - if slot == keys: - self.datasets[i].add_points(data, idx[0], batching_dim) - continue - datasets = [] - for dataset in self.datasets: - if not dataset.empty: - dataset.initialize() - datasets.append(dataset) - self.datasets = datasets + split_names = list(splits_dict.keys()) + dataset_dict = {name: {} for name in split_names} + for condition_name, condition_dict in collector.data_collections.items(): + len_data = len(condition_dict['input_points']) + if self.shuffle: + idx = torch.randperm(len_data) + for k, v in condition_dict.items(): + if k == 'equation': + continue + if isinstance(v, list): + condition_dict[k] = [v[i] for i in idx] + elif isinstance(v, LabelTensor): + condition_dict[k] = LabelTensor(v.tensor[[idx]], + v.labels) + elif isinstance(v, torch.Tensor): + condition_dict[k] = v[idx] + else: + raise ValueError(f"Data type {type(v)} not supported") + for key, data in self._split_condition(condition_dict, + splits_dict).items(): + dataset_dict[key].update({condition_name: data}) + return dataset_dict + + def find_max_conditions_lengths(self, split): + max_conditions_lengths = {} + for k, v in self.collector_splits[split].items(): + if self.batch_size is None: + max_conditions_lengths[k] = len(v['input_points']) + elif self.repeat: + max_conditions_lengths[k] = self.batch_size + else: + max_conditions_lengths[k] = min(len(v['input_points']), + self.batch_size) + return max_conditions_lengths def val_dataloader(self): """ Create the validation dataloader """ - return PinaDataLoader(self.splits['val'], self.batch_size, - self.condition_names) + batch_size = self.batch_size if self.batch_size is not None else len( + self.val_dataset) + + sampler = PinaBatchSampler(self.val_dataset, batch_size, shuffle=False) + return DataLoader(self.val_dataset, sampler=sampler, + collate_fn=collate_dummy) def train_dataloader(self): """ Create the training dataloader """ - return PinaDataLoader(self.splits['train'], self.batch_size, - self.condition_names) + batch_size = self.batch_size if self.batch_size is not None else len( + self.train_dataset) + + sampler = PinaBatchSampler(self.train_dataset, batch_size, + shuffle=False) + return DataLoader(self.train_dataset, sampler=sampler, + collate_fn=collate_dummy) def test_dataloader(self): """ Create the testing dataloader """ - return PinaDataLoader(self.splits['test'], self.batch_size, - self.condition_names) + max_conditions_lengths = self.find_max_conditions_lengths('test') + collate_fn_test = partial(collate_fn, + max_conditions_lengths=max_conditions_lengths) + return DataLoader(self.test_dataset, self.batch_size, + collate_fn=collate_fn_test, shuffle=False) def predict_dataloader(self): """ Create the prediction dataloader """ - return PinaDataLoader(self.splits['predict'], self.batch_size, - self.condition_names) + max_conditions_lengths = self.find_max_conditions_lengths('predict') + collate_fn_predict = partial(collate_fn, + max_conditions_lengths=max_conditions_lengths) + return DataLoader(self.predict_dataset, self.batch_size, + collate_fn=collate_fn_predict, shuffle=False) + + def transfer_batch_to_device(self, batch, device, dataloader_idx): + """ + Transfer the batch to the device. This method is called in the + training loop and is used to transfer the batch to the device. + """ + batch = [ + (k, super(LightningDataModule, self).transfer_batch_to_device(v, + device, + dataloader_idx)) + for k, v in batch.items() + ] + return batch diff --git a/pina/data/dataset.py b/pina/data/dataset.py new file mode 100644 index 00000000..0f7d9994 --- /dev/null +++ b/pina/data/dataset.py @@ -0,0 +1,91 @@ +""" +This module provide basic data management functionalities +""" +import torch +from torch.utils.data import Dataset +from abc import abstractmethod + +from torch_geometric.data import Batch + + +class PinaDataset(Dataset): + """ + Dataset class for the PIN + """ + def __init__(self, conditions_dict, max_conditions_lengths): + self.conditions_dict = conditions_dict + self.max_conditions_lengths = max_conditions_lengths + self.conditions_length = {k: len(v['input_points']) for k, v in + self.conditions_dict.items()} + self.length = max(self.conditions_length.values()) + + def _get_max_len(self): + max_len = 0 + for condition in self.conditions_dict.values(): + max_len = max(max_len, len(condition['input_points'])) + return max_len + + def __len__(self): + return self.length + + @abstractmethod + def __getitem__(self, item): + pass + + +class PinaDatasetFactory: + """ + Dataset class for the PIN + """ + def __new__(cls, conditions_dict, **kwargs): + print([isinstance(v['input_points'], list) for v + in conditions_dict.values()]) + if len(conditions_dict) == 0: + raise ValueError('No conditions provided') + if all([isinstance(v['input_points'], torch.Tensor) for v + in conditions_dict.values()]): + return PinaTensorDataset(conditions_dict, **kwargs) + elif all([isinstance(v['input_points'], list) for v + in conditions_dict.values()]): + return PinaGraphDataset(conditions_dict, **kwargs) + +class PinaTensorDataset(PinaDataset): + def __init__(self, conditions_dict, max_conditions_lengths): + super().__init__(conditions_dict, max_conditions_lengths) + def __getitem__(self, idx): + """ + Getitem method for large batch size + """ + + to_return_dict = {} + for condition, data in self.conditions_dict.items(): + cond_idx = idx[:self.max_conditions_lengths[condition]] + condition_len = self.conditions_length[condition] + if self.length > condition_len: + cond_idx = [idx%condition_len for idx in cond_idx] + to_return_dict[condition] = {k: v[cond_idx] + for k, v in data.items()} + return to_return_dict + + +class PinaGraphDataset(PinaDataset): + def __init__(self, conditions_dict, max_conditions_lengths): + super().__init__(conditions_dict, max_conditions_lengths) + + def __getitem__(self, idx): + """ + Getitem method for large batch size + """ + to_return_dict = {} + for condition, data in self.conditions_dict.items(): + cond_idx = idx[:self.max_conditions_lengths[condition]] + condition_len = self.conditions_length[condition] + if self.length > condition_len: + cond_idx = [idx%condition_len for idx in cond_idx] + to_return_dict[condition] = {k: Batch.from_data_list([v[i] + for i in cond_idx]) + if isinstance(v, list) else v[[cond_idx]] + for k, v in data.items() + } + return to_return_dict + diff --git a/pina/data/pina_batch.py b/pina/data/pina_batch.py deleted file mode 100644 index e43e1108..00000000 --- a/pina/data/pina_batch.py +++ /dev/null @@ -1,92 +0,0 @@ -""" -Batch management module -""" -import torch -from ..label_tensor import LabelTensor - -from .pina_subset import PinaSubset - - -class Batch: - """ - Implementation of the Batch class used during training to perform SGD - optimization. - """ - - def __init__(self, dataset_dict, idx_dict, require_grad=True): - self.attributes = [] - for k, v in dataset_dict.items(): - index = idx_dict[k] - if isinstance(v, PinaSubset): - dataset_index = v.indices - if isinstance(dataset_index, slice): - index = slice(dataset_index.start + index.start, - min(dataset_index.start + index.stop, - dataset_index.stop)) - setattr(self, k, PinaSubset(v.dataset, index, - require_grad=require_grad)) - self.attributes.append(k) - self.require_grad = require_grad - - def __len__(self): - """ - Returns the number of elements in the batch - :return: number of elements in the batch - :rtype: int - """ - length = 0 - for dataset in self.attributes: - attribute = getattr(self, dataset) - length += len(attribute) - return length - - def __getattr__(self, item): - if item == 'data' and len(self.attributes) == 1: - item = self.attributes[0] - return self.__getattribute__(item) - raise AttributeError(f"'Batch' object has no attribute '{item}'") - - def get_data(self, batch_name=None): - """ - # TODO - """ - data = getattr(self, batch_name) - to_return_list = [] - if isinstance(data, PinaSubset): - items = data.dataset.__slots__ - else: - items = data.__slots__ - indices = torch.unique(data.condition_indices).tolist() - condition_idx = data.condition_indices - for i in indices: - temp = [] - for j in items: - var = getattr(data, j) - if isinstance(var, (torch.Tensor, LabelTensor)): - temp.append(var[i == condition_idx]) - if isinstance(var, list) and len(var) > 0: - temp.append([var[k] for k in range(len(var)) if - i == condition_idx[k]]) - temp.append(i) - to_return_list.append(temp) - return to_return_list - - def get_supervised_data(self): - """ - Get a subset of the batch - :param idx: indices of the subset - :type idx: slice - :return: subset of the batch - :rtype: Batch - """ - return self.get_data(batch_name='supervised') - - def get_physics_data(self): - """ - Get a subset of the batch - :param idx: indices of the subset - :type idx: slice - :return: subset of the batch - :rtype: Batch - """ - return self.get_data(batch_name='physics') diff --git a/pina/data/pina_dataloader.py b/pina/data/pina_dataloader.py deleted file mode 100644 index a28ca6c6..00000000 --- a/pina/data/pina_dataloader.py +++ /dev/null @@ -1,96 +0,0 @@ -""" -This module is used to create an iterable object used during training -""" -import math - -from .pina_batch import Batch - - -class PinaDataLoader: - """ - This class is used to create a dataloader to use during the training. - - :var condition_names: The names of the conditions. The order is consistent - with the condition indeces in the batches. - :vartype condition_names: list[str] - """ - - def __init__(self, dataset_dict, batch_size, condition_names) -> None: - """ - Initialize local variables - :param dataset_dict: Dictionary of datasets - :type dataset_dict: dict - :param batch_size: Size of the batch - :type batch_size: int - :param condition_names: Names of the conditions - :type condition_names: list[str] - """ - self.condition_names = condition_names - self.dataset_dict = dataset_dict - self.batch_size = batch_size - self._init_batches(batch_size) - - def _init_batches(self, batch_size=None): - """ - Create batches according to the batch_size provided in input. - """ - self.batches = [] - n_elements = sum(len(v) for v in self.dataset_dict.values()) - if batch_size is None: - batch_size = n_elements - self.batch_size = n_elements - n_batches = int(math.ceil(n_elements / batch_size)) - indexes_dict = { - k: math.floor(len(v) / n_batches) if n_batches != 1 else len(v) for - k, v in self.dataset_dict.items()} - - dataset_names = list(self.dataset_dict.keys()) - num_el_per_batch = [{i: indexes_dict[i] for i in dataset_names} for _ - in range(n_batches - 1)] + [ - {i: 0 for i in dataset_names}] - reminders = { - i: len(self.dataset_dict[i]) - indexes_dict[i] * (n_batches - 1) for - i in dataset_names} - dataset_names = iter(dataset_names) - name = next(dataset_names, None) - for batch in num_el_per_batch: - tot_num_el = sum(batch.values()) - batch_reminder = batch_size - tot_num_el - for _ in range(batch_reminder): - if name is None: - break - if reminders[name] > 0: - batch[name] += 1 - reminders[name] -= 1 - else: - name = next(dataset_names, None) - if name is None: - break - batch[name] += 1 - reminders[name] -= 1 - - reminders, dataset_names, indexes_dict = None, None, None # free memory - actual_indices = {k: 0 for k in self.dataset_dict.keys()} - for batch in num_el_per_batch: - temp_dict = {} - total_length = 0 - for k, v in batch.items(): - temp_dict[k] = slice(actual_indices[k], actual_indices[k] + v) - actual_indices[k] = actual_indices[k] + v - total_length += v - self.batches.append( - Batch(idx_dict=temp_dict, dataset_dict=self.dataset_dict)) - - def __iter__(self): - """ - Makes dataloader object iterable - """ - yield from self.batches - - def __len__(self): - """ - Return the number of batches. - :return: The number of batches. - :rtype: int - """ - return len(self.batches) diff --git a/pina/data/pina_subset.py b/pina/data/pina_subset.py deleted file mode 100644 index b5b74a68..00000000 --- a/pina/data/pina_subset.py +++ /dev/null @@ -1,49 +0,0 @@ -""" -Module for PinaSubset class -""" -from pina import LabelTensor -from torch import Tensor, float32 - - -class PinaSubset: - """ - TODO - """ - __slots__ = ['dataset', 'indices', 'require_grad'] - - def __init__(self, dataset, indices, require_grad=False): - """ - TODO - """ - self.dataset = dataset - self.indices = indices - self.require_grad = require_grad - - def __len__(self): - """ - TODO - """ - if isinstance(self.indices, slice): - return self.indices.stop - self.indices.start - return len(self.indices) - - def __getattr__(self, name): - tensor = self.dataset.__getattribute__(name) - if isinstance(tensor, (LabelTensor, Tensor)): - if isinstance(self.indices, slice): - tensor = tensor[self.indices] - if (tensor.device != self.dataset.device - and tensor.dtype == float32): - tensor = tensor.to(self.dataset.device) - elif isinstance(self.indices, list): - tensor = tensor[[self.indices]].to(self.dataset.device) - else: - raise ValueError(f"Indices type {type(self.indices)} not " - f"supported") - return tensor.requires_grad_( - self.require_grad) if tensor.dtype == float32 else tensor - if isinstance(tensor, list): - if isinstance(self.indices, list): - return [tensor[i] for i in self.indices] - return tensor[self.indices] - raise AttributeError(f"No attribute named {name}") diff --git a/pina/data/sample_dataset.py b/pina/data/sample_dataset.py deleted file mode 100644 index bc3bca33..00000000 --- a/pina/data/sample_dataset.py +++ /dev/null @@ -1,35 +0,0 @@ -""" -Sample dataset module -""" -from copy import deepcopy -from .base_dataset import BaseDataset -from ..condition import InputPointsEquationCondition - - -class SamplePointDataset(BaseDataset): - """ - This class extends the BaseDataset to handle physical datasets - composed of only input points. - """ - data_type = 'physics' - __slots__ = InputPointsEquationCondition.__slots__ - - def add_points(self, data_dict, condition_idx, batching_dim=0): - data_dict = deepcopy(data_dict) - data_dict.pop('equation') - super().add_points(data_dict, condition_idx) - - def _init_from_problem(self, collector_dict): - for name, data in collector_dict.items(): - keys = list(data.keys()) - if set(self.__slots__) == set(keys): - data = deepcopy(data) - data.pop('equation') - self._populate_init_list(data) - idx = [ - key for key, val in - self.problem.collector.conditions_name.items() - if val == name - ] - self.conditions_idx.append(idx) - self.initialize() diff --git a/pina/data/supervised_dataset.py b/pina/data/supervised_dataset.py deleted file mode 100644 index be601050..00000000 --- a/pina/data/supervised_dataset.py +++ /dev/null @@ -1,13 +0,0 @@ -""" -Supervised dataset module -""" -from .base_dataset import BaseDataset - - -class SupervisedDataset(BaseDataset): - """ - This class extends the BaseDataset to handle datasets that consist of - input-output pairs. - """ - data_type = 'supervised' - __slots__ = ['input_points', 'output_points'] diff --git a/pina/data/unsupervised_dataset.py b/pina/data/unsupervised_dataset.py deleted file mode 100644 index 18cf296f..00000000 --- a/pina/data/unsupervised_dataset.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Unsupervised dataset module -""" -from .base_dataset import BaseDataset - - -class UnsupervisedDataset(BaseDataset): - """ - This class extend BaseDataset class to handle - unsupervised dataset,composed of input points - and, optionally, conditional variables - """ - data_type = 'unsupervised' - __slots__ = ['input_points', 'conditional_variables'] diff --git a/pina/graph.py b/pina/graph.py index 97b2770e..bde5bbf5 100644 --- a/pina/graph.py +++ b/pina/graph.py @@ -93,8 +93,8 @@ def _build_radius(**kwargs): logging.debug(f"edge_index computed") return Data( - x=nodes_data, - pos=nodes_coordinates, + x=nodes_data.tensor, + pos=nodes_coordinates.tensor, edge_index=edge_index, edge_attr=edges_data, ) diff --git a/pina/label_tensor.py b/pina/label_tensor.py index 58dc8b71..d4627e8a 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -4,21 +4,9 @@ import torch from torch import Tensor -full_labels = True -MATH_MODULES = {torch.sin, torch.cos, torch.exp, torch.tan, torch.log, - torch.sqrt} - - -def issubset(a, b): - """ - Check if a is a subset of b. - """ - if isinstance(a, list) and isinstance(b, list): - return set(a).issubset(set(b)) - if isinstance(a, range) and isinstance(b, range): - return a.start <= b.start and a.stop >= b.stop - return False +full_labels = True +MATH_FUNCTIONS = {torch.sin, torch.cos} class LabelTensor(torch.Tensor): """Torch tensor with a label for any column.""" @@ -26,6 +14,7 @@ class LabelTensor(torch.Tensor): @staticmethod def __new__(cls, x, labels, *args, **kwargs): full = kwargs.pop("full", full_labels) + if isinstance(x, LabelTensor): x.full = full return x @@ -46,79 +35,84 @@ def __init__(self, x, labels, **kwargs): {1: {"name": "space"['a', 'b', 'c']) """ - self.dim_names = None self.full = kwargs.get('full', full_labels) - self.labels = labels + if labels is not None: + self.labels = labels + else: + self._labels = {} @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): + # TODO: complete if kwargs is None: kwargs = {} - if func in MATH_MODULES: + if func in MATH_FUNCTIONS: str_labels = func.__name__ - labels = copy(args[0].stored_labels) + lt = super().__torch_function__(func, types, args=args, kwargs=kwargs) - lt_shape = lt.shape - - if len(lt_shape) - 1 in labels.keys(): - labels.update({ - len(lt_shape) - 1: { - 'dof': [f'{str_labels}({i})' for i in - labels[len(lt_shape) - 1]['dof']], - 'name': len(lt_shape) - 1 - } - }) - lt._labels = labels - return lt + if hasattr(args[0], '_labels'): + labels = {k: copy(v) for k, v in args[0].stored_labels.items()} + lt._labels = labels + + lt_shape = lt.shape + + if len(lt_shape) - 1 in labels.keys(): + labels.update({ + len(lt_shape) - 1: { + 'dof': [f'{str_labels}({i})' for i in + labels[len(lt_shape) - 1]['dof']], + 'name': len(lt_shape) - 1 + } + }) + lt._labels = labels + + return lt return super().__torch_function__(func, types, args=args, kwargs=kwargs) def __mul__(self, other): + #TODO: improve lt = super().__mul__(other) + if not hasattr(self, '_labels'): + return lt if isinstance(other, (int, float)): if hasattr(self, '_labels'): lt._labels = self._labels + if isinstance(other, LabelTensor): lt_shape = lt.shape - labels = copy(self.stored_labels) - other_labels = other.stored_labels + check = False - for (k, v), (ko, vo) in zip(sorted(labels.items()), - sorted(other_labels.items())): - if k != ko: - raise ValueError('Labels must be the same') - if k != len(lt_shape) - 1: - if vo != v: + if self.ndim in (0, 1): + labels = copy(other.stored_labels) + else: + labels = copy(self.stored_labels) + other_labels = copy(other.stored_labels) + for (k, v), (ko, vo) in zip(sorted(labels.items()), + sorted(other_labels.items())): + if k != ko: raise ValueError('Labels must be the same') - else: - check = True - if check: - labels.update({ - len(lt_shape) - 1: {'dof': [f'{i}{j}' for i, j in - zip(self.stored_labels[ - len(lt_shape) - 1]['dof'], - other.stored_labels[ - len(lt_shape) - 1]['dof'])], - 'name': self.stored_labels[ - len(lt_shape) - 1]['name']} - }) + if k != len(lt_shape) - 1: + if vo != v: + raise ValueError('Labels must be the same') + else: + check = True + if check: + labels.update({ + len(lt_shape) - 1: {'dof': [f'{i}{j}' for i, j in + zip(self.stored_labels[ + len(lt_shape) - 1][ + 'dof'], + other.stored_labels[ + len(lt_shape) - 1][ + 'dof'])], + 'name': self.stored_labels[ + len(lt_shape) - 1]['name']} + }) lt._labels = labels return lt - @classmethod - def __internal_init__(cls, - x, - labels, - dim_names, - *args, - **kwargs): - lt = cls.__new__(cls, x, labels, *args, **kwargs) - lt._labels = labels - lt.full = kwargs.get('full', full_labels) - lt.dim_names = dim_names - return lt - @property def labels(self): """Property decorator for labels @@ -166,14 +160,13 @@ def labels(self, labels): self._labels = {} if isinstance(labels, dict): self._init_labels_from_dict(labels) - elif isinstance(labels, list): + elif isinstance(labels, (list, range)): self._init_labels_from_list(labels) elif isinstance(labels, str): labels = [labels] self._init_labels_from_list(labels) else: raise ValueError("labels must be list, dict or string.") - self.set_names() def _init_labels_from_dict(self, labels): """ @@ -186,6 +179,8 @@ def _init_labels_from_dict(self, labels): does not match with tensor shape """ tensor_shape = self.shape + + # Set all labels if full_labels is True if hasattr(self, 'full') and self.full: labels = { i: labels[i] if i in labels else { @@ -193,27 +188,30 @@ def _init_labels_from_dict(self, labels): } for i in range(len(tensor_shape)) } + for k, v in labels.items(): + # Init labels from str if isinstance(v, str): v = {'name': v, 'dof': range(tensor_shape[k])} + # Init labels from dict - elif isinstance(v, dict) and list(v.keys()) == ['name']: - # Init from dict with only name key - v['dof'] = range(tensor_shape[k]) - # Init from dict with both name and dof keys - elif isinstance(v, dict) and sorted(list( - v.keys())) == ['dof', 'name']: - dof_list = v['dof'] - dof_len = len(dof_list) - if dof_len != len(set(dof_list)): - raise ValueError("dof must be unique") - if dof_len != tensor_shape[k]: - raise ValueError( - 'Number of dof does not match tensor shape') + elif isinstance(v, dict): + # Only name of the dimension if provided + if list(v.keys()) == ['name']: + v['dof'] = range(tensor_shape[k]) + # Both name and dof are provided + elif sorted(list(v.keys())) == ['dof', 'name']: + dof_list = v['dof'] + dof_len = len(dof_list) + if dof_len != len(set(dof_list)): + raise ValueError("dof must be unique") + if dof_len != tensor_shape[k]: + raise ValueError( + 'Number of dof does not match tensor shape') else: raise ValueError('Illegal labels initialization') - # Perform update + # Assign labels values self._labels[k] = v def _init_labels_from_list(self, labels): @@ -233,72 +231,71 @@ def _init_labels_from_list(self, labels): } self._init_labels_from_dict(last_dim_labels) - def set_names(self): - labels = self.stored_labels - self.dim_names = {} - for dim in labels.keys(): - self.dim_names[labels[dim]['name']] = dim - def extract(self, labels_to_extract): """ Extract the subset of the original tensor by returning all the columns corresponding to the passed ``label_to_extract``. - :param label_to_extract: The label(s) to extract. - :type label_to_extract: str | list(str) | tuple(str) + :param labels_to_extract: The label(s) to extract. + :type labels_to_extract: str | list(str) | tuple(str) :raises TypeError: Labels are not ``str``. :raises ValueError: Label to extract is not in the labels ``list``. """ # Convert str/int to string + def find_names(labels): + dim_names = {} + for dim in labels.keys(): + dim_names[labels[dim]['name']] = dim + return dim_names + if isinstance(labels_to_extract, (str, int)): labels_to_extract = [labels_to_extract] # Store useful variables - labels = self.stored_labels + labels = copy(self._labels) stored_keys = labels.keys() - dim_names = self.dim_names + dim_names = find_names(labels) ndim = len(super().shape) - # Convert tuple/list to dict + + # Convert tuple/list to dict (having a list as input + # means that we want to extract a values from the last dimension) if isinstance(labels_to_extract, (tuple, list)): if not ndim - 1 in stored_keys: raise ValueError( "LabelTensor does not have labels in last dimension") - name = labels[max(stored_keys)]['name'] + name = labels[ndim-1]['name'] labels_to_extract = {name: list(labels_to_extract)} # If labels_to_extract is not dict then rise error if not isinstance(labels_to_extract, dict): raise ValueError('labels_to_extract must be str or list or dict') - # Make copy of labels (avoid issue in consistency) - updated_labels = {k: copy(v) for k, v in labels.items()} - # Initialize list used to perform extraction - extractor = [slice(None) for _ in range(ndim)] + extractor = [slice(None)]*ndim + # Loop over labels_to_extract dict - for k, v in labels_to_extract.items(): + for dim_name, labels_te in labels_to_extract.items(): # If label is not find raise value error - idx_dim = dim_names.get(k) + idx_dim = dim_names.get(dim_name, None) if idx_dim is None: raise ValueError( 'Cannot extract label with is not in original labels') dim_labels = labels[idx_dim]['dof'] - v = [v] if isinstance(v, (int, str)) else v - if not isinstance(v, range): - extractor[idx_dim] = [dim_labels.index(i) - for i in v] if len(v) > 1 else slice( - dim_labels.index(v[0]), - dim_labels.index(v[0]) + 1) + labels_te = [labels_te] if isinstance(labels_te, (int, str)) else labels_te + if not isinstance(labels_te, range): + #If is done to keep the dimension if there is only one extracted label + extractor[idx_dim] = [dim_labels.index(i) for i in labels_te] \ + if len(labels_te)>1 else slice(dim_labels.index(labels_te[0]), dim_labels.index(labels_te[0])+1) else: - extractor[idx_dim] = slice(v.start, v.stop) + extractor[idx_dim] = slice(labels_te.start, labels_te.stop) - updated_labels.update({idx_dim: {'dof': v, 'name': k}}) + labels.update({idx_dim: {'dof': labels_te, 'name': dim_name}}) - tensor = self.tensor - tensor = tensor[extractor] - return LabelTensor.__internal_init__(tensor, updated_labels, dim_names) + tensor = super().__getitem__(extractor).as_subclass(LabelTensor) + tensor._labels = labels + return tensor def __str__(self): """ @@ -330,41 +327,53 @@ def cat(tensors, dim=0): return [] if len(tensors) == 1 or isinstance(tensors, LabelTensor): return tensors[0] + # Perform cat on tensors new_tensor = torch.cat(tensors, dim=dim) - new_tensor_shape = new_tensor.shape - # Update labels - labels = LabelTensor.__create_labels_cat(tensors, dim, new_tensor_shape) - return LabelTensor.__internal_init__(new_tensor, labels, - tensors[0].dim_names) + # --------- Start definition auxiliary function ------ + # Compute and update labels + def create_labels_cat(tensors, dim, tensor_shape): + stored_labels = [tensor.stored_labels for tensor in tensors] + keys = stored_labels[0].keys() + + if any(not all(stored_labels[i][k] == stored_labels[0][k] for i in + range(len(stored_labels))) for k in keys if k != dim): + raise RuntimeError('tensors must have the same shape and dof') + + # Copy labels from the first tensor and update the 'dof' for dimension `dim` + labels = copy(stored_labels[0]) + if dim in labels: + labels_list = [tensor[dim]['dof'] for tensor in stored_labels] + last_dim_dof = range(tensor_shape[dim]) if all(isinstance(label, range) + for label in labels_list) else sum(labels_list, []) + labels[dim]['dof'] = last_dim_dof + return labels + # --------- End definition auxiliary function ------ - @staticmethod - def __create_labels_cat(tensors, dim, tensor_shape): - # Check if names and dof of the labels are the same in all dimensions - # except in dim - stored_labels = [tensor.stored_labels for tensor in tensors] - # check if: - # - labels dict have same keys - # - all labels are the same expect for dimension dim - if not all(all(stored_labels[i][k] == stored_labels[0][k] - for i in range(len(stored_labels))) - for k in stored_labels[0].keys() if k != dim): - raise RuntimeError('tensors must have the same shape and dof') + # Update labels + if dim in tensors[0].stored_labels.keys(): + new_tensor_shape = new_tensor.shape + labels = create_labels_cat(tensors, dim, new_tensor_shape) + else: + labels = tensors[0].stored_labels + new_tensor._labels = labels + return new_tensor - labels = {k: copy(v) for k, v in tensors[0].stored_labels.items()} - if dim in labels.keys(): - labels_list = [j[dim]['dof'] for j in stored_labels] - if all(isinstance(j, range) for j in labels_list): - last_dim_dof = range(tensor_shape[dim]) - else: - last_dim_dof = [i for j in labels_list for i in j] - labels[dim]['dof'] = last_dim_dof - return labels + @staticmethod + def stack(tensors): + new_tensor = torch.stack(tensors) + labels = tensors[0]._labels + labels = {key + 1: value for key, value in labels.items()} + if full_labels: + new_tensor.labels = labels + else: + new_tensor._labels = labels + return new_tensor def requires_grad_(self, mode=True): lt = super().requires_grad_(mode) - lt.labels = self._labels + lt._labels = self._labels return lt @property @@ -377,8 +386,8 @@ def to(self, *args, **kwargs): :meth:`torch.Tensor.to`. """ lt = super().to(*args, **kwargs) - return LabelTensor.__internal_init__(lt, - self.stored_labels, self.dim_names) + lt._labels = self._labels + return lt def clone(self, *args, **kwargs): """ @@ -388,9 +397,7 @@ def clone(self, *args, **kwargs): :return: A copy of the tensor. :rtype: LabelTensor """ - labels = {k: {sub_k: copy(sub_v) for sub_k, sub_v in v.items()} for k, v - in self.stored_labels.items()} - out = LabelTensor(super().clone(*args, **kwargs), labels) + out = LabelTensor(super().clone(*args, **kwargs), deepcopy(self._labels)) return out @staticmethod @@ -408,7 +415,7 @@ def summation(tensors): raise RuntimeError('Tensors must have the same shape and labels') last_dim_labels = [] - data = torch.zeros(tensors[0].tensor.shape) + data = torch.zeros(tensors[0].tensor.shape).to(tensors[0].device) for tensor in tensors: data += tensor.tensor last_dim_labels.append(tensor.labels) @@ -456,102 +463,115 @@ def vstack(label_tensors): """ return LabelTensor.cat(label_tensors, dim=0) + # ---------------------- Start auxiliary function definition ----- + # This method is used to update labels + def _update_single_label(self, old_labels, to_update_labels, index, dim, + to_update_dim): + """ + TODO + :param old_labels: labels from which retrieve data + :param to_update_labels: labels to update + :param index: index of dof to retain + :param dim: label index + :return: + """ + old_dof = old_labels[to_update_dim]['dof'] + if isinstance(index, slice): + to_update_labels.update({ + dim: { + 'dof': old_dof[index], + 'name': old_labels[dim]['name'] + } + }) + return + if isinstance(index, int): + index = [index] + print(index) + if isinstance(index, (list, torch.Tensor)): + to_update_labels.update({ + dim: { + 'dof': [old_dof[i] for i in index] if isinstance(old_dof, list) else index, + 'name': old_labels[dim]['name'] + } + }) + return + raise NotImplementedError(f'Getitem not implemented for ' + f'{type(index)} values') + # ---------------------- End auxiliary function definition ----- + + def __getitem__(self, index): """ TODO: Complete docstring :param index: :return: """ + # Index are str --> call extract if isinstance(index, str) or (isinstance(index, (tuple, list)) and all( isinstance(a, str) for a in index)): return self.extract(index) - if isinstance(index, torch.Tensor) and index.dtype == torch.bool: - index = [index.nonzero().squeeze()] + # Store important variables selected_lt = super().__getitem__(index) + stored_labels = self._labels + labels = copy(stored_labels) - if isinstance(index, (int, slice)): + # Put here because it is the most common case (int as index). + # Used by DataLoader -> put here for efficiency purpose + if isinstance(index, list): + if 0 in labels.keys(): + self._update_single_label(stored_labels, labels, index, + 0, 0) + selected_lt._labels = labels + return selected_lt + + if isinstance(index, int): + labels.pop(0, None) + labels = {key - 1 if key > 0 else key: value for key, value in + labels.items()} + selected_lt._labels = labels + return selected_lt + + if not isinstance(index, (tuple, torch.Tensor)): index = [index] + # Ellipsis are used to perform operation on the last dimension if index[0] == Ellipsis: - index = [slice(None)] * (self.ndim - 1) + [index[1]] - - try: - stored_labels = self.stored_labels - labels = {} - for j, idx in enumerate(index): + if len(self.shape) in labels: + self._update_single_label(stored_labels, labels, index, 0, 0) + selected_lt._labels = labels + return selected_lt + + i = 0 + for j, idx in enumerate(index): + if j in self.stored_labels.keys(): if isinstance(idx, int) or ( isinstance(idx, torch.Tensor) and idx.ndim == 0): selected_lt = selected_lt.unsqueeze(j) - if j in self.stored_labels.keys() and idx != slice(None): - self._update_single_label(stored_labels, labels, idx, j) - labels.update( - {k: {sub_k: copy(sub_v) for sub_k, sub_v in v.items()} for k, v - in stored_labels.items() if k not in labels}) - selected_lt = LabelTensor.__internal_init__(selected_lt, labels, - self.dim_names) - except AttributeError: - warnings.warn('No attribute labels in LabelTensor') + if idx != slice(None): + self._update_single_label(stored_labels, labels, idx, j, i) + else: + if isinstance(idx, int): + labels = {key - 1 if key > j else key: + value for key, value in labels.items()} + continue + i += 1 + selected_lt._labels = labels return selected_lt - @staticmethod - def _update_single_label(old_labels, to_update_labels, index, dim): - """ - TODO - :param old_labels: labels from which retrieve data - :param to_update_labels: labels to update - :param index: index of dof to retain - :param dim: label index - :return: - """ - - old_dof = old_labels[dim]['dof'] - if isinstance(index, torch.Tensor) and index.ndim == 0: - index = int(index) - if (not isinstance( - index, (int, slice)) and len(index) == len(old_dof) and - isinstance(old_dof, range)): - return - - if isinstance(index, torch.Tensor): - if isinstance(old_dof, range): - to_update_labels.update({ - dim: { - 'dof': index.tolist(), - 'name': old_labels[dim]['name'] - } - }) - return - index = index.tolist() - if isinstance(index, list): - to_update_labels.update({ - dim: { - 'dof': [old_dof[i] for i in index], - 'name': old_labels[dim]['name'] - } - }) - return - to_update_labels.update( - {dim: { - 'dof': old_dof[index] if isinstance(old_dof[index], - (list, range)) else [ - old_dof[index]], - 'name': old_labels[dim]['name'] - }}) - def sort_labels(self, dim=None): - def arg_sort(lst): return sorted(range(len(lst)), key=lambda x: lst[x]) - if dim is None: dim = self.ndim - 1 + if self.shape[dim] == 1: + return self labels = self.stored_labels[dim]['dof'] sorted_index = arg_sort(labels) indexer = [slice(None)] * self.ndim indexer[dim] = sorted_index - return self.__getitem__(indexer) + return self.__getitem__(tuple(indexer)) def __deepcopy__(self, memo): cls = self.__class__ @@ -560,10 +580,28 @@ def __deepcopy__(self, memo): def permute(self, *dims): tensor = super().permute(*dims) - stored_labels = self.stored_labels + labels = self._labels keys_list = list(*dims) labels = { - keys_list.index(k): copy(stored_labels[k]) - for k in stored_labels.keys() + keys_list.index(k): labels[k] + for k in labels.keys() } - return LabelTensor.__internal_init__(tensor, labels, self.dim_names) + tensor._labels = labels + return tensor + + def detach(self): + lt = super().detach() + lt._labels = self.stored_labels + return lt + + +class LabelParameter(torch.nn.Parameter, LabelTensor): + """A class that combines torch.nn.Parameter with LabelTensor behavior.""" + + def __new__(cls, x, labels=None, requires_grad=True): + instance = torch.nn.Parameter.__new__(cls, data=x, + requires_grad=requires_grad) + return instance + + def __init__(self, x, labels=None, requires_grad=True): + LabelTensor.__init__(self, x, labels) diff --git a/pina/model/network.py b/pina/model/network.py index 6fde8039..aed3dff3 100644 --- a/pina/model/network.py +++ b/pina/model/network.py @@ -29,7 +29,8 @@ class is used internally in PINA to convert # check model consistency check_consistency(model, nn.Module) check_consistency(input_variables, str) - check_consistency(output_variables, str) + if output_variables is not None: + check_consistency(output_variables, str) self._model = model self._input_variables = input_variables @@ -67,16 +68,15 @@ def forward(self, x): # in case `input_variables = []` all points are used if self._input_variables: x = x.extract(self._input_variables) - # extract features and append for feature in self._extra_features: x = x.append(feature(x)) # perform forward pass + converting to LabelTensor - output = self._model(x).as_subclass(LabelTensor) - - # set the labels for LabelTensor - output.labels = self._output_variables + x = x.as_subclass(torch.Tensor) + output = self._model(x) + if self._output_variables is not None: + output = LabelTensor(output, self._output_variables) return output @@ -97,15 +97,9 @@ def forward_map(self, x): This function does not extract the input variables, all the variables are used for both tensors. Output variables are correctly applied. """ - # convert LabelTensor s to torch.Tensor s - x = list(map(lambda x: x.as_subclass(torch.Tensor), x)) # perform forward pass (using torch.Tensor) + converting to LabelTensor - output = self._model(x).as_subclass(LabelTensor) - - # set the labels for LabelTensor - output.labels = self._output_variables - + output = LabelTensor(self._model(x.tensor), self._output_variables) return output @property diff --git a/pina/operators.py b/pina/operators.py index 0b306dfb..ef389a64 100644 --- a/pina/operators.py +++ b/pina/operators.py @@ -63,11 +63,9 @@ def grad_scalar_output(output_, input_, d): retain_graph=True, allow_unused=True, )[0] - - gradients.labels = input_.labels - gradients = gradients.extract(d) + gradients.labels = input_.stored_labels + gradients = gradients[..., [input_.labels.index(i) for i in d]] gradients.labels = [f"d{output_fieldname}d{i}" for i in d] - return gradients if not isinstance(input_, LabelTensor): @@ -190,7 +188,9 @@ def laplacian(output_, input_, components=None, d=None, method="std"): to_append_tensors = [] for i, label in enumerate(grad_output.labels): gg = grad(grad_output, input_, d=d, components=[label]) - to_append_tensors.append(gg.extract([gg.labels[i]])) + gg = gg.extract([gg.labels[i]]) + + to_append_tensors.append(gg) labels = [f"dd{components[0]}"] result = LabelTensor.summation(tensors=to_append_tensors) result.labels = labels diff --git a/pina/problem/abstract_problem.py b/pina/problem/abstract_problem.py index ee590e02..6b21b23a 100644 --- a/pina/problem/abstract_problem.py +++ b/pina/problem/abstract_problem.py @@ -90,10 +90,6 @@ def input_variables(self): variables += self.spatial_variables if hasattr(self, "temporal_variable"): variables += self.temporal_variable - #if hasattr(self, "unknown_parameters"): - # variables += self.unknown_parameters - if hasattr(self, "custom_variables"): - variables += self.custom_variables return variables diff --git a/pina/solvers/graph.py b/pina/solvers/graph.py index 9af04e76..57dcf615 100644 --- a/pina/solvers/graph.py +++ b/pina/solvers/graph.py @@ -8,27 +8,11 @@ def __init__( self, problem, model, - nodes_coordinates, - nodes_data, loss=None, optimizer=None, - scheduler=None): - super().__init__(problem, model, loss, optimizer, scheduler) - if isinstance(nodes_coordinates, str): - self._nodes_coordinates = [nodes_coordinates] - else: - self._nodes_coordinates = nodes_coordinates - if isinstance(nodes_data, str): - self._nodes_data = nodes_data - else: - self._nodes_data = nodes_data + scheduler=None, + use_lt=True,): + super().__init__(problem, model, loss, optimizer, scheduler, use_lt=use_lt) - def forward(self, input): - input_coords = input.extract(self._nodes_coordinates) - input_data = input.extract(self._nodes_data) - - if not isinstance(input, Graph): - input = Graph.build('radius', nodes_coordinates=input_coords, nodes_data=input_data, radius=0.2) - g = self.model(input.data, edge_index=input.data.edge_index) - g.labels = {1: {'name': 'output', 'dof': ['u']}} - return g + def forward(self, batch): + return self._model(batch.x, batch.edge_index, batch.batch) diff --git a/pina/solvers/pinns/basepinn.py b/pina/solvers/pinns/basepinn.py index 2bca1823..588d7314 100644 --- a/pina/solvers/pinns/basepinn.py +++ b/pina/solvers/pinns/basepinn.py @@ -100,7 +100,8 @@ def __init__( self._optimizer = self._pina_optimizers[0] self._scheduler = self._pina_schedulers[0] - def training_step(self, batch, _): + + def training_step(self, batch): """ The Physics Informed Solver Training Step. This function takes care of the physics informed training step, and it must not be override @@ -113,46 +114,70 @@ def training_step(self, batch, _): :return: The sum of the loss functions. :rtype: LabelTensor """ - condition_losses = [] - batches = batch.get_supervised_data() - for points in batches: - input_pts, output_pts, condition_id = points - condition_name = self._dataloader.condition_names[condition_id] - self.__logged_metric = condition_name - loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) - condition_losses.append(loss_.as_subclass(torch.Tensor)) - - batches = batch.get_physics_data() - for points in batches: - input_pts, condition_id = points - condition_name = self._dataloader.condition_names[condition_id] - condition = self.problem.conditions[condition_name] - self.__logged_metric = condition_name - loss_ = self.loss_phys(input_pts, condition.equation) - # add condition losses for each epoch - condition_losses.append(loss_.as_subclass(torch.Tensor)) + condition_loss = [] + for condition_name, points in batch: + if 'output_points' in points: + input_pts, output_pts = points['input_points'], points['output_points'] + + loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) + condition_loss.append(loss_.as_subclass(torch.Tensor)) + else: + input_pts = points['input_points'] + + condition = self.problem.conditions[condition_name] + + loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation) + condition_loss.append(loss_.as_subclass(torch.Tensor)) + condition_loss.append(loss_.as_subclass(torch.Tensor)) # clamp unknown parameters in InverseProblem (if needed) self._clamp_params() - loss = sum(condition_losses) + loss = sum(condition_loss) + self.log('train_loss', loss, prog_bar=True, on_epoch=True, + logger=True, batch_size=self.get_batch_size(batch), + sync_dist=True) return loss + def validation_step(self, batch): + """ + TODO: add docstring + """ + condition_loss = [] + for condition_name, points in batch: + if 'output_points' in points: + input_pts, output_pts = points['input_points'], points['output_points'] + loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) + condition_loss.append(loss_.as_subclass(torch.Tensor)) + else: + input_pts = points['input_points'] + + condition = self.problem.conditions[condition_name] + with torch.set_grad_enabled(True): + loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation) + condition_loss.append(loss_.as_subclass(torch.Tensor)) + condition_loss.append(loss_.as_subclass(torch.Tensor)) + # clamp unknown parameters in InverseProblem (if needed) + + loss = sum(condition_loss) + self.log('val_loss', loss, on_epoch=True, prog_bar=True, + logger=True, batch_size=self.get_batch_size(batch), + sync_dist=True) + def loss_data(self, input_pts, output_pts): """ The data loss for the PINN solver. It computes the loss between the network output against the true solution. This function should not be override if not intentionally. - :param LabelTensor input_tensor: The input to the neural networks. - :param LabelTensor output_tensor: The true solution to compare the + :param LabelTensor input_pts: The input to the neural networks. + :param LabelTensor output_pts: The true solution to compare the network solution. :return: The residual loss averaged on the input coordinates :rtype: torch.Tensor """ return self._loss(self.forward(input_pts), output_pts) - @abstractmethod def loss_phys(self, samples, equation): """ @@ -202,14 +227,17 @@ def store_log(self, loss_value): :param str name: The name of the loss. :param torch.Tensor loss_value: The value of the loss. """ + batch_size = self.trainer.data_module.batch_size \ + if self.trainer.data_module.batch_size is not None else 999 + self.log( self.__logged_metric + "_loss", loss_value, prog_bar=True, logger=True, on_epoch=True, - on_step=False, - batch_size=self._dataloader.batch_size, + on_step=True, + batch_size=batch_size, ) self.__logged_res_losses.append(loss_value) diff --git a/pina/solvers/pinns/pinn.py b/pina/solvers/pinns/pinn.py index 29949710..08882020 100644 --- a/pina/solvers/pinns/pinn.py +++ b/pina/solvers/pinns/pinn.py @@ -119,9 +119,8 @@ def loss_phys(self, samples, equation): """ residual = self.compute_residual(samples=samples, equation=equation) loss_value = self.loss( - torch.zeros_like(residual, requires_grad=True), residual + torch.zeros_like(residual), residual ) - self.store_log(loss_value=float(loss_value)) return loss_value def configure_optimizers(self): @@ -134,7 +133,18 @@ def configure_optimizers(self): """ # if the problem is an InverseProblem, add the unknown parameters # to the parameters that the optimizer needs to optimize + + self._optimizer.hook(self._model.parameters()) + if isinstance(self.problem, InverseProblem): + self._optimizer.optimizer_instance.add_param_group( + { + "params": [ + self._params[var] + for var in self.problem.unknown_variables + ] + } + ) self._scheduler.hook(self._optimizer) return ([self._optimizer.optimizer_instance], [self._scheduler.scheduler_instance]) diff --git a/pina/solvers/solver.py b/pina/solvers/solver.py index fe9c897e..3a8f400c 100644 --- a/pina/solvers/solver.py +++ b/pina/solvers/solver.py @@ -2,7 +2,7 @@ from abc import ABCMeta, abstractmethod from ..model.network import Network -import pytorch_lightning +import lightning from ..utils import check_consistency from ..problem import AbstractProblem from ..optim import Optimizer, Scheduler @@ -10,7 +10,8 @@ import sys -class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta): + +class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): """ Solver base class. This class inherits is a wrapper of LightningModule class, inheriting all the @@ -93,7 +94,7 @@ def forward(self, *args, **kwargs): pass @abstractmethod - def training_step(self, batch, batch_idx): + def training_step(self, batch): pass @abstractmethod @@ -142,3 +143,11 @@ def _check_solver_consistency(self, problem): raise ValueError( f'{self.__name__} dose not support condition ' f'{condition.condition_type}') + + @staticmethod + def get_batch_size(batch): + # Assuming batch is your custom Batch object + batch_size = 0 + for data in batch: + batch_size += len(data[1]['input_points']) + return batch_size \ No newline at end of file diff --git a/pina/solvers/supervised.py b/pina/solvers/supervised.py index 049518f1..0adfe3d3 100644 --- a/pina/solvers/supervised.py +++ b/pina/solvers/supervised.py @@ -1,6 +1,7 @@ """ Module for SupervisedSolver """ import torch from pytorch_lightning.utilities.types import STEP_OUTPUT +from sympy.strategies.branch import condition from torch.nn.modules.loss import _Loss from ..optim import TorchOptimizer, TorchScheduler from .solver import SolverInterface @@ -47,7 +48,8 @@ def __init__(self, loss=None, optimizer=None, scheduler=None, - extra_features=None): + extra_features=None, + use_lt=True): """ :param AbstractProblem problem: The formualation of the problem. :param torch.nn.Module model: The neural network model to use. @@ -73,10 +75,11 @@ def __init__(self, problem=problem, optimizers=optimizer, schedulers=scheduler, - extra_features=extra_features) + extra_features=extra_features, + use_lt=use_lt) # check consistency - check_consistency(loss, (LossInterface, _Loss), + check_consistency(loss, (LossInterface, _Loss, torch.nn.Module), subclass=False) self._loss = loss self._model = self._pina_models[0] @@ -121,85 +124,30 @@ def training_step(self, batch): :rtype: LabelTensor """ condition_loss = [] - batches = batch.get_supervised_data() - for points in batches: - input_pts, output_pts, _ = points + for condition_name, points in batch: + input_pts, output_pts = points['input_points'], points['output_points'] + + loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) condition_loss.append(loss_.as_subclass(torch.Tensor)) loss = sum(condition_loss) - self.log("mean_loss", float(loss), prog_bar=True, logger=True, - on_epoch=True, - on_step=False, batch_size=self.trainer.data_module.batch_size) + self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, + batch_size=self.get_batch_size(batch), sync_dist=True) return loss def validation_step(self, batch): """ Solver validation step. """ - - batch = batch.supervised - condition_idx = batch.condition_indices - for i in range(condition_idx.min(), condition_idx.max() + 1): - condition_name = self.trainer.data_module.condition_names[i] - condition = self.problem.conditions[condition_name] - pts = batch.input_points - out = batch.output_points - if condition_name not in self.problem.conditions: - raise RuntimeError("Something wrong happened.") - - # for data driven mode - if not hasattr(condition, "output_points"): - raise NotImplementedError( - f"{type(self).__name__} works only in data-driven mode.") - - output_pts = out[condition_idx == i] - input_pts = pts[condition_idx == i] - + condition_loss = [] + for condition_name, points in batch: + input_pts, output_pts = points['input_points'], points['output_points'] loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) - self.validation_condition_losses[condition_name]['loss'].append( - loss_) - self.validation_condition_losses[condition_name]['count'].append( - len(input_pts)) - - def on_validation_epoch_end(self): - """ - Solver validation epoch end. - """ - total_loss = [] - total_count = [] - for k, v in self.validation_condition_losses.items(): - local_counter = torch.tensor(v['count']).to(self.device) - n_elements = torch.sum(local_counter) - loss = torch.sum( - torch.stack(v['loss']) * local_counter) / n_elements - loss = loss.as_subclass(torch.Tensor) - total_loss.append(loss) - total_count.append(n_elements) - self.log( - k + "_loss", - loss, - prog_bar=True, - logger=True, - on_epoch=True, - on_step=False, - batch_size=self.trainer.data_module.batch_size, - ) - total_count = (torch.tensor(total_count, dtype=torch.float32). - to(self.device)) - mean_loss = (torch.sum(torch.stack(total_loss) * total_count) / - total_count) - self.log( - "val_loss", - mean_loss, - prog_bar=True, - logger=True, - on_epoch=True, - on_step=False, - batch_size=self.trainer.data_module.batch_size, - ) - for key in self.validation_condition_losses.keys(): - self.validation_condition_losses[key]['loss'] = [] - self.validation_condition_losses[key]['count'] = [] + condition_loss.append(loss_.as_subclass(torch.Tensor)) + loss = sum(condition_loss) + self.log('val_loss', loss, prog_bar=True, logger=True, + batch_size=self.get_batch_size(batch), sync_dist=True) + def test_step(self, batch, batch_idx) -> STEP_OUTPUT: """ diff --git a/pina/trainer.py b/pina/trainer.py index f5ea5513..b8cd9f14 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -1,13 +1,13 @@ """ Trainer module. """ import warnings import torch -import pytorch_lightning +import lightning from .utils import check_consistency from .data import PinaDataModule from .solvers.solver import SolverInterface -class Trainer(pytorch_lightning.Trainer): +class Trainer(lightning.pytorch.Trainer): def __init__(self, solver, @@ -31,8 +31,8 @@ def __init__(self, and can be choosen from the `pytorch-lightning Trainer API `_ """ - log_every_n_steps = kwargs.pop('log_every_n_steps', 0) - super().__init__(log_every_n_steps=log_every_n_steps, **kwargs) + + super().__init__(**kwargs) # check inheritance consistency for solver and batch size check_consistency(solver, SolverInterface) @@ -46,6 +46,7 @@ def __init__(self, self.batch_size = batch_size self._move_to_device() self.data_module = None + self._create_loader() def _move_to_device(self): device = self._accelerator_connector._parallel_devices[0] @@ -72,35 +73,20 @@ def _create_loader(self): raise RuntimeError('Cannot create Trainer if not all conditions ' 'are sampled. The Trainer got the following:\n' f'{error_message}') - devices = self._accelerator_connector._parallel_devices - - if len(devices) > 1: - raise RuntimeError("Parallel training is not supported yet.") - - device = devices[0] - - self.data_module = PinaDataModule(problem=self.solver.problem, - device=device, + self.data_module = PinaDataModule(collector=self.solver.problem.collector, train_size=self.train_size, test_size=self.test_size, val_size=self.val_size, predict_size=self.predict_size, batch_size=self.batch_size, ) - self.data_module.setup() + if self.batch_size is None: + self.data_module.setup() def train(self, **kwargs): """ Train the solver method. """ - self._create_loader() - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", - message="You defined a `validation_step` but have no " - "`val_dataloader`", - category=UserWarning - ) - return super().fit(self.solver, + return super().fit(self.solver, datamodule=self.data_module, **kwargs) diff --git a/setup.py b/setup.py index c44cacf7..04f7b2fc 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ KEYWORDS = 'physics-informed neural-network' REQUIRED = [ - 'numpy', 'matplotlib', 'torch', 'lightning', 'pytorch_lightning', 'torch_geometric', 'torch-cluster' + 'numpy', 'matplotlib', 'torch', 'lightning', 'torch_geometric', 'torch-cluster' ] EXTRAS = {