diff --git a/simulation/differential_mesh/differential_mesh_graph.py b/simulation/differential_mesh/differential_mesh_graph.py index 9352164..0e44e17 100644 --- a/simulation/differential_mesh/differential_mesh_graph.py +++ b/simulation/differential_mesh/differential_mesh_graph.py @@ -8,6 +8,7 @@ import itertools from collections.abc import Iterator +from enum import Enum, auto from typing import Any import matplotlib.pyplot as plt @@ -30,6 +31,12 @@ DIFFERENTIAL_MESH_GRAPH_EDGE_COLOR = "#1f78b4" +class DifferentialMeshGraphBoundaryCondition(Enum): + """Boundary condition enumeration.""" + REFERENCE_NODE = auto() + ZERO_MEAN_NODE_POTENTIAL = auto() + + class DifferentialMeshGraph: """Interface for a differential mesh graph. @@ -282,9 +289,12 @@ def output_spice_netlist(self, netlist: str, target_node: int = -1) -> None: f.write(f".print dc v({target_node})\n") f.write(".end\n") - def calculate_node_standard_errors(self, - noise: float = 1 - ) -> list[tuple[int, int]]: + def calculate_node_standard_errors( + self, + noise: float = 1, + boundary_condition: DifferentialMeshGraphBoundaryCondition = ( + DifferentialMeshGraphBoundaryCondition.REFERENCE_NODE), + ) -> list[tuple[int, int]]: """Calculates the standard error of the node potentials using eigendecomposition. @@ -294,16 +304,38 @@ def calculate_node_standard_errors(self, Returns: A list of 2-tuples, each consisting of the node label and the corresponding standard error. + + Raises: + ValueError: If the boundary condition is invalid. """ - node_to_index_map = self.get_node_to_index_map() - root_index = node_to_index_map[DIFFERENTIAL_MESH_GRAPH_ROOT_NODE] index_to_node_map = self.get_index_to_node_map() - L = self.create_laplacian_matrix() - L[root_index, :] = 0 - L[:, root_index] = 0 - Linv = np.linalg.pinv(L, hermitian=True) - squared_stderrs = np.diag(Linv) + match boundary_condition: + case DifferentialMeshGraphBoundaryCondition.REFERENCE_NODE: + node_to_index_map = self.get_node_to_index_map() + root_index = node_to_index_map[ + DIFFERENTIAL_MESH_GRAPH_ROOT_NODE] + + L[root_index, :] = 0 + L[:, root_index] = 0 + Linv = np.linalg.pinv(L, hermitian=True) + squared_stderrs = np.diag(Linv) + case DifferentialMeshGraphBoundaryCondition.ZERO_MEAN_NODE_POTENTIAL: + eigenvalues, eigenvectors = np.linalg.eigh(L) + zero_index = np.argmin(eigenvalues) + + # Remove the zero eigenvalue and its corresponding eigenvector. + eigenvectors_squared = eigenvectors**2 + eigenvectors_squared_without_zero = np.delete( + eigenvectors_squared, zero_index, axis=1) + eigenvalues_without_zero = np.delete(eigenvalues, zero_index) + + squared_stderrs = np.sum(eigenvectors_squared_without_zero / + eigenvalues_without_zero, + axis=1).T + case _: + raise ValueError("Invalid boundary condition: %s.", + boundary_condition) return [(index_to_node_map[node_index], np.sqrt(squared_stderr) * noise) for node_index, squared_stderr in enumerate(squared_stderrs)]