-
Notifications
You must be signed in to change notification settings - Fork 69
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* fix doc loss + adding PowerLoss * adding loss tests folder --------- Co-authored-by: Dario Coscia <[email protected]>
- Loading branch information
1 parent
4c256f1
commit 18065eb
Showing
7 changed files
with
163 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
LpLoss | ||
==== | ||
.. currentmodule:: pina.loss | ||
|
||
.. automodule:: pina.loss | ||
|
||
.. autoclass:: LpLoss | ||
:members: | ||
:private-members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
PowerLoss | ||
========= | ||
.. currentmodule:: pina.loss | ||
|
||
.. automodule:: pina.loss | ||
|
||
.. autoclass:: PowerLoss | ||
:members: | ||
:private-members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import torch | ||
import pytest | ||
|
||
from pina.loss import PowerLoss | ||
|
||
input = torch.tensor([[3.], [1.], [-8.]]) | ||
target = torch.tensor([[6.], [4.], [2.]]) | ||
available_reductions = ['str', 'mean', 'none'] | ||
|
||
|
||
def test_PowerLoss_constructor(): | ||
# test reduction | ||
for reduction in available_reductions: | ||
PowerLoss(reduction=reduction) | ||
# test p | ||
for p in [float('inf'), -float('inf'), 1, 10, -8]: | ||
PowerLoss(p=p) | ||
|
||
def test_PowerLoss_forward(): | ||
# l2 loss | ||
loss = PowerLoss(p=2, reduction='mean') | ||
l2_loss = torch.mean((input-target).pow(2)) | ||
assert loss(input, target) == l2_loss | ||
# l1 loss | ||
loss = PowerLoss(p=1, reduction='sum') | ||
l1_loss = torch.sum(torch.abs(input-target)) | ||
assert loss(input, target) == l1_loss | ||
|
||
def test_LpRelativeLoss_constructor(): | ||
# test reduction | ||
for reduction in available_reductions: | ||
PowerLoss(reduction=reduction, relative=True) | ||
# test p | ||
for p in [float('inf'), -float('inf'), 1, 10, -8]: | ||
PowerLoss(p=p,relative=True) | ||
|
||
def test_LpRelativeLoss_forward(): | ||
# l2 relative loss | ||
loss = PowerLoss(p=2, reduction='mean',relative=True) | ||
l2_loss = (input-target).pow(2)/input.pow(2) | ||
assert loss(input, target) == torch.mean(l2_loss) | ||
# l1 relative loss | ||
loss = PowerLoss(p=1, reduction='sum',relative=True) | ||
l1_loss = torch.abs(input-target)/torch.abs(input) | ||
assert loss(input, target) == torch.sum(l1_loss) | ||
|
||
|
||
|
||
|