Skip to content

Commit

Permalink
Modifying solvers to log every epoch correctly
Browse files Browse the repository at this point in the history
* add `on_epoch` flag to logger
* fix bug in `pinn.py` `pts -> samples` in `_loss_phys`
* add `optimizer_zero_grad()` in garom generator training loop
* modify imports in `callbacks.py`
  • Loading branch information
Dario Coscia committed Nov 7, 2023
1 parent 9353834 commit 05050f4
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 52 deletions.
2 changes: 1 addition & 1 deletion pina/callbacks/processing_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
'''PINA Callbacks Implementations'''

from lightning.pytorch.callbacks import Callback
from pytorch_lightning.callbacks import Callback
import torch
import copy

Expand Down
49 changes: 9 additions & 40 deletions pina/label_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,6 @@ def __init__(self, x, labels):
[1.0246e-01, 9.5179e-01, 3.7043e-02],
[9.6150e-01, 8.0656e-01, 8.3824e-01]])
>>> tensor.extract('a')
tensor([[0.0671],
[0.9239],
[0.8927],
...,
[0.5819],
[0.1025],
[0.9615]])
>>> tensor['a']
tensor([[0.0671],
[0.9239],
[0.8927],
Expand Down Expand Up @@ -77,7 +69,7 @@ def __init__(self, x, labels):
'the passed labels.'
)
self._labels = labels

@property
def labels(self):
"""Property decorator for labels
Expand Down Expand Up @@ -118,8 +110,6 @@ def vstack(label_tensors):
tensors = [lt.extract(labels) for lt in label_tensors]
return LabelTensor(torch.vstack(tensors), labels)

# TODO remove try/ except thing IMPORTANT
# make the label None of default
def clone(self, *args, **kwargs):
"""
Clone the LabelTensor. For more details, see
Expand All @@ -128,11 +118,12 @@ def clone(self, *args, **kwargs):
:return: a copy of the tensor
:rtype: LabelTensor
"""
try:
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
except: # this is used when the tensor loose the labels, notice it will create a bug! Kept for compatibility with Lightining
out = super().clone(*args, **kwargs)

# used before merging
# try:
# out = LabelTensor(super().clone(*args, **kwargs), self.labels)
# except:
# out = super().clone(*args, **kwargs)
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
return out

def to(self, *args, **kwargs):
Expand All @@ -153,24 +144,6 @@ def select(self, *args, **kwargs):
tmp._labels = self._labels
return tmp

def cuda(self, *args, **kwargs):
"""
Send Tensor to cuda. For more details, see :meth:`torch.Tensor.cuda`.
"""
tmp = super().cuda(*args, **kwargs)
new = self.__class__.clone(self)
new.data = tmp.data
return tmp

def cpu(self, *args, **kwargs):
"""
Send Tensor to cpu. For more details, see :meth:`torch.Tensor.cpu`.
"""
tmp = super().cpu(*args, **kwargs)
new = self.__class__.clone(self)
new.data = tmp.data
return tmp

def extract(self, label_to_extract):
"""
Extract the subset of the original tensor by returning all the columns
Expand All @@ -197,7 +170,7 @@ def extract(self, label_to_extract):
except ValueError:
raise ValueError(f'`{f}` not in the labels list')

new_data = super(Tensor, self.T).__getitem__(indeces).T
new_data = super(Tensor, self.T).__getitem__(indeces).float().T
new_labels = [self.labels[idx] for idx in indeces]

extracted_tensor = new_data.as_subclass(LabelTensor)
Expand Down Expand Up @@ -256,12 +229,8 @@ def __getitem__(self, index):
"""
Return a copy of the selected tensor.
"""

if isinstance(index, str) or (isinstance(index, (tuple, list))and all(isinstance(a, str) for a in index)):
return self.extract(index)

selected_lt = super(Tensor, self).__getitem__(index)

try:
len_index = len(index)
except TypeError:
Expand Down
9 changes: 5 additions & 4 deletions pina/solvers/garom.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def _train_generator(self, parameters, snapshots):
Private method to train the generator network.
"""
optimizer = self.optimizer_generator
optimizer.zero_grad()

generated_snapshots = self.generator(parameters)

Expand Down Expand Up @@ -258,10 +259,10 @@ def training_step(self, batch, batch_idx):
diff = self._update_weights(d_loss_real, d_loss_fake)

# logging
self.log('mean_loss', float(r_loss), prog_bar=True, logger=True)
self.log('d_loss', float(d_loss), prog_bar=True, logger=True)
self.log('g_loss', float(g_loss), prog_bar=True, logger=True)
self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True)
self.log('mean_loss', float(r_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False)
self.log('d_loss', float(d_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False)
self.log('g_loss', float(g_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False)
self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True, on_epoch=True, on_step=False)

return

Expand Down
15 changes: 8 additions & 7 deletions pina/solvers/pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,26 +130,27 @@ def training_step(self, batch, batch_idx):

if len(batch) == 2:
samples = pts[condition_idx == condition_id]
loss = self._loss_phys(pts, condition.equation)
loss = self._loss_phys(samples, condition.equation)
elif len(batch) == 3:
samples = pts[condition_idx == condition_id]
ground_truth = batch['output'][condition_idx == condition_id]
loss = self._loss_data(samples, ground_truth)
else:
raise ValueError("Batch size not supported")

# TODO for users this us hard to remebeber when creating a new solver, to fix in a smarter way
loss = loss.as_subclass(torch.Tensor)
loss = loss

# add condition losses and accumulate logging for each epoch
condition_losses.append(loss * condition.data_weight)
self.log(condition_name + '_loss', float(loss),
prog_bar=True, logger=True, on_epoch=True, on_step=False)

# TODO Fix the bug, tot_loss is a label tensor without labels
# we need to pass it as a torch tensor to make everything work
# add to tot loss and accumulate logging for each epoch
total_loss = sum(condition_losses)
self.log('mean_loss', float(total_loss / len(condition_losses)),
prog_bar=True, logger=True, on_epoch=True, on_step=False)

self.log('mean_loss', float(total_loss / len(condition_losses)), prog_bar=True, logger=True)
# for condition_loss, loss in zip(condition_names, condition_losses):
# self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=True)
return total_loss

@property
Expand Down

0 comments on commit 05050f4

Please sign in to comment.