diff --git a/pina/__init__.py b/pina/__init__.py index 793cf342..60801b67 100644 --- a/pina/__init__.py +++ b/pina/__init__.py @@ -15,4 +15,7 @@ from .plotter import Plotter from .condition import Condition from .geometry import Location -from .geometry import CartesianDomain \ No newline at end of file +from .geometry import CartesianDomain + +from .dataset import SamplePointDataset +from .dataset import SamplePointLoader \ No newline at end of file diff --git a/pina/callbacks/adaptive_refinment_callbacks.py b/pina/callbacks/adaptive_refinment_callbacks.py index 93a6ac84..664c8be6 100644 --- a/pina/callbacks/adaptive_refinment_callbacks.py +++ b/pina/callbacks/adaptive_refinment_callbacks.py @@ -1,6 +1,7 @@ '''PINA Callbacks Implementations''' -from lightning.pytorch.callbacks import Callback +# from lightning.pytorch.callbacks import Callback +from pytorch_lightning.callbacks import Callback import torch from ..utils import check_consistency diff --git a/pina/callbacks/optimizer_callbacks.py b/pina/callbacks/optimizer_callbacks.py index 4027b854..0f375f14 100644 --- a/pina/callbacks/optimizer_callbacks.py +++ b/pina/callbacks/optimizer_callbacks.py @@ -1,6 +1,6 @@ '''PINA Callbacks Implementations''' -from lightning.pytorch.callbacks import Callback +from pytorch_lightning.callbacks import Callback import torch from ..utils import check_consistency diff --git a/pina/dataset.py b/pina/dataset.py index 8093cd5d..650818cc 100644 --- a/pina/dataset.py +++ b/pina/dataset.py @@ -1,78 +1,240 @@ -from torch.utils.data import Dataset, DataLoader -import functools +from torch.utils.data import Dataset +import torch +from pina import LabelTensor -class PinaDataset(): +class SamplePointDataset(Dataset): + """ + This class is used to create a dataset of sample points. + """ - def __init__(self, pinn) -> None: - self.pinn = pinn + def __init__(self, problem, device) -> None: + """ + :param dict input_pts: The input points. + """ + super().__init__() + pts_list = [] + self.condition_names = [] + + for name, condition in problem.conditions.items(): + if not hasattr(condition, 'output_points'): + pts_list.append(problem.input_pts[name]) + self.condition_names.append(name) + + self.pts = LabelTensor.vstack(pts_list) + + if self.pts != []: + self.condition_indeces = torch.cat([ + torch.tensor([i]*len(pts_list[i])) + for i in range(len(self.condition_names)) + ], dim=0) + else: # if there are no sample points + self.condition_indeces = torch.tensor([]) + self.pts = torch.tensor([]) + + self.pts = self.pts.to(device) + self.condition_indeces = self.condition_indeces.to(device) + + def __len__(self): + return self.pts.shape[0] + + +class DataPointDataset(Dataset): + + def __init__(self, problem, device) -> None: + super().__init__() + input_list = [] + output_list = [] + self.condition_names = [] + + for name, condition in problem.conditions.items(): + if hasattr(condition, 'output_points'): + input_list.append(problem.conditions[name].input_points) + output_list.append(problem.conditions[name].output_points) + self.condition_names.append(name) + + self.input_pts = LabelTensor.vstack(input_list) + self.output_pts = LabelTensor.vstack(output_list) + + if self.input_pts != []: + self.condition_indeces = torch.cat([ + torch.tensor([i]*len(input_list[i])) + for i in range(len(self.condition_names)) + ], dim=0) + else: # if there are no data points + self.condition_indeces = torch.tensor([]) + self.input_pts = torch.tensor([]) + self.output_pts = torch.tensor([]) + + self.input_pts = self.input_pts.to(device) + self.output_pts = self.output_pts.to(device) + self.condition_indeces = self.condition_indeces.to(device) + + def __len__(self): + return self.input_pts.shape[0] + + +class SamplePointLoader: + """ + This class is used to create a dataloader to use during the training. + + :var condition_names: The names of the conditions. The order is consistent + with the condition indeces in the batches. + :vartype condition_names: list[str] + """ + + def __init__(self, sample_dataset, data_dataset, batch_size=None, shuffle=True) -> None: + """ + Constructor. + + :param SamplePointDataset sample_pts: The sample points dataset. + :param int batch_size: The batch size. If ``None``, the batch size is + set to the number of sample points. Default is ``None``. + :param bool shuffle: If ``True``, the sample points are shuffled. + Default is ``True``. + """ + if not isinstance(sample_dataset, SamplePointDataset): + raise TypeError(f'Expected SamplePointDataset, got {type(sample_dataset)}') + if not isinstance(data_dataset, DataPointDataset): + raise TypeError(f'Expected DataPointDataset, got {type(data_dataset)}') + + self.n_data_conditions = len(data_dataset.condition_names) + self.n_phys_conditions = len(sample_dataset.condition_names) + data_dataset.condition_indeces += self.n_phys_conditions + + self._prepare_sample_dataset(sample_dataset, batch_size, shuffle) + self._prepare_data_dataset(data_dataset, batch_size, shuffle) + + self.condition_names = ( + sample_dataset.condition_names + data_dataset.condition_names) + + self.batch_list = [] + for i in range(len(self.batch_sample_pts)): + self.batch_list.append( + ('sample', i) + ) - @property - def dataloader(self): - return self._create_dataloader() + for i in range(len(self.batch_input_pts)): + self.batch_list.append( + ('data', i) + ) - @property - def dataset(self): - return [self.SampleDataset(key, val) - for key, val in self.input_pts.items()] + if shuffle: + self.random_idx = torch.randperm(len(self.batch_list)) + else: + self.random_idx = torch.arange(len(self.batch_list)) - def _create_dataloader(self): - """Private method for creating dataloader - :return: dataloader - :rtype: torch.utils.data.DataLoader + def _prepare_data_dataset(self, dataset, batch_size, shuffle): """ - if self.pinn.batch_size is None: - return {key: [{key: val}] for key, val in self.pinn.input_pts.items()} - - def custom_collate(batch): - # extracting pts labels - _, pts = list(batch[0].items())[0] - labels = pts.labels - # calling default torch collate - collate_res = default_collate(batch) - # save collate result in dict - res = {} - for key, val in collate_res.items(): - val.labels = labels - res[key] = val - def __getitem__(self, index): - tensor = self._tensor.select(0, index) - return {self._location: tensor} - - def __len__(self): - return self._len - - - -# TODO: working also for datapoints -class DummyLoader: - - def __init__(self, data, device) -> None: - - # TODO: We need to make a dataset somehow - # and the PINADataset needs to have a method - # to send points to device - # now we simply do it here - # send data to device - def convert_tensors(pts, device): - pts = pts.to(device) - pts.requires_grad_(True) - pts.retain_grad() - return pts - - for location, pts in data.items(): - if isinstance(pts, (tuple, list)): - pts = tuple(map(functools.partial(convert_tensors, device=device),pts)) - else: - pts = pts.to(device) - pts = pts.requires_grad_(True) - pts.retain_grad() - - data[location] = pts + Prepare the dataset for data points. - # iterator - self.data = [data] + :param SamplePointDataset dataset: The dataset. + :param int batch_size: The batch size. + :param bool shuffle: If ``True``, the sample points are shuffled. + """ + self.sample_dataset = dataset + + if len(dataset) == 0: + self.batch_data_conditions = [] + self.batch_input_pts = [] + self.batch_output_pts = [] + return + + if batch_size is None: + batch_size = len(dataset) + batch_num = len(dataset) // batch_size + if len(dataset) % batch_size != 0: + batch_num += 1 + + output_labels = dataset.output_pts.labels + input_labels = dataset.input_pts.labels + self.tensor_conditions = dataset.condition_indeces + + if shuffle: + idx = torch.randperm(dataset.input_pts.shape[0]) + self.input_pts = dataset.input_pts[idx] + self.output_pts = dataset.output_pts[idx] + self.tensor_conditions = dataset.condition_indeces[idx] + + self.batch_input_pts = torch.tensor_split( + dataset.input_pts, batch_num) + self.batch_output_pts = torch.tensor_split( + dataset.output_pts, batch_num) + + for i in range(len(self.batch_input_pts)): + self.batch_input_pts[i].labels = input_labels + self.batch_output_pts[i].labels = output_labels + + self.batch_data_conditions = torch.tensor_split( + self.tensor_conditions, batch_num) + + def _prepare_sample_dataset(self, dataset, batch_size, shuffle): + """ + Prepare the dataset for sample points. + + :param DataPointDataset dataset: The dataset. + :param int batch_size: The batch size. + :param bool shuffle: If ``True``, the sample points are shuffled. + """ + + self.sample_dataset = dataset + if len(dataset) == 0: + self.batch_sample_conditions = [] + self.batch_sample_pts = [] + return + + if batch_size is None: + batch_size = len(dataset) + + batch_num = len(dataset) // batch_size + if len(dataset) % batch_size != 0: + batch_num += 1 + + self.tensor_pts = dataset.pts + self.tensor_conditions = dataset.condition_indeces + + # if shuffle: + # idx = torch.randperm(self.tensor_pts.shape[0]) + # self.tensor_pts = self.tensor_pts[idx] + # self.tensor_conditions = self.tensor_conditions[idx] + + self.batch_sample_pts = torch.tensor_split(self.tensor_pts, batch_num) + for i in range(len(self.batch_sample_pts)): + self.batch_sample_pts[i].labels = dataset.pts.labels + + self.batch_sample_conditions = torch.tensor_split( + self.tensor_conditions, batch_num) def __iter__(self): - return iter(self.data) + """ + Return an iterator over the points. Any element of the iterator is a + dictionary with the following keys: + - ``pts``: The input sample points. It is a LabelTensor with the + shape ``(batch_size, input_dimension)``. + - ``output``: The output sample points. This key is present only + if data conditions are present. It is a LabelTensor with the + shape ``(batch_size, output_dimension)``. + - ``condition``: The integer condition indeces. It is a tensor + with the shape ``(batch_size, )`` of type ``torch.int64`` and + indicates for any ``pts`` the corresponding problem condition. + + :return: An iterator over the points. + :rtype: iter + """ + #for i in self.random_idx: + for i in range(len(self.batch_list)): + type_, idx_ = self.batch_list[i] + + if type_ == 'sample': + d = { + 'pts': self.batch_sample_pts[idx_].requires_grad_(True), + 'condition': self.batch_sample_conditions[idx_], + } + else: + d = { + 'pts': self.batch_input_pts[idx_].requires_grad_(True), + 'output': self.batch_output_pts[idx_], + 'condition': self.batch_data_conditions[idx_], + } + yield d \ No newline at end of file diff --git a/pina/geometry/simplex.py b/pina/geometry/simplex.py index ac1e9c0e..c371aec3 100644 --- a/pina/geometry/simplex.py +++ b/pina/geometry/simplex.py @@ -55,13 +55,15 @@ def __init__(self, simplex_matrix, sample_surface=False): raise ValueError("An n-dimensional simplex is composed by n + 1 tensors of dimension n.") # creating vertices matrix - self._vertices_matrix = torch.cat(simplex_matrix) - self._vertices_matrix.labels = matrix_labels + self._vertices_matrix = LabelTensor.vstack(simplex_matrix) # creating basis vectors for simplex - self._vectors_shifted = ( - (self._vertices_matrix.T)[:, :-1] - (self._vertices_matrix.T)[:, None, -1] - ) + # self._vectors_shifted = ( + # (self._vertices_matrix.T)[:, :-1] - (self._vertices_matrix.T)[:, None, -1] + # ) ### TODO: Remove after checking + + vert = self._vertices_matrix + self._vectors_shifted = (vert[:-1] - vert[-1]).T # build cartesian_bound self._cartesian_bound = self._build_cartesian(self._vertices_matrix) @@ -114,8 +116,8 @@ def is_inside(self, point, check_border=False): f" expected {self.variables}." ) - # shift point - point_shift = point.T - (self._vertices_matrix.T)[:, None, -1] + point_shift = point - self._vertices_matrix[-1] + point_shift = point_shift.tensor.reshape(-1, 1) # compute barycentric coordinates lambda_ = torch.linalg.solve(self._vectors_shifted * 1.0, point_shift * 1.0) diff --git a/pina/label_tensor.py b/pina/label_tensor.py index 5645381f..df6a6792 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -96,6 +96,28 @@ def labels(self, labels): self._labels = labels # assign the label + @staticmethod + def vstack(label_tensors): + """ + Stack tensors vertically. For more details, see + :meth:`torch.vstack`. + + :param list(LabelTensor) label_tensors: the tensors to stack. They need + to have equal labels. + :return: the stacked tensor + :rtype: LabelTensor + """ + if len(label_tensors) == 0: + return [] + + all_labels = [label for lt in label_tensors for label in lt.labels] + if set(all_labels) != set(label_tensors[0].labels): + raise RuntimeError('The tensors to stack have different labels') + + labels = label_tensors[0].labels + tensors = [lt.extract(labels) for lt in label_tensors] + return LabelTensor(torch.vstack(tensors), labels) + # TODO remove try/ except thing IMPORTANT # make the label None of default def clone(self, *args, **kwargs): @@ -183,6 +205,18 @@ def extract(self, label_to_extract): return extracted_tensor + def detach(self): + detached = super().detach() + if hasattr(self, '_labels'): + detached._labels = self._labels + return detached + + + def requires_grad_(self, mode = True) -> Tensor: + lt = super().requires_grad_(mode) + lt.labels = self.labels + return lt + def append(self, lt, mode='std'): """ Return a copy of the merged tensors. @@ -232,7 +266,7 @@ def __getitem__(self, index): len_index = len(index) except TypeError: len_index = 1 - + if isinstance(index, int) or len_index == 1: if selected_lt.ndim == 1: selected_lt = selected_lt.reshape(1, -1) @@ -246,8 +280,14 @@ def __getitem__(self, index): selected_lt.labels = [self.labels[i] for i in index[1]] else: selected_lt.labels = self.labels[index[1]] + else: + selected_lt.labels = self.labels return selected_lt + + @property + def tensor(self): + return self.as_subclass(Tensor) def __len__(self) -> int: return super().__len__() diff --git a/pina/loss.py b/pina/loss.py index 91b07d7c..d2d0574d 100644 --- a/pina/loss.py +++ b/pina/loss.py @@ -111,8 +111,9 @@ def __init__(self, p=2, reduction = 'mean', relative = False): # check consistency check_consistency(p, (str,int,float)) - self.p = p check_consistency(relative, bool) + + self.p = p self.relative = relative def forward(self, input, target): diff --git a/pina/problem/abstract_problem.py b/pina/problem/abstract_problem.py index 6e840e05..95fdb651 100644 --- a/pina/problem/abstract_problem.py +++ b/pina/problem/abstract_problem.py @@ -22,6 +22,8 @@ def __init__(self): # varible to check if sampling is done. If no location # element is presented in Condition this variable is set to true self._have_sampled_points = {} + for condition_name in self.conditions: + self._have_sampled_points[condition_name] = False # put in self.input_pts all the points that we don't need to sample self._span_condition_points() @@ -102,15 +104,10 @@ def _span_condition_points(self): """ for condition_name in self.conditions: condition = self.conditions[condition_name] - if hasattr(condition, 'equation') and hasattr(condition, 'input_points'): + if hasattr(condition, 'input_points'): samples = condition.input_points - elif hasattr(condition, 'output_points') and hasattr(condition, 'input_points'): - samples = (condition.input_points, condition.output_points) - # skip if we need to sample - elif hasattr(condition, 'location'): - self._have_sampled_points[condition_name] = False - continue - self.input_pts[condition_name] = samples + self.input_pts[condition_name] = samples + self._have_sampled_points[condition_name] = True def discretise_domain(self, n, mode = 'random', variables = 'all', locations = 'all'): """ @@ -204,7 +201,7 @@ def discretise_domain(self, n, mode = 'random', variables = 'all', locations = ' def add_points(self, new_points): """ - Adding points to the already sampled points + Adding points to the already sampled points. :param dict new_points: a dictionary with key the location to add the points and values the torch.Tensor points. diff --git a/pina/solvers/garom.py b/pina/solvers/garom.py index cfb96656..f09e700a 100644 --- a/pina/solvers/garom.py +++ b/pina/solvers/garom.py @@ -115,12 +115,15 @@ def __init__(self, check_consistency(lambda_k, float) check_consistency(regularizer, bool) - # assign schedulers - self._schedulers = [scheduler_generator(self.optimizers[0], - **scheduler_generator_kwargs), - scheduler_discriminator(self.optimizers[1], - **scheduler_discriminator_kwargs)] + self._schedulers = [ + scheduler_generator( + self.optimizers[0], **scheduler_generator_kwargs), + scheduler_discriminator( + self.optimizers[1], + **scheduler_discriminator_kwargs) + ] + # loss and writer self._loss = loss @@ -157,6 +160,63 @@ def configure_optimizers(self): def sample(self, x): # sampling return self.generator(x) + + def _train_generator(self, parameters, snapshots): + """ + Private method to train the generator network. + """ + optimizer = self.optimizer_generator + + generated_snapshots = self.generator(parameters) + + # generator loss + r_loss = self._loss(snapshots, generated_snapshots) + d_fake = self.discriminator([generated_snapshots, parameters]) + g_loss = self._loss(d_fake, generated_snapshots) + self.regularizer * r_loss + + # backward step + g_loss.backward() + optimizer.step() + + return r_loss, g_loss + + def _train_discriminator(self, parameters, snapshots): + """ + Private method to train the discriminator network. + """ + optimizer = self.optimizer_discriminator + optimizer.zero_grad() + + # Generate a batch of images + generated_snapshots = self.generator(parameters) + + # Discriminator pass + d_real = self.discriminator([snapshots, parameters]) + d_fake = self.discriminator([generated_snapshots, parameters]) + + # evaluate loss + d_loss_real = self._loss(d_real, snapshots) + d_loss_fake = self._loss(d_fake, generated_snapshots.detach()) + d_loss = d_loss_real - self.k * d_loss_fake + + # backward step + d_loss.backward(retain_graph=True) + optimizer.step() + + return d_loss_real, d_loss_fake, d_loss + + def _update_weights(self, d_loss_real, d_loss_fake): + """ + Private method to Update the weights of the generator and discriminator + networks. + """ + + diff = torch.mean(self.gamma * d_loss_real - d_loss_fake) + + # Update weight term for fake samples + self.k += self.lambda_k * diff.item() + self.k = min(max(self.k, 0), 1) # Constraint to interval [0, 1] + return diff def training_step(self, batch, batch_idx): """PINN solver training step. @@ -169,77 +229,39 @@ def training_step(self, batch, batch_idx): :rtype: LabelTensor """ - for condition_name, samples in batch.items(): + dataloader = self.trainer.train_dataloader + condition_idx = batch['condition'] - if condition_name not in self.problem.conditions: - raise RuntimeError('Something wrong happened.') + for condition_id in range(condition_idx.min(), condition_idx.max()+1): + condition_name = dataloader.condition_names[condition_id] condition = self.problem.conditions[condition_name] + pts = batch['pts'] + out = batch['output'] - # for data driven mode - if hasattr(condition, 'output_points'): + if condition_name not in self.problem.conditions: + raise RuntimeError('Something wrong happened.') - # get data - parameters, input_pts = samples + # for data driven mode + if not hasattr(condition, 'output_points'): + raise NotImplementedError('GAROM works only in data-driven mode.') - # get optimizers - opt_gen, opt_disc = self.optimizers + # get data + snapshots = out[condition_idx == condition_id] + parameters = pts[condition_idx == condition_id] - # --------------------- - # Train Discriminator - # --------------------- - opt_disc.zero_grad() + d_loss_real, d_loss_fake, d_loss = self._train_discriminator( + parameters, snapshots) - # Generate a batch of images - gen_imgs = self.generator(parameters) + r_loss, g_loss = self._train_generator(parameters, snapshots) - # Discriminator pass - d_real = self.discriminator([input_pts, parameters]) - d_fake = self.discriminator([gen_imgs.detach(), parameters]) - - # evaluate loss - d_loss_real = self._loss(d_real, input_pts) - d_loss_fake = self._loss(d_fake, gen_imgs.detach()) - d_loss = d_loss_real - self.k * d_loss_fake - - # backward step - d_loss.backward() - opt_disc.step() - - # ----------------- - # Train Generator - # ----------------- - opt_gen.zero_grad() - - # Generate a batch of images - gen_imgs = self.generator(parameters) - - # generator loss - r_loss = self._loss(input_pts, gen_imgs) - d_fake = self.discriminator([gen_imgs, parameters]) - g_loss = self._loss(d_fake, gen_imgs) + self.regularizer * r_loss - - # backward step - g_loss.backward() - opt_gen.step() - - # ---------------- - # Update weights - # ---------------- - diff = torch.mean(self.gamma * d_loss_real - d_loss_fake) - - # Update weight term for fake samples - self.k += self.lambda_k * diff.item() - self.k = min(max(self.k, 0), 1) # Constraint to interval [0, 1] - - # logging - self.log('mean_loss', float(r_loss), prog_bar=True, logger=True) - self.log('d_loss', float(d_loss), prog_bar=True, logger=True) - self.log('g_loss', float(g_loss), prog_bar=True, logger=True) - self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True) - - else: - raise NotImplementedError('GAROM works only in data-driven mode.') + diff = self._update_weights(d_loss_real, d_loss_fake) + + # logging + self.log('mean_loss', float(r_loss), prog_bar=True, logger=True) + self.log('d_loss', float(d_loss), prog_bar=True, logger=True) + self.log('g_loss', float(g_loss), prog_bar=True, logger=True) + self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True) return diff --git a/pina/solvers/pinn.py b/pina/solvers/pinn.py index 04d2dcab..55d5f25d 100644 --- a/pina/solvers/pinn.py +++ b/pina/solvers/pinn.py @@ -97,6 +97,15 @@ def configure_optimizers(self): """ return self.optimizers, [self.scheduler] + def _loss_data(self, input, output): + return self.loss(self.forward(input), output) + + + def _loss_phys(self, samples, equation): + residual = equation.residual(samples, self.forward(samples)) + return self.loss(torch.zeros_like(residual, requires_grad=True), residual) + + def training_step(self, batch, batch_idx): """PINN solver training step. @@ -108,25 +117,29 @@ def training_step(self, batch, batch_idx): :rtype: LabelTensor """ + dataloader = self.trainer.train_dataloader condition_losses = [] - condition_names = [] - for condition_name, samples in batch.items(): + condition_idx = batch['condition'] - if condition_name not in self.problem.conditions: - raise RuntimeError('Something wrong happened.') + for condition_id in range(condition_idx.min(), condition_idx.max()+1): - condition_names.append(condition_name) + condition_name = dataloader.condition_names[condition_id] condition = self.problem.conditions[condition_name] + pts = batch['pts'] + + if len(batch) == 2: + samples = pts[condition_idx == condition_id] + loss = self._loss_phys(pts, condition.equation) + elif len(batch) == 3: + samples = pts[condition_idx == condition_id] + ground_truth = batch['output'][condition_idx == condition_id] + loss = self._loss_data(samples, ground_truth) + else: + raise ValueError("Batch size not supported") - # PINN loss: equation evaluated on location or input_points - if hasattr(condition, 'equation'): - target = condition.equation.residual(samples, self.forward(samples)) - loss = self.loss(torch.zeros_like(target), target) - # PINN loss: evaluate model(input_points) vs output_points - elif hasattr(condition, 'output_points'): - input_pts, output_pts = samples - loss = self.loss(self.forward(input_pts), output_pts) + loss = loss.as_subclass(torch.Tensor) + loss = loss condition_losses.append(loss * condition.data_weight) @@ -135,8 +148,8 @@ def training_step(self, batch, batch_idx): total_loss = sum(condition_losses) self.log('mean_loss', float(total_loss / len(condition_losses)), prog_bar=True, logger=True) - for condition_loss, loss in zip(condition_names, condition_losses): - self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=True) + # for condition_loss, loss in zip(condition_names, condition_losses): + # self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=True) return total_loss @property diff --git a/pina/solvers/solver.py b/pina/solvers/solver.py index fbe4553a..c2058401 100644 --- a/pina/solvers/solver.py +++ b/pina/solvers/solver.py @@ -2,7 +2,7 @@ from abc import ABCMeta, abstractmethod from ..model.network import Network -import lightning.pytorch as pl +import pytorch_lightning as pl from ..utils import check_consistency from ..problem import AbstractProblem import torch diff --git a/pina/trainer.py b/pina/trainer.py index be220d24..fc2a9382 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -1,18 +1,19 @@ """ Solver module. """ -import lightning.pytorch as pl +from pytorch_lightning import Trainer from .utils import check_consistency -from .dataset import DummyLoader +from .dataset import SamplePointDataset, SamplePointLoader, DataPointDataset from .solvers.solver import SolverInterface -class Trainer(pl.Trainer): +class Trainer(Trainer): - def __init__(self, solver, **kwargs): + def __init__(self, solver, batch_size=None, **kwargs): super().__init__(**kwargs) # check inheritance consistency for solver check_consistency(solver, SolverInterface) self._model = solver + self.batch_size = batch_size # create dataloader if solver.problem.have_sampled_points is False: @@ -22,19 +23,31 @@ def __init__(self, solver, **kwargs): 'discretise_domain function before train ' 'in the provided locations.') - # TODO: make a better dataloader for train self._create_or_update_loader() - # this method is used here because is resampling is needed - # during training, there is no need to define to touch the - # trainer dataloader, just call the method. def _create_or_update_loader(self): - # get accellerator - device = self._accelerator_connector._accelerator_flag - self._loader = DummyLoader(self._model.problem.input_pts, device) + """ + This method is used here because is resampling is needed + during training, there is no need to define to touch the + trainer dataloader, just call the method. + """ + devices = self._accelerator_connector._parallel_devices + + if len(devices) > 1: + raise RuntimeError('Parallel training is not supported yet.') - def train(self, **kwargs): # TODO add kwargs and lightining capabilities - return super().fit(self._model, self._loader, **kwargs) + device = devices[0] + dataset_phys = SamplePointDataset(self._model.problem, device) + dataset_data = DataPointDataset(self._model.problem, device) + self._loader = SamplePointLoader( + dataset_phys, dataset_data, batch_size=self.batch_size, + shuffle=True) + + def train(self, **kwargs): + """ + Train the solver. + """ + return super().fit(self._model, train_dataloaders=self._loader, **kwargs) @property def solver(self): diff --git a/setup.py b/setup.py index b61c8880..0e5cf670 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ KEYWORDS = 'physics-informed neural-network' REQUIRED = [ - 'numpy', 'matplotlib', 'torch', 'lightning' + 'numpy', 'matplotlib', 'torch', 'lightning', 'pytorch_lightning' ] EXTRAS = { diff --git a/tests/test_callbacks/test_adaptive_refinment_callbacks.py b/tests/test_callbacks/test_adaptive_refinment_callbacks.py index cb1f85bd..11440250 100644 --- a/tests/test_callbacks/test_adaptive_refinment_callbacks.py +++ b/tests/test_callbacks/test_adaptive_refinment_callbacks.py @@ -44,9 +44,9 @@ class Poisson(SpatialProblem): 'D': Condition( input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']), equation=my_laplace), - 'data': Condition( - input_points=in_, - output_points=out_) + # 'data': Condition( + # input_points=in_, + # output_points=out_) } diff --git a/tests/test_callbacks/test_optimizer_callbacks.py b/tests/test_callbacks/test_optimizer_callbacks.py index d385a67b..9250ae10 100644 --- a/tests/test_callbacks/test_optimizer_callbacks.py +++ b/tests/test_callbacks/test_optimizer_callbacks.py @@ -44,9 +44,9 @@ class Poisson(SpatialProblem): 'D': Condition( input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']), equation=my_laplace), - 'data': Condition( - input_points=in_, - output_points=out_) + # 'data': Condition( + # input_points=in_, + # output_points=out_) } diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 00000000..ff1b6c22 --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,122 @@ +import torch +import pytest + +from pina.dataset import SamplePointDataset, SamplePointLoader, DataPointDataset +from pina import LabelTensor, Condition +from pina.equation import Equation +from pina.geometry import CartesianDomain +from pina.problem import SpatialProblem +from pina.model import FeedForward +from pina.operators import laplacian +from pina.equation.equation_factory import FixedValue + + +def laplace_equation(input_, output_): + force_term = (torch.sin(input_.extract(['x'])*torch.pi) * + torch.sin(input_.extract(['y'])*torch.pi)) + delta_u = laplacian(output_.extract(['u']), input_) + return delta_u - force_term + +my_laplace = Equation(laplace_equation) +in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y']) +out_ = LabelTensor(torch.tensor([[0.]]), ['u']) +in2_ = LabelTensor(torch.rand(60, 2), ['x', 'y']) +out2_ = LabelTensor(torch.rand(60, 1), ['u']) + +class Poisson(SpatialProblem): + output_variables = ['u'] + spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]}) + + conditions = { + 'gamma1': Condition( + location=CartesianDomain({'x': [0, 1], 'y': 1}), + equation=FixedValue(0.0)), + 'gamma2': Condition( + location=CartesianDomain({'x': [0, 1], 'y': 0}), + equation=FixedValue(0.0)), + 'gamma3': Condition( + location=CartesianDomain({'x': 1, 'y': [0, 1]}), + equation=FixedValue(0.0)), + 'gamma4': Condition( + location=CartesianDomain({'x': 0, 'y': [0, 1]}), + equation=FixedValue(0.0)), + 'D': Condition( + input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']), + equation=my_laplace), + 'data': Condition( + input_points=in_, + output_points=out_), + 'data2': Condition( + input_points=in2_, + output_points=out2_) + } + +boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +poisson = Poisson() +poisson.discretise_domain(10, 'grid', locations=boundaries) + +def test_sample(): + sample_dataset = SamplePointDataset(poisson, device='cpu') + assert len(sample_dataset) == 140 + assert sample_dataset.pts.shape == (140, 2) + assert sample_dataset.pts.labels == ['x', 'y'] + assert sample_dataset.condition_indeces.dtype == torch.int64 + assert sample_dataset.condition_indeces.max() == torch.tensor(4) + assert sample_dataset.condition_indeces.min() == torch.tensor(0) + +def test_data(): + dataset = DataPointDataset(poisson, device='cpu') + assert len(dataset) == 61 + assert dataset.input_pts.shape == (61, 2) + assert dataset.input_pts.labels == ['x', 'y'] + assert dataset.output_pts.shape == (61, 1 ) + assert dataset.output_pts.labels == ['u'] + assert dataset.condition_indeces.dtype == torch.int64 + assert dataset.condition_indeces.max() == torch.tensor(1) + assert dataset.condition_indeces.min() == torch.tensor(0) + +def test_loader(): + sample_dataset = SamplePointDataset(poisson, device='cpu') + data_dataset = DataPointDataset(poisson, device='cpu') + loader = SamplePointLoader(sample_dataset, data_dataset, batch_size=10) + + for batch in loader: + assert len(batch) in [2, 3] + assert batch['pts'].shape[0] <= 10 + assert batch['pts'].requires_grad == True + assert batch['pts'].labels == ['x', 'y'] + + loader2 = SamplePointLoader(sample_dataset, data_dataset, batch_size=None) + assert len(list(loader2)) == 2 + +def test_loader2(): + poisson2 = Poisson() + del poisson.conditions['data2'] + del poisson2.conditions['data'] + poisson2.discretise_domain(10, 'grid', locations=boundaries) + sample_dataset = SamplePointDataset(poisson, device='cpu') + data_dataset = DataPointDataset(poisson, device='cpu') + loader = SamplePointLoader(sample_dataset, data_dataset, batch_size=10) + + for batch in loader: + assert len(batch) == 2 # only phys condtions + assert batch['pts'].shape[0] <= 10 + assert batch['pts'].requires_grad == True + assert batch['pts'].labels == ['x', 'y'] + +def test_loader3(): + poisson2 = Poisson() + del poisson.conditions['gamma1'] + del poisson.conditions['gamma2'] + del poisson.conditions['gamma3'] + del poisson.conditions['gamma4'] + del poisson.conditions['D'] + sample_dataset = SamplePointDataset(poisson, device='cpu') + data_dataset = DataPointDataset(poisson, device='cpu') + loader = SamplePointLoader(sample_dataset, data_dataset, batch_size=10) + + for batch in loader: + assert len(batch) == 2 # only phys condtions + assert batch['pts'].shape[0] <= 10 + assert batch['pts'].requires_grad == True + assert batch['pts'].labels == ['x', 'y'] diff --git a/tests/test_label_tensor.py b/tests/test_label_tensor.py index 4e6a3024..1365aec0 100644 --- a/tests/test_label_tensor.py +++ b/tests/test_label_tensor.py @@ -95,10 +95,14 @@ def test_getitem(): def test_getitem2(): tensor = LabelTensor(data, labels) tensor_view = tensor[:5] - assert tensor_view.labels == labels assert torch.allclose(tensor_view, data[:5]) + idx = torch.randperm(tensor.shape[0]) + tensor_view = tensor[idx] + assert tensor_view.labels == labels + + def test_slice(): tensor = LabelTensor(data, labels) tensor_view = tensor[:5, :2] diff --git a/tests/test_solvers/test_garom.py b/tests/test_solvers/test_garom.py index 754dfd76..3087d7d5 100644 --- a/tests/test_solvers/test_garom.py +++ b/tests/test_solvers/test_garom.py @@ -134,7 +134,7 @@ def test_train_cpu(): hidden_dimension=64) ) - trainer = Trainer(solver=solver, max_epochs=4, accelerator='cpu') + trainer = Trainer(solver=solver, max_epochs=4, accelerator='cpu', batch_size=20) trainer.train() def test_sample(): diff --git a/tests/test_solvers/test_pinn.py b/tests/test_solvers/test_pinn.py index 3b8d0a06..47570ac4 100644 --- a/tests/test_solvers/test_pinn.py +++ b/tests/test_solvers/test_pinn.py @@ -22,6 +22,8 @@ def laplace_equation(input_, output_): my_laplace = Equation(laplace_equation) in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y']) out_ = LabelTensor(torch.tensor([[0.]]), ['u']) +in2_ = LabelTensor(torch.rand(60, 2), ['x', 'y']) +out2_ = LabelTensor(torch.rand(60, 1), ['u']) class Poisson(SpatialProblem): output_variables = ['u'] @@ -45,7 +47,10 @@ class Poisson(SpatialProblem): equation=my_laplace), 'data': Condition( input_points=in_, - output_points=out_) + output_points=out_), + 'data2': Condition( + input_points=in2_, + output_points=out2_) } def poisson_sol(self, pts): @@ -92,7 +97,7 @@ def test_train_cpu(): n = 10 poisson_problem.discretise_domain(n, 'grid', locations=boundaries) pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) - trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu') + trainer = Trainer(solver=pinn, max_epochs=1, accelerator='cpu', batch_size=20) trainer.train() def test_train_restore(): @@ -106,7 +111,7 @@ def test_train_restore(): trainer.train() ntrainer = Trainer(solver=pinn, max_epochs=15, accelerator='cpu') t = ntrainer.train( - ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt') + ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=10.ckpt') import shutil shutil.rmtree(tmpdir) @@ -121,7 +126,7 @@ def test_train_load(): default_root_dir=tmpdir) trainer.train() new_pinn = PINN.load_from_checkpoint( - f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=15.ckpt', + f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt', problem = poisson_problem, model=model) test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10) assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1)