From 81e66419903f5dc6652fa6d2a47847bd790cfc3c Mon Sep 17 00:00:00 2001
From: stillerpatrick
Date: Thu, 29 Jul 2021 15:49:38 +0200
Subject: [PATCH] Adding logging features to fork
---
PINNFramework/BoundaryCondition.py | 85 ++++-
PINNFramework/HPMLoss.py | 4 +-
PINNFramework/InitalCondition.py | 6 +-
PINNFramework/JoinedDataset.py | 24 +-
PINNFramework/Logger_Interface.py | 45 +++
PINNFramework/LossTerm.py | 4 +-
PINNFramework/PDELoss.py | 5 +-
PINNFramework/PINN.py | 340 ++++++++++++++---
PINNFramework/WandB_Logger.py | 70 ++++
PINNFramework/__init__.py | 12 +-
PINNFramework/callbacks/Callback.py | 48 +++
PINNFramework/callbacks/__init__.py | 6 +
PINNFramework/models/Finger_Net.py | 115 ++++++
PINNFramework/models/__init__.py | 16 +-
PINNFramework/models/activations/__init__.py | 5 +
PINNFramework/models/activations/snake.py | 20 +
.../models/{moe.py => distributed_moe.py} | 51 ++-
PINNFramework/models/mlp.py | 14 +-
PINNFramework/models/moe_finger.py | 350 +++++++++++++++++
PINNFramework/models/moe_finger_test.py | 32 ++
PINNFramework/models/moe_mlp.py | 361 ++++++++++++++++++
PINNFramework/models/moe_mlp_test.py | 29 ++
PINNFramework/models/snake_mlp.py | 39 ++
benchmarks/hvd_test.py | 6 +
benchmarks/moe_benchmark.py | 49 +++
25 files changed, 1633 insertions(+), 103 deletions(-)
create mode 100644 PINNFramework/Logger_Interface.py
create mode 100644 PINNFramework/WandB_Logger.py
create mode 100644 PINNFramework/callbacks/Callback.py
create mode 100644 PINNFramework/callbacks/__init__.py
create mode 100644 PINNFramework/models/Finger_Net.py
create mode 100644 PINNFramework/models/activations/__init__.py
create mode 100644 PINNFramework/models/activations/snake.py
rename PINNFramework/models/{moe.py => distributed_moe.py} (86%)
create mode 100644 PINNFramework/models/moe_finger.py
create mode 100644 PINNFramework/models/moe_finger_test.py
create mode 100644 PINNFramework/models/moe_mlp.py
create mode 100644 PINNFramework/models/moe_mlp_test.py
create mode 100644 PINNFramework/models/snake_mlp.py
create mode 100644 benchmarks/hvd_test.py
create mode 100644 benchmarks/moe_benchmark.py
diff --git a/PINNFramework/BoundaryCondition.py b/PINNFramework/BoundaryCondition.py
index 55f5107..fdb748b 100644
--- a/PINNFramework/BoundaryCondition.py
+++ b/PINNFramework/BoundaryCondition.py
@@ -4,9 +4,8 @@
class BoundaryCondition(LossTerm):
- def __init__(self, name, dataset, norm='L2', weight=1.):
- self.name = name
- super(BoundaryCondition, self).__init__(dataset, norm, weight)
+ def __init__(self, dataset, name, norm='L2', weight=1.):
+ super(BoundaryCondition, self).__init__(dataset, name, norm, weight)
def __call__(self, *args, **kwargs):
raise NotImplementedError("The call function of the Boundary Condition has to be implemented")
@@ -18,7 +17,7 @@ class DirichletBC(BoundaryCondition):
"""
def __init__(self, func, dataset, name, norm='L2',weight=1.):
- super(DirichletBC, self).__init__(name, dataset, norm, weight)
+ super(DirichletBC, self).__init__(dataset, name, norm, weight)
self.func = func
def __call__(self, x, model):
@@ -29,19 +28,37 @@ def __call__(self, x, model):
class NeumannBC(BoundaryCondition):
"""
Neumann boundary conditions: dy/dn(x) = func(x).
+
+ With dy/dn(x) = <∇y,n>
"""
- def __init__(self, func, dataset, input_dimension, output_dimension, name, norm='L2',weight=1.):
- super(NeumannBC, self).__init__(name, dataset, norm, weight)
+ def __init__(self, func, dataset, normal_vector, begin, end, output_dimension, name, norm='L2', weight=1.):
+ """
+ Args:
+ func: scalar but vectorized function f(x)
+ normal_vector: normal vector for the face
+ name: identifier of the boundary condition
+ weight: weighting of the boundary condition
+ begin: defines the begin of spatial variables in x
+ end: defines the end of the spatial domain in x
+ output_dimension defines on which dimension of the output the boundary condition performed
+ """
+ super(NeumannBC, self).__init__(dataset, name, norm, weight)
self.func = func
- self.input_dimension = input_dimension
+ self.normal_vector = normal_vector
+ self.begin = begin
+ self.end = end
self.output_dimension = output_dimension
def __call__(self, x, model):
- grads = ones(x.shape, device=model.device)
- y = model(x)[:, self.output_dimension]
+ x.requires_grad = True
+ y = model(x)
+ y = y[:, self.output_dimension]
+ grads = ones(y.shape, device=y.device)
grad_y = grad(y, x, create_graph=True, grad_outputs=grads)[0]
- y_dn = grad_y[:, self.input_dimension]
+ grad_y = grad_y[:,self.begin:self.end]
+ self.normal_vector.to(y.device) # move normal vector to the correct device
+ y_dn = grad_y @ self.normal_vector
return self.weight * self.norm(y_dn, self.func(x))
@@ -50,17 +67,34 @@ class RobinBC(BoundaryCondition):
Robin boundary conditions: dy/dn(x) = func(x, y).
"""
- def __init__(self, func, dataset, input_dimension, output_dimension, name, norm='L2', weight=1.):
- super(RobinBC, self).__init__(name, dataset, norm, weight)
+ def __init__(self, func, dataset, normal_vector, begin, end, output_dimension, name, norm='L2', weight=1.):
+ """
+ Args:
+ func: scalar but vectorized function f(x,y)
+ normal_vector: normal vector for the face
+ name: identifier of the boundary condition
+ weight: weighting of the boundary condition
+ begin: defines the begin of spatial variables in x
+ end: defines the end of the spatial domain in x
+ output_dimension defines on which dimension of the output the boundary condition performed
+ """
+
+ super(RobinBC, self).__init__(dataset, name, norm, weight)
self.func = func
- self.input_dimension = input_dimension
+ self.begin = begin
+ self.end = end
+ self.normal_vector = normal_vector
self.output_dimension = output_dimension
def __call__(self, x, y, model):
- y = model(x)[:, self.output_dimension]
+ x.requires_grad = True
+ y = model(x)
+ y = y[:, self.output_dimension]
grads = ones(y.shape, device=y.device)
grad_y = grad(y, x, create_graph=True, grad_outputs=grads)[0]
- y_dn = grad_y[:, self.input_dimension]
+ grad_y = grad_y[:, self.begin:self.end]
+ self.normal_vector.to(y.device) # move normal vector to the correct device
+ y_dn = grad_y @ self.normal_vector
return self.weight * self.norm(y_dn, self.func(x, y))
@@ -70,7 +104,7 @@ class PeriodicBC(BoundaryCondition):
"""
def __init__(self, dataset, output_dimension, name, degree=None, input_dimension=None, norm='L2', weight=1.):
- super(PeriodicBC, self).__init__(name, dataset, norm, weight)
+ super(PeriodicBC, self).__init__(dataset, name, norm, weight)
if degree is not None and input_dimension is None:
raise ValueError("If the degree of the boundary condition is defined the input dimension for the "
"derivative has to be defined too ")
@@ -95,3 +129,22 @@ def __call__(self, x_lb, x_ub, model):
else:
raise NotImplementedError("Periodic Boundary Condition for a higher degree than one is not supported")
+
+
+class TimeDerivativeBC(BoundaryCondition):
+ """
+ For hyperbolic systems it may be needed to initialize the time derivative. This boundary condition intializes
+ the time derivative in a data driven way.
+
+ """
+ def __init__(self, dataset, name, norm='L2', weight=1):
+ super(TimeDerivativeBC, self).__init__(dataset, name, norm, weight)
+
+ def __call__(self, x, dt_y, model):
+ x.requires_grad = True
+ pred = model(x)
+ grads = ones(pred.shape, device=pred.device)
+ pred_dt = grad(pred, x, create_graph=True, grad_outputs=grads)[0][:, -1]
+ pred_dt = pred_dt.reshape(-1,1)
+ return self.weight * self.norm(pred_dt, dt_y)
+
diff --git a/PINNFramework/HPMLoss.py b/PINNFramework/HPMLoss.py
index 318ea29..c29b36f 100644
--- a/PINNFramework/HPMLoss.py
+++ b/PINNFramework/HPMLoss.py
@@ -1,7 +1,7 @@
from .PDELoss import PDELoss
class HPMLoss(PDELoss):
- def __init__(self, dataset, hpm_input, hpm_model, norm='L2', weight=1.):
+ def __init__(self, dataset, name, hpm_input, hpm_model, norm='L2', weight=1.):
"""
Constructor of the HPM loss
@@ -13,7 +13,7 @@ def __init__(self, dataset, hpm_input, hpm_model, norm='L2', weight=1.):
norm: Norm used for calculation PDE loss
weight: Weighting for the loss term
"""
- super(HPMLoss, self).__init__(dataset, None, norm, weight)
+ super(HPMLoss, self).__init__(dataset, None, name, norm='L2', weight=1.)
self.hpm_input = hpm_input
self.hpm_model = hpm_model
diff --git a/PINNFramework/InitalCondition.py b/PINNFramework/InitalCondition.py
index 502855e..8031083 100644
--- a/PINNFramework/InitalCondition.py
+++ b/PINNFramework/InitalCondition.py
@@ -4,16 +4,16 @@
class InitialCondition(LossTerm):
- def __init__(self, dataset, norm='L2', weight=1.):
+ def __init__(self, dataset, name, norm='L2', weight=1.):
"""
- Constructor for the Intial condition
+ Constructor for the Initial condition
Args:
dataset (torch.utils.Dataset): dataset that provides the residual points
norm: Norm used for calculation PDE loss
weight: Weighting for the loss term
"""
- super(InitialCondition, self).__init__(dataset, norm, weight)
+ super(InitialCondition, self).__init__(dataset, name, norm, weight)
def __call__(self, x: Tensor, model: Module, gt_y: Tensor):
"""
diff --git a/PINNFramework/JoinedDataset.py b/PINNFramework/JoinedDataset.py
index 4d9a33e..0b40cf1 100644
--- a/PINNFramework/JoinedDataset.py
+++ b/PINNFramework/JoinedDataset.py
@@ -19,12 +19,30 @@ def min_length(datasets):
minimum = length
return minimum
- def __init__(self, datasets):
+ @staticmethod
+ def max_length(datasets):
+ """
+ Calculates the minimum dataset length of a list of datasets
+
+ datasets (Map): Map of datasets to be concatenated
+ """
+ maximum = -1 * float("inf")
+ for key in datasets.keys():
+ length = len(datasets[key])
+ if length > maximum:
+ maximum = length
+ return maximum
+
+ def __init__(self, datasets, mode='min'):
super(JoinedDataset, self).__init__()
self.datasets = datasets
+ self.mode = mode
def __len__(self):
- return self.min_length(self.datasets)
+ if self.mode =='min':
+ return self.min_length(self.datasets)
+ if self.mode =='max':
+ return self.max_length(self.datasets)
def __getitem__(self, idx):
if idx < 0:
@@ -32,6 +50,8 @@ def __getitem__(self, idx):
raise ValueError("absolute value of index should not exceed dataset length")
combined_item = {}
for key in self.datasets.keys():
+ if self.mode == 'max':
+ idx = idx % len(self.datasets[key])
item = self.datasets[key][idx]
combined_item[key] = item
return combined_item
diff --git a/PINNFramework/Logger_Interface.py b/PINNFramework/Logger_Interface.py
new file mode 100644
index 0000000..1523697
--- /dev/null
+++ b/PINNFramework/Logger_Interface.py
@@ -0,0 +1,45 @@
+from abc import ABC, abstractmethod
+
+
+class LoggerInterface(ABC):
+
+ @abstractmethod
+ def log_scalar(self, scalar, name, epoch):
+ """
+ Method that defines how scalars are logged
+
+ Args:
+ scalar: scalar to be logged
+ name: name of the scalar
+ epoch: epoch in the training loop
+
+ """
+ pass
+
+ @abstractmethod
+ def log_image(self, image, name, epoch):
+ """
+ Method that defines how images are logged
+
+ Args:
+ image: image to be logged
+ name: name of the image
+ epoch: epoch in the training loop
+
+ """
+ pass
+
+ @abstractmethod
+ def log_histogram(self, histogram, name, epoch):
+ """
+ Method that defines how images are logged
+
+ Args:
+ histogram: histogram to be logged
+ name: name of the histogram
+ epoch: epoch in the training loop
+
+ """
+ pass
+
+
diff --git a/PINNFramework/LossTerm.py b/PINNFramework/LossTerm.py
index 10f84c5..6b5654d 100644
--- a/PINNFramework/LossTerm.py
+++ b/PINNFramework/LossTerm.py
@@ -5,13 +5,12 @@ class LossTerm:
"""
Defines the main structure of a loss term
"""
- def __init__(self, dataset, norm='L2', weight=1.):
+ def __init__(self, dataset, name, norm='L2', weight=1.):
"""
Constructor of a loss term
Args:
dataset (torch.utils.Dataset): dataset that provides the residual points
- pde (function): function that represents residual of the PDE
norm: Norm used for calculation PDE loss
weight: Weighting for the loss term
"""
@@ -24,4 +23,5 @@ def __init__(self, dataset, norm='L2', weight=1.):
# Case for self implemented norms
self.norm = norm
self.dataset = dataset
+ self.name = name
self.weight = weight
diff --git a/PINNFramework/PDELoss.py b/PINNFramework/PDELoss.py
index 0b4f26b..4ca8fc2 100644
--- a/PINNFramework/PDELoss.py
+++ b/PINNFramework/PDELoss.py
@@ -1,12 +1,11 @@
import torch
from torch import Tensor as Tensor
from torch.nn import Module as Module
-from torch.nn import MSELoss, L1Loss
from .LossTerm import LossTerm
class PDELoss(LossTerm):
- def __init__(self, dataset, pde, norm='L2', weight=1.):
+ def __init__(self, dataset, pde, name, norm='L2', weight=1.):
"""
Constructor of the PDE Loss
@@ -16,7 +15,7 @@ def __init__(self, dataset, pde, norm='L2', weight=1.):
norm: Norm used for calculation PDE loss
weight: Weighting for the loss term
"""
- super(PDELoss, self).__init__(dataset, norm, weight)
+ super(PDELoss, self).__init__(dataset, name, norm, weight)
self.dataset = dataset
self.pde = pde
diff --git a/PINNFramework/PINN.py b/PINNFramework/PINN.py
index aa5fcf9..086aa7f 100644
--- a/PINNFramework/PINN.py
+++ b/PINNFramework/PINN.py
@@ -1,23 +1,35 @@
import torch
import torch.nn as nn
+import numpy as np
+from os.path import exists
from itertools import chain
from torch.utils.data import DataLoader
from .InitalCondition import InitialCondition
-from .BoundaryCondition import BoundaryCondition, PeriodicBC, DirichletBC, NeumannBC, RobinBC
+from .BoundaryCondition import BoundaryCondition, PeriodicBC, DirichletBC, NeumannBC, RobinBC, TimeDerivativeBC
from .PDELoss import PDELoss
from .JoinedDataset import JoinedDataset
from .HPMLoss import HPMLoss
+from torch.autograd import grad as grad
+from PINNFramework.callbacks import CallbackList
try:
import horovod.torch as hvd
except:
print("Was not able to import Horovod. Thus Horovod support is not enabled")
+# set initial seed for torch and numpy
+torch.manual_seed(42)
+np.random.seed(42)
+
+
+def worker_init_fn(worker_id):
+ np.random.seed(np.random.get_state()[1][0] + worker_id)
+
class PINN(nn.Module):
def __init__(self, model: torch.nn.Module, input_dimension: int, output_dimension: int,
pde_loss: PDELoss, initial_condition: InitialCondition, boundary_condition,
- use_gpu=True, use_horovod=False):
+ use_gpu=True, use_horovod=False,dataset_mode='min'):
"""
Initializes an physics-informed neural network (PINN). A PINN consists of a model which represents the solution
of the underlying partial differential equation(PDE) u, three loss terms representing initial (IC) and boundary
@@ -33,21 +45,25 @@ def __init__(self, model: torch.nn.Module, input_dimension: int, output_dimensio
of the BoundaryCondition class
use_gpu: enables gpu usage
use_horovod: enables horovod support
+ dataset_mode: defines the behavior of the joined dataset. The 'min'-mode sets the length of the dataset to
+ the minimum of the
"""
-
super(PINN, self).__init__()
# checking if the model is a torch module more model checking should be possible
self.use_gpu = use_gpu
self.use_horovod = use_horovod
- self.rank = 0 # initialize rank 0 by default in order to make the fit method more flexible
+ self.rank = 0 # initialize rank 0 by default in order to make the fit method more flexible
+
if self.use_horovod:
+
# Initialize Horovod
hvd.init()
# Pin GPU to be used to process local rank (one GPU per process)
torch.cuda.set_device(hvd.local_rank())
self.rank = hvd.rank()
-
+ if self.rank == 0:
+ self.loss_log = {}
if isinstance(model, nn.Module):
self.model = model
if self.use_gpu:
@@ -80,7 +96,7 @@ def __init__(self, model: torch.nn.Module, input_dimension: int, output_dimensio
else:
raise TypeError("PDE loss has to be an instance of a PDE Loss class")
- if isinstance(pde_loss,HPMLoss):
+ if isinstance(pde_loss, HPMLoss):
self.is_hpm = True
if isinstance(initial_condition, InitialCondition):
@@ -88,26 +104,45 @@ def __init__(self, model: torch.nn.Module, input_dimension: int, output_dimensio
else:
raise TypeError("Initial condition has to be an instance of the InitialCondition class")
- joined_datasets = {"Initial_Condition": initial_condition.dataset, "PDE": pde_loss.dataset}
+ joined_datasets = {
+ initial_condition.name: initial_condition.dataset,
+ pde_loss.name: pde_loss.dataset
+ }
+ if self.rank == 0:
+ self.loss_log[initial_condition.name] = float(0.0) # adding initial condition to the loss_log
+ self.loss_log[pde_loss.name] = float(0.0)
+ self.loss_log["model_loss_pinn"] = float(0.0)
+
if not self.is_hpm:
if type(boundary_condition) is list:
for bc in boundary_condition:
if not isinstance(bc, BoundaryCondition):
raise TypeError("Boundary Condition has to be an instance of the BoundaryCondition class ")
- self.boundary_condition = boundary_condition
joined_datasets[bc.name] = bc.dataset
-
+ if self.rank == 0:
+ self.loss_log[bc.name] = float(0.0)
+ self.boundary_condition = boundary_condition
else:
if isinstance(boundary_condition, BoundaryCondition):
self.boundary_condition = boundary_condition
+ joined_datasets[boundary_condition.name] = boundary_condition.dataset
else:
raise TypeError("Boundary Condition has to be an instance of the BoundaryCondition class"
"or a list of instances of the BoundaryCondition class")
- self.dataset = JoinedDataset(joined_datasets)
+ self.dataset = JoinedDataset(joined_datasets, dataset_mode)
+
+ def loss_grad_std_wn(self, loss):
+ device = torch.device("cuda" if self.use_gpu else "cpu")
+ grad_ = torch.zeros((0), dtype=torch.float32, device=device)
+ model_grads = grad(loss, self.model.parameters(), allow_unused=True, retain_graph=True)
+ for elem in model_grads:
+ if elem is not None:
+ grad_ = torch.cat((grad_, elem.view(-1)))
+ return torch.std(grad_)
def forward(self, x):
"""
- Predicting the solution at given position x
+ Predicting the solution at given pos
"""
return self.model(x)
@@ -162,7 +197,7 @@ def calculate_boundary_condition(self, boundary_condition: BoundaryCondition, tr
else:
raise ValueError(
"The boundary condition {} has to be tuple of coordinates for lower and upper bound".
- format(boundary_condition.name))
+ format(boundary_condition.name))
else:
raise ValueError("The boundary condition {} has to be tuple of coordinates for lower and upper bound".
format(boundary_condition.name))
@@ -190,12 +225,27 @@ def calculate_boundary_condition(self, boundary_condition: BoundaryCondition, tr
else:
raise ValueError(
"The boundary condition {} has to be tuple of coordinates for lower and upper bound".
- format(boundary_condition.name))
+ format(boundary_condition.name))
else:
raise ValueError("The boundary condition {} has to be tuple of coordinates for lower and upper bound".
format(boundary_condition.name))
- def pinn_loss(self, training_data):
+ if isinstance(boundary_condition, TimeDerivativeBC):
+ # Robin Boundary Condition
+ if isinstance(training_data, list):
+ if len(training_data) == 2:
+ return boundary_condition(training_data[0][0].type(self.dtype),
+ training_data[1][0].type(self.dtype),
+ self.model)
+ else:
+ raise ValueError(
+ "The boundary condition {} has to be tuple of coordinates for input data and gt time derivative".
+ format(boundary_condition.name))
+ else:
+ raise ValueError("The boundary condition {} has to be tuple of coordinates for lower and upper bound".
+ format(boundary_condition.name))
+
+ def pinn_loss(self, training_data, annealing=False):
"""
Function for calculating the PINN loss. The PINN Loss is a weighted sum of losses for initial and boundary
condition and the residual of the PDE
@@ -205,37 +255,120 @@ def pinn_loss(self, training_data):
dictionary holds the training data for initial condition at the key "Initial_Condition" training data for
the PDE at the key "PDE" and the data for the boundary condition under the name of the boundary condition
"""
-
pinn_loss = 0
# unpack training data
- if type(training_data["Initial_Condition"]) is list:
+ # ============== PDE LOSS ============== "
+ if type(training_data[self.pde_loss.name]) is not list:
+ pde_loss = self.pde_loss(training_data[self.pde_loss.name][0].type(self.dtype), self.model)
+ if annealing:
+ std_pde = self.loss_grad_std_wn(pde_loss)
+ pinn_loss = pinn_loss + pde_loss
+ if self.rank == 0:
+ self.loss_log[self.pde_loss.name] = pde_loss + self.loss_log[self.pde_loss.name] / self.pde_loss.weight
+ else:
+ raise ValueError("Training Data for PDE data is a single tensor consists of residual points ")
+
+ # ============== INITIAL CONDITION ============== "
+ if type(training_data[self.initial_condition.name]) is list:
# initial condition loss
- if len(training_data["Initial_Condition"]) == 2:
- pinn_loss = pinn_loss + self.initial_condition(training_data["Initial_Condition"][0][0].type(self.dtype),
- self.model,
- training_data["Initial_Condition"][1][0].type(self.dtype))
+ if len(training_data[self.initial_condition.name]) == 2:
+ ic_loss = self.initial_condition(
+ training_data[self.initial_condition.name][0][0].type(self.dtype),
+ self.model,
+ training_data[self.initial_condition.name][1][0].type(self.dtype)
+ )
+ if self.rank == 0:
+ self.loss_log[self.initial_condition.name] = self.loss_log[self.initial_condition.name] +\
+ ic_loss / self.initial_condition.weight
+ if annealing:
+ std_ic = self.loss_grad_std_wn(ic_loss)
+ lambda_hat = std_pde / std_ic
+ self.initial_condition.weight = (1 - 0.5) * self.initial_condition.weight + 0.5 * lambda_hat
+ pinn_loss = pinn_loss + ic_loss
else:
raise ValueError("Training Data for initial condition is a tuple (x,y) with x the input coordinates"
" and ground truth values y")
else:
- raise ValueError("Training Data for initial condition is a tuple (x,y) with x the input coordinates"
+ raise ValueError("Training Data for initial condition is a tuple (x,y) with x the input coordinates"
" and ground truth values y")
- if type(training_data["PDE"]) is not list:
- pinn_loss = pinn_loss + self.pde_loss(training_data["PDE"][0].type(self.dtype), self.model)
- else:
- raise ValueError("Training Data for PDE data is a single tensor consists of residual points ")
+ # ============== BOUNDARY CONDITION ============== "
if not self.is_hpm:
if isinstance(self.boundary_condition, list):
for bc in self.boundary_condition:
- pinn_loss = pinn_loss + self.calculate_boundary_condition(bc, training_data[bc.name])
+ bc_loss = self.calculate_boundary_condition(bc, training_data[bc.name])
+ if self.rank == 0:
+ self.loss_log[bc.name] = self.loss_log[bc.name] + bc_loss / bc.weight
+ if annealing:
+ std_bc = self.loss_grad_std_wn(bc_loss)
+ lambda_hat = std_pde / std_bc
+ bc.weight = (1 - 0.5) * bc.weight + 0.5 * lambda_hat
+ pinn_loss = pinn_loss + bc_loss
else:
- pinn_loss = pinn_loss + self.calculate_boundary_condition(self.boundary_condition,
- training_data[self.boundary_condition.name])
+ bc_loss = self.calculate_boundary_condition(self.boundary_condition,
+ training_data[self.boundary_condition.name])
+ if self.rank == 0:
+ self.loss_log[self.boundary_condition.name] = self.loss_log[self.boundary_condition.name] +\
+ bc_loss / self.boundary_condition.weight
+ if annealing:
+ std_bc = self.loss_grad_std_wn(bc_loss)
+ lambda_hat = std_pde / std_bc
+ self.boundary_condition.weight = (1 - 0.5) * self.boundary_condition.weight + 0.5 * lambda_hat
+ pinn_loss = pinn_loss + bc_loss
+
+ # ============== Model specific losses ============== "
+ if hasattr(self.model, 'loss'):
+ pinn_loss = pinn_loss + self.model.loss
+ if self.rank == 0:
+ self.loss_log["model_loss_pinn"] = self.loss_log["model_loss_pinn"] + self.model.loss
+ if self.is_hpm:
+ if hasattr(self.pde_loss.hpm_model, 'loss'):
+ pinn_loss = pinn_loss + self.pde_loss.hpm_model.loss
+ if self.rank == 0:
+ self.loss_log["model_loss_hpm"] = self.loss_log["model_loss_hpm"] + self.pde_loss.hpm_model.loss
+
return pinn_loss
- def fit(self, epochs, optimizer='Adam', learning_rate=1e-3, lbfgs_finetuning=True,
- writing_cylcle= 30, save_model=True, pinn_path='best_model_pinn.pt', hpm_path='best_model_hpm.pt'):
+ def write_checkpoint(self, checkpoint_path, epoch, pretraining, minimum_pinn_loss, optimizer):
+ checkpoint = {}
+ checkpoint["epoch"] = epoch
+ checkpoint["pretraining"] = pretraining
+ checkpoint["minimum_pinn_loss"] = minimum_pinn_loss
+ checkpoint["optimizer"] = optimizer.state_dict()
+ checkpoint["weight_"+ self.initial_condition.name] = self.initial_condition.weight
+ checkpoint["weight_" + self.pde_loss.name] = self.initial_condition.weight
+ checkpoint["pinn_model"] = self.model.state_dict()
+ if isinstance(self.boundary_condition, list):
+ for bc in self.boundary_condition:
+ checkpoint["weight_"+ bc.name] = bc.weight
+ else:
+ checkpoint["weight_" + self.boundary_condition.name] = self.boundary_condition.weight
+
+ if self.is_hpm:
+ checkpoint["hpm_model"] = self.pde_loss.hpm_model.state_dict()
+ checkpoint_path = checkpoint_path + '_' + str(epoch)
+ torch.save(checkpoint, checkpoint_path)
+
+
+
+ def fit(self,
+ epochs,
+ checkpoint_path=None,
+ restart=False,
+ optimizer='Adam',
+ learning_rate=1e-3,
+ pretraining=False,
+ epochs_pt=100,
+ lbfgs_finetuning=True,
+ writing_cycle=30,
+ writing_cycle_pt=10,
+ save_model=True,
+ pinn_path='best_model_pinn.pt',
+ hpm_path='best_model_hpm.pt',
+ logger=None,
+ activate_annealing=False,
+ annealing_cycle=100,
+ callbacks=None):
"""
Function for optimizing the parameters of the PINN-Model
@@ -244,17 +377,33 @@ def fit(self, epochs, optimizer='Adam', learning_rate=1e-3, lbfgs_finetuning=Tru
optimizer (String, torch.optim.Optimizer) : Optimizer used for training. At the moment only ADAM and LBFGS
are supported by string command. It is also possible to give instances of torch optimizers as a parameter
learning_rate: The learning rate of the optimizer
+ pretraining: Activates seperate training on the initial condition at the beginning
+ epochs_pt: defines the number of epochs for the pretraining
lbfgs_finetuning: Enables LBFGS finetuning after main training
writing_cylcle: defines the cylcus of model writing
save_model: enables or disables checkpointing
pinn_path: defines the path where the pinn get stored
hpm_path: defines the path where the hpm get stored
+ logger (Logger): tracks the convergence of all loss terms
+ activate_annealing (Boolean): enables annealing
+ annealing_cycle (int): defines the periodicity of using annealing
+ callbacks (CallbackList): is a list of callbacks which are called at the end of a writing cycle. Can be used
+ for different purposes e.g. early stopping, visualization, model state logging etc.
+ checkpoint_path (string) : path to the checkpoint
+ restart (integer) : defines if checkpoint will be used (False) or will be overwritten (True)
+
"""
+ # checking if callbacks are a instance of CallbackList
+ if callbacks is not None:
+ if not isinstance(callbacks, CallbackList):
+ raise ValueError("Callbacks has to be a instance of CallbackList but type {} was found".
+ format(type(callbacks)))
+
if isinstance(self.pde_loss, HPMLoss):
params = list(self.model.parameters()) + list(self.pde_loss.hpm_model.parameters())
- named_parameters = chain(self.model.named_parameters(),self.pde_loss.hpm_model.named_parameters())
- if self.use_horovod and lbfgs_finetuning:
+ named_parameters = chain(self.model.named_parameters(), self.pde_loss.hpm_model.named_parameters())
+ if self.use_horovod and lbfgs_finetuning:
raise ValueError("LBFGS Finetuning is not possible with horovod")
if optimizer == 'Adam':
optim = torch.optim.Adam(params, lr=learning_rate)
@@ -268,6 +417,7 @@ def fit(self, epochs, optimizer='Adam', learning_rate=1e-3, lbfgs_finetuning=Tru
if lbfgs_finetuning and not self.use_horovod:
lbfgs_optim = torch.optim.LBFGS(params, lr=0.9)
+
def closure():
lbfgs_optim.zero_grad()
pinn_loss = self.pinn_loss(training_data)
@@ -295,33 +445,127 @@ def closure():
if self.use_horovod:
# Partition dataset among workers using DistributedSampler
train_sampler = torch.utils.data.distributed.DistributedSampler(
- self.dataset, num_replicas=hvd.size(), rank=hvd.rank())
- data_loader = DataLoader(self.dataset, batch_size=1,sampler=train_sampler)
+ self.dataset, num_replicas=hvd.size(), rank=hvd.rank()
+ )
+ data_loader = DataLoader(self.dataset, batch_size=1, sampler=train_sampler, worker_init_fn=worker_init_fn)
optim = hvd.DistributedOptimizer(optim, named_parameters=named_parameters)
+ if pretraining:
+ train_sampler_pt = torch.utils.data.distributed.DistributedSampler(
+ self.initial_condition.dataset, num_replicas=hvd.size(), rank=hvd.rank()
+ )
+ data_loader_pt = DataLoader(self.initial_condition.dataset,
+ batch_size=None,
+ sampler=train_sampler_pt,
+ worker_init_fn=worker_init_fn)
# Broadcast parameters from rank 0 to all other processes.
hvd.broadcast_parameters(self.model.state_dict(), root_rank=0)
if isinstance(self.pde_loss, HPMLoss):
- hvd.broadcast_parameters(self.pinn_loss.hpm_model.state_dict(), root_rank=0)
+ hvd.broadcast_parameters(self.pde_loss.hpm_model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optim, root_rank=0)
else:
- data_loader = DataLoader(self.dataset, batch_size=1)
+ data_loader = DataLoader(self.dataset, batch_size=1, worker_init_fn=worker_init_fn)
+ data_loader_pt = DataLoader(self.initial_condition.dataset, batch_size=None, worker_init_fn=worker_init_fn)
+
+ start_epoch = 0
+
+ # load checkpoint routine if a checkpoint path is set and its allowed to not overwrite the checkpoint
+ if checkpoint_path is not None:
+ if not exists(checkpoint_path) and not restart:
+ raise FileNotFoundError(
+ "Checkpoint path {} do not exists. Please change the path to a existing checkpoint"
+ "or change the restart flag to true in order to create a new checkpoint"
+ .format(checkpoint_path))
+ if checkpoint_path is not None and restart == 0:
+ checkpoint = torch.load(checkpoint_path)
+ start_epoch = checkpoint["epoch"]
+ pretraining = checkpoint["pretraining"]
+ self.initial_condition.weight = checkpoint["weight_" + self.initial_condition.name]
+ self.pde_loss.weight = checkpoint["weight_" + self.pde_loss.name]
+ if isinstance(self.boundary_condition, list):
+ for bc in self.boundary_condition:
+ bc.weight = checkpoint["weight_" + bc.name]
+ else:
+ self.boundary_condition.weight = checkpoint["weight_" + self.boundary_condition.name]
+
+ self.model.load_state_dict(checkpoint["pinn_model"])
+ if self.is_hpm:
+ self.pde_loss.hpm_model.load_state_dict(checkpoint["hpm_model"])
- for epoch in range(epochs):
- for training_data in data_loader:
- training_data = training_data
+ optim.load_state_dict(checkpoint['optimizer'])
+ minimum_pinn_loss = checkpoint["minimum_pinn_loss"]
+ print("Checkpoint Loaded", flush=True)
+ else:
+ print("Checkpoint not loaded", flush=True)
+
+ print("===== Pretraining =====")
+ if pretraining:
+ for epoch in range(start_epoch, epochs_pt):
+ for x, y in data_loader_pt:
+ optim.zero_grad()
+ ic_loss = self.initial_condition(model=self.model, x=x.type(self.dtype), gt_y=y.type(self.dtype))
+ ic_loss.backward()
+ optim.step()
+ if not self.rank and not (epoch + 1) % writing_cycle_pt and checkpoint_path is not None:
+ self.write_checkpoint(checkpoint_path, epoch, True, minimum_pinn_loss, optim)
+ if not self.rank:
+ print("IC Loss {} Epoch {} from {}".format(ic_loss, epoch+1, epochs_pt))
+ print("===== Main training =====")
+ for epoch in range(start_epoch, epochs):
+ # for parallel training the rank should also define the seed
+ np.random.seed(42 + epoch + self.rank)
+ batch_counter = 0.
+ pinn_loss_sum = 0.
+ for idx, training_data in enumerate(data_loader):
+ do_annealing = activate_annealing and (idx == 0) and not (epoch + 1) % annealing_cycle
optim.zero_grad()
- pinn_loss = self.pinn_loss(training_data)
+ pinn_loss = self.pinn_loss(training_data, do_annealing)
pinn_loss.backward()
- if not self.rank:
- print("PINN Loss {} Epoch {} from {}".format(pinn_loss, epoch, epochs))
optim.step()
- if (pinn_loss < minimum_pinn_loss) and not (epoch % writing_cylcle) and save_model and not self.rank:
- self.save_model(pinn_path, hpm_path)
- minimum_pinn_loss = pinn_loss
-
+ pinn_loss_sum = pinn_loss_sum + pinn_loss
+ batch_counter += 1
+ del pinn_loss
+
+ if not self.rank:
+ print("PINN Loss {} Epoch {} from {}".format(pinn_loss_sum / batch_counter, epoch+1, epochs), flush=True)
+ if logger is not None and not (epoch+1) % writing_cycle:
+ logger.log_scalar(scalar=pinn_loss_sum / batch_counter, name=" Weighted PINN Loss", epoch=epoch)
+ logger.log_scalar(scalar=sum(self.loss_log.values())/batch_counter,
+ name=" Non-Weighted PINN Loss", epoch=epoch+1)
+ # Log values of the loss terms
+ for key, value in self.loss_log.items():
+ logger.log_scalar(scalar=value / batch_counter, name=key, epoch=epoch+1)
+
+ # Log weights of loss terms separately
+ logger.log_scalar(scalar=self.initial_condition.weight,
+ name=self.initial_condition.name + "_weight",
+ epoch=epoch+1)
+ if not self.is_hpm:
+ if isinstance(self.boundary_condition, list):
+ for bc in self.boundary_condition:
+ logger.log_scalar(scalar=bc.weight,
+ name=bc.name + "_weight",
+ epoch=epoch+1)
+ else:
+ logger.log_scalar(scalar=self.boundary_condition.weight,
+ name=self.boundary_condition.name + "_weight",
+ epoch=epoch+1)
+ if callbacks is not None and not (epoch+1) % writing_cycle:
+ callbacks(epoch=epoch+1)
+ # saving routine
+ if (pinn_loss_sum / batch_counter < minimum_pinn_loss) and save_model:
+ self.save_model(pinn_path, hpm_path)
+ minimum_pinn_loss = pinn_loss_sum / batch_counter
+
+ # reset loss log after the end of the epoch
+ for key in self.loss_log.keys():
+ self.loss_log[key] = float(0)
+
+ # writing checkpoint
+ if not (epoch + 1) % writing_cycle and checkpoint_path is not None:
+ self.write_checkpoint(checkpoint_path, epoch, False, minimum_pinn_loss, optim)
if lbfgs_finetuning:
lbfgs_optim.step(closure)
- print("After LBFGS-B: PINN Loss {} Epoch {} from {}".format(pinn_loss, epoch, epochs))
- if (pinn_loss < minimum_pinn_loss) and not (epoch % writing_cylcle) and save_model:
+ print("After LBFGS-B: PINN Loss {} Epoch {} from {}".format(pinn_loss, epoch+1, epochs))
+ if (pinn_loss < minimum_pinn_loss) and not (epoch % writing_cycle) and save_model:
self.save_model(pinn_path, hpm_path)
diff --git a/PINNFramework/WandB_Logger.py b/PINNFramework/WandB_Logger.py
new file mode 100644
index 0000000..8414948
--- /dev/null
+++ b/PINNFramework/WandB_Logger.py
@@ -0,0 +1,70 @@
+from .Logger_Interface import LoggerInterface
+import wandb
+
+
+class WandbLogger(LoggerInterface):
+
+ def __init__(self, project, args, entity=None, group=None):
+ """
+ Initialize wandb instance and connect to the server
+
+ Args:
+ project: name of the project
+ args: hyperparameters used for this runs
+ writing_cycle: defines the writing period
+ entity: account or group id used for that run
+ """
+ wandb.init(project=project, entity=entity, group=group)
+ wandb.config.update(args) # adds all of the arguments as config variable
+
+ def log_scalar(self, scalar, name, epoch):
+ """
+ Logs a scalar to wandb
+
+ Args:
+ scalar: the scalar to be logged
+ name: name of the sclar
+ epoch: epoch in the training loop
+ """
+ wandb.log({name: scalar}, step=epoch)
+
+ def log_image(self, image, name, epoch):
+ """
+ Logs a image to wandb
+
+ Args:
+ image (Image) : the image to be logged
+ name (String) : name of the image
+ epoch (Integer) : epoch in the training loop
+
+ """
+ wandb.log({name: [wandb.Image(image, caption=name)]}, step=epoch)
+
+ def log_plot(self, plot, name, epoch):
+ """
+ Logs a plot to wandb
+
+ Args:
+ plot (plot) : the plot to be logged
+ name (String) : name of the plot
+ epoch (Integer) : epoch in the training loop
+
+ """
+ wandb.log({name: plot}, step=epoch)
+
+ def log_histogram(self, histogram,name, epoch):
+ """
+ Logs a histogram to wandb
+
+ Args:
+ histogram (histogram) : the histogram to be logged
+ name (String) : name of the histogram
+ epoch (Integer) : epoch in the training loop
+
+ """
+ wandb.log({name: wandb.Histogram(histogram)}, step=epoch)
+
+
+
+
+
diff --git a/PINNFramework/__init__.py b/PINNFramework/__init__.py
index eb42dbb..b54ca88 100644
--- a/PINNFramework/__init__.py
+++ b/PINNFramework/__init__.py
@@ -3,11 +3,16 @@
from .BoundaryCondition import PeriodicBC
from .BoundaryCondition import DirichletBC
from .BoundaryCondition import RobinBC
+from .BoundaryCondition import TimeDerivativeBC
from .BoundaryCondition import NeumannBC
from .PDELoss import PDELoss
+from .Logger_Interface import LoggerInterface
+from .WandB_Logger import WandbLogger
from .PINN import PINN
import PINNFramework.models
+import PINNFramework.callbacks
+
__all__ = [
'InitialCondition',
@@ -15,8 +20,11 @@
'DirichletBC',
'RobinBC',
'NeumannBC',
+ 'TimeDerivativeBC',
'PDELoss',
'HPMLoss',
'PINN',
- 'models'
-]
+ 'models',
+ 'LoggerInterface',
+ 'WandbLogger',
+ 'callbacks']
diff --git a/PINNFramework/callbacks/Callback.py b/PINNFramework/callbacks/Callback.py
new file mode 100644
index 0000000..daa4c5a
--- /dev/null
+++ b/PINNFramework/callbacks/Callback.py
@@ -0,0 +1,48 @@
+from torch.nn import Module
+from PINNFramework.Logger_Interface import LoggerInterface
+
+
+class Callback:
+ def __init__(self):
+ self.model = None
+ self.logger = None
+
+ def set_model(self, model):
+ if isinstance(model, Module):
+ self.model = model
+ else:
+ raise ValueError("Model is not of type but model of type {} was found"
+ .format(type(model)))
+
+ def set_logger(self, logger):
+ if isinstance(LoggerInterface):
+ self.logger
+ else:
+ raise ValueError("Logger is not of type but logger of type {} was found"
+ .format(type(logger)))
+
+ def __call__(self, epoch):
+ raise NotImplementedError("method __call__() of the callback is not implemented")
+
+
+class CallbackList:
+ def __init__(self, callbacks):
+ if isinstance(callbacks, list):
+ for cb in callbacks:
+ if not isinstance(cb, Callback):
+ raise ValueError("Callback has to be of type but type {} was found"
+ .format(type(cb)))
+ self.callbacks = callbacks
+ else:
+ raise ValueError("Callback has to be of type but type {} was found"
+ .format(type(callbacks)))
+
+ def __call__(self, epoch):
+ for cb in self.callbacks:
+ cb(epoch)
+
+
+
+
+
+
diff --git a/PINNFramework/callbacks/__init__.py b/PINNFramework/callbacks/__init__.py
new file mode 100644
index 0000000..ceb4290
--- /dev/null
+++ b/PINNFramework/callbacks/__init__.py
@@ -0,0 +1,6 @@
+from .Callback import Callback
+from .Callback import CallbackList
+
+___all__ = ["Callback",
+ "CallbackList"
+ ]
\ No newline at end of file
diff --git a/PINNFramework/models/Finger_Net.py b/PINNFramework/models/Finger_Net.py
new file mode 100644
index 0000000..797e98f
--- /dev/null
+++ b/PINNFramework/models/Finger_Net.py
@@ -0,0 +1,115 @@
+import torch
+import torch.nn as nn
+
+class FingerNet(nn.Module):
+ def __init__(self, lb, ub, numFeatures = 500, numLayers = 8, activation = torch.relu, normalize=True, scaling=1.):
+ torch.manual_seed(1234)
+ super(FingerNet, self).__init__()
+
+ self.numFeatures = numFeatures
+ self.numLayers = numLayers
+ self.lin_layers = nn.ModuleList()
+ self.lb = torch.tensor(lb).float()
+ self.ub = torch.tensor(ub).float()
+ self.activation = activation
+ self.normalize = normalize
+ self.scaling = scaling
+ self.init_layers()
+
+
+ def init_layers(self):
+ """
+ This function creates the torch layers and initialize them with xavier
+ :param self:
+ :return:
+ """
+ self.in_x = nn.ModuleList()
+ self.in_y = nn.ModuleList()
+ self.in_z = nn.ModuleList()
+ self.in_t = nn.ModuleList()
+ lenInput = 1
+ noInLayers = 3
+
+ self.in_x.append(nn.Linear(lenInput,self.numFeatures))
+ for _ in range(noInLayers):
+ self.in_x.append(nn.Linear(self.numFeatures, self.numFeatures))
+
+ self.in_y.append(nn.Linear(lenInput,self.numFeatures))
+ for _ in range(noInLayers):
+ self.in_y.append(nn.Linear(self.numFeatures, self.numFeatures))
+
+ self.in_z.append(nn.Linear(lenInput,self.numFeatures))
+ for _ in range(noInLayers):
+ self.in_z.append(nn.Linear(self.numFeatures, self.numFeatures))
+
+ self.in_t.append(nn.Linear(1,self.numFeatures))
+ for _ in range(noInLayers):
+ self.in_t.append(nn.Linear(self.numFeatures, self.numFeatures))
+
+ for m in [self.in_x,self.in_y,self.in_z,self.in_t]:
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ #nn.init.constant_(m.bias, 0)
+
+ self.lin_layers.append(nn.Linear(4 * self.numFeatures, self.numFeatures))
+ for i in range(self.numLayers):
+ inFeatures = self.numFeatures
+ self.lin_layers.append(nn.Linear(inFeatures,self.numFeatures))
+ inFeatures = self.numFeatures
+ self.lin_layers.append(nn.Linear(inFeatures,1))
+ for m in self.lin_layers:
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x_in):
+ if self.normalize:
+ x = 2.0 * (x_in - self.lb) / (self.ub - self.lb) - 1.0
+
+ x_inx = x_in[:,0].view(-1,1)
+ x_iny = x_in[:,1].view(-1,1)
+ x_inz = x_in[:,2].view(-1,1)
+ x_int = x_in[:,3].view(-1,1)
+
+ for i in range(0,len(self.in_x)):
+ x_inx = self.in_x[i](x_inx)
+ x_inx = self.activation(x_inx)
+
+
+ for i in range(0,len(self.in_y)):
+ x_iny = self.in_y[i](x_iny)
+ x_iny = self.activation(x_iny)
+
+ for i in range(0,len(self.in_z)):
+ x_inz = self.in_z[i](x_inz)
+ x_inz = self.activation(x_inz)
+
+ for i in range(0,len(self.in_t)):
+ x_int = self.in_t[i](x_int)
+ x_int = self.activation(x_int)
+
+
+ x = torch.cat([x_inx,x_iny,x_inz,x_int],1)
+
+
+ for i in range(0,len(self.lin_layers)-1):
+ x = self.lin_layers[i](x)
+ x = self.activation(x)
+ x = self.lin_layers[-1](x)
+
+ return self.scaling * x
+
+ def cuda(self):
+ super(FingerNet, self).cuda()
+ self.lb = self.lb.cuda()
+ self.ub = self.ub.cuda()
+
+ def cpu(self):
+ super(FingerNet, self).cpu()
+ self.lb = self.lb.cpu()
+ self.ub = self.ub.cpu()
+
+ def to(self, device):
+ super(FingerNet, self).to(device)
+ self.lb.to(device)
+ self.ub.to(device)
\ No newline at end of file
diff --git a/PINNFramework/models/__init__.py b/PINNFramework/models/__init__.py
index 00d6803..5550f1a 100644
--- a/PINNFramework/models/__init__.py
+++ b/PINNFramework/models/__init__.py
@@ -1,5 +1,17 @@
from .mlp import MLP
-
+from .distributed_moe import MoE as distMoe
+from .moe_mlp import MoE as MoE
+from .snake_mlp import SnakeMLP
+from .Finger_Net import FingerNet
+from .moe_finger import MoE as FingerMoE
+from . import activations
__all__ = [
- 'MLP'
+ 'MLP',
+ 'MoE',
+ 'distMoe',
+ 'SnakeMLP',
+ 'FingerNet',
+ 'FingerMoE',
+ 'activations'
+
]
diff --git a/PINNFramework/models/activations/__init__.py b/PINNFramework/models/activations/__init__.py
new file mode 100644
index 0000000..9977c26
--- /dev/null
+++ b/PINNFramework/models/activations/__init__.py
@@ -0,0 +1,5 @@
+from .snake import Snake
+
+__all__ = [
+ 'Snake'
+]
diff --git a/PINNFramework/models/activations/snake.py b/PINNFramework/models/activations/snake.py
new file mode 100644
index 0000000..ca819a1
--- /dev/null
+++ b/PINNFramework/models/activations/snake.py
@@ -0,0 +1,20 @@
+import torch
+import torch.nn as nn
+
+
+class Snake(nn.Module,):
+ """ Implementation of the snake activation function as a torch nn module
+ The result of the activation function a(x) is calculated by a(x) = x + sin^2(x)
+ With alpha is a trainab
+ """
+
+ def __init__(self,frequency=10):
+ """Constructor function that initialize the torch module
+ """
+ super(Snake, self).__init__()
+
+ # making beta trainable by activating gradient calculation
+ self.a = nn.Parameter(torch.tensor([float(frequency)], requires_grad=True))
+
+ def forward(self, x):
+ return x + ((torch.sin(self.a* x)) ** 2) / self.a
\ No newline at end of file
diff --git a/PINNFramework/models/moe.py b/PINNFramework/models/distributed_moe.py
similarity index 86%
rename from PINNFramework/models/moe.py
rename to PINNFramework/models/distributed_moe.py
index 7a3f9c8..c7420d5 100644
--- a/PINNFramework/models/moe.py
+++ b/PINNFramework/models/distributed_moe.py
@@ -132,7 +132,9 @@ class MoE(nn.Module):
k: an integer - how many experts to use for each batch element
"""
- def __init__(self, input_size, output_size, num_experts, hidden_size, num_hidden, activation=torch.tanh, non_linear=False,noisy_gating=False, k=4, device = "cpu"):
+ def __init__(self, input_size, output_size, num_experts,
+ hidden_size, num_hidden, lb, ub, activation=torch.tanh,
+ non_linear=False, noisy_gating=False, k=1, device = "cpu"):
super(MoE, self).__init__()
self.noisy_gating = noisy_gating
self.num_experts = num_experts
@@ -142,12 +144,16 @@ def __init__(self, input_size, output_size, num_experts, hidden_size, num_hidden
self.device = device
self.k = k
self.loss = 0
-
- # instantiate experts
- self.experts = nn.ModuleList([MLP(input_size, output_size, hidden_size, num_hidden, activation) for i in range(self.num_experts)])
-
- self.w_gate = nn.Parameter(torch.randn(input_size, num_experts), requires_grad=True)
- self.w_noise = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True)
+ self.num_devices = torch.cuda.device_count() - 1 # cuda:0 is for handling data
+ print("The model runs on {} devices".format(self.num_devices))
+ # instantiate experts on the needed GPUs
+ self.experts = nn.ModuleList([
+ MLP(input_size, output_size, hidden_size, num_hidden, lb, ub, activation,
+ device='cuda:{}'.format((i % self.num_devices)+1))
+ .to('cuda:{}'.format((i % self.num_devices)+1))for i in range(self.num_experts)
+ ])
+ self.w_gate = nn.Parameter(torch.randn(input_size, num_experts, device=self.device), requires_grad=True)
+ self.w_noise = nn.Parameter(torch.zeros(input_size, num_experts, device=self.device), requires_grad=True)
self.softplus = nn.Softplus()
self.softmax = nn.Softmax(1)
@@ -155,7 +161,8 @@ def __init__(self, input_size, output_size, num_experts, hidden_size, num_hidden
self.non_linear = non_linear
if self.non_linear:
- self.gating_network = MLP(input_size,num_experts,num_experts*2,1,activation=F.relu)
+ self.gating_network = MLP(input_size, num_experts, num_experts*2, 1, lb, ub,activation=F.relu,
+ device=self.device).to(self.device)
assert(self.k <= self.num_experts)
@@ -186,9 +193,6 @@ def _gates_to_load(self, gates):
"""
return (gates > 0).sum(0)
-
-
-
def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values):
"""Helper function to NoisyTopKGating.
Computes the probability that value is in top k, given different random noise.
@@ -210,11 +214,11 @@ def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_val
batch = clean_values.size(0)
m = noisy_top_values.size(1)
top_values_flat = noisy_top_values.flatten()
- threshold_positions_if_in = (torch.arange(batch) * m + self.k).to(self.device)#.cuda()
+ threshold_positions_if_in = (torch.arange(batch) * m + self.k).to(self.device)
threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1)
is_in = torch.gt(noisy_values, threshold_if_in)
- threshold_positions_if_out = (threshold_positions_if_in - 1).to(self.device)#.cuda()
- threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat,0 , threshold_positions_if_out), 1)
+ threshold_positions_if_out = (threshold_positions_if_in - 1).to(self.device)
+ threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1)
# is each value currently in the top k.
prob_if_in = self.normal.cdf((clean_values - threshold_if_in)/noise_stddev)
prob_if_out = self.normal.cdf((clean_values - threshold_if_out)/noise_stddev)
@@ -222,7 +226,7 @@ def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_val
return prob
- def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2):
+ def noisy_top_k_gating(self, x, train, noise_epsilon=1e-1):
"""Noisy top-k gating.
See paper: https://arxiv.org/abs/1701.06538.
Args:
@@ -243,7 +247,7 @@ def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2):
if(self.k > 1):
raw_noise_stddev = self.softplus(raw_noise_stddev)
noise_stddev = ((raw_noise_stddev + noise_epsilon) * train)
- noisy_logits = clean_logits + ( torch.randn_like(clean_logits) * noise_stddev)
+ noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev)
logits = noisy_logits
else:
logits = clean_logits
@@ -290,9 +294,18 @@ def forward(self, x, train=True, loss_coef=1e-2):
self.loss = loss
- dispatcher = SparseDispatcher(self.num_experts, gates)
+ dispatcher = SparseDispatcher(self.num_experts, gates, device=self.device)
expert_inputs = dispatcher.dispatch(x)
gates = dispatcher.expert_to_gates()
- expert_outputs = [self.experts[i](expert_inputs[i].unsqueeze(1)) for i in range(self.num_experts)]
+ # Here is a loop needed for asynchonous calls of the GPUs
+ expert_outputs = []
+ for i in range(self.num_experts):
+ # move data to device
+ exp_input = expert_inputs[i].to(self.experts[i].device)
+ expert_output = self.experts[i](exp_input)
+ # move expert output back to device
+ expert_output = expert_output.to(self.device)
+ # append it for stiching
+ expert_outputs.append(expert_output)
y = dispatcher.combine(expert_outputs)
- return y #, loss
\ No newline at end of file
+ return y
diff --git a/PINNFramework/models/mlp.py b/PINNFramework/models/mlp.py
index 5b074f5..9981b75 100644
--- a/PINNFramework/models/mlp.py
+++ b/PINNFramework/models/mlp.py
@@ -3,13 +3,14 @@
class MLP(nn.Module):
- def __init__(self, input_size, output_size, hidden_size, num_hidden, lb, ub, activation=torch.tanh):
+ def __init__(self, input_size, output_size, hidden_size, num_hidden, lb, ub, activation=torch.tanh, normalize=True):
super(MLP, self).__init__()
self.linear_layers = nn.ModuleList()
self.activation = activation
self.init_layers(input_size, output_size, hidden_size,num_hidden)
self.lb = torch.Tensor(lb).float()
self.ub = torch.Tensor(ub).float()
+ self.normalize = normalize
def init_layers(self, input_size, output_size, hidden_size, num_hidden):
self.linear_layers.append(nn.Linear(input_size, hidden_size))
@@ -23,7 +24,8 @@ def init_layers(self, input_size, output_size, hidden_size, num_hidden):
nn.init.constant_(m.bias, 0)
def forward(self, x):
- x = 2.0*(x - self.lb)/(self.ub - self.lb) - 1.0
+ if self.normalize:
+ x = 2.0*(x - self.lb)/(self.ub - self.lb) - 1.0
for i in range(len(self.linear_layers) - 1):
x = self.linear_layers[i](x)
x = self.activation(x)
@@ -36,7 +38,11 @@ def cuda(self):
self.ub = self.ub.cuda()
def cpu(self):
- super(MLP, self).cuda()
+ super(MLP, self).cpu()
self.lb = self.lb.cpu()
self.ub = self.ub.cpu()
-
+
+ def to(self, device):
+ super(MLP,self).to(device)
+ self.lb.to(device)
+ self.ub.to(device)
diff --git a/PINNFramework/models/moe_finger.py b/PINNFramework/models/moe_finger.py
new file mode 100644
index 0000000..2ed6659
--- /dev/null
+++ b/PINNFramework/models/moe_finger.py
@@ -0,0 +1,350 @@
+# Sparsely-Gated Mixture-of-Experts Layers.
+# See "Outrageously Large Neural Networks"
+# https://arxiv.org/abs/1701.06538
+#
+# Author: David Rau
+#
+# The code is based on the TensorFlow implementation:
+# https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/expert_utils.py
+
+import time
+import torch
+import torch.nn as nn
+from torch.distributions.normal import Normal
+import numpy as np
+import torch.nn.functional as F
+from PINNFramework.models import FingerNet
+from PINNFramework.models import MLP
+
+class SparseDispatcher(object):
+ """Helper for implementing a mixture of experts.
+ The purpose of this class is to create input minibatches for the
+ experts and to combine the results of the experts to form a unified
+ output tensor.
+ There are two functions:
+ dispatch - take an input Tensor and create input Tensors for each expert.
+ combine - take output Tensors from each expert and form a combined output
+ Tensor. Outputs from different experts for the same batch element are
+ summed together, weighted by the provided "gates".
+ The class is initialized with a "gates" Tensor, which specifies which
+ batch elements go to which experts, and the weights to use when combining
+ the outputs. Batch element b is sent to expert e iff gates[b, e] != 0.
+ The inputs and outputs are all two-dimensional [batch, depth].
+ Caller is responsible for collapsing additional dimensions prior to
+ calling this class and reshaping the output to the original shape.
+ See common_layers.reshape_like().
+ Example use:
+ gates: a float32 `Tensor` with shape `[batch_size, num_experts]`
+ inputs: a float32 `Tensor` with shape `[batch_size, input_size]`
+ experts: a list of length `num_experts` containing sub-networks.
+ dispatcher = SparseDispatcher(num_experts, gates)
+ expert_inputs = dispatcher.dispatch(inputs)
+ expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)]
+ outputs = dispatcher.combine(expert_outputs)
+ The preceding code sets the output for a particular example b to:
+ output[b] = Sum_i(gates[b, i] * experts[i](inputs[b]))
+ This class takes advantage of sparsity in the gate matrix by including in the
+ `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`.
+ """
+
+ def __init__(self, num_experts, gates, use_gpu=False):
+ """Create a SparseDispatcher."""
+ self.use_gpu=use_gpu
+ self._gates = gates
+ self._num_experts = num_experts
+ # sort experts
+ sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0)
+ # drop indices
+ _, self._expert_index = sorted_experts.split(1, dim=1)
+ # get according batch index for each expert
+ self._batch_index = sorted_experts[index_sorted_experts[:, 1],0]
+ # calculate num samples that each expert gets
+ if self.use_gpu:
+ self._part_sizes = list((gates > 0).sum(0).cuda())
+ else:
+ self._part_sizes = list((gates > 0).sum(0))
+
+ # expand gates to match with self._batch_index
+ gates_exp = gates[self._batch_index.flatten()]
+ self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index)
+
+ def dispatch(self, inp):
+ """Create one input Tensor for each expert.
+ The `Tensor` for a expert `i` contains the slices of `inp` corresponding
+ to the batch elements `b` where `gates[b, i] > 0`.
+ Args:
+ inp: a `Tensor` of shape "[batch_size, ]`
+ Returns:
+ a list of `num_experts` `Tensor`s with shapes
+ `[expert_batch_size_i, ]`.
+ """
+
+ # assigns samples to experts whose gate is nonzero
+
+ # expand according to batch index so we can just split by _part_sizes
+ inp_exp = inp[self._batch_index].squeeze(1)
+ return torch.split(inp_exp, self._part_sizes, dim=0)
+
+
+ def combine(self, expert_out, multiply_by_gates=True):
+ """Sum together the expert output, weighted by the gates.
+ The slice corresponding to a particular batch element `b` is computed
+ as the sum over all experts `i` of the expert output, weighted by the
+ corresponding gate values. If `multiply_by_gates` is set to False, the
+ gate values are ignored.
+ Args:
+ expert_out: a list of `num_experts` `Tensor`s, each with shape
+ `[expert_batch_size_i, ]`.
+ multiply_by_gates: a boolean
+ Returns:
+ a `Tensor` with shape `[batch_size, ]`.
+ """
+ # apply exp to expert outputs, so we are not longer in log space
+ stitched = torch.cat(expert_out, 0).exp()
+
+ if multiply_by_gates:
+ stitched = stitched.mul(self._nonzero_gates)
+ zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), requires_grad=True)
+ if self.use_gpu:
+ zeros = zeros.cuda()
+ # combine samples that have been processed by the same k experts
+ combined = zeros.index_add(0, self._batch_index, stitched.float())
+ # add eps to all zero values in order to avoid nans when going back to log space
+ combined[combined == 0] = np.finfo(float).eps
+ # back to log space
+ return combined.log()
+
+
+ def expert_to_gates(self):
+ """Gate values corresponding to the examples in the per-expert `Tensor`s.
+ Returns:
+ a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32`
+ and shapes `[expert_batch_size_i]`
+ """
+ # split nonzero gates for each expert
+ return torch.split(self._nonzero_gates, self._part_sizes, dim=0)
+
+class MoE(nn.Module):
+
+ """Call a Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
+ Args:
+ input_size: integer - size of the input
+ output_size: integer - size of the input
+ num_experts: an integer - number of experts
+ hidden_size: an integer - hidden size of the experts
+ noisy_gating: a boolean
+ k: an integer - how many experts to use for each batch element
+ """
+
+ def __init__(self, input_size, output_size, num_experts,
+ hidden_size, num_hidden, lb, ub, activation=torch.tanh,
+ non_linear=False, noisy_gating=False, k=1, scaling_factor=1.):
+ super(MoE, self).__init__()
+ self.noisy_gating = noisy_gating
+ self.num_experts = num_experts
+ self.output_size = output_size
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.use_gpu = False
+ self.k = k
+ self.loss = 0
+ self.lb = torch.Tensor(lb).float()
+ self.ub = torch.Tensor(ub).float()
+ self.scaling_factor = scaling_factor
+
+ # instantiate experts
+ # normalization of the MLPs is disabled cause the Gating Network performs the normalization
+ self.experts = nn.ModuleList([
+ FingerNet(lb, ub, hidden_size, num_hidden, activation, False)
+ for _ in range(self.num_experts)
+ ])
+
+ self.w_gate = nn.Parameter(torch.randn(input_size, num_experts), requires_grad=True)
+ self.w_noise = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True)
+
+ self.softplus = nn.Softplus()
+ self.softmax = nn.Softmax(1)
+ if self.use_gpu:
+ self.normal = Normal(torch.tensor([0.0]).cuda(), torch.tensor([1.0]).cuda())
+ else:
+ self.normal = Normal(torch.tensor([0.0]).cuda(), torch.tensor([1.0]).cuda())
+
+ self.non_linear = non_linear
+ if self.non_linear:
+ self.gating_network = MLP(input_size,
+ num_experts,
+ num_experts*2,
+ 1,
+ lb,
+ ub,
+ activation=F.relu,
+ normalize=False)
+
+ assert(self.k <= self.num_experts)
+
+ def cv_squared(self, x):
+ """The squared coefficient of variation of a sample.
+ Useful as a loss to encourage a positive distribution to be more uniform.
+ Epsilons added for numerical stability.
+ Returns 0 for an empty Tensor.
+ Args:
+ x: a `Tensor`.
+ Returns:
+ a `Scalar`.
+ """
+ eps = 1e-10
+ # if only num_experts = 1
+ if x.shape[0] == 1:
+ return torch.Tensor([0])
+ return x.float().var() / (x.float().mean()**2 + eps)
+
+
+ def _gates_to_load(self, gates):
+ """Compute the true load per expert, given the gates.
+ The load is the number of examples for which the corresponding gate is >0.
+ Args:
+ gates: a `Tensor` of shape [batch_size, n]
+ Returns:
+ a float32 `Tensor` of shape [n]
+ """
+ return (gates > 0).sum(0)
+
+ def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values):
+ """Helper function to NoisyTopKGating.
+ Computes the probability that value is in top k, given different random noise.
+ This gives us a way of backpropagating from a loss that balances the number
+ of times each expert is in the top k experts per example.
+ In the case of no noise, pass in None for noise_stddev, and the result will
+ not be differentiable.
+ Args:
+ clean_values: a `Tensor` of shape [batch, n].
+ noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus
+ normally distributed noise with standard deviation noise_stddev.
+ noise_stddev: a `Tensor` of shape [batch, n], or None
+ noisy_top_values: a `Tensor` of shape [batch, m].
+ "values" Output of tf.top_k(noisy_top_values, m). m >= k+1
+ Returns:
+ a `Tensor` of shape [batch, n].
+ """
+
+ batch = clean_values.size(0)
+ m = noisy_top_values.size(1)
+ top_values_flat = noisy_top_values.flatten()
+ if self.use_gpu:
+ threshold_positions_if_in = (torch.arange(batch) * m + self.k).cuda()
+ else:
+ threshold_positions_if_in = (torch.arange(batch) * m + self.k).cuda()
+ threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1)
+ is_in = torch.gt(noisy_values, threshold_if_in)
+ if self.use_gpu:
+ threshold_positions_if_out = (threshold_positions_if_in - 1).cuda()
+ else:
+ threshold_positions_if_out = (threshold_positions_if_in - 1).cuda()
+
+ threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1)
+ # is each value currently in the top k.
+ prob_if_in = self.normal.cdf((clean_values - threshold_if_in)/noise_stddev)
+ prob_if_out = self.normal.cdf((clean_values - threshold_if_out)/noise_stddev)
+ prob = torch.where(is_in, prob_if_in, prob_if_out)
+ return prob
+
+
+ def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2):
+ """Noisy top-k gating.
+ See paper: https://arxiv.org/abs/1701.06538.
+ Args:
+ x: input Tensor with shape [batch_size, input_size]
+ train: a boolean - we only add noise at training time.
+ noise_epsilon: a float
+ Returns:
+ gates: a Tensor with shape [batch_size, num_experts]
+ load: a Tensor with shape [num_experts]
+ """
+
+ if self.non_linear:
+ clean_logits = self.gating_network(x)
+ else:
+ clean_logits = x @ self.w_gate
+ if self.noisy_gating:
+ raw_noise_stddev = x @ self.w_noise
+ if(self.k > 1):
+ raw_noise_stddev = self.softplus(raw_noise_stddev)
+ noise_stddev = ((raw_noise_stddev + noise_epsilon) * train)
+ noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev)
+ logits = noisy_logits
+ else:
+ logits = clean_logits
+
+ # calculate topk + 1 that will be needed for the noisy gates
+ top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1)
+ top_k_logits = top_logits[:, :self.k]
+ top_k_indices = top_indices[:, :self.k]
+ if(self.k > 1):
+ top_k_gates = self.softmax(top_k_logits)
+ else:
+ top_k_gates = torch.sigmoid(top_k_logits)
+ zeros = torch.zeros_like(logits, requires_grad=True)
+ gates = zeros.scatter(1, top_k_indices, top_k_gates)
+
+ if self.noisy_gating and self.k < self.num_experts:
+ load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0)
+ else:
+ load = self._gates_to_load(gates)
+ return gates, load
+
+
+ def get_utilisation_loss(self):
+ return self.loss
+
+ def forward(self, x, train=True, loss_coef=1e-2):
+ """Args:
+ x: tensor shape [batch_size, input_size]
+ train: a boolean scalar.
+ loss_coef: a scalar - multiplier on load-balancing losses
+ Returns:
+ y: a tensor with shape [batch_size, output_size].
+ extra_training_loss: a scalar. This should be added into the overall
+ training loss of the model. The backpropagation of this loss
+ encourages all experts to be approximately equally used across a batch.
+ """
+ # normalization is performed here for better convergence of the gating network
+ x = 2.0*(x - self.lb)/(self.ub - self.lb) - 1.0
+
+ gates, load = self.noisy_top_k_gating(x, train)
+ # calculate importance loss
+ importance = gates.sum(0)
+ #
+ loss = self.cv_squared(importance) + self.cv_squared(load)
+ loss *= loss_coef
+ self.loss = loss
+
+ dispatcher = SparseDispatcher(self.num_experts, gates, self.use_gpu)
+ expert_inputs = dispatcher.dispatch(x)
+ gates = dispatcher.expert_to_gates()
+ expert_outputs = []
+ for i in range(self.num_experts):
+ if expert_inputs[i] is not None:
+ expert_outputs.append(self.experts[i](expert_inputs[i]))
+ y = dispatcher.combine(expert_outputs)
+ return self.scaling_factor * y
+
+ def cuda(self):
+ super(MoE, self).cuda()
+ self.use_gpu = True
+ #iterate over all experts and move them to gpu
+ for i in range(self.num_experts):
+ self.experts[i].cuda()
+ self.lb = self.lb.cuda()
+ self.ub = self.ub.cuda()
+ if self.non_linear:
+ self.gating_network.cuda()
+
+ def cpu(self):
+ super(MoE, self).cpu()
+ self.use_gpu = False
+ for i in range(self.num_experts):
+ self.experts[i].cpu()
+ self.lb = self.lb.cpu()
+ self.ub = self.ub.cpu()
+ if self.non_linear:
+ self.gating_network.cpu()
diff --git a/PINNFramework/models/moe_finger_test.py b/PINNFramework/models/moe_finger_test.py
new file mode 100644
index 0000000..a8eb2e6
--- /dev/null
+++ b/PINNFramework/models/moe_finger_test.py
@@ -0,0 +1,32 @@
+from PINNFramework.models import FingerMoE
+import numpy as np
+import torch
+
+if __name__ == "__main__":
+ lb = np.array([0, 0, 0, 0])
+ ub = np.array([1, 1, 1, 1])
+
+ # finger MoE gpu
+ moe = FingerMoE(4, 3, 5, 100, 2, lb, ub)
+ moe.cuda()
+ x_gpu = torch.randn(3, 4).cuda()
+ y_gpu = moe(x_gpu)
+ print(y_gpu)
+
+ # finger MoE cpu
+ moe.cpu()
+ x_cpu = torch.randn(3, 4)
+ y_cpu = moe(x_cpu)
+ print(y_cpu)
+
+ # non linear gating test
+ moe = FingerMoE(4, 3, 5, 100, 2, lb, ub, non_linear=True)
+ moe.cuda()
+ x_gpu = torch.randn(3, 4).cuda()
+ y_gpu = moe(x_gpu)
+ print(y_gpu)
+ moe.cpu()
+ x_cpu = torch.randn(3, 4)
+ y_cpu = moe(x_cpu)
+ print(y_cpu)
+
diff --git a/PINNFramework/models/moe_mlp.py b/PINNFramework/models/moe_mlp.py
new file mode 100644
index 0000000..93c4491
--- /dev/null
+++ b/PINNFramework/models/moe_mlp.py
@@ -0,0 +1,361 @@
+# Sparsely-Gated Mixture-of-Experts Layers.
+# See "Outrageously Large Neural Networks"
+# https://arxiv.org/abs/1701.06538
+#
+# Author: David Rau
+#
+# The code is based on the TensorFlow implementation:
+# https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/expert_utils.py
+
+import time
+import torch
+import torch.nn as nn
+from torch.distributions.normal import Normal
+import numpy as np
+import torch.nn.functional as F
+from PINNFramework.models import MLP
+
+class SparseDispatcher(object):
+ """Helper for implementing a mixture of experts.
+ The purpose of this class is to create input minibatches for the
+ experts and to combine the results of the experts to form a unified
+ output tensor.
+ There are two functions:
+ dispatch - take an input Tensor and create input Tensors for each expert.
+ combine - take output Tensors from each expert and form a combined output
+ Tensor. Outputs from different experts for the same batch element are
+ summed together, weighted by the provided "gates".
+ The class is initialized with a "gates" Tensor, which specifies which
+ batch elements go to which experts, and the weights to use when combining
+ the outputs. Batch element b is sent to expert e iff gates[b, e] != 0.
+ The inputs and outputs are all two-dimensional [batch, depth].
+ Caller is responsible for collapsing additional dimensions prior to
+ calling this class and reshaping the output to the original shape.
+ See common_layers.reshape_like().
+ Example use:
+ gates: a float32 `Tensor` with shape `[batch_size, num_experts]`
+ inputs: a float32 `Tensor` with shape `[batch_size, input_size]`
+ experts: a list of length `num_experts` containing sub-networks.
+ dispatcher = SparseDispatcher(num_experts, gates)
+ expert_inputs = dispatcher.dispatch(inputs)
+ expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)]
+ outputs = dispatcher.combine(expert_outputs)
+ The preceding code sets the output for a particular example b to:
+ output[b] = Sum_i(gates[b, i] * experts[i](inputs[b]))
+ This class takes advantage of sparsity in the gate matrix by including in the
+ `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`.
+ """
+
+ def __init__(self, num_experts, gates, use_gpu=True):
+ """Create a SparseDispatcher."""
+ self.use_gpu = use_gpu
+ self._gates = gates
+ self._num_experts = num_experts
+ # sort experts
+ sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0)
+ # drop indices
+ _, self._expert_index = sorted_experts.split(1, dim=1)
+ # get according batch index for each expert
+ self._batch_index = sorted_experts[index_sorted_experts[:, 1],0]
+ # calculate num samples that each expert gets
+ if self.use_gpu:
+ self._part_sizes = list((gates > 0).sum(0).cuda())
+ else:
+ self._part_sizes = list((gates > 0).sum(0))
+
+ # expand gates to match with self._batch_index
+ gates_exp = gates[self._batch_index.flatten()]
+ self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index)
+
+ def dispatch(self, inp):
+ """Create one input Tensor for each expert.
+ The `Tensor` for a expert `i` contains the slices of `inp` corresponding
+ to the batch elements `b` where `gates[b, i] > 0`.
+ Args:
+ inp: a `Tensor` of shape "[batch_size, ]`
+ Returns:
+ a list of `num_experts` `Tensor`s with shapes
+ `[expert_batch_size_i, ]`.
+ """
+
+ # assigns samples to experts whose gate is nonzero
+
+ # expand according to batch index so we can just split by _part_sizes
+ inp_exp = inp[self._batch_index].squeeze(1)
+ return torch.split(inp_exp, self._part_sizes, dim=0)
+
+
+ def combine(self, expert_out, multiply_by_gates=True):
+ """Sum together the expert output, weighted by the gates.
+ The slice corresponding to a particular batch element `b` is computed
+ as the sum over all experts `i` of the expert output, weighted by the
+ corresponding gate values. If `multiply_by_gates` is set to False, the
+ gate values are ignored.
+ Args:
+ expert_out: a list of `num_experts` `Tensor`s, each with shape
+ `[expert_batch_size_i, ]`.
+ multiply_by_gates: a boolean
+ Returns:
+ a `Tensor` with shape `[batch_size, ]`.
+ """
+ # apply exp to expert outputs, so we are not longer in log space
+ stitched = torch.cat(expert_out, 0).exp()
+
+ if multiply_by_gates:
+ stitched = stitched.mul(self._nonzero_gates)
+ if self.use_gpu:
+ zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), requires_grad=True).cuda()
+ else:
+ zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), requires_grad=True)
+
+ # combine samples that have been processed by the same k experts
+ combined = zeros.index_add(0, self._batch_index, stitched.float())
+ # add eps to all zero values in order to avoid nans when going back to log space
+ combined[combined == 0] = np.finfo(float).eps
+ # back to log space
+ return combined.log()
+
+
+ def expert_to_gates(self):
+ """Gate values corresponding to the examples in the per-expert `Tensor`s.
+ Returns:
+ a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32`
+ and shapes `[expert_batch_size_i]`
+ """
+ # split nonzero gates for each expert
+ return torch.split(self._nonzero_gates, self._part_sizes, dim=0)
+
+class MoE(nn.Module):
+
+ """Call a Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
+ Args:
+ input_size: integer - size of the input
+ output_size: integer - size of the input
+ num_experts: an integer - number of experts
+ hidden_size: an integer - hidden size of the experts
+ noisy_gating: a boolean
+ k: an integer - how many experts to use for each batch element
+ """
+
+ def __init__(self, input_size, output_size, num_experts,
+ hidden_size, num_hidden, lb, ub, activation=torch.tanh,
+ non_linear=False, noisy_gating=False, k=1,):
+ super(MoE, self).__init__()
+ self.noisy_gating = noisy_gating
+ self.num_experts = num_experts
+ self.output_size = output_size
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.use_gpu = False
+ self.k = k
+ self.loss = 0
+ self.lb = torch.Tensor(lb).float()
+ self.ub = torch.Tensor(ub).float()
+
+ # instantiate experts
+ # normalization of the MLPs is disabled cause the Gating Network performs the normalization
+ self.experts = nn.ModuleList([
+ MLP(input_size, output_size, hidden_size, num_hidden, lb, ub, activation)
+ for _ in range(self.num_experts)
+ ])
+
+ self.w_gate = torch.randn(input_size, num_experts, requires_grad=True)
+ self.w_noise = torch.zeros(input_size, num_experts, requires_grad=True)
+
+ if self.use_gpu:
+ self.w_gate = self.w_gate.cuda()
+ self.w_noise = self.w_noise.cuda()
+
+ self.w_gate = torch.nn.Parameter(self.w_gate)
+ self.w_noise = torch.nn.Parameter(self.w_noise)
+
+ self.softplus = nn.Softplus()
+ self.softmax = nn.Softmax(1)
+
+ if self.use_gpu:
+ self.normal = Normal(torch.tensor([0.0]).cuda(), torch.tensor([1.0]).cuda())
+ else:
+ self.normal = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
+
+ self.non_linear = non_linear
+ if self.non_linear:
+ self.gating_network = MLP(input_size,
+ num_experts,
+ num_experts*2,
+ 1,
+ lb,
+ ub,
+ activation=F.relu,
+ normalize=False)
+ if self.use_gpu:
+ self.gating_network.cuda()
+
+ assert(self.k <= self.num_experts)
+
+ def cv_squared(self, x):
+ """The squared coefficient of variation of a sample.
+ Useful as a loss to encourage a positive distribution to be more uniform.
+ Epsilons added for numerical stability.
+ Returns 0 for an empty Tensor.
+ Args:
+ x: a `Tensor`.
+ Returns:
+ a `Scalar`.
+ """
+ eps = 1e-10
+ # if only num_experts = 1
+ if x.shape[0] == 1:
+ return torch.Tensor([0])
+ return x.float().var() / (x.float().mean()**2 + eps)
+
+
+ def _gates_to_load(self, gates):
+ """Compute the true load per expert, given the gates.
+ The load is the number of examples for which the corresponding gate is >0.
+ Args:
+ gates: a `Tensor` of shape [batch_size, n]
+ Returns:
+ a float32 `Tensor` of shape [n]
+ """
+ return (gates > 0).sum(0)
+
+ def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values):
+ """Helper function to NoisyTopKGating.
+ Computes the probability that value is in top k, given different random noise.
+ This gives us a way of backpropagating from a loss that balances the number
+ of times each expert is in the top k experts per example.
+ In the case of no noise, pass in None for noise_stddev, and the result will
+ not be differentiable.
+ Args:
+ clean_values: a `Tensor` of shape [batch, n].
+ noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus
+ normally distributed noise with standard deviation noise_stddev.
+ noise_stddev: a `Tensor` of shape [batch, n], or None
+ noisy_top_values: a `Tensor` of shape [batch, m].
+ "values" Output of tf.top_k(noisy_top_values, m). m >= k+1
+ Returns:
+ a `Tensor` of shape [batch, n].
+ """
+
+ batch = clean_values.size(0)
+ m = noisy_top_values.size(1)
+ top_values_flat = noisy_top_values.flatten()
+ if self.use_gpu:
+ threshold_positions_if_in = (torch.arange(batch) * m + self.k).cuda()
+ else:
+ threshold_positions_if_in = (torch.arange(batch) * m + self.k)
+
+ threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1)
+ is_in = torch.gt(noisy_values, threshold_if_in)
+ if self.use_gpu:
+ threshold_positions_if_out = (threshold_positions_if_in - 1).cuda()
+ else:
+ threshold_positions_if_out = (threshold_positions_if_in - 1)
+ threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1)
+ # is each value currently in the top k.
+ prob_if_in = self.normal.cdf((clean_values - threshold_if_in)/noise_stddev)
+ prob_if_out = self.normal.cdf((clean_values - threshold_if_out)/noise_stddev)
+ prob = torch.where(is_in, prob_if_in, prob_if_out)
+ return prob
+
+
+ def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2):
+ """Noisy top-k gating.
+ See paper: https://arxiv.org/abs/1701.06538.
+ Args:
+ x: input Tensor with shape [batch_size, input_size]
+ train: a boolean - we only add noise at training time.
+ noise_epsilon: a float
+ Returns:
+ gates: a Tensor with shape [batch_size, num_experts]
+ load: a Tensor with shape [num_experts]
+ """
+
+ if self.non_linear:
+ clean_logits = self.gating_network(x)
+ else:
+ clean_logits = x @ self.w_gate
+ if self.noisy_gating:
+ raw_noise_stddev = x @ self.w_noise
+ if(self.k > 1):
+ raw_noise_stddev = self.softplus(raw_noise_stddev)
+ noise_stddev = ((raw_noise_stddev + noise_epsilon) * train)
+ noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev)
+ logits = noisy_logits
+ else:
+ logits = clean_logits
+
+ # calculate topk + 1 that will be needed for the noisy gates
+ top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1)
+ top_k_logits = top_logits[:, :self.k]
+ top_k_indices = top_indices[:, :self.k]
+ if(self.k > 1):
+ top_k_gates = self.softmax(top_k_logits)
+ else:
+ top_k_gates = torch.sigmoid(top_k_logits)
+ zeros = torch.zeros_like(logits, requires_grad=True)
+ gates = zeros.scatter(1, top_k_indices, top_k_gates)
+
+ if self.noisy_gating and self.k < self.num_experts:
+ load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0)
+ else:
+ load = self._gates_to_load(gates)
+ return gates, load
+
+
+ def get_utilisation_loss(self):
+ return self.loss
+
+ def forward(self, x, train=True, loss_coef=1e-2):
+ """Args:
+ x: tensor shape [batch_size, input_size]
+ train: a boolean scalar.
+ loss_coef: a scalar - multiplier on load-balancing losses
+ Returns:
+ y: a tensor with shape [batch_size, output_size].
+ extra_training_loss: a scalar. This should be added into the overall
+ training loss of the model. The backpropagation of this loss
+ encourages all experts to be approximately equally used across a batch.
+ """
+ # normalization is performed here for better convergence of the gating network
+ x = 2.0*(x - self.lb)/(self.ub - self.lb) - 1.0
+ gates, load = self.noisy_top_k_gating(x, train)
+ # calculate importance loss
+ importance = gates.sum(0)
+ #
+ loss = self.cv_squared(importance) + self.cv_squared(load)
+ loss *= loss_coef
+ self.loss = loss
+
+ dispatcher = SparseDispatcher(self.num_experts, gates, self.use_gpu)
+ expert_inputs = dispatcher.dispatch(x)
+ gates = dispatcher.expert_to_gates()
+ expert_outputs = []
+ for i in range(self.num_experts):
+ if expert_inputs[i] is not None:
+ expert_outputs.append(self.experts[i](expert_inputs[i]))
+ y = dispatcher.combine(expert_outputs)
+ return y
+
+ def cuda(self):
+ super(MoE, self).cuda()
+ self.use_gpu = True
+ for i in range(self.num_experts):
+ self.experts[i].cuda()
+ self.lb = self.lb.cuda()
+ self.ub = self.ub.cuda()
+ if self.non_linear:
+ self.gating_network.cuda()
+
+ def cpu(self):
+ super(MoE, self).cpu()
+ self.use_gpu = False
+ for i in range(self.num_experts):
+ self.experts[i].cpu()
+ self.lb = self.lb.cpu()
+ self.ub = self.ub.cpu()
+ if self.non_linear:
+ self.gating_network.cpu()
+
+
+
diff --git a/PINNFramework/models/moe_mlp_test.py b/PINNFramework/models/moe_mlp_test.py
new file mode 100644
index 0000000..3d66740
--- /dev/null
+++ b/PINNFramework/models/moe_mlp_test.py
@@ -0,0 +1,29 @@
+from PINNFramework.models import MoE
+import numpy as np
+import torch
+
+if __name__ == "__main__":
+ lb = np.array([0, 0, 0])
+ ub = np.array([1, 1, 1])
+
+ # linear gating test
+ moe = MoE(3, 3, 5, 100, 2, lb, ub)
+ moe.cuda()
+ x_gpu = torch.randn(3, 3).cuda()
+ y_gpu = moe(x_gpu)
+ print(y_gpu)
+ moe.cpu()
+ x_cpu = torch.randn(3, 3)
+ y_cpu = moe(x_cpu)
+ print(y_cpu)
+
+ # non linear gating test
+ moe = MoE(3, 3, 5, 100, 2, lb, ub, non_linear=True)
+ moe.cuda()
+ x_gpu = torch.randn(3, 3).cuda()
+ y_gpu = moe(x_gpu)
+ print(y_gpu)
+ moe.cpu()
+ x_cpu = torch.randn(3, 3)
+ y_cpu = moe(x_cpu)
+ print(y_cpu)
diff --git a/PINNFramework/models/snake_mlp.py b/PINNFramework/models/snake_mlp.py
new file mode 100644
index 0000000..5131631
--- /dev/null
+++ b/PINNFramework/models/snake_mlp.py
@@ -0,0 +1,39 @@
+import torch
+import torch.nn as nn
+import math
+from .mlp import MLP
+from .activations.snake import Snake
+
+class SnakeMLP(MLP):
+ def __init__(self, input_size, output_size, hidden_size, num_hidden, lb, ub, frequency, normalize=True):
+ super(MLP, self).__init__()
+ self.linear_layers = nn.ModuleList()
+ self.activation = nn.ModuleList()
+ self.init_layers(input_size, output_size, hidden_size, num_hidden, frequency)
+ self.lb = torch.tensor(lb).float()
+ self.ub = torch.tensor(ub).float()
+ self.normalize = normalize
+
+
+ def init_layers(self, input_size, output_size, hidden_size, num_hidden, frequency):
+ self.linear_layers.append(nn.Linear(input_size, hidden_size))
+ self.activation.append(Snake(frequency=frequency))
+ for _ in range(num_hidden):
+ self.linear_layers.append(nn.Linear(hidden_size, hidden_size))
+ self.activation.append(Snake(frequency=frequency))
+ self.linear_layers.append(nn.Linear(hidden_size, output_size))
+
+ for m in self.linear_layers:
+ if isinstance(m, nn.Linear):
+ bound = math.sqrt(3 / m.weight.size()[0])
+ torch.nn.init.uniform_(m.weight, a=-bound, b=bound)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ if self.normalize:
+ x = 2.0*(x - self.lb)/(self.ub - self.lb) - 1.0
+ for i in range(len(self.linear_layers) - 1):
+ x = self.linear_layers[i](x)
+ x = self.activation[i](x)
+ x = self.linear_layers[-1](x)
+ return x
diff --git a/benchmarks/hvd_test.py b/benchmarks/hvd_test.py
new file mode 100644
index 0000000..5e55834
--- /dev/null
+++ b/benchmarks/hvd_test.py
@@ -0,0 +1,6 @@
+import horovod.torch as hvd
+
+if __name__ == "__main__":
+ # test if horovod is correct installed
+ hvd.init()
+ print("Hello from rank{}".format(hvd.rank()))
\ No newline at end of file
diff --git a/benchmarks/moe_benchmark.py b/benchmarks/moe_benchmark.py
new file mode 100644
index 0000000..e33299f
--- /dev/null
+++ b/benchmarks/moe_benchmark.py
@@ -0,0 +1,49 @@
+import torch
+torch.manual_seed(0)
+import time
+import numpy as np
+np.random.seed(0)
+from argparse import ArgumentParser
+import matplotlib.pyplot as plt
+import sys
+
+sys.path.append('..')
+parser = ArgumentParser()
+parser.add_argument("--distributed", dest="distributed", type=int, default=0)
+args = parser.parse_args()
+
+if args.distributed:
+ from PINNFramework.models.distributed_moe import MoE
+else:
+ from PINNFramework.models.moe_mlp import MoE
+
+if __name__ == "__main__":
+ model = MoE(3, 3, 7, 300, 5, lb=[0, 0, 0], ub=[1, 1, 1], device='cuda:0', k=1).eval()
+ times = []
+ for i in range(100000, 10100000,100000):
+ x = torch.randn((i, 3)).cuda()
+ torch.cuda.synchronize()
+ begin_time = time.time()
+ model.forward(x, train=False)
+ torch.cuda.synchronize()
+ end_time = time.time()
+ run_time = (end_time - begin_time)
+ print("For {} samples: {} sec".format(i, run_time))
+ times.append([i,
+ run_time])
+ del x
+ torch.cuda.empty_cache()
+ if args.distributed:
+ np.save("dist_run_time", times)
+ else:
+ np.save("non_dist_run_time", times)
+
+
+ """
+ times = np.array(times)
+ plt.scatter(times[1:, 0],times[1:, 1], s=1)
+ plt.xlabel("Number of input samples")
+ plt.ylabel("Inference time")
+ plt.show()
+ """
+