From c3a80e29be66cba68645608d95c052dca9d2cf79 Mon Sep 17 00:00:00 2001 From: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> Date: Wed, 8 Nov 2023 14:10:23 +0100 Subject: [PATCH] Solvers logging (#202) * Modifying solvers to log every epoch correctly * add `on_epoch` flag to logger * fix bug in `pinn.py` `pts -> samples` in `_loss_phys` * add `optimizer_zero_grad()` in garom generator training loop * modify imports in `callbacks.py` * fixing tests --------- Co-authored-by: Dario Coscia --- pina/callbacks/adaptive_refinment_callbacks.py | 14 +++++++------- pina/callbacks/processing_callbacks.py | 2 +- pina/label_tensor.py | 15 +++++++-------- pina/solvers/garom.py | 9 +++++---- pina/solvers/pinn.py | 15 ++++++++------- 5 files changed, 28 insertions(+), 27 deletions(-) diff --git a/pina/callbacks/adaptive_refinment_callbacks.py b/pina/callbacks/adaptive_refinment_callbacks.py index 664c8be6..1be904d5 100644 --- a/pina/callbacks/adaptive_refinment_callbacks.py +++ b/pina/callbacks/adaptive_refinment_callbacks.py @@ -61,7 +61,7 @@ def _compute_residual(self, trainer): pts.retain_grad() # PINN loss: equation evaluated only on locations where sampling is needed target = condition.equation.residual(pts, solver.forward(pts)) - res_loss[location] = torch.abs(target) + res_loss[location] = torch.abs(target).as_subclass(torch.Tensor) tot_loss.append(torch.abs(target)) return torch.vstack(tot_loss), res_loss @@ -74,6 +74,7 @@ def _r3_routine(self, trainer): """ # compute residual (all device possible) tot_loss, res_loss = self._compute_residual(trainer) + tot_loss = tot_loss.as_subclass(torch.Tensor) # !!!!!! From now everything is performed on CPU !!!!!! @@ -89,12 +90,11 @@ def _r3_routine(self, trainer): pts = pts.cpu().detach() residuals = res_loss[location].cpu() mask = (residuals > avg).flatten() - # TODO masking remove labels - pts = pts[mask] - pts.labels = labels - #### - old_pts[location] = pts - tot_points += len(pts) + if any(mask): # if there are residuals greater than averge we append them + pts = pts[mask] # TODO masking remove labels + pts.labels = labels + old_pts[location] = pts + tot_points += len(pts) # extract new points to sample uniformally for each location n_points = (self._tot_pop_numb - tot_points ) // len(self._sampling_locations) diff --git a/pina/callbacks/processing_callbacks.py b/pina/callbacks/processing_callbacks.py index c382c6c0..74ccc44e 100644 --- a/pina/callbacks/processing_callbacks.py +++ b/pina/callbacks/processing_callbacks.py @@ -1,6 +1,6 @@ '''PINA Callbacks Implementations''' -from lightning.pytorch.callbacks import Callback +from pytorch_lightning.callbacks import Callback import torch import copy diff --git a/pina/label_tensor.py b/pina/label_tensor.py index df6a6792..c11f7a18 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -118,8 +118,6 @@ def vstack(label_tensors): tensors = [lt.extract(labels) for lt in label_tensors] return LabelTensor(torch.vstack(tensors), labels) - # TODO remove try/ except thing IMPORTANT - # make the label None of default def clone(self, *args, **kwargs): """ Clone the LabelTensor. For more details, see @@ -128,11 +126,12 @@ def clone(self, *args, **kwargs): :return: a copy of the tensor :rtype: LabelTensor """ - try: - out = LabelTensor(super().clone(*args, **kwargs), self.labels) - except: # this is used when the tensor loose the labels, notice it will create a bug! Kept for compatibility with Lightining - out = super().clone(*args, **kwargs) - + # # used before merging + # try: + # out = LabelTensor(super().clone(*args, **kwargs), self.labels) + # except: + # out = super().clone(*args, **kwargs) + out = LabelTensor(super().clone(*args, **kwargs), self.labels) return out def to(self, *args, **kwargs): @@ -298,4 +297,4 @@ def __str__(self): else: s = 'no labels\n' s += super().__str__() - return s + return s \ No newline at end of file diff --git a/pina/solvers/garom.py b/pina/solvers/garom.py index f09e700a..ddc16a37 100644 --- a/pina/solvers/garom.py +++ b/pina/solvers/garom.py @@ -166,6 +166,7 @@ def _train_generator(self, parameters, snapshots): Private method to train the generator network. """ optimizer = self.optimizer_generator + optimizer.zero_grad() generated_snapshots = self.generator(parameters) @@ -258,10 +259,10 @@ def training_step(self, batch, batch_idx): diff = self._update_weights(d_loss_real, d_loss_fake) # logging - self.log('mean_loss', float(r_loss), prog_bar=True, logger=True) - self.log('d_loss', float(d_loss), prog_bar=True, logger=True) - self.log('g_loss', float(g_loss), prog_bar=True, logger=True) - self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True) + self.log('mean_loss', float(r_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False) + self.log('d_loss', float(d_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False) + self.log('g_loss', float(g_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False) + self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True, on_epoch=True, on_step=False) return diff --git a/pina/solvers/pinn.py b/pina/solvers/pinn.py index 55d5f25d..e986bb63 100644 --- a/pina/solvers/pinn.py +++ b/pina/solvers/pinn.py @@ -130,7 +130,7 @@ def training_step(self, batch, batch_idx): if len(batch) == 2: samples = pts[condition_idx == condition_id] - loss = self._loss_phys(pts, condition.equation) + loss = self._loss_phys(samples, condition.equation) elif len(batch) == 3: samples = pts[condition_idx == condition_id] ground_truth = batch['output'][condition_idx == condition_id] @@ -138,18 +138,19 @@ def training_step(self, batch, batch_idx): else: raise ValueError("Batch size not supported") + # TODO for users this us hard to remebeber when creating a new solver, to fix in a smarter way loss = loss.as_subclass(torch.Tensor) - loss = loss + # add condition losses and accumulate logging for each epoch condition_losses.append(loss * condition.data_weight) + self.log(condition_name + '_loss', float(loss), + prog_bar=True, logger=True, on_epoch=True, on_step=False) - # TODO Fix the bug, tot_loss is a label tensor without labels - # we need to pass it as a torch tensor to make everything work + # add to tot loss and accumulate logging for each epoch total_loss = sum(condition_losses) + self.log('mean_loss', float(total_loss / len(condition_losses)), + prog_bar=True, logger=True, on_epoch=True, on_step=False) - self.log('mean_loss', float(total_loss / len(condition_losses)), prog_bar=True, logger=True) - # for condition_loss, loss in zip(condition_names, condition_losses): - # self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=True) return total_loss @property