From 9b7a84371124a36397d1cdbd95669b5e520abd0c Mon Sep 17 00:00:00 2001 From: cyberguli Date: Mon, 4 Mar 2024 09:22:59 +0100 Subject: [PATCH 01/29] Added Averaging Neural Operator with tests and a tutorial --- pina/model/__init__.py | 2 + pina/model/avno.py | 58 ++++++++++++++++++++++ tests/test_model/test_avno.py | 28 +++++++++++ tutorials/tutorial10/tutorial.py | 84 ++++++++++++++++++++++++++++++++ 4 files changed, 172 insertions(+) create mode 100644 pina/model/avno.py create mode 100644 tests/test_model/test_avno.py create mode 100644 tutorials/tutorial10/tutorial.py diff --git a/pina/model/__init__.py b/pina/model/__init__.py index 869a4365..14c9f2b3 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -7,6 +7,7 @@ "FNO", "FourierIntegralKernel", "KernelNeuralOperator", + "AVNO", ] from .feed_forward import FeedForward, ResidualFeedForward @@ -14,3 +15,4 @@ from .deeponet import DeepONet, MIONet from .fno import FNO, FourierIntegralKernel from .base_no import KernelNeuralOperator +from .avno import AVNO diff --git a/pina/model/avno.py b/pina/model/avno.py new file mode 100644 index 00000000..8e1536f3 --- /dev/null +++ b/pina/model/avno.py @@ -0,0 +1,58 @@ +import torch +from . import FeedForward +from torch import nn + + + +class AVNOLayer(nn.Module): + def __init__(self,hidden_size,func): + super().__init__() + self.nn=nn.Linear(hidden_size,hidden_size) + self.func=func + + def forward(self,batch): + return self.func()(self.nn(batch)+torch.mean(batch,dim=1).unsqueeze(1)) + +class AVNO(nn.Module): + def __init__(self, + input_features, + output_features, + points, + inner_size=100, + n_layers=4, + func=nn.GELU, + ): + + super().__init__() + self.input_features=input_features + self.output_features=output_features + self.num_points=points.shape[0] + self.points_size=points.shape[1] + self.lifting=FeedForward(input_features+self.points_size,inner_size,inner_size,n_layers,func) + self.nn=nn.Sequential(*[AVNOLayer(inner_size,func) for _ in range(n_layers)]) + self.projection=FeedForward(inner_size+self.points_size,output_features,output_features,n_layers,func) + self.points=points + + def forward(self, batch): + points_tmp=self.points.unsqueeze(0).repeat(batch.shape[0],1,1) + new_batch=torch.concatenate((batch,points_tmp),dim=2) + new_batch=self.lifting(new_batch) + new_batch=self.nn(new_batch) + new_batch=torch.concatenate((new_batch,points_tmp),dim=2) + new_batch=self.projection(new_batch) + return new_batch + + def forward_eval(self,batch,points): + points_tmp=points.unsqueeze(0).repeat(batch.shape[0],1,1) + new_batch=torch.concatenate((batch,points_tmp),dim=2) + new_batch=self.lifting(new_batch) + new_batch=self.nn(new_batch) + new_batch=torch.concatenate((new_batch,points_tmp),dim=2) + new_batch=self.projection(new_batch) + return new_batch + + + + + + diff --git a/tests/test_model/test_avno.py b/tests/test_model/test_avno.py new file mode 100644 index 00000000..9d8c7b90 --- /dev/null +++ b/tests/test_model/test_avno.py @@ -0,0 +1,28 @@ +import torch +from pina.model import AVNO + +output_channels = 5 +batch_size = 15 + + +def test_constructor(): + input_channels = 1 + output_channels = 1 + #minimuum constructor + AVNO(input_channels, output_channels, torch.rand(10000, 2)) + + #all constructor + AVNO(input_channels, output_channels, torch.rand(100, 2), inner_size=5,n_layers=5,func=torch.nn.ReLU) + + + +def test_forward(): + input_channels = 1 + output_channels = 1 + input_ = torch.rand(batch_size, 1000, input_channels) + points=torch.rand(1000,2) + ano = AVNO(input_channels, output_channels, points) + out = ano(input_) + assert out.shape == torch.Size([batch_size, points.shape[0], output_channels]) + + diff --git a/tutorials/tutorial10/tutorial.py b/tutorials/tutorial10/tutorial.py new file mode 100644 index 00000000..7f5b41cd --- /dev/null +++ b/tutorials/tutorial10/tutorial.py @@ -0,0 +1,84 @@ +import torch +from time import time +#Data generation + +torch.manual_seed(0) + +def sample_unit_circle(num_points): + radius=torch.rand(num_points,1) + angle=torch.rand(num_points,1)*2*torch.pi + x=radius*torch.cos(angle) + y=radius*torch.sin(angle) + data=torch.cat((x,y),dim=1) + return data + +#sin(a*x+b*y) +def compute_input(data,theta): + data=data.reshape(1,-1,2) + z=torch.sin(theta[:,:,0]*data[:,:,0]+theta[:,:,1]*data[:,:,1]) + return z + +#1+convolution of sin(a*x+b*y) with sin(x) over [0,2pi]x[0,2pi ] +def compute_output(data,theta): + data=data.reshape(1,-1,2) + z=1-4*torch.sin(torch.pi*theta[:,:,0])*torch.sin(torch.pi*theta[:,:,1])*torch.cos(theta[:,:,0]*(torch.pi*data[:,:,0])+theta[:,:,1]*(torch.pi*data[:,:,1]))/((theta[:,:,0]**2-1)*theta[:,:,1]) + return z + + + +theta=1+0.01*torch.rand(300,1,2) +data_coarse=sample_unit_circle(1000) +output_coarse=compute_output(data_coarse,theta).unsqueeze(-1) +input_coarse=compute_input(data_coarse,theta).unsqueeze(-1) +data_dense=sample_unit_circle(1000) +output_dense=compute_output(data_dense,theta).unsqueeze(-1) +input_dense=compute_input(data_dense,theta).unsqueeze(-1) + +from pina.model import AVNO +from pina import Condition,LabelTensor +from pina.problem import AbstractProblem +from pina.solvers import SupervisedSolver +from pina.trainer import Trainer + +model=AVNO(1,1,data_coarse,inner_size=500,n_layers=4) +class ANOSolver(AbstractProblem): + input_variables=['input'] + input_points=LabelTensor(input_coarse,input_variables) + output_variables=['output'] + output_points=LabelTensor(output_coarse,output_variables) + conditions={"data":Condition(input_points=input_points,output_points=output_points)} + +batch_size=1 +problem=ANOSolver() +solver=SupervisedSolver(problem,model,optimizer_kwargs={'lr':1e-3},optimizer=torch.optim.AdamW) +trainer=Trainer(solver=solver,max_epochs=5,accelerator='cpu',enable_model_summary=False,batch_size=batch_size) +from pina.loss import LpLoss +loss=LpLoss(2,relative=True) + +start_time=time() +trainer.train() +end_time=time() +print(end_time-start_time) +solver.neural_net=solver.neural_net.eval() +loss=torch.nn.MSELoss() +num_batches=len(input_coarse)//batch_size +num=0 +dem=0 +for i in range(num_batches): + input_variables=['input'] + myinput=LabelTensor(input_coarse[i].unsqueeze(0),input_variables) + tmp=model(myinput).detach().squeeze(0) + num=num+torch.linalg.norm(tmp-output_coarse[i])**2 + dem=dem+torch.linalg.norm(output_coarse[i])**2 +print("Training mse loss is", torch.sqrt(num/dem)) + + +num=0 +dem=0 +for i in range(num_batches): + input_variables=['input'] + myinput=LabelTensor(input_dense[i].unsqueeze(0),input_variables) + tmp=model.forward_eval(myinput,data_dense).detach().squeeze(0) + num=num+torch.linalg.norm(tmp-output_dense[i])**2 + dem=dem+torch.linalg.norm(output_dense[i])**2 +print("Super Resolution mse loss is", torch.sqrt(num/dem)) From 3471d58d46c86923f74f8f74968bc6d64ed58679 Mon Sep 17 00:00:00 2001 From: cyberguli Date: Thu, 8 Feb 2024 20:47:28 +0100 Subject: [PATCH 02/29] trying to fix codacy issues --- pina/model/avno.py | 70 +++++++++++++++++++++++++++++++++------------- 1 file changed, 51 insertions(+), 19 deletions(-) diff --git a/pina/model/avno.py b/pina/model/avno.py index 8e1536f3..eb97d2df 100644 --- a/pina/model/avno.py +++ b/pina/model/avno.py @@ -1,19 +1,38 @@ -import torch -from . import FeedForward -from torch import nn +"""Module Averaging Neural Operator.""" +from torch import nn, mean, concatenate +from . import FeedForward class AVNOLayer(nn.Module): + """ + The PINA implementation of the inner layer of the Averaging Neural Operator . + + :param int hidden_size: size of the layer. + :param func: the activation function to use. + """ + def __init__(self,hidden_size,func): super().__init__() self.nn=nn.Linear(hidden_size,hidden_size) self.func=func - + def forward(self,batch): - return self.func()(self.nn(batch)+torch.mean(batch,dim=1).unsqueeze(1)) + return self.func()(self.nn(batch)+mean(batch,dim=1).unsqueeze(1)) class AVNO(nn.Module): + """ + The PINA implementation of the inner layer of the Averaging Neural Operator. + + :param int input_features: The number of input components of the model. + :param int output_features: The number of output components of the model. + :param ndarray points: the points in the training set. + :param int inner_size: number of neurons in the hidden layer(s). Default is 100. + :param int n_layers: number of hidden layers. Default is 4. + :param func: the activation function to use. Default to nn.GELU. + + """ + def __init__(self, input_features, output_features, @@ -28,31 +47,44 @@ def __init__(self, self.output_features=output_features self.num_points=points.shape[0] self.points_size=points.shape[1] - self.lifting=FeedForward(input_features+self.points_size,inner_size,inner_size,n_layers,func) - self.nn=nn.Sequential(*[AVNOLayer(inner_size,func) for _ in range(n_layers)]) - self.projection=FeedForward(inner_size+self.points_size,output_features,output_features,n_layers,func) + self.lifting=FeedForward(input_features+self. + points_size, + inner_size, + inner_size,n_layers,func) + self.nn=nn.Sequential(*[AVNOLayer(inner_size,func) + for _ in range(n_layers)]) + self.projection=FeedForward(inner_size+self.points_size, + inner_size, + output_features,n_layers,func) self.points=points def forward(self, batch): + """ + Computes the forward pass of the model with the points specified in init. + + :param torch.Tensor batch: the input tensor. + + """ points_tmp=self.points.unsqueeze(0).repeat(batch.shape[0],1,1) - new_batch=torch.concatenate((batch,points_tmp),dim=2) + new_batch=concatenate((batch,points_tmp),dim=2) new_batch=self.lifting(new_batch) new_batch=self.nn(new_batch) - new_batch=torch.concatenate((new_batch,points_tmp),dim=2) + new_batch=concatenate((new_batch,points_tmp),dim=2) new_batch=self.projection(new_batch) return new_batch def forward_eval(self,batch,points): + """ + Computes the forward pass of the model with the points specified when calling the function. + + :param torch.Tensor batch: the input tensor. + :param torch.Tensor points: the points tensor. + + """ points_tmp=points.unsqueeze(0).repeat(batch.shape[0],1,1) - new_batch=torch.concatenate((batch,points_tmp),dim=2) + new_batch=concatenate((batch,points_tmp),dim=2) new_batch=self.lifting(new_batch) new_batch=self.nn(new_batch) - new_batch=torch.concatenate((new_batch,points_tmp),dim=2) + new_batch=concatenate((new_batch,points_tmp),dim=2) new_batch=self.projection(new_batch) - return new_batch - - - - - - + return new_batch \ No newline at end of file From 8b0f7a4d3f4bbed8a5391e87bd9b9427b9beee82 Mon Sep 17 00:00:00 2001 From: cyberguli Date: Thu, 8 Feb 2024 21:00:05 +0100 Subject: [PATCH 03/29] fixed refactoring error --- pina/model/avno.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pina/model/avno.py b/pina/model/avno.py index eb97d2df..74a7ecae 100644 --- a/pina/model/avno.py +++ b/pina/model/avno.py @@ -54,8 +54,8 @@ def __init__(self, self.nn=nn.Sequential(*[AVNOLayer(inner_size,func) for _ in range(n_layers)]) self.projection=FeedForward(inner_size+self.points_size, - inner_size, - output_features,n_layers,func) + output_features, + inner_size,n_layers,func) self.points=points def forward(self, batch): From 9f1f9ec7204fff8d454a1b511835a7323e5b7c5c Mon Sep 17 00:00:00 2001 From: cyberguli Date: Fri, 9 Feb 2024 16:30:01 +0100 Subject: [PATCH 04/29] pep8 everywhere --- pina/model/avno.py | 92 ++++++++++++++------------------- pina/model/layers/__init__.py | 3 ++ pina/model/layers/avno_layer.py | 20 +++++++ 3 files changed, 62 insertions(+), 53 deletions(-) create mode 100644 pina/model/layers/avno_layer.py diff --git a/pina/model/avno.py b/pina/model/avno.py index 74a7ecae..3cde5fa4 100644 --- a/pina/model/avno.py +++ b/pina/model/avno.py @@ -2,23 +2,10 @@ from torch import nn, mean, concatenate from . import FeedForward +from .layers import AVNOLayer -class AVNOLayer(nn.Module): - """ - The PINA implementation of the inner layer of the Averaging Neural Operator . - - :param int hidden_size: size of the layer. - :param func: the activation function to use. - """ - - def __init__(self,hidden_size,func): - super().__init__() - self.nn=nn.Linear(hidden_size,hidden_size) - self.func=func - def forward(self,batch): - return self.func()(self.nn(batch)+mean(batch,dim=1).unsqueeze(1)) class AVNO(nn.Module): """ @@ -33,31 +20,30 @@ class AVNO(nn.Module): """ - def __init__(self, - input_features, - output_features, - points, - inner_size=100, - n_layers=4, - func=nn.GELU, - ): - + def __init__( + self, + input_features, + output_features, + points, + inner_size=100, + n_layers=4, + func=nn.GELU, + ): + super().__init__() - self.input_features=input_features - self.output_features=output_features - self.num_points=points.shape[0] - self.points_size=points.shape[1] - self.lifting=FeedForward(input_features+self. - points_size, - inner_size, - inner_size,n_layers,func) - self.nn=nn.Sequential(*[AVNOLayer(inner_size,func) - for _ in range(n_layers)]) - self.projection=FeedForward(inner_size+self.points_size, - output_features, - inner_size,n_layers,func) - self.points=points - + self.input_features = input_features + self.output_features = output_features + self.num_points = points.shape[0] + self.points_size = points.shape[1] + self.lifting = FeedForward(input_features + self.points_size, + inner_size, inner_size, n_layers, func) + self.nn = nn.Sequential( + *[AVNOLayer(inner_size, func) for _ in range(n_layers)]) + self.projection = FeedForward(inner_size + self.points_size, + output_features, inner_size, n_layers, + func) + self.points = points + def forward(self, batch): """ Computes the forward pass of the model with the points specified in init. @@ -65,15 +51,15 @@ def forward(self, batch): :param torch.Tensor batch: the input tensor. """ - points_tmp=self.points.unsqueeze(0).repeat(batch.shape[0],1,1) - new_batch=concatenate((batch,points_tmp),dim=2) - new_batch=self.lifting(new_batch) - new_batch=self.nn(new_batch) - new_batch=concatenate((new_batch,points_tmp),dim=2) - new_batch=self.projection(new_batch) + points_tmp = self.points.unsqueeze(0).repeat(batch.shape[0], 1, 1) + new_batch = concatenate((batch, points_tmp), dim=2) + new_batch = self.lifting(new_batch) + new_batch = self.nn(new_batch) + new_batch = concatenate((new_batch, points_tmp), dim=2) + new_batch = self.projection(new_batch) return new_batch - - def forward_eval(self,batch,points): + + def forward_eval(self, batch, points): """ Computes the forward pass of the model with the points specified when calling the function. @@ -81,10 +67,10 @@ def forward_eval(self,batch,points): :param torch.Tensor points: the points tensor. """ - points_tmp=points.unsqueeze(0).repeat(batch.shape[0],1,1) - new_batch=concatenate((batch,points_tmp),dim=2) - new_batch=self.lifting(new_batch) - new_batch=self.nn(new_batch) - new_batch=concatenate((new_batch,points_tmp),dim=2) - new_batch=self.projection(new_batch) - return new_batch \ No newline at end of file + points_tmp = points.unsqueeze(0).repeat(batch.shape[0], 1, 1) + new_batch = concatenate((batch, points_tmp), dim=2) + new_batch = self.lifting(new_batch) + new_batch = self.nn(new_batch) + new_batch = concatenate((new_batch, points_tmp), dim=2) + new_batch = self.projection(new_batch) + return new_batch diff --git a/pina/model/layers/__init__.py b/pina/model/layers/__init__.py index 77ee587a..fdf7c714 100644 --- a/pina/model/layers/__init__.py +++ b/pina/model/layers/__init__.py @@ -10,6 +10,7 @@ "FourierBlock3D", "PODBlock", "PeriodicBoundaryEmbedding", + "AVNOLayer", ] from .convolution_2d import ContinuousConvBlock @@ -22,3 +23,5 @@ from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D from .pod import PODBlock from .embedding import PeriodicBoundaryEmbedding +from .avno_layer import AVNOLayer + diff --git a/pina/model/layers/avno_layer.py b/pina/model/layers/avno_layer.py new file mode 100644 index 00000000..5d919359 --- /dev/null +++ b/pina/model/layers/avno_layer.py @@ -0,0 +1,20 @@ +"""Module for Averaging Neural Operator class.""" +from torch import nn, mean + + +class AVNOLayer(nn.Module): + """ + The PINA implementation of the inner layer of the Averaging Neural Operator . + + :param int hidden_size: size of the layer. + :param func: the activation function to use. + + """ + + def __init__(self, hidden_size, func): + super().__init__() + self.nn = nn.Linear(hidden_size, hidden_size) + self.func = func + + def forward(self, batch): + return self.func()(self.nn(batch) + mean(batch, dim=1).unsqueeze(1)) From 43012943ec01db970f728c1702cb526aa96684cb Mon Sep 17 00:00:00 2001 From: cyberguli Date: Mon, 4 Mar 2024 09:29:45 +0100 Subject: [PATCH 05/29] fixing codacy --- pina/model/avno.py | 2 +- pina/model/layers/avno_layer.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pina/model/avno.py b/pina/model/avno.py index 3cde5fa4..5bf6da3f 100644 --- a/pina/model/avno.py +++ b/pina/model/avno.py @@ -1,6 +1,6 @@ """Module Averaging Neural Operator.""" -from torch import nn, mean, concatenate +from torch import nn, concatenate from . import FeedForward from .layers import AVNOLayer diff --git a/pina/model/layers/avno_layer.py b/pina/model/layers/avno_layer.py index 5d919359..0ba1a6c5 100644 --- a/pina/model/layers/avno_layer.py +++ b/pina/model/layers/avno_layer.py @@ -1,4 +1,4 @@ -"""Module for Averaging Neural Operator class.""" +"""Module for Averaging Neural Operator Layer class.""" from torch import nn, mean @@ -17,4 +17,5 @@ def __init__(self, hidden_size, func): self.func = func def forward(self, batch): + """Forward pass of the layer.""" return self.func()(self.nn(batch) + mean(batch, dim=1).unsqueeze(1)) From 0cd36e07c0b88c745c44cbb8a29b5591b603d673 Mon Sep 17 00:00:00 2001 From: cyberguli Date: Fri, 9 Feb 2024 17:11:32 +0100 Subject: [PATCH 06/29] added backward test --- tests/test_model/test_avno.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/tests/test_model/test_avno.py b/tests/test_model/test_avno.py index 9d8c7b90..10c8f941 100644 --- a/tests/test_model/test_avno.py +++ b/tests/test_model/test_avno.py @@ -12,17 +12,35 @@ def test_constructor(): AVNO(input_channels, output_channels, torch.rand(10000, 2)) #all constructor - AVNO(input_channels, output_channels, torch.rand(100, 2), inner_size=5,n_layers=5,func=torch.nn.ReLU) - + AVNO(input_channels, + output_channels, + torch.rand(100, 2), + inner_size=5, + n_layers=5, + func=torch.nn.ReLU) def test_forward(): input_channels = 1 output_channels = 1 input_ = torch.rand(batch_size, 1000, input_channels) - points=torch.rand(1000,2) + points = torch.rand(1000, 2) ano = AVNO(input_channels, output_channels, points) out = ano(input_) - assert out.shape == torch.Size([batch_size, points.shape[0], output_channels]) + assert out.shape == torch.Size( + [batch_size, points.shape[0], output_channels]) +def test_backward(): + input_channels = 2 + output_channels = 1 + input_ = torch.rand(batch_size, 1000, input_channels) + input_ = input_.requires_grad_() + points = torch.rand(1000, 2) + ano = AVNO(input_channels, output_channels, points) + out = ano(input_) + tmp = torch.linalg.norm(out) + tmp.backward() + grad = input_.grad + assert grad.shape == torch.Size( + [batch_size, points.shape[0], input_channels]) From 2ac572424143603b59c6220b957b427b6c44bab0 Mon Sep 17 00:00:00 2001 From: cyberguli Date: Fri, 9 Feb 2024 17:15:29 +0100 Subject: [PATCH 07/29] added backward test --- tests/test_model/test_avno.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_model/test_avno.py b/tests/test_model/test_avno.py index 10c8f941..024870e8 100644 --- a/tests/test_model/test_avno.py +++ b/tests/test_model/test_avno.py @@ -37,6 +37,7 @@ def test_backward(): input_ = torch.rand(batch_size, 1000, input_channels) input_ = input_.requires_grad_() points = torch.rand(1000, 2) + points = points.requires_grad_() ano = AVNO(input_channels, output_channels, points) out = ano(input_) tmp = torch.linalg.norm(out) @@ -44,3 +45,7 @@ def test_backward(): grad = input_.grad assert grad.shape == torch.Size( [batch_size, points.shape[0], input_channels]) + + grad = points.grad + print(grad.shape) + assert grad.shape == torch.Size([points.shape[0], 2]) From 8010a42e3586f61d988b47d7656d2ad4862e35d9 Mon Sep 17 00:00:00 2001 From: cyberguli Date: Sat, 10 Feb 2024 14:47:03 +0100 Subject: [PATCH 08/29] changed structure to one similar to deeponet --- pina/model/avno.py | 45 ++++++++++++++++------------------- tests/test_model/test_avno.py | 28 +++++++++++----------- 2 files changed, 34 insertions(+), 39 deletions(-) diff --git a/pina/model/avno.py b/pina/model/avno.py index 5bf6da3f..4877741a 100644 --- a/pina/model/avno.py +++ b/pina/model/avno.py @@ -5,26 +5,27 @@ from .layers import AVNOLayer - - class AVNO(nn.Module): """ The PINA implementation of the inner layer of the Averaging Neural Operator. :param int input_features: The number of input components of the model. :param int output_features: The number of output components of the model. - :param ndarray points: the points in the training set. + :param int points_size: the dimension of the domain of the functions. :param int inner_size: number of neurons in the hidden layer(s). Default is 100. :param int n_layers: number of hidden layers. Default is 4. :param func: the activation function to use. Default to nn.GELU. - + :param str features_label: the label of the features in the input tensor. Default to 'v'. + :param str points_label: the label of the points in the input tensor. Default to 'p'. """ def __init__( self, input_features, output_features, - points, + features_label='v', + points_label='p', + points_size=3, inner_size=100, n_layers=4, func=nn.GELU, @@ -33,8 +34,9 @@ def __init__( super().__init__() self.input_features = input_features self.output_features = output_features - self.num_points = points.shape[0] - self.points_size = points.shape[1] + self.points_size = points_size + self.points_label = points_label + self.features_label = features_label self.lifting = FeedForward(input_features + self.points_size, inner_size, inner_size, n_layers, func) self.nn = nn.Sequential( @@ -42,24 +44,8 @@ def __init__( self.projection = FeedForward(inner_size + self.points_size, output_features, inner_size, n_layers, func) - self.points = points def forward(self, batch): - """ - Computes the forward pass of the model with the points specified in init. - - :param torch.Tensor batch: the input tensor. - - """ - points_tmp = self.points.unsqueeze(0).repeat(batch.shape[0], 1, 1) - new_batch = concatenate((batch, points_tmp), dim=2) - new_batch = self.lifting(new_batch) - new_batch = self.nn(new_batch) - new_batch = concatenate((new_batch, points_tmp), dim=2) - new_batch = self.projection(new_batch) - return new_batch - - def forward_eval(self, batch, points): """ Computes the forward pass of the model with the points specified when calling the function. @@ -67,8 +53,17 @@ def forward_eval(self, batch, points): :param torch.Tensor points: the points tensor. """ - points_tmp = points.unsqueeze(0).repeat(batch.shape[0], 1, 1) - new_batch = concatenate((batch, points_tmp), dim=2) + points_tmp = concatenate([ + batch.extract(self.points_label + "_{}".format(i)) + for i in range(self.points_size) + ], + axis=2) + features_tmp = concatenate([ + batch.extract(self.features_label + "_{}".format(i)) + for i in range(self.input_features) + ], + axis=2) + new_batch = concatenate((features_tmp, points_tmp), dim=2) new_batch = self.lifting(new_batch) new_batch = self.nn(new_batch) new_batch = concatenate((new_batch, points_tmp), dim=2) diff --git a/tests/test_model/test_avno.py b/tests/test_model/test_avno.py index 024870e8..4c516c66 100644 --- a/tests/test_model/test_avno.py +++ b/tests/test_model/test_avno.py @@ -1,5 +1,6 @@ import torch from pina.model import AVNO +from pina import LabelTensor output_channels = 5 batch_size = 15 @@ -23,29 +24,28 @@ def test_constructor(): def test_forward(): input_channels = 1 output_channels = 1 - input_ = torch.rand(batch_size, 1000, input_channels) - points = torch.rand(1000, 2) - ano = AVNO(input_channels, output_channels, points) + points_size = 1 + input_ = LabelTensor( + torch.rand(batch_size, 1000, input_channels + points_size), + ['p_0', 'v_0']) + ano = AVNO(input_channels, output_channels, points_size=points_size) out = ano(input_) assert out.shape == torch.Size( - [batch_size, points.shape[0], output_channels]) + [batch_size, input_.shape[1], output_channels]) def test_backward(): - input_channels = 2 + input_channels = 1 + points_size = 1 output_channels = 1 - input_ = torch.rand(batch_size, 1000, input_channels) + input_ = LabelTensor( + torch.rand(batch_size, 1000, input_channels + points_size), + ['p_0', 'v_0']) input_ = input_.requires_grad_() - points = torch.rand(1000, 2) - points = points.requires_grad_() - ano = AVNO(input_channels, output_channels, points) + ano = AVNO(input_channels, output_channels, points_size=points_size) out = ano(input_) tmp = torch.linalg.norm(out) tmp.backward() grad = input_.grad assert grad.shape == torch.Size( - [batch_size, points.shape[0], input_channels]) - - grad = points.grad - print(grad.shape) - assert grad.shape == torch.Size([points.shape[0], 2]) + [batch_size, input_.shape[1], input_channels + points_size]) From 33fe9410e2dc300fc78e10fed16598f9534af9a6 Mon Sep 17 00:00:00 2001 From: cyberguli Date: Sat, 17 Feb 2024 16:57:25 +0100 Subject: [PATCH 09/29] fixing codacy --- pina/model/avno.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/pina/model/avno.py b/pina/model/avno.py index 4877741a..a90891cb 100644 --- a/pina/model/avno.py +++ b/pina/model/avno.py @@ -7,16 +7,20 @@ class AVNO(nn.Module): """ - The PINA implementation of the inner layer of the Averaging Neural Operator. + The PINA implementation of the inner layer + of the Averaging Neural Operator. :param int input_features: The number of input components of the model. :param int output_features: The number of output components of the model. :param int points_size: the dimension of the domain of the functions. - :param int inner_size: number of neurons in the hidden layer(s). Default is 100. + :param int inner_size: number of neurons in the hidden layer(s). + Default is 100. :param int n_layers: number of hidden layers. Default is 4. :param func: the activation function to use. Default to nn.GELU. - :param str features_label: the label of the features in the input tensor. Default to 'v'. - :param str points_label: the label of the points in the input tensor. Default to 'p'. + :param str features_label: the label of the features in the input tensor. + Default to 'v'. + :param str points_label: the label of the points in the input tensor. + Default to 'p'. """ def __init__( @@ -47,19 +51,20 @@ def __init__( def forward(self, batch): """ - Computes the forward pass of the model with the points specified when calling the function. + Computes the forward pass of the model with the points + specified when calling the function. :param torch.Tensor batch: the input tensor. :param torch.Tensor points: the points tensor. """ points_tmp = concatenate([ - batch.extract(self.points_label + "_{}".format(i)) + batch.extract(f"{self.points_label}_{i}") for i in range(self.points_size) ], axis=2) features_tmp = concatenate([ - batch.extract(self.features_label + "_{}".format(i)) + batch.extract(f"{self.features_label}_{i}") for i in range(self.input_features) ], axis=2) @@ -69,3 +74,15 @@ def forward(self, batch): new_batch = concatenate((new_batch, points_tmp), dim=2) new_batch = self.projection(new_batch) return new_batch + + @property + def lifting(self): + return self.lifting + + @property + def nn(self): + return self.nn + + @property + def projection(self): + return self.projection \ No newline at end of file From 734f3ff3b02de9212230ef23f27e1370d7040fc4 Mon Sep 17 00:00:00 2001 From: cyberguli Date: Sat, 17 Feb 2024 17:00:41 +0100 Subject: [PATCH 10/29] fixing property --- pina/model/avno.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pina/model/avno.py b/pina/model/avno.py index a90891cb..daadb0a3 100644 --- a/pina/model/avno.py +++ b/pina/model/avno.py @@ -41,11 +41,11 @@ def __init__( self.points_size = points_size self.points_label = points_label self.features_label = features_label - self.lifting = FeedForward(input_features + self.points_size, + self._lifting = FeedForward(input_features + self.points_size, inner_size, inner_size, n_layers, func) - self.nn = nn.Sequential( + self._nn = nn.Sequential( *[AVNOLayer(inner_size, func) for _ in range(n_layers)]) - self.projection = FeedForward(inner_size + self.points_size, + self._projection = FeedForward(inner_size + self.points_size, output_features, inner_size, n_layers, func) @@ -69,20 +69,20 @@ def forward(self, batch): ], axis=2) new_batch = concatenate((features_tmp, points_tmp), dim=2) - new_batch = self.lifting(new_batch) - new_batch = self.nn(new_batch) + new_batch = self._lifting(new_batch) + new_batch = self._nn(new_batch) new_batch = concatenate((new_batch, points_tmp), dim=2) - new_batch = self.projection(new_batch) + new_batch = self._projection(new_batch) return new_batch @property def lifting(self): - return self.lifting + return self._lifting @property def nn(self): - return self.nn + return self._nn @property def projection(self): - return self.projection \ No newline at end of file + return self._projection \ No newline at end of file From 3aafb1d3ce174a33f6974cd987086aa76f917dec Mon Sep 17 00:00:00 2001 From: cyberguli Date: Sat, 17 Feb 2024 17:25:02 +0100 Subject: [PATCH 11/29] codacy issues, converted avno_layer to dataclass --- pina/model/avno.py | 7 +++++-- pina/model/layers/__init__.py | 1 - pina/model/layers/avno_layer.py | 17 +++++++++++------ 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/pina/model/avno.py b/pina/model/avno.py index daadb0a3..ec2289e0 100644 --- a/pina/model/avno.py +++ b/pina/model/avno.py @@ -64,7 +64,7 @@ def forward(self, batch): ], axis=2) features_tmp = concatenate([ - batch.extract(f"{self.features_label}_{i}") + batch.extract(f"{self.features_label}_{i}") for i in range(self.input_features) ], axis=2) @@ -77,12 +77,15 @@ def forward(self, batch): @property def lifting(self): + "Lifting operator of the AVNO" return self._lifting @property def nn(self): + "Integral operator of the AVNO" return self._nn @property def projection(self): - return self._projection \ No newline at end of file + "Projection operator of the AVNO" + return self._projection diff --git a/pina/model/layers/__init__.py b/pina/model/layers/__init__.py index fdf7c714..2b3b5210 100644 --- a/pina/model/layers/__init__.py +++ b/pina/model/layers/__init__.py @@ -24,4 +24,3 @@ from .pod import PODBlock from .embedding import PeriodicBoundaryEmbedding from .avno_layer import AVNOLayer - diff --git a/pina/model/layers/avno_layer.py b/pina/model/layers/avno_layer.py index 0ba1a6c5..996b61e4 100644 --- a/pina/model/layers/avno_layer.py +++ b/pina/model/layers/avno_layer.py @@ -1,20 +1,25 @@ """Module for Averaging Neural Operator Layer class.""" from torch import nn, mean +import dataclasses +from collections.abc import Callable - +@dataclasses.dataclass class AVNOLayer(nn.Module): """ - The PINA implementation of the inner layer of the Averaging Neural Operator . + The PINA implementation of the inner layer + of the Averaging Neural Operator . :param int hidden_size: size of the layer. :param func: the activation function to use. - """ - def __init__(self, hidden_size, func): + hidden_size: int + func: Callable + + def __post_init__(self): super().__init__() - self.nn = nn.Linear(hidden_size, hidden_size) - self.func = func + self.nn = nn.Linear(self.hidden_size, self.hidden_size) + def forward(self, batch): """Forward pass of the layer.""" From 2e9d2adc3157720680deb6114cefb44a585b9283 Mon Sep 17 00:00:00 2001 From: cyberguli Date: Sat, 17 Feb 2024 17:32:02 +0100 Subject: [PATCH 12/29] reverting dataclass as only worsens things --- pina/model/layers/avno_layer.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/pina/model/layers/avno_layer.py b/pina/model/layers/avno_layer.py index 996b61e4..161d0e5a 100644 --- a/pina/model/layers/avno_layer.py +++ b/pina/model/layers/avno_layer.py @@ -1,9 +1,6 @@ """Module for Averaging Neural Operator Layer class.""" from torch import nn, mean -import dataclasses -from collections.abc import Callable -@dataclasses.dataclass class AVNOLayer(nn.Module): """ The PINA implementation of the inner layer @@ -12,15 +9,16 @@ class AVNOLayer(nn.Module): :param int hidden_size: size of the layer. :param func: the activation function to use. """ - - hidden_size: int - func: Callable - - def __post_init__(self): + def __init__(self, hidden_size: int, func=nn.GELU): super().__init__() + self.hidden_size = hidden_size self.nn = nn.Linear(self.hidden_size, self.hidden_size) def forward(self, batch): """Forward pass of the layer.""" return self.func()(self.nn(batch) + mean(batch, dim=1).unsqueeze(1)) + + def linear_component(self, batch): + """Linear component of the layer.""" + return self.nn(batch) From d5d6a2e5495c66621852a05f953f1c317dbcc49a Mon Sep 17 00:00:00 2001 From: cyberguli Date: Sat, 17 Feb 2024 17:35:07 +0100 Subject: [PATCH 13/29] added func in avno layer --- pina/model/layers/avno_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pina/model/layers/avno_layer.py b/pina/model/layers/avno_layer.py index 161d0e5a..cef9c9ee 100644 --- a/pina/model/layers/avno_layer.py +++ b/pina/model/layers/avno_layer.py @@ -13,7 +13,7 @@ def __init__(self, hidden_size: int, func=nn.GELU): super().__init__() self.hidden_size = hidden_size self.nn = nn.Linear(self.hidden_size, self.hidden_size) - + self.func = func def forward(self, batch): """Forward pass of the layer.""" From 4a79fc0ef03ed81c38b2a6c8547738b6a9da4d58 Mon Sep 17 00:00:00 2001 From: cyberguli Date: Sat, 17 Feb 2024 17:43:48 +0100 Subject: [PATCH 14/29] trying to fix last codacy error --- pina/model/avno.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pina/model/avno.py b/pina/model/avno.py index ec2289e0..650972b3 100644 --- a/pina/model/avno.py +++ b/pina/model/avno.py @@ -79,7 +79,7 @@ def forward(self, batch): def lifting(self): "Lifting operator of the AVNO" return self._lifting - + @property def nn(self): "Integral operator of the AVNO" From 0e9416dcb5386d8dd53d2993a97ee9a61a42dc76 Mon Sep 17 00:00:00 2001 From: cyberguli Date: Sat, 17 Feb 2024 17:47:48 +0100 Subject: [PATCH 15/29] deleted another trailing whitespace --- pina/model/avno.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pina/model/avno.py b/pina/model/avno.py index 650972b3..127a989e 100644 --- a/pina/model/avno.py +++ b/pina/model/avno.py @@ -84,7 +84,7 @@ def lifting(self): def nn(self): "Integral operator of the AVNO" return self._nn - + @property def projection(self): "Projection operator of the AVNO" From b7f5684e79e682ccdad3587a0774c75e3d354958 Mon Sep 17 00:00:00 2001 From: cyberguli Date: Sat, 17 Feb 2024 17:54:43 +0100 Subject: [PATCH 16/29] Grammatic fixes --- pina/model/avno.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pina/model/avno.py b/pina/model/avno.py index 127a989e..26f5d0d1 100644 --- a/pina/model/avno.py +++ b/pina/model/avno.py @@ -14,13 +14,13 @@ class AVNO(nn.Module): :param int output_features: The number of output components of the model. :param int points_size: the dimension of the domain of the functions. :param int inner_size: number of neurons in the hidden layer(s). - Default is 100. + Defaults to 100. :param int n_layers: number of hidden layers. Default is 4. :param func: the activation function to use. Default to nn.GELU. :param str features_label: the label of the features in the input tensor. - Default to 'v'. + Defaults to 'v'. :param str points_label: the label of the points in the input tensor. - Default to 'p'. + Defaults to 'p'. """ def __init__( From d051161aee184bc39e99cb384a5c39cb712f1d10 Mon Sep 17 00:00:00 2001 From: cyberguli Date: Sat, 17 Feb 2024 18:00:05 +0100 Subject: [PATCH 17/29] other grammatic fixes --- pina/model/layers/avno_layer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pina/model/layers/avno_layer.py b/pina/model/layers/avno_layer.py index cef9c9ee..ed4e73e5 100644 --- a/pina/model/layers/avno_layer.py +++ b/pina/model/layers/avno_layer.py @@ -7,9 +7,12 @@ class AVNOLayer(nn.Module): of the Averaging Neural Operator . :param int hidden_size: size of the layer. + Defaults to 100. :param func: the activation function to use. + Default to nn.GELU. + """ - def __init__(self, hidden_size: int, func=nn.GELU): + def __init__(self, hidden_size = 100, func = nn.GELU): super().__init__() self.hidden_size = hidden_size self.nn = nn.Linear(self.hidden_size, self.hidden_size) From 9cc389b8698fc40556251fe9edae80217017f6b1 Mon Sep 17 00:00:00 2001 From: cyberguli Date: Wed, 21 Feb 2024 12:24:22 +0100 Subject: [PATCH 18/29] fixed typos and adapted AVNO to KernelNO --- pina/model/avno.py | 100 ++++++++++--------- pina/model/layers/__init__.py | 4 +- pina/model/layers/{avno_layer.py => avno.py} | 2 +- 3 files changed, 54 insertions(+), 52 deletions(-) rename pina/model/layers/{avno_layer.py => avno.py} (96%) diff --git a/pina/model/avno.py b/pina/model/avno.py index 26f5d0d1..a6d3a057 100644 --- a/pina/model/avno.py +++ b/pina/model/avno.py @@ -2,10 +2,11 @@ from torch import nn, concatenate from . import FeedForward -from .layers import AVNOLayer +from .layers import AVNOBlock +from .base_no import KernelNeuralOperator -class AVNO(nn.Module): +class AVNO(KernelNeuralOperator): """ The PINA implementation of the inner layer of the Averaging Neural Operator. @@ -35,57 +36,58 @@ def __init__( func=nn.GELU, ): - super().__init__() self.input_features = input_features self.output_features = output_features self.points_size = points_size self.points_label = points_label self.features_label = features_label - self._lifting = FeedForward(input_features + self.points_size, - inner_size, inner_size, n_layers, func) - self._nn = nn.Sequential( - *[AVNOLayer(inner_size, func) for _ in range(n_layers)]) - self._projection = FeedForward(inner_size + self.points_size, - output_features, inner_size, n_layers, - func) - def forward(self, batch): - """ - Computes the forward pass of the model with the points - specified when calling the function. + class Lifting_Net(nn.Module): + def __init__(self): + super(Lifting_Net, self).__init__() + self._lifting = FeedForward(input_features + points_size, + inner_size, inner_size, n_layers, func) + + def forward(self, batch): + points_tmp = concatenate([ + batch.extract(f"{points_label}_{i}") + for i in range(points_size) + ], + axis=2) + features_tmp = concatenate([ + batch.extract(f"{features_label}_{i}") + for i in range(input_features) + ], + axis=2) + new_batch = concatenate((features_tmp, points_tmp), dim=2) + new_batch = self._lifting(new_batch) + return [new_batch, points_tmp] + + class NN_Net(nn.Module): + def __init__(self): + super(NN_Net, self).__init__() + self._nn = nn.Sequential( + *[AVNOBlock(inner_size, func) for _ in range(n_layers)]) + + def forward(self, batch): + new_batch, points_tmp = batch + new_batch = self._nn(new_batch) + return [new_batch, points_tmp] + + class Projection_Net(nn.Module): + def __init__(self): + super(Projection_Net, self).__init__() + self._projection = FeedForward(inner_size + points_size, + output_features, inner_size, n_layers, + func) + + def forward(self, batch): + new_batch, points_tmp = batch + new_batch = concatenate((new_batch, points_tmp), dim=2) + new_batch = self._projection(new_batch) + return new_batch - :param torch.Tensor batch: the input tensor. - :param torch.Tensor points: the points tensor. - - """ - points_tmp = concatenate([ - batch.extract(f"{self.points_label}_{i}") - for i in range(self.points_size) - ], - axis=2) - features_tmp = concatenate([ - batch.extract(f"{self.features_label}_{i}") - for i in range(self.input_features) - ], - axis=2) - new_batch = concatenate((features_tmp, points_tmp), dim=2) - new_batch = self._lifting(new_batch) - new_batch = self._nn(new_batch) - new_batch = concatenate((new_batch, points_tmp), dim=2) - new_batch = self._projection(new_batch) - return new_batch - - @property - def lifting(self): - "Lifting operator of the AVNO" - return self._lifting - - @property - def nn(self): - "Integral operator of the AVNO" - return self._nn - - @property - def projection(self): - "Projection operator of the AVNO" - return self._projection + nn_net=NN_Net() + lifting_net=Lifting_Net() + projection_net=Projection_Net() + super().__init__(lifting_net, nn_net, projection_net) diff --git a/pina/model/layers/__init__.py b/pina/model/layers/__init__.py index 2b3b5210..7aa51b77 100644 --- a/pina/model/layers/__init__.py +++ b/pina/model/layers/__init__.py @@ -10,7 +10,7 @@ "FourierBlock3D", "PODBlock", "PeriodicBoundaryEmbedding", - "AVNOLayer", + "AVNOBlock", ] from .convolution_2d import ContinuousConvBlock @@ -23,4 +23,4 @@ from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D from .pod import PODBlock from .embedding import PeriodicBoundaryEmbedding -from .avno_layer import AVNOLayer +from .avno import AVNOBlock diff --git a/pina/model/layers/avno_layer.py b/pina/model/layers/avno.py similarity index 96% rename from pina/model/layers/avno_layer.py rename to pina/model/layers/avno.py index ed4e73e5..6342503d 100644 --- a/pina/model/layers/avno_layer.py +++ b/pina/model/layers/avno.py @@ -1,7 +1,7 @@ """Module for Averaging Neural Operator Layer class.""" from torch import nn, mean -class AVNOLayer(nn.Module): +class AVNOBlock(nn.Module): """ The PINA implementation of the inner layer of the Averaging Neural Operator . From ca77b47256441524d2cbc50c602286d1ff79f154 Mon Sep 17 00:00:00 2001 From: cyberguli Date: Wed, 21 Feb 2024 12:47:59 +0100 Subject: [PATCH 19/29] other codacy fixes --- pina/model/avno.py | 151 +++++++++++++++++++++++++++++++-------------- 1 file changed, 103 insertions(+), 48 deletions(-) diff --git a/pina/model/avno.py b/pina/model/avno.py index a6d3a057..c577fd10 100644 --- a/pina/model/avno.py +++ b/pina/model/avno.py @@ -5,6 +5,98 @@ from .layers import AVNOBlock from .base_no import KernelNeuralOperator +class Lifting_Net(nn.Module): + """ + The PINA implementation of the lifting layer of the AVNO + + :param int input_features: The number of input components of the model. + :param int points_size: the dimension of the domain of the functions. + :param int inner_size: number of neurons in the hidden layer(s). + :param int n_layers: number of hidden layers. Default is 4. + :param func: the activation function to use. Default to nn.GELU. + :param str features_label: the label of the features in the input tensor. + :param str points_label: the label of the points in the input tensor. + + """ + + def __init__(self, input_features, points_size, inner_size,n_layers, func, points_label, features_label): + super(Lifting_Net, self).__init__() + self._lifting = FeedForward(input_features + points_size, + inner_size, inner_size, n_layers, func) + self.points_size = points_size + self.inner_size = inner_size + self.input_features = input_features + self.points_label = points_label + self.features_label = features_label + + def forward(self, batch): + """Forward pass of the layer.""" + points_tmp = concatenate([ + batch.extract(f"{self.points_label}_{i}") + for i in range(self.points_size) + ], + axis=2) + features_tmp = concatenate([ + batch.extract(f"{self.features_label}_{i}") + for i in range(self.input_features) + ], + axis=2) + new_batch = concatenate((features_tmp, points_tmp), dim=2) + new_batch = self._lifting(new_batch) + return [new_batch, points_tmp] + + def net(self): + """Returns the net""" + return self._lifting + + +class Integral_Net(nn.Module): + def __init__(self, inner_size, n_layers, func): + super(Integral_Net, self).__init__() + self._nn = nn.Sequential( + *[AVNOBlock(inner_size, func) for _ in range(n_layers)]) + + def forward(self, batch): + """Forward pass of the layer.""" + + new_batch, points_tmp = batch + new_batch = self._nn(new_batch) + return [new_batch, points_tmp] + + def net(self): + """Returns the net""" + return self._nn + +class Projection_Net(nn.Module): + """ + The PINA implementation of the projection + layer of the AVNO. + + :param int inner_size: number of neurons in the hidden layer(s). + :param int points_size: the dimension of the domain of the functions. + :param int output_features: The number of output components of the model. + :param int n_layers: number of hidden layers. Default is 4. + :param func: the activation function to use. Default to nn.GELU. + + """ + + def __init__(self, inner_size, points_size, output_features, n_layers, func): + super(Projection_Net, self).__init__() + self._projection = FeedForward(inner_size + points_size, + output_features, inner_size, n_layers, + func) + + def forward(self, batch): + """Forward pass of the layer.""" + new_batch, points_tmp = batch + new_batch = concatenate((new_batch, points_tmp), dim=2) + new_batch = self._projection(new_batch) + return new_batch + + def net(self): + """Returns the net""" + return self._projection + class AVNO(KernelNeuralOperator): """ @@ -42,52 +134,15 @@ def __init__( self.points_label = points_label self.features_label = features_label - class Lifting_Net(nn.Module): - def __init__(self): - super(Lifting_Net, self).__init__() - self._lifting = FeedForward(input_features + points_size, - inner_size, inner_size, n_layers, func) - - def forward(self, batch): - points_tmp = concatenate([ - batch.extract(f"{points_label}_{i}") - for i in range(points_size) - ], - axis=2) - features_tmp = concatenate([ - batch.extract(f"{features_label}_{i}") - for i in range(input_features) - ], - axis=2) - new_batch = concatenate((features_tmp, points_tmp), dim=2) - new_batch = self._lifting(new_batch) - return [new_batch, points_tmp] - - class NN_Net(nn.Module): - def __init__(self): - super(NN_Net, self).__init__() - self._nn = nn.Sequential( - *[AVNOBlock(inner_size, func) for _ in range(n_layers)]) - - def forward(self, batch): - new_batch, points_tmp = batch - new_batch = self._nn(new_batch) - return [new_batch, points_tmp] - - class Projection_Net(nn.Module): - def __init__(self): - super(Projection_Net, self).__init__() - self._projection = FeedForward(inner_size + points_size, - output_features, inner_size, n_layers, - func) - - def forward(self, batch): - new_batch, points_tmp = batch - new_batch = concatenate((new_batch, points_tmp), dim=2) - new_batch = self._projection(new_batch) - return new_batch - - nn_net=NN_Net() - lifting_net=Lifting_Net() - projection_net=Projection_Net() + nn_net=Integral_Net(inner_size=inner_size, n_layers=n_layers, func=func) + lifting_net=Lifting_Net(input_features=input_features, + points_size=points_size, + inner_size=inner_size, + n_layers=n_layers, func=func, + points_label=points_label, + features_label=features_label) + projection_net=Projection_Net(inner_size=inner_size, + points_size=points_size, + output_features=output_features, + n_layers=n_layers, func=func) super().__init__(lifting_net, nn_net, projection_net) From 17019206873a73c171decfa53ee17ad5b84faa34 Mon Sep 17 00:00:00 2001 From: cyberguli Date: Wed, 21 Feb 2024 12:53:56 +0100 Subject: [PATCH 20/29] pep8 --- pina/model/avno.py | 70 ++++++++++++++++++++++++++++------------------ 1 file changed, 43 insertions(+), 27 deletions(-) diff --git a/pina/model/avno.py b/pina/model/avno.py index c577fd10..0b112c6c 100644 --- a/pina/model/avno.py +++ b/pina/model/avno.py @@ -5,6 +5,7 @@ from .layers import AVNOBlock from .base_no import KernelNeuralOperator + class Lifting_Net(nn.Module): """ The PINA implementation of the lifting layer of the AVNO @@ -18,11 +19,12 @@ class Lifting_Net(nn.Module): :param str points_label: the label of the points in the input tensor. """ - - def __init__(self, input_features, points_size, inner_size,n_layers, func, points_label, features_label): - super(Lifting_Net, self).__init__() - self._lifting = FeedForward(input_features + points_size, - inner_size, inner_size, n_layers, func) + + def __init__(self, input_features, points_size, inner_size, n_layers, func, + points_label, features_label): + super().__init__() + self._lifting = FeedForward(input_features + points_size, inner_size, + inner_size, n_layers, func) self.points_size = points_size self.inner_size = inner_size self.input_features = input_features @@ -35,12 +37,12 @@ def forward(self, batch): batch.extract(f"{self.points_label}_{i}") for i in range(self.points_size) ], - axis=2) + axis=2) features_tmp = concatenate([ batch.extract(f"{self.features_label}_{i}") for i in range(self.input_features) ], - axis=2) + axis=2) new_batch = concatenate((features_tmp, points_tmp), dim=2) new_batch = self._lifting(new_batch) return [new_batch, points_tmp] @@ -49,10 +51,19 @@ def net(self): """Returns the net""" return self._lifting - + class Integral_Net(nn.Module): + """ + The PINA implementation of the integral layer of the AVNO. + + :param int inner_size: number of neurons in the hidden layer(s). + :param int n_layers: number of hidden layers. Default is 4. + :param func: the activation function to use. Default to nn.GELU. + + """ + def __init__(self, inner_size, n_layers, func): - super(Integral_Net, self).__init__() + super().__init__() self._nn = nn.Sequential( *[AVNOBlock(inner_size, func) for _ in range(n_layers)]) @@ -66,7 +77,8 @@ def forward(self, batch): def net(self): """Returns the net""" return self._nn - + + class Projection_Net(nn.Module): """ The PINA implementation of the projection @@ -79,12 +91,13 @@ class Projection_Net(nn.Module): :param func: the activation function to use. Default to nn.GELU. """ - - def __init__(self, inner_size, points_size, output_features, n_layers, func): - super(Projection_Net, self).__init__() + + def __init__(self, inner_size, points_size, output_features, n_layers, + func): + super().__init__() self._projection = FeedForward(inner_size + points_size, - output_features, inner_size, n_layers, - func) + output_features, inner_size, n_layers, + func) def forward(self, batch): """Forward pass of the layer.""" @@ -133,16 +146,19 @@ def __init__( self.points_size = points_size self.points_label = points_label self.features_label = features_label - - nn_net=Integral_Net(inner_size=inner_size, n_layers=n_layers, func=func) - lifting_net=Lifting_Net(input_features=input_features, - points_size=points_size, - inner_size=inner_size, - n_layers=n_layers, func=func, - points_label=points_label, - features_label=features_label) - projection_net=Projection_Net(inner_size=inner_size, - points_size=points_size, - output_features=output_features, - n_layers=n_layers, func=func) + nn_net = Integral_Net(inner_size=inner_size, + n_layers=n_layers, + func=func) + lifting_net = Lifting_Net(input_features=input_features, + points_size=points_size, + inner_size=inner_size, + n_layers=n_layers, + func=func, + points_label=points_label, + features_label=features_label) + projection_net = Projection_Net(inner_size=inner_size, + points_size=points_size, + output_features=output_features, + n_layers=n_layers, + func=func) super().__init__(lifting_net, nn_net, projection_net) From 7cd694fb696e7997d87846435a38eab6d3b1b6de Mon Sep 17 00:00:00 2001 From: cyberguli Date: Wed, 21 Feb 2024 13:06:00 +0100 Subject: [PATCH 21/29] Testing a maybe fake cyclic import --- pina/model/layers/avno.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pina/model/layers/avno.py b/pina/model/layers/avno.py index 6342503d..b9f8a6d4 100644 --- a/pina/model/layers/avno.py +++ b/pina/model/layers/avno.py @@ -1,4 +1,5 @@ """Module for Averaging Neural Operator Layer class.""" + from torch import nn, mean class AVNOBlock(nn.Module): From 213383648897150640a81a3d0090f7170f6f6768 Mon Sep 17 00:00:00 2001 From: cyberguli Date: Mon, 4 Mar 2024 10:15:51 +0100 Subject: [PATCH 22/29] various fixes --- pina/model/avno.py | 154 ++++++----------------------- pina/model/layers/avno.py | 11 ++- tests/test_model/test_avno.py | 28 ++++-- tutorials/tutorial10/tutorial.py | 160 ++++++++++++++++++------------- 4 files changed, 149 insertions(+), 204 deletions(-) diff --git a/pina/model/avno.py b/pina/model/avno.py index 0b112c6c..98312ccb 100644 --- a/pina/model/avno.py +++ b/pina/model/avno.py @@ -6,111 +6,6 @@ from .base_no import KernelNeuralOperator -class Lifting_Net(nn.Module): - """ - The PINA implementation of the lifting layer of the AVNO - - :param int input_features: The number of input components of the model. - :param int points_size: the dimension of the domain of the functions. - :param int inner_size: number of neurons in the hidden layer(s). - :param int n_layers: number of hidden layers. Default is 4. - :param func: the activation function to use. Default to nn.GELU. - :param str features_label: the label of the features in the input tensor. - :param str points_label: the label of the points in the input tensor. - - """ - - def __init__(self, input_features, points_size, inner_size, n_layers, func, - points_label, features_label): - super().__init__() - self._lifting = FeedForward(input_features + points_size, inner_size, - inner_size, n_layers, func) - self.points_size = points_size - self.inner_size = inner_size - self.input_features = input_features - self.points_label = points_label - self.features_label = features_label - - def forward(self, batch): - """Forward pass of the layer.""" - points_tmp = concatenate([ - batch.extract(f"{self.points_label}_{i}") - for i in range(self.points_size) - ], - axis=2) - features_tmp = concatenate([ - batch.extract(f"{self.features_label}_{i}") - for i in range(self.input_features) - ], - axis=2) - new_batch = concatenate((features_tmp, points_tmp), dim=2) - new_batch = self._lifting(new_batch) - return [new_batch, points_tmp] - - def net(self): - """Returns the net""" - return self._lifting - - -class Integral_Net(nn.Module): - """ - The PINA implementation of the integral layer of the AVNO. - - :param int inner_size: number of neurons in the hidden layer(s). - :param int n_layers: number of hidden layers. Default is 4. - :param func: the activation function to use. Default to nn.GELU. - - """ - - def __init__(self, inner_size, n_layers, func): - super().__init__() - self._nn = nn.Sequential( - *[AVNOBlock(inner_size, func) for _ in range(n_layers)]) - - def forward(self, batch): - """Forward pass of the layer.""" - - new_batch, points_tmp = batch - new_batch = self._nn(new_batch) - return [new_batch, points_tmp] - - def net(self): - """Returns the net""" - return self._nn - - -class Projection_Net(nn.Module): - """ - The PINA implementation of the projection - layer of the AVNO. - - :param int inner_size: number of neurons in the hidden layer(s). - :param int points_size: the dimension of the domain of the functions. - :param int output_features: The number of output components of the model. - :param int n_layers: number of hidden layers. Default is 4. - :param func: the activation function to use. Default to nn.GELU. - - """ - - def __init__(self, inner_size, points_size, output_features, n_layers, - func): - super().__init__() - self._projection = FeedForward(inner_size + points_size, - output_features, inner_size, n_layers, - func) - - def forward(self, batch): - """Forward pass of the layer.""" - new_batch, points_tmp = batch - new_batch = concatenate((new_batch, points_tmp), dim=2) - new_batch = self._projection(new_batch) - return new_batch - - def net(self): - """Returns the net""" - return self._projection - - class AVNO(KernelNeuralOperator): """ The PINA implementation of the inner layer @@ -133,8 +28,8 @@ def __init__( self, input_features, output_features, - features_label='v', - points_label='p', + field_indices, + coordinates_indices, points_size=3, inner_size=100, n_layers=4, @@ -144,21 +39,30 @@ def __init__( self.input_features = input_features self.output_features = output_features self.points_size = points_size - self.points_label = points_label - self.features_label = features_label - nn_net = Integral_Net(inner_size=inner_size, - n_layers=n_layers, - func=func) - lifting_net = Lifting_Net(input_features=input_features, - points_size=points_size, - inner_size=inner_size, - n_layers=n_layers, - func=func, - points_label=points_label, - features_label=features_label) - projection_net = Projection_Net(inner_size=inner_size, - points_size=points_size, - output_features=output_features, - n_layers=n_layers, - func=func) - super().__init__(lifting_net, nn_net, projection_net) + self.coordinates_indices = coordinates_indices + self.field_indices = field_indices + integral_net = nn.Sequential( + *[AVNOBlock(inner_size, func) for _ in range(n_layers)]) + lifting_net = FeedForward(input_features + points_size, inner_size, + inner_size, n_layers, func) + projection_net = FeedForward(inner_size + points_size, output_features, + inner_size, n_layers, func) + super().__init__(lifting_net, integral_net, projection_net) + + def forward(self, batch): + points_tmp = concatenate([ + batch.extract(f"{self.coordinates_indices}_{i}") + for i in range(self.points_size) + ], + axis=2) + features_tmp = concatenate([ + batch.extract(f"{self.field_indices}_{i}") + for i in range(self.input_features) + ], + axis=2) + new_batch = concatenate((features_tmp, points_tmp), dim=2) + new_batch = self._lifting_operator(new_batch) + new_batch = self._integral_kernels(new_batch) + new_batch = concatenate((new_batch, points_tmp), dim=2) + new_batch = self._projection_operator(new_batch) + return new_batch diff --git a/pina/model/layers/avno.py b/pina/model/layers/avno.py index b9f8a6d4..0695fdbe 100644 --- a/pina/model/layers/avno.py +++ b/pina/model/layers/avno.py @@ -2,6 +2,7 @@ from torch import nn, mean + class AVNOBlock(nn.Module): """ The PINA implementation of the inner layer @@ -13,15 +14,15 @@ class AVNOBlock(nn.Module): Default to nn.GELU. """ - def __init__(self, hidden_size = 100, func = nn.GELU): + + def __init__(self, hidden_size=100, func=nn.GELU): super().__init__() - self.hidden_size = hidden_size - self.nn = nn.Linear(self.hidden_size, self.hidden_size) - self.func = func + self.nn = nn.Linear(hidden_size, hidden_size) + self.func = func() def forward(self, batch): """Forward pass of the layer.""" - return self.func()(self.nn(batch) + mean(batch, dim=1).unsqueeze(1)) + return self.func(self.nn(batch) + mean(batch, dim=1, keepdim=True)) def linear_component(self, batch): """Linear component of the layer.""" diff --git a/tests/test_model/test_avno.py b/tests/test_model/test_avno.py index 4c516c66..a30c8cd5 100644 --- a/tests/test_model/test_avno.py +++ b/tests/test_model/test_avno.py @@ -10,15 +10,19 @@ def test_constructor(): input_channels = 1 output_channels = 1 #minimuum constructor - AVNO(input_channels, output_channels, torch.rand(10000, 2)) + AVNO(input_channels, + output_channels, + coordinates_indices='p', + field_indices='v') #all constructor AVNO(input_channels, output_channels, - torch.rand(100, 2), inner_size=5, n_layers=5, - func=torch.nn.ReLU) + func=torch.nn.ReLU, + coordinates_indices='p', + field_indices='v') def test_forward(): @@ -28,7 +32,11 @@ def test_forward(): input_ = LabelTensor( torch.rand(batch_size, 1000, input_channels + points_size), ['p_0', 'v_0']) - ano = AVNO(input_channels, output_channels, points_size=points_size) + ano = AVNO(input_channels, + output_channels, + points_size=points_size, + coordinates_indices='p', + field_indices='v') out = ano(input_) assert out.shape == torch.Size( [batch_size, input_.shape[1], output_channels]) @@ -39,13 +47,17 @@ def test_backward(): points_size = 1 output_channels = 1 input_ = LabelTensor( - torch.rand(batch_size, 1000, input_channels + points_size), + torch.rand(batch_size, 1000, points_size + input_channels), ['p_0', 'v_0']) input_ = input_.requires_grad_() - ano = AVNO(input_channels, output_channels, points_size=points_size) - out = ano(input_) + avno = AVNO(input_channels, + output_channels, + points_size=points_size, + coordinates_indices='p', + field_indices='v') + out = avno(input_) tmp = torch.linalg.norm(out) tmp.backward() grad = input_.grad assert grad.shape == torch.Size( - [batch_size, input_.shape[1], input_channels + points_size]) + [batch_size, input_.shape[1], points_size + input_channels]) diff --git a/tutorials/tutorial10/tutorial.py b/tutorials/tutorial10/tutorial.py index 7f5b41cd..80dfaaeb 100644 --- a/tutorials/tutorial10/tutorial.py +++ b/tutorials/tutorial10/tutorial.py @@ -1,84 +1,112 @@ -import torch +"""Tutorial of the Averaging Neural Operator.""" + from time import time +import torch +from pina.model import AVNO +from pina import Condition, LabelTensor +from pina.problem import AbstractProblem +from pina.solvers import SupervisedSolver +from pina.trainer import Trainer + #Data generation torch.manual_seed(0) + def sample_unit_circle(num_points): - radius=torch.rand(num_points,1) - angle=torch.rand(num_points,1)*2*torch.pi - x=radius*torch.cos(angle) - y=radius*torch.sin(angle) - data=torch.cat((x,y),dim=1) + radius = torch.rand(num_points, 1) + angle = torch.rand(num_points, 1) * 2 * torch.pi + x = radius * torch.cos(angle) + y = radius * torch.sin(angle) + data = torch.cat((x, y), dim=1) return data + #sin(a*x+b*y) -def compute_input(data,theta): - data=data.reshape(1,-1,2) - z=torch.sin(theta[:,:,0]*data[:,:,0]+theta[:,:,1]*data[:,:,1]) +def compute_input(data, theta): + data = data.reshape(1, -1, 2) + z = torch.sin(theta[:, :, 0] * data[:, :, 0] + + theta[:, :, 1] * data[:, :, 1]) return z -#1+convolution of sin(a*x+b*y) with sin(x) over [0,2pi]x[0,2pi ] -def compute_output(data,theta): - data=data.reshape(1,-1,2) - z=1-4*torch.sin(torch.pi*theta[:,:,0])*torch.sin(torch.pi*theta[:,:,1])*torch.cos(theta[:,:,0]*(torch.pi*data[:,:,0])+theta[:,:,1]*(torch.pi*data[:,:,1]))/((theta[:,:,0]**2-1)*theta[:,:,1]) + +#1+convolution of sin(a*x+b*y) with sin(x) over [0,2pi]x[0,2pi] +def compute_output(data, theta): + data = data.reshape(1, -1, 2) + s = torch.cos(theta[:, :, 0] * (torch.pi * data[:, :, 0]) + theta[:, :, 1] * + (torch.pi * data[:, :, 1])) + z = 1 - 4 * torch.sin(torch.pi * theta[:, :, 0]) * torch.sin( + torch.pi * theta[:, :, 1]) * s / ( + (theta[:, :, 0]**2 - 1) * theta[:, :, 1]) return z +theta_dataset = 1 + 0.01 * torch.rand(300, 1, 2) +data_coarse = sample_unit_circle(1000) +output_coarse = compute_output(data_coarse, theta_dataset).unsqueeze(-1) +input_coarse = compute_input(data_coarse, theta_dataset).unsqueeze(-1) +data_dense = sample_unit_circle(1000) +output_dense = compute_output(data_dense, theta_dataset).unsqueeze(-1) +input_dense = compute_input(data_dense, theta_dataset).unsqueeze(-1) -theta=1+0.01*torch.rand(300,1,2) -data_coarse=sample_unit_circle(1000) -output_coarse=compute_output(data_coarse,theta).unsqueeze(-1) -input_coarse=compute_input(data_coarse,theta).unsqueeze(-1) -data_dense=sample_unit_circle(1000) -output_dense=compute_output(data_dense,theta).unsqueeze(-1) -input_dense=compute_input(data_dense,theta).unsqueeze(-1) +data_coarse = data_coarse.unsqueeze(0).repeat(300, 1, 1) +data_dense = data_dense.unsqueeze(0).repeat(300, 1, 1) +x_coarse = LabelTensor(torch.concatenate((data_coarse, input_coarse), axis=2), + ['p_0', 'p_1', 'v_0']) +x_dense = LabelTensor(torch.concatenate((data_dense, input_dense), axis=2), + ['p_0', 'p_1', 'v_0']) -from pina.model import AVNO -from pina import Condition,LabelTensor -from pina.problem import AbstractProblem -from pina.solvers import SupervisedSolver -from pina.trainer import Trainer +print(x_coarse.shape) +model = AVNO(input_features=1, + output_features=1, + inner_size=500, + n_layers=4, + points_size=2,field_indices='v', coordinates_indices='p') -model=AVNO(1,1,data_coarse,inner_size=500,n_layers=4) -class ANOSolver(AbstractProblem): - input_variables=['input'] - input_points=LabelTensor(input_coarse,input_variables) - output_variables=['output'] - output_points=LabelTensor(output_coarse,output_variables) - conditions={"data":Condition(input_points=input_points,output_points=output_points)} - -batch_size=1 -problem=ANOSolver() -solver=SupervisedSolver(problem,model,optimizer_kwargs={'lr':1e-3},optimizer=torch.optim.AdamW) -trainer=Trainer(solver=solver,max_epochs=5,accelerator='cpu',enable_model_summary=False,batch_size=batch_size) -from pina.loss import LpLoss -loss=LpLoss(2,relative=True) - -start_time=time() + +class ANOProblem(AbstractProblem): + input_variables = ['p_0', 'p_1', 'v_0'] + input_points = x_coarse + output_variables = ['output'] + output_points = LabelTensor(output_coarse, output_variables) + conditions = { + "data": Condition(input_points=input_points, + output_points=output_points) + } + +batch_size = 1 +problem = ANOProblem() +solver = SupervisedSolver(problem, + model, + optimizer_kwargs={'lr': 1e-3}, + optimizer=torch.optim.AdamW) +trainer = Trainer(solver=solver, + max_epochs=5, + accelerator='cpu', + enable_model_summary=False, + batch_size=batch_size) + +start_time = time() trainer.train() -end_time=time() -print(end_time-start_time) -solver.neural_net=solver.neural_net.eval() -loss=torch.nn.MSELoss() -num_batches=len(input_coarse)//batch_size -num=0 -dem=0 -for i in range(num_batches): - input_variables=['input'] - myinput=LabelTensor(input_coarse[i].unsqueeze(0),input_variables) - tmp=model(myinput).detach().squeeze(0) - num=num+torch.linalg.norm(tmp-output_coarse[i])**2 - dem=dem+torch.linalg.norm(output_coarse[i])**2 -print("Training mse loss is", torch.sqrt(num/dem)) - - -num=0 -dem=0 -for i in range(num_batches): - input_variables=['input'] - myinput=LabelTensor(input_dense[i].unsqueeze(0),input_variables) - tmp=model.forward_eval(myinput,data_dense).detach().squeeze(0) - num=num+torch.linalg.norm(tmp-output_dense[i])**2 - dem=dem+torch.linalg.norm(output_dense[i])**2 -print("Super Resolution mse loss is", torch.sqrt(num/dem)) +end_time = time() +print(end_time - start_time) +solver.neural_net = solver.neural_net.eval() +num_batches = len(input_coarse) // batch_size +num = 0 +dem = 0 + +for i in range(300): + input_tmp = LabelTensor(x_coarse[i].unsqueeze(0), ['p_0', 'p_1', 'v_0']) + tmp = model(input_tmp).detach().squeeze(0) + num = num + torch.linalg.norm(tmp - output_coarse[i])**2 + dem = dem + torch.linalg.norm(output_coarse[i])**2 +print("Training mse loss is", torch.sqrt(num / dem)) + +num = 0 +dem = 0 +for i in range(300): + input_tmp = LabelTensor(x_coarse[i].unsqueeze(0), ['p_0', 'p_1', 'v_0']) + tmp = model(input_tmp).detach().squeeze(0) + num = num + torch.linalg.norm(tmp - output_dense[i])**2 + dem = dem + torch.linalg.norm(output_dense[i])**2 +print("Super Resolution mse loss is", torch.sqrt(num / dem)) \ No newline at end of file From d60f7b25f7b56141236eabb73098ca5c78411582 Mon Sep 17 00:00:00 2001 From: cyberguli Date: Mon, 4 Mar 2024 11:34:14 +0100 Subject: [PATCH 23/29] various fixes --- pina/model/avno.py | 67 ++++++++++---------- pina/model/layers/__init__.py | 2 +- pina/model/layers/{avno.py => avno_layer.py} | 6 +- tests/test_model/test_avno.py | 64 +++++++++---------- tutorials/tutorial10/tutorial.py | 21 +++--- 5 files changed, 80 insertions(+), 80 deletions(-) rename pina/model/layers/{avno.py => avno_layer.py} (84%) diff --git a/pina/model/avno.py b/pina/model/avno.py index 98312ccb..a2f596ce 100644 --- a/pina/model/avno.py +++ b/pina/model/avno.py @@ -8,58 +8,61 @@ class AVNO(KernelNeuralOperator): """ - The PINA implementation of the inner layer - of the Averaging Neural Operator. + Implementation of Averaging Neural Operator. - :param int input_features: The number of input components of the model. - :param int output_features: The number of output components of the model. - :param int points_size: the dimension of the domain of the functions. - :param int inner_size: number of neurons in the hidden layer(s). - Defaults to 100. - :param int n_layers: number of hidden layers. Default is 4. - :param func: the activation function to use. Default to nn.GELU. - :param str features_label: the label of the features in the input tensor. - Defaults to 'v'. - :param str points_label: the label of the points in the input tensor. - Defaults to 'p'. + This class implements the Averaging Neural Operator. + + .. seealso:: + + **Original reference**: Lanthaler S. Li, Z., Kovachki, + Stuart, A. + (2020). *The Nonlocal Neural Operator: + Universal Approximation*. + DOI: `arXiv preprint arXiv:2304.13221. + `_ """ def __init__( self, - input_features, - output_features, + input_numb_fields, + output_numb_fields, field_indices, coordinates_indices, - points_size=3, + dimension=3, inner_size=100, n_layers=4, func=nn.GELU, ): + """ + :param int input_features: The number of input components of the model. + :param int output_features: The number of output components of the model. + :param int points_size: the dimension of the domain of the functions. + :param int inner_size: number of neurons in the hidden layer(s). + Defaults to 100. + :param int n_layers: number of hidden layers. Default is 4. + :param func: the activation function to use. Default to nn.GELU. + :param str features_label: the label of the features in the input tensor. + Defaults to 'v'. + :param str points_label: the label of the points in the input tensor. + Defaults to 'p'. + """ - self.input_features = input_features - self.output_features = output_features - self.points_size = points_size + self.input_numb_fields = input_numb_fields + self.output_numb_fields = output_numb_fields + self.dimension = dimension self.coordinates_indices = coordinates_indices self.field_indices = field_indices integral_net = nn.Sequential( *[AVNOBlock(inner_size, func) for _ in range(n_layers)]) - lifting_net = FeedForward(input_features + points_size, inner_size, + lifting_net = FeedForward(dimension + input_numb_fields, inner_size, inner_size, n_layers, func) - projection_net = FeedForward(inner_size + points_size, output_features, + projection_net = FeedForward(inner_size + dimension, output_numb_fields, inner_size, n_layers, func) super().__init__(lifting_net, integral_net, projection_net) - def forward(self, batch): - points_tmp = concatenate([ - batch.extract(f"{self.coordinates_indices}_{i}") - for i in range(self.points_size) - ], - axis=2) - features_tmp = concatenate([ - batch.extract(f"{self.field_indices}_{i}") - for i in range(self.input_features) - ], - axis=2) + def forward(self, x): + points_tmp = x.extract(self.coordinates_indices) + features_tmp = x.extract(self.field_indices) new_batch = concatenate((features_tmp, points_tmp), dim=2) new_batch = self._lifting_operator(new_batch) new_batch = self._integral_kernels(new_batch) diff --git a/pina/model/layers/__init__.py b/pina/model/layers/__init__.py index 7aa51b77..0ebe2f91 100644 --- a/pina/model/layers/__init__.py +++ b/pina/model/layers/__init__.py @@ -23,4 +23,4 @@ from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D from .pod import PODBlock from .embedding import PeriodicBoundaryEmbedding -from .avno import AVNOBlock +from .avno_block import AVNOBlock diff --git a/pina/model/layers/avno.py b/pina/model/layers/avno_layer.py similarity index 84% rename from pina/model/layers/avno.py rename to pina/model/layers/avno_layer.py index 0695fdbe..7b098d45 100644 --- a/pina/model/layers/avno.py +++ b/pina/model/layers/avno_layer.py @@ -22,8 +22,4 @@ def __init__(self, hidden_size=100, func=nn.GELU): def forward(self, batch): """Forward pass of the layer.""" - return self.func(self.nn(batch) + mean(batch, dim=1, keepdim=True)) - - def linear_component(self, batch): - """Linear component of the layer.""" - return self.nn(batch) + return self.func(self.nn(batch) + mean(batch, dim=1, keepdim=True)) \ No newline at end of file diff --git a/tests/test_model/test_avno.py b/tests/test_model/test_avno.py index a30c8cd5..10302d2b 100644 --- a/tests/test_model/test_avno.py +++ b/tests/test_model/test_avno.py @@ -2,62 +2,60 @@ from pina.model import AVNO from pina import LabelTensor -output_channels = 5 +output_numb_fields = 5 batch_size = 15 def test_constructor(): - input_channels = 1 - output_channels = 1 + input_numb_fields = 1 + output_numb_fields = 1 #minimuum constructor - AVNO(input_channels, - output_channels, - coordinates_indices='p', - field_indices='v') + AVNO(input_numb_fields, + output_numb_fields, + coordinates_indices=['p'], + field_indices=['v']) #all constructor - AVNO(input_channels, - output_channels, + AVNO(input_numb_fields, + output_numb_fields, inner_size=5, n_layers=5, func=torch.nn.ReLU, - coordinates_indices='p', - field_indices='v') + coordinates_indices=['p'], + field_indices=['v']) def test_forward(): - input_channels = 1 - output_channels = 1 - points_size = 1 + input_numb_fields = 1 + output_numb_fields = 1 + dimension = 1 input_ = LabelTensor( - torch.rand(batch_size, 1000, input_channels + points_size), - ['p_0', 'v_0']) - ano = AVNO(input_channels, - output_channels, - points_size=points_size, - coordinates_indices='p', - field_indices='v') + torch.rand(batch_size, 1000, input_numb_fields + dimension), ['p', 'v']) + ano = AVNO(input_numb_fields, + output_numb_fields, + dimension=dimension, + coordinates_indices=['p'], + field_indices=['v']) out = ano(input_) assert out.shape == torch.Size( - [batch_size, input_.shape[1], output_channels]) + [batch_size, input_.shape[1], output_numb_fields]) def test_backward(): - input_channels = 1 - points_size = 1 - output_channels = 1 + input_numb_fields = 1 + dimension = 1 + output_numb_fields = 1 input_ = LabelTensor( - torch.rand(batch_size, 1000, points_size + input_channels), - ['p_0', 'v_0']) + torch.rand(batch_size, 1000, dimension + input_numb_fields), ['p', 'v']) input_ = input_.requires_grad_() - avno = AVNO(input_channels, - output_channels, - points_size=points_size, - coordinates_indices='p', - field_indices='v') + avno = AVNO(input_numb_fields, + output_numb_fields, + dimension=dimension, + coordinates_indices=['p'], + field_indices=['v']) out = avno(input_) tmp = torch.linalg.norm(out) tmp.backward() grad = input_.grad assert grad.shape == torch.Size( - [batch_size, input_.shape[1], points_size + input_channels]) + [batch_size, input_.shape[1], dimension + input_numb_fields]) diff --git a/tutorials/tutorial10/tutorial.py b/tutorials/tutorial10/tutorial.py index 80dfaaeb..a46ac751 100644 --- a/tutorials/tutorial10/tutorial.py +++ b/tutorials/tutorial10/tutorial.py @@ -52,20 +52,22 @@ def compute_output(data, theta): data_coarse = data_coarse.unsqueeze(0).repeat(300, 1, 1) data_dense = data_dense.unsqueeze(0).repeat(300, 1, 1) x_coarse = LabelTensor(torch.concatenate((data_coarse, input_coarse), axis=2), - ['p_0', 'p_1', 'v_0']) + ['x', 'y', 'v']) x_dense = LabelTensor(torch.concatenate((data_dense, input_dense), axis=2), - ['p_0', 'p_1', 'v_0']) + ['x', 'y', 'v']) print(x_coarse.shape) -model = AVNO(input_features=1, - output_features=1, +model = AVNO(input_numb_fields=1, + output_numb_fields=1, inner_size=500, n_layers=4, - points_size=2,field_indices='v', coordinates_indices='p') + dimension=2, + field_indices=['v'], + coordinates_indices=['x', 'y']) class ANOProblem(AbstractProblem): - input_variables = ['p_0', 'p_1', 'v_0'] + input_variables = ['x', 'y', 'v'] input_points = x_coarse output_variables = ['output'] output_points = LabelTensor(output_coarse, output_variables) @@ -74,6 +76,7 @@ class ANOProblem(AbstractProblem): output_points=output_points) } + batch_size = 1 problem = ANOProblem() solver = SupervisedSolver(problem, @@ -96,7 +99,7 @@ class ANOProblem(AbstractProblem): dem = 0 for i in range(300): - input_tmp = LabelTensor(x_coarse[i].unsqueeze(0), ['p_0', 'p_1', 'v_0']) + input_tmp = LabelTensor(x_coarse[i].unsqueeze(0), ['x', 'y', 'v']) tmp = model(input_tmp).detach().squeeze(0) num = num + torch.linalg.norm(tmp - output_coarse[i])**2 dem = dem + torch.linalg.norm(output_coarse[i])**2 @@ -105,8 +108,8 @@ class ANOProblem(AbstractProblem): num = 0 dem = 0 for i in range(300): - input_tmp = LabelTensor(x_coarse[i].unsqueeze(0), ['p_0', 'p_1', 'v_0']) + input_tmp = LabelTensor(x_coarse[i].unsqueeze(0), ['x', 'y', 'v']) tmp = model(input_tmp).detach().squeeze(0) num = num + torch.linalg.norm(tmp - output_dense[i])**2 dem = dem + torch.linalg.norm(output_dense[i])**2 -print("Super Resolution mse loss is", torch.sqrt(num / dem)) \ No newline at end of file +print("Super Resolution mse loss is", torch.sqrt(num / dem)) From 5453e83f192fb6a4fdea0bdbc3e7469d5b538a26 Mon Sep 17 00:00:00 2001 From: cyberguli Date: Mon, 4 Mar 2024 11:36:32 +0100 Subject: [PATCH 24/29] fixed typo --- pina/model/layers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pina/model/layers/__init__.py b/pina/model/layers/__init__.py index 0ebe2f91..2086e7a3 100644 --- a/pina/model/layers/__init__.py +++ b/pina/model/layers/__init__.py @@ -23,4 +23,4 @@ from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D from .pod import PODBlock from .embedding import PeriodicBoundaryEmbedding -from .avno_block import AVNOBlock +from .avno_layer import AVNOBlock From fe318dff1f2bb107d84d531e90b0137ac412d811 Mon Sep 17 00:00:00 2001 From: cyberguli Date: Mon, 4 Mar 2024 11:41:55 +0100 Subject: [PATCH 25/29] fixing the fixable codacy --- pina/model/avno.py | 6 ++++-- pina/model/layers/avno_layer.py | 6 +++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/pina/model/avno.py b/pina/model/avno.py index a2f596ce..822d66e2 100644 --- a/pina/model/avno.py +++ b/pina/model/avno.py @@ -35,13 +35,15 @@ def __init__( ): """ :param int input_features: The number of input components of the model. - :param int output_features: The number of output components of the model. + :param int output_features: The number of output components + of the model. :param int points_size: the dimension of the domain of the functions. :param int inner_size: number of neurons in the hidden layer(s). Defaults to 100. :param int n_layers: number of hidden layers. Default is 4. :param func: the activation function to use. Default to nn.GELU. - :param str features_label: the label of the features in the input tensor. + :param str features_label: the label of the features + in the input tensor. Defaults to 'v'. :param str points_label: the label of the points in the input tensor. Defaults to 'p'. diff --git a/pina/model/layers/avno_layer.py b/pina/model/layers/avno_layer.py index 7b098d45..2a17a8f9 100644 --- a/pina/model/layers/avno_layer.py +++ b/pina/model/layers/avno_layer.py @@ -17,9 +17,9 @@ class AVNOBlock(nn.Module): def __init__(self, hidden_size=100, func=nn.GELU): super().__init__() - self.nn = nn.Linear(hidden_size, hidden_size) - self.func = func() + self._nn = nn.Linear(hidden_size, hidden_size) + self._func = func() def forward(self, batch): """Forward pass of the layer.""" - return self.func(self.nn(batch) + mean(batch, dim=1, keepdim=True)) \ No newline at end of file + return self._func(self._nn(batch) + mean(batch, dim=1, keepdim=True)) From cefc4f497a8acd0ecd8e8419a7c8863f6df9f519 Mon Sep 17 00:00:00 2001 From: cyberguli Date: Mon, 4 Mar 2024 21:32:26 +0100 Subject: [PATCH 26/29] removed tutorial, modified AVNO name and added docs --- docs/source/_rst/_code.rst | 3 +- docs/source/_rst/layers/avno_layer.rst | 8 ++ docs/source/_rst/models/avno.rst | 7 ++ pina/model/__init__.py | 2 +- pina/model/avno.py | 16 ++-- tests/test_model/test_avno.py | 13 +-- tutorials/tutorial10/tutorial.py | 115 ------------------------- 7 files changed, 33 insertions(+), 131 deletions(-) create mode 100644 docs/source/_rst/layers/avno_layer.rst create mode 100644 docs/source/_rst/models/avno.rst delete mode 100644 tutorials/tutorial10/tutorial.py diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 8e7b31f7..067ce244 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -56,6 +56,7 @@ Models MIONet FourierIntegralKernel FNO + AveragingNeuralOperator Layers ------------- @@ -70,7 +71,7 @@ Layers Continuous convolution Proper Orthogonal Decomposition Periodic Boundary Condition embeddings - + Averaging Neural Operator block Equations and Operators ------------------------- diff --git a/docs/source/_rst/layers/avno_layer.rst b/docs/source/_rst/layers/avno_layer.rst new file mode 100644 index 00000000..cbc64bdb --- /dev/null +++ b/docs/source/_rst/layers/avno_layer.rst @@ -0,0 +1,8 @@ +Averaging Neural Operator block +========================= +.. currentmodule:: pina.model.layers.avno_layer + +.. autoclass:: AVNOBlock + :members: + :show-inheritance: + :noindex: diff --git a/docs/source/_rst/models/avno.rst b/docs/source/_rst/models/avno.rst new file mode 100644 index 00000000..bb0406aa --- /dev/null +++ b/docs/source/_rst/models/avno.rst @@ -0,0 +1,7 @@ +Averaging Neural Operator +=========== +.. currentmodule:: pina.model.avno + +.. autoclass:: AveragingNeuralOperator + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/model/__init__.py b/pina/model/__init__.py index 14c9f2b3..f69d3570 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -15,4 +15,4 @@ from .deeponet import DeepONet, MIONet from .fno import FNO, FourierIntegralKernel from .base_no import KernelNeuralOperator -from .avno import AVNO +from .avno import AveragingNeuralOperator diff --git a/pina/model/avno.py b/pina/model/avno.py index 822d66e2..359efa09 100644 --- a/pina/model/avno.py +++ b/pina/model/avno.py @@ -6,7 +6,7 @@ from .base_no import KernelNeuralOperator -class AVNO(KernelNeuralOperator): +class AveragingNeuralOperator(KernelNeuralOperator): """ Implementation of Averaging Neural Operator. @@ -34,19 +34,19 @@ def __init__( func=nn.GELU, ): """ - :param int input_features: The number of input components of the model. - :param int output_features: The number of output components + :param int input_numb_fields: The number of input components of the model. - :param int points_size: the dimension of the domain of the functions. + :param int output_numb_fields: The number of output components + of the model. + :param int dimension: the dimension of the domain of the functions. :param int inner_size: number of neurons in the hidden layer(s). Defaults to 100. :param int n_layers: number of hidden layers. Default is 4. :param func: the activation function to use. Default to nn.GELU. - :param str features_label: the label of the features + :param str field_indices: the label of the fields in the input tensor. - Defaults to 'v'. - :param str points_label: the label of the points in the input tensor. - Defaults to 'p'. + :param str coordinates_indices: the label of the + coordinates in the input tensor. """ self.input_numb_fields = input_numb_fields diff --git a/tests/test_model/test_avno.py b/tests/test_model/test_avno.py index 10302d2b..a08f02c0 100644 --- a/tests/test_model/test_avno.py +++ b/tests/test_model/test_avno.py @@ -1,5 +1,5 @@ import torch -from pina.model import AVNO +from pina.model import AveragingNeuralOperator from pina import LabelTensor output_numb_fields = 5 @@ -10,13 +10,13 @@ def test_constructor(): input_numb_fields = 1 output_numb_fields = 1 #minimuum constructor - AVNO(input_numb_fields, + AveragingNeuralOperator(input_numb_fields, output_numb_fields, coordinates_indices=['p'], field_indices=['v']) #all constructor - AVNO(input_numb_fields, + AveragingNeuralOperator(input_numb_fields, output_numb_fields, inner_size=5, n_layers=5, @@ -31,7 +31,7 @@ def test_forward(): dimension = 1 input_ = LabelTensor( torch.rand(batch_size, 1000, input_numb_fields + dimension), ['p', 'v']) - ano = AVNO(input_numb_fields, + ano = AveragingNeuralOperator(input_numb_fields, output_numb_fields, dimension=dimension, coordinates_indices=['p'], @@ -46,9 +46,10 @@ def test_backward(): dimension = 1 output_numb_fields = 1 input_ = LabelTensor( - torch.rand(batch_size, 1000, dimension + input_numb_fields), ['p', 'v']) + torch.rand(batch_size, 1000, dimension + input_numb_fields), + ['p', 'v']) input_ = input_.requires_grad_() - avno = AVNO(input_numb_fields, + avno = AveragingNeuralOperator(input_numb_fields, output_numb_fields, dimension=dimension, coordinates_indices=['p'], diff --git a/tutorials/tutorial10/tutorial.py b/tutorials/tutorial10/tutorial.py deleted file mode 100644 index a46ac751..00000000 --- a/tutorials/tutorial10/tutorial.py +++ /dev/null @@ -1,115 +0,0 @@ -"""Tutorial of the Averaging Neural Operator.""" - -from time import time -import torch -from pina.model import AVNO -from pina import Condition, LabelTensor -from pina.problem import AbstractProblem -from pina.solvers import SupervisedSolver -from pina.trainer import Trainer - -#Data generation - -torch.manual_seed(0) - - -def sample_unit_circle(num_points): - radius = torch.rand(num_points, 1) - angle = torch.rand(num_points, 1) * 2 * torch.pi - x = radius * torch.cos(angle) - y = radius * torch.sin(angle) - data = torch.cat((x, y), dim=1) - return data - - -#sin(a*x+b*y) -def compute_input(data, theta): - data = data.reshape(1, -1, 2) - z = torch.sin(theta[:, :, 0] * data[:, :, 0] + - theta[:, :, 1] * data[:, :, 1]) - return z - - -#1+convolution of sin(a*x+b*y) with sin(x) over [0,2pi]x[0,2pi] -def compute_output(data, theta): - data = data.reshape(1, -1, 2) - s = torch.cos(theta[:, :, 0] * (torch.pi * data[:, :, 0]) + theta[:, :, 1] * - (torch.pi * data[:, :, 1])) - z = 1 - 4 * torch.sin(torch.pi * theta[:, :, 0]) * torch.sin( - torch.pi * theta[:, :, 1]) * s / ( - (theta[:, :, 0]**2 - 1) * theta[:, :, 1]) - return z - - -theta_dataset = 1 + 0.01 * torch.rand(300, 1, 2) -data_coarse = sample_unit_circle(1000) -output_coarse = compute_output(data_coarse, theta_dataset).unsqueeze(-1) -input_coarse = compute_input(data_coarse, theta_dataset).unsqueeze(-1) -data_dense = sample_unit_circle(1000) -output_dense = compute_output(data_dense, theta_dataset).unsqueeze(-1) -input_dense = compute_input(data_dense, theta_dataset).unsqueeze(-1) - -data_coarse = data_coarse.unsqueeze(0).repeat(300, 1, 1) -data_dense = data_dense.unsqueeze(0).repeat(300, 1, 1) -x_coarse = LabelTensor(torch.concatenate((data_coarse, input_coarse), axis=2), - ['x', 'y', 'v']) -x_dense = LabelTensor(torch.concatenate((data_dense, input_dense), axis=2), - ['x', 'y', 'v']) - -print(x_coarse.shape) -model = AVNO(input_numb_fields=1, - output_numb_fields=1, - inner_size=500, - n_layers=4, - dimension=2, - field_indices=['v'], - coordinates_indices=['x', 'y']) - - -class ANOProblem(AbstractProblem): - input_variables = ['x', 'y', 'v'] - input_points = x_coarse - output_variables = ['output'] - output_points = LabelTensor(output_coarse, output_variables) - conditions = { - "data": Condition(input_points=input_points, - output_points=output_points) - } - - -batch_size = 1 -problem = ANOProblem() -solver = SupervisedSolver(problem, - model, - optimizer_kwargs={'lr': 1e-3}, - optimizer=torch.optim.AdamW) -trainer = Trainer(solver=solver, - max_epochs=5, - accelerator='cpu', - enable_model_summary=False, - batch_size=batch_size) - -start_time = time() -trainer.train() -end_time = time() -print(end_time - start_time) -solver.neural_net = solver.neural_net.eval() -num_batches = len(input_coarse) // batch_size -num = 0 -dem = 0 - -for i in range(300): - input_tmp = LabelTensor(x_coarse[i].unsqueeze(0), ['x', 'y', 'v']) - tmp = model(input_tmp).detach().squeeze(0) - num = num + torch.linalg.norm(tmp - output_coarse[i])**2 - dem = dem + torch.linalg.norm(output_coarse[i])**2 -print("Training mse loss is", torch.sqrt(num / dem)) - -num = 0 -dem = 0 -for i in range(300): - input_tmp = LabelTensor(x_coarse[i].unsqueeze(0), ['x', 'y', 'v']) - tmp = model(input_tmp).detach().squeeze(0) - num = num + torch.linalg.norm(tmp - output_dense[i])**2 - dem = dem + torch.linalg.norm(output_dense[i])**2 -print("Super Resolution mse loss is", torch.sqrt(num / dem)) From 1de2b0d4594c20ddfd1c8a3ebee59fdfa17cc38d Mon Sep 17 00:00:00 2001 From: cyberguli Date: Mon, 4 Mar 2024 21:36:37 +0100 Subject: [PATCH 27/29] fixed init --- pina/model/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pina/model/__init__.py b/pina/model/__init__.py index f69d3570..b0849887 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -7,7 +7,7 @@ "FNO", "FourierIntegralKernel", "KernelNeuralOperator", - "AVNO", + "AveragingNeuralOperator", ] from .feed_forward import FeedForward, ResidualFeedForward From 0b187f89fdd609a60757b5f5d5aac1853e8ad3be Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Tue, 5 Mar 2024 11:00:55 +0100 Subject: [PATCH 28/29] minor changes --- docs/source/_rst/_code.rst | 2 +- docs/source/_rst/layers/avno_layer.rst | 4 +- docs/source/_rst/models/avno.rst | 2 +- pina/model/avno.py | 41 +++++++++++++-- pina/model/layers/avno_layer.py | 70 +++++++++++++++++++++----- 5 files changed, 97 insertions(+), 22 deletions(-) diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 067ce244..d1c40062 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -68,10 +68,10 @@ Layers EnhancedLinear layer Spectral convolution Fourier layers + Averaging layer Continuous convolution Proper Orthogonal Decomposition Periodic Boundary Condition embeddings - Averaging Neural Operator block Equations and Operators ------------------------- diff --git a/docs/source/_rst/layers/avno_layer.rst b/docs/source/_rst/layers/avno_layer.rst index cbc64bdb..38d7ccbe 100644 --- a/docs/source/_rst/layers/avno_layer.rst +++ b/docs/source/_rst/layers/avno_layer.rst @@ -1,5 +1,5 @@ -Averaging Neural Operator block -========================= +Averaging layers +==================== .. currentmodule:: pina.model.layers.avno_layer .. autoclass:: AVNOBlock diff --git a/docs/source/_rst/models/avno.rst b/docs/source/_rst/models/avno.rst index bb0406aa..a083f6fd 100644 --- a/docs/source/_rst/models/avno.rst +++ b/docs/source/_rst/models/avno.rst @@ -1,5 +1,5 @@ Averaging Neural Operator -=========== +============================== .. currentmodule:: pina.model.avno .. autoclass:: AveragingNeuralOperator diff --git a/pina/model/avno.py b/pina/model/avno.py index 359efa09..0c05adf0 100644 --- a/pina/model/avno.py +++ b/pina/model/avno.py @@ -4,16 +4,20 @@ from . import FeedForward from .layers import AVNOBlock from .base_no import KernelNeuralOperator +from pina.utils import check_consistency class AveragingNeuralOperator(KernelNeuralOperator): """ - Implementation of Averaging Neural Operator. + Implementation of Averaging Neural Operator. - This class implements the Averaging Neural Operator. + Averaging Neural Operator is a general architecture for + learning Operators. Unlike traditional machine learning methods + AveragingNeuralOperator is designed to map entire functions + to other functions. It can be trained with Supervised learning strategies. + AveragingNeuralOperator does convolution by performing a field average. .. seealso:: - **Original reference**: Lanthaler S. Li, Z., Kovachki, Stuart, A. (2020). *The Nonlocal Neural Operator: @@ -43,12 +47,23 @@ def __init__( Defaults to 100. :param int n_layers: number of hidden layers. Default is 4. :param func: the activation function to use. Default to nn.GELU. - :param str field_indices: the label of the fields + :param list[str] field_indices: the label of the fields in the input tensor. - :param str coordinates_indices: the label of the + :param list[str] coordinates_indices: the label of the coordinates in the input tensor. """ + # check consistency + check_consistency(input_numb_fields, int) + check_consistency(output_numb_fields, int) + check_consistency(field_indices, str) + check_consistency(coordinates_indices, str) + check_consistency(dimension, int) + check_consistency(inner_size, int) + check_consistency(n_layers, int) + check_consistency(func, nn.Module, subclass=True) + + # assign self.input_numb_fields = input_numb_fields self.output_numb_fields = output_numb_fields self.dimension = dimension @@ -63,6 +78,22 @@ def __init__( super().__init__(lifting_net, integral_net, projection_net) def forward(self, x): + r""" + Forward computation for Averaging Neural Operator. It performs a + lifting of the input by the ``lifting_net``. Then different layers + of Averaging Neural Operator Blocks are applied. + Finally the output is projected to the final dimensionality + by the ``projecting_net``. + + :param torch.Tensor x: The input tensor for fourier block, + depending on ``dimension`` in the initialization. It expects + a tensor :math:`B \times N \times D`, + where :math:`B` is the batch_size, :math:`N` the number of points + in the mesh, :math:`D` the dimension of the problem, i.e. the sum + of ``len(coordinates_indices)+len(field_indices)``. + :return: The output tensor obtained from Average Neural Operator. + :rtype: torch.Tensor + """ points_tmp = x.extract(self.coordinates_indices) features_tmp = x.extract(self.field_indices) new_batch = concatenate((features_tmp, points_tmp), dim=2) diff --git a/pina/model/layers/avno_layer.py b/pina/model/layers/avno_layer.py index 2a17a8f9..29bfef1f 100644 --- a/pina/model/layers/avno_layer.py +++ b/pina/model/layers/avno_layer.py @@ -1,25 +1,69 @@ -"""Module for Averaging Neural Operator Layer class.""" +""" Module for Averaging Neural Operator Layer class. """ from torch import nn, mean +from pina.utils import check_consistency class AVNOBlock(nn.Module): - """ - The PINA implementation of the inner layer - of the Averaging Neural Operator . - - :param int hidden_size: size of the layer. - Defaults to 100. - :param func: the activation function to use. - Default to nn.GELU. - + r""" + The PINA implementation of the inner layer of the Averaging Neural Operator. + + The operator layer performs an affine transformation where the convolution + is approximated with a local average. Given the input function + :math:`v(x)\in\mathbb{R}^{\rm{emb}` the :meth:`AVNOBlock` computes + the operator update :math:`K(v)` as: + + .. math:: + K(v) = \sigma\left(Wv(x) + b + \frac{1}{|\mathcal{A}|}\int v(y)dy\right) + + where: + + * :math:`\mathbb{R}^{\rm{emb}}` is the embedding (hidden) size + corresponding to the ``hidden_size`` object + * :math:`\sigma` is a non-linear activation, corresponding to the + ``func`` object + * :math:`W\in\mathbb{R}^{\rm{emb}\times\rm{emb}}` is a tunable matrix. + * :math:`b\in\mathbb{R}^{\rm{emb}}` is a tunable bias. + + .. seealso:: + + **Original reference**: Lanthaler S. Li, Z., Kovachki, + Stuart, A. (2020). *The Nonlocal Neural Operator: Universal + Approximation*. + DOI: `arXiv preprint arXiv:2304.13221. + `_ + """ def __init__(self, hidden_size=100, func=nn.GELU): + """ + Initialize AVNOBlock. + + :param int hidden_size: Size of the hidden layer, defaults to 100. + :param func: The activation function, default to nn.GELU. + """ + super().__init__() + # Check type consistency + check_consistency(hidden_size, int) + check_consistency(func, nn.Module, subclass=True) + # Assignment self._nn = nn.Linear(hidden_size, hidden_size) self._func = func() - def forward(self, batch): - """Forward pass of the layer.""" - return self._func(self._nn(batch) + mean(batch, dim=1, keepdim=True)) + def forward(self, x): + """ + Forward pass of the layer, it performs a sum of local average + and an affine transformation of the field. + + :param torch.Tensor x: The input tensor for performing the + computation. It expects a tensor :math:`B \times N \times D`, + where :math:`B` is the batch_size, :math:`N` the number of points + in the mesh, :math:`D` the dimension of the problem. In particular + :math:`D` is the codomain of the function :math:`v`. For example + a scalar function has :math:`D=1`, a 4-dimensional vector function + :math:`D=4`. + :return: The output tensor obtained from the :meth:`AVNOBlock`. + :rtype: torch.Tensor + """ + return self._func(self._nn(x) + mean(x, dim=1, keepdim=True)) From 494af4fedca2c2aea0295f6cac76f5b4703af86d Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Tue, 5 Mar 2024 11:07:25 +0100 Subject: [PATCH 29/29] doc addition --- pina/model/avno.py | 6 +++--- pina/model/layers/avno_layer.py | 10 ++++------ 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/pina/model/avno.py b/pina/model/avno.py index 0c05adf0..b85695ca 100644 --- a/pina/model/avno.py +++ b/pina/model/avno.py @@ -18,10 +18,10 @@ class AveragingNeuralOperator(KernelNeuralOperator): AveragingNeuralOperator does convolution by performing a field average. .. seealso:: + **Original reference**: Lanthaler S. Li, Z., Kovachki, - Stuart, A. - (2020). *The Nonlocal Neural Operator: - Universal Approximation*. + Stuart, A. (2020). *The Nonlocal Neural Operator: + Universal Approximation*. DOI: `arXiv preprint arXiv:2304.13221. `_ """ diff --git a/pina/model/layers/avno_layer.py b/pina/model/layers/avno_layer.py index 29bfef1f..9e91c616 100644 --- a/pina/model/layers/avno_layer.py +++ b/pina/model/layers/avno_layer.py @@ -10,7 +10,7 @@ class AVNOBlock(nn.Module): The operator layer performs an affine transformation where the convolution is approximated with a local average. Given the input function - :math:`v(x)\in\mathbb{R}^{\rm{emb}` the :meth:`AVNOBlock` computes + :math:`v(x)\in\mathbb{R}^{\rm{emb}}` the layer computes the operator update :math:`K(v)` as: .. math:: @@ -37,13 +37,11 @@ class AVNOBlock(nn.Module): def __init__(self, hidden_size=100, func=nn.GELU): """ - Initialize AVNOBlock. - :param int hidden_size: Size of the hidden layer, defaults to 100. :param func: The activation function, default to nn.GELU. """ - super().__init__() + # Check type consistency check_consistency(hidden_size, int) check_consistency(func, nn.Module, subclass=True) @@ -52,7 +50,7 @@ def __init__(self, hidden_size=100, func=nn.GELU): self._func = func() def forward(self, x): - """ + r""" Forward pass of the layer, it performs a sum of local average and an affine transformation of the field. @@ -63,7 +61,7 @@ def forward(self, x): :math:`D` is the codomain of the function :math:`v`. For example a scalar function has :math:`D=1`, a 4-dimensional vector function :math:`D=4`. - :return: The output tensor obtained from the :meth:`AVNOBlock`. + :return: The output tensor obtained from Average Neural Operator Block. :rtype: torch.Tensor """ return self._func(self._nn(x) + mean(x, dim=1, keepdim=True))