Skip to content

Commit

Permalink
Solvers logging (#202)
Browse files Browse the repository at this point in the history
* Modifying solvers to log every epoch correctly
* add `on_epoch` flag to logger
* fix bug in `pinn.py` `pts -> samples` in `_loss_phys`
* add `optimizer_zero_grad()` in garom generator training loop
* modify imports in `callbacks.py`
* fixing tests

---------

Co-authored-by: Dario Coscia <[email protected]>
  • Loading branch information
dario-coscia and Dario Coscia authored Nov 8, 2023
1 parent 9353834 commit c3a80e2
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 27 deletions.
14 changes: 7 additions & 7 deletions pina/callbacks/adaptive_refinment_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _compute_residual(self, trainer):
pts.retain_grad()
# PINN loss: equation evaluated only on locations where sampling is needed
target = condition.equation.residual(pts, solver.forward(pts))
res_loss[location] = torch.abs(target)
res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
tot_loss.append(torch.abs(target))

return torch.vstack(tot_loss), res_loss
Expand All @@ -74,6 +74,7 @@ def _r3_routine(self, trainer):
"""
# compute residual (all device possible)
tot_loss, res_loss = self._compute_residual(trainer)
tot_loss = tot_loss.as_subclass(torch.Tensor)

# !!!!!! From now everything is performed on CPU !!!!!!

Expand All @@ -89,12 +90,11 @@ def _r3_routine(self, trainer):
pts = pts.cpu().detach()
residuals = res_loss[location].cpu()
mask = (residuals > avg).flatten()
# TODO masking remove labels
pts = pts[mask]
pts.labels = labels
####
old_pts[location] = pts
tot_points += len(pts)
if any(mask): # if there are residuals greater than averge we append them
pts = pts[mask] # TODO masking remove labels
pts.labels = labels
old_pts[location] = pts
tot_points += len(pts)

# extract new points to sample uniformally for each location
n_points = (self._tot_pop_numb - tot_points ) // len(self._sampling_locations)
Expand Down
2 changes: 1 addition & 1 deletion pina/callbacks/processing_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
'''PINA Callbacks Implementations'''

from lightning.pytorch.callbacks import Callback
from pytorch_lightning.callbacks import Callback
import torch
import copy

Expand Down
15 changes: 7 additions & 8 deletions pina/label_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,6 @@ def vstack(label_tensors):
tensors = [lt.extract(labels) for lt in label_tensors]
return LabelTensor(torch.vstack(tensors), labels)

# TODO remove try/ except thing IMPORTANT
# make the label None of default
def clone(self, *args, **kwargs):
"""
Clone the LabelTensor. For more details, see
Expand All @@ -128,11 +126,12 @@ def clone(self, *args, **kwargs):
:return: a copy of the tensor
:rtype: LabelTensor
"""
try:
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
except: # this is used when the tensor loose the labels, notice it will create a bug! Kept for compatibility with Lightining
out = super().clone(*args, **kwargs)

# # used before merging
# try:
# out = LabelTensor(super().clone(*args, **kwargs), self.labels)
# except:
# out = super().clone(*args, **kwargs)
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
return out

def to(self, *args, **kwargs):
Expand Down Expand Up @@ -298,4 +297,4 @@ def __str__(self):
else:
s = 'no labels\n'
s += super().__str__()
return s
return s
9 changes: 5 additions & 4 deletions pina/solvers/garom.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def _train_generator(self, parameters, snapshots):
Private method to train the generator network.
"""
optimizer = self.optimizer_generator
optimizer.zero_grad()

generated_snapshots = self.generator(parameters)

Expand Down Expand Up @@ -258,10 +259,10 @@ def training_step(self, batch, batch_idx):
diff = self._update_weights(d_loss_real, d_loss_fake)

# logging
self.log('mean_loss', float(r_loss), prog_bar=True, logger=True)
self.log('d_loss', float(d_loss), prog_bar=True, logger=True)
self.log('g_loss', float(g_loss), prog_bar=True, logger=True)
self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True)
self.log('mean_loss', float(r_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False)
self.log('d_loss', float(d_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False)
self.log('g_loss', float(g_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False)
self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True, on_epoch=True, on_step=False)

return

Expand Down
15 changes: 8 additions & 7 deletions pina/solvers/pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,26 +130,27 @@ def training_step(self, batch, batch_idx):

if len(batch) == 2:
samples = pts[condition_idx == condition_id]
loss = self._loss_phys(pts, condition.equation)
loss = self._loss_phys(samples, condition.equation)
elif len(batch) == 3:
samples = pts[condition_idx == condition_id]
ground_truth = batch['output'][condition_idx == condition_id]
loss = self._loss_data(samples, ground_truth)
else:
raise ValueError("Batch size not supported")

# TODO for users this us hard to remebeber when creating a new solver, to fix in a smarter way
loss = loss.as_subclass(torch.Tensor)
loss = loss

# add condition losses and accumulate logging for each epoch
condition_losses.append(loss * condition.data_weight)
self.log(condition_name + '_loss', float(loss),
prog_bar=True, logger=True, on_epoch=True, on_step=False)

# TODO Fix the bug, tot_loss is a label tensor without labels
# we need to pass it as a torch tensor to make everything work
# add to tot loss and accumulate logging for each epoch
total_loss = sum(condition_losses)
self.log('mean_loss', float(total_loss / len(condition_losses)),
prog_bar=True, logger=True, on_epoch=True, on_step=False)

self.log('mean_loss', float(total_loss / len(condition_losses)), prog_bar=True, logger=True)
# for condition_loss, loss in zip(condition_names, condition_losses):
# self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=True)
return total_loss

@property
Expand Down

0 comments on commit c3a80e2

Please sign in to comment.