diff --git a/PINNFramework/BoundaryCondition.py b/PINNFramework/BoundaryCondition.py index 55f5107..fdb748b 100644 --- a/PINNFramework/BoundaryCondition.py +++ b/PINNFramework/BoundaryCondition.py @@ -4,9 +4,8 @@ class BoundaryCondition(LossTerm): - def __init__(self, name, dataset, norm='L2', weight=1.): - self.name = name - super(BoundaryCondition, self).__init__(dataset, norm, weight) + def __init__(self, dataset, name, norm='L2', weight=1.): + super(BoundaryCondition, self).__init__(dataset, name, norm, weight) def __call__(self, *args, **kwargs): raise NotImplementedError("The call function of the Boundary Condition has to be implemented") @@ -18,7 +17,7 @@ class DirichletBC(BoundaryCondition): """ def __init__(self, func, dataset, name, norm='L2',weight=1.): - super(DirichletBC, self).__init__(name, dataset, norm, weight) + super(DirichletBC, self).__init__(dataset, name, norm, weight) self.func = func def __call__(self, x, model): @@ -29,19 +28,37 @@ def __call__(self, x, model): class NeumannBC(BoundaryCondition): """ Neumann boundary conditions: dy/dn(x) = func(x). + + With dy/dn(x) = <∇y,n> """ - def __init__(self, func, dataset, input_dimension, output_dimension, name, norm='L2',weight=1.): - super(NeumannBC, self).__init__(name, dataset, norm, weight) + def __init__(self, func, dataset, normal_vector, begin, end, output_dimension, name, norm='L2', weight=1.): + """ + Args: + func: scalar but vectorized function f(x) + normal_vector: normal vector for the face + name: identifier of the boundary condition + weight: weighting of the boundary condition + begin: defines the begin of spatial variables in x + end: defines the end of the spatial domain in x + output_dimension defines on which dimension of the output the boundary condition performed + """ + super(NeumannBC, self).__init__(dataset, name, norm, weight) self.func = func - self.input_dimension = input_dimension + self.normal_vector = normal_vector + self.begin = begin + self.end = end self.output_dimension = output_dimension def __call__(self, x, model): - grads = ones(x.shape, device=model.device) - y = model(x)[:, self.output_dimension] + x.requires_grad = True + y = model(x) + y = y[:, self.output_dimension] + grads = ones(y.shape, device=y.device) grad_y = grad(y, x, create_graph=True, grad_outputs=grads)[0] - y_dn = grad_y[:, self.input_dimension] + grad_y = grad_y[:,self.begin:self.end] + self.normal_vector.to(y.device) # move normal vector to the correct device + y_dn = grad_y @ self.normal_vector return self.weight * self.norm(y_dn, self.func(x)) @@ -50,17 +67,34 @@ class RobinBC(BoundaryCondition): Robin boundary conditions: dy/dn(x) = func(x, y). """ - def __init__(self, func, dataset, input_dimension, output_dimension, name, norm='L2', weight=1.): - super(RobinBC, self).__init__(name, dataset, norm, weight) + def __init__(self, func, dataset, normal_vector, begin, end, output_dimension, name, norm='L2', weight=1.): + """ + Args: + func: scalar but vectorized function f(x,y) + normal_vector: normal vector for the face + name: identifier of the boundary condition + weight: weighting of the boundary condition + begin: defines the begin of spatial variables in x + end: defines the end of the spatial domain in x + output_dimension defines on which dimension of the output the boundary condition performed + """ + + super(RobinBC, self).__init__(dataset, name, norm, weight) self.func = func - self.input_dimension = input_dimension + self.begin = begin + self.end = end + self.normal_vector = normal_vector self.output_dimension = output_dimension def __call__(self, x, y, model): - y = model(x)[:, self.output_dimension] + x.requires_grad = True + y = model(x) + y = y[:, self.output_dimension] grads = ones(y.shape, device=y.device) grad_y = grad(y, x, create_graph=True, grad_outputs=grads)[0] - y_dn = grad_y[:, self.input_dimension] + grad_y = grad_y[:, self.begin:self.end] + self.normal_vector.to(y.device) # move normal vector to the correct device + y_dn = grad_y @ self.normal_vector return self.weight * self.norm(y_dn, self.func(x, y)) @@ -70,7 +104,7 @@ class PeriodicBC(BoundaryCondition): """ def __init__(self, dataset, output_dimension, name, degree=None, input_dimension=None, norm='L2', weight=1.): - super(PeriodicBC, self).__init__(name, dataset, norm, weight) + super(PeriodicBC, self).__init__(dataset, name, norm, weight) if degree is not None and input_dimension is None: raise ValueError("If the degree of the boundary condition is defined the input dimension for the " "derivative has to be defined too ") @@ -95,3 +129,22 @@ def __call__(self, x_lb, x_ub, model): else: raise NotImplementedError("Periodic Boundary Condition for a higher degree than one is not supported") + + +class TimeDerivativeBC(BoundaryCondition): + """ + For hyperbolic systems it may be needed to initialize the time derivative. This boundary condition intializes + the time derivative in a data driven way. + + """ + def __init__(self, dataset, name, norm='L2', weight=1): + super(TimeDerivativeBC, self).__init__(dataset, name, norm, weight) + + def __call__(self, x, dt_y, model): + x.requires_grad = True + pred = model(x) + grads = ones(pred.shape, device=pred.device) + pred_dt = grad(pred, x, create_graph=True, grad_outputs=grads)[0][:, -1] + pred_dt = pred_dt.reshape(-1,1) + return self.weight * self.norm(pred_dt, dt_y) + diff --git a/PINNFramework/HPMLoss.py b/PINNFramework/HPMLoss.py index 318ea29..c29b36f 100644 --- a/PINNFramework/HPMLoss.py +++ b/PINNFramework/HPMLoss.py @@ -1,7 +1,7 @@ from .PDELoss import PDELoss class HPMLoss(PDELoss): - def __init__(self, dataset, hpm_input, hpm_model, norm='L2', weight=1.): + def __init__(self, dataset, name, hpm_input, hpm_model, norm='L2', weight=1.): """ Constructor of the HPM loss @@ -13,7 +13,7 @@ def __init__(self, dataset, hpm_input, hpm_model, norm='L2', weight=1.): norm: Norm used for calculation PDE loss weight: Weighting for the loss term """ - super(HPMLoss, self).__init__(dataset, None, norm, weight) + super(HPMLoss, self).__init__(dataset, None, name, norm='L2', weight=1.) self.hpm_input = hpm_input self.hpm_model = hpm_model diff --git a/PINNFramework/InitalCondition.py b/PINNFramework/InitalCondition.py index 502855e..8031083 100644 --- a/PINNFramework/InitalCondition.py +++ b/PINNFramework/InitalCondition.py @@ -4,16 +4,16 @@ class InitialCondition(LossTerm): - def __init__(self, dataset, norm='L2', weight=1.): + def __init__(self, dataset, name, norm='L2', weight=1.): """ - Constructor for the Intial condition + Constructor for the Initial condition Args: dataset (torch.utils.Dataset): dataset that provides the residual points norm: Norm used for calculation PDE loss weight: Weighting for the loss term """ - super(InitialCondition, self).__init__(dataset, norm, weight) + super(InitialCondition, self).__init__(dataset, name, norm, weight) def __call__(self, x: Tensor, model: Module, gt_y: Tensor): """ diff --git a/PINNFramework/JoinedDataset.py b/PINNFramework/JoinedDataset.py index 4d9a33e..0b40cf1 100644 --- a/PINNFramework/JoinedDataset.py +++ b/PINNFramework/JoinedDataset.py @@ -19,12 +19,30 @@ def min_length(datasets): minimum = length return minimum - def __init__(self, datasets): + @staticmethod + def max_length(datasets): + """ + Calculates the minimum dataset length of a list of datasets + + datasets (Map): Map of datasets to be concatenated + """ + maximum = -1 * float("inf") + for key in datasets.keys(): + length = len(datasets[key]) + if length > maximum: + maximum = length + return maximum + + def __init__(self, datasets, mode='min'): super(JoinedDataset, self).__init__() self.datasets = datasets + self.mode = mode def __len__(self): - return self.min_length(self.datasets) + if self.mode =='min': + return self.min_length(self.datasets) + if self.mode =='max': + return self.max_length(self.datasets) def __getitem__(self, idx): if idx < 0: @@ -32,6 +50,8 @@ def __getitem__(self, idx): raise ValueError("absolute value of index should not exceed dataset length") combined_item = {} for key in self.datasets.keys(): + if self.mode == 'max': + idx = idx % len(self.datasets[key]) item = self.datasets[key][idx] combined_item[key] = item return combined_item diff --git a/PINNFramework/Logger_Interface.py b/PINNFramework/Logger_Interface.py new file mode 100644 index 0000000..1523697 --- /dev/null +++ b/PINNFramework/Logger_Interface.py @@ -0,0 +1,45 @@ +from abc import ABC, abstractmethod + + +class LoggerInterface(ABC): + + @abstractmethod + def log_scalar(self, scalar, name, epoch): + """ + Method that defines how scalars are logged + + Args: + scalar: scalar to be logged + name: name of the scalar + epoch: epoch in the training loop + + """ + pass + + @abstractmethod + def log_image(self, image, name, epoch): + """ + Method that defines how images are logged + + Args: + image: image to be logged + name: name of the image + epoch: epoch in the training loop + + """ + pass + + @abstractmethod + def log_histogram(self, histogram, name, epoch): + """ + Method that defines how images are logged + + Args: + histogram: histogram to be logged + name: name of the histogram + epoch: epoch in the training loop + + """ + pass + + diff --git a/PINNFramework/LossTerm.py b/PINNFramework/LossTerm.py index 10f84c5..6b5654d 100644 --- a/PINNFramework/LossTerm.py +++ b/PINNFramework/LossTerm.py @@ -5,13 +5,12 @@ class LossTerm: """ Defines the main structure of a loss term """ - def __init__(self, dataset, norm='L2', weight=1.): + def __init__(self, dataset, name, norm='L2', weight=1.): """ Constructor of a loss term Args: dataset (torch.utils.Dataset): dataset that provides the residual points - pde (function): function that represents residual of the PDE norm: Norm used for calculation PDE loss weight: Weighting for the loss term """ @@ -24,4 +23,5 @@ def __init__(self, dataset, norm='L2', weight=1.): # Case for self implemented norms self.norm = norm self.dataset = dataset + self.name = name self.weight = weight diff --git a/PINNFramework/PDELoss.py b/PINNFramework/PDELoss.py index 0b4f26b..4ca8fc2 100644 --- a/PINNFramework/PDELoss.py +++ b/PINNFramework/PDELoss.py @@ -1,12 +1,11 @@ import torch from torch import Tensor as Tensor from torch.nn import Module as Module -from torch.nn import MSELoss, L1Loss from .LossTerm import LossTerm class PDELoss(LossTerm): - def __init__(self, dataset, pde, norm='L2', weight=1.): + def __init__(self, dataset, pde, name, norm='L2', weight=1.): """ Constructor of the PDE Loss @@ -16,7 +15,7 @@ def __init__(self, dataset, pde, norm='L2', weight=1.): norm: Norm used for calculation PDE loss weight: Weighting for the loss term """ - super(PDELoss, self).__init__(dataset, norm, weight) + super(PDELoss, self).__init__(dataset, name, norm, weight) self.dataset = dataset self.pde = pde diff --git a/PINNFramework/PINN.py b/PINNFramework/PINN.py index aa5fcf9..086aa7f 100644 --- a/PINNFramework/PINN.py +++ b/PINNFramework/PINN.py @@ -1,23 +1,35 @@ import torch import torch.nn as nn +import numpy as np +from os.path import exists from itertools import chain from torch.utils.data import DataLoader from .InitalCondition import InitialCondition -from .BoundaryCondition import BoundaryCondition, PeriodicBC, DirichletBC, NeumannBC, RobinBC +from .BoundaryCondition import BoundaryCondition, PeriodicBC, DirichletBC, NeumannBC, RobinBC, TimeDerivativeBC from .PDELoss import PDELoss from .JoinedDataset import JoinedDataset from .HPMLoss import HPMLoss +from torch.autograd import grad as grad +from PINNFramework.callbacks import CallbackList try: import horovod.torch as hvd except: print("Was not able to import Horovod. Thus Horovod support is not enabled") +# set initial seed for torch and numpy +torch.manual_seed(42) +np.random.seed(42) + + +def worker_init_fn(worker_id): + np.random.seed(np.random.get_state()[1][0] + worker_id) + class PINN(nn.Module): def __init__(self, model: torch.nn.Module, input_dimension: int, output_dimension: int, pde_loss: PDELoss, initial_condition: InitialCondition, boundary_condition, - use_gpu=True, use_horovod=False): + use_gpu=True, use_horovod=False,dataset_mode='min'): """ Initializes an physics-informed neural network (PINN). A PINN consists of a model which represents the solution of the underlying partial differential equation(PDE) u, three loss terms representing initial (IC) and boundary @@ -33,21 +45,25 @@ def __init__(self, model: torch.nn.Module, input_dimension: int, output_dimensio of the BoundaryCondition class use_gpu: enables gpu usage use_horovod: enables horovod support + dataset_mode: defines the behavior of the joined dataset. The 'min'-mode sets the length of the dataset to + the minimum of the """ - super(PINN, self).__init__() # checking if the model is a torch module more model checking should be possible self.use_gpu = use_gpu self.use_horovod = use_horovod - self.rank = 0 # initialize rank 0 by default in order to make the fit method more flexible + self.rank = 0 # initialize rank 0 by default in order to make the fit method more flexible + if self.use_horovod: + # Initialize Horovod hvd.init() # Pin GPU to be used to process local rank (one GPU per process) torch.cuda.set_device(hvd.local_rank()) self.rank = hvd.rank() - + if self.rank == 0: + self.loss_log = {} if isinstance(model, nn.Module): self.model = model if self.use_gpu: @@ -80,7 +96,7 @@ def __init__(self, model: torch.nn.Module, input_dimension: int, output_dimensio else: raise TypeError("PDE loss has to be an instance of a PDE Loss class") - if isinstance(pde_loss,HPMLoss): + if isinstance(pde_loss, HPMLoss): self.is_hpm = True if isinstance(initial_condition, InitialCondition): @@ -88,26 +104,45 @@ def __init__(self, model: torch.nn.Module, input_dimension: int, output_dimensio else: raise TypeError("Initial condition has to be an instance of the InitialCondition class") - joined_datasets = {"Initial_Condition": initial_condition.dataset, "PDE": pde_loss.dataset} + joined_datasets = { + initial_condition.name: initial_condition.dataset, + pde_loss.name: pde_loss.dataset + } + if self.rank == 0: + self.loss_log[initial_condition.name] = float(0.0) # adding initial condition to the loss_log + self.loss_log[pde_loss.name] = float(0.0) + self.loss_log["model_loss_pinn"] = float(0.0) + if not self.is_hpm: if type(boundary_condition) is list: for bc in boundary_condition: if not isinstance(bc, BoundaryCondition): raise TypeError("Boundary Condition has to be an instance of the BoundaryCondition class ") - self.boundary_condition = boundary_condition joined_datasets[bc.name] = bc.dataset - + if self.rank == 0: + self.loss_log[bc.name] = float(0.0) + self.boundary_condition = boundary_condition else: if isinstance(boundary_condition, BoundaryCondition): self.boundary_condition = boundary_condition + joined_datasets[boundary_condition.name] = boundary_condition.dataset else: raise TypeError("Boundary Condition has to be an instance of the BoundaryCondition class" "or a list of instances of the BoundaryCondition class") - self.dataset = JoinedDataset(joined_datasets) + self.dataset = JoinedDataset(joined_datasets, dataset_mode) + + def loss_grad_std_wn(self, loss): + device = torch.device("cuda" if self.use_gpu else "cpu") + grad_ = torch.zeros((0), dtype=torch.float32, device=device) + model_grads = grad(loss, self.model.parameters(), allow_unused=True, retain_graph=True) + for elem in model_grads: + if elem is not None: + grad_ = torch.cat((grad_, elem.view(-1))) + return torch.std(grad_) def forward(self, x): """ - Predicting the solution at given position x + Predicting the solution at given pos """ return self.model(x) @@ -162,7 +197,7 @@ def calculate_boundary_condition(self, boundary_condition: BoundaryCondition, tr else: raise ValueError( "The boundary condition {} has to be tuple of coordinates for lower and upper bound". - format(boundary_condition.name)) + format(boundary_condition.name)) else: raise ValueError("The boundary condition {} has to be tuple of coordinates for lower and upper bound". format(boundary_condition.name)) @@ -190,12 +225,27 @@ def calculate_boundary_condition(self, boundary_condition: BoundaryCondition, tr else: raise ValueError( "The boundary condition {} has to be tuple of coordinates for lower and upper bound". - format(boundary_condition.name)) + format(boundary_condition.name)) else: raise ValueError("The boundary condition {} has to be tuple of coordinates for lower and upper bound". format(boundary_condition.name)) - def pinn_loss(self, training_data): + if isinstance(boundary_condition, TimeDerivativeBC): + # Robin Boundary Condition + if isinstance(training_data, list): + if len(training_data) == 2: + return boundary_condition(training_data[0][0].type(self.dtype), + training_data[1][0].type(self.dtype), + self.model) + else: + raise ValueError( + "The boundary condition {} has to be tuple of coordinates for input data and gt time derivative". + format(boundary_condition.name)) + else: + raise ValueError("The boundary condition {} has to be tuple of coordinates for lower and upper bound". + format(boundary_condition.name)) + + def pinn_loss(self, training_data, annealing=False): """ Function for calculating the PINN loss. The PINN Loss is a weighted sum of losses for initial and boundary condition and the residual of the PDE @@ -205,37 +255,120 @@ def pinn_loss(self, training_data): dictionary holds the training data for initial condition at the key "Initial_Condition" training data for the PDE at the key "PDE" and the data for the boundary condition under the name of the boundary condition """ - pinn_loss = 0 # unpack training data - if type(training_data["Initial_Condition"]) is list: + # ============== PDE LOSS ============== " + if type(training_data[self.pde_loss.name]) is not list: + pde_loss = self.pde_loss(training_data[self.pde_loss.name][0].type(self.dtype), self.model) + if annealing: + std_pde = self.loss_grad_std_wn(pde_loss) + pinn_loss = pinn_loss + pde_loss + if self.rank == 0: + self.loss_log[self.pde_loss.name] = pde_loss + self.loss_log[self.pde_loss.name] / self.pde_loss.weight + else: + raise ValueError("Training Data for PDE data is a single tensor consists of residual points ") + + # ============== INITIAL CONDITION ============== " + if type(training_data[self.initial_condition.name]) is list: # initial condition loss - if len(training_data["Initial_Condition"]) == 2: - pinn_loss = pinn_loss + self.initial_condition(training_data["Initial_Condition"][0][0].type(self.dtype), - self.model, - training_data["Initial_Condition"][1][0].type(self.dtype)) + if len(training_data[self.initial_condition.name]) == 2: + ic_loss = self.initial_condition( + training_data[self.initial_condition.name][0][0].type(self.dtype), + self.model, + training_data[self.initial_condition.name][1][0].type(self.dtype) + ) + if self.rank == 0: + self.loss_log[self.initial_condition.name] = self.loss_log[self.initial_condition.name] +\ + ic_loss / self.initial_condition.weight + if annealing: + std_ic = self.loss_grad_std_wn(ic_loss) + lambda_hat = std_pde / std_ic + self.initial_condition.weight = (1 - 0.5) * self.initial_condition.weight + 0.5 * lambda_hat + pinn_loss = pinn_loss + ic_loss else: raise ValueError("Training Data for initial condition is a tuple (x,y) with x the input coordinates" " and ground truth values y") else: - raise ValueError("Training Data for initial condition is a tuple (x,y) with x the input coordinates" + raise ValueError("Training Data for initial condition is a tuple (x,y) with x the input coordinates" " and ground truth values y") - if type(training_data["PDE"]) is not list: - pinn_loss = pinn_loss + self.pde_loss(training_data["PDE"][0].type(self.dtype), self.model) - else: - raise ValueError("Training Data for PDE data is a single tensor consists of residual points ") + # ============== BOUNDARY CONDITION ============== " if not self.is_hpm: if isinstance(self.boundary_condition, list): for bc in self.boundary_condition: - pinn_loss = pinn_loss + self.calculate_boundary_condition(bc, training_data[bc.name]) + bc_loss = self.calculate_boundary_condition(bc, training_data[bc.name]) + if self.rank == 0: + self.loss_log[bc.name] = self.loss_log[bc.name] + bc_loss / bc.weight + if annealing: + std_bc = self.loss_grad_std_wn(bc_loss) + lambda_hat = std_pde / std_bc + bc.weight = (1 - 0.5) * bc.weight + 0.5 * lambda_hat + pinn_loss = pinn_loss + bc_loss else: - pinn_loss = pinn_loss + self.calculate_boundary_condition(self.boundary_condition, - training_data[self.boundary_condition.name]) + bc_loss = self.calculate_boundary_condition(self.boundary_condition, + training_data[self.boundary_condition.name]) + if self.rank == 0: + self.loss_log[self.boundary_condition.name] = self.loss_log[self.boundary_condition.name] +\ + bc_loss / self.boundary_condition.weight + if annealing: + std_bc = self.loss_grad_std_wn(bc_loss) + lambda_hat = std_pde / std_bc + self.boundary_condition.weight = (1 - 0.5) * self.boundary_condition.weight + 0.5 * lambda_hat + pinn_loss = pinn_loss + bc_loss + + # ============== Model specific losses ============== " + if hasattr(self.model, 'loss'): + pinn_loss = pinn_loss + self.model.loss + if self.rank == 0: + self.loss_log["model_loss_pinn"] = self.loss_log["model_loss_pinn"] + self.model.loss + if self.is_hpm: + if hasattr(self.pde_loss.hpm_model, 'loss'): + pinn_loss = pinn_loss + self.pde_loss.hpm_model.loss + if self.rank == 0: + self.loss_log["model_loss_hpm"] = self.loss_log["model_loss_hpm"] + self.pde_loss.hpm_model.loss + return pinn_loss - def fit(self, epochs, optimizer='Adam', learning_rate=1e-3, lbfgs_finetuning=True, - writing_cylcle= 30, save_model=True, pinn_path='best_model_pinn.pt', hpm_path='best_model_hpm.pt'): + def write_checkpoint(self, checkpoint_path, epoch, pretraining, minimum_pinn_loss, optimizer): + checkpoint = {} + checkpoint["epoch"] = epoch + checkpoint["pretraining"] = pretraining + checkpoint["minimum_pinn_loss"] = minimum_pinn_loss + checkpoint["optimizer"] = optimizer.state_dict() + checkpoint["weight_"+ self.initial_condition.name] = self.initial_condition.weight + checkpoint["weight_" + self.pde_loss.name] = self.initial_condition.weight + checkpoint["pinn_model"] = self.model.state_dict() + if isinstance(self.boundary_condition, list): + for bc in self.boundary_condition: + checkpoint["weight_"+ bc.name] = bc.weight + else: + checkpoint["weight_" + self.boundary_condition.name] = self.boundary_condition.weight + + if self.is_hpm: + checkpoint["hpm_model"] = self.pde_loss.hpm_model.state_dict() + checkpoint_path = checkpoint_path + '_' + str(epoch) + torch.save(checkpoint, checkpoint_path) + + + + def fit(self, + epochs, + checkpoint_path=None, + restart=False, + optimizer='Adam', + learning_rate=1e-3, + pretraining=False, + epochs_pt=100, + lbfgs_finetuning=True, + writing_cycle=30, + writing_cycle_pt=10, + save_model=True, + pinn_path='best_model_pinn.pt', + hpm_path='best_model_hpm.pt', + logger=None, + activate_annealing=False, + annealing_cycle=100, + callbacks=None): """ Function for optimizing the parameters of the PINN-Model @@ -244,17 +377,33 @@ def fit(self, epochs, optimizer='Adam', learning_rate=1e-3, lbfgs_finetuning=Tru optimizer (String, torch.optim.Optimizer) : Optimizer used for training. At the moment only ADAM and LBFGS are supported by string command. It is also possible to give instances of torch optimizers as a parameter learning_rate: The learning rate of the optimizer + pretraining: Activates seperate training on the initial condition at the beginning + epochs_pt: defines the number of epochs for the pretraining lbfgs_finetuning: Enables LBFGS finetuning after main training writing_cylcle: defines the cylcus of model writing save_model: enables or disables checkpointing pinn_path: defines the path where the pinn get stored hpm_path: defines the path where the hpm get stored + logger (Logger): tracks the convergence of all loss terms + activate_annealing (Boolean): enables annealing + annealing_cycle (int): defines the periodicity of using annealing + callbacks (CallbackList): is a list of callbacks which are called at the end of a writing cycle. Can be used + for different purposes e.g. early stopping, visualization, model state logging etc. + checkpoint_path (string) : path to the checkpoint + restart (integer) : defines if checkpoint will be used (False) or will be overwritten (True) + """ + # checking if callbacks are a instance of CallbackList + if callbacks is not None: + if not isinstance(callbacks, CallbackList): + raise ValueError("Callbacks has to be a instance of CallbackList but type {} was found". + format(type(callbacks))) + if isinstance(self.pde_loss, HPMLoss): params = list(self.model.parameters()) + list(self.pde_loss.hpm_model.parameters()) - named_parameters = chain(self.model.named_parameters(),self.pde_loss.hpm_model.named_parameters()) - if self.use_horovod and lbfgs_finetuning: + named_parameters = chain(self.model.named_parameters(), self.pde_loss.hpm_model.named_parameters()) + if self.use_horovod and lbfgs_finetuning: raise ValueError("LBFGS Finetuning is not possible with horovod") if optimizer == 'Adam': optim = torch.optim.Adam(params, lr=learning_rate) @@ -268,6 +417,7 @@ def fit(self, epochs, optimizer='Adam', learning_rate=1e-3, lbfgs_finetuning=Tru if lbfgs_finetuning and not self.use_horovod: lbfgs_optim = torch.optim.LBFGS(params, lr=0.9) + def closure(): lbfgs_optim.zero_grad() pinn_loss = self.pinn_loss(training_data) @@ -295,33 +445,127 @@ def closure(): if self.use_horovod: # Partition dataset among workers using DistributedSampler train_sampler = torch.utils.data.distributed.DistributedSampler( - self.dataset, num_replicas=hvd.size(), rank=hvd.rank()) - data_loader = DataLoader(self.dataset, batch_size=1,sampler=train_sampler) + self.dataset, num_replicas=hvd.size(), rank=hvd.rank() + ) + data_loader = DataLoader(self.dataset, batch_size=1, sampler=train_sampler, worker_init_fn=worker_init_fn) optim = hvd.DistributedOptimizer(optim, named_parameters=named_parameters) + if pretraining: + train_sampler_pt = torch.utils.data.distributed.DistributedSampler( + self.initial_condition.dataset, num_replicas=hvd.size(), rank=hvd.rank() + ) + data_loader_pt = DataLoader(self.initial_condition.dataset, + batch_size=None, + sampler=train_sampler_pt, + worker_init_fn=worker_init_fn) # Broadcast parameters from rank 0 to all other processes. hvd.broadcast_parameters(self.model.state_dict(), root_rank=0) if isinstance(self.pde_loss, HPMLoss): - hvd.broadcast_parameters(self.pinn_loss.hpm_model.state_dict(), root_rank=0) + hvd.broadcast_parameters(self.pde_loss.hpm_model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optim, root_rank=0) else: - data_loader = DataLoader(self.dataset, batch_size=1) + data_loader = DataLoader(self.dataset, batch_size=1, worker_init_fn=worker_init_fn) + data_loader_pt = DataLoader(self.initial_condition.dataset, batch_size=None, worker_init_fn=worker_init_fn) + + start_epoch = 0 + + # load checkpoint routine if a checkpoint path is set and its allowed to not overwrite the checkpoint + if checkpoint_path is not None: + if not exists(checkpoint_path) and not restart: + raise FileNotFoundError( + "Checkpoint path {} do not exists. Please change the path to a existing checkpoint" + "or change the restart flag to true in order to create a new checkpoint" + .format(checkpoint_path)) + if checkpoint_path is not None and restart == 0: + checkpoint = torch.load(checkpoint_path) + start_epoch = checkpoint["epoch"] + pretraining = checkpoint["pretraining"] + self.initial_condition.weight = checkpoint["weight_" + self.initial_condition.name] + self.pde_loss.weight = checkpoint["weight_" + self.pde_loss.name] + if isinstance(self.boundary_condition, list): + for bc in self.boundary_condition: + bc.weight = checkpoint["weight_" + bc.name] + else: + self.boundary_condition.weight = checkpoint["weight_" + self.boundary_condition.name] + + self.model.load_state_dict(checkpoint["pinn_model"]) + if self.is_hpm: + self.pde_loss.hpm_model.load_state_dict(checkpoint["hpm_model"]) - for epoch in range(epochs): - for training_data in data_loader: - training_data = training_data + optim.load_state_dict(checkpoint['optimizer']) + minimum_pinn_loss = checkpoint["minimum_pinn_loss"] + print("Checkpoint Loaded", flush=True) + else: + print("Checkpoint not loaded", flush=True) + + print("===== Pretraining =====") + if pretraining: + for epoch in range(start_epoch, epochs_pt): + for x, y in data_loader_pt: + optim.zero_grad() + ic_loss = self.initial_condition(model=self.model, x=x.type(self.dtype), gt_y=y.type(self.dtype)) + ic_loss.backward() + optim.step() + if not self.rank and not (epoch + 1) % writing_cycle_pt and checkpoint_path is not None: + self.write_checkpoint(checkpoint_path, epoch, True, minimum_pinn_loss, optim) + if not self.rank: + print("IC Loss {} Epoch {} from {}".format(ic_loss, epoch+1, epochs_pt)) + print("===== Main training =====") + for epoch in range(start_epoch, epochs): + # for parallel training the rank should also define the seed + np.random.seed(42 + epoch + self.rank) + batch_counter = 0. + pinn_loss_sum = 0. + for idx, training_data in enumerate(data_loader): + do_annealing = activate_annealing and (idx == 0) and not (epoch + 1) % annealing_cycle optim.zero_grad() - pinn_loss = self.pinn_loss(training_data) + pinn_loss = self.pinn_loss(training_data, do_annealing) pinn_loss.backward() - if not self.rank: - print("PINN Loss {} Epoch {} from {}".format(pinn_loss, epoch, epochs)) optim.step() - if (pinn_loss < minimum_pinn_loss) and not (epoch % writing_cylcle) and save_model and not self.rank: - self.save_model(pinn_path, hpm_path) - minimum_pinn_loss = pinn_loss - + pinn_loss_sum = pinn_loss_sum + pinn_loss + batch_counter += 1 + del pinn_loss + + if not self.rank: + print("PINN Loss {} Epoch {} from {}".format(pinn_loss_sum / batch_counter, epoch+1, epochs), flush=True) + if logger is not None and not (epoch+1) % writing_cycle: + logger.log_scalar(scalar=pinn_loss_sum / batch_counter, name=" Weighted PINN Loss", epoch=epoch) + logger.log_scalar(scalar=sum(self.loss_log.values())/batch_counter, + name=" Non-Weighted PINN Loss", epoch=epoch+1) + # Log values of the loss terms + for key, value in self.loss_log.items(): + logger.log_scalar(scalar=value / batch_counter, name=key, epoch=epoch+1) + + # Log weights of loss terms separately + logger.log_scalar(scalar=self.initial_condition.weight, + name=self.initial_condition.name + "_weight", + epoch=epoch+1) + if not self.is_hpm: + if isinstance(self.boundary_condition, list): + for bc in self.boundary_condition: + logger.log_scalar(scalar=bc.weight, + name=bc.name + "_weight", + epoch=epoch+1) + else: + logger.log_scalar(scalar=self.boundary_condition.weight, + name=self.boundary_condition.name + "_weight", + epoch=epoch+1) + if callbacks is not None and not (epoch+1) % writing_cycle: + callbacks(epoch=epoch+1) + # saving routine + if (pinn_loss_sum / batch_counter < minimum_pinn_loss) and save_model: + self.save_model(pinn_path, hpm_path) + minimum_pinn_loss = pinn_loss_sum / batch_counter + + # reset loss log after the end of the epoch + for key in self.loss_log.keys(): + self.loss_log[key] = float(0) + + # writing checkpoint + if not (epoch + 1) % writing_cycle and checkpoint_path is not None: + self.write_checkpoint(checkpoint_path, epoch, False, minimum_pinn_loss, optim) if lbfgs_finetuning: lbfgs_optim.step(closure) - print("After LBFGS-B: PINN Loss {} Epoch {} from {}".format(pinn_loss, epoch, epochs)) - if (pinn_loss < minimum_pinn_loss) and not (epoch % writing_cylcle) and save_model: + print("After LBFGS-B: PINN Loss {} Epoch {} from {}".format(pinn_loss, epoch+1, epochs)) + if (pinn_loss < minimum_pinn_loss) and not (epoch % writing_cycle) and save_model: self.save_model(pinn_path, hpm_path) diff --git a/PINNFramework/WandB_Logger.py b/PINNFramework/WandB_Logger.py new file mode 100644 index 0000000..8414948 --- /dev/null +++ b/PINNFramework/WandB_Logger.py @@ -0,0 +1,70 @@ +from .Logger_Interface import LoggerInterface +import wandb + + +class WandbLogger(LoggerInterface): + + def __init__(self, project, args, entity=None, group=None): + """ + Initialize wandb instance and connect to the server + + Args: + project: name of the project + args: hyperparameters used for this runs + writing_cycle: defines the writing period + entity: account or group id used for that run + """ + wandb.init(project=project, entity=entity, group=group) + wandb.config.update(args) # adds all of the arguments as config variable + + def log_scalar(self, scalar, name, epoch): + """ + Logs a scalar to wandb + + Args: + scalar: the scalar to be logged + name: name of the sclar + epoch: epoch in the training loop + """ + wandb.log({name: scalar}, step=epoch) + + def log_image(self, image, name, epoch): + """ + Logs a image to wandb + + Args: + image (Image) : the image to be logged + name (String) : name of the image + epoch (Integer) : epoch in the training loop + + """ + wandb.log({name: [wandb.Image(image, caption=name)]}, step=epoch) + + def log_plot(self, plot, name, epoch): + """ + Logs a plot to wandb + + Args: + plot (plot) : the plot to be logged + name (String) : name of the plot + epoch (Integer) : epoch in the training loop + + """ + wandb.log({name: plot}, step=epoch) + + def log_histogram(self, histogram,name, epoch): + """ + Logs a histogram to wandb + + Args: + histogram (histogram) : the histogram to be logged + name (String) : name of the histogram + epoch (Integer) : epoch in the training loop + + """ + wandb.log({name: wandb.Histogram(histogram)}, step=epoch) + + + + + diff --git a/PINNFramework/__init__.py b/PINNFramework/__init__.py index eb42dbb..b54ca88 100644 --- a/PINNFramework/__init__.py +++ b/PINNFramework/__init__.py @@ -3,11 +3,16 @@ from .BoundaryCondition import PeriodicBC from .BoundaryCondition import DirichletBC from .BoundaryCondition import RobinBC +from .BoundaryCondition import TimeDerivativeBC from .BoundaryCondition import NeumannBC from .PDELoss import PDELoss +from .Logger_Interface import LoggerInterface +from .WandB_Logger import WandbLogger from .PINN import PINN import PINNFramework.models +import PINNFramework.callbacks + __all__ = [ 'InitialCondition', @@ -15,8 +20,11 @@ 'DirichletBC', 'RobinBC', 'NeumannBC', + 'TimeDerivativeBC', 'PDELoss', 'HPMLoss', 'PINN', - 'models' -] + 'models', + 'LoggerInterface', + 'WandbLogger', + 'callbacks'] diff --git a/PINNFramework/callbacks/Callback.py b/PINNFramework/callbacks/Callback.py new file mode 100644 index 0000000..daa4c5a --- /dev/null +++ b/PINNFramework/callbacks/Callback.py @@ -0,0 +1,48 @@ +from torch.nn import Module +from PINNFramework.Logger_Interface import LoggerInterface + + +class Callback: + def __init__(self): + self.model = None + self.logger = None + + def set_model(self, model): + if isinstance(model, Module): + self.model = model + else: + raise ValueError("Model is not of type but model of type {} was found" + .format(type(model))) + + def set_logger(self, logger): + if isinstance(LoggerInterface): + self.logger + else: + raise ValueError("Logger is not of type but logger of type {} was found" + .format(type(logger))) + + def __call__(self, epoch): + raise NotImplementedError("method __call__() of the callback is not implemented") + + +class CallbackList: + def __init__(self, callbacks): + if isinstance(callbacks, list): + for cb in callbacks: + if not isinstance(cb, Callback): + raise ValueError("Callback has to be of type but type {} was found" + .format(type(cb))) + self.callbacks = callbacks + else: + raise ValueError("Callback has to be of type but type {} was found" + .format(type(callbacks))) + + def __call__(self, epoch): + for cb in self.callbacks: + cb(epoch) + + + + + + diff --git a/PINNFramework/callbacks/__init__.py b/PINNFramework/callbacks/__init__.py new file mode 100644 index 0000000..ceb4290 --- /dev/null +++ b/PINNFramework/callbacks/__init__.py @@ -0,0 +1,6 @@ +from .Callback import Callback +from .Callback import CallbackList + +___all__ = ["Callback", + "CallbackList" + ] \ No newline at end of file diff --git a/PINNFramework/models/Finger_Net.py b/PINNFramework/models/Finger_Net.py new file mode 100644 index 0000000..797e98f --- /dev/null +++ b/PINNFramework/models/Finger_Net.py @@ -0,0 +1,115 @@ +import torch +import torch.nn as nn + +class FingerNet(nn.Module): + def __init__(self, lb, ub, numFeatures = 500, numLayers = 8, activation = torch.relu, normalize=True, scaling=1.): + torch.manual_seed(1234) + super(FingerNet, self).__init__() + + self.numFeatures = numFeatures + self.numLayers = numLayers + self.lin_layers = nn.ModuleList() + self.lb = torch.tensor(lb).float() + self.ub = torch.tensor(ub).float() + self.activation = activation + self.normalize = normalize + self.scaling = scaling + self.init_layers() + + + def init_layers(self): + """ + This function creates the torch layers and initialize them with xavier + :param self: + :return: + """ + self.in_x = nn.ModuleList() + self.in_y = nn.ModuleList() + self.in_z = nn.ModuleList() + self.in_t = nn.ModuleList() + lenInput = 1 + noInLayers = 3 + + self.in_x.append(nn.Linear(lenInput,self.numFeatures)) + for _ in range(noInLayers): + self.in_x.append(nn.Linear(self.numFeatures, self.numFeatures)) + + self.in_y.append(nn.Linear(lenInput,self.numFeatures)) + for _ in range(noInLayers): + self.in_y.append(nn.Linear(self.numFeatures, self.numFeatures)) + + self.in_z.append(nn.Linear(lenInput,self.numFeatures)) + for _ in range(noInLayers): + self.in_z.append(nn.Linear(self.numFeatures, self.numFeatures)) + + self.in_t.append(nn.Linear(1,self.numFeatures)) + for _ in range(noInLayers): + self.in_t.append(nn.Linear(self.numFeatures, self.numFeatures)) + + for m in [self.in_x,self.in_y,self.in_z,self.in_t]: + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + #nn.init.constant_(m.bias, 0) + + self.lin_layers.append(nn.Linear(4 * self.numFeatures, self.numFeatures)) + for i in range(self.numLayers): + inFeatures = self.numFeatures + self.lin_layers.append(nn.Linear(inFeatures,self.numFeatures)) + inFeatures = self.numFeatures + self.lin_layers.append(nn.Linear(inFeatures,1)) + for m in self.lin_layers: + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.constant_(m.bias, 0) + + def forward(self, x_in): + if self.normalize: + x = 2.0 * (x_in - self.lb) / (self.ub - self.lb) - 1.0 + + x_inx = x_in[:,0].view(-1,1) + x_iny = x_in[:,1].view(-1,1) + x_inz = x_in[:,2].view(-1,1) + x_int = x_in[:,3].view(-1,1) + + for i in range(0,len(self.in_x)): + x_inx = self.in_x[i](x_inx) + x_inx = self.activation(x_inx) + + + for i in range(0,len(self.in_y)): + x_iny = self.in_y[i](x_iny) + x_iny = self.activation(x_iny) + + for i in range(0,len(self.in_z)): + x_inz = self.in_z[i](x_inz) + x_inz = self.activation(x_inz) + + for i in range(0,len(self.in_t)): + x_int = self.in_t[i](x_int) + x_int = self.activation(x_int) + + + x = torch.cat([x_inx,x_iny,x_inz,x_int],1) + + + for i in range(0,len(self.lin_layers)-1): + x = self.lin_layers[i](x) + x = self.activation(x) + x = self.lin_layers[-1](x) + + return self.scaling * x + + def cuda(self): + super(FingerNet, self).cuda() + self.lb = self.lb.cuda() + self.ub = self.ub.cuda() + + def cpu(self): + super(FingerNet, self).cpu() + self.lb = self.lb.cpu() + self.ub = self.ub.cpu() + + def to(self, device): + super(FingerNet, self).to(device) + self.lb.to(device) + self.ub.to(device) \ No newline at end of file diff --git a/PINNFramework/models/__init__.py b/PINNFramework/models/__init__.py index 00d6803..5550f1a 100644 --- a/PINNFramework/models/__init__.py +++ b/PINNFramework/models/__init__.py @@ -1,5 +1,17 @@ from .mlp import MLP - +from .distributed_moe import MoE as distMoe +from .moe_mlp import MoE as MoE +from .snake_mlp import SnakeMLP +from .Finger_Net import FingerNet +from .moe_finger import MoE as FingerMoE +from . import activations __all__ = [ - 'MLP' + 'MLP', + 'MoE', + 'distMoe', + 'SnakeMLP', + 'FingerNet', + 'FingerMoE', + 'activations' + ] diff --git a/PINNFramework/models/activations/__init__.py b/PINNFramework/models/activations/__init__.py new file mode 100644 index 0000000..9977c26 --- /dev/null +++ b/PINNFramework/models/activations/__init__.py @@ -0,0 +1,5 @@ +from .snake import Snake + +__all__ = [ + 'Snake' +] diff --git a/PINNFramework/models/activations/snake.py b/PINNFramework/models/activations/snake.py new file mode 100644 index 0000000..ca819a1 --- /dev/null +++ b/PINNFramework/models/activations/snake.py @@ -0,0 +1,20 @@ +import torch +import torch.nn as nn + + +class Snake(nn.Module,): + """ Implementation of the snake activation function as a torch nn module + The result of the activation function a(x) is calculated by a(x) = x + sin^2(x) + With alpha is a trainab + """ + + def __init__(self,frequency=10): + """Constructor function that initialize the torch module + """ + super(Snake, self).__init__() + + # making beta trainable by activating gradient calculation + self.a = nn.Parameter(torch.tensor([float(frequency)], requires_grad=True)) + + def forward(self, x): + return x + ((torch.sin(self.a* x)) ** 2) / self.a \ No newline at end of file diff --git a/PINNFramework/models/moe.py b/PINNFramework/models/distributed_moe.py similarity index 86% rename from PINNFramework/models/moe.py rename to PINNFramework/models/distributed_moe.py index 7a3f9c8..c7420d5 100644 --- a/PINNFramework/models/moe.py +++ b/PINNFramework/models/distributed_moe.py @@ -132,7 +132,9 @@ class MoE(nn.Module): k: an integer - how many experts to use for each batch element """ - def __init__(self, input_size, output_size, num_experts, hidden_size, num_hidden, activation=torch.tanh, non_linear=False,noisy_gating=False, k=4, device = "cpu"): + def __init__(self, input_size, output_size, num_experts, + hidden_size, num_hidden, lb, ub, activation=torch.tanh, + non_linear=False, noisy_gating=False, k=1, device = "cpu"): super(MoE, self).__init__() self.noisy_gating = noisy_gating self.num_experts = num_experts @@ -142,12 +144,16 @@ def __init__(self, input_size, output_size, num_experts, hidden_size, num_hidden self.device = device self.k = k self.loss = 0 - - # instantiate experts - self.experts = nn.ModuleList([MLP(input_size, output_size, hidden_size, num_hidden, activation) for i in range(self.num_experts)]) - - self.w_gate = nn.Parameter(torch.randn(input_size, num_experts), requires_grad=True) - self.w_noise = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True) + self.num_devices = torch.cuda.device_count() - 1 # cuda:0 is for handling data + print("The model runs on {} devices".format(self.num_devices)) + # instantiate experts on the needed GPUs + self.experts = nn.ModuleList([ + MLP(input_size, output_size, hidden_size, num_hidden, lb, ub, activation, + device='cuda:{}'.format((i % self.num_devices)+1)) + .to('cuda:{}'.format((i % self.num_devices)+1))for i in range(self.num_experts) + ]) + self.w_gate = nn.Parameter(torch.randn(input_size, num_experts, device=self.device), requires_grad=True) + self.w_noise = nn.Parameter(torch.zeros(input_size, num_experts, device=self.device), requires_grad=True) self.softplus = nn.Softplus() self.softmax = nn.Softmax(1) @@ -155,7 +161,8 @@ def __init__(self, input_size, output_size, num_experts, hidden_size, num_hidden self.non_linear = non_linear if self.non_linear: - self.gating_network = MLP(input_size,num_experts,num_experts*2,1,activation=F.relu) + self.gating_network = MLP(input_size, num_experts, num_experts*2, 1, lb, ub,activation=F.relu, + device=self.device).to(self.device) assert(self.k <= self.num_experts) @@ -186,9 +193,6 @@ def _gates_to_load(self, gates): """ return (gates > 0).sum(0) - - - def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values): """Helper function to NoisyTopKGating. Computes the probability that value is in top k, given different random noise. @@ -210,11 +214,11 @@ def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_val batch = clean_values.size(0) m = noisy_top_values.size(1) top_values_flat = noisy_top_values.flatten() - threshold_positions_if_in = (torch.arange(batch) * m + self.k).to(self.device)#.cuda() + threshold_positions_if_in = (torch.arange(batch) * m + self.k).to(self.device) threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1) is_in = torch.gt(noisy_values, threshold_if_in) - threshold_positions_if_out = (threshold_positions_if_in - 1).to(self.device)#.cuda() - threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat,0 , threshold_positions_if_out), 1) + threshold_positions_if_out = (threshold_positions_if_in - 1).to(self.device) + threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1) # is each value currently in the top k. prob_if_in = self.normal.cdf((clean_values - threshold_if_in)/noise_stddev) prob_if_out = self.normal.cdf((clean_values - threshold_if_out)/noise_stddev) @@ -222,7 +226,7 @@ def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_val return prob - def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2): + def noisy_top_k_gating(self, x, train, noise_epsilon=1e-1): """Noisy top-k gating. See paper: https://arxiv.org/abs/1701.06538. Args: @@ -243,7 +247,7 @@ def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2): if(self.k > 1): raw_noise_stddev = self.softplus(raw_noise_stddev) noise_stddev = ((raw_noise_stddev + noise_epsilon) * train) - noisy_logits = clean_logits + ( torch.randn_like(clean_logits) * noise_stddev) + noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev) logits = noisy_logits else: logits = clean_logits @@ -290,9 +294,18 @@ def forward(self, x, train=True, loss_coef=1e-2): self.loss = loss - dispatcher = SparseDispatcher(self.num_experts, gates) + dispatcher = SparseDispatcher(self.num_experts, gates, device=self.device) expert_inputs = dispatcher.dispatch(x) gates = dispatcher.expert_to_gates() - expert_outputs = [self.experts[i](expert_inputs[i].unsqueeze(1)) for i in range(self.num_experts)] + # Here is a loop needed for asynchonous calls of the GPUs + expert_outputs = [] + for i in range(self.num_experts): + # move data to device + exp_input = expert_inputs[i].to(self.experts[i].device) + expert_output = self.experts[i](exp_input) + # move expert output back to device + expert_output = expert_output.to(self.device) + # append it for stiching + expert_outputs.append(expert_output) y = dispatcher.combine(expert_outputs) - return y #, loss \ No newline at end of file + return y diff --git a/PINNFramework/models/mlp.py b/PINNFramework/models/mlp.py index 5b074f5..9981b75 100644 --- a/PINNFramework/models/mlp.py +++ b/PINNFramework/models/mlp.py @@ -3,13 +3,14 @@ class MLP(nn.Module): - def __init__(self, input_size, output_size, hidden_size, num_hidden, lb, ub, activation=torch.tanh): + def __init__(self, input_size, output_size, hidden_size, num_hidden, lb, ub, activation=torch.tanh, normalize=True): super(MLP, self).__init__() self.linear_layers = nn.ModuleList() self.activation = activation self.init_layers(input_size, output_size, hidden_size,num_hidden) self.lb = torch.Tensor(lb).float() self.ub = torch.Tensor(ub).float() + self.normalize = normalize def init_layers(self, input_size, output_size, hidden_size, num_hidden): self.linear_layers.append(nn.Linear(input_size, hidden_size)) @@ -23,7 +24,8 @@ def init_layers(self, input_size, output_size, hidden_size, num_hidden): nn.init.constant_(m.bias, 0) def forward(self, x): - x = 2.0*(x - self.lb)/(self.ub - self.lb) - 1.0 + if self.normalize: + x = 2.0*(x - self.lb)/(self.ub - self.lb) - 1.0 for i in range(len(self.linear_layers) - 1): x = self.linear_layers[i](x) x = self.activation(x) @@ -36,7 +38,11 @@ def cuda(self): self.ub = self.ub.cuda() def cpu(self): - super(MLP, self).cuda() + super(MLP, self).cpu() self.lb = self.lb.cpu() self.ub = self.ub.cpu() - + + def to(self, device): + super(MLP,self).to(device) + self.lb.to(device) + self.ub.to(device) diff --git a/PINNFramework/models/moe_finger.py b/PINNFramework/models/moe_finger.py new file mode 100644 index 0000000..2ed6659 --- /dev/null +++ b/PINNFramework/models/moe_finger.py @@ -0,0 +1,350 @@ +# Sparsely-Gated Mixture-of-Experts Layers. +# See "Outrageously Large Neural Networks" +# https://arxiv.org/abs/1701.06538 +# +# Author: David Rau +# +# The code is based on the TensorFlow implementation: +# https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/expert_utils.py + +import time +import torch +import torch.nn as nn +from torch.distributions.normal import Normal +import numpy as np +import torch.nn.functional as F +from PINNFramework.models import FingerNet +from PINNFramework.models import MLP + +class SparseDispatcher(object): + """Helper for implementing a mixture of experts. + The purpose of this class is to create input minibatches for the + experts and to combine the results of the experts to form a unified + output tensor. + There are two functions: + dispatch - take an input Tensor and create input Tensors for each expert. + combine - take output Tensors from each expert and form a combined output + Tensor. Outputs from different experts for the same batch element are + summed together, weighted by the provided "gates". + The class is initialized with a "gates" Tensor, which specifies which + batch elements go to which experts, and the weights to use when combining + the outputs. Batch element b is sent to expert e iff gates[b, e] != 0. + The inputs and outputs are all two-dimensional [batch, depth]. + Caller is responsible for collapsing additional dimensions prior to + calling this class and reshaping the output to the original shape. + See common_layers.reshape_like(). + Example use: + gates: a float32 `Tensor` with shape `[batch_size, num_experts]` + inputs: a float32 `Tensor` with shape `[batch_size, input_size]` + experts: a list of length `num_experts` containing sub-networks. + dispatcher = SparseDispatcher(num_experts, gates) + expert_inputs = dispatcher.dispatch(inputs) + expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)] + outputs = dispatcher.combine(expert_outputs) + The preceding code sets the output for a particular example b to: + output[b] = Sum_i(gates[b, i] * experts[i](inputs[b])) + This class takes advantage of sparsity in the gate matrix by including in the + `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`. + """ + + def __init__(self, num_experts, gates, use_gpu=False): + """Create a SparseDispatcher.""" + self.use_gpu=use_gpu + self._gates = gates + self._num_experts = num_experts + # sort experts + sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0) + # drop indices + _, self._expert_index = sorted_experts.split(1, dim=1) + # get according batch index for each expert + self._batch_index = sorted_experts[index_sorted_experts[:, 1],0] + # calculate num samples that each expert gets + if self.use_gpu: + self._part_sizes = list((gates > 0).sum(0).cuda()) + else: + self._part_sizes = list((gates > 0).sum(0)) + + # expand gates to match with self._batch_index + gates_exp = gates[self._batch_index.flatten()] + self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index) + + def dispatch(self, inp): + """Create one input Tensor for each expert. + The `Tensor` for a expert `i` contains the slices of `inp` corresponding + to the batch elements `b` where `gates[b, i] > 0`. + Args: + inp: a `Tensor` of shape "[batch_size, ]` + Returns: + a list of `num_experts` `Tensor`s with shapes + `[expert_batch_size_i, ]`. + """ + + # assigns samples to experts whose gate is nonzero + + # expand according to batch index so we can just split by _part_sizes + inp_exp = inp[self._batch_index].squeeze(1) + return torch.split(inp_exp, self._part_sizes, dim=0) + + + def combine(self, expert_out, multiply_by_gates=True): + """Sum together the expert output, weighted by the gates. + The slice corresponding to a particular batch element `b` is computed + as the sum over all experts `i` of the expert output, weighted by the + corresponding gate values. If `multiply_by_gates` is set to False, the + gate values are ignored. + Args: + expert_out: a list of `num_experts` `Tensor`s, each with shape + `[expert_batch_size_i, ]`. + multiply_by_gates: a boolean + Returns: + a `Tensor` with shape `[batch_size, ]`. + """ + # apply exp to expert outputs, so we are not longer in log space + stitched = torch.cat(expert_out, 0).exp() + + if multiply_by_gates: + stitched = stitched.mul(self._nonzero_gates) + zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), requires_grad=True) + if self.use_gpu: + zeros = zeros.cuda() + # combine samples that have been processed by the same k experts + combined = zeros.index_add(0, self._batch_index, stitched.float()) + # add eps to all zero values in order to avoid nans when going back to log space + combined[combined == 0] = np.finfo(float).eps + # back to log space + return combined.log() + + + def expert_to_gates(self): + """Gate values corresponding to the examples in the per-expert `Tensor`s. + Returns: + a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32` + and shapes `[expert_batch_size_i]` + """ + # split nonzero gates for each expert + return torch.split(self._nonzero_gates, self._part_sizes, dim=0) + +class MoE(nn.Module): + + """Call a Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts. + Args: + input_size: integer - size of the input + output_size: integer - size of the input + num_experts: an integer - number of experts + hidden_size: an integer - hidden size of the experts + noisy_gating: a boolean + k: an integer - how many experts to use for each batch element + """ + + def __init__(self, input_size, output_size, num_experts, + hidden_size, num_hidden, lb, ub, activation=torch.tanh, + non_linear=False, noisy_gating=False, k=1, scaling_factor=1.): + super(MoE, self).__init__() + self.noisy_gating = noisy_gating + self.num_experts = num_experts + self.output_size = output_size + self.input_size = input_size + self.hidden_size = hidden_size + self.use_gpu = False + self.k = k + self.loss = 0 + self.lb = torch.Tensor(lb).float() + self.ub = torch.Tensor(ub).float() + self.scaling_factor = scaling_factor + + # instantiate experts + # normalization of the MLPs is disabled cause the Gating Network performs the normalization + self.experts = nn.ModuleList([ + FingerNet(lb, ub, hidden_size, num_hidden, activation, False) + for _ in range(self.num_experts) + ]) + + self.w_gate = nn.Parameter(torch.randn(input_size, num_experts), requires_grad=True) + self.w_noise = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True) + + self.softplus = nn.Softplus() + self.softmax = nn.Softmax(1) + if self.use_gpu: + self.normal = Normal(torch.tensor([0.0]).cuda(), torch.tensor([1.0]).cuda()) + else: + self.normal = Normal(torch.tensor([0.0]).cuda(), torch.tensor([1.0]).cuda()) + + self.non_linear = non_linear + if self.non_linear: + self.gating_network = MLP(input_size, + num_experts, + num_experts*2, + 1, + lb, + ub, + activation=F.relu, + normalize=False) + + assert(self.k <= self.num_experts) + + def cv_squared(self, x): + """The squared coefficient of variation of a sample. + Useful as a loss to encourage a positive distribution to be more uniform. + Epsilons added for numerical stability. + Returns 0 for an empty Tensor. + Args: + x: a `Tensor`. + Returns: + a `Scalar`. + """ + eps = 1e-10 + # if only num_experts = 1 + if x.shape[0] == 1: + return torch.Tensor([0]) + return x.float().var() / (x.float().mean()**2 + eps) + + + def _gates_to_load(self, gates): + """Compute the true load per expert, given the gates. + The load is the number of examples for which the corresponding gate is >0. + Args: + gates: a `Tensor` of shape [batch_size, n] + Returns: + a float32 `Tensor` of shape [n] + """ + return (gates > 0).sum(0) + + def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values): + """Helper function to NoisyTopKGating. + Computes the probability that value is in top k, given different random noise. + This gives us a way of backpropagating from a loss that balances the number + of times each expert is in the top k experts per example. + In the case of no noise, pass in None for noise_stddev, and the result will + not be differentiable. + Args: + clean_values: a `Tensor` of shape [batch, n]. + noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus + normally distributed noise with standard deviation noise_stddev. + noise_stddev: a `Tensor` of shape [batch, n], or None + noisy_top_values: a `Tensor` of shape [batch, m]. + "values" Output of tf.top_k(noisy_top_values, m). m >= k+1 + Returns: + a `Tensor` of shape [batch, n]. + """ + + batch = clean_values.size(0) + m = noisy_top_values.size(1) + top_values_flat = noisy_top_values.flatten() + if self.use_gpu: + threshold_positions_if_in = (torch.arange(batch) * m + self.k).cuda() + else: + threshold_positions_if_in = (torch.arange(batch) * m + self.k).cuda() + threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1) + is_in = torch.gt(noisy_values, threshold_if_in) + if self.use_gpu: + threshold_positions_if_out = (threshold_positions_if_in - 1).cuda() + else: + threshold_positions_if_out = (threshold_positions_if_in - 1).cuda() + + threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1) + # is each value currently in the top k. + prob_if_in = self.normal.cdf((clean_values - threshold_if_in)/noise_stddev) + prob_if_out = self.normal.cdf((clean_values - threshold_if_out)/noise_stddev) + prob = torch.where(is_in, prob_if_in, prob_if_out) + return prob + + + def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2): + """Noisy top-k gating. + See paper: https://arxiv.org/abs/1701.06538. + Args: + x: input Tensor with shape [batch_size, input_size] + train: a boolean - we only add noise at training time. + noise_epsilon: a float + Returns: + gates: a Tensor with shape [batch_size, num_experts] + load: a Tensor with shape [num_experts] + """ + + if self.non_linear: + clean_logits = self.gating_network(x) + else: + clean_logits = x @ self.w_gate + if self.noisy_gating: + raw_noise_stddev = x @ self.w_noise + if(self.k > 1): + raw_noise_stddev = self.softplus(raw_noise_stddev) + noise_stddev = ((raw_noise_stddev + noise_epsilon) * train) + noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev) + logits = noisy_logits + else: + logits = clean_logits + + # calculate topk + 1 that will be needed for the noisy gates + top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1) + top_k_logits = top_logits[:, :self.k] + top_k_indices = top_indices[:, :self.k] + if(self.k > 1): + top_k_gates = self.softmax(top_k_logits) + else: + top_k_gates = torch.sigmoid(top_k_logits) + zeros = torch.zeros_like(logits, requires_grad=True) + gates = zeros.scatter(1, top_k_indices, top_k_gates) + + if self.noisy_gating and self.k < self.num_experts: + load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0) + else: + load = self._gates_to_load(gates) + return gates, load + + + def get_utilisation_loss(self): + return self.loss + + def forward(self, x, train=True, loss_coef=1e-2): + """Args: + x: tensor shape [batch_size, input_size] + train: a boolean scalar. + loss_coef: a scalar - multiplier on load-balancing losses + Returns: + y: a tensor with shape [batch_size, output_size]. + extra_training_loss: a scalar. This should be added into the overall + training loss of the model. The backpropagation of this loss + encourages all experts to be approximately equally used across a batch. + """ + # normalization is performed here for better convergence of the gating network + x = 2.0*(x - self.lb)/(self.ub - self.lb) - 1.0 + + gates, load = self.noisy_top_k_gating(x, train) + # calculate importance loss + importance = gates.sum(0) + # + loss = self.cv_squared(importance) + self.cv_squared(load) + loss *= loss_coef + self.loss = loss + + dispatcher = SparseDispatcher(self.num_experts, gates, self.use_gpu) + expert_inputs = dispatcher.dispatch(x) + gates = dispatcher.expert_to_gates() + expert_outputs = [] + for i in range(self.num_experts): + if expert_inputs[i] is not None: + expert_outputs.append(self.experts[i](expert_inputs[i])) + y = dispatcher.combine(expert_outputs) + return self.scaling_factor * y + + def cuda(self): + super(MoE, self).cuda() + self.use_gpu = True + #iterate over all experts and move them to gpu + for i in range(self.num_experts): + self.experts[i].cuda() + self.lb = self.lb.cuda() + self.ub = self.ub.cuda() + if self.non_linear: + self.gating_network.cuda() + + def cpu(self): + super(MoE, self).cpu() + self.use_gpu = False + for i in range(self.num_experts): + self.experts[i].cpu() + self.lb = self.lb.cpu() + self.ub = self.ub.cpu() + if self.non_linear: + self.gating_network.cpu() diff --git a/PINNFramework/models/moe_finger_test.py b/PINNFramework/models/moe_finger_test.py new file mode 100644 index 0000000..a8eb2e6 --- /dev/null +++ b/PINNFramework/models/moe_finger_test.py @@ -0,0 +1,32 @@ +from PINNFramework.models import FingerMoE +import numpy as np +import torch + +if __name__ == "__main__": + lb = np.array([0, 0, 0, 0]) + ub = np.array([1, 1, 1, 1]) + + # finger MoE gpu + moe = FingerMoE(4, 3, 5, 100, 2, lb, ub) + moe.cuda() + x_gpu = torch.randn(3, 4).cuda() + y_gpu = moe(x_gpu) + print(y_gpu) + + # finger MoE cpu + moe.cpu() + x_cpu = torch.randn(3, 4) + y_cpu = moe(x_cpu) + print(y_cpu) + + # non linear gating test + moe = FingerMoE(4, 3, 5, 100, 2, lb, ub, non_linear=True) + moe.cuda() + x_gpu = torch.randn(3, 4).cuda() + y_gpu = moe(x_gpu) + print(y_gpu) + moe.cpu() + x_cpu = torch.randn(3, 4) + y_cpu = moe(x_cpu) + print(y_cpu) + diff --git a/PINNFramework/models/moe_mlp.py b/PINNFramework/models/moe_mlp.py new file mode 100644 index 0000000..93c4491 --- /dev/null +++ b/PINNFramework/models/moe_mlp.py @@ -0,0 +1,361 @@ +# Sparsely-Gated Mixture-of-Experts Layers. +# See "Outrageously Large Neural Networks" +# https://arxiv.org/abs/1701.06538 +# +# Author: David Rau +# +# The code is based on the TensorFlow implementation: +# https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/expert_utils.py + +import time +import torch +import torch.nn as nn +from torch.distributions.normal import Normal +import numpy as np +import torch.nn.functional as F +from PINNFramework.models import MLP + +class SparseDispatcher(object): + """Helper for implementing a mixture of experts. + The purpose of this class is to create input minibatches for the + experts and to combine the results of the experts to form a unified + output tensor. + There are two functions: + dispatch - take an input Tensor and create input Tensors for each expert. + combine - take output Tensors from each expert and form a combined output + Tensor. Outputs from different experts for the same batch element are + summed together, weighted by the provided "gates". + The class is initialized with a "gates" Tensor, which specifies which + batch elements go to which experts, and the weights to use when combining + the outputs. Batch element b is sent to expert e iff gates[b, e] != 0. + The inputs and outputs are all two-dimensional [batch, depth]. + Caller is responsible for collapsing additional dimensions prior to + calling this class and reshaping the output to the original shape. + See common_layers.reshape_like(). + Example use: + gates: a float32 `Tensor` with shape `[batch_size, num_experts]` + inputs: a float32 `Tensor` with shape `[batch_size, input_size]` + experts: a list of length `num_experts` containing sub-networks. + dispatcher = SparseDispatcher(num_experts, gates) + expert_inputs = dispatcher.dispatch(inputs) + expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)] + outputs = dispatcher.combine(expert_outputs) + The preceding code sets the output for a particular example b to: + output[b] = Sum_i(gates[b, i] * experts[i](inputs[b])) + This class takes advantage of sparsity in the gate matrix by including in the + `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`. + """ + + def __init__(self, num_experts, gates, use_gpu=True): + """Create a SparseDispatcher.""" + self.use_gpu = use_gpu + self._gates = gates + self._num_experts = num_experts + # sort experts + sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0) + # drop indices + _, self._expert_index = sorted_experts.split(1, dim=1) + # get according batch index for each expert + self._batch_index = sorted_experts[index_sorted_experts[:, 1],0] + # calculate num samples that each expert gets + if self.use_gpu: + self._part_sizes = list((gates > 0).sum(0).cuda()) + else: + self._part_sizes = list((gates > 0).sum(0)) + + # expand gates to match with self._batch_index + gates_exp = gates[self._batch_index.flatten()] + self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index) + + def dispatch(self, inp): + """Create one input Tensor for each expert. + The `Tensor` for a expert `i` contains the slices of `inp` corresponding + to the batch elements `b` where `gates[b, i] > 0`. + Args: + inp: a `Tensor` of shape "[batch_size, ]` + Returns: + a list of `num_experts` `Tensor`s with shapes + `[expert_batch_size_i, ]`. + """ + + # assigns samples to experts whose gate is nonzero + + # expand according to batch index so we can just split by _part_sizes + inp_exp = inp[self._batch_index].squeeze(1) + return torch.split(inp_exp, self._part_sizes, dim=0) + + + def combine(self, expert_out, multiply_by_gates=True): + """Sum together the expert output, weighted by the gates. + The slice corresponding to a particular batch element `b` is computed + as the sum over all experts `i` of the expert output, weighted by the + corresponding gate values. If `multiply_by_gates` is set to False, the + gate values are ignored. + Args: + expert_out: a list of `num_experts` `Tensor`s, each with shape + `[expert_batch_size_i, ]`. + multiply_by_gates: a boolean + Returns: + a `Tensor` with shape `[batch_size, ]`. + """ + # apply exp to expert outputs, so we are not longer in log space + stitched = torch.cat(expert_out, 0).exp() + + if multiply_by_gates: + stitched = stitched.mul(self._nonzero_gates) + if self.use_gpu: + zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), requires_grad=True).cuda() + else: + zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), requires_grad=True) + + # combine samples that have been processed by the same k experts + combined = zeros.index_add(0, self._batch_index, stitched.float()) + # add eps to all zero values in order to avoid nans when going back to log space + combined[combined == 0] = np.finfo(float).eps + # back to log space + return combined.log() + + + def expert_to_gates(self): + """Gate values corresponding to the examples in the per-expert `Tensor`s. + Returns: + a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32` + and shapes `[expert_batch_size_i]` + """ + # split nonzero gates for each expert + return torch.split(self._nonzero_gates, self._part_sizes, dim=0) + +class MoE(nn.Module): + + """Call a Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts. + Args: + input_size: integer - size of the input + output_size: integer - size of the input + num_experts: an integer - number of experts + hidden_size: an integer - hidden size of the experts + noisy_gating: a boolean + k: an integer - how many experts to use for each batch element + """ + + def __init__(self, input_size, output_size, num_experts, + hidden_size, num_hidden, lb, ub, activation=torch.tanh, + non_linear=False, noisy_gating=False, k=1,): + super(MoE, self).__init__() + self.noisy_gating = noisy_gating + self.num_experts = num_experts + self.output_size = output_size + self.input_size = input_size + self.hidden_size = hidden_size + self.use_gpu = False + self.k = k + self.loss = 0 + self.lb = torch.Tensor(lb).float() + self.ub = torch.Tensor(ub).float() + + # instantiate experts + # normalization of the MLPs is disabled cause the Gating Network performs the normalization + self.experts = nn.ModuleList([ + MLP(input_size, output_size, hidden_size, num_hidden, lb, ub, activation) + for _ in range(self.num_experts) + ]) + + self.w_gate = torch.randn(input_size, num_experts, requires_grad=True) + self.w_noise = torch.zeros(input_size, num_experts, requires_grad=True) + + if self.use_gpu: + self.w_gate = self.w_gate.cuda() + self.w_noise = self.w_noise.cuda() + + self.w_gate = torch.nn.Parameter(self.w_gate) + self.w_noise = torch.nn.Parameter(self.w_noise) + + self.softplus = nn.Softplus() + self.softmax = nn.Softmax(1) + + if self.use_gpu: + self.normal = Normal(torch.tensor([0.0]).cuda(), torch.tensor([1.0]).cuda()) + else: + self.normal = Normal(torch.tensor([0.0]), torch.tensor([1.0])) + + self.non_linear = non_linear + if self.non_linear: + self.gating_network = MLP(input_size, + num_experts, + num_experts*2, + 1, + lb, + ub, + activation=F.relu, + normalize=False) + if self.use_gpu: + self.gating_network.cuda() + + assert(self.k <= self.num_experts) + + def cv_squared(self, x): + """The squared coefficient of variation of a sample. + Useful as a loss to encourage a positive distribution to be more uniform. + Epsilons added for numerical stability. + Returns 0 for an empty Tensor. + Args: + x: a `Tensor`. + Returns: + a `Scalar`. + """ + eps = 1e-10 + # if only num_experts = 1 + if x.shape[0] == 1: + return torch.Tensor([0]) + return x.float().var() / (x.float().mean()**2 + eps) + + + def _gates_to_load(self, gates): + """Compute the true load per expert, given the gates. + The load is the number of examples for which the corresponding gate is >0. + Args: + gates: a `Tensor` of shape [batch_size, n] + Returns: + a float32 `Tensor` of shape [n] + """ + return (gates > 0).sum(0) + + def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values): + """Helper function to NoisyTopKGating. + Computes the probability that value is in top k, given different random noise. + This gives us a way of backpropagating from a loss that balances the number + of times each expert is in the top k experts per example. + In the case of no noise, pass in None for noise_stddev, and the result will + not be differentiable. + Args: + clean_values: a `Tensor` of shape [batch, n]. + noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus + normally distributed noise with standard deviation noise_stddev. + noise_stddev: a `Tensor` of shape [batch, n], or None + noisy_top_values: a `Tensor` of shape [batch, m]. + "values" Output of tf.top_k(noisy_top_values, m). m >= k+1 + Returns: + a `Tensor` of shape [batch, n]. + """ + + batch = clean_values.size(0) + m = noisy_top_values.size(1) + top_values_flat = noisy_top_values.flatten() + if self.use_gpu: + threshold_positions_if_in = (torch.arange(batch) * m + self.k).cuda() + else: + threshold_positions_if_in = (torch.arange(batch) * m + self.k) + + threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1) + is_in = torch.gt(noisy_values, threshold_if_in) + if self.use_gpu: + threshold_positions_if_out = (threshold_positions_if_in - 1).cuda() + else: + threshold_positions_if_out = (threshold_positions_if_in - 1) + threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1) + # is each value currently in the top k. + prob_if_in = self.normal.cdf((clean_values - threshold_if_in)/noise_stddev) + prob_if_out = self.normal.cdf((clean_values - threshold_if_out)/noise_stddev) + prob = torch.where(is_in, prob_if_in, prob_if_out) + return prob + + + def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2): + """Noisy top-k gating. + See paper: https://arxiv.org/abs/1701.06538. + Args: + x: input Tensor with shape [batch_size, input_size] + train: a boolean - we only add noise at training time. + noise_epsilon: a float + Returns: + gates: a Tensor with shape [batch_size, num_experts] + load: a Tensor with shape [num_experts] + """ + + if self.non_linear: + clean_logits = self.gating_network(x) + else: + clean_logits = x @ self.w_gate + if self.noisy_gating: + raw_noise_stddev = x @ self.w_noise + if(self.k > 1): + raw_noise_stddev = self.softplus(raw_noise_stddev) + noise_stddev = ((raw_noise_stddev + noise_epsilon) * train) + noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev) + logits = noisy_logits + else: + logits = clean_logits + + # calculate topk + 1 that will be needed for the noisy gates + top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1) + top_k_logits = top_logits[:, :self.k] + top_k_indices = top_indices[:, :self.k] + if(self.k > 1): + top_k_gates = self.softmax(top_k_logits) + else: + top_k_gates = torch.sigmoid(top_k_logits) + zeros = torch.zeros_like(logits, requires_grad=True) + gates = zeros.scatter(1, top_k_indices, top_k_gates) + + if self.noisy_gating and self.k < self.num_experts: + load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0) + else: + load = self._gates_to_load(gates) + return gates, load + + + def get_utilisation_loss(self): + return self.loss + + def forward(self, x, train=True, loss_coef=1e-2): + """Args: + x: tensor shape [batch_size, input_size] + train: a boolean scalar. + loss_coef: a scalar - multiplier on load-balancing losses + Returns: + y: a tensor with shape [batch_size, output_size]. + extra_training_loss: a scalar. This should be added into the overall + training loss of the model. The backpropagation of this loss + encourages all experts to be approximately equally used across a batch. + """ + # normalization is performed here for better convergence of the gating network + x = 2.0*(x - self.lb)/(self.ub - self.lb) - 1.0 + gates, load = self.noisy_top_k_gating(x, train) + # calculate importance loss + importance = gates.sum(0) + # + loss = self.cv_squared(importance) + self.cv_squared(load) + loss *= loss_coef + self.loss = loss + + dispatcher = SparseDispatcher(self.num_experts, gates, self.use_gpu) + expert_inputs = dispatcher.dispatch(x) + gates = dispatcher.expert_to_gates() + expert_outputs = [] + for i in range(self.num_experts): + if expert_inputs[i] is not None: + expert_outputs.append(self.experts[i](expert_inputs[i])) + y = dispatcher.combine(expert_outputs) + return y + + def cuda(self): + super(MoE, self).cuda() + self.use_gpu = True + for i in range(self.num_experts): + self.experts[i].cuda() + self.lb = self.lb.cuda() + self.ub = self.ub.cuda() + if self.non_linear: + self.gating_network.cuda() + + def cpu(self): + super(MoE, self).cpu() + self.use_gpu = False + for i in range(self.num_experts): + self.experts[i].cpu() + self.lb = self.lb.cpu() + self.ub = self.ub.cpu() + if self.non_linear: + self.gating_network.cpu() + + + diff --git a/PINNFramework/models/moe_mlp_test.py b/PINNFramework/models/moe_mlp_test.py new file mode 100644 index 0000000..3d66740 --- /dev/null +++ b/PINNFramework/models/moe_mlp_test.py @@ -0,0 +1,29 @@ +from PINNFramework.models import MoE +import numpy as np +import torch + +if __name__ == "__main__": + lb = np.array([0, 0, 0]) + ub = np.array([1, 1, 1]) + + # linear gating test + moe = MoE(3, 3, 5, 100, 2, lb, ub) + moe.cuda() + x_gpu = torch.randn(3, 3).cuda() + y_gpu = moe(x_gpu) + print(y_gpu) + moe.cpu() + x_cpu = torch.randn(3, 3) + y_cpu = moe(x_cpu) + print(y_cpu) + + # non linear gating test + moe = MoE(3, 3, 5, 100, 2, lb, ub, non_linear=True) + moe.cuda() + x_gpu = torch.randn(3, 3).cuda() + y_gpu = moe(x_gpu) + print(y_gpu) + moe.cpu() + x_cpu = torch.randn(3, 3) + y_cpu = moe(x_cpu) + print(y_cpu) diff --git a/PINNFramework/models/snake_mlp.py b/PINNFramework/models/snake_mlp.py new file mode 100644 index 0000000..5131631 --- /dev/null +++ b/PINNFramework/models/snake_mlp.py @@ -0,0 +1,39 @@ +import torch +import torch.nn as nn +import math +from .mlp import MLP +from .activations.snake import Snake + +class SnakeMLP(MLP): + def __init__(self, input_size, output_size, hidden_size, num_hidden, lb, ub, frequency, normalize=True): + super(MLP, self).__init__() + self.linear_layers = nn.ModuleList() + self.activation = nn.ModuleList() + self.init_layers(input_size, output_size, hidden_size, num_hidden, frequency) + self.lb = torch.tensor(lb).float() + self.ub = torch.tensor(ub).float() + self.normalize = normalize + + + def init_layers(self, input_size, output_size, hidden_size, num_hidden, frequency): + self.linear_layers.append(nn.Linear(input_size, hidden_size)) + self.activation.append(Snake(frequency=frequency)) + for _ in range(num_hidden): + self.linear_layers.append(nn.Linear(hidden_size, hidden_size)) + self.activation.append(Snake(frequency=frequency)) + self.linear_layers.append(nn.Linear(hidden_size, output_size)) + + for m in self.linear_layers: + if isinstance(m, nn.Linear): + bound = math.sqrt(3 / m.weight.size()[0]) + torch.nn.init.uniform_(m.weight, a=-bound, b=bound) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + if self.normalize: + x = 2.0*(x - self.lb)/(self.ub - self.lb) - 1.0 + for i in range(len(self.linear_layers) - 1): + x = self.linear_layers[i](x) + x = self.activation[i](x) + x = self.linear_layers[-1](x) + return x diff --git a/benchmarks/hvd_test.py b/benchmarks/hvd_test.py new file mode 100644 index 0000000..5e55834 --- /dev/null +++ b/benchmarks/hvd_test.py @@ -0,0 +1,6 @@ +import horovod.torch as hvd + +if __name__ == "__main__": + # test if horovod is correct installed + hvd.init() + print("Hello from rank{}".format(hvd.rank())) \ No newline at end of file diff --git a/benchmarks/moe_benchmark.py b/benchmarks/moe_benchmark.py new file mode 100644 index 0000000..e33299f --- /dev/null +++ b/benchmarks/moe_benchmark.py @@ -0,0 +1,49 @@ +import torch +torch.manual_seed(0) +import time +import numpy as np +np.random.seed(0) +from argparse import ArgumentParser +import matplotlib.pyplot as plt +import sys + +sys.path.append('..') +parser = ArgumentParser() +parser.add_argument("--distributed", dest="distributed", type=int, default=0) +args = parser.parse_args() + +if args.distributed: + from PINNFramework.models.distributed_moe import MoE +else: + from PINNFramework.models.moe_mlp import MoE + +if __name__ == "__main__": + model = MoE(3, 3, 7, 300, 5, lb=[0, 0, 0], ub=[1, 1, 1], device='cuda:0', k=1).eval() + times = [] + for i in range(100000, 10100000,100000): + x = torch.randn((i, 3)).cuda() + torch.cuda.synchronize() + begin_time = time.time() + model.forward(x, train=False) + torch.cuda.synchronize() + end_time = time.time() + run_time = (end_time - begin_time) + print("For {} samples: {} sec".format(i, run_time)) + times.append([i, + run_time]) + del x + torch.cuda.empty_cache() + if args.distributed: + np.save("dist_run_time", times) + else: + np.save("non_dist_run_time", times) + + + """ + times = np.array(times) + plt.scatter(times[1:, 0],times[1:, 1], s=1) + plt.xlabel("Number of input samples") + plt.ylabel("Inference time") + plt.show() + """ +