Skip to content

Commit

Permalink
Finished adding documentation to latent_space.py
Browse files Browse the repository at this point in the history
I added comments + doc strings to the Autoencoder class.
  • Loading branch information
Robert Stephany committed Oct 28, 2024
1 parent 046cfcf commit f341da6
Showing 1 changed file with 154 additions and 31 deletions.
185 changes: 154 additions & 31 deletions src/lasdi/latent_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def __init__( self,
This class defines a standard multi-layer network network.
-------------------------------------------------------------------------------------------
Arguments
-------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -149,7 +148,6 @@ def __init__( self,
below the threshold.
-------------------------------------------------------------------------------------------
Returns
-------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -222,7 +220,6 @@ def forward(self, x : torch.Tensor) -> torch.Tensor:
elements, then the final k elements of x's shape must match self.reshape_shape.
-------------------------------------------------------------------------------------------
Returns
-------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -278,7 +275,6 @@ def init_weight(self) -> None:
"""
This function initializes the weight matrices and bias vectors in self's layers.
-------------------------------------------------------------------------------------------
Arguments
Expand All @@ -287,7 +283,6 @@ def init_weight(self) -> None:
None!
-------------------------------------------------------------------------------------------
Returns
-------------------------------------------------------------------------------------------
Expand All @@ -310,49 +305,177 @@ def init_weight(self) -> None:
# -------------------------------------------------------------------------------------------------

class Autoencoder(torch.nn.Module):
def __init__(self, physics : Physics, config : dict):
def __init__(self, physics : Physics, config : dict) -> None:
"""
Initializes an Autoencoder object. An Autoencoder consists of two networks, an encoder,
E : \mathbb{R}^F -> \mathbb{R}^L, and a decoder, D : \mathbb{R}^L -> \marthbb{R}^F. We
assume that the dataset consists of samples of a parameterized L-manifold in
\mathbb{R}^F. The idea then is that E and D act like the inverse coordinate patch and
coordinate patch, respectively. In our case, E and D are trainable neural networks. We
try to train E and map data in \mathbb{R}^F to elements of a low dimensional latent
space (\mathbb{R}^L) which D can send back to the original data. (thus, E, and D should
act like inverses of one another).
The Autoencoder class implements this model as a trainable torch.nn.Module object.
-------------------------------------------------------------------------------------------
Arguments
-------------------------------------------------------------------------------------------
physics: A "Physics" object that holds the fom solution frames. We use this object to
determine the shape of each fom solution frame. Recall that each Physics object has a
corresponding PDE. We
config: A dictionary representing the loaded .yml configuration file. We expect it to have
the following keys/:
hidden_units: A list of integers specifying the dimension of the co-domain of each
encoder layer except for the final one. Thus, if the k'th layer maps from
\mathbb{R}^{n(k)} to \mathbb{R}^{n(k + 1)} and there are K layers (indexed 0, 1, ... ,
K - 1), then hidden_units should specify n(1), ... , n(K - 1).
latent_dimension: The dimensionality of the Autoencoder's latent space. Equivalently,
the dimensionality of the co-domain of the encoder (i.e., the dimensionality of the
co-domain of the last layer of the encoder) and the domain of the decoder (i.e., the
dimensionality of the domain of the first layer of the decoder).
-------------------------------------------------------------------------------------------
Returns
-------------------------------------------------------------------------------------------
Nothing!
"""

super(Autoencoder, self).__init__()

self.qgrid_size = physics.qgrid_size;
# A Physics object's qgrid_size is a list of integers specifying the shape of each frame of
# the fom solution. If the solution is scalar valued, then this is just a list whose i'th
# element specifies the number of grid points along the i'th spatial axis. If the solution
# is vector valued, however, we prepend the dimensionality of the vector field to the list
# from the scalar list (so the 0 element represents the dimension of the vector field at
# each point).
self.qgrid_size : list[int] = physics.qgrid_size;

# The product of the elements of qgrid_size is the number of dimensions in each fom
# solution frame. This number represents the dimensionality of the input to the encoder
# (since we pass a flattened fom frame as input).
self.space_dim : np.ndarray = np.prod(self.qgrid_size);
hidden_units : int = config['hidden_units'];

# Fetch information about the domain/co-domain of each encoder layer.
hidden_units : list[int] = config['hidden_units'];
n_z : int = config['latent_dimension'];
self.n_z : int = n_z

layer_sizes = [self.space_dim] + hidden_units + [n_z]
#grab relevant initialization values from config
act_type = config['activation'] if 'activation' in config else 'sigmoid'
threshold = config["threshold"] if "threshold" in config else 0.1
value = config["value"] if "value" in config else 0.0
num_heads = config['num_heads'] if 'num_heads' in config else 1

self.encoder = MultiLayerPerceptron(layer_sizes, act_type,
reshape_index=0, reshape_shape=self.qgrid_size,
threshold=threshold, value=value, num_heads=num_heads)

self.decoder = MultiLayerPerceptron(layer_sizes[::-1], act_type,
reshape_index=-1, reshape_shape=self.qgrid_size,
threshold=threshold, value=value, num_heads=num_heads)
self.n_z : int = n_z;

# Build the "layer_sizes" argument for the MLP class. This consists of the dimensions of
# each layers' domain + the dimension of the co-domain of the final layer.
layer_sizes = [self.space_dim] + hidden_units + [n_z];

# Use the settings to set up the activation information for the encoder.
act_type = config['activation'] if 'activation' in config else 'sigmoid'
threshold = config["threshold"] if "threshold" in config else 0.1
value = config["value"] if "value" in config else 0.0

# Now, build the encoder.
self.encoder = MultiLayerPerceptron(
layer_sizes = layer_sizes,
act_type = act_type,
reshape_index = 0, # We need to flatten the spatial dimensions of each fom frame.
reshape_shape = self.qgrid_size,
threshold = threshold,
value = value);

self.decoder = MultiLayerPerceptron(
latent_sizes = layer_sizes[::-1], # Reverses the order of the the list.
act_type = act_type,
reshape_index = -1,
reshape_shape = self.qgrid_size, # We need to reshape the network output to a fom frame.
threshold = threshold,
value = value)

# All done!
return



def forward(self, x):
def forward(self, x : torch.Tensor) -> torch.Tensor:
"""
This function defines the forward pass through self.
x = self.encoder(x)
x = self.decoder(x)
-------------------------------------------------------------------------------------------
Arguments
-------------------------------------------------------------------------------------------
return x
x: A tensor holding a batch of inputs. We pass this tensor through the encoder + decoder
and then return the result.
-------------------------------------------------------------------------------------------
Returns
-------------------------------------------------------------------------------------------
The image of x under the encoder + decoder.
"""

# Encoder the input
z : torch.Tensor = self.encoder(x)

# Now decode z.
y : torch.Tensor = self.decoder(z)

# All done! Hopefully y \approx x.
return y



def export(self):
dict_ = {'autoencoder_param': self.cpu().state_dict()}
def export(self) -> dict:
"""
This function extracts self's parameters and returns them in a dictionary.
-------------------------------------------------------------------------------------------
Arguments
-------------------------------------------------------------------------------------------
None!
-------------------------------------------------------------------------------------------
Returns
-------------------------------------------------------------------------------------------
The A dictionary housing self's state dictionary.
"""

# TO DO: deep export which includes all information needed to re-initialize self from
# scratch. This would probably require changing the initializer.

dict_ = { 'autoencoder_param' : self.cpu().state_dict()}
return dict_



def load(self, dict_):
def load(self, dict_ : dict) -> None:
"""
This function loads self's state dictionary.
-------------------------------------------------------------------------------------------
Arguments
-------------------------------------------------------------------------------------------
dict_: This should be a dictionary with the key "autoencoder_param" whose corresponding
value is the state dictionary of an autoencoder which has the same architecture (i.e.,
layer sizes) as self.
-------------------------------------------------------------------------------------------
Returns
-------------------------------------------------------------------------------------------
Nothing!
"""

self.load_state_dict(dict_['autoencoder_param'])
return

0 comments on commit f341da6

Please sign in to comment.