Skip to content

Commit

Permalink
Add validation_step and improve data management
Browse files Browse the repository at this point in the history
  • Loading branch information
FilippoOlivo committed Nov 11, 2024
1 parent 06ac8ad commit 6fabbff
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 28 deletions.
52 changes: 38 additions & 14 deletions pina/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import torch
import logging
from pytorch_lightning import LightningDataModule
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, \
TRAIN_DATALOADERS

from .sample_dataset import SamplePointDataset
from .supervised_dataset import SupervisedDataset
from .unsupervised_dataset import UnsupervisedDataset
Expand Down Expand Up @@ -61,30 +64,31 @@ def __init__(self,
if train_size > 0:
self.split_names.append('train')
self.split_length.append(train_size)
self.loader_functions['train_dataloader'] = lambda \
x: PinaDataLoader(self.splits['train'], self.batch_size,
self.condition_names)
else:
self.train_dataloader = super().train_dataloader

if test_size > 0:
self.split_length.append(test_size)
self.split_names.append('test')
self.loader_functions['test_dataloader'] = lambda x: PinaDataLoader(
self.splits['test'], self.batch_size, self.condition_names)
else:
self.test_dataloader = super().test_dataloader

if val_size > 0:
self.split_length.append(val_size)
self.split_names.append('val')
self.loader_functions['val_dataloader'] = lambda x: PinaDataLoader(
self.splits['val'], self.batch_size, self.condition_names)
else:
self.val_dataloader = super().val_dataloader

if predict_size > 0:
self.split_length.append(predict_size)
self.split_names.append('predict')
self.loader_functions[
'predict_dataloader'] = lambda x: PinaDataLoader(
self.splits['predict'], self.batch_size, self.condition_names)
else:
self.predict_dataloader = super().predict_dataloader

self.splits = {k: {} for k in self.split_names}
self.shuffle = shuffle

for k, v in self.loader_functions.items():
setattr(self, k, v.__get__(self, PinaDataModule))
self.has_setup_fit = False
self.has_setup_test = False

def prepare_data(self):
if self.datasets is None:
Expand All @@ -106,8 +110,12 @@ def setup(self, stage=None):
for i in range(len(self.split_length)):
self.splits[self.split_names[i]][
dataset.data_type] = splits[i]
self.has_setup_fit = True
elif stage == 'test':
raise NotImplementedError("Testing pipeline not implemented yet")
if self.has_setup_fit is False:
raise NotImplementedError(
"You must call setup with stage='fit' "
"first")
else:
raise ValueError("stage must be either 'fit' or 'test'")

Expand Down Expand Up @@ -178,3 +186,19 @@ def _create_datasets(self):
dataset.initialize()
datasets.append(dataset)
self.datasets = datasets

def val_dataloader(self) -> EVAL_DATALOADERS:
return PinaDataLoader(self.splits['val'], self.batch_size,
self.condition_names)

def train_dataloader(self) -> TRAIN_DATALOADERS:
return PinaDataLoader(self.splits['train'], self.batch_size,
self.condition_names)

def test_dataloader(self) -> EVAL_DATALOADERS:
return PinaDataLoader(self.splits['test'], self.batch_size,
self.condition_names)

def predict_dataloader(self) -> EVAL_DATALOADERS:
return PinaDataLoader(self.splits['predict'], self.batch_size,
self.condition_names)
1 change: 0 additions & 1 deletion pina/solvers/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def __init__(self,
" optimizers.")

# extra features handling

self._pina_models = models
self._pina_optimizers = optimizers
self._pina_schedulers = schedulers
Expand Down
98 changes: 91 additions & 7 deletions pina/solvers/supervised.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" Module for SupervisedSolver """
import torch
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch.nn.modules.loss import _Loss
from ..optim import TorchOptimizer, TorchScheduler
from .solver import SolverInterface
Expand Down Expand Up @@ -75,11 +76,15 @@ def __init__(self,
extra_features=extra_features)

# check consistency
check_consistency(loss, (LossInterface, _Loss), subclass=False)
check_consistency(loss, (LossInterface, _Loss),
subclass=False)
self._loss = loss
self._model = self._pina_models[0]
self._optimizer = self._pina_optimizers[0]
self._scheduler = self._pina_schedulers[0]
self.validation_condition_losses = {
k: {'loss': [],
'count': []} for k in self.problem.conditions.keys()}

def forward(self, x):
"""Forward pass implementation for the solver.
Expand Down Expand Up @@ -117,12 +122,14 @@ def training_step(self, batch, batch_idx):
"""

condition_idx = batch.supervised.condition_indices
loss = torch.tensor(0, dtype=torch.float32)
loss = torch.tensor(0, dtype=torch.float32).to(self.device)
batch = batch.supervised
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
condition_name = self._dataloader.condition_names[condition_id]
condition_name = self.trainer.data_module.condition_names[
condition_id]
condition = self.problem.conditions[condition_name]
pts = batch.supervised.input_points
out = batch.supervised.output_points
pts = batch.input_points
out = batch.output_points
if condition_name not in self.problem.conditions:
raise RuntimeError("Something wrong happened.")

Expand All @@ -134,13 +141,90 @@ def training_step(self, batch, batch_idx):
output_pts = out[condition_idx == condition_id]
input_pts = pts[condition_idx == condition_id]


loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
loss += loss_.as_subclass(torch.Tensor)

self.log("mean_loss", float(loss), prog_bar=True, logger=True)
self.log("mean_loss", float(loss), prog_bar=True, logger=True,
on_epoch=True,
on_step=False, batch_size=self.trainer.data_module.batch_size)
return loss

def validation_step(self, batch, batch_idx):
"""
Solver validation step.
"""

batch = batch.supervised
condition_idx = batch.condition_indices
for i in range(condition_idx.min(), condition_idx.max() + 1):
condition_name = self.trainer.data_module.condition_names[i]
condition = self.problem.conditions[condition_name]
pts = batch.input_points
out = batch.output_points
if condition_name not in self.problem.conditions:
raise RuntimeError("Something wrong happened.")

# for data driven mode
if not hasattr(condition, "output_points"):
raise NotImplementedError(
f"{type(self).__name__} works only in data-driven mode.")

output_pts = out[condition_idx == i]
input_pts = pts[condition_idx == i]

loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
self.validation_condition_losses[condition_name]['loss'].append(
loss_)
self.validation_condition_losses[condition_name]['count'].append(
len(input_pts))

def on_validation_epoch_end(self):
"""
Solver validation epoch end.
"""
total_loss = []
total_count = []
for k, v in self.validation_condition_losses.items():
local_counter = torch.tensor(v['count']).to(self.device)
n_elements = torch.sum(local_counter)
loss = torch.sum(
torch.stack(v['loss']) * local_counter) / n_elements
loss = loss.as_subclass(torch.Tensor)
total_loss.append(loss)
total_count.append(n_elements)
self.log(
k + "_loss",
loss,
prog_bar=True,
logger=True,
on_epoch=True,
on_step=False,
batch_size=self.trainer.data_module.batch_size,
)
total_count = (torch.tensor(total_count, dtype=torch.float32).
to(self.device))
mean_loss = (torch.sum(torch.stack(total_loss) * total_count) /
total_count)
self.log(
"val_loss",
mean_loss,
prog_bar=True,
logger=True,
on_epoch=True,
on_step=False,
batch_size=self.trainer.data_module.batch_size,
)
for key in self.validation_condition_losses.keys():
self.validation_condition_losses[key]['loss'] = []
self.validation_condition_losses[key]['count'] = []

def test_step(self, batch, batch_idx) -> STEP_OUTPUT:
"""
Solver test step.
"""

raise NotImplementedError("Test step not implemented.")

def loss_data(self, input_pts, output_pts):
"""
The data loss for the Supervised solver. It computes the loss between
Expand Down
29 changes: 23 additions & 6 deletions pina/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import pytorch_lightning
import warnings
from .utils import check_consistency
from .data import PinaDataModule
from .solvers.solver import SolverInterface
Expand All @@ -15,6 +16,7 @@ def __init__(self,
train_size=.7,
test_size=.2,
val_size=.1,
predict_size=.0,
**kwargs):
"""
PINA Trainer class for costumizing every aspect of training via flags.
Expand All @@ -30,8 +32,8 @@ def __init__(self,
and can be choosen from the `pytorch-lightning
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
"""

super().__init__(**kwargs)
log_every_n_steps = kwargs.get('log_every_n_steps', 0)
super().__init__(log_every_n_steps=log_every_n_steps, **kwargs)

# check inheritance consistency for solver and batch size
check_consistency(solver, SolverInterface)
Expand All @@ -40,9 +42,9 @@ def __init__(self,
self.train_size = train_size
self.test_size = test_size
self.val_size = val_size
self.predict_size = predict_size
self.solver = solver
self.batch_size = batch_size
self._create_loader()
self._move_to_device()
self.data_module = None

Expand Down Expand Up @@ -83,6 +85,7 @@ def _create_loader(self):
train_size=self.train_size,
test_size=self.test_size,
val_size=self.val_size,
predict_size=self.predict_size,
batch_size=self.batch_size, )
self.data_module.setup()

Expand All @@ -91,9 +94,23 @@ def train(self, **kwargs):
Train the solver method.
"""
self._create_loader()
return super().fit(self.solver,
datamodule=self.data_module,
**kwargs)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="You defined a `validation_step` but have no `val_dataloader`",
category=UserWarning
)
return super().fit(self.solver,
datamodule=self.data_module,
**kwargs)

def test(self, **kwargs):
"""
Test the solver method.
"""
return super().test(self.solver,
datamodule=self.data_module,
**kwargs)

@property
def solver(self):
Expand Down

0 comments on commit 6fabbff

Please sign in to comment.