From 30e2fa8ee7b61f09cf15d61eab8047b3040a6b17 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Thu, 31 Oct 2024 14:26:20 +0100 Subject: [PATCH 1/7] Bug fix in LabelTensor, Dataset and DataLoader, solve #377, first attempt with PINN --- pina/condition/data_condition.py | 2 +- pina/condition/domain_equation_condition.py | 3 +- pina/condition/input_equation_condition.py | 3 +- pina/condition/input_output_condition.py | 3 +- pina/data/base_dataset.py | 9 +- pina/data/data_module.py | 21 +- pina/data/pina_batch.py | 16 +- pina/data/pina_dataloader.py | 52 +++- pina/data/pina_subset.py | 19 +- pina/domain/cartesian.py | 6 +- pina/domain/operation_interface.py | 2 +- pina/domain/union_domain.py | 3 +- pina/label_tensor.py | 114 +++++--- pina/problem/abstract_problem.py | 5 +- pina/problem/inverse_problem.py | 3 +- pina/solvers/pinns/basepinn.py | 83 ++++-- pina/solvers/pinns/pinn.py | 37 +-- pina/solvers/solver.py | 6 +- pina/solvers/supervised.py | 6 +- pina/trainer.py | 24 +- tests/test_dataset.py | 70 ++--- .../test_label_tensor/test_label_tensor_01.py | 2 +- tests/test_solvers/test_pinn.py | 263 +++--------------- tests/test_solvers/test_supervised_solver.py | 2 +- tutorials/tutorial5/tutorial.ipynb | 154 +++++----- 25 files changed, 397 insertions(+), 511 deletions(-) diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index c6777231..05c543eb 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -15,6 +15,7 @@ class DataConditionInterface(ConditionInterface): """ __slots__ = ["input_points", "conditional_variables"] + condition_type = ['unsupervised'] def __init__(self, input_points, conditional_variables=None): """ @@ -23,7 +24,6 @@ def __init__(self, input_points, conditional_variables=None): super().__init__() self.input_points = input_points self.conditional_variables = conditional_variables - self._condition_type = 'unsupervised' def __setattr__(self, key, value): if (key == 'input_points') or (key == 'conditional_variables'): diff --git a/pina/condition/domain_equation_condition.py b/pina/condition/domain_equation_condition.py index 58dca70b..53e07621 100644 --- a/pina/condition/domain_equation_condition.py +++ b/pina/condition/domain_equation_condition.py @@ -13,7 +13,7 @@ class DomainEquationCondition(ConditionInterface): """ __slots__ = ["domain", "equation"] - + condition_type = ['physics'] def __init__(self, domain, equation): """ TODO @@ -21,7 +21,6 @@ def __init__(self, domain, equation): super().__init__() self.domain = domain self.equation = equation - self._condition_type = 'physics' def __setattr__(self, key, value): if key == 'domain': diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index bf05130c..2a7f4647 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -14,7 +14,7 @@ class InputPointsEquationCondition(ConditionInterface): """ __slots__ = ["input_points", "equation"] - + condition_type = ['physics'] def __init__(self, input_points, equation): """ TODO @@ -22,7 +22,6 @@ def __init__(self, input_points, equation): super().__init__() self.input_points = input_points self.equation = equation - self._condition_type = 'physics' def __setattr__(self, key, value): if key == 'input_points': diff --git a/pina/condition/input_output_condition.py b/pina/condition/input_output_condition.py index 08ed21d9..e9c34bea 100644 --- a/pina/condition/input_output_condition.py +++ b/pina/condition/input_output_condition.py @@ -13,7 +13,7 @@ class InputOutputPointsCondition(ConditionInterface): """ __slots__ = ["input_points", "output_points"] - + condition_type = ['supervised'] def __init__(self, input_points, output_points): """ TODO @@ -21,7 +21,6 @@ def __init__(self, input_points, output_points): super().__init__() self.input_points = input_points self.output_points = output_points - self._condition_type = ['supervised', 'physics'] def __setattr__(self, key, value): if (key == 'input_points') or (key == 'output_points'): diff --git a/pina/data/base_dataset.py b/pina/data/base_dataset.py index 2c28ba30..d05784f8 100644 --- a/pina/data/base_dataset.py +++ b/pina/data/base_dataset.py @@ -5,7 +5,6 @@ import logging from torch.utils.data import Dataset - from ..label_tensor import LabelTensor @@ -109,14 +108,14 @@ def initialize(self): already filled """ logging.debug(f'Initialize dataset {self.__class__.__name__}') - if self.num_el_per_condition: self.condition_indices = torch.cat([ - torch.tensor([i] * self.num_el_per_condition[i], - dtype=torch.uint8) + torch.tensor( + [self.conditions_idx[i]] * self.num_el_per_condition[i], + dtype=torch.uint8) for i in range(len(self.num_el_per_condition)) ], - dim=0) + dim=0) for slot in self.__slots__: current_attribute = getattr(self, slot) if all(isinstance(a, LabelTensor) for a in current_attribute): diff --git a/pina/data/data_module.py b/pina/data/data_module.py index bd117b54..ea6a802c 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -23,8 +23,8 @@ def __init__(self, problem, device, train_size=.7, - test_size=.1, - val_size=.2, + test_size=.2, + val_size=.1, predict_size=0., batch_size=None, shuffle=True, @@ -61,28 +61,30 @@ def __init__(self, if train_size > 0: self.split_names.append('train') self.split_length.append(train_size) - self.loader_functions['train_dataloader'] = lambda: PinaDataLoader( - self.splits['train'], self.batch_size, self.condition_names) + self.loader_functions['train_dataloader'] = lambda \ + x: PinaDataLoader(self.splits['train'], self.batch_size, + self.condition_names) if test_size > 0: self.split_length.append(test_size) self.split_names.append('test') - self.loader_functions['test_dataloader'] = lambda: PinaDataLoader( + self.loader_functions['test_dataloader'] = lambda x: PinaDataLoader( self.splits['test'], self.batch_size, self.condition_names) if val_size > 0: self.split_length.append(val_size) self.split_names.append('val') - self.loader_functions['val_dataloader'] = lambda: PinaDataLoader( + self.loader_functions['val_dataloader'] = lambda x: PinaDataLoader( self.splits['val'], self.batch_size, self.condition_names) if predict_size > 0: self.split_length.append(predict_size) self.split_names.append('predict') - self.loader_functions['predict_dataloader'] = lambda: PinaDataLoader( + self.loader_functions[ + 'predict_dataloader'] = lambda x: PinaDataLoader( self.splits['predict'], self.batch_size, self.condition_names) self.splits = {k: {} for k in self.split_names} self.shuffle = shuffle for k, v in self.loader_functions.items(): - setattr(self, k, v) + setattr(self, k, v.__get__(self, PinaDataModule)) def prepare_data(self): if self.datasets is None: @@ -140,12 +142,11 @@ def dataset_split(dataset, lengths, seed=None, shuffle=True): indices = torch.randperm(sum(lengths)) dataset.apply_shuffle(indices) - indices = torch.arange(0, sum(lengths), 1, dtype=torch.uint8).tolist() offsets = [ sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths)) ] return [ - PinaSubset(dataset, indices[offset:offset + length]) + PinaSubset(dataset, slice(offset, offset + length)) for offset, length in zip(offsets, lengths) ] diff --git a/pina/data/pina_batch.py b/pina/data/pina_batch.py index c5d1b61d..79c076da 100644 --- a/pina/data/pina_batch.py +++ b/pina/data/pina_batch.py @@ -27,21 +27,27 @@ def __len__(self): :rtype: int """ length = 0 - for dataset in dir(self): + for dataset in self.attributes: attribute = getattr(self, dataset) - if isinstance(attribute, list): - length += len(getattr(self, dataset)) + length += len(attribute) return length def __getattribute__(self, item): if item in super().__getattribute__('attributes'): dataset = super().__getattribute__(item) index = super().__getattribute__(item + '_idx') - return PinaSubset(dataset.dataset, dataset.indices[index]) + if isinstance(dataset, PinaSubset): + dataset_index = dataset.indices + if isinstance(dataset_index, slice): + index = slice(dataset_index.start + index.start, + min(dataset_index.start + index.stop, + dataset_index.stop)) + return PinaSubset(dataset.dataset, index, + require_grad=self.require_grad) return super().__getattribute__(item) def __getattr__(self, item): if item == 'data' and len(self.attributes) == 1: item = self.attributes[0] - return super().__getattribute__(item) + return self.__getattribute__(item) raise AttributeError(f"'Batch' object has no attribute '{item}'") diff --git a/pina/data/pina_dataloader.py b/pina/data/pina_dataloader.py index e2d3fb76..a28ca6c6 100644 --- a/pina/data/pina_dataloader.py +++ b/pina/data/pina_dataloader.py @@ -2,6 +2,7 @@ This module is used to create an iterable object used during training """ import math + from .pina_batch import Batch @@ -26,6 +27,7 @@ def __init__(self, dataset_dict, batch_size, condition_names) -> None: """ self.condition_names = condition_names self.dataset_dict = dataset_dict + self.batch_size = batch_size self._init_batches(batch_size) def _init_batches(self, batch_size=None): @@ -36,20 +38,46 @@ def _init_batches(self, batch_size=None): n_elements = sum(len(v) for v in self.dataset_dict.values()) if batch_size is None: batch_size = n_elements - indexes_dict = {} + self.batch_size = n_elements n_batches = int(math.ceil(n_elements / batch_size)) - for k, v in self.dataset_dict.items(): - if n_batches != 1: - indexes_dict[k] = math.floor(len(v) / (n_batches - 1)) - else: - indexes_dict[k] = len(v) - for i in range(n_batches): - temp_dict = {} - for k, v in indexes_dict.items(): - if i != n_batches - 1: - temp_dict[k] = slice(i * v, (i + 1) * v) + indexes_dict = { + k: math.floor(len(v) / n_batches) if n_batches != 1 else len(v) for + k, v in self.dataset_dict.items()} + + dataset_names = list(self.dataset_dict.keys()) + num_el_per_batch = [{i: indexes_dict[i] for i in dataset_names} for _ + in range(n_batches - 1)] + [ + {i: 0 for i in dataset_names}] + reminders = { + i: len(self.dataset_dict[i]) - indexes_dict[i] * (n_batches - 1) for + i in dataset_names} + dataset_names = iter(dataset_names) + name = next(dataset_names, None) + for batch in num_el_per_batch: + tot_num_el = sum(batch.values()) + batch_reminder = batch_size - tot_num_el + for _ in range(batch_reminder): + if name is None: + break + if reminders[name] > 0: + batch[name] += 1 + reminders[name] -= 1 else: - temp_dict[k] = slice(i * v, len(self.dataset_dict[k])) + name = next(dataset_names, None) + if name is None: + break + batch[name] += 1 + reminders[name] -= 1 + + reminders, dataset_names, indexes_dict = None, None, None # free memory + actual_indices = {k: 0 for k in self.dataset_dict.keys()} + for batch in num_el_per_batch: + temp_dict = {} + total_length = 0 + for k, v in batch.items(): + temp_dict[k] = slice(actual_indices[k], actual_indices[k] + v) + actual_indices[k] = actual_indices[k] + v + total_length += v self.batches.append( Batch(idx_dict=temp_dict, dataset_dict=self.dataset_dict)) diff --git a/pina/data/pina_subset.py b/pina/data/pina_subset.py index 275541e9..b5b74a68 100644 --- a/pina/data/pina_subset.py +++ b/pina/data/pina_subset.py @@ -11,7 +11,7 @@ class PinaSubset: """ __slots__ = ['dataset', 'indices', 'require_grad'] - def __init__(self, dataset, indices, require_grad=True): + def __init__(self, dataset, indices, require_grad=False): """ TODO """ @@ -23,14 +23,27 @@ def __len__(self): """ TODO """ + if isinstance(self.indices, slice): + return self.indices.stop - self.indices.start return len(self.indices) def __getattr__(self, name): tensor = self.dataset.__getattribute__(name) if isinstance(tensor, (LabelTensor, Tensor)): - tensor = tensor[[self.indices]].to(self.dataset.device) + if isinstance(self.indices, slice): + tensor = tensor[self.indices] + if (tensor.device != self.dataset.device + and tensor.dtype == float32): + tensor = tensor.to(self.dataset.device) + elif isinstance(self.indices, list): + tensor = tensor[[self.indices]].to(self.dataset.device) + else: + raise ValueError(f"Indices type {type(self.indices)} not " + f"supported") return tensor.requires_grad_( self.require_grad) if tensor.dtype == float32 else tensor if isinstance(tensor, list): - return [tensor[i] for i in self.indices] + if isinstance(self.indices, list): + return [tensor[i] for i in self.indices] + return tensor[self.indices] raise AttributeError(f"No attribute named {name}") diff --git a/pina/domain/cartesian.py b/pina/domain/cartesian.py index 5fe99d64..a7351baf 100644 --- a/pina/domain/cartesian.py +++ b/pina/domain/cartesian.py @@ -236,15 +236,15 @@ def _single_points_sample(n, variables): return result if self.fixed_ and (not self.range_): - return _single_points_sample(n, variables) + return _single_points_sample(n, variables).sort_labels() if variables == "all": variables = list(self.range_.keys()) + list(self.fixed_.keys()) if mode in ["grid", "chebyshev"]: - return _1d_sampler(n, mode, variables) + return _1d_sampler(n, mode, variables).sort_labels() elif mode in ["random", "lh", "latin"]: - return _Nd_sampler(n, mode, variables) + return _Nd_sampler(n, mode, variables).sort_labels() else: raise ValueError(f"mode={mode} is not valid.") diff --git a/pina/domain/operation_interface.py b/pina/domain/operation_interface.py index 0300f524..abe3be2b 100644 --- a/pina/domain/operation_interface.py +++ b/pina/domain/operation_interface.py @@ -66,7 +66,7 @@ def _check_dimensions(self, geometries): :type geometries: list[Location] """ for geometry in geometries: - if geometry.variables != geometries[0].variables: + if sorted(geometry.variables) != sorted(geometries[0].variables): raise NotImplementedError( f"The geometries need to have same dimensions and labels." ) diff --git a/pina/domain/union_domain.py b/pina/domain/union_domain.py index 0af8e1bd..f181c354 100644 --- a/pina/domain/union_domain.py +++ b/pina/domain/union_domain.py @@ -113,5 +113,4 @@ def sample(self, n, mode="random", variables="all"): # in case number of sampled points is smaller than the number of geometries if len(sampled_points) >= n: break - - return LabelTensor(torch.cat(sampled_points), labels=self.variables) + return LabelTensor.cat(sampled_points) diff --git a/pina/label_tensor.py b/pina/label_tensor.py index 719975c5..ad9034b2 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -1,8 +1,10 @@ """ Module for LabelTensor """ -from copy import copy, deepcopy +import warnings import torch +from copy import copy, deepcopy from torch import Tensor +full_labels = True def issubset(a, b): """ @@ -20,7 +22,9 @@ class LabelTensor(torch.Tensor): @staticmethod def __new__(cls, x, labels, *args, **kwargs): + full = kwargs.pop("full", full_labels) if isinstance(x, LabelTensor): + x.full = full return x else: return super().__new__(cls, x, *args, **kwargs) @@ -41,7 +45,7 @@ def __init__(self, x, labels, **kwargs): """ self.dim_names = None - self.full = kwargs.get('full', True) + self.full = kwargs.get('full', full_labels) self.labels = labels @classmethod @@ -53,7 +57,7 @@ def __internal_init__(cls, **kwargs): lt = cls.__new__(cls, x, labels, *args, **kwargs) lt._labels = labels - lt.full = kwargs.get('full', True) + lt.full = kwargs.get('full', full_labels) lt.dim_names = dim_names return lt @@ -124,13 +128,12 @@ def _init_labels_from_dict(self, labels): does not match with tensor shape """ tensor_shape = self.shape - if hasattr(self, 'full') and self.full: labels = { i: labels[i] if i in labels else { - 'name': i + 'name': i, 'dof': range(tensor_shape[i]) } - for i in labels.keys() + for i in range(len(tensor_shape)) } for k, v in labels.items(): # Init labels from str @@ -197,7 +200,6 @@ def extract(self, labels_to_extract): stored_keys = labels.keys() dim_names = self.dim_names ndim = len(super().shape) - # Convert tuple/list to dict if isinstance(labels_to_extract, (tuple, list)): if not ndim - 1 in stored_keys: @@ -215,7 +217,6 @@ def extract(self, labels_to_extract): # Initialize list used to perform extraction extractor = [slice(None) for _ in range(ndim)] - # Loop over labels_to_extract dict for k, v in labels_to_extract.items(): @@ -227,12 +228,11 @@ def extract(self, labels_to_extract): dim_labels = labels[idx_dim]['dof'] v = [v] if isinstance(v, (int, str)) else v - if not isinstance(v, range): extractor[idx_dim] = [dim_labels.index(i) for i in v] if len(v) > 1 else slice( - dim_labels.index(v[0]), - dim_labels.index(v[0]) + 1) + dim_labels.index(v[0]), + dim_labels.index(v[0]) + 1) else: extractor[idx_dim] = slice(v.start, v.stop) @@ -274,31 +274,33 @@ def cat(tensors, dim=0): return tensors[0] # Perform cat on tensors new_tensor = torch.cat(tensors, dim=dim) - + new_tensor_shape = new_tensor.shape # Update labels - labels = LabelTensor.__create_labels_cat(tensors, dim) + labels = LabelTensor.__create_labels_cat(tensors, dim, new_tensor_shape) return LabelTensor.__internal_init__(new_tensor, labels, tensors[0].dim_names) @staticmethod - def __create_labels_cat(tensors, dim): + def __create_labels_cat(tensors, dim, tensor_shape): # Check if names and dof of the labels are the same in all dimensions # except in dim stored_labels = [tensor.stored_labels for tensor in tensors] - # check if: # - labels dict have same keys # - all labels are the same expect for dimension dim - if not all( - all(stored_labels[i][k] == stored_labels[0][k] - for i in range(len(stored_labels))) - for k in stored_labels[0].keys() if k != dim): + if not all(all(stored_labels[i][k] == stored_labels[0][k] + for i in range(len(stored_labels))) + for k in stored_labels[0].keys() if k != dim): raise RuntimeError('tensors must have the same shape and dof') labels = {k: copy(v) for k, v in tensors[0].stored_labels.items()} if dim in labels.keys(): - last_dim_dof = [i for j in stored_labels for i in j[dim]['dof']] + labels_list = [j[dim]['dof'] for j in stored_labels] + if all(isinstance(j, range) for j in labels_list): + last_dim_dof = range(tensor_shape[dim]) + else: + last_dim_dof = [i for j in labels_list for i in j] labels[dim]['dof'] = last_dim_dof return labels @@ -316,10 +318,9 @@ def to(self, *args, **kwargs): Performs Tensor dtype and/or device conversion. For more details, see :meth:`torch.Tensor.to`. """ - tmp = super().to(*args, **kwargs) - new = self.__class__.clone(self) - new.data = tmp.data - return new + lt = super().to(*args, **kwargs) + return LabelTensor.__internal_init__(lt, + self.stored_labels, self.dim_names) def clone(self, *args, **kwargs): """ @@ -329,7 +330,8 @@ def clone(self, *args, **kwargs): :return: A copy of the tensor. :rtype: LabelTensor """ - labels = {k: copy(v) for k, v in self._labels.items()} + labels = {k: {sub_k: copy(sub_v) for sub_k, sub_v in v.items()} for k, v + in self.stored_labels.items()} out = LabelTensor(super().clone(*args, **kwargs), labels) return out @@ -402,11 +404,13 @@ def __getitem__(self, index): :param index: :return: """ - if isinstance(index, - str) or (isinstance(index, (tuple, list)) - and all(isinstance(a, str) for a in index)): + if isinstance(index, str) or (isinstance(index, (tuple, list)) + and all( + isinstance(a, str) for a in index)): return self.extract(index) + if isinstance(index, torch.Tensor) and index.dtype == torch.bool: + index = [index.nonzero().squeeze()] selected_lt = super().__getitem__(index) if isinstance(index, (int, slice)): @@ -415,15 +419,22 @@ def __getitem__(self, index): if index[0] == Ellipsis: index = [slice(None)] * (self.ndim - 1) + [index[1]] - if hasattr(self, "labels"): - labels = {k: copy(v) for k, v in self.stored_labels.items()} + try: + stored_labels = self.stored_labels + labels = {} for j, idx in enumerate(index): - if isinstance(idx, int): + if isinstance(idx, int) or ( + isinstance(idx, torch.Tensor) and idx.ndim == 0): selected_lt = selected_lt.unsqueeze(j) - if j in labels.keys() and idx != slice(None): - self._update_single_label(labels, labels, idx, j) + if j in self.stored_labels.keys() and idx != slice(None): + self._update_single_label(stored_labels, labels, idx, j) + labels.update( + {k: {sub_k: copy(sub_v) for sub_k, sub_v in v.items()} for k, v + in stored_labels.items() if k not in labels}) selected_lt = LabelTensor.__internal_init__(selected_lt, labels, self.dim_names) + except AttributeError: + warnings.warn('No attribute labels in LabelTensor') return selected_lt @staticmethod @@ -436,16 +447,25 @@ def _update_single_label(old_labels, to_update_labels, index, dim): :param dim: label index :return: """ + old_dof = old_labels[dim]['dof'] - if not isinstance( - index, - (int, slice)) and len(index) == len(old_dof) and isinstance( - old_dof, range): + if isinstance(index, torch.Tensor) and index.ndim == 0: + index = int(index) + if (not isinstance( + index, (int, slice)) and len(index) == len(old_dof) and + isinstance(old_dof, range)): return + if isinstance(index, torch.Tensor): - index = index.nonzero( - as_tuple=True - )[0] if index.dtype == torch.bool else index.tolist() + if isinstance(old_dof, range): + to_update_labels.update({ + dim: { + 'dof': index.tolist(), + 'name': old_labels[dim]['name'] + } + }) + return + index = index.tolist() if isinstance(index, list): to_update_labels.update({ dim: { @@ -453,12 +473,14 @@ def _update_single_label(old_labels, to_update_labels, index, dim): 'name': old_labels[dim]['name'] } }) - else: - to_update_labels.update( - {dim: { - 'dof': old_dof[index], - 'name': old_labels[dim]['name'] - }}) + return + to_update_labels.update( + {dim: { + 'dof': old_dof[index] if isinstance(old_dof[index], + (list, range)) else [ + old_dof[index]], + 'name': old_labels[dim]['name'] + }}) def sort_labels(self, dim=None): diff --git a/pina/problem/abstract_problem.py b/pina/problem/abstract_problem.py index 6897fbb7..ee590e02 100644 --- a/pina/problem/abstract_problem.py +++ b/pina/problem/abstract_problem.py @@ -90,8 +90,8 @@ def input_variables(self): variables += self.spatial_variables if hasattr(self, "temporal_variable"): variables += self.temporal_variable - if hasattr(self, "unknown_parameters"): - variables += self.parameters + #if hasattr(self, "unknown_parameters"): + # variables += self.unknown_parameters if hasattr(self, "custom_variables"): variables += self.custom_variables @@ -170,7 +170,6 @@ def discretise_domain(self, f"Wrong variables for sampling. Variables ", f"should be in {self.input_variables}.", ) - # check correct location if locations == "all": locations = [ diff --git a/pina/problem/inverse_problem.py b/pina/problem/inverse_problem.py index 51cbd3ca..e54495a6 100644 --- a/pina/problem/inverse_problem.py +++ b/pina/problem/inverse_problem.py @@ -1,7 +1,6 @@ """Module for the ParametricProblem class""" - +import torch from abc import abstractmethod - from .abstract_problem import AbstractProblem diff --git a/pina/solvers/pinns/basepinn.py b/pina/solvers/pinns/basepinn.py index 543f823f..3762cc88 100644 --- a/pina/solvers/pinns/basepinn.py +++ b/pina/solvers/pinns/basepinn.py @@ -1,14 +1,15 @@ """ Module for PINN """ -import sys from abc import ABCMeta, abstractmethod import torch - -from ...solvers.solver import SolverInterface -from pina.utils import check_consistency -from pina.loss.loss_interface import LossInterface -from pina.problem import InverseProblem from torch.nn.modules.loss import _Loss +from ...condition import InputOutputPointsCondition +from ...solvers.solver import SolverInterface +from ...utils import check_consistency +from ...loss.loss_interface import LossInterface +from ...problem import InverseProblem +from ...condition import DomainEquationCondition +from ...optim import TorchOptimizer, TorchScheduler torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 @@ -25,13 +26,14 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): to the user to choose which problem the implemented solver inheriting from this class is suitable for. """ - + accepted_condition_types = [DomainEquationCondition.condition_type[0], + InputOutputPointsCondition.condition_type[0]] def __init__( self, models, problem, optimizers, - optimizers_kwargs, + schedulers, extra_features, loss, ): @@ -53,11 +55,20 @@ def __init__( :param torch.nn.Module loss: The loss function used as minimizer, default :class:`torch.nn.MSELoss`. """ + if optimizers is None: + optimizers = TorchOptimizer(torch.optim.Adam, lr=0.001) + + if schedulers is None: + schedulers = TorchScheduler(torch.optim.lr_scheduler.ConstantLR) + + if loss is None: + loss = torch.nn.MSELoss() + super().__init__( models=models, problem=problem, optimizers=optimizers, - optimizers_kwargs=optimizers_kwargs, + schedulers=schedulers, extra_features=extra_features, ) @@ -85,6 +96,10 @@ def __init__( # variable will be stored with name = self.__logged_metric self.__logged_metric = None + self._model = self._pina_models[0] + self._optimizer = self._pina_optimizers[0] + self._scheduler = self._pina_schedulers[0] + def training_step(self, batch, _): """ The Physics Informed Solver Training Step. This function takes care @@ -100,28 +115,43 @@ def training_step(self, batch, _): """ condition_losses = [] - condition_idx = batch["condition"] - for condition_id in range(condition_idx.min(), condition_idx.max() + 1): + physics = batch.physics + if hasattr(batch, 'supervised'): + supervised = batch.supervised + condition_idx = supervised.condition_indices + else: + condition_idx = torch.tensor([]) + + for condition_id in torch.unique(condition_idx).tolist(): + condition_name = self._dataloader.condition_names[condition_id] + condition = self.problem.conditions[condition_name] + self.__logged_metric = condition_name + pts = batch.supervised.input_points + out = batch.supervised.output_points + output_pts = out[condition_idx == condition_id] + input_pts = pts[condition_idx == condition_id] + + input_pts.labels = pts.labels + output_pts.labels = out.labels + + loss = self.loss_data(input_points=input_pts, output_points=output_pts) + loss = loss.as_subclass(torch.Tensor) + + condition_idx = physics.condition_indices + for condition_id in torch.unique(condition_idx).tolist(): condition_name = self._dataloader.condition_names[condition_id] condition = self.problem.conditions[condition_name] - pts = batch["pts"] - # condition name is logged (if logs enabled) self.__logged_metric = condition_name + pts = batch.physics.input_points + input_pts = pts[condition_idx == condition_id] - if len(batch) == 2: - samples = pts[condition_idx == condition_id] - loss = self.loss_phys(samples, condition.equation) - elif len(batch) == 3: - samples = pts[condition_idx == condition_id] - ground_truth = batch["output"][condition_idx == condition_id] - loss = self.loss_data(samples, ground_truth) - else: - raise ValueError("Batch size not supported") + input_pts.labels = pts.labels + loss = self.loss_phys(pts, condition.equation) # add condition losses for each epoch - condition_losses.append(loss * condition.data_weight) + condition_losses.append(loss) # clamp unknown parameters in InverseProblem (if needed) self._clamp_params() @@ -130,7 +160,7 @@ def training_step(self, batch, _): total_loss = sum(condition_losses) return total_loss.as_subclass(torch.Tensor) - def loss_data(self, input_tensor, output_tensor): + def loss_data(self, input_points, output_points): """ The data loss for the PINN solver. It computes the loss between the network output against the true solution. This function @@ -142,9 +172,9 @@ def loss_data(self, input_tensor, output_tensor): :return: The residual loss averaged on the input coordinates :rtype: torch.Tensor """ - loss_value = self.loss(self.forward(input_tensor), output_tensor) + loss_value = self.loss(self.forward(input_points), output_points) self.store_log(loss_value=float(loss_value)) - return self.loss(self.forward(input_tensor), output_tensor) + return loss_value @abstractmethod def loss_phys(self, samples, equation): @@ -202,6 +232,7 @@ def store_log(self, loss_value): logger=True, on_epoch=True, on_step=False, + batch_size=self._dataloader.batch_size, ) self.__logged_res_losses.append(loss_value) diff --git a/pina/solvers/pinns/pinn.py b/pina/solvers/pinns/pinn.py index 15f90818..29949710 100644 --- a/pina/solvers/pinns/pinn.py +++ b/pina/solvers/pinns/pinn.py @@ -9,10 +9,8 @@ _LRScheduler as LRScheduler, ) # torch < 2.0 -from torch.optim.lr_scheduler import ConstantLR from .basepinn import PINNInterface -from pina.utils import check_consistency from pina.problem import InverseProblem @@ -56,16 +54,16 @@ class PINN(PINNInterface): DOI: `10.1038 `_. """ + __name__ = 'PINN' + def __init__( self, problem, model, extra_features=None, - loss=torch.nn.MSELoss(), - optimizer=torch.optim.Adam, - optimizer_kwargs={"lr": 0.001}, - scheduler=ConstantLR, - scheduler_kwargs={"factor": 1, "total_iters": 0}, + loss=None, + optimizer=None, + scheduler=None, ): """ :param AbstractProblem problem: The formulation of the problem. @@ -82,20 +80,15 @@ def __init__( :param dict scheduler_kwargs: LR scheduler constructor keyword args. """ super().__init__( - models=[model], + models=model, problem=problem, - optimizers=[optimizer], - optimizers_kwargs=[optimizer_kwargs], + optimizers=optimizer, + schedulers=scheduler, extra_features=extra_features, loss=loss, ) - # check consistency - check_consistency(scheduler, LRScheduler, subclass=True) - check_consistency(scheduler_kwargs, dict) - # assign variables - self._scheduler = scheduler(self.optimizers[0], **scheduler_kwargs) self._neural_net = self.models[0] def forward(self, x): @@ -141,16 +134,10 @@ def configure_optimizers(self): """ # if the problem is an InverseProblem, add the unknown parameters # to the parameters that the optimizer needs to optimize - if isinstance(self.problem, InverseProblem): - self.optimizers[0].add_param_group( - { - "params": [ - self._params[var] - for var in self.problem.unknown_variables - ] - } - ) - return self.optimizers, [self.scheduler] + self._optimizer.hook(self._model.parameters()) + self._scheduler.hook(self._optimizer) + return ([self._optimizer.optimizer_instance], + [self._scheduler.scheduler_instance]) @property def scheduler(self): diff --git a/pina/solvers/solver.py b/pina/solvers/solver.py index e00bc8d5..b622546e 100644 --- a/pina/solvers/solver.py +++ b/pina/solvers/solver.py @@ -138,8 +138,8 @@ def _check_solver_consistency(self, problem): TODO """ for _, condition in problem.conditions.items(): - if not set(self.accepted_condition_types).issubset( - condition.condition_type): + if not set(condition.condition_type).issubset( + set(self.accepted_condition_types)): raise ValueError( - f'{self.__name__} support only dose not support condition ' + f'{self.__name__} dose not support condition ' f'{condition.condition_type}') diff --git a/pina/solvers/supervised.py b/pina/solvers/supervised.py index 62fc9914..b9258f7d 100644 --- a/pina/solvers/supervised.py +++ b/pina/solvers/supervised.py @@ -7,6 +7,7 @@ from ..label_tensor import LabelTensor from ..utils import check_consistency from ..loss.loss_interface import LossInterface +from ..condition import InputOutputPointsCondition class SupervisedSolver(SolverInterface): @@ -37,7 +38,7 @@ class SupervisedSolver(SolverInterface): we are seeking to approximate multiple (discretised) functions given multiple (discretised) input functions. """ - accepted_condition_types = ['supervised'] + accepted_condition_types = [InputOutputPointsCondition.condition_type[0]] __name__ = 'SupervisedSolver' def __init__(self, @@ -115,10 +116,9 @@ def training_step(self, batch, batch_idx): :return: The sum of the loss functions. :rtype: LabelTensor """ - condition_idx = batch.supervised.condition_indices + condition_idx = batch.supervised.condition_indices for condition_id in range(condition_idx.min(), condition_idx.max() + 1): - condition_name = self._dataloader.condition_names[condition_id] condition = self.problem.conditions[condition_name] pts = batch.supervised.input_points diff --git a/pina/trainer.py b/pina/trainer.py index 58c66f67..49461166 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -14,7 +14,7 @@ def __init__(self, batch_size=None, train_size=.7, test_size=.2, - eval_size=.1, + val_size=.1, **kwargs): """ PINA Trainer class for costumizing every aspect of training via flags. @@ -39,11 +39,12 @@ def __init__(self, check_consistency(batch_size, int) self.train_size = train_size self.test_size = test_size - self.eval_size = eval_size + self.val_size = val_size self.solver = solver self.batch_size = batch_size self._create_loader() self._move_to_device() + self.data_module = None def _move_to_device(self): device = self._accelerator_connector._parallel_devices[0] @@ -64,7 +65,7 @@ def _create_loader(self): if not self.solver.problem.collector.full: error_message = '\n'.join([ f"""{" " * 13} ---> Condition {key} {"sampled" if value else - "not sampled"}""" for key, value in + "not sampled"}""" for key, value in self._solver.problem.collector._is_conditions_ready.items() ]) raise RuntimeError('Cannot create Trainer if not all conditions ' @@ -77,20 +78,21 @@ def _create_loader(self): device = devices[0] - data_module = PinaDataModule(problem=self.solver.problem, - device=device, - train_size=self.train_size, - test_size=self.test_size, - val_size=self.eval_size) - data_module.setup() - self._loader = data_module.train_dataloader() + self.data_module = PinaDataModule(problem=self.solver.problem, + device=device, + train_size=self.train_size, + test_size=self.test_size, + val_size=self.val_size, + batch_size=self.batch_size, ) + self.data_module.setup() def train(self, **kwargs): """ Train the solver method. """ + self._create_loader() return super().fit(self.solver, - train_dataloaders=self._loader, + datamodule=self.data_module, **kwargs) @property diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 87fd9a15..c8e29cec 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -32,49 +32,49 @@ class Poisson(SpatialProblem): conditions = { 'gamma1': - Condition(domain=CartesianDomain({ - 'x': [0, 1], - 'y': 1 - }), - equation=FixedValue(0.0)), + Condition(domain=CartesianDomain({ + 'x': [0, 1], + 'y': 1 + }), + equation=FixedValue(0.0)), 'gamma2': - Condition(domain=CartesianDomain({ - 'x': [0, 1], - 'y': 0 - }), - equation=FixedValue(0.0)), + Condition(domain=CartesianDomain({ + 'x': [0, 1], + 'y': 0 + }), + equation=FixedValue(0.0)), 'gamma3': - Condition(domain=CartesianDomain({ - 'x': 1, - 'y': [0, 1] - }), - equation=FixedValue(0.0)), + Condition(domain=CartesianDomain({ + 'x': 1, + 'y': [0, 1] + }), + equation=FixedValue(0.0)), 'gamma4': - Condition(domain=CartesianDomain({ - 'x': 0, - 'y': [0, 1] - }), - equation=FixedValue(0.0)), + Condition(domain=CartesianDomain({ + 'x': 0, + 'y': [0, 1] + }), + equation=FixedValue(0.0)), 'D': - Condition(input_points=LabelTensor(torch.rand(size=(100, 2)), - ['x', 'y']), - equation=my_laplace), + Condition(input_points=LabelTensor(torch.rand(size=(100, 2)), + ['x', 'y']), + equation=my_laplace), 'data': - Condition(input_points=in_, output_points=out_), + Condition(input_points=in_, output_points=out_), 'data2': - Condition(input_points=in2_, output_points=out2_), + Condition(input_points=in2_, output_points=out2_), 'unsupervised': - Condition( - input_points=LabelTensor(torch.rand(size=(45, 2)), ['x', 'y']), - conditional_variables=LabelTensor(torch.ones(size=(45, 1)), - ['alpha']), - ), + Condition( + input_points=LabelTensor(torch.rand(size=(45, 2)), ['x', 'y']), + conditional_variables=LabelTensor(torch.ones(size=(45, 1)), + ['alpha']), + ), 'unsupervised2': - Condition( - input_points=LabelTensor(torch.rand(size=(90, 2)), ['x', 'y']), - conditional_variables=LabelTensor(torch.ones(size=(90, 1)), - ['alpha']), - ) + Condition( + input_points=LabelTensor(torch.rand(size=(90, 2)), ['x', 'y']), + conditional_variables=LabelTensor(torch.ones(size=(90, 1)), + ['alpha']), + ) } diff --git a/tests/test_label_tensor/test_label_tensor_01.py b/tests/test_label_tensor/test_label_tensor_01.py index 57aafb8c..ea43307c 100644 --- a/tests/test_label_tensor/test_label_tensor_01.py +++ b/tests/test_label_tensor/test_label_tensor_01.py @@ -114,5 +114,5 @@ def test_slice(): assert torch.allclose(tensor_view2, data[3]) tensor_view3 = tensor[:, 2] - assert tensor_view3.labels == labels[2] + assert tensor_view3.labels == [labels[2]] assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1)) diff --git a/tests/test_solvers/test_pinn.py b/tests/test_solvers/test_pinn.py index 8ee9d612..72887a4f 100644 --- a/tests/test_solvers/test_pinn.py +++ b/tests/test_solvers/test_pinn.py @@ -1,5 +1,4 @@ import torch - from pina.problem import SpatialProblem, InverseProblem from pina.operators import laplacian from pina.domain import CartesianDomain @@ -9,7 +8,7 @@ from pina.model import FeedForward from pina.equation.equation import Equation from pina.equation.equation_factory import FixedValue -from pina.loss.loss_interface import LpLoss +from pina.loss import LpLoss def laplace_equation(input_, output_): @@ -54,22 +53,22 @@ def laplace_equation(input_, output_, params_): # define the conditions for the loss (boundary conditions, equation, data) conditions = { - 'gamma1': Condition(location=CartesianDomain({'x': [x_min, x_max], + 'gamma1': Condition(domain=CartesianDomain({'x': [x_min, x_max], 'y': y_max}), equation=FixedValue(0.0, components=['u'])), - 'gamma2': Condition(location=CartesianDomain( + 'gamma2': Condition(domain=CartesianDomain( {'x': [x_min, x_max], 'y': y_min }), equation=FixedValue(0.0, components=['u'])), - 'gamma3': Condition(location=CartesianDomain( + 'gamma3': Condition(domain=CartesianDomain( {'x': x_max, 'y': [y_min, y_max] }), equation=FixedValue(0.0, components=['u'])), - 'gamma4': Condition(location=CartesianDomain( + 'gamma4': Condition(domain=CartesianDomain( {'x': x_min, 'y': [y_min, y_max] }), equation=FixedValue(0.0, components=['u'])), - 'D': Condition(location=CartesianDomain( + 'D': Condition(domain=CartesianDomain( {'x': [x_min, x_max], 'y': [y_min, y_max] }), equation=Equation(laplace_equation)), @@ -84,16 +83,16 @@ class Poisson(SpatialProblem): conditions = { 'gamma1': Condition( - location=CartesianDomain({'x': [0, 1], 'y': 1}), + domain=CartesianDomain({'x': [0, 1], 'y': 1}), equation=FixedValue(0.0)), 'gamma2': Condition( - location=CartesianDomain({'x': [0, 1], 'y': 0}), + domain=CartesianDomain({'x': [0, 1], 'y': 0}), equation=FixedValue(0.0)), 'gamma3': Condition( - location=CartesianDomain({'x': 1, 'y': [0, 1]}), + domain=CartesianDomain({'x': 1, 'y': [0, 1]}), equation=FixedValue(0.0)), 'gamma4': Condition( - location=CartesianDomain({'x': 0, 'y': [0, 1]}), + domain=CartesianDomain({'x': 0, 'y': [0, 1]}), equation=FixedValue(0.0)), 'D': Condition( input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']), @@ -112,7 +111,6 @@ def poisson_sol(self, pts): truth_solution = poisson_sol - class myFeature(torch.nn.Module): """ Feature: sin(x) @@ -158,12 +156,10 @@ def test_train_cpu(): pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) trainer = Trainer(solver=pinn, max_epochs=1, - accelerator='cpu', batch_size=20) - trainer.train() - + accelerator='cpu', batch_size=20, val_size=0., train_size=1., test_size=0.) -def test_train_restore(): - tmpdir = "tests/tmp_restore" +def test_train_load(): + tmpdir = "tests/tmp_load" poisson_problem = Poisson() boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] n = 10 @@ -173,20 +169,25 @@ def test_train_restore(): extra_features=None, loss=LpLoss()) trainer = Trainer(solver=pinn, - max_epochs=5, + max_epochs=15, accelerator='cpu', default_root_dir=tmpdir) trainer.train() - ntrainer = Trainer(solver=pinn, max_epochs=15, accelerator='cpu') - t = ntrainer.train( - ckpt_path=f'{tmpdir}/lightning_logs/version_0/' - 'checkpoints/epoch=4-step=10.ckpt') + new_pinn = PINN.load_from_checkpoint( + f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=15.ckpt', + problem = poisson_problem, model=model) + test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10) + assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1) + assert new_pinn.forward(test_pts).extract( + ['u']).shape == pinn.forward(test_pts).extract(['u']).shape + torch.testing.assert_close( + new_pinn.forward(test_pts).extract(['u']), + pinn.forward(test_pts).extract(['u'])) import shutil shutil.rmtree(tmpdir) - -def test_train_load(): - tmpdir = "tests/tmp_load" +def test_train_restore(): + tmpdir = "tests/tmp_restore" poisson_problem = Poisson() boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] n = 10 @@ -196,20 +197,14 @@ def test_train_load(): extra_features=None, loss=LpLoss()) trainer = Trainer(solver=pinn, - max_epochs=15, + max_epochs=5, accelerator='cpu', default_root_dir=tmpdir) trainer.train() - new_pinn = PINN.load_from_checkpoint( - f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt', - problem = poisson_problem, model=model) - test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10) - assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1) - assert new_pinn.forward(test_pts).extract( - ['u']).shape == pinn.forward(test_pts).extract(['u']).shape - torch.testing.assert_close( - new_pinn.forward(test_pts).extract(['u']), - pinn.forward(test_pts).extract(['u'])) + ntrainer = Trainer(solver=pinn, max_epochs=15, accelerator='cpu') + t = ntrainer.train( + ckpt_path=f'{tmpdir}/lightning_logs/version_0/' + 'checkpoints/epoch=4-step=5.ckpt') import shutil shutil.rmtree(tmpdir) @@ -217,36 +212,24 @@ def test_train_inverse_problem_cpu(): poisson_problem = InversePoisson() boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D'] n = 100 - poisson_problem.discretise_domain(n, 'random', locations=boundaries) + poisson_problem.discretise_domain(n, 'random', locations=boundaries, + variables=['x', 'y']) pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) trainer = Trainer(solver=pinn, max_epochs=1, accelerator='cpu', batch_size=20) trainer.train() - -# # TODO does not currently work -# def test_train_inverse_problem_restore(): -# tmpdir = "tests/tmp_restore_inv" -# poisson_problem = InversePoisson() -# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D'] -# n = 100 -# poisson_problem.discretise_domain(n, 'random', locations=boundaries) -# pinn = PINN(problem=poisson_problem, -# model=model, -# extra_features=None, -# loss=LpLoss()) -# trainer = Trainer(solver=pinn, -# max_epochs=5, -# accelerator='cpu', -# default_root_dir=tmpdir) -# trainer.train() -# ntrainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu') -# t = ntrainer.train( -# ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=10.ckpt') -# import shutil -# shutil.rmtree(tmpdir) - +def test_train_extra_feats_cpu(): + poisson_problem = Poisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 10 + poisson_problem.discretise_domain(n, 'grid', locations=boundaries) + pinn = PINN(problem=poisson_problem, + model=model_extra_feats, + extra_features=extra_feats) + trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu') + trainer.train() def test_train_inverse_problem_load(): tmpdir = "tests/tmp_load_inv" @@ -264,7 +247,7 @@ def test_train_inverse_problem_load(): default_root_dir=tmpdir) trainer.train() new_pinn = PINN.load_from_checkpoint( - f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt', + f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=15.ckpt', problem = poisson_problem, model=model) test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10) assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1) @@ -274,160 +257,4 @@ def test_train_inverse_problem_load(): new_pinn.forward(test_pts).extract(['u']), pinn.forward(test_pts).extract(['u'])) import shutil - shutil.rmtree(tmpdir) - -# # TODO fix asap. Basically sampling few variables -# # works only if both variables are in a range. -# # if one is fixed and the other not, this will -# # not work. This test also needs to be fixed and -# # insert in test problem not in test pinn. -# def test_train_cpu_sampling_few_vars(): -# poisson_problem = Poisson() -# boundaries = ['gamma1', 'gamma2', 'gamma3'] -# n = 10 -# poisson_problem.discretise_domain(n, 'grid', locations=boundaries) -# poisson_problem.discretise_domain(n, 'random', locations=['gamma4'], variables=['x']) -# poisson_problem.discretise_domain(n, 'random', locations=['gamma4'], variables=['y']) -# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) -# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'cpu'}) -# trainer.train() - - -def test_train_extra_feats_cpu(): - poisson_problem = Poisson() - boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] - n = 10 - poisson_problem.discretise_domain(n, 'grid', locations=boundaries) - pinn = PINN(problem=poisson_problem, - model=model_extra_feats, - extra_features=extra_feats) - trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu') - trainer.train() - - -# TODO, fix GitHub actions to run also on GPU -# def test_train_gpu(): -# poisson_problem = Poisson() -# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] -# n = 10 -# poisson_problem.discretise_domain(n, 'grid', locations=boundaries) -# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) -# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'}) -# trainer.train() - -# def test_train_gpu(): #TODO fix ASAP -# poisson_problem = Poisson() -# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] -# n = 10 -# poisson_problem.discretise_domain(n, 'grid', locations=boundaries) -# poisson_problem.conditions.pop('data') # The input/output pts are allocated on cpu -# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) -# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'}) -# trainer.train() - -# def test_train_2(): -# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] -# n = 10 -# expected_keys = [[], list(range(0, 50, 3))] -# param = [0, 3] -# for i, truth_key in zip(param, expected_keys): -# pinn = PINN(problem, model) -# pinn.discretise_domain(n, 'grid', locations=boundaries) -# pinn.discretise_domain(n, 'grid', locations=['D']) -# pinn.train(50, save_loss=i) -# assert list(pinn.history_loss.keys()) == truth_key - - -# def test_train_extra_feats(): -# pinn = PINN(problem, model_extra_feat, [myFeature()]) -# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] -# n = 10 -# pinn.discretise_domain(n, 'grid', locations=boundaries) -# pinn.discretise_domain(n, 'grid', locations=['D']) -# pinn.train(5) - - -# def test_train_2_extra_feats(): -# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] -# n = 10 -# expected_keys = [[], list(range(0, 50, 3))] -# param = [0, 3] -# for i, truth_key in zip(param, expected_keys): -# pinn = PINN(problem, model_extra_feat, [myFeature()]) -# pinn.discretise_domain(n, 'grid', locations=boundaries) -# pinn.discretise_domain(n, 'grid', locations=['D']) -# pinn.train(50, save_loss=i) -# assert list(pinn.history_loss.keys()) == truth_key - - -# def test_train_with_optimizer_kwargs(): -# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] -# n = 10 -# expected_keys = [[], list(range(0, 50, 3))] -# param = [0, 3] -# for i, truth_key in zip(param, expected_keys): -# pinn = PINN(problem, model, optimizer_kwargs={'lr' : 0.3}) -# pinn.discretise_domain(n, 'grid', locations=boundaries) -# pinn.discretise_domain(n, 'grid', locations=['D']) -# pinn.train(50, save_loss=i) -# assert list(pinn.history_loss.keys()) == truth_key - - -# def test_train_with_lr_scheduler(): -# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] -# n = 10 -# expected_keys = [[], list(range(0, 50, 3))] -# param = [0, 3] -# for i, truth_key in zip(param, expected_keys): -# pinn = PINN( -# problem, -# model, -# lr_scheduler_type=torch.optim.lr_scheduler.CyclicLR, -# lr_scheduler_kwargs={'base_lr' : 0.1, 'max_lr' : 0.3, 'cycle_momentum': False} -# ) -# pinn.discretise_domain(n, 'grid', locations=boundaries) -# pinn.discretise_domain(n, 'grid', locations=['D']) -# pinn.train(50, save_loss=i) -# assert list(pinn.history_loss.keys()) == truth_key - - -# # def test_train_batch(): -# # pinn = PINN(problem, model, batch_size=6) -# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] -# # n = 10 -# # pinn.discretise_domain(n, 'grid', locations=boundaries) -# # pinn.discretise_domain(n, 'grid', locations=['D']) -# # pinn.train(5) - - -# # def test_train_batch_2(): -# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] -# # n = 10 -# # expected_keys = [[], list(range(0, 50, 3))] -# # param = [0, 3] -# # for i, truth_key in zip(param, expected_keys): -# # pinn = PINN(problem, model, batch_size=6) -# # pinn.discretise_domain(n, 'grid', locations=boundaries) -# # pinn.discretise_domain(n, 'grid', locations=['D']) -# # pinn.train(50, save_loss=i) -# # assert list(pinn.history_loss.keys()) == truth_key - - -# if torch.cuda.is_available(): - -# # def test_gpu_train(): -# # pinn = PINN(problem, model, batch_size=20, device='cuda') -# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] -# # n = 100 -# # pinn.discretise_domain(n, 'grid', locations=boundaries) -# # pinn.discretise_domain(n, 'grid', locations=['D']) -# # pinn.train(5) - -# def test_gpu_train_nobatch(): -# pinn = PINN(problem, model, batch_size=None, device='cuda') -# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] -# n = 100 -# pinn.discretise_domain(n, 'grid', locations=boundaries) -# pinn.discretise_domain(n, 'grid', locations=['D']) -# pinn.train(5) - + shutil.rmtree(tmpdir) \ No newline at end of file diff --git a/tests/test_solvers/test_supervised_solver.py b/tests/test_solvers/test_supervised_solver.py index 8ceadcd9..ebe8179e 100644 --- a/tests/test_solvers/test_supervised_solver.py +++ b/tests/test_solvers/test_supervised_solver.py @@ -121,7 +121,7 @@ def test_train_cpu(): batch_size=5, train_size=1, test_size=0., - eval_size=0.) + val_size=0.) trainer.train() test_train_cpu() diff --git a/tutorials/tutorial5/tutorial.ipynb b/tutorials/tutorial5/tutorial.ipynb index 64032d31..d8e98546 100644 --- a/tutorials/tutorial5/tutorial.ipynb +++ b/tutorials/tutorial5/tutorial.ipynb @@ -22,8 +22,8 @@ "id": "5f2744dc", "metadata": { "ExecuteTime": { - "end_time": "2024-09-19T13:35:28.837348Z", - "start_time": "2024-09-19T13:35:27.611334Z" + "end_time": "2024-10-31T13:30:49.886194Z", + "start_time": "2024-10-31T13:30:48.624214Z" } }, "source": [ @@ -61,8 +61,8 @@ "id": "2ffb8a4c", "metadata": { "ExecuteTime": { - "end_time": "2024-09-19T13:35:28.989631Z", - "start_time": "2024-09-19T13:35:28.952744Z" + "end_time": "2024-10-31T13:30:49.920642Z", + "start_time": "2024-10-31T13:30:49.888843Z" } }, "source": [ @@ -70,14 +70,18 @@ "data = io.loadmat(\"Data_Darcy.mat\")\n", "\n", "# extract data (we use only 100 data for train)\n", - "k_train = LabelTensor(torch.tensor(data['k_train'], dtype=torch.float).unsqueeze(-1), \n", - " labels={3:{'dof': ['u0'], 'name': 'k_train'}})\n", - "u_train = LabelTensor(torch.tensor(data['u_train'], dtype=torch.float).unsqueeze(-1),\n", - " labels={3:{'dof': ['u'], 'name': 'u_train'}})\n", - "k_test = LabelTensor(torch.tensor(data['k_test'], dtype=torch.float).unsqueeze(-1),\n", - " labels={3:{'dof': ['u0'], 'name': 'k_test'}})\n", - "u_test= LabelTensor(torch.tensor(data['u_test'], dtype=torch.float).unsqueeze(-1),\n", - " labels={3:{'dof': ['u'], 'name': 'u_test'}})\n", + "k_train = LabelTensor(\n", + " torch.tensor(data['k_train'], dtype=torch.float).unsqueeze(-1),\n", + " labels={3: {'dof': ['u0'], 'name': 'k_train'}})\n", + "u_train = LabelTensor(\n", + " torch.tensor(data['u_train'], dtype=torch.float).unsqueeze(-1),\n", + " labels={3: {'dof': ['u'], 'name': 'u_train'}})\n", + "k_test = LabelTensor(\n", + " torch.tensor(data['k_test'], dtype=torch.float).unsqueeze(-1),\n", + " labels={3: {'dof': ['u0'], 'name': 'k_test'}})\n", + "u_test = LabelTensor(\n", + " torch.tensor(data['u_test'], dtype=torch.float).unsqueeze(-1),\n", + " labels={3: {'dof': ['u'], 'name': 'u_test'}})\n", "x = torch.tensor(data['x'], dtype=torch.float)[0]\n", "y = torch.tensor(data['y'], dtype=torch.float)[0]" ], @@ -97,8 +101,8 @@ "id": "c8501b6f", "metadata": { "ExecuteTime": { - "end_time": "2024-09-19T13:35:29.108381Z", - "start_time": "2024-09-19T13:35:29.031076Z" + "end_time": "2024-10-31T13:30:50.101674Z", + "start_time": "2024-10-31T13:30:50.034859Z" } }, "source": [ @@ -124,32 +128,6 @@ ], "execution_count": 3 }, - { - "cell_type": "code", - "id": "082ab7a8-22e0-498b-b138-158dc9f2658f", - "metadata": { - "ExecuteTime": { - "end_time": "2024-09-19T13:35:29.122858Z", - "start_time": "2024-09-19T13:35:29.119985Z" - } - }, - "source": [ - "u_train.labels[3]['dof']" - ], - "outputs": [ - { - "data": { - "text/plain": [ - "['u']" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "execution_count": 4 - }, { "cell_type": "markdown", "id": "89a77ff1", @@ -163,33 +141,23 @@ "id": "8b27d283", "metadata": { "ExecuteTime": { - "end_time": "2024-09-19T13:35:29.136572Z", - "start_time": "2024-09-19T13:35:29.134124Z" + "end_time": "2024-10-31T13:30:50.191764Z", + "start_time": "2024-10-31T13:30:50.189977Z" } }, "source": [ "class NeuralOperatorSolver(AbstractProblem):\n", - " input_variables = k_train.labels[3]['dof']\n", - " output_variables = u_train.labels[3]['dof']\n", - " domains = {\n", - " 'pts': k_train\n", - " }\n", - " conditions = {'data' : Condition(domain='pts', \n", - " output_points=u_train)}\n", + " input_variables = k_train.labels\n", + " output_variables = u_train.labels\n", + " conditions = {'data': Condition(input_points=k_train,\n", + " output_points=u_train)}\n", + "\n", "\n", "# make problem\n", "problem = NeuralOperatorSolver()" ], - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "execution_count": 5 + "outputs": [], + "execution_count": 4 }, { "cell_type": "markdown", @@ -206,20 +174,21 @@ "id": "e34f18b0", "metadata": { "ExecuteTime": { - "end_time": "2024-09-19T13:35:31.245429Z", - "start_time": "2024-09-19T13:35:29.154937Z" + "end_time": "2024-10-31T13:30:52.528635Z", + "start_time": "2024-10-31T13:30:50.225049Z" } }, "source": [ "# make model\n", "model = FeedForward(input_dimensions=1, output_dimensions=1)\n", "\n", - "\n", "# make solver\n", "solver = SupervisedSolver(problem=problem, model=model)\n", "\n", "# make the trainer and train\n", - "trainer = Trainer(solver=solver, max_epochs=10, accelerator='cpu', enable_model_summary=False, batch_size=10) \n", + "trainer = Trainer(solver=solver, max_epochs=10, accelerator='cpu',\n", + " enable_model_summary=False, batch_size=10, train_size=1.,\n", + " test_size=0., val_size=0.)\n", "# We train on CPU and avoid model summary at the beginning of training (optional)\n", "trainer.train()" ], @@ -231,14 +200,15 @@ "GPU available: True (mps), used: False\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n", - "/Users/filippoolivo/miniconda3/envs/PINAv0.2/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n" + "/Users/filippoolivo/miniconda3/envs/PINAv0.2-test/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n", + "/Users/filippoolivo/miniconda3/envs/PINAv0.2-test/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 9: 100%|██████████| 100/100 [00:00<00:00, 552.80it/s, v_num=18, mean_loss=0.113]" + "Epoch 9: 100%|██████████| 100/100 [00:00<00:00, 459.79it/s, v_num=58, mean_loss=0.108]" ] }, { @@ -252,11 +222,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch 9: 100%|██████████| 100/100 [00:00<00:00, 547.37it/s, v_num=18, mean_loss=0.113]\n" + "Epoch 9: 100%|██████████| 100/100 [00:00<00:00, 456.20it/s, v_num=58, mean_loss=0.108]\n" ] } ], - "execution_count": 6 + "execution_count": 5 }, { "cell_type": "markdown", @@ -271,8 +241,8 @@ "id": "0e2a6aa4", "metadata": { "ExecuteTime": { - "end_time": "2024-09-19T13:35:31.295336Z", - "start_time": "2024-09-19T13:35:31.256308Z" + "end_time": "2024-10-31T13:30:52.600181Z", + "start_time": "2024-10-31T13:30:52.564456Z" } }, "source": [ @@ -282,10 +252,12 @@ "metric_err = LpLoss(relative=True)\n", "\n", "model = solver.models[0]\n", - "err = float(metric_err(u_train.squeeze(-1), model(k_train).squeeze(-1)).mean())*100\n", + "err = float(\n", + " metric_err(u_train.squeeze(-1), model(k_train).squeeze(-1)).mean()) * 100\n", "print(f'Final error training {err:.2f}%')\n", "\n", - "err = float(metric_err(u_test.squeeze(-1), model(k_test).squeeze(-1)).mean())*100\n", + "err = float(\n", + " metric_err(u_test.squeeze(-1), model(k_test).squeeze(-1)).mean()) * 100\n", "print(f'Final error testing {err:.2f}%')" ], "outputs": [ @@ -294,11 +266,11 @@ "output_type": "stream", "text": [ "Final error training 56.05%\n", - "Final error testing 55.95%\n" + "Final error testing 56.02%\n" ] } ], - "execution_count": 7 + "execution_count": 6 }, { "cell_type": "markdown", @@ -315,8 +287,8 @@ "id": "9af523a5", "metadata": { "ExecuteTime": { - "end_time": "2024-09-19T13:35:44.717807Z", - "start_time": "2024-09-19T13:35:31.306689Z" + "end_time": "2024-10-31T13:31:06.537460Z", + "start_time": "2024-10-31T13:30:52.694424Z" } }, "source": [ @@ -330,12 +302,14 @@ " inner_size=24,\n", " padding=8)\n", "\n", - "\n", "# make solver\n", "solver = SupervisedSolver(problem=problem, model=model)\n", "\n", "# make the trainer and train\n", - "trainer = Trainer(solver=solver, max_epochs=10, accelerator='cpu', enable_model_summary=False, batch_size=10) # we train on CPU and avoid model summary at beginning of training (optional)\n", + "trainer = Trainer(solver=solver, max_epochs=10, accelerator='cpu',\n", + " enable_model_summary=False, batch_size=10, train_size=1.,\n", + " test_size=0.,\n", + " val_size=0.) # we train on CPU and avoid model summary at beginning of training (optional)\n", "trainer.train()" ], "outputs": [ @@ -352,7 +326,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch 9: 100%|██████████| 100/100 [00:01<00:00, 73.04it/s, v_num=19, mean_loss=0.00215]" + "Epoch 9: 100%|██████████| 100/100 [00:01<00:00, 72.88it/s, v_num=59, mean_loss=0.00291]" ] }, { @@ -366,11 +340,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch 9: 100%|██████████| 100/100 [00:01<00:00, 72.84it/s, v_num=19, mean_loss=0.00215]\n" + "Epoch 9: 100%|██████████| 100/100 [00:01<00:00, 72.69it/s, v_num=59, mean_loss=0.00291]\n" ] } ], - "execution_count": 8 + "execution_count": 7 }, { "cell_type": "markdown", @@ -385,17 +359,19 @@ "id": "58e2db89", "metadata": { "ExecuteTime": { - "end_time": "2024-09-19T13:35:45.259819Z", - "start_time": "2024-09-19T13:35:44.729042Z" + "end_time": "2024-10-31T13:31:07.134298Z", + "start_time": "2024-10-31T13:31:06.657198Z" } }, "source": [ "model = solver.models[0]\n", "\n", - "err = float(metric_err(u_train.squeeze(-1), model(k_train).squeeze(-1)).mean())*100\n", + "err = float(\n", + " metric_err(u_train.squeeze(-1), model(k_train).squeeze(-1)).mean()) * 100\n", "print(f'Final error training {err:.2f}%')\n", "\n", - "err = float(metric_err(u_test.squeeze(-1), model(k_test).squeeze(-1)).mean())*100\n", + "err = float(\n", + " metric_err(u_test.squeeze(-1), model(k_test).squeeze(-1)).mean()) * 100\n", "print(f'Final error testing {err:.2f}%')" ], "outputs": [ @@ -403,12 +379,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "Final error training 7.48%\n", - "Final error testing 7.73%\n" + "Final error training 7.21%\n", + "Final error testing 7.39%\n" ] } ], - "execution_count": 9 + "execution_count": 8 }, { "cell_type": "markdown", @@ -431,8 +407,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-09-19T13:08:35.195331Z", - "start_time": "2024-09-19T13:08:35.193830Z" + "end_time": "2024-10-31T13:31:07.231206Z", + "start_time": "2024-10-31T13:31:07.230017Z" } }, "cell_type": "code", From 457820d8fd4dbfad13489395f99267a852dd827e Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Sat, 9 Nov 2024 14:07:37 +0100 Subject: [PATCH 2/7] Bug fix in SupervisedSolver and improve LabelTensor with override of __torch_functions__ and __mul__ --- pina/label_tensor.py | 56 +++++++++++++++++++++++++++++++++- pina/solvers/pinns/basepinn.py | 24 ++++++--------- pina/solvers/supervised.py | 9 +++--- 3 files changed, 68 insertions(+), 21 deletions(-) diff --git a/pina/label_tensor.py b/pina/label_tensor.py index ad9034b2..a9c9da4e 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -15,7 +15,8 @@ def issubset(a, b): if isinstance(a, range) and isinstance(b, range): return a.start <= b.start and a.stop >= b.stop return False - +MATH_MODULES = {torch.sin, torch.cos, torch.exp, torch.tan, torch.log, + torch.sqrt} class LabelTensor(torch.Tensor): """Torch tensor with a label for any column.""" @@ -48,6 +49,59 @@ def __init__(self, x, labels, **kwargs): self.full = kwargs.get('full', full_labels) self.labels = labels + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func in MATH_MODULES: + str_labels = func.__name__ + labels = copy(args[0].stored_labels) + lt = super().__torch_function__(func, types, args=args, + kwargs=kwargs) + lt_shape = lt.shape + + if len(lt_shape) - 1 in labels.keys(): + labels.update({ + len(lt_shape) - 1: { + 'dof': [f'{str_labels}({i})' for i in + labels[len(lt_shape) - 1]['dof']], + 'name': len(lt_shape) - 1 + } + }) + lt._labels = labels + return lt + return super().__torch_function__(func, types, args=args, kwargs=kwargs) + + def __mul__(self, other): + lt = super().__mul__(other) + if isinstance(other, (int, float)): + if hasattr(self, '_labels'): + lt._labels = self._labels + if isinstance(other, LabelTensor): + lt_shape = lt.shape + labels = copy(self.stored_labels) + other_labels = other.stored_labels + check = False + for (k, v), (ko, vo) in zip(sorted(labels.items()), + sorted(other_labels.items())): + if k != ko: + raise ValueError('Labels must be the same') + if k != len(lt_shape) - 1: + if vo != v: + raise ValueError('Labels must be the same') + else: + check = True + if check: + labels.update({ + len(lt_shape) - 1: {'dof': [f'{i}{j}' for i, j in + zip(self.stored_labels[len(lt_shape) - 1]['dof'], + other.stored_labels[len(lt_shape) - 1]['dof'])], + 'name': self.stored_labels[len(lt_shape) - 1]['name']} + }) + + lt._labels = labels + return lt + @classmethod def __internal_init__(cls, x, diff --git a/pina/solvers/pinns/basepinn.py b/pina/solvers/pinns/basepinn.py index 3762cc88..fbed0bc5 100644 --- a/pina/solvers/pinns/basepinn.py +++ b/pina/solvers/pinns/basepinn.py @@ -122,7 +122,7 @@ def training_step(self, batch, _): condition_idx = supervised.condition_indices else: condition_idx = torch.tensor([]) - + loss = torch.tensor(0, dtype=torch.float32) for condition_id in torch.unique(condition_idx).tolist(): condition_name = self._dataloader.condition_names[condition_id] condition = self.problem.conditions[condition_name] @@ -132,11 +132,8 @@ def training_step(self, batch, _): output_pts = out[condition_idx == condition_id] input_pts = pts[condition_idx == condition_id] - input_pts.labels = pts.labels - output_pts.labels = out.labels - - loss = self.loss_data(input_points=input_pts, output_points=output_pts) - loss = loss.as_subclass(torch.Tensor) + loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) + loss += loss_.as_subclass(torch.Tensor) condition_idx = physics.condition_indices for condition_id in torch.unique(condition_idx).tolist(): @@ -147,20 +144,18 @@ def training_step(self, batch, _): pts = batch.physics.input_points input_pts = pts[condition_idx == condition_id] - input_pts.labels = pts.labels - loss = self.loss_phys(pts, condition.equation) + loss_ = self.loss_phys(input_pts, condition.equation) # add condition losses for each epoch - condition_losses.append(loss) + loss += loss_.as_subclass(torch.Tensor) # clamp unknown parameters in InverseProblem (if needed) self._clamp_params() # total loss (must be a torch.Tensor) - total_loss = sum(condition_losses) - return total_loss.as_subclass(torch.Tensor) + return loss - def loss_data(self, input_points, output_points): + def loss_data(self, input_pts, output_pts): """ The data loss for the PINN solver. It computes the loss between the network output against the true solution. This function @@ -172,9 +167,8 @@ def loss_data(self, input_points, output_points): :return: The residual loss averaged on the input coordinates :rtype: torch.Tensor """ - loss_value = self.loss(self.forward(input_points), output_points) - self.store_log(loss_value=float(loss_value)) - return loss_value + return self._loss(self.forward(input_pts), output_pts) + @abstractmethod def loss_phys(self, samples, equation): diff --git a/pina/solvers/supervised.py b/pina/solvers/supervised.py index b9258f7d..a2be1102 100644 --- a/pina/solvers/supervised.py +++ b/pina/solvers/supervised.py @@ -1,5 +1,4 @@ """ Module for SupervisedSolver """ - import torch from torch.nn.modules.loss import _Loss from ..optim import TorchOptimizer, TorchScheduler @@ -118,6 +117,7 @@ def training_step(self, batch, batch_idx): """ condition_idx = batch.supervised.condition_indices + loss = torch.tensor(0, dtype=torch.float32) for condition_id in range(condition_idx.min(), condition_idx.max() + 1): condition_name = self._dataloader.condition_names[condition_id] condition = self.problem.conditions[condition_name] @@ -130,14 +130,13 @@ def training_step(self, batch, batch_idx): if not hasattr(condition, "output_points"): raise NotImplementedError( f"{type(self).__name__} works only in data-driven mode.") + output_pts = out[condition_idx == condition_id] input_pts = pts[condition_idx == condition_id] - input_pts.labels = pts.labels - output_pts.labels = out.labels - loss = self.loss_data(input_pts=input_pts, output_pts=output_pts) - loss = loss.as_subclass(torch.Tensor) + loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) + loss += loss_.as_subclass(torch.Tensor) self.log("mean_loss", float(loss), prog_bar=True, logger=True) return loss From dfc6bf8d9e45af6aa19f9a27357f092a911ef02d Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Sat, 9 Nov 2024 14:49:16 +0100 Subject: [PATCH 3/7] Improve PinaBatch --- pina/data/pina_batch.py | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/pina/data/pina_batch.py b/pina/data/pina_batch.py index 79c076da..82d9e51e 100644 --- a/pina/data/pina_batch.py +++ b/pina/data/pina_batch.py @@ -1,6 +1,8 @@ """ Batch management module """ +from attr import attributes + from .pina_subset import PinaSubset @@ -13,11 +15,16 @@ class Batch: def __init__(self, dataset_dict, idx_dict, require_grad=True): self.attributes = [] for k, v in dataset_dict.items(): - setattr(self, k, v) + index = idx_dict[k] + if isinstance(v, PinaSubset): + dataset_index = v.indices + if isinstance(dataset_index, slice): + index = slice(dataset_index.start + index.start, + min(dataset_index.start + index.stop, + dataset_index.stop)) + setattr(self, k, PinaSubset(v.dataset, index, + require_grad=require_grad)) self.attributes.append(k) - - for k, v in idx_dict.items(): - setattr(self, k + '_idx', v) self.require_grad = require_grad def __len__(self): @@ -32,20 +39,6 @@ def __len__(self): length += len(attribute) return length - def __getattribute__(self, item): - if item in super().__getattribute__('attributes'): - dataset = super().__getattribute__(item) - index = super().__getattribute__(item + '_idx') - if isinstance(dataset, PinaSubset): - dataset_index = dataset.indices - if isinstance(dataset_index, slice): - index = slice(dataset_index.start + index.start, - min(dataset_index.start + index.stop, - dataset_index.stop)) - return PinaSubset(dataset.dataset, index, - require_grad=self.require_grad) - return super().__getattribute__(item) - def __getattr__(self, item): if item == 'data' and len(self.attributes) == 1: item = self.attributes[0] From 06ac8ad9172a33cc4b5cc04e99e4dc97d481c3ab Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Sat, 9 Nov 2024 14:59:37 +0100 Subject: [PATCH 4/7] Codacy warning correction --- pina/data/pina_batch.py | 4 +--- pina/label_tensor.py | 20 ++++++++++++-------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/pina/data/pina_batch.py b/pina/data/pina_batch.py index 82d9e51e..48bb0924 100644 --- a/pina/data/pina_batch.py +++ b/pina/data/pina_batch.py @@ -1,8 +1,6 @@ """ Batch management module """ -from attr import attributes - from .pina_subset import PinaSubset @@ -23,7 +21,7 @@ def __init__(self, dataset_dict, idx_dict, require_grad=True): min(dataset_index.start + index.stop, dataset_index.stop)) setattr(self, k, PinaSubset(v.dataset, index, - require_grad=require_grad)) + require_grad=require_grad)) self.attributes.append(k) self.require_grad = require_grad diff --git a/pina/label_tensor.py b/pina/label_tensor.py index a9c9da4e..58dc8b71 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -1,10 +1,13 @@ """ Module for LabelTensor """ import warnings -import torch from copy import copy, deepcopy +import torch from torch import Tensor full_labels = True +MATH_MODULES = {torch.sin, torch.cos, torch.exp, torch.tan, torch.log, + torch.sqrt} + def issubset(a, b): """ @@ -15,8 +18,7 @@ def issubset(a, b): if isinstance(a, range) and isinstance(b, range): return a.start <= b.start and a.stop >= b.stop return False -MATH_MODULES = {torch.sin, torch.cos, torch.exp, torch.tan, torch.log, - torch.sqrt} + class LabelTensor(torch.Tensor): """Torch tensor with a label for any column.""" @@ -27,8 +29,7 @@ def __new__(cls, x, labels, *args, **kwargs): if isinstance(x, LabelTensor): x.full = full return x - else: - return super().__new__(cls, x, *args, **kwargs) + return super().__new__(cls, x, *args, **kwargs) @property def tensor(self): @@ -94,9 +95,12 @@ def __mul__(self, other): if check: labels.update({ len(lt_shape) - 1: {'dof': [f'{i}{j}' for i, j in - zip(self.stored_labels[len(lt_shape) - 1]['dof'], - other.stored_labels[len(lt_shape) - 1]['dof'])], - 'name': self.stored_labels[len(lt_shape) - 1]['name']} + zip(self.stored_labels[ + len(lt_shape) - 1]['dof'], + other.stored_labels[ + len(lt_shape) - 1]['dof'])], + 'name': self.stored_labels[ + len(lt_shape) - 1]['name']} }) lt._labels = labels From d8292694d782ae6994b2918a10a7411a8bc5f1f3 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 11 Nov 2024 09:58:54 +0100 Subject: [PATCH 5/7] Add validation_step and improve data management --- pina/data/data_module.py | 61 ++++++++++++++++------ pina/solvers/solver.py | 1 - pina/solvers/supervised.py | 100 ++++++++++++++++++++++++++++++++++--- pina/trainer.py | 31 +++++++++--- 4 files changed, 163 insertions(+), 30 deletions(-) diff --git a/pina/data/data_module.py b/pina/data/data_module.py index ea6a802c..b09fb54a 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -61,30 +61,31 @@ def __init__(self, if train_size > 0: self.split_names.append('train') self.split_length.append(train_size) - self.loader_functions['train_dataloader'] = lambda \ - x: PinaDataLoader(self.splits['train'], self.batch_size, - self.condition_names) + else: + self.train_dataloader = super().train_dataloader + if test_size > 0: self.split_length.append(test_size) self.split_names.append('test') - self.loader_functions['test_dataloader'] = lambda x: PinaDataLoader( - self.splits['test'], self.batch_size, self.condition_names) + else: + self.test_dataloader = super().test_dataloader + if val_size > 0: self.split_length.append(val_size) self.split_names.append('val') - self.loader_functions['val_dataloader'] = lambda x: PinaDataLoader( - self.splits['val'], self.batch_size, self.condition_names) + else: + self.val_dataloader = super().val_dataloader + if predict_size > 0: self.split_length.append(predict_size) self.split_names.append('predict') - self.loader_functions[ - 'predict_dataloader'] = lambda x: PinaDataLoader( - self.splits['predict'], self.batch_size, self.condition_names) + else: + self.predict_dataloader = super().predict_dataloader + self.splits = {k: {} for k in self.split_names} self.shuffle = shuffle - - for k, v in self.loader_functions.items(): - setattr(self, k, v.__get__(self, PinaDataModule)) + self.has_setup_fit = False + self.has_setup_test = False def prepare_data(self): if self.datasets is None: @@ -106,8 +107,12 @@ def setup(self, stage=None): for i in range(len(self.split_length)): self.splits[self.split_names[i]][ dataset.data_type] = splits[i] + self.has_setup_fit = True elif stage == 'test': - raise NotImplementedError("Testing pipeline not implemented yet") + if self.has_setup_fit is False: + raise NotImplementedError( + "You must call setup with stage='fit' " + "first") else: raise ValueError("stage must be either 'fit' or 'test'") @@ -178,3 +183,31 @@ def _create_datasets(self): dataset.initialize() datasets.append(dataset) self.datasets = datasets + + def val_dataloader(self): + """ + Create the validation dataloader + """ + return PinaDataLoader(self.splits['val'], self.batch_size, + self.condition_names) + + def train_dataloader(self): + """ + Create the training dataloader + """ + return PinaDataLoader(self.splits['train'], self.batch_size, + self.condition_names) + + def test_dataloader(self): + """ + Create the testing dataloader + """ + return PinaDataLoader(self.splits['test'], self.batch_size, + self.condition_names) + + def predict_dataloader(self): + """ + Create the prediction dataloader + """ + return PinaDataLoader(self.splits['predict'], self.batch_size, + self.condition_names) diff --git a/pina/solvers/solver.py b/pina/solvers/solver.py index b622546e..fe9c897e 100644 --- a/pina/solvers/solver.py +++ b/pina/solvers/solver.py @@ -83,7 +83,6 @@ def __init__(self, " optimizers.") # extra features handling - self._pina_models = models self._pina_optimizers = optimizers self._pina_schedulers = schedulers diff --git a/pina/solvers/supervised.py b/pina/solvers/supervised.py index a2be1102..ff4153a6 100644 --- a/pina/solvers/supervised.py +++ b/pina/solvers/supervised.py @@ -1,5 +1,6 @@ """ Module for SupervisedSolver """ import torch +from pytorch_lightning.utilities.types import STEP_OUTPUT from torch.nn.modules.loss import _Loss from ..optim import TorchOptimizer, TorchScheduler from .solver import SolverInterface @@ -75,11 +76,15 @@ def __init__(self, extra_features=extra_features) # check consistency - check_consistency(loss, (LossInterface, _Loss), subclass=False) + check_consistency(loss, (LossInterface, _Loss), + subclass=False) self._loss = loss self._model = self._pina_models[0] self._optimizer = self._pina_optimizers[0] self._scheduler = self._pina_schedulers[0] + self.validation_condition_losses = { + k: {'loss': [], + 'count': []} for k in self.problem.conditions.keys()} def forward(self, x): """Forward pass implementation for the solver. @@ -105,7 +110,7 @@ def configure_optimizers(self): return ([self._optimizer.optimizer_instance], [self._scheduler.scheduler_instance]) - def training_step(self, batch, batch_idx): + def training_step(self, batch): """Solver training step. :param batch: The batch element in the dataloader. @@ -117,12 +122,14 @@ def training_step(self, batch, batch_idx): """ condition_idx = batch.supervised.condition_indices - loss = torch.tensor(0, dtype=torch.float32) + loss = torch.tensor(0, dtype=torch.float32).to(self.device) + batch = batch.supervised for condition_id in range(condition_idx.min(), condition_idx.max() + 1): - condition_name = self._dataloader.condition_names[condition_id] + condition_name = self.trainer.data_module.condition_names[ + condition_id] condition = self.problem.conditions[condition_name] - pts = batch.supervised.input_points - out = batch.supervised.output_points + pts = batch.input_points + out = batch.output_points if condition_name not in self.problem.conditions: raise RuntimeError("Something wrong happened.") @@ -134,13 +141,90 @@ def training_step(self, batch, batch_idx): output_pts = out[condition_idx == condition_id] input_pts = pts[condition_idx == condition_id] - loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) loss += loss_.as_subclass(torch.Tensor) - self.log("mean_loss", float(loss), prog_bar=True, logger=True) + self.log("mean_loss", float(loss), prog_bar=True, logger=True, + on_epoch=True, + on_step=False, batch_size=self.trainer.data_module.batch_size) return loss + def validation_step(self, batch): + """ + Solver validation step. + """ + + batch = batch.supervised + condition_idx = batch.condition_indices + for i in range(condition_idx.min(), condition_idx.max() + 1): + condition_name = self.trainer.data_module.condition_names[i] + condition = self.problem.conditions[condition_name] + pts = batch.input_points + out = batch.output_points + if condition_name not in self.problem.conditions: + raise RuntimeError("Something wrong happened.") + + # for data driven mode + if not hasattr(condition, "output_points"): + raise NotImplementedError( + f"{type(self).__name__} works only in data-driven mode.") + + output_pts = out[condition_idx == i] + input_pts = pts[condition_idx == i] + + loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) + self.validation_condition_losses[condition_name]['loss'].append( + loss_) + self.validation_condition_losses[condition_name]['count'].append( + len(input_pts)) + + def on_validation_epoch_end(self): + """ + Solver validation epoch end. + """ + total_loss = [] + total_count = [] + for k, v in self.validation_condition_losses.items(): + local_counter = torch.tensor(v['count']).to(self.device) + n_elements = torch.sum(local_counter) + loss = torch.sum( + torch.stack(v['loss']) * local_counter) / n_elements + loss = loss.as_subclass(torch.Tensor) + total_loss.append(loss) + total_count.append(n_elements) + self.log( + k + "_loss", + loss, + prog_bar=True, + logger=True, + on_epoch=True, + on_step=False, + batch_size=self.trainer.data_module.batch_size, + ) + total_count = (torch.tensor(total_count, dtype=torch.float32). + to(self.device)) + mean_loss = (torch.sum(torch.stack(total_loss) * total_count) / + total_count) + self.log( + "val_loss", + mean_loss, + prog_bar=True, + logger=True, + on_epoch=True, + on_step=False, + batch_size=self.trainer.data_module.batch_size, + ) + for key in self.validation_condition_losses.keys(): + self.validation_condition_losses[key]['loss'] = [] + self.validation_condition_losses[key]['count'] = [] + + def test_step(self, batch, batch_idx) -> STEP_OUTPUT: + """ + Solver test step. + """ + + raise NotImplementedError("Test step not implemented.") + def loss_data(self, input_pts, output_pts): """ The data loss for the Supervised solver. It computes the loss between diff --git a/pina/trainer.py b/pina/trainer.py index 49461166..46d26471 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -1,5 +1,5 @@ """ Trainer module. """ - +import warnings import torch import pytorch_lightning from .utils import check_consistency @@ -15,6 +15,7 @@ def __init__(self, train_size=.7, test_size=.2, val_size=.1, + predict_size=.0, **kwargs): """ PINA Trainer class for costumizing every aspect of training via flags. @@ -30,8 +31,8 @@ def __init__(self, and can be choosen from the `pytorch-lightning Trainer API `_ """ - - super().__init__(**kwargs) + log_every_n_steps = kwargs.get('log_every_n_steps', 0) + super().__init__(log_every_n_steps=log_every_n_steps, **kwargs) # check inheritance consistency for solver and batch size check_consistency(solver, SolverInterface) @@ -40,9 +41,9 @@ def __init__(self, self.train_size = train_size self.test_size = test_size self.val_size = val_size + self.predict_size = predict_size self.solver = solver self.batch_size = batch_size - self._create_loader() self._move_to_device() self.data_module = None @@ -83,6 +84,7 @@ def _create_loader(self): train_size=self.train_size, test_size=self.test_size, val_size=self.val_size, + predict_size=self.predict_size, batch_size=self.batch_size, ) self.data_module.setup() @@ -91,9 +93,24 @@ def train(self, **kwargs): Train the solver method. """ self._create_loader() - return super().fit(self.solver, - datamodule=self.data_module, - **kwargs) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="You defined a `validation_step` but have no " + "`val_dataloader`", + category=UserWarning + ) + return super().fit(self.solver, + datamodule=self.data_module, + **kwargs) + + def test(self, **kwargs): + """ + Test the solver method. + """ + return super().test(self.solver, + datamodule=self.data_module, + **kwargs) @property def solver(self): From e851c333ac501e472dcb3c2a6740c3aa1d671c20 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 11 Nov 2024 13:59:54 +0100 Subject: [PATCH 6/7] Improve training loop in for PINN and supervised solver --- pina/data/pina_batch.py | 48 ++++++++++++++++++++++++++++++++++ pina/solvers/pinns/basepinn.py | 35 +++++++------------------ pina/solvers/supervised.py | 31 +++++----------------- pina/trainer.py | 2 +- 4 files changed, 65 insertions(+), 51 deletions(-) diff --git a/pina/data/pina_batch.py b/pina/data/pina_batch.py index 48bb0924..e43e1108 100644 --- a/pina/data/pina_batch.py +++ b/pina/data/pina_batch.py @@ -1,6 +1,9 @@ """ Batch management module """ +import torch +from ..label_tensor import LabelTensor + from .pina_subset import PinaSubset @@ -42,3 +45,48 @@ def __getattr__(self, item): item = self.attributes[0] return self.__getattribute__(item) raise AttributeError(f"'Batch' object has no attribute '{item}'") + + def get_data(self, batch_name=None): + """ + # TODO + """ + data = getattr(self, batch_name) + to_return_list = [] + if isinstance(data, PinaSubset): + items = data.dataset.__slots__ + else: + items = data.__slots__ + indices = torch.unique(data.condition_indices).tolist() + condition_idx = data.condition_indices + for i in indices: + temp = [] + for j in items: + var = getattr(data, j) + if isinstance(var, (torch.Tensor, LabelTensor)): + temp.append(var[i == condition_idx]) + if isinstance(var, list) and len(var) > 0: + temp.append([var[k] for k in range(len(var)) if + i == condition_idx[k]]) + temp.append(i) + to_return_list.append(temp) + return to_return_list + + def get_supervised_data(self): + """ + Get a subset of the batch + :param idx: indices of the subset + :type idx: slice + :return: subset of the batch + :rtype: Batch + """ + return self.get_data(batch_name='supervised') + + def get_physics_data(self): + """ + Get a subset of the batch + :param idx: indices of the subset + :type idx: slice + :return: subset of the batch + :rtype: Batch + """ + return self.get_data(batch_name='physics') diff --git a/pina/solvers/pinns/basepinn.py b/pina/solvers/pinns/basepinn.py index fbed0bc5..2bca1823 100644 --- a/pina/solvers/pinns/basepinn.py +++ b/pina/solvers/pinns/basepinn.py @@ -113,46 +113,29 @@ def training_step(self, batch, _): :return: The sum of the loss functions. :rtype: LabelTensor """ - condition_losses = [] - - physics = batch.physics - if hasattr(batch, 'supervised'): - supervised = batch.supervised - condition_idx = supervised.condition_indices - else: - condition_idx = torch.tensor([]) - loss = torch.tensor(0, dtype=torch.float32) - for condition_id in torch.unique(condition_idx).tolist(): + batches = batch.get_supervised_data() + for points in batches: + input_pts, output_pts, condition_id = points condition_name = self._dataloader.condition_names[condition_id] - condition = self.problem.conditions[condition_name] self.__logged_metric = condition_name - pts = batch.supervised.input_points - out = batch.supervised.output_points - output_pts = out[condition_idx == condition_id] - input_pts = pts[condition_idx == condition_id] - loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) - loss += loss_.as_subclass(torch.Tensor) - - condition_idx = physics.condition_indices - for condition_id in torch.unique(condition_idx).tolist(): + condition_losses.append(loss_.as_subclass(torch.Tensor)) + batches = batch.get_physics_data() + for points in batches: + input_pts, condition_id = points condition_name = self._dataloader.condition_names[condition_id] condition = self.problem.conditions[condition_name] self.__logged_metric = condition_name - pts = batch.physics.input_points - input_pts = pts[condition_idx == condition_id] - loss_ = self.loss_phys(input_pts, condition.equation) - # add condition losses for each epoch - loss += loss_.as_subclass(torch.Tensor) + condition_losses.append(loss_.as_subclass(torch.Tensor)) # clamp unknown parameters in InverseProblem (if needed) self._clamp_params() + loss = sum(condition_losses) - # total loss (must be a torch.Tensor) return loss def loss_data(self, input_pts, output_pts): diff --git a/pina/solvers/supervised.py b/pina/solvers/supervised.py index ff4153a6..049518f1 100644 --- a/pina/solvers/supervised.py +++ b/pina/solvers/supervised.py @@ -120,30 +120,13 @@ def training_step(self, batch): :return: The sum of the loss functions. :rtype: LabelTensor """ - - condition_idx = batch.supervised.condition_indices - loss = torch.tensor(0, dtype=torch.float32).to(self.device) - batch = batch.supervised - for condition_id in range(condition_idx.min(), condition_idx.max() + 1): - condition_name = self.trainer.data_module.condition_names[ - condition_id] - condition = self.problem.conditions[condition_name] - pts = batch.input_points - out = batch.output_points - if condition_name not in self.problem.conditions: - raise RuntimeError("Something wrong happened.") - - # for data driven mode - if not hasattr(condition, "output_points"): - raise NotImplementedError( - f"{type(self).__name__} works only in data-driven mode.") - - output_pts = out[condition_idx == condition_id] - input_pts = pts[condition_idx == condition_id] - + condition_loss = [] + batches = batch.get_supervised_data() + for points in batches: + input_pts, output_pts, _ = points loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) - loss += loss_.as_subclass(torch.Tensor) - + condition_loss.append(loss_.as_subclass(torch.Tensor)) + loss = sum(condition_loss) self.log("mean_loss", float(loss), prog_bar=True, logger=True, on_epoch=True, on_step=False, batch_size=self.trainer.data_module.batch_size) @@ -223,7 +206,7 @@ def test_step(self, batch, batch_idx) -> STEP_OUTPUT: Solver test step. """ - raise NotImplementedError("Test step not implemented.") + raise NotImplementedError("Test step not implemented yet.") def loss_data(self, input_pts, output_pts): """ diff --git a/pina/trainer.py b/pina/trainer.py index 46d26471..f5ea5513 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -31,7 +31,7 @@ def __init__(self, and can be choosen from the `pytorch-lightning Trainer API `_ """ - log_every_n_steps = kwargs.get('log_every_n_steps', 0) + log_every_n_steps = kwargs.pop('log_every_n_steps', 0) super().__init__(log_every_n_steps=log_every_n_steps, **kwargs) # check inheritance consistency for solver and batch size From fdfd5733f158225c3ee8e35456d33321ed150f0f Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Wed, 13 Nov 2024 14:39:39 +0100 Subject: [PATCH 7/7] Reimplementation of data management classes, fix bugs and improve efficiency of LabelTensor --- pina/__init__.py | 10 +- pina/collector.py | 4 +- pina/condition/data_condition.py | 2 +- pina/condition/domain_equation_condition.py | 2 +- pina/condition/input_equation_condition.py | 2 +- pina/condition/input_output_condition.py | 2 +- pina/data/__init__.py | 13 +- pina/data/base_dataset.py | 156 ---------- pina/data/data_module.py | 299 +++++++++++--------- pina/data/pina_batch.py | 92 ------ pina/data/pina_dataloader.py | 96 ------- pina/data/pina_subset.py | 49 ---- pina/data/sample_dataset.py | 35 --- pina/data/supervised_dataset.py | 13 - pina/data/unsupervised_dataset.py | 14 - pina/label_tensor.py | 134 ++++++--- pina/model/network.py | 18 +- pina/operators.py | 10 +- pina/solvers/pinns/basepinn.py | 71 +++-- pina/solvers/pinns/pinn.py | 3 +- pina/solvers/solver.py | 61 +++- pina/solvers/supervised.py | 87 +----- pina/trainer.py | 21 +- setup.py | 2 +- 24 files changed, 416 insertions(+), 780 deletions(-) delete mode 100644 pina/data/base_dataset.py delete mode 100644 pina/data/pina_batch.py delete mode 100644 pina/data/pina_dataloader.py delete mode 100644 pina/data/pina_subset.py delete mode 100644 pina/data/sample_dataset.py delete mode 100644 pina/data/supervised_dataset.py delete mode 100644 pina/data/unsupervised_dataset.py diff --git a/pina/__init__.py b/pina/__init__.py index 3bc28ae6..8d56f235 100644 --- a/pina/__init__.py +++ b/pina/__init__.py @@ -1,17 +1,17 @@ __all__ = [ - "Trainer", "LabelTensor", "Plotter", "Condition", "SamplePointDataset", - "PinaDataModule", "PinaDataLoader", 'TorchOptimizer', 'Graph' + "Trainer", "LabelTensor", "Plotter", "Condition", + "PinaDataModule", 'TorchOptimizer', 'Graph', 'LabelParameter' ] from .meta import * -from .label_tensor import LabelTensor +from .label_tensor import LabelTensor, LabelParameter from .solvers.solver import SolverInterface from .trainer import Trainer from .plotter import Plotter from .condition.condition import Condition -from .data import SamplePointDataset + from .data import PinaDataModule -from .data import PinaDataLoader + from .optim import TorchOptimizer from .optim import TorchScheduler from .graph import Graph diff --git a/pina/collector.py b/pina/collector.py index 3219b2b6..77daad14 100644 --- a/pina/collector.py +++ b/pina/collector.py @@ -68,7 +68,7 @@ def store_sample_domains(self, n, mode, variables, sample_locations): condition = self.problem.conditions[loc] keys = ["input_points", "equation"] # if the condition is not ready, we get and store the data - if (not self._is_conditions_ready[loc]): + if not self._is_conditions_ready[loc]: # if it is the first time we sample if not self.data_collections[loc]: already_sampled = [] @@ -87,7 +87,7 @@ def store_sample_domains(self, n, mode, variables, sample_locations): condition.domain.sample(n=n, mode=mode, variables=variables) ] + already_sampled pts = merge_tensors(samples) - if (set(pts.labels).issubset(sorted(self.problem.input_variables))): + if set(pts.labels).issubset(sorted(self.problem.input_variables)): pts = pts.sort_labels() if sorted(pts.labels) == sorted(self.problem.input_variables): self._is_conditions_ready[loc] = True diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index 05c543eb..2adf1c82 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -19,7 +19,7 @@ class DataConditionInterface(ConditionInterface): def __init__(self, input_points, conditional_variables=None): """ - TODO + TODO : add docstring """ super().__init__() self.input_points = input_points diff --git a/pina/condition/domain_equation_condition.py b/pina/condition/domain_equation_condition.py index 53e07621..e0e7f916 100644 --- a/pina/condition/domain_equation_condition.py +++ b/pina/condition/domain_equation_condition.py @@ -16,7 +16,7 @@ class DomainEquationCondition(ConditionInterface): condition_type = ['physics'] def __init__(self, domain, equation): """ - TODO + TODO : add docstring """ super().__init__() self.domain = domain diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index 2a7f4647..2c376a16 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -17,7 +17,7 @@ class InputPointsEquationCondition(ConditionInterface): condition_type = ['physics'] def __init__(self, input_points, equation): """ - TODO + TODO : add docstring """ super().__init__() self.input_points = input_points diff --git a/pina/condition/input_output_condition.py b/pina/condition/input_output_condition.py index e9c34bea..de92926d 100644 --- a/pina/condition/input_output_condition.py +++ b/pina/condition/input_output_condition.py @@ -16,7 +16,7 @@ class InputOutputPointsCondition(ConditionInterface): condition_type = ['supervised'] def __init__(self, input_points, output_points): """ - TODO + TODO : add docstring """ super().__init__() self.input_points = input_points diff --git a/pina/data/__init__.py b/pina/data/__init__.py index 2b3a126a..a2adcff4 100644 --- a/pina/data/__init__.py +++ b/pina/data/__init__.py @@ -2,14 +2,11 @@ Import data classes """ __all__ = [ - 'PinaDataLoader', 'SupervisedDataset', 'SamplePointDataset', - 'UnsupervisedDataset', 'Batch', 'PinaDataModule', 'BaseDataset' + 'PinaDataModule', + 'PinaDataset' ] -from .pina_dataloader import PinaDataLoader -from .supervised_dataset import SupervisedDataset -from .sample_dataset import SamplePointDataset -from .unsupervised_dataset import UnsupervisedDataset -from .pina_batch import Batch + + from .data_module import PinaDataModule -from .base_dataset import BaseDataset +from .data_module import PinaDataset diff --git a/pina/data/base_dataset.py b/pina/data/base_dataset.py deleted file mode 100644 index d05784f8..00000000 --- a/pina/data/base_dataset.py +++ /dev/null @@ -1,156 +0,0 @@ -""" -Basic data module implementation -""" -import torch -import logging - -from torch.utils.data import Dataset -from ..label_tensor import LabelTensor - - -class BaseDataset(Dataset): - """ - BaseDataset class, which handle initialization and data retrieval - :var condition_indices: List of indices - :var device: torch.device - """ - - def __new__(cls, problem=None, device=torch.device('cpu')): - """ - Ensure correct definition of __slots__ before initialization - :param AbstractProblem problem: The formulation of the problem. - :param torch.device device: The device on which the - dataset will be loaded. - """ - if cls is BaseDataset: - raise TypeError( - 'BaseDataset cannot be instantiated directly. Use a subclass.') - if not hasattr(cls, '__slots__'): - raise TypeError( - 'Something is wrong, __slots__ must be defined in subclasses.') - return object.__new__(cls) - - def __init__(self, problem=None, device=torch.device('cpu')): - """" - Initialize the object based on __slots__ - :param AbstractProblem problem: The formulation of the problem. - :param torch.device device: The device on which the - dataset will be loaded. - """ - super().__init__() - self.empty = True - self.problem = problem - self.device = device - self.condition_indices = None - for slot in self.__slots__: - setattr(self, slot, []) - self.num_el_per_condition = [] - self.conditions_idx = [] - if self.problem is not None: - self._init_from_problem(self.problem.collector.data_collections) - self.initialized = False - - def _init_from_problem(self, collector_dict): - """ - TODO - """ - for name, data in collector_dict.items(): - keys = list(data.keys()) - if set(self.__slots__) == set(keys): - self._populate_init_list(data) - idx = [ - key for key, val in - self.problem.collector.conditions_name.items() - if val == name - ] - self.conditions_idx.append(idx) - self.initialize() - - def add_points(self, data_dict, condition_idx, batching_dim=0): - """ - This method filled internal lists of data points - :param data_dict: dictionary containing data points - :param condition_idx: index of the condition to which the data points - belong to - :param batching_dim: dimension of the batching - :raises: ValueError if the dataset has already been initialized - """ - if not self.initialized: - self._populate_init_list(data_dict, batching_dim) - self.conditions_idx.append(condition_idx) - self.empty = False - else: - raise ValueError('Dataset already initialized') - - def _populate_init_list(self, data_dict, batching_dim=0): - current_cond_num_el = None - for slot in data_dict.keys(): - slot_data = data_dict[slot] - if batching_dim != 0: - if isinstance(slot_data, (LabelTensor, torch.Tensor)): - dims = len(slot_data.size()) - slot_data = slot_data.permute( - [batching_dim] + - [dim for dim in range(dims) if dim != batching_dim]) - if current_cond_num_el is None: - current_cond_num_el = len(slot_data) - elif current_cond_num_el != len(slot_data): - raise ValueError('Different dimension in same condition') - current_list = getattr(self, slot) - current_list += [ - slot_data - ] if not (isinstance(slot_data, list)) else slot_data - self.num_el_per_condition.append(current_cond_num_el) - - def initialize(self): - """ - Initialize the datasets tensors/LabelTensors/lists given the lists - already filled - """ - logging.debug(f'Initialize dataset {self.__class__.__name__}') - if self.num_el_per_condition: - self.condition_indices = torch.cat([ - torch.tensor( - [self.conditions_idx[i]] * self.num_el_per_condition[i], - dtype=torch.uint8) - for i in range(len(self.num_el_per_condition)) - ], - dim=0) - for slot in self.__slots__: - current_attribute = getattr(self, slot) - if all(isinstance(a, LabelTensor) for a in current_attribute): - setattr(self, slot, LabelTensor.vstack(current_attribute)) - self.initialized = True - - def __len__(self): - """ - :return: Number of elements in the dataset - """ - return len(getattr(self, self.__slots__[0])) - - def __getitem__(self, idx): - """ - :param idx: - :return: - """ - if not isinstance(idx, (tuple, list, slice, int)): - raise IndexError("Invalid index") - tensors = [] - for attribute in self.__slots__: - tensor = getattr(self, attribute) - if isinstance(attribute, (LabelTensor, torch.Tensor)): - tensors.append(tensor.__getitem__(idx)) - elif isinstance(attribute, list): - if isinstance(idx, (list, tuple)): - tensor = [tensor[i] for i in idx] - tensors.append(tensor) - return tensors - - def apply_shuffle(self, indices): - for slot in self.__slots__: - if slot != 'equation': - attribute = getattr(self, slot) - if isinstance(attribute, (LabelTensor, torch.Tensor)): - setattr(self, 'slot', attribute[[indices]]) - if isinstance(attribute, list): - setattr(self, 'slot', [attribute[i] for i in indices]) diff --git a/pina/data/data_module.py b/pina/data/data_module.py index b09fb54a..44461dce 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -2,15 +2,61 @@ This module provide basic data management functionalities """ -import math -import torch import logging -from pytorch_lightning import LightningDataModule -from .sample_dataset import SamplePointDataset -from .supervised_dataset import SupervisedDataset -from .unsupervised_dataset import UnsupervisedDataset -from .pina_dataloader import PinaDataLoader -from .pina_subset import PinaSubset +from lightning.pytorch import LightningDataModule +import torch +import math +from .. import LabelTensor +from torch.utils.data import Dataset +from torch.utils.data import DataLoader +from functools import partial + +class PinaDataset(Dataset): + def __init__(self, conditions_dict): + self.conditions_dict = conditions_dict + print(conditions_dict.keys()) + self.length = self._get_max_len() + + def _get_max_len(self): + max_len = 0 + for condition in self.conditions_dict.values(): + max_len = max(max_len, len(condition['input_points'])) + return max_len + + def __len__(self): + return self.length + + def __getitem__(self, idx): + return { + k: {k_data: v[k_data][idx % len(v['input_points'])] for k_data + in v.keys()} for k, v in self.conditions_dict.items() + } + +def collate_fn(batch, max_conditions_lengths): + """ + Function used to collate the batch + """ + batch_dict = {} + if isinstance(batch, dict): + return batch + conditions_names = batch[0].keys() + + # Condition names + for condition_name in conditions_names: + single_cond_dict = {} + condition_args = batch[0][condition_name].keys() + for arg in condition_args: + data_list = [batch[idx][condition_name][arg] for idx in range( + min(len(batch), max_conditions_lengths[condition_name]))] + if isinstance(data_list[0], LabelTensor): + single_cond_dict[arg] = LabelTensor.cat(data_list) + elif isinstance(data_list[0], torch.Tensor): + single_cond_dict[arg] = torch.stack(data_list) + else: + raise NotImplementedError( + f"Data type {type(data_list[0])} not supported") + batch_dict[condition_name] = single_cond_dict + return batch_dict class PinaDataModule(LightningDataModule): @@ -20,194 +66,183 @@ class PinaDataModule(LightningDataModule): """ def __init__(self, - problem, - device, + collector, train_size=.7, test_size=.2, val_size=.1, predict_size=0., batch_size=None, shuffle=True, - datasets=None): + repeat=False): """ Initialize the object, creating dataset based on input problem - :param AbstractProblem problem: PINA problem - :param device: Device used for training and testing + :param Collector collector: PINA problem :param train_size: number/percentage of elements in train split :param test_size: number/percentage of elements in test split - :param eval_size: number/percentage of elements in evaluation split + :param val_size: number/percentage of elements in evaluation split :param batch_size: batch size used for training - :param datasets: list of datasets objects """ logging.debug('Start initialization of Pina DataModule') logging.info('Start initialization of Pina DataModule') super().__init__() - self.problem = problem - self.device = device - self.dataset_classes = [ - SupervisedDataset, UnsupervisedDataset, SamplePointDataset - ] - if datasets is None: - self.datasets = None - else: - self.datasets = datasets - self.split_length = [] - self.split_names = [] - self.loader_functions = {} self.batch_size = batch_size - self.condition_names = problem.collector.conditions_name + self.shuffle = shuffle + self.repeat = repeat + # Begin Data splitting + splits_dict = {} if train_size > 0: - self.split_names.append('train') - self.split_length.append(train_size) + splits_dict['train'] = train_size + self.train_dataset = None else: self.train_dataloader = super().train_dataloader - if test_size > 0: - self.split_length.append(test_size) - self.split_names.append('test') + splits_dict['test'] = train_size + self.test_dataset = None else: self.test_dataloader = super().test_dataloader - if val_size > 0: - self.split_length.append(val_size) - self.split_names.append('val') + splits_dict['val'] = train_size + self.val_dataset = None else: self.val_dataloader = super().val_dataloader - if predict_size > 0: - self.split_length.append(predict_size) - self.split_names.append('predict') + splits_dict['predict'] = train_size + self.predict_dataset = None else: self.predict_dataloader = super().predict_dataloader + self.collector_splits = self._create_splits(collector, splits_dict) - self.splits = {k: {} for k in self.split_names} - self.shuffle = shuffle - self.has_setup_fit = False - self.has_setup_test = False - - def prepare_data(self): - if self.datasets is None: - self._create_datasets() def setup(self, stage=None): """ Perform the splitting of the dataset """ logging.debug('Start setup of Pina DataModule obj') - if self.datasets is None: - self._create_datasets() + if stage == 'fit' or stage is None: - for dataset in self.datasets: - if len(dataset) > 0: - splits = self.dataset_split(dataset, - self.split_length, - shuffle=self.shuffle) - for i in range(len(self.split_length)): - self.splits[self.split_names[i]][ - dataset.data_type] = splits[i] - self.has_setup_fit = True + self.train_dataset = PinaDataset(self.collector_splits['train']) + if 'val' in self.collector_splits.keys(): + self.val_dataset = PinaDataset(self.collector_splits['val']) elif stage == 'test': - if self.has_setup_fit is False: - raise NotImplementedError( - "You must call setup with stage='fit' " - "first") + self.test_dataset = PinaDataset(self.collector_splits['test']) + elif stage == 'predict': + self.predict_dataset = PinaDataset(self.collector_splits['predict']) else: - raise ValueError("stage must be either 'fit' or 'test'") + raise ValueError("stage must be either 'fit' or 'test' or 'predict'.") @staticmethod - def dataset_split(dataset, lengths, seed=None, shuffle=True): - """ - Perform the splitting of the dataset - :param dataset: dataset object we wanted to split - :param lengths: lengths of elements in dataset - :param seed: random seed - :param shuffle: shuffle dataset - :return: split dataset - :rtype: PinaSubset - """ - if sum(lengths) - 1 < 1e-3: - len_dataset = len(dataset) - lengths = [ - int(math.floor(len_dataset * length)) for length in lengths - ] - remainder = len(dataset) - sum(lengths) - for i in range(remainder): - lengths[i % len(lengths)] += 1 - elif sum(lengths) - 1 >= 1e-3: - raise ValueError(f"Sum of lengths is {sum(lengths)} less than 1") - - if shuffle: - if seed is not None: - generator = torch.Generator() - generator.manual_seed(seed) - indices = torch.randperm(sum(lengths), generator=generator) - else: - indices = torch.randperm(sum(lengths)) - dataset.apply_shuffle(indices) - - offsets = [ - sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths)) - ] - return [ - PinaSubset(dataset, slice(offset, offset + length)) - for offset, length in zip(offsets, lengths) + def _split_condition(condition_dict, splits_dict): + len_condition = len(condition_dict['input_points']) + lengths = [ + int(math.floor(len_condition * length)) for length in + splits_dict.values() ] - - def _create_datasets(self): + remainder = len_condition - sum(lengths) + for i in range(remainder): + lengths[i % len(lengths)] += 1 + splits_dict = {k: v for k, v in zip(splits_dict.keys(), lengths) + } + + to_return_dict = {} + offset = 0 + for stage, stage_len in splits_dict.items(): + to_return_dict[stage] = {k:v [offset:offset + stage_len] + for k, v in condition_dict.items() if k != 'equation' # Equations are NEVER dataloaded + } + offset += stage_len + return to_return_dict + + def _create_splits(self, collector, splits_dict): """ Create the dataset objects putting data """ logging.debug('Dataset creation in PinaDataModule obj') - collector = self.problem.collector - batching_dim = self.problem.batching_dimension - datasets_slots = [i.__slots__ for i in self.dataset_classes] - self.datasets = [ - dataset(device=self.device) for dataset in self.dataset_classes - ] - logging.debug('Filling datasets in PinaDataModule obj') - for name, data in collector.data_collections.items(): - keys = list(data.keys()) - idx = [ - key for key, val in collector.conditions_name.items() - if val == name - ] - for i, slot in enumerate(datasets_slots): - if slot == keys: - self.datasets[i].add_points(data, idx[0], batching_dim) - continue - datasets = [] - for dataset in self.datasets: - if not dataset.empty: - dataset.initialize() - datasets.append(dataset) - self.datasets = datasets + split_names = list(splits_dict.keys()) + dataset_dict = {name: {} for name in split_names} + for condition_name, condition_dict in collector.data_collections.items(): + len_data = len(condition_dict['input_points']) + if self.shuffle: + idx = torch.randperm(len_data) + for k, v in condition_dict.items(): + if k == 'equation': + continue + if isinstance(v, list): + condition_dict[k] = [v[i] for i in idx] + elif isinstance(v, (torch.Tensor, LabelTensor)): + condition_dict[k] = v[[idx]] + else: + raise ValueError(f"Data type {type(v)} not supported") + for key, data in self._split_condition(condition_dict, splits_dict).items(): + dataset_dict[key].update({condition_name: data}) + return dataset_dict + + def find_max_conditions_lengths(self, split): + max_conditions_lengths = {} + for k, v in self.collector_splits[split].items(): + if self.batch_size is None: + max_conditions_lengths[k] = len(v['input_points']) + elif self.repeat: + max_conditions_lengths[k] = self.batch_size + else: + max_conditions_lengths[k] = min(len(v['input_points']), + self.batch_size) + return max_conditions_lengths def val_dataloader(self): """ Create the validation dataloader """ - return PinaDataLoader(self.splits['val'], self.batch_size, - self.condition_names) + max_conditions_lengths = self.find_max_conditions_lengths('val') + collate_fn_val = partial(collate_fn, max_conditions_lengths = max_conditions_lengths) + return DataLoader(self.val_dataset, self.batch_size, + collate_fn=collate_fn_val, shuffle=False # already shuffled in self._create_split + ) def train_dataloader(self): """ Create the training dataloader """ - return PinaDataLoader(self.splits['train'], self.batch_size, - self.condition_names) + max_conditions_lengths = self.find_max_conditions_lengths('train') + collate_fn_train = partial(collate_fn, max_conditions_lengths = max_conditions_lengths) + return DataLoader(self.train_dataset, self.batch_size, + collate_fn=collate_fn_train, shuffle=False # already shuffled in self._create_split + ) def test_dataloader(self): """ Create the testing dataloader """ - return PinaDataLoader(self.splits['test'], self.batch_size, - self.condition_names) + max_conditions_lengths = self.find_max_conditions_lengths('test') + collate_fn_test = partial(collate_fn, max_conditions_lengths = max_conditions_lengths) + return DataLoader(self.test_dataset, self.batch_size, + collate_fn=collate_fn_test, shuffle=False) def predict_dataloader(self): """ Create the prediction dataloader """ - return PinaDataLoader(self.splits['predict'], self.batch_size, - self.condition_names) + max_conditions_lengths = self.find_max_conditions_lengths('predict') + collate_fn_predict = partial(collate_fn, max_conditions_lengths = max_conditions_lengths) + return DataLoader(self.predict_dataset, self.batch_size, + collate_fn=collate_fn_predict, shuffle=False) + + def transfer_batch_to_device(self, batch, device, dataloader_idx): + """ + Transfer the batch to the device. This method is called in the + training loop and is used to transfer the batch to the device. + """ + ''' + for i in batch.keys(): + for j in batch[i].keys(): + batch[i][j] = batch[i][j].to(device) + ''' + batch = {k: super(LightningDataModule, self).transfer_batch_to_device(v, device, dataloader_idx) + for k, v in batch.items()} + return batch + + +from ..label_tensor import LabelTensor +import torch + diff --git a/pina/data/pina_batch.py b/pina/data/pina_batch.py deleted file mode 100644 index e43e1108..00000000 --- a/pina/data/pina_batch.py +++ /dev/null @@ -1,92 +0,0 @@ -""" -Batch management module -""" -import torch -from ..label_tensor import LabelTensor - -from .pina_subset import PinaSubset - - -class Batch: - """ - Implementation of the Batch class used during training to perform SGD - optimization. - """ - - def __init__(self, dataset_dict, idx_dict, require_grad=True): - self.attributes = [] - for k, v in dataset_dict.items(): - index = idx_dict[k] - if isinstance(v, PinaSubset): - dataset_index = v.indices - if isinstance(dataset_index, slice): - index = slice(dataset_index.start + index.start, - min(dataset_index.start + index.stop, - dataset_index.stop)) - setattr(self, k, PinaSubset(v.dataset, index, - require_grad=require_grad)) - self.attributes.append(k) - self.require_grad = require_grad - - def __len__(self): - """ - Returns the number of elements in the batch - :return: number of elements in the batch - :rtype: int - """ - length = 0 - for dataset in self.attributes: - attribute = getattr(self, dataset) - length += len(attribute) - return length - - def __getattr__(self, item): - if item == 'data' and len(self.attributes) == 1: - item = self.attributes[0] - return self.__getattribute__(item) - raise AttributeError(f"'Batch' object has no attribute '{item}'") - - def get_data(self, batch_name=None): - """ - # TODO - """ - data = getattr(self, batch_name) - to_return_list = [] - if isinstance(data, PinaSubset): - items = data.dataset.__slots__ - else: - items = data.__slots__ - indices = torch.unique(data.condition_indices).tolist() - condition_idx = data.condition_indices - for i in indices: - temp = [] - for j in items: - var = getattr(data, j) - if isinstance(var, (torch.Tensor, LabelTensor)): - temp.append(var[i == condition_idx]) - if isinstance(var, list) and len(var) > 0: - temp.append([var[k] for k in range(len(var)) if - i == condition_idx[k]]) - temp.append(i) - to_return_list.append(temp) - return to_return_list - - def get_supervised_data(self): - """ - Get a subset of the batch - :param idx: indices of the subset - :type idx: slice - :return: subset of the batch - :rtype: Batch - """ - return self.get_data(batch_name='supervised') - - def get_physics_data(self): - """ - Get a subset of the batch - :param idx: indices of the subset - :type idx: slice - :return: subset of the batch - :rtype: Batch - """ - return self.get_data(batch_name='physics') diff --git a/pina/data/pina_dataloader.py b/pina/data/pina_dataloader.py deleted file mode 100644 index a28ca6c6..00000000 --- a/pina/data/pina_dataloader.py +++ /dev/null @@ -1,96 +0,0 @@ -""" -This module is used to create an iterable object used during training -""" -import math - -from .pina_batch import Batch - - -class PinaDataLoader: - """ - This class is used to create a dataloader to use during the training. - - :var condition_names: The names of the conditions. The order is consistent - with the condition indeces in the batches. - :vartype condition_names: list[str] - """ - - def __init__(self, dataset_dict, batch_size, condition_names) -> None: - """ - Initialize local variables - :param dataset_dict: Dictionary of datasets - :type dataset_dict: dict - :param batch_size: Size of the batch - :type batch_size: int - :param condition_names: Names of the conditions - :type condition_names: list[str] - """ - self.condition_names = condition_names - self.dataset_dict = dataset_dict - self.batch_size = batch_size - self._init_batches(batch_size) - - def _init_batches(self, batch_size=None): - """ - Create batches according to the batch_size provided in input. - """ - self.batches = [] - n_elements = sum(len(v) for v in self.dataset_dict.values()) - if batch_size is None: - batch_size = n_elements - self.batch_size = n_elements - n_batches = int(math.ceil(n_elements / batch_size)) - indexes_dict = { - k: math.floor(len(v) / n_batches) if n_batches != 1 else len(v) for - k, v in self.dataset_dict.items()} - - dataset_names = list(self.dataset_dict.keys()) - num_el_per_batch = [{i: indexes_dict[i] for i in dataset_names} for _ - in range(n_batches - 1)] + [ - {i: 0 for i in dataset_names}] - reminders = { - i: len(self.dataset_dict[i]) - indexes_dict[i] * (n_batches - 1) for - i in dataset_names} - dataset_names = iter(dataset_names) - name = next(dataset_names, None) - for batch in num_el_per_batch: - tot_num_el = sum(batch.values()) - batch_reminder = batch_size - tot_num_el - for _ in range(batch_reminder): - if name is None: - break - if reminders[name] > 0: - batch[name] += 1 - reminders[name] -= 1 - else: - name = next(dataset_names, None) - if name is None: - break - batch[name] += 1 - reminders[name] -= 1 - - reminders, dataset_names, indexes_dict = None, None, None # free memory - actual_indices = {k: 0 for k in self.dataset_dict.keys()} - for batch in num_el_per_batch: - temp_dict = {} - total_length = 0 - for k, v in batch.items(): - temp_dict[k] = slice(actual_indices[k], actual_indices[k] + v) - actual_indices[k] = actual_indices[k] + v - total_length += v - self.batches.append( - Batch(idx_dict=temp_dict, dataset_dict=self.dataset_dict)) - - def __iter__(self): - """ - Makes dataloader object iterable - """ - yield from self.batches - - def __len__(self): - """ - Return the number of batches. - :return: The number of batches. - :rtype: int - """ - return len(self.batches) diff --git a/pina/data/pina_subset.py b/pina/data/pina_subset.py deleted file mode 100644 index b5b74a68..00000000 --- a/pina/data/pina_subset.py +++ /dev/null @@ -1,49 +0,0 @@ -""" -Module for PinaSubset class -""" -from pina import LabelTensor -from torch import Tensor, float32 - - -class PinaSubset: - """ - TODO - """ - __slots__ = ['dataset', 'indices', 'require_grad'] - - def __init__(self, dataset, indices, require_grad=False): - """ - TODO - """ - self.dataset = dataset - self.indices = indices - self.require_grad = require_grad - - def __len__(self): - """ - TODO - """ - if isinstance(self.indices, slice): - return self.indices.stop - self.indices.start - return len(self.indices) - - def __getattr__(self, name): - tensor = self.dataset.__getattribute__(name) - if isinstance(tensor, (LabelTensor, Tensor)): - if isinstance(self.indices, slice): - tensor = tensor[self.indices] - if (tensor.device != self.dataset.device - and tensor.dtype == float32): - tensor = tensor.to(self.dataset.device) - elif isinstance(self.indices, list): - tensor = tensor[[self.indices]].to(self.dataset.device) - else: - raise ValueError(f"Indices type {type(self.indices)} not " - f"supported") - return tensor.requires_grad_( - self.require_grad) if tensor.dtype == float32 else tensor - if isinstance(tensor, list): - if isinstance(self.indices, list): - return [tensor[i] for i in self.indices] - return tensor[self.indices] - raise AttributeError(f"No attribute named {name}") diff --git a/pina/data/sample_dataset.py b/pina/data/sample_dataset.py deleted file mode 100644 index bc3bca33..00000000 --- a/pina/data/sample_dataset.py +++ /dev/null @@ -1,35 +0,0 @@ -""" -Sample dataset module -""" -from copy import deepcopy -from .base_dataset import BaseDataset -from ..condition import InputPointsEquationCondition - - -class SamplePointDataset(BaseDataset): - """ - This class extends the BaseDataset to handle physical datasets - composed of only input points. - """ - data_type = 'physics' - __slots__ = InputPointsEquationCondition.__slots__ - - def add_points(self, data_dict, condition_idx, batching_dim=0): - data_dict = deepcopy(data_dict) - data_dict.pop('equation') - super().add_points(data_dict, condition_idx) - - def _init_from_problem(self, collector_dict): - for name, data in collector_dict.items(): - keys = list(data.keys()) - if set(self.__slots__) == set(keys): - data = deepcopy(data) - data.pop('equation') - self._populate_init_list(data) - idx = [ - key for key, val in - self.problem.collector.conditions_name.items() - if val == name - ] - self.conditions_idx.append(idx) - self.initialize() diff --git a/pina/data/supervised_dataset.py b/pina/data/supervised_dataset.py deleted file mode 100644 index be601050..00000000 --- a/pina/data/supervised_dataset.py +++ /dev/null @@ -1,13 +0,0 @@ -""" -Supervised dataset module -""" -from .base_dataset import BaseDataset - - -class SupervisedDataset(BaseDataset): - """ - This class extends the BaseDataset to handle datasets that consist of - input-output pairs. - """ - data_type = 'supervised' - __slots__ = ['input_points', 'output_points'] diff --git a/pina/data/unsupervised_dataset.py b/pina/data/unsupervised_dataset.py deleted file mode 100644 index 18cf296f..00000000 --- a/pina/data/unsupervised_dataset.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Unsupervised dataset module -""" -from .base_dataset import BaseDataset - - -class UnsupervisedDataset(BaseDataset): - """ - This class extend BaseDataset class to handle - unsupervised dataset,composed of input points - and, optionally, conditional variables - """ - data_type = 'unsupervised' - __slots__ = ['input_points', 'conditional_variables'] diff --git a/pina/label_tensor.py b/pina/label_tensor.py index 58dc8b71..a55a3cd5 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -4,9 +4,9 @@ import torch from torch import Tensor -full_labels = True -MATH_MODULES = {torch.sin, torch.cos, torch.exp, torch.tan, torch.log, - torch.sqrt} +full_labels = False +MATH_FUNCTIONS = {torch.sin, torch.cos} +GRAD_FUNCTIONS = {torch.autograd.grad} def issubset(a, b): @@ -26,6 +26,7 @@ class LabelTensor(torch.Tensor): @staticmethod def __new__(cls, x, labels, *args, **kwargs): full = kwargs.pop("full", full_labels) + if isinstance(x, LabelTensor): x.full = full return x @@ -48,62 +49,87 @@ def __init__(self, x, labels, **kwargs): """ self.dim_names = None self.full = kwargs.get('full', full_labels) - self.labels = labels + if labels is not None: + self.labels = labels + else: + self._labels = {} @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - if func in MATH_MODULES: + if func in MATH_FUNCTIONS: str_labels = func.__name__ - labels = copy(args[0].stored_labels) + lt = super().__torch_function__(func, types, args=args, kwargs=kwargs) - lt_shape = lt.shape - - if len(lt_shape) - 1 in labels.keys(): - labels.update({ - len(lt_shape) - 1: { - 'dof': [f'{str_labels}({i})' for i in - labels[len(lt_shape) - 1]['dof']], - 'name': len(lt_shape) - 1 - } - }) - lt._labels = labels - return lt + if hasattr(args[0], '_labels'): + labels = {k: copy(v) for k, v in args[0].stored_labels.items()} + lt._labels = labels + lt.dim_names = args[0].dim_names + + lt_shape = lt.shape + + if len(lt_shape) - 1 in labels.keys(): + labels.update({ + len(lt_shape) - 1: { + 'dof': [f'{str_labels}({i})' for i in + labels[len(lt_shape) - 1]['dof']], + 'name': len(lt_shape) - 1 + } + }) + lt._labels = labels + + return lt + if func in GRAD_FUNCTIONS: + # TODO: Implement the gradient of the LabelTensor + pass return super().__torch_function__(func, types, args=args, kwargs=kwargs) def __mul__(self, other): + lt = super().__mul__(other) + if not hasattr(self, '_labels'): + return lt if isinstance(other, (int, float)): if hasattr(self, '_labels'): lt._labels = self._labels + lt.dim_names = self.dim_names + if isinstance(other, LabelTensor): lt_shape = lt.shape - labels = copy(self.stored_labels) - other_labels = other.stored_labels + check = False - for (k, v), (ko, vo) in zip(sorted(labels.items()), - sorted(other_labels.items())): - if k != ko: - raise ValueError('Labels must be the same') - if k != len(lt_shape) - 1: - if vo != v: + if self.ndim in (0, 1): + labels = copy(other.stored_labels) + else: + labels = copy(self.stored_labels) + other_labels = copy(other.stored_labels) + for (k, v), (ko, vo) in zip(sorted(labels.items()), + sorted(other_labels.items())): + if k != ko: raise ValueError('Labels must be the same') - else: - check = True - if check: - labels.update({ - len(lt_shape) - 1: {'dof': [f'{i}{j}' for i, j in - zip(self.stored_labels[ - len(lt_shape) - 1]['dof'], - other.stored_labels[ - len(lt_shape) - 1]['dof'])], - 'name': self.stored_labels[ - len(lt_shape) - 1]['name']} - }) + if k != len(lt_shape) - 1: + if vo != v: + raise ValueError('Labels must be the same') + else: + check = True + if check: + labels.update({ + len(lt_shape) - 1: {'dof': [f'{i}{j}' for i, j in + zip(self.stored_labels[ + len(lt_shape) - 1][ + 'dof'], + other.stored_labels[ + len(lt_shape) - 1][ + 'dof'])], + 'name': self.stored_labels[ + len(lt_shape) - 1]['name']} + }) lt._labels = labels + lt.dim_names = self.dim_names + return lt @classmethod @@ -166,7 +192,7 @@ def labels(self, labels): self._labels = {} if isinstance(labels, dict): self._init_labels_from_dict(labels) - elif isinstance(labels, list): + elif isinstance(labels, (list, range)): self._init_labels_from_list(labels) elif isinstance(labels, str): labels = [labels] @@ -377,9 +403,11 @@ def to(self, *args, **kwargs): :meth:`torch.Tensor.to`. """ lt = super().to(*args, **kwargs) + return LabelTensor.__internal_init__(lt, self.stored_labels, self.dim_names) + def clone(self, *args, **kwargs): """ Clone the LabelTensor. For more details, see @@ -408,7 +436,7 @@ def summation(tensors): raise RuntimeError('Tensors must have the same shape and labels') last_dim_labels = [] - data = torch.zeros(tensors[0].tensor.shape) + data = torch.zeros(tensors[0].tensor.shape).to(tensors[0].device) for tensor in tensors: data += tensor.tensor last_dim_labels.append(tensor.labels) @@ -462,6 +490,7 @@ def __getitem__(self, index): :param index: :return: """ + if isinstance(index, str) or (isinstance(index, (tuple, list)) and all( isinstance(a, str) for a in index)): @@ -509,16 +538,19 @@ def _update_single_label(old_labels, to_update_labels, index, dim): old_dof = old_labels[dim]['dof'] if isinstance(index, torch.Tensor) and index.ndim == 0: index = int(index) + ''' if (not isinstance( index, (int, slice)) and len(index) == len(old_dof) and isinstance(old_dof, range)): return - + ''' if isinstance(index, torch.Tensor): if isinstance(old_dof, range): to_update_labels.update({ dim: { - 'dof': index.tolist(), + 'dof': index.tolist() if not ( + torch.diff(index) == 1).all() else + range(old_dof[index[0]], old_dof[index[-1]] + 1), 'name': old_labels[dim]['name'] } }) @@ -567,3 +599,21 @@ def permute(self, *dims): for k in stored_labels.keys() } return LabelTensor.__internal_init__(tensor, labels, self.dim_names) + + def detach(self): + lt = super().detach() + lt._labels = self.stored_labels + lt.dim_names = self.dim_names + return lt + + +class LabelParameter(torch.nn.Parameter, LabelTensor): + """A class that combines torch.nn.Parameter with LabelTensor behavior.""" + + def __new__(cls, x, labels=None, requires_grad=True): + instance = torch.nn.Parameter.__new__(cls, data=x, + requires_grad=requires_grad) + return instance + + def __init__(self, x, labels=None, requires_grad=True): + LabelTensor.__init__(self, x, labels) diff --git a/pina/model/network.py b/pina/model/network.py index 6fde8039..aec4f016 100644 --- a/pina/model/network.py +++ b/pina/model/network.py @@ -29,7 +29,8 @@ class is used internally in PINA to convert # check model consistency check_consistency(model, nn.Module) check_consistency(input_variables, str) - check_consistency(output_variables, str) + if output_variables is not None: + check_consistency(output_variables, str) self._model = model self._input_variables = input_variables @@ -67,16 +68,15 @@ def forward(self, x): # in case `input_variables = []` all points are used if self._input_variables: x = x.extract(self._input_variables) - # extract features and append for feature in self._extra_features: x = x.append(feature(x)) # perform forward pass + converting to LabelTensor - output = self._model(x).as_subclass(LabelTensor) - # set the labels for LabelTensor - output.labels = self._output_variables + output = self._model(x.as_subclass(torch.Tensor)) + if self._output_variables is not None: + output = LabelTensor(output, self._output_variables) return output @@ -97,15 +97,9 @@ def forward_map(self, x): This function does not extract the input variables, all the variables are used for both tensors. Output variables are correctly applied. """ - # convert LabelTensor s to torch.Tensor s - x = list(map(lambda x: x.as_subclass(torch.Tensor), x)) # perform forward pass (using torch.Tensor) + converting to LabelTensor - output = self._model(x).as_subclass(LabelTensor) - - # set the labels for LabelTensor - output.labels = self._output_variables - + output = LabelTensor(self._model(x.tensor), self._output_variables) return output @property diff --git a/pina/operators.py b/pina/operators.py index 0b306dfb..ef389a64 100644 --- a/pina/operators.py +++ b/pina/operators.py @@ -63,11 +63,9 @@ def grad_scalar_output(output_, input_, d): retain_graph=True, allow_unused=True, )[0] - - gradients.labels = input_.labels - gradients = gradients.extract(d) + gradients.labels = input_.stored_labels + gradients = gradients[..., [input_.labels.index(i) for i in d]] gradients.labels = [f"d{output_fieldname}d{i}" for i in d] - return gradients if not isinstance(input_, LabelTensor): @@ -190,7 +188,9 @@ def laplacian(output_, input_, components=None, d=None, method="std"): to_append_tensors = [] for i, label in enumerate(grad_output.labels): gg = grad(grad_output, input_, d=d, components=[label]) - to_append_tensors.append(gg.extract([gg.labels[i]])) + gg = gg.extract([gg.labels[i]]) + + to_append_tensors.append(gg) labels = [f"dd{components[0]}"] result = LabelTensor.summation(tensors=to_append_tensors) result.labels = labels diff --git a/pina/solvers/pinns/basepinn.py b/pina/solvers/pinns/basepinn.py index 2bca1823..6d63b17d 100644 --- a/pina/solvers/pinns/basepinn.py +++ b/pina/solvers/pinns/basepinn.py @@ -100,7 +100,7 @@ def __init__( self._optimizer = self._pina_optimizers[0] self._scheduler = self._pina_schedulers[0] - def training_step(self, batch, _): + def training_step(self, batch): """ The Physics Informed Solver Training Step. This function takes care of the physics informed training step, and it must not be override @@ -113,46 +113,64 @@ def training_step(self, batch, _): :return: The sum of the loss functions. :rtype: LabelTensor """ - condition_losses = [] - batches = batch.get_supervised_data() - for points in batches: - input_pts, output_pts, condition_id = points - condition_name = self._dataloader.condition_names[condition_id] - self.__logged_metric = condition_name - loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) - condition_losses.append(loss_.as_subclass(torch.Tensor)) - - batches = batch.get_physics_data() - for points in batches: - input_pts, condition_id = points - condition_name = self._dataloader.condition_names[condition_id] - condition = self.problem.conditions[condition_name] - self.__logged_metric = condition_name - loss_ = self.loss_phys(input_pts, condition.equation) - # add condition losses for each epoch - condition_losses.append(loss_.as_subclass(torch.Tensor)) + condition_loss = [] + for condition_name, points in batch.items(): + if 'output_points' in points: + input_pts, output_pts = points['input_points'], points['output_points'] + loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) + condition_loss.append(loss_.as_subclass(torch.Tensor)) + else: + input_pts = points['input_points'] + condition = self.problem.conditions[condition_name] + + loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation) + condition_loss.append(loss_.as_subclass(torch.Tensor)) + condition_loss.append(loss_.as_subclass(torch.Tensor)) # clamp unknown parameters in InverseProblem (if needed) self._clamp_params() - loss = sum(condition_losses) - + loss = sum(condition_loss) + self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, + logger=True) return loss + def validation_step(self, batch): + """ + TODO: add docstring + """ + condition_loss = [] + for condition_name, points in batch.items(): + if 'output_points' in points: + input_pts, output_pts = points['input_points'], points['output_points'] + loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) + condition_loss.append(loss_.as_subclass(torch.Tensor)) + else: + input_pts = points['input_points'] + condition = self.problem.conditions[condition_name] + with torch.set_grad_enabled(True): + loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation) + condition_loss.append(loss_.as_subclass(torch.Tensor)) + condition_loss.append(loss_.as_subclass(torch.Tensor)) + # clamp unknown parameters in InverseProblem (if needed) + + loss = sum(condition_loss) + self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, + logger=True) + def loss_data(self, input_pts, output_pts): """ The data loss for the PINN solver. It computes the loss between the network output against the true solution. This function should not be override if not intentionally. - :param LabelTensor input_tensor: The input to the neural networks. - :param LabelTensor output_tensor: The true solution to compare the + :param LabelTensor input_pts: The input to the neural networks. + :param LabelTensor output_pts: The true solution to compare the network solution. :return: The residual loss averaged on the input coordinates :rtype: torch.Tensor """ return self._loss(self.forward(input_pts), output_pts) - @abstractmethod def loss_phys(self, samples, equation): """ @@ -202,6 +220,9 @@ def store_log(self, loss_value): :param str name: The name of the loss. :param torch.Tensor loss_value: The value of the loss. """ + batch_size = self.trainer.data_module.batch_size \ + if self.trainer.data_module.batch_size is not None else 999 + self.log( self.__logged_metric + "_loss", loss_value, @@ -209,7 +230,7 @@ def store_log(self, loss_value): logger=True, on_epoch=True, on_step=False, - batch_size=self._dataloader.batch_size, + batch_size=batch_size, ) self.__logged_res_losses.append(loss_value) diff --git a/pina/solvers/pinns/pinn.py b/pina/solvers/pinns/pinn.py index 29949710..03fb7ac1 100644 --- a/pina/solvers/pinns/pinn.py +++ b/pina/solvers/pinns/pinn.py @@ -119,9 +119,8 @@ def loss_phys(self, samples, equation): """ residual = self.compute_residual(samples=samples, equation=equation) loss_value = self.loss( - torch.zeros_like(residual, requires_grad=True), residual + torch.zeros_like(residual), residual ) - self.store_log(loss_value=float(loss_value)) return loss_value def configure_optimizers(self): diff --git a/pina/solvers/solver.py b/pina/solvers/solver.py index fe9c897e..0fa8bc22 100644 --- a/pina/solvers/solver.py +++ b/pina/solvers/solver.py @@ -2,7 +2,7 @@ from abc import ABCMeta, abstractmethod from ..model.network import Network -import pytorch_lightning +import lightning from ..utils import check_consistency from ..problem import AbstractProblem from ..optim import Optimizer, Scheduler @@ -10,7 +10,8 @@ import sys -class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta): + +class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): """ Solver base class. This class inherits is a wrapper of LightningModule class, inheriting all the @@ -88,12 +89,20 @@ def __init__(self, self._pina_schedulers = schedulers self._pina_problem = problem + self.validation_condition_losses = { + k: {'loss': [], + 'count': []} for k in self.problem.conditions.keys()} + self.train_condition_losses = { + k: {'loss': [], + 'count': []} for k in self.problem.conditions.keys()} + + @abstractmethod def forward(self, *args, **kwargs): pass @abstractmethod - def training_step(self, batch, batch_idx): + def training_step(self, batch): pass @abstractmethod @@ -142,3 +151,49 @@ def _check_solver_consistency(self, problem): raise ValueError( f'{self.__name__} dose not support condition ' f'{condition.condition_type}') + + def epoch_logger(self, name): + if name == 'train': + losses_dict = self.train_condition_losses + elif name == 'val': + losses_dict = self.validation_condition_losses + total_loss = [] + total_count = [] + for k, v in losses_dict.items(): + local_counter = torch.tensor(v['count']).to(self.device) + n_elements = torch.sum(local_counter) + loss = torch.sum( + torch.stack(v['loss']) * local_counter) / n_elements + loss = loss.as_subclass(torch.Tensor) + total_loss.append(loss) + total_count.append(n_elements) + self.log( + k + f"_{name}_loss", + loss, + prog_bar=True, + logger=True, + on_epoch=True, + on_step=False, + batch_size=self.trainer.data_module.batch_size, + ) + total_count = (torch.tensor(total_count, dtype=torch.float32). + to(self.device)) + mean_loss = (torch.sum(torch.stack(total_loss) * total_count) / + total_count) + self.log( + f"_{name}_loss", + mean_loss, + prog_bar=True, + logger=True, + on_epoch=True, + on_step=False, + batch_size=self.trainer.data_module.batch_size, + ) + if name == 'val': + self.validation_condition_losses = { + k: {'loss': [], + 'count': []} for k in self.problem.conditions.keys()} + elif name == 'train': + self.train_condition_losses = { + k: {'loss': [], + 'count': []} for k in self.problem.conditions.keys()} \ No newline at end of file diff --git a/pina/solvers/supervised.py b/pina/solvers/supervised.py index 049518f1..0ff8f416 100644 --- a/pina/solvers/supervised.py +++ b/pina/solvers/supervised.py @@ -1,6 +1,7 @@ """ Module for SupervisedSolver """ import torch from pytorch_lightning.utilities.types import STEP_OUTPUT +from sympy.strategies.branch import condition from torch.nn.modules.loss import _Loss from ..optim import TorchOptimizer, TorchScheduler from .solver import SolverInterface @@ -47,7 +48,8 @@ def __init__(self, loss=None, optimizer=None, scheduler=None, - extra_features=None): + extra_features=None, + use_lt=True): """ :param AbstractProblem problem: The formualation of the problem. :param torch.nn.Module model: The neural network model to use. @@ -73,10 +75,11 @@ def __init__(self, problem=problem, optimizers=optimizer, schedulers=scheduler, - extra_features=extra_features) + extra_features=extra_features, + use_lt=use_lt) # check consistency - check_consistency(loss, (LossInterface, _Loss), + check_consistency(loss, (LossInterface, _Loss, torch.nn.Module), subclass=False) self._loss = loss self._model = self._pina_models[0] @@ -121,85 +124,25 @@ def training_step(self, batch): :rtype: LabelTensor """ condition_loss = [] - batches = batch.get_supervised_data() - for points in batches: - input_pts, output_pts, _ = points + for condition_name, points in batch.items(): + input_pts, output_pts = points['input_points'], points['output_points'] loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) condition_loss.append(loss_.as_subclass(torch.Tensor)) loss = sum(condition_loss) - self.log("mean_loss", float(loss), prog_bar=True, logger=True, - on_epoch=True, - on_step=False, batch_size=self.trainer.data_module.batch_size) + self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) return loss def validation_step(self, batch): """ Solver validation step. """ - - batch = batch.supervised - condition_idx = batch.condition_indices - for i in range(condition_idx.min(), condition_idx.max() + 1): - condition_name = self.trainer.data_module.condition_names[i] - condition = self.problem.conditions[condition_name] - pts = batch.input_points - out = batch.output_points - if condition_name not in self.problem.conditions: - raise RuntimeError("Something wrong happened.") - - # for data driven mode - if not hasattr(condition, "output_points"): - raise NotImplementedError( - f"{type(self).__name__} works only in data-driven mode.") - - output_pts = out[condition_idx == i] - input_pts = pts[condition_idx == i] - + condition_loss = [] + for condition_name, points in batch.items(): + input_pts, output_pts = points['input_points'], points['output_points'] loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) - self.validation_condition_losses[condition_name]['loss'].append( - loss_) - self.validation_condition_losses[condition_name]['count'].append( - len(input_pts)) - - def on_validation_epoch_end(self): - """ - Solver validation epoch end. - """ - total_loss = [] - total_count = [] - for k, v in self.validation_condition_losses.items(): - local_counter = torch.tensor(v['count']).to(self.device) - n_elements = torch.sum(local_counter) - loss = torch.sum( - torch.stack(v['loss']) * local_counter) / n_elements - loss = loss.as_subclass(torch.Tensor) - total_loss.append(loss) - total_count.append(n_elements) - self.log( - k + "_loss", - loss, - prog_bar=True, - logger=True, - on_epoch=True, - on_step=False, - batch_size=self.trainer.data_module.batch_size, - ) - total_count = (torch.tensor(total_count, dtype=torch.float32). - to(self.device)) - mean_loss = (torch.sum(torch.stack(total_loss) * total_count) / - total_count) - self.log( - "val_loss", - mean_loss, - prog_bar=True, - logger=True, - on_epoch=True, - on_step=False, - batch_size=self.trainer.data_module.batch_size, - ) - for key in self.validation_condition_losses.keys(): - self.validation_condition_losses[key]['loss'] = [] - self.validation_condition_losses[key]['count'] = [] + condition_loss.append(loss_.as_subclass(torch.Tensor)) + loss = sum(condition_loss) + self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) def test_step(self, batch, batch_idx) -> STEP_OUTPUT: """ diff --git a/pina/trainer.py b/pina/trainer.py index f5ea5513..1de93abb 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -1,13 +1,13 @@ """ Trainer module. """ import warnings import torch -import pytorch_lightning +import lightning from .utils import check_consistency from .data import PinaDataModule from .solvers.solver import SolverInterface -class Trainer(pytorch_lightning.Trainer): +class Trainer(lightning.pytorch.Trainer): def __init__(self, solver, @@ -34,6 +34,9 @@ def __init__(self, log_every_n_steps = kwargs.pop('log_every_n_steps', 0) super().__init__(log_every_n_steps=log_every_n_steps, **kwargs) + strategy = kwargs.get('strategy', None) + + # check inheritance consistency for solver and batch size check_consistency(solver, SolverInterface) if batch_size is not None: @@ -72,21 +75,14 @@ def _create_loader(self): raise RuntimeError('Cannot create Trainer if not all conditions ' 'are sampled. The Trainer got the following:\n' f'{error_message}') - devices = self._accelerator_connector._parallel_devices - - if len(devices) > 1: - raise RuntimeError("Parallel training is not supported yet.") - - device = devices[0] - - self.data_module = PinaDataModule(problem=self.solver.problem, - device=device, + self.data_module = PinaDataModule(collector=self.solver.problem.collector, train_size=self.train_size, test_size=self.test_size, val_size=self.val_size, predict_size=self.predict_size, batch_size=self.batch_size, ) - self.data_module.setup() + if self.batch_size is None: + self.data_module.setup() def train(self, **kwargs): """ @@ -100,6 +96,7 @@ def train(self, **kwargs): "`val_dataloader`", category=UserWarning ) + return super().fit(self.solver, datamodule=self.data_module, **kwargs) diff --git a/setup.py b/setup.py index c44cacf7..04f7b2fc 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ KEYWORDS = 'physics-informed neural-network' REQUIRED = [ - 'numpy', 'matplotlib', 'torch', 'lightning', 'pytorch_lightning', 'torch_geometric', 'torch-cluster' + 'numpy', 'matplotlib', 'torch', 'lightning', 'torch_geometric', 'torch-cluster' ] EXTRAS = {