Skip to content

Commit

Permalink
GAROM solver loss update
Browse files Browse the repository at this point in the history
Changing from `LpLoss` to `PowerLoss`
  • Loading branch information
dario-coscia authored Oct 6, 2023
1 parent 5806403 commit d8153d7
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions pina/solvers/garom.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Module for PINN """
""" Module for GAROM """

import torch
try:
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
Expand All @@ -8,7 +9,7 @@
from torch.optim.lr_scheduler import ConstantLR
from .solver import SolverInterface
from ..utils import check_consistency
from ..loss import LossInterface, LpLoss
from ..loss import LossInterface, PowerLoss
from torch.nn.modules.loss import _Loss


Expand Down Expand Up @@ -58,7 +59,7 @@ def __init__(self,
extra features for each.
:param torch.nn.Module loss: The loss function used as minimizer,
default ``None``. If ``loss`` is ``None`` the defualt
``LpLoss(p=1)`` is used, as in the original paper.
``PowerLoss(p=1)`` is used, as in the original paper.
:param torch.optim.Optimizer optimizer_generator: The neural
network optimizer to use for the generator network
, default is `torch.optim.Adam`.
Expand Down Expand Up @@ -102,7 +103,7 @@ def __init__(self,

# set loss
if loss is None:
loss = LpLoss(p=1)
loss = PowerLoss(p=1)

# check consistency
check_consistency(scheduler_generator, LRScheduler, subclass=True)
Expand Down Expand Up @@ -264,4 +265,4 @@ def scheduler_generator(self):

@property
def scheduler_discriminator(self):
return self._schedulers[1]
return self._schedulers[1]

0 comments on commit d8153d7

Please sign in to comment.