Skip to content

Commit

Permalink
add spline model (#321)
Browse files Browse the repository at this point in the history
* add spline model
* add tests for splines
* rst files for splines

---------

Co-authored-by: AleDinve <[email protected]>
Co-authored-by: dario-coscia <[email protected]>
  • Loading branch information
3 people authored Sep 27, 2024
1 parent fefba81 commit 4c5cb8f
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/_rst/_code.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Models
FeedForward <models/fnn.rst>
MultiFeedForward <models/multifeedforward.rst>
ResidualFeedForward <models/fnn_residual.rst>
Spline <models/spline.rst>
DeepONet <models/deeponet.rst>
MIONet <models/mionet.rst>
FourierIntegralKernel <models/fourier_kernel.rst>
Expand Down
7 changes: 7 additions & 0 deletions docs/source/_rst/models/spline.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Spline
========
.. currentmodule:: pina.model.spline

.. autoclass:: Spline
:members:
:show-inheritance:
2 changes: 2 additions & 0 deletions pina/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"KernelNeuralOperator",
"AveragingNeuralOperator",
"LowRankNeuralOperator",
"Spline",
]

from .feed_forward import FeedForward, ResidualFeedForward
Expand All @@ -18,3 +19,4 @@
from .base_no import KernelNeuralOperator
from .avno import AveragingNeuralOperator
from .lno import LowRankNeuralOperator
from .spline import Spline
166 changes: 166 additions & 0 deletions pina/model/spline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""Module for Spline model"""

import torch
import torch.nn as nn
from ..utils import check_consistency

class Spline(torch.nn.Module):

def __init__(self, order=4, knots=None, control_points=None) -> None:
"""
Spline model.
:param int order: the order of the spline.
:param torch.Tensor knots: the knot vector.
:param torch.Tensor control_points: the control points.
"""
super().__init__()

check_consistency(order, int)

if order < 0:
raise ValueError("Spline order cannot be negative.")
if knots is None and control_points is None:
raise ValueError("Knots and control points cannot be both None.")

self.order = order
self.k = order - 1

if knots is not None and control_points is not None:
self.knots = knots
self.control_points = control_points

elif knots is not None:
print('Warning: control points will be initialized automatically.')
print(' experimental feature')

self.knots = knots
n = len(knots) - order
self.control_points = torch.nn.Parameter(
torch.zeros(n), requires_grad=True)

elif control_points is not None:
print('Warning: knots will be initialized automatically.')
print(' experimental feature')

self.control_points = control_points

n = len(self.control_points)-1
self.knots = {
'type': 'auto',
'min': 0,
'max': 1,
'n': n+2+self.order}

else:
raise ValueError(
"Knots and control points cannot be both None."
)


if self.knots.ndim != 1:
raise ValueError("Knot vector must be one-dimensional.")

def basis(self, x, k, i, t):
'''
Recursive function to compute the basis functions of the spline.
:param torch.Tensor x: points to be evaluated.
:param int k: spline degree
:param int i: the index of the interval
:param torch.Tensor t: vector of knots
:return: the basis functions evaluated at x
:rtype: torch.Tensor
'''

if k == 0:
a = torch.where(torch.logical_and(t[i] <= x, x < t[i+1]), 1.0, 0.0)
if i == len(t) - self.order - 1:
a = torch.where(x == t[-1], 1.0, a)
a.requires_grad_(True)
return a


if t[i+k] == t[i]:
c1 = torch.tensor([0.0]*len(x), requires_grad=True)
else:
c1 = (x - t[i])/(t[i+k] - t[i]) * self.basis(x, k-1, i, t)

if t[i+k+1] == t[i+1]:
c2 = torch.tensor([0.0]*len(x), requires_grad=True)
else:
c2 = (t[i+k+1] - x)/(t[i+k+1] - t[i+1]) * self.basis(x, k-1, i+1, t)

return c1 + c2


@property
def control_points(self):
return self._control_points

@control_points.setter
def control_points(self, value):
if isinstance(value, dict):
if 'n' not in value:
raise ValueError('Invalid value for control_points')
n = value['n']
dim = value.get('dim', 1)
value = torch.zeros(n, dim)

if not isinstance(value, torch.Tensor):
raise ValueError('Invalid value for control_points')
self._control_points = torch.nn.Parameter(value, requires_grad=True)

@property
def knots(self):
return self._knots

@knots.setter
def knots(self, value):
if isinstance(value, dict):

type_ = value.get('type', 'auto')
min_ = value.get('min', 0)
max_ = value.get('max', 1)
n = value.get('n', 10)

if type_ == 'uniform':
value = torch.linspace(min_, max_, n + self.k + 1)
elif type_ == 'auto':
initial_knots = torch.ones(self.order+1)*min_
final_knots = torch.ones(self.order+1)*max_

if n < self.order + 1:
value = torch.concatenate((initial_knots, final_knots))
elif n - 2*self.order + 1 == 1:
value = torch.Tensor([(max_ + min_)/2])
else:
value = torch.linspace(min_, max_, n - 2*self.order - 1)

value = torch.concatenate(
(
initial_knots, value, final_knots
)
)

if not isinstance(value, torch.Tensor):
raise ValueError('Invalid value for knots')

self._knots = value

def forward(self, x_):
"""
Forward pass of the spline model.
:param torch.Tensor x_: points to be evaluated.
:return: the spline evaluated at x_
:rtype: torch.Tensor
"""
t = self.knots
k = self.k
c = self.control_points

basis = map(lambda i: self.basis(x_, k, i, t)[:, None], range(len(c)))
y = (torch.cat(list(basis), dim=1) * c).sum(axis=1)

return y
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
'sphinx_design',
'pydata_sphinx_theme'
],
'test': ['pytest', 'pytest-cov'],
'test': [
'pytest',
'pytest-cov',
'scipy'
],
}

LDESCRIPTION = (
Expand Down
74 changes: 74 additions & 0 deletions tests/test_model/test_spline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import torch
import pytest

from pina.model import Spline

data = torch.rand((20, 3))
input_vars = 3
output_vars = 4

valid_args = [
{
'knots': torch.tensor([0., 0., 0., 1., 2., 3., 3., 3.]),
'control_points': torch.tensor([0., 0., 1., 0., 0.]),
'order': 3
},
{
'knots': torch.tensor([-2., -2., -2., -2., -1., 0., 1., 2., 2., 2., 2.]),
'control_points': torch.tensor([0., 0., 0., 6., 0., 0., 0.]),
'order': 4
},
# {'control_points': {'n': 5, 'dim': 1}, 'order': 2},
# {'control_points': {'n': 7, 'dim': 1}, 'order': 3}
]

def scipy_check(model, x, y):
from scipy.interpolate._bsplines import BSpline
import numpy as np
spline = BSpline(
t=model.knots.detach().numpy(),
c=model.control_points.detach().numpy(),
k=model.order-1
)
y_scipy = spline(x).flatten()
y = y.detach().numpy()
np.testing.assert_allclose(y, y_scipy, atol=1e-5)

@pytest.mark.parametrize("args", valid_args)
def test_constructor(args):
Spline(**args)

def test_constructor_wrong():
with pytest.raises(ValueError):
Spline()

@pytest.mark.parametrize("args", valid_args)
def test_forward(args):
min_x = args['knots'][0]
max_x = args['knots'][-1]
xi = torch.linspace(min_x, max_x, 1000)
model = Spline(**args)
yi = model(xi).squeeze()
scipy_check(model, xi, yi)
return


@pytest.mark.parametrize("args", valid_args)
def test_backward(args):
min_x = args['knots'][0]
max_x = args['knots'][-1]
xi = torch.linspace(min_x, max_x, 100)
model = Spline(**args)
yi = model(xi)
fake_loss = torch.sum(yi)
assert model.control_points.grad is None
fake_loss.backward()
assert model.control_points.grad is not None

# dim_in, dim_out = 3, 2
# fnn = FeedForward(dim_in, dim_out)
# data.requires_grad = True
# output_ = fnn(data)
# l=torch.mean(output_)
# l.backward()
# assert data._grad.shape == torch.Size([20,3])

0 comments on commit 4c5cb8f

Please sign in to comment.