-
Notifications
You must be signed in to change notification settings - Fork 69
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add spline model * add tests for splines * rst files for splines --------- Co-authored-by: AleDinve <[email protected]> Co-authored-by: dario-coscia <[email protected]>
- Loading branch information
1 parent
fefba81
commit 4c5cb8f
Showing
6 changed files
with
255 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Spline | ||
======== | ||
.. currentmodule:: pina.model.spline | ||
|
||
.. autoclass:: Spline | ||
:members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
"""Module for Spline model""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
from ..utils import check_consistency | ||
|
||
class Spline(torch.nn.Module): | ||
|
||
def __init__(self, order=4, knots=None, control_points=None) -> None: | ||
""" | ||
Spline model. | ||
:param int order: the order of the spline. | ||
:param torch.Tensor knots: the knot vector. | ||
:param torch.Tensor control_points: the control points. | ||
""" | ||
super().__init__() | ||
|
||
check_consistency(order, int) | ||
|
||
if order < 0: | ||
raise ValueError("Spline order cannot be negative.") | ||
if knots is None and control_points is None: | ||
raise ValueError("Knots and control points cannot be both None.") | ||
|
||
self.order = order | ||
self.k = order - 1 | ||
|
||
if knots is not None and control_points is not None: | ||
self.knots = knots | ||
self.control_points = control_points | ||
|
||
elif knots is not None: | ||
print('Warning: control points will be initialized automatically.') | ||
print(' experimental feature') | ||
|
||
self.knots = knots | ||
n = len(knots) - order | ||
self.control_points = torch.nn.Parameter( | ||
torch.zeros(n), requires_grad=True) | ||
|
||
elif control_points is not None: | ||
print('Warning: knots will be initialized automatically.') | ||
print(' experimental feature') | ||
|
||
self.control_points = control_points | ||
|
||
n = len(self.control_points)-1 | ||
self.knots = { | ||
'type': 'auto', | ||
'min': 0, | ||
'max': 1, | ||
'n': n+2+self.order} | ||
|
||
else: | ||
raise ValueError( | ||
"Knots and control points cannot be both None." | ||
) | ||
|
||
|
||
if self.knots.ndim != 1: | ||
raise ValueError("Knot vector must be one-dimensional.") | ||
|
||
def basis(self, x, k, i, t): | ||
''' | ||
Recursive function to compute the basis functions of the spline. | ||
:param torch.Tensor x: points to be evaluated. | ||
:param int k: spline degree | ||
:param int i: the index of the interval | ||
:param torch.Tensor t: vector of knots | ||
:return: the basis functions evaluated at x | ||
:rtype: torch.Tensor | ||
''' | ||
|
||
if k == 0: | ||
a = torch.where(torch.logical_and(t[i] <= x, x < t[i+1]), 1.0, 0.0) | ||
if i == len(t) - self.order - 1: | ||
a = torch.where(x == t[-1], 1.0, a) | ||
a.requires_grad_(True) | ||
return a | ||
|
||
|
||
if t[i+k] == t[i]: | ||
c1 = torch.tensor([0.0]*len(x), requires_grad=True) | ||
else: | ||
c1 = (x - t[i])/(t[i+k] - t[i]) * self.basis(x, k-1, i, t) | ||
|
||
if t[i+k+1] == t[i+1]: | ||
c2 = torch.tensor([0.0]*len(x), requires_grad=True) | ||
else: | ||
c2 = (t[i+k+1] - x)/(t[i+k+1] - t[i+1]) * self.basis(x, k-1, i+1, t) | ||
|
||
return c1 + c2 | ||
|
||
|
||
@property | ||
def control_points(self): | ||
return self._control_points | ||
|
||
@control_points.setter | ||
def control_points(self, value): | ||
if isinstance(value, dict): | ||
if 'n' not in value: | ||
raise ValueError('Invalid value for control_points') | ||
n = value['n'] | ||
dim = value.get('dim', 1) | ||
value = torch.zeros(n, dim) | ||
|
||
if not isinstance(value, torch.Tensor): | ||
raise ValueError('Invalid value for control_points') | ||
self._control_points = torch.nn.Parameter(value, requires_grad=True) | ||
|
||
@property | ||
def knots(self): | ||
return self._knots | ||
|
||
@knots.setter | ||
def knots(self, value): | ||
if isinstance(value, dict): | ||
|
||
type_ = value.get('type', 'auto') | ||
min_ = value.get('min', 0) | ||
max_ = value.get('max', 1) | ||
n = value.get('n', 10) | ||
|
||
if type_ == 'uniform': | ||
value = torch.linspace(min_, max_, n + self.k + 1) | ||
elif type_ == 'auto': | ||
initial_knots = torch.ones(self.order+1)*min_ | ||
final_knots = torch.ones(self.order+1)*max_ | ||
|
||
if n < self.order + 1: | ||
value = torch.concatenate((initial_knots, final_knots)) | ||
elif n - 2*self.order + 1 == 1: | ||
value = torch.Tensor([(max_ + min_)/2]) | ||
else: | ||
value = torch.linspace(min_, max_, n - 2*self.order - 1) | ||
|
||
value = torch.concatenate( | ||
( | ||
initial_knots, value, final_knots | ||
) | ||
) | ||
|
||
if not isinstance(value, torch.Tensor): | ||
raise ValueError('Invalid value for knots') | ||
|
||
self._knots = value | ||
|
||
def forward(self, x_): | ||
""" | ||
Forward pass of the spline model. | ||
:param torch.Tensor x_: points to be evaluated. | ||
:return: the spline evaluated at x_ | ||
:rtype: torch.Tensor | ||
""" | ||
t = self.knots | ||
k = self.k | ||
c = self.control_points | ||
|
||
basis = map(lambda i: self.basis(x_, k, i, t)[:, None], range(len(c))) | ||
y = (torch.cat(list(basis), dim=1) * c).sum(axis=1) | ||
|
||
return y |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import torch | ||
import pytest | ||
|
||
from pina.model import Spline | ||
|
||
data = torch.rand((20, 3)) | ||
input_vars = 3 | ||
output_vars = 4 | ||
|
||
valid_args = [ | ||
{ | ||
'knots': torch.tensor([0., 0., 0., 1., 2., 3., 3., 3.]), | ||
'control_points': torch.tensor([0., 0., 1., 0., 0.]), | ||
'order': 3 | ||
}, | ||
{ | ||
'knots': torch.tensor([-2., -2., -2., -2., -1., 0., 1., 2., 2., 2., 2.]), | ||
'control_points': torch.tensor([0., 0., 0., 6., 0., 0., 0.]), | ||
'order': 4 | ||
}, | ||
# {'control_points': {'n': 5, 'dim': 1}, 'order': 2}, | ||
# {'control_points': {'n': 7, 'dim': 1}, 'order': 3} | ||
] | ||
|
||
def scipy_check(model, x, y): | ||
from scipy.interpolate._bsplines import BSpline | ||
import numpy as np | ||
spline = BSpline( | ||
t=model.knots.detach().numpy(), | ||
c=model.control_points.detach().numpy(), | ||
k=model.order-1 | ||
) | ||
y_scipy = spline(x).flatten() | ||
y = y.detach().numpy() | ||
np.testing.assert_allclose(y, y_scipy, atol=1e-5) | ||
|
||
@pytest.mark.parametrize("args", valid_args) | ||
def test_constructor(args): | ||
Spline(**args) | ||
|
||
def test_constructor_wrong(): | ||
with pytest.raises(ValueError): | ||
Spline() | ||
|
||
@pytest.mark.parametrize("args", valid_args) | ||
def test_forward(args): | ||
min_x = args['knots'][0] | ||
max_x = args['knots'][-1] | ||
xi = torch.linspace(min_x, max_x, 1000) | ||
model = Spline(**args) | ||
yi = model(xi).squeeze() | ||
scipy_check(model, xi, yi) | ||
return | ||
|
||
|
||
@pytest.mark.parametrize("args", valid_args) | ||
def test_backward(args): | ||
min_x = args['knots'][0] | ||
max_x = args['knots'][-1] | ||
xi = torch.linspace(min_x, max_x, 100) | ||
model = Spline(**args) | ||
yi = model(xi) | ||
fake_loss = torch.sum(yi) | ||
assert model.control_points.grad is None | ||
fake_loss.backward() | ||
assert model.control_points.grad is not None | ||
|
||
# dim_in, dim_out = 3, 2 | ||
# fnn = FeedForward(dim_in, dim_out) | ||
# data.requires_grad = True | ||
# output_ = fnn(data) | ||
# l=torch.mean(output_) | ||
# l.backward() | ||
# assert data._grad.shape == torch.Size([20,3]) |