Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Averaging Neural Operator with tests and a tutorial #230

Merged
merged 29 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9b7a843
Added Averaging Neural Operator with tests and a tutorial
guglielmopadula Mar 4, 2024
3471d58
trying to fix codacy issues
guglielmopadula Feb 8, 2024
8b0f7a4
fixed refactoring error
guglielmopadula Feb 8, 2024
9f1f9ec
pep8 everywhere
guglielmopadula Feb 9, 2024
4301294
fixing codacy
guglielmopadula Mar 4, 2024
0cd36e0
added backward test
guglielmopadula Feb 9, 2024
2ac5724
added backward test
guglielmopadula Feb 9, 2024
8010a42
changed structure to one similar to deeponet
guglielmopadula Feb 10, 2024
33fe941
fixing codacy
guglielmopadula Feb 17, 2024
734f3ff
fixing property
guglielmopadula Feb 17, 2024
3aafb1d
codacy issues, converted avno_layer to dataclass
guglielmopadula Feb 17, 2024
2e9d2ad
reverting dataclass as only worsens things
guglielmopadula Feb 17, 2024
d5d6a2e
added func in avno layer
guglielmopadula Feb 17, 2024
4a79fc0
trying to fix last codacy error
guglielmopadula Feb 17, 2024
0e9416d
deleted another trailing whitespace
guglielmopadula Feb 17, 2024
b7f5684
Grammatic fixes
guglielmopadula Feb 17, 2024
d051161
other grammatic fixes
guglielmopadula Feb 17, 2024
9cc389b
fixed typos and adapted AVNO to KernelNO
guglielmopadula Feb 21, 2024
ca77b47
other codacy fixes
guglielmopadula Feb 21, 2024
1701920
pep8
guglielmopadula Feb 21, 2024
7cd694f
Testing a maybe fake cyclic import
guglielmopadula Feb 21, 2024
2133836
various fixes
guglielmopadula Mar 4, 2024
d60f7b2
various fixes
guglielmopadula Mar 4, 2024
5453e83
fixed typo
guglielmopadula Mar 4, 2024
fe318df
fixing the fixable codacy
guglielmopadula Mar 4, 2024
cefc4f4
removed tutorial, modified AVNO name and added docs
guglielmopadula Mar 4, 2024
1de2b0d
fixed init
guglielmopadula Mar 4, 2024
0b187f8
minor changes
Mar 5, 2024
494af4f
doc addition
Mar 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/_rst/_code.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ Models
MIONet <models/mionet.rst>
FourierIntegralKernel <models/fourier_kernel.rst>
FNO <models/fno.rst>
AveragingNeuralOperator <models/avno.rst>

Layers
-------------
Expand All @@ -67,10 +68,10 @@ Layers
EnhancedLinear layer <layers/enhanced_linear.rst>
Spectral convolution <layers/spectral.rst>
Fourier layers <layers/fourier.rst>
Averaging layer <layers/avno_layer.rst>
Continuous convolution <layers/convolution.rst>
Proper Orthogonal Decomposition <layers/pod.rst>
Periodic Boundary Condition embeddings <layers/embedding.rst>


Equations and Operators
-------------------------
Expand Down
8 changes: 8 additions & 0 deletions docs/source/_rst/layers/avno_layer.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Averaging layers
====================
.. currentmodule:: pina.model.layers.avno_layer

.. autoclass:: AVNOBlock
:members:
:show-inheritance:
:noindex:
7 changes: 7 additions & 0 deletions docs/source/_rst/models/avno.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Averaging Neural Operator
==============================
.. currentmodule:: pina.model.avno

.. autoclass:: AveragingNeuralOperator
:members:
:show-inheritance:
2 changes: 2 additions & 0 deletions pina/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
"FNO",
"FourierIntegralKernel",
"KernelNeuralOperator",
"AveragingNeuralOperator",
]

from .feed_forward import FeedForward, ResidualFeedForward
from .multi_feed_forward import MultiFeedForward
from .deeponet import DeepONet, MIONet
from .fno import FNO, FourierIntegralKernel
from .base_no import KernelNeuralOperator
from .avno import AveragingNeuralOperator
104 changes: 104 additions & 0 deletions pina/model/avno.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Module Averaging Neural Operator."""

from torch import nn, concatenate
from . import FeedForward
from .layers import AVNOBlock
from .base_no import KernelNeuralOperator
from pina.utils import check_consistency


class AveragingNeuralOperator(KernelNeuralOperator):
"""
Implementation of Averaging Neural Operator.

Averaging Neural Operator is a general architecture for
learning Operators. Unlike traditional machine learning methods
AveragingNeuralOperator is designed to map entire functions
to other functions. It can be trained with Supervised learning strategies.
AveragingNeuralOperator does convolution by performing a field average.

.. seealso::

**Original reference**: Lanthaler S. Li, Z., Kovachki,
Stuart, A. (2020). *The Nonlocal Neural Operator:
Universal Approximation*.
DOI: `arXiv preprint arXiv:2304.13221.
<https://arxiv.org/abs/2304.13221>`_
"""

def __init__(
self,
input_numb_fields,
output_numb_fields,
field_indices,
coordinates_indices,
dimension=3,
inner_size=100,
n_layers=4,
func=nn.GELU,
):
"""
:param int input_numb_fields: The number of input components
of the model.
:param int output_numb_fields: The number of output components
of the model.
:param int dimension: the dimension of the domain of the functions.
:param int inner_size: number of neurons in the hidden layer(s).
Defaults to 100.
:param int n_layers: number of hidden layers. Default is 4.
:param func: the activation function to use. Default to nn.GELU.
:param list[str] field_indices: the label of the fields
in the input tensor.
:param list[str] coordinates_indices: the label of the
coordinates in the input tensor.
"""

# check consistency
check_consistency(input_numb_fields, int)
check_consistency(output_numb_fields, int)
check_consistency(field_indices, str)
check_consistency(coordinates_indices, str)
check_consistency(dimension, int)
check_consistency(inner_size, int)
check_consistency(n_layers, int)
check_consistency(func, nn.Module, subclass=True)

# assign
self.input_numb_fields = input_numb_fields
self.output_numb_fields = output_numb_fields
self.dimension = dimension
self.coordinates_indices = coordinates_indices
self.field_indices = field_indices
integral_net = nn.Sequential(
*[AVNOBlock(inner_size, func) for _ in range(n_layers)])
lifting_net = FeedForward(dimension + input_numb_fields, inner_size,
inner_size, n_layers, func)
projection_net = FeedForward(inner_size + dimension, output_numb_fields,
inner_size, n_layers, func)
super().__init__(lifting_net, integral_net, projection_net)

def forward(self, x):
r"""
Forward computation for Averaging Neural Operator. It performs a
lifting of the input by the ``lifting_net``. Then different layers
of Averaging Neural Operator Blocks are applied.
Finally the output is projected to the final dimensionality
by the ``projecting_net``.

:param torch.Tensor x: The input tensor for fourier block,
depending on ``dimension`` in the initialization. It expects
a tensor :math:`B \times N \times D`,
where :math:`B` is the batch_size, :math:`N` the number of points
in the mesh, :math:`D` the dimension of the problem, i.e. the sum
of ``len(coordinates_indices)+len(field_indices)``.
:return: The output tensor obtained from Average Neural Operator.
:rtype: torch.Tensor
"""
points_tmp = x.extract(self.coordinates_indices)
features_tmp = x.extract(self.field_indices)
new_batch = concatenate((features_tmp, points_tmp), dim=2)
new_batch = self._lifting_operator(new_batch)
new_batch = self._integral_kernels(new_batch)
new_batch = concatenate((new_batch, points_tmp), dim=2)
new_batch = self._projection_operator(new_batch)
return new_batch
2 changes: 2 additions & 0 deletions pina/model/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"FourierBlock3D",
"PODBlock",
"PeriodicBoundaryEmbedding",
"AVNOBlock",
]

from .convolution_2d import ContinuousConvBlock
Expand All @@ -22,3 +23,4 @@
from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
from .pod import PODBlock
from .embedding import PeriodicBoundaryEmbedding
from .avno_layer import AVNOBlock
67 changes: 67 additions & 0 deletions pina/model/layers/avno_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
""" Module for Averaging Neural Operator Layer class. """

from torch import nn, mean
from pina.utils import check_consistency


class AVNOBlock(nn.Module):
r"""
The PINA implementation of the inner layer of the Averaging Neural Operator.

The operator layer performs an affine transformation where the convolution
is approximated with a local average. Given the input function
:math:`v(x)\in\mathbb{R}^{\rm{emb}}` the layer computes
the operator update :math:`K(v)` as:

.. math::
K(v) = \sigma\left(Wv(x) + b + \frac{1}{|\mathcal{A}|}\int v(y)dy\right)

where:

* :math:`\mathbb{R}^{\rm{emb}}` is the embedding (hidden) size
corresponding to the ``hidden_size`` object
* :math:`\sigma` is a non-linear activation, corresponding to the
``func`` object
* :math:`W\in\mathbb{R}^{\rm{emb}\times\rm{emb}}` is a tunable matrix.
* :math:`b\in\mathbb{R}^{\rm{emb}}` is a tunable bias.

.. seealso::

**Original reference**: Lanthaler S. Li, Z., Kovachki,
Stuart, A. (2020). *The Nonlocal Neural Operator: Universal
Approximation*.
DOI: `arXiv preprint arXiv:2304.13221.
<https://arxiv.org/abs/2304.13221>`_

"""

def __init__(self, hidden_size=100, func=nn.GELU):
"""
:param int hidden_size: Size of the hidden layer, defaults to 100.
:param func: The activation function, default to nn.GELU.
"""
super().__init__()

# Check type consistency
check_consistency(hidden_size, int)
check_consistency(func, nn.Module, subclass=True)
# Assignment
self._nn = nn.Linear(hidden_size, hidden_size)
self._func = func()

def forward(self, x):
r"""
Forward pass of the layer, it performs a sum of local average
and an affine transformation of the field.

:param torch.Tensor x: The input tensor for performing the
computation. It expects a tensor :math:`B \times N \times D`,
where :math:`B` is the batch_size, :math:`N` the number of points
in the mesh, :math:`D` the dimension of the problem. In particular
:math:`D` is the codomain of the function :math:`v`. For example
a scalar function has :math:`D=1`, a 4-dimensional vector function
:math:`D=4`.
:return: The output tensor obtained from Average Neural Operator Block.
:rtype: torch.Tensor
"""
return self._func(self._nn(x) + mean(x, dim=1, keepdim=True))
62 changes: 62 additions & 0 deletions tests/test_model/test_avno.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch
from pina.model import AveragingNeuralOperator
from pina import LabelTensor

output_numb_fields = 5
batch_size = 15


def test_constructor():
input_numb_fields = 1
output_numb_fields = 1
#minimuum constructor
AveragingNeuralOperator(input_numb_fields,
output_numb_fields,
coordinates_indices=['p'],
field_indices=['v'])

#all constructor
AveragingNeuralOperator(input_numb_fields,
output_numb_fields,
inner_size=5,
n_layers=5,
func=torch.nn.ReLU,
coordinates_indices=['p'],
field_indices=['v'])

dario-coscia marked this conversation as resolved.
Show resolved Hide resolved

def test_forward():
input_numb_fields = 1
output_numb_fields = 1
dimension = 1
input_ = LabelTensor(
torch.rand(batch_size, 1000, input_numb_fields + dimension), ['p', 'v'])
ano = AveragingNeuralOperator(input_numb_fields,
output_numb_fields,
dimension=dimension,
coordinates_indices=['p'],
field_indices=['v'])
out = ano(input_)
assert out.shape == torch.Size(
[batch_size, input_.shape[1], output_numb_fields])


def test_backward():
input_numb_fields = 1
dimension = 1
output_numb_fields = 1
input_ = LabelTensor(
torch.rand(batch_size, 1000, dimension + input_numb_fields),
['p', 'v'])
input_ = input_.requires_grad_()
avno = AveragingNeuralOperator(input_numb_fields,
output_numb_fields,
dimension=dimension,
coordinates_indices=['p'],
field_indices=['v'])
out = avno(input_)
tmp = torch.linalg.norm(out)
tmp.backward()
grad = input_.grad
assert grad.shape == torch.Size(
[batch_size, input_.shape[1], dimension + input_numb_fields])
Loading