diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 8e7b31f7..d1c40062 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -56,6 +56,7 @@ Models MIONet FourierIntegralKernel FNO + AveragingNeuralOperator Layers ------------- @@ -67,10 +68,10 @@ Layers EnhancedLinear layer Spectral convolution Fourier layers + Averaging layer Continuous convolution Proper Orthogonal Decomposition Periodic Boundary Condition embeddings - Equations and Operators ------------------------- diff --git a/docs/source/_rst/layers/avno_layer.rst b/docs/source/_rst/layers/avno_layer.rst new file mode 100644 index 00000000..38d7ccbe --- /dev/null +++ b/docs/source/_rst/layers/avno_layer.rst @@ -0,0 +1,8 @@ +Averaging layers +==================== +.. currentmodule:: pina.model.layers.avno_layer + +.. autoclass:: AVNOBlock + :members: + :show-inheritance: + :noindex: diff --git a/docs/source/_rst/models/avno.rst b/docs/source/_rst/models/avno.rst new file mode 100644 index 00000000..a083f6fd --- /dev/null +++ b/docs/source/_rst/models/avno.rst @@ -0,0 +1,7 @@ +Averaging Neural Operator +============================== +.. currentmodule:: pina.model.avno + +.. autoclass:: AveragingNeuralOperator + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/model/__init__.py b/pina/model/__init__.py index 869a4365..b0849887 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -7,6 +7,7 @@ "FNO", "FourierIntegralKernel", "KernelNeuralOperator", + "AveragingNeuralOperator", ] from .feed_forward import FeedForward, ResidualFeedForward @@ -14,3 +15,4 @@ from .deeponet import DeepONet, MIONet from .fno import FNO, FourierIntegralKernel from .base_no import KernelNeuralOperator +from .avno import AveragingNeuralOperator diff --git a/pina/model/avno.py b/pina/model/avno.py new file mode 100644 index 00000000..b85695ca --- /dev/null +++ b/pina/model/avno.py @@ -0,0 +1,104 @@ +"""Module Averaging Neural Operator.""" + +from torch import nn, concatenate +from . import FeedForward +from .layers import AVNOBlock +from .base_no import KernelNeuralOperator +from pina.utils import check_consistency + + +class AveragingNeuralOperator(KernelNeuralOperator): + """ + Implementation of Averaging Neural Operator. + + Averaging Neural Operator is a general architecture for + learning Operators. Unlike traditional machine learning methods + AveragingNeuralOperator is designed to map entire functions + to other functions. It can be trained with Supervised learning strategies. + AveragingNeuralOperator does convolution by performing a field average. + + .. seealso:: + + **Original reference**: Lanthaler S. Li, Z., Kovachki, + Stuart, A. (2020). *The Nonlocal Neural Operator: + Universal Approximation*. + DOI: `arXiv preprint arXiv:2304.13221. + `_ + """ + + def __init__( + self, + input_numb_fields, + output_numb_fields, + field_indices, + coordinates_indices, + dimension=3, + inner_size=100, + n_layers=4, + func=nn.GELU, + ): + """ + :param int input_numb_fields: The number of input components + of the model. + :param int output_numb_fields: The number of output components + of the model. + :param int dimension: the dimension of the domain of the functions. + :param int inner_size: number of neurons in the hidden layer(s). + Defaults to 100. + :param int n_layers: number of hidden layers. Default is 4. + :param func: the activation function to use. Default to nn.GELU. + :param list[str] field_indices: the label of the fields + in the input tensor. + :param list[str] coordinates_indices: the label of the + coordinates in the input tensor. + """ + + # check consistency + check_consistency(input_numb_fields, int) + check_consistency(output_numb_fields, int) + check_consistency(field_indices, str) + check_consistency(coordinates_indices, str) + check_consistency(dimension, int) + check_consistency(inner_size, int) + check_consistency(n_layers, int) + check_consistency(func, nn.Module, subclass=True) + + # assign + self.input_numb_fields = input_numb_fields + self.output_numb_fields = output_numb_fields + self.dimension = dimension + self.coordinates_indices = coordinates_indices + self.field_indices = field_indices + integral_net = nn.Sequential( + *[AVNOBlock(inner_size, func) for _ in range(n_layers)]) + lifting_net = FeedForward(dimension + input_numb_fields, inner_size, + inner_size, n_layers, func) + projection_net = FeedForward(inner_size + dimension, output_numb_fields, + inner_size, n_layers, func) + super().__init__(lifting_net, integral_net, projection_net) + + def forward(self, x): + r""" + Forward computation for Averaging Neural Operator. It performs a + lifting of the input by the ``lifting_net``. Then different layers + of Averaging Neural Operator Blocks are applied. + Finally the output is projected to the final dimensionality + by the ``projecting_net``. + + :param torch.Tensor x: The input tensor for fourier block, + depending on ``dimension`` in the initialization. It expects + a tensor :math:`B \times N \times D`, + where :math:`B` is the batch_size, :math:`N` the number of points + in the mesh, :math:`D` the dimension of the problem, i.e. the sum + of ``len(coordinates_indices)+len(field_indices)``. + :return: The output tensor obtained from Average Neural Operator. + :rtype: torch.Tensor + """ + points_tmp = x.extract(self.coordinates_indices) + features_tmp = x.extract(self.field_indices) + new_batch = concatenate((features_tmp, points_tmp), dim=2) + new_batch = self._lifting_operator(new_batch) + new_batch = self._integral_kernels(new_batch) + new_batch = concatenate((new_batch, points_tmp), dim=2) + new_batch = self._projection_operator(new_batch) + return new_batch diff --git a/pina/model/layers/__init__.py b/pina/model/layers/__init__.py index 77ee587a..2086e7a3 100644 --- a/pina/model/layers/__init__.py +++ b/pina/model/layers/__init__.py @@ -10,6 +10,7 @@ "FourierBlock3D", "PODBlock", "PeriodicBoundaryEmbedding", + "AVNOBlock", ] from .convolution_2d import ContinuousConvBlock @@ -22,3 +23,4 @@ from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D from .pod import PODBlock from .embedding import PeriodicBoundaryEmbedding +from .avno_layer import AVNOBlock diff --git a/pina/model/layers/avno_layer.py b/pina/model/layers/avno_layer.py new file mode 100644 index 00000000..9e91c616 --- /dev/null +++ b/pina/model/layers/avno_layer.py @@ -0,0 +1,67 @@ +""" Module for Averaging Neural Operator Layer class. """ + +from torch import nn, mean +from pina.utils import check_consistency + + +class AVNOBlock(nn.Module): + r""" + The PINA implementation of the inner layer of the Averaging Neural Operator. + + The operator layer performs an affine transformation where the convolution + is approximated with a local average. Given the input function + :math:`v(x)\in\mathbb{R}^{\rm{emb}}` the layer computes + the operator update :math:`K(v)` as: + + .. math:: + K(v) = \sigma\left(Wv(x) + b + \frac{1}{|\mathcal{A}|}\int v(y)dy\right) + + where: + + * :math:`\mathbb{R}^{\rm{emb}}` is the embedding (hidden) size + corresponding to the ``hidden_size`` object + * :math:`\sigma` is a non-linear activation, corresponding to the + ``func`` object + * :math:`W\in\mathbb{R}^{\rm{emb}\times\rm{emb}}` is a tunable matrix. + * :math:`b\in\mathbb{R}^{\rm{emb}}` is a tunable bias. + + .. seealso:: + + **Original reference**: Lanthaler S. Li, Z., Kovachki, + Stuart, A. (2020). *The Nonlocal Neural Operator: Universal + Approximation*. + DOI: `arXiv preprint arXiv:2304.13221. + `_ + + """ + + def __init__(self, hidden_size=100, func=nn.GELU): + """ + :param int hidden_size: Size of the hidden layer, defaults to 100. + :param func: The activation function, default to nn.GELU. + """ + super().__init__() + + # Check type consistency + check_consistency(hidden_size, int) + check_consistency(func, nn.Module, subclass=True) + # Assignment + self._nn = nn.Linear(hidden_size, hidden_size) + self._func = func() + + def forward(self, x): + r""" + Forward pass of the layer, it performs a sum of local average + and an affine transformation of the field. + + :param torch.Tensor x: The input tensor for performing the + computation. It expects a tensor :math:`B \times N \times D`, + where :math:`B` is the batch_size, :math:`N` the number of points + in the mesh, :math:`D` the dimension of the problem. In particular + :math:`D` is the codomain of the function :math:`v`. For example + a scalar function has :math:`D=1`, a 4-dimensional vector function + :math:`D=4`. + :return: The output tensor obtained from Average Neural Operator Block. + :rtype: torch.Tensor + """ + return self._func(self._nn(x) + mean(x, dim=1, keepdim=True)) diff --git a/tests/test_model/test_avno.py b/tests/test_model/test_avno.py new file mode 100644 index 00000000..a08f02c0 --- /dev/null +++ b/tests/test_model/test_avno.py @@ -0,0 +1,62 @@ +import torch +from pina.model import AveragingNeuralOperator +from pina import LabelTensor + +output_numb_fields = 5 +batch_size = 15 + + +def test_constructor(): + input_numb_fields = 1 + output_numb_fields = 1 + #minimuum constructor + AveragingNeuralOperator(input_numb_fields, + output_numb_fields, + coordinates_indices=['p'], + field_indices=['v']) + + #all constructor + AveragingNeuralOperator(input_numb_fields, + output_numb_fields, + inner_size=5, + n_layers=5, + func=torch.nn.ReLU, + coordinates_indices=['p'], + field_indices=['v']) + + +def test_forward(): + input_numb_fields = 1 + output_numb_fields = 1 + dimension = 1 + input_ = LabelTensor( + torch.rand(batch_size, 1000, input_numb_fields + dimension), ['p', 'v']) + ano = AveragingNeuralOperator(input_numb_fields, + output_numb_fields, + dimension=dimension, + coordinates_indices=['p'], + field_indices=['v']) + out = ano(input_) + assert out.shape == torch.Size( + [batch_size, input_.shape[1], output_numb_fields]) + + +def test_backward(): + input_numb_fields = 1 + dimension = 1 + output_numb_fields = 1 + input_ = LabelTensor( + torch.rand(batch_size, 1000, dimension + input_numb_fields), + ['p', 'v']) + input_ = input_.requires_grad_() + avno = AveragingNeuralOperator(input_numb_fields, + output_numb_fields, + dimension=dimension, + coordinates_indices=['p'], + field_indices=['v']) + out = avno(input_) + tmp = torch.linalg.norm(out) + tmp.backward() + grad = input_.grad + assert grad.shape == torch.Size( + [batch_size, input_.shape[1], dimension + input_numb_fields])