From 2381d618ecc078d453fbad28e8b4bc1e728ba716 Mon Sep 17 00:00:00 2001 From: Kevin Chung Date: Thu, 10 Oct 2024 16:12:17 -0700 Subject: [PATCH] factorized autoencoder. --- src/lasdi/latent_space.py | 234 ++++++++++++++++++-------------------- 1 file changed, 112 insertions(+), 122 deletions(-) diff --git a/src/lasdi/latent_space.py b/src/lasdi/latent_space.py index 5007beb..b1ad363 100644 --- a/src/lasdi/latent_space.py +++ b/src/lasdi/latent_space.py @@ -1,6 +1,32 @@ import torch import numpy as np +# activation dict +act_dict = {'ELU': torch.nn.ELU, + 'hardshrink': torch.nn.Hardshrink, + 'hardsigmoid': torch.nn.Hardsigmoid, + 'hardtanh': torch.nn.Hardtanh, + 'hardswish': torch.nn.Hardswish, + 'leakyReLU': torch.nn.LeakyReLU, + 'logsigmoid': torch.nn.LogSigmoid, + 'multihead': torch.nn.MultiheadAttention, + 'PReLU': torch.nn.PReLU, + 'ReLU': torch.nn.ReLU, + 'ReLU6': torch.nn.ReLU6, + 'RReLU': torch.nn.RReLU, + 'SELU': torch.nn.SELU, + 'CELU': torch.nn.CELU, + 'GELU': torch.nn.GELU, + 'sigmoid': torch.nn.Sigmoid, + 'SiLU': torch.nn.SiLU, + 'mish': torch.nn.Mish, + 'softplus': torch.nn.Softplus, + 'softshrink': torch.nn.Softshrink, + 'tanh': torch.nn.Tanh, + 'tanhshrink': torch.nn.Tanhshrink, + 'threshold': torch.nn.Threshold, + } + def initial_condition_latent(param_grid, physics, autoencoder): ''' @@ -23,141 +49,114 @@ def initial_condition_latent(param_grid, physics, autoencoder): Z0.append(z0) return Z0 - -class Autoencoder(torch.nn.Module): - # set by physics.qgrid_size - qgrid_size = [] - # prod(qgrid_size) - space_dim = -1 - n_z = -1 - - # activation dict - act_dict = {'ELU': torch.nn.ELU, - 'hardshrink': torch.nn.Hardshrink, - 'hardsigmoid': torch.nn.Hardsigmoid, - 'hardtanh': torch.nn.Hardtanh, - 'hardswish': torch.nn.Hardswish, - 'leakyReLU': torch.nn.LeakyReLU, - 'logsigmoid': torch.nn.LogSigmoid, - 'multihead': torch.nn.MultiheadAttention, - 'PReLU': torch.nn.PReLU, - 'ReLU': torch.nn.ReLU, - 'ReLU6': torch.nn.ReLU6, - 'RReLU': torch.nn.RReLU, - 'SELU': torch.nn.SELU, - 'CELU': torch.nn.CELU, - 'GELU': torch.nn.GELU, - 'sigmoid': torch.nn.Sigmoid, - 'SiLU': torch.nn.SiLU, - 'mish': torch.nn.Mish, - 'softplus': torch.nn.Softplus, - 'softshrink': torch.nn.Softshrink, - 'tanh': torch.nn.Tanh, - 'tanhshrink': torch.nn.Tanhshrink, - 'threshold': torch.nn.Threshold, - } - - def __init__(self, physics, config): - super(Autoencoder, self).__init__() - self.qgrid_size = physics.qgrid_size - self.space_dim = np.prod(self.qgrid_size) - hidden_units = config['hidden_units'] - n_z = config['latent_dimension'] - self.n_z = n_z +class MultiLayerPerceptron(torch.nn.Module): - n_layers = len(hidden_units) - self.n_layers = n_layers + def __init__(self, layer_sizes, + act_type='sigmoid', reshape_index=None, reshape_shape=None, + threshold=0.1, value=0.0, num_heads=1): + super(MultiLayerPerceptron, self).__init__() - fc1_e = torch.nn.Linear(self.space_dim, hidden_units[0]) - torch.nn.init.xavier_uniform_(fc1_e.weight) - self.fc1_e = fc1_e + # including input, hidden, output layers + self.n_layers = len(layer_sizes) + self.layer_sizes = layer_sizes - if n_layers > 1: - for i in range(n_layers - 1): - fc_e = torch.nn.Linear(hidden_units[i], hidden_units[i + 1]) - torch.nn.init.xavier_uniform_(fc_e.weight) - setattr(self, 'fc' + str(i + 2) + '_e', fc_e) + # Linear features between layers + self.fcs = [] + for k in range(self.n_layers-1): + self.fcs += [torch.nn.Linear(layer_sizes[k], layer_sizes[k + 1])] + self.fcs = torch.nn.ModuleList(self.fcs) - fc_e = torch.nn.Linear(hidden_units[-1], n_z) - torch.nn.init.xavier_uniform_(fc_e.weight) - setattr(self, 'fc' + str(n_layers + 1) + '_e', fc_e) + # Reshape input or output layer + assert((reshape_index is None) or (reshape_index in [0, -1])) + assert((reshape_shape is None) or (np.prod(reshape_shape) == layer_sizes[reshape_index])) + self.reshape_index = reshape_index + self.reshape_shape = reshape_shape - act_type = config['activation'] if 'activation' in config else 'sigmoid' + # Initalize activation function + self.act_type = act_type + self.use_multihead = False if act_type == "threshold": - #grab relevant initialization values from config - threshold = config["threshold"] if "threshold" in config else 0.1 - value = config["value"] if "value" in config else 0.0 - self.g_e = self.act_dict[act_type](threshold, value) + self.act = act_dict[act_type](threshold, value) elif act_type == "multihead": - #grab relevant initialization values from config - num_heads = config['num_heads'] if 'num_heads' in config else 1 - if n_layers > 1: - for i in range(n_layers): - setattr(self, 'a' + str(i + 1), self.act_dict[act_type](hidden_units[i], num_heads)) - self.g_e = torch.nn.Identity() # No additional activation + self.use_multihead = True + if (self.n_layers > 3): # if you have more than one hidden layer + self.act = [] + for i in range(self.n_layers-2): + self.act += [act_dict[act_type](layer_sizes[i+1], num_heads)] + else: + self.act = [torch.nn.Identity()] # No additional activation + self.act = torch.nn.ModuleList(self.fcs) #all other activation functions initialized here else: - self.g_e = self.act_dict[act_type]() - - fc1_d = torch.nn.Linear(n_z, hidden_units[-1]) - torch.nn.init.xavier_uniform_(fc1_d.weight) - self.fc1_d = fc1_d - - if n_layers > 1: - for i in range(n_layers - 1, 0, -1): - fc_d = torch.nn.Linear(hidden_units[i], hidden_units[i - 1]) - torch.nn.init.xavier_uniform_(fc_d.weight) - setattr(self, 'fc' + str(n_layers - i + 1) + '_d', fc_d) - - fc_d = torch.nn.Linear(hidden_units[0], self.space_dim) - torch.nn.init.xavier_uniform_(fc_d.weight) - setattr(self, 'fc' + str(n_layers + 1) + '_d', fc_d) - - - - def encoder(self, x): - # make sure the input has a proper shape - assert(list(x.shape[-len(self.qgrid_size):]) == self.qgrid_size) - # we use torch.Tensor.view instead of torch.Tensor.reshape, - # in order to avoid data copying. - x = x.view(list(x.shape[:-len(self.qgrid_size)]) + [self.space_dim]) - - for i in range(1, self.n_layers + 1): - fc = getattr(self, 'fc' + str(i) + '_e') - x = fc(x) # apply linear layer - if hasattr(self, 'a1'): # test if there is at least one attention layer + self.act = act_dict[act_type]() + return + + def forward(self, x): + if (self.reshape_index == 0): + # make sure the input has a proper shape + assert(list(x.shape[-len(self.reshape_shape):]) == self.reshape_shape) + # we use torch.Tensor.view instead of torch.Tensor.reshape, + # in order to avoid data copying. + x = x.view(list(x.shape[:-len(self.reshape_shape)]) + [self.layer_sizes[self.reshape_index]]) + + for i in range(self.n_layers-2): + x = self.fcs[i](x) # apply linear layer + if (self.use_multihead): x = self.apply_attention(self, x, i) - - x = self.g_e(x) # apply activation function + else: + x = self.act(x) - fc = getattr(self, 'fc' + str(self.n_layers + 1) + '_e') - x = fc(x) - - return x + x = self.fcs[-1](x) + if (self.reshape_index == -1): + # we use torch.Tensor.view instead of torch.Tensor.reshape, + # in order to avoid data copying. + x = x.view(list(x.shape[:-1]) + self.reshape_shape) - def decoder(self, x): + return x + + def apply_attention(self, x, act_idx): + x = x.unsqueeze(1) # Add sequence dimension for attention + x, _ = self.act[act_idx](x, x, x) # apply attention + x = x.squeeze(1) # Remove sequence dimension + return x + + def init_weight(self): + # TODO(kevin): support other initializations? + for fc in self.fcs: + torch.nn.init.xavier_uniform_(fc.weight) + return - for i in range(1, self.n_layers + 1): - fc = getattr(self, 'fc' + str(i) + '_d') - x = fc(x) # apply linear layer - if hasattr(self, 'a1'): # test if there is at least one attention layer - x = self.apply_attention(self, x, self.n_layers - i) +class Autoencoder(torch.nn.Module): - x = self.g_e(x) # apply activation function + def __init__(self, physics, config): + super(Autoencoder, self).__init__() - fc = getattr(self, 'fc' + str(self.n_layers + 1) + '_d') - x = fc(x) + self.qgrid_size = physics.qgrid_size + self.space_dim = np.prod(self.qgrid_size) + hidden_units = config['hidden_units'] + n_z = config['latent_dimension'] + self.n_z = n_z - # we use torch.Tensor.view instead of torch.Tensor.reshape, - # in order to avoid data copying. - x = x.view(list(x.shape[:-1]) + self.qgrid_size) + layer_sizes = [self.space_dim] + hidden_units + [n_z] + #grab relevant initialization values from config + act_type = config['activation'] if 'activation' in config else 'sigmoid' + threshold = config["threshold"] if "threshold" in config else 0.1 + value = config["value"] if "value" in config else 0.0 + num_heads = config['num_heads'] if 'num_heads' in config else 1 - return x + self.encoder = MultiLayerPerceptron(layer_sizes, act_type, + reshape_index=0, reshape_shape=self.qgrid_size, + threshold=threshold, value=value, num_heads=num_heads) + + self.decoder = MultiLayerPerceptron(layer_sizes[::-1], act_type, + reshape_index=-1, reshape_shape=self.qgrid_size, + threshold=threshold, value=value, num_heads=num_heads) + return def forward(self, x): @@ -165,15 +164,6 @@ def forward(self, x): x = self.decoder(x) return x - - - def apply_attention(self, x, layer): - x = x.unsqueeze(1) # Add sequence dimension for attention - a = getattr(self, 'a' + str(layer)) - x, _ = a(x, x, x) # apply attention - x = x.squeeze(1) # Remove sequence dimension - - return x def export(self): dict_ = {'autoencoder_param': self.cpu().state_dict()}