From 05050f4b0967d3c7c79a9b48498e51c5213d841d Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Tue, 7 Nov 2023 12:13:29 +0100 Subject: [PATCH] Modifying solvers to log every epoch correctly * add `on_epoch` flag to logger * fix bug in `pinn.py` `pts -> samples` in `_loss_phys` * add `optimizer_zero_grad()` in garom generator training loop * modify imports in `callbacks.py` --- pina/callbacks/processing_callbacks.py | 2 +- pina/label_tensor.py | 49 +++++--------------------- pina/solvers/garom.py | 9 ++--- pina/solvers/pinn.py | 15 ++++---- 4 files changed, 23 insertions(+), 52 deletions(-) diff --git a/pina/callbacks/processing_callbacks.py b/pina/callbacks/processing_callbacks.py index c382c6c0..74ccc44e 100644 --- a/pina/callbacks/processing_callbacks.py +++ b/pina/callbacks/processing_callbacks.py @@ -1,6 +1,6 @@ '''PINA Callbacks Implementations''' -from lightning.pytorch.callbacks import Callback +from pytorch_lightning.callbacks import Callback import torch import copy diff --git a/pina/label_tensor.py b/pina/label_tensor.py index df6a6792..d612b47b 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -33,14 +33,6 @@ def __init__(self, x, labels): [1.0246e-01, 9.5179e-01, 3.7043e-02], [9.6150e-01, 8.0656e-01, 8.3824e-01]]) >>> tensor.extract('a') - tensor([[0.0671], - [0.9239], - [0.8927], - ..., - [0.5819], - [0.1025], - [0.9615]]) - >>> tensor['a'] tensor([[0.0671], [0.9239], [0.8927], @@ -77,7 +69,7 @@ def __init__(self, x, labels): 'the passed labels.' ) self._labels = labels - + @property def labels(self): """Property decorator for labels @@ -118,8 +110,6 @@ def vstack(label_tensors): tensors = [lt.extract(labels) for lt in label_tensors] return LabelTensor(torch.vstack(tensors), labels) - # TODO remove try/ except thing IMPORTANT - # make the label None of default def clone(self, *args, **kwargs): """ Clone the LabelTensor. For more details, see @@ -128,11 +118,12 @@ def clone(self, *args, **kwargs): :return: a copy of the tensor :rtype: LabelTensor """ - try: - out = LabelTensor(super().clone(*args, **kwargs), self.labels) - except: # this is used when the tensor loose the labels, notice it will create a bug! Kept for compatibility with Lightining - out = super().clone(*args, **kwargs) - + # used before merging + # try: + # out = LabelTensor(super().clone(*args, **kwargs), self.labels) + # except: + # out = super().clone(*args, **kwargs) + out = LabelTensor(super().clone(*args, **kwargs), self.labels) return out def to(self, *args, **kwargs): @@ -153,24 +144,6 @@ def select(self, *args, **kwargs): tmp._labels = self._labels return tmp - def cuda(self, *args, **kwargs): - """ - Send Tensor to cuda. For more details, see :meth:`torch.Tensor.cuda`. - """ - tmp = super().cuda(*args, **kwargs) - new = self.__class__.clone(self) - new.data = tmp.data - return tmp - - def cpu(self, *args, **kwargs): - """ - Send Tensor to cpu. For more details, see :meth:`torch.Tensor.cpu`. - """ - tmp = super().cpu(*args, **kwargs) - new = self.__class__.clone(self) - new.data = tmp.data - return tmp - def extract(self, label_to_extract): """ Extract the subset of the original tensor by returning all the columns @@ -197,7 +170,7 @@ def extract(self, label_to_extract): except ValueError: raise ValueError(f'`{f}` not in the labels list') - new_data = super(Tensor, self.T).__getitem__(indeces).T + new_data = super(Tensor, self.T).__getitem__(indeces).float().T new_labels = [self.labels[idx] for idx in indeces] extracted_tensor = new_data.as_subclass(LabelTensor) @@ -256,12 +229,8 @@ def __getitem__(self, index): """ Return a copy of the selected tensor. """ - - if isinstance(index, str) or (isinstance(index, (tuple, list))and all(isinstance(a, str) for a in index)): - return self.extract(index) - selected_lt = super(Tensor, self).__getitem__(index) - + try: len_index = len(index) except TypeError: diff --git a/pina/solvers/garom.py b/pina/solvers/garom.py index f09e700a..ddc16a37 100644 --- a/pina/solvers/garom.py +++ b/pina/solvers/garom.py @@ -166,6 +166,7 @@ def _train_generator(self, parameters, snapshots): Private method to train the generator network. """ optimizer = self.optimizer_generator + optimizer.zero_grad() generated_snapshots = self.generator(parameters) @@ -258,10 +259,10 @@ def training_step(self, batch, batch_idx): diff = self._update_weights(d_loss_real, d_loss_fake) # logging - self.log('mean_loss', float(r_loss), prog_bar=True, logger=True) - self.log('d_loss', float(d_loss), prog_bar=True, logger=True) - self.log('g_loss', float(g_loss), prog_bar=True, logger=True) - self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True) + self.log('mean_loss', float(r_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False) + self.log('d_loss', float(d_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False) + self.log('g_loss', float(g_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False) + self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True, on_epoch=True, on_step=False) return diff --git a/pina/solvers/pinn.py b/pina/solvers/pinn.py index 55d5f25d..e986bb63 100644 --- a/pina/solvers/pinn.py +++ b/pina/solvers/pinn.py @@ -130,7 +130,7 @@ def training_step(self, batch, batch_idx): if len(batch) == 2: samples = pts[condition_idx == condition_id] - loss = self._loss_phys(pts, condition.equation) + loss = self._loss_phys(samples, condition.equation) elif len(batch) == 3: samples = pts[condition_idx == condition_id] ground_truth = batch['output'][condition_idx == condition_id] @@ -138,18 +138,19 @@ def training_step(self, batch, batch_idx): else: raise ValueError("Batch size not supported") + # TODO for users this us hard to remebeber when creating a new solver, to fix in a smarter way loss = loss.as_subclass(torch.Tensor) - loss = loss + # add condition losses and accumulate logging for each epoch condition_losses.append(loss * condition.data_weight) + self.log(condition_name + '_loss', float(loss), + prog_bar=True, logger=True, on_epoch=True, on_step=False) - # TODO Fix the bug, tot_loss is a label tensor without labels - # we need to pass it as a torch tensor to make everything work + # add to tot loss and accumulate logging for each epoch total_loss = sum(condition_losses) + self.log('mean_loss', float(total_loss / len(condition_losses)), + prog_bar=True, logger=True, on_epoch=True, on_step=False) - self.log('mean_loss', float(total_loss / len(condition_losses)), prog_bar=True, logger=True) - # for condition_loss, loss in zip(condition_names, condition_losses): - # self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=True) return total_loss @property