diff --git a/docs/api_core.rst b/docs/api_core.rst index bf10f191..4a68cb05 100644 --- a/docs/api_core.rst +++ b/docs/api_core.rst @@ -50,4 +50,30 @@ These are the building blocks which are used to construct the CVs. :template: custom-class-template.rst Transform + + +.. rubric:: Transform.descriptors + +.. currentmodule:: mlcolvar.core.transform.descriptors + +.. autosummary:: + :toctree: autosummary + :template: custom-class-template.rst + + PairwiseDistances + TorsionalAngle + CoordinationNumbers + EigsAdjMat + MultipleDescriptors + +.. rubric:: Transform.tools + +.. currentmodule:: mlcolvar.core.transform.tools + +.. autosummary:: + :toctree: autosummary + :template: custom-class-template.rst + Normalization + ContinuousHistogram + SwitchingFunctions \ No newline at end of file diff --git a/mlcolvar/core/stats/tica.py b/mlcolvar/core/stats/tica.py index 3a6b2061..0a7e3739 100644 --- a/mlcolvar/core/stats/tica.py +++ b/mlcolvar/core/stats/tica.py @@ -10,7 +10,7 @@ compute_average, reduced_rank_eig, ) -from mlcolvar.core.transform.utils import batch_reshape +from mlcolvar.core.transform.tools.utils import batch_reshape import warnings diff --git a/mlcolvar/core/transform/__init__.py b/mlcolvar/core/transform/__init__.py index 5099f761..80e88304 100644 --- a/mlcolvar/core/transform/__init__.py +++ b/mlcolvar/core/transform/__init__.py @@ -1,5 +1,6 @@ -__all__ = ["Transform", "Normalization", "Statistics", "Inverse"] +__all__ = ["Transform","Normalization","Statistics","SwitchingFunctions","MultipleDescriptors","PairwiseDistances","EigsAdjMat","ContinuousHistogram","Inverse",'TorsionalAngle'] from .transform import * -from .normalization import * from .utils import * +from .tools import * +from .descriptors import * diff --git a/mlcolvar/core/transform/descriptors/__init__.py b/mlcolvar/core/transform/descriptors/__init__.py new file mode 100644 index 00000000..588cce97 --- /dev/null +++ b/mlcolvar/core/transform/descriptors/__init__.py @@ -0,0 +1,7 @@ +__all__ = ["MultipleDescriptors", "CoordinationNumbers", "EigsAdjMat", "PairwiseDistances", "TorsionalAngle"] + +from .coordination_numbers import * +from .eigs_adjacency_matrix import * +from .pairwise_distances import * +from .torsional_angle import * +from .multiple_descriptors import * \ No newline at end of file diff --git a/mlcolvar/core/transform/descriptors/coordination_numbers.py b/mlcolvar/core/transform/descriptors/coordination_numbers.py new file mode 100644 index 00000000..8aa94b97 --- /dev/null +++ b/mlcolvar/core/transform/descriptors/coordination_numbers.py @@ -0,0 +1,202 @@ +import torch +import numpy as np + +from mlcolvar.core.transform import Transform +from mlcolvar.core.transform.descriptors.utils import compute_distances_matrix, apply_cutoff, sanitize_positions_shape + +from typing import Union + +__all__ = ["CoordinationNumbers"] + +class CoordinationNumbers(Transform): + """ + Coordination number between the elements of two groups of atoms from their positions + """ + + def __init__(self, + group_A: list, + group_B: list, + cutoff: float, + n_atoms: int, + PBC: bool, + cell: Union[float, list], + mode: str, + scaled_coords: bool = False, + switching_function = None) -> torch.Tensor: + """Initialize a coordination number object between two groups of atoms A and B. + + Parameters + ---------- + group_A : list + Zero-based indices of group A atoms + group_B : list + Zero-based indices of group B atoms + cutoff : float + Cutoff radius for coordination number evaluation + n_atoms : int + Total number of atoms in the system + PBC : bool + Switch for Periodic Boundary Conditions use + cell : Union[float, list] + Dimensions of the real cell, orthorombic-like cells only + mode : str + Mode for cutoff application, either: + - 'continuous': applies a switching function to the distances which can be specified with switching_function keyword, has stable derivatives + - 'discontinuous': set at zero everything above the cutoff and one below, derivatives may be be incorrect + scaled_coords : bool + Switch for coordinates scaled on cell's vectors use, by default False + switching_function : _type_, optional + Switching function to be applied for the cutoff, can be either initialized as a switching_functions/SwitchingFunctions class or a simple function, by default None + + Returns + ------- + torch.Tensor + Coordination numbers of elements of group A with respect to elements of group B + """ + super().__init__(in_features=int(n_atoms*3), out_features=len(group_A)) + + # parse args + self.group_A = group_A + self._group_A_size = len(group_A) + self.group_B = group_B + self._group_B_size = len(group_B) + self._reordering = np.concatenate((self.group_A, self.group_B)) + self.cutoff = cutoff + self.n_atoms = n_atoms + self.PBC = PBC + self.cell = cell + self.scaled_coords = scaled_coords + self.mode = mode + self.switching_function = switching_function + + def compute_coordination_number(self, pos): + # move the group A elements to first positions + pos, batch_size = sanitize_positions_shape(pos, self.n_atoms) + pos = pos[:, self._reordering, :] + dist = compute_distances_matrix(pos=pos, + n_atoms=self.n_atoms, + PBC=self.PBC, + cell=self.cell, + scaled_coords=self.scaled_coords) + + # we can apply the switching cutoff with the switching function + contributions = apply_cutoff(x=dist, + cutoff=self.cutoff, + mode=self.mode, + switching_function=self.switching_function) + + # we can throw away part of the matrix as it is repeated uselessly + contributions = contributions[:, :self._group_A_size, :] + + # and also ensure that the AxA part of the matrix is zero, we need also to preserve the gradients + mask = torch.ones_like(contributions) + mask[:, :self._group_A_size, :self._group_A_size] = 0 + contributions = contributions*mask + + # compute coordination + coord_numbers = torch.sum(contributions, dim=-1) + + return coord_numbers + + def forward(self, pos): + coord_numbers = self.compute_coordination_number(pos) + return coord_numbers + + +def test_coordination_number(): + from mlcolvar.core.transform.tools.switching_functions import SwitchingFunctions + + # simple example based on calixarene water coordination numbers + pos = torch.Tensor([[[-0.410219, -0.680065, -2.016121], + [-0.164329, -0.630426, -2.120843], + [-0.250341, -0.392700, -1.534535], + [-0.277187, -0.615506, -1.335904], + [-0.762276, -1.041939, -1.546581], + [-0.200766, -0.851481, -1.534129], + [ 0.051099, -0.898884, -1.628219], + [-1.257225, 1.671602, 0.166190], + [-0.486917, -0.902610, -1.554715], + [-0.020386, -0.566621, -1.597171], + [-0.507683, -0.541252, -1.540805], + [-0.527323, -0.206236, -1.532587]], + [[-0.410387, -0.677657, -2.018355], + [-0.163502, -0.626094, -2.123348], + [-0.250672, -0.389610, -1.536810], + [-0.275395, -0.612535, -1.338175], + [-0.762197, -1.037856, -1.547382], + [-0.200948, -0.847825, -1.536010], + [ 0.051170, -0.896311, -1.629396], + [-1.257530, 1.674078, 0.165089], + [-0.486894, -0.900076, -1.556366], + [-0.020235, -0.563252, -1.601229], + [-0.507242, -0.537527, -1.543025], + [-0.528576, -0.202031, -1.534733]]]) + + cell = 4.0273098 + pos.requires_grad = True + + n_atoms = 12 + cutoff=0.25 + switching_function=SwitchingFunctions(in_features=n_atoms*3, name='Rational', cutoff=cutoff, options={'n': 2, 'm' : 6, 'eps' : 1e0}) + + model = CoordinationNumbers(group_A=[0, 1], + group_B=[2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + cutoff=cutoff, + n_atoms=n_atoms, + PBC=True, + cell=cell, + mode='continuous', + scaled_coords=False, + switching_function=switching_function) + + out = model(pos) + out.sum().backward() + + # we swap by hand the 0,1 atoms with 2,3 + pos = torch.Tensor([[[-0.250341, -0.392700, -1.534535], + [-0.277187, -0.615506, -1.335904], + [-0.410219, -0.680065, -2.016121], + [-0.164329, -0.630426, -2.120843], + [-0.762276, -1.041939, -1.546581], + [-0.200766, -0.851481, -1.534129], + [ 0.051099, -0.898884, -1.628219], + [-1.257225, 1.671602, 0.166190], + [-0.486917, -0.902610, -1.554715], + [-0.020386, -0.566621, -1.597171], + [-0.507683, -0.541252, -1.540805], + [-0.527323, -0.206236, -1.532587]], + [[-0.250672, -0.389610, -1.536810], + [-0.275395, -0.612535, -1.338175], + [-0.410387, -0.677657, -2.018355], + [-0.163502, -0.626094, -2.123348], + [-0.762197, -1.037856, -1.547382], + [-0.200948, -0.847825, -1.536010], + [ 0.051170, -0.896311, -1.629396], + [-1.257530, 1.674078, 0.165089], + [-0.486894, -0.900076, -1.556366], + [-0.020235, -0.563252, -1.601229], + [-0.507242, -0.537527, -1.543025], + [-0.528576, -0.202031, -1.534733]]]) + + pos.requires_grad = True + switching_function=SwitchingFunctions(in_features=n_atoms*3, name='Rational', cutoff=cutoff, options={'n': 2, 'm' : 6, 'eps' : 1e0}) + + model = CoordinationNumbers(group_A=[2, 3], + group_B=[0, 1, 4, 5, 6, 7, 8, 9, 10, 11], + cutoff=cutoff, + n_atoms=n_atoms, + PBC=True, + cell=cell, + mode='continuous', + scaled_coords=False, + switching_function=switching_function) + + out_2 = model(pos) + out_2.sum().backward() + assert(torch.allclose(out, out_2)) + + # TODO add reference value for check + +if __name__ == "__main__": + test_coordination_number() + diff --git a/mlcolvar/core/transform/descriptors/eigs_adjacency_matrix.py b/mlcolvar/core/transform/descriptors/eigs_adjacency_matrix.py new file mode 100644 index 00000000..bd770c1d --- /dev/null +++ b/mlcolvar/core/transform/descriptors/eigs_adjacency_matrix.py @@ -0,0 +1,118 @@ +import torch + +from mlcolvar.core.transform import Transform +from mlcolvar.core.transform.descriptors.utils import compute_adjacency_matrix + +from typing import Union + +__all__ = ["EigsAdjMat"] + +class EigsAdjMat(Transform): + """ + Eigenvalues of the adjacency matrix for a set of atoms from their positions + """ + + def __init__(self, + mode: str, + cutoff: float, + n_atoms: int, + PBC: bool, + cell: Union[float, list], + scaled_coords: bool = False, + switching_function = None) -> torch.Tensor: + """Initialize an eigenvalues of an adjacency matrix object. + + Parameters + ---------- + mode : str + Mode for cutoff application, either: + - 'continuous': applies a switching function to the distances which can be specified with switching_function keyword, has stable derivatives + - 'discontinuous': set at zero everything above the cutoff and one below, derivatives may be be incorrect + cutoff : float + Cutoff for the adjacency criterion + n_atoms : int + Number of atoms in the system + PBC : bool + Switch for Periodic Boundary Conditions use + cell : Union[float, list] + Dimensions of the real cell, orthorombic-like cells only + scaled_coords : bool + Switch for coordinates scaled on cell's vectors use, by default False + switching_function : _type_, optional + Switching function to be applied for the cutoff, can be either initialized as a switching_functions/SwitchingFunctions class or a simple function, by default None + + Returns + ------- + torch.Tensor + Adjacency matrix of all the n_atoms according to cutoff + """ + super().__init__(in_features=int(n_atoms*3), out_features=n_atoms) + + # parse args + self.mode = mode + self.cutoff = cutoff + self.n_atoms = n_atoms + self.PBC = PBC + self.cell = cell + self.scaled_coords = scaled_coords + self.switching_function = switching_function + + def compute_adjacency_matrix(self, pos): + pos = compute_adjacency_matrix(pos=pos, + mode=self.mode, + cutoff=self.cutoff, + n_atoms=self.n_atoms, + PBC=self.PBC, + cell=self.cell, + scaled_coords=self.scaled_coords, + switching_function=self.switching_function) + return pos + + def get_eigenvalues(self, x): + eigs = torch.linalg.eigvalsh(x) + return eigs + + def forward(self, x: torch.Tensor): + x = self.compute_adjacency_matrix(x) + eigs = self.get_eigenvalues(x) + return eigs + +def test_eigs_of_adj_matrix(): + from mlcolvar.core.transform.tools.switching_functions import SwitchingFunctions + + n_atoms=2 + pos = torch.Tensor([ [ [0., 0., 0.], + [1., 1., 1.] ], + [ [0., 0., 0.], + [1., 1.1, 1.] ] ] + ) + pos.requires_grad = True + cell = torch.Tensor([1., 2., 1.]) + + cutoff = 1.8 + switching_function=SwitchingFunctions(in_features=n_atoms*3, name='Fermi', cutoff=cutoff, options={'q':0.05}) + + model = EigsAdjMat(mode='continuous', + cutoff=cutoff, + n_atoms=n_atoms, + PBC=True, + cell=cell, + scaled_coords=False, + switching_function=switching_function) + out = model(pos) + out.sum().backward() + + pos = torch.einsum('bij,j->bij', pos, 1/cell) + model = EigsAdjMat(mode='continuous', + cutoff=cutoff, + n_atoms=n_atoms, + PBC=True, + cell=cell, + scaled_coords=True, + switching_function=switching_function) + out = model(pos) + assert(out.shape[-1] == model.out_features) + out.sum().backward() + +if __name__ == "__main__": + test_eigs_of_adj_matrix() \ No newline at end of file diff --git a/mlcolvar/core/transform/descriptors/multiple_descriptors.py b/mlcolvar/core/transform/descriptors/multiple_descriptors.py new file mode 100644 index 00000000..63f1a604 --- /dev/null +++ b/mlcolvar/core/transform/descriptors/multiple_descriptors.py @@ -0,0 +1,89 @@ +import torch + +__all__ = ["MultipleDescriptors"] + +class MultipleDescriptors(torch.nn.Module): + """Wrapper class to combine multiple descriptor transform objects acting on the same set of atomic positions""" + def __init__(self, + descriptors_list: list, + n_atoms: int, + ): + """_summary_ + + Parameters + ---------- + descriptors_list : list + List of descriptor transform objects to be combined + n_atoms : int + Number of atoms in the system + """ + super().__init__() + self.in_features = n_atoms * 3 + self.descriptors_list = descriptors_list + + self.out_features = 0 + for d in self.descriptors_list: + self.out_features += d.out_features + + def forward(self, pos): + for i,d in enumerate(self.descriptors_list): + if i == 0: + out = d(pos) + else: + aux = d(pos) + out = torch.concatenate((out, aux), 1) + return out + +def test_multipledescriptors(): + from .torsional_angle import TorsionalAngle + from .pairwise_distances import PairwiseDistances + + # check using torsional angles and distances in alanine + pos = torch.Tensor([[[ 0.3887, -0.4169, -0.1212], + [ 0.4264, -0.4374, -0.0983], + [ 0.4574, -0.4136, -0.0931], + [ 0.4273, -0.4797, -0.0871], + [ 0.4684, 0.4965, -0.0692], + [ 0.4478, 0.4571, -0.0441], + [-0.4933, 0.4869, -0.1026], + [-0.4840, 0.4488, -0.1116], + [-0.4748, -0.4781, -0.1232], + [-0.4407, -0.4781, -0.1569]], + [[ 0.3910, -0.4103, -0.1189], + [ 0.4334, -0.4329, -0.1020], + [ 0.4682, -0.4145, -0.1013], + [ 0.4322, -0.4739, -0.0867], + [ 0.4669, -0.4992, -0.0666], + [ 0.4448, 0.4670, -0.0375], + [-0.4975, 0.4844, -0.0981], + [-0.4849, 0.4466, -0.0991], + [-0.4818, -0.4870, -0.1291], + [-0.4490, -0.4933, -0.1668]]]) + pos.requires_grad = True + cell = torch.Tensor([3.0233, 3.0233, 3.0233]) + + # model 1 and 2 for torsional angles, model 3 for distances + model_1 = TorsionalAngle(indices=[1,3,4,6], n_atoms=10, mode=['angle'], PBC=False, cell=cell, scaled_coords=False) + model_2 = TorsionalAngle(indices=[3,4,6,8], n_atoms=10, mode=['angle'], PBC=False, cell=cell, scaled_coords=False) + model_3 = PairwiseDistances(n_atoms=10, PBC=True, cell=cell, scaled_coords=False, slicing_pairs=[[0, 1], [0, 2]]) + + # compute single references + angle_1 = model_1(pos) + angle_2 = model_2(pos) + distances = model_3(pos) + + # stack torsional angles + model_tot = MultipleDescriptors(descriptors_list=[model_1, model_2], n_atoms=10) + out = model_tot(pos) + out.sum().backward() + for i in range(len(pos)): + assert(torch.allclose(out[i, 0], angle_1[i])) + assert(torch.allclose(out[i, 1], angle_2[i])) + + # stack torsional angle and two distances + model_tot = MultipleDescriptors(descriptors_list=[model_1, model_3], n_atoms=10) + out = model_tot(pos) + out.sum().backward() + for i in range(len(pos)): + assert(torch.allclose(out[i, 0], angle_1[i])) + assert(torch.allclose(out[i, 1:], distances[i])) \ No newline at end of file diff --git a/mlcolvar/core/transform/descriptors/pairwise_distances.py b/mlcolvar/core/transform/descriptors/pairwise_distances.py new file mode 100644 index 00000000..eaa08d71 --- /dev/null +++ b/mlcolvar/core/transform/descriptors/pairwise_distances.py @@ -0,0 +1,119 @@ +import torch + +from mlcolvar.core.transform import Transform +from mlcolvar.core.transform.descriptors.utils import compute_distances_matrix + +from typing import Union + +__all__ = ["PairwiseDistances"] + +class PairwiseDistances(Transform): + """ + Non duplicated pairwise distances for a set of atoms from their positions + """ + + def __init__(self, + n_atoms: int, + PBC: bool, + cell: Union[float, list], + scaled_coords: bool = False, + slicing_pairs: list = None) -> torch.Tensor: + """Initialize a pairwise distances matrix object. + + Parameters + ---------- + n_atoms : int + Number of atoms in the system + PBC : bool + Switch for Periodic Boundary Conditions use + cell : Union[float, list] + Dimensions of the real cell, orthorombic-like cells only + scaled_coords : bool + Switch for coordinates scaled on cell's vectors use, by default False + slicing_pairs : list + indices of the subset of distances to be returned, by default None + + Returns + ------- + torch.Tensor + Non duplicated pairwise distances between all the atoms + """ + if slicing_pairs is None: + super().__init__(in_features=int(n_atoms*3), out_features=int(n_atoms*(n_atoms-1) / 2)) + else: + super().__init__(in_features=int(n_atoms*3), out_features=len(slicing_pairs)) + + # parse args + self.n_atoms = n_atoms + self.PBC = PBC + self.cell = cell + self.scaled_coords = scaled_coords + if slicing_pairs is not None: + self.slicing_pairs = torch.Tensor(slicing_pairs).to(torch.long) + else: + self.slicing_pairs = slicing_pairs + + def compute_pairwise_distances(self, pos): + dist = compute_distances_matrix(pos=pos, + n_atoms=self.n_atoms, + PBC=self.PBC, + cell=self.cell, + scaled_coords=self.scaled_coords) + batch_size = dist.shape[0] + if self.slicing_pairs is None: + device = pos.device + # mask out diagonal elements + aux_mask = torch.ones_like(dist, device=device) - torch.eye(dist.shape[-1], device=device) + # keep upper triangular part to avoid duplicates + unique = aux_mask.triu().nonzero(as_tuple=True) + pairwise_distances = dist[unique].reshape((batch_size, -1)) + return pairwise_distances + else: + return dist[:, self.slicing_pairs[:, 0], self.slicing_pairs[:, 1]] + + + def forward(self, x: torch.Tensor): + x = self.compute_pairwise_distances(x) + return x + +def test_pairwise_distances(): + # simple test based on alanine distances + pos_abs = torch.Tensor([[ 1.4970, 1.3861, -0.0273, -1.4933, 1.5070, -0.1133, -1.4473, -1.4193, + -0.0553, 1.4940, 1.4990, -0.2403, 1.4780, -1.4173, -0.3363, -1.4243, + -1.4093, -0.4293, 1.3530, -1.4313, -0.4183, 1.3060, 1.4750, -0.4333, + 1.2970, -1.3233, -0.4643, 1.1670, -1.3253, -0.5354]]) + pos_abs.requires_grad = True + + cell = torch.Tensor([3.0233]) + + pos_scaled = pos_abs / cell + + ref_distances = torch.Tensor([[0.1521, 0.2335, 0.2412, 0.3798, 0.4733, 0.4649, 0.4575, 0.5741, 0.6815, + 0.1220, 0.1323, 0.2495, 0.3407, 0.3627, 0.3919, 0.4634, 0.5885, 0.2280, + 0.2976, 0.3748, 0.4262, 0.4821, 0.5043, 0.6376, 0.1447, 0.2449, 0.2454, + 0.2705, 0.3597, 0.4833, 0.1528, 0.1502, 0.2370, 0.2408, 0.3805, 0.2472, + 0.3243, 0.3159, 0.4527, 0.1270, 0.1301, 0.2440, 0.2273, 0.2819, 0.1482]]) + + # PBC no scaled coords + model = PairwiseDistances(n_atoms=10, PBC=True, cell=cell, scaled_coords=False) + out = model(pos_abs) + assert(out.reshape(pos_abs.shape[0], -1).shape[-1] == model.out_features) + assert(torch.allclose(out, ref_distances, atol=1e-3)) + out.sum().backward() + + # PBC no scaled coords slicing + model = PairwiseDistances(n_atoms=10, PBC=True, cell=cell, scaled_coords=False, slicing_pairs=[[0, 1], [0, 2]]) + out = model(pos_abs) + assert(torch.allclose(out, ref_distances[:, [0, 1]], atol=1e-3)) + out.sum().backward() + + # PBC and scaled coords + model = PairwiseDistances(n_atoms=10, PBC=True, cell=cell, scaled_coords=True) + out = model(pos_scaled) + assert(out.reshape(pos_scaled.shape[0], -1).shape[-1] == model.out_features) + assert(torch.allclose(out, ref_distances, atol=1e-3)) + out.sum().backward() + + +if __name__ == "__main__": + test_pairwise_distances() \ No newline at end of file diff --git a/mlcolvar/core/transform/descriptors/torsional_angle.py b/mlcolvar/core/transform/descriptors/torsional_angle.py new file mode 100644 index 00000000..c2414961 --- /dev/null +++ b/mlcolvar/core/transform/descriptors/torsional_angle.py @@ -0,0 +1,161 @@ +import torch +import numpy as np + +from mlcolvar.core.transform import Transform +from mlcolvar.core.transform.descriptors.utils import compute_distances_matrix, sanitize_positions_shape + +from typing import Union + +__all__ = ["TorsionalAngle"] + +class TorsionalAngle(Transform): + """ + Torsional angle defined by a set of 4 atoms from their positions + """ + + MODES = ["angle", "sin", "cos"] + + def __init__(self, + indices: Union[list, np.ndarray, torch.Tensor], + n_atoms: int, + mode: Union[str, list], + PBC: bool, + cell: Union[float, list], + scaled_coords: bool = False) -> torch.Tensor: + """Initialize a torsional angle object + + Parameters + ---------- + indices : Union[list, np.ndarray, torch.Tensor] + Indices of the ordered atoms defining the torsional angle + n_atoms : int + Number of atoms in the positions tensor used in the forward. + mode : Union[str, list] + Which quantities to return among 'angle', 'sin' and 'cos' + PBC : bool + Switch for Periodic Boundary Conditions use + cell : Union[float, list] + Dimensions of the real cell, orthorombic-like cells only + scaled_coords : bool, optional + Switch for coordinates scaled on cell's vectors use, by default False + + Returns + ------- + torch.Tensor + Depending on `mode` selection, the torsional angle in radiants, its sine and its cosine. + """ + + # check mode here to get number of out_features + for i in mode: + if i not in self.MODES: + raise ValueError(f'The mode {i} is not available in this class. The available modes are: {", ".join(self.MODES)}.') + + mode_idx = [] + + for n in mode: + if n not in self.MODES: + raise(ValueError(f"The given mode : {n} is not available! The available options are {', '.join(self.MODES)}")) + + for i,m in enumerate(self.MODES): + if m in mode: + mode_idx.append(i) + self.mode_idx = mode_idx + + # now we can initialize the mother class + super().__init__(in_features=int(n_atoms*3), out_features=len(mode_idx)) + + # initialize class attributes + self.indices = indices + self.n_atoms = n_atoms + self.PBC = PBC + self.cell = cell + self.scaled_coords = scaled_coords + + + def compute_torsional_angle(self, pos): + tors_pos, batch_size = sanitize_positions_shape(pos, self.n_atoms) + + # select relevant atoms only + tors_pos = tors_pos[:, self.indices, :] + + dist_components = compute_distances_matrix(pos=tors_pos, + n_atoms=4, + PBC=self.PBC, + cell=self.cell, + scaled_coords=self.scaled_coords, + vector=True) + + # get AB, BC, CD distances + AB = dist_components[:, :, 0, 1] + BC = dist_components[:, :, 1, 2] + CD = dist_components[:, :, 2, 3] + + # check that they are in the -0.5 : 0.5 range + AB = self._center_distances(AB) + BC = self._center_distances(BC) + CD = self._center_distances(CD) + # obtain normal direction + n1 = torch.cross(AB, BC) + n2 = torch.cross(BC, CD) + # obtain versors + n1_normalized = n1 / torch.norm(n1, dim=1, keepdim=True) + n2_normalized = n2 / torch.norm(n2, dim=1, keepdim=True) + UBC= BC / torch.norm(BC,dim=1,keepdim=True) + + sin = torch.einsum('bij,bij->bj', torch.cross(n1_normalized, n2_normalized).unsqueeze(-1), UBC.unsqueeze(-1)) + cos = torch.einsum('bij,bij->bj', n1_normalized.unsqueeze(-1), n2_normalized.unsqueeze(-1)) + + angle = torch.atan2(sin, cos) + + return torch.hstack([angle, sin, cos]) + + def _center_distances(self, dist): + dist[dist > 0.5] = dist[dist > 0.5] - 1 + dist[dist < -0.5] = dist[dist < -0.5] + 1 + return dist + + def forward(self, x): + out = self.compute_torsional_angle(x) + return out[:, self.mode_idx] + +def test_torsional_angle(): + # simple test on alanine phi angle + pos = torch.Tensor([[[ 0.3887, -0.4169, -0.1212], + [ 0.4264, -0.4374, -0.0983], + [ 0.4574, -0.4136, -0.0931], + [ 0.4273, -0.4797, -0.0871], + [ 0.4684, 0.4965, -0.0692], + [ 0.4478, 0.4571, -0.0441], + [-0.4933, 0.4869, -0.1026], + [-0.4840, 0.4488, -0.1116], + [-0.4748, -0.4781, -0.1232], + [-0.4407, -0.4781, -0.1569]], + [[ 0.3910, -0.4103, -0.1189], + [ 0.4334, -0.4329, -0.1020], + [ 0.4682, -0.4145, -0.1013], + [ 0.4322, -0.4739, -0.0867], + [ 0.4669, -0.4992, -0.0666], + [ 0.4448, 0.4670, -0.0375], + [-0.4975, 0.4844, -0.0981], + [-0.4849, 0.4466, -0.0991], + [-0.4818, -0.4870, -0.1291], + [-0.4490, -0.4933, -0.1668]]]) + pos.requires_grad = True + + cell = torch.Tensor([3.0233, 3.0233, 3.0233]) + model = TorsionalAngle(indices=[1,3,4,6], n_atoms=10, mode=['angle', 'sin', 'cos'], PBC=False, cell=cell, scaled_coords=False) + angle = model(pos) + print(angle) + angle.sum().backward() + + model = TorsionalAngle([1,3,4,6], n_atoms=10, mode=['sin'], PBC=False, cell=cell, scaled_coords=False) + angle = model(pos) + print(angle) + angle.sum().backward() + + # TODO add reference value for check + +if __name__ == "__main__": + test_torsional_angle() + + diff --git a/mlcolvar/core/transform/descriptors/utils.py b/mlcolvar/core/transform/descriptors/utils.py new file mode 100644 index 00000000..b04c85a1 --- /dev/null +++ b/mlcolvar/core/transform/descriptors/utils.py @@ -0,0 +1,298 @@ +import torch +from typing import Union + +def sanitize_positions_shape(pos: torch.Tensor, + n_atoms: int): + """Sanitize positions tensor to have [batch, atoms, dims=3] shape + + Parameters + ---------- + pos : torch.Tensor + Positions of the atoms, they can be given with shapes: + - Shape: (n_batch (optional), n_atoms * 3), i.e [ [x1,y1,z1, x2,y2,z2, .... xn,yn,zn] ] + - Shape: (n_batch (optional), n_atoms, 3), i.e [ [ [x1,y1,z1], [x2,y2,z2], .... [xn,yn,zn] ] ] + n_atoms : int + Number of atoms + """ + # check if we have batch dimension in positions tensor + + if len(pos.shape)==3: + # check that index 0: batch, 1: atom, 2: coords + if pos.shape[1] != n_atoms: + raise ValueError(f"The given positions tensor has the wrong format, probably the wrong number of atoms. Expected {n_atoms} found {pos.shape[1]}") + if pos.shape[2] != 3: + raise ValueError(f"The given position tensor has the wrong format, probably the wrong number of spatial coordinates. Expected 3 found {pos.shape[2]}") + + if len(pos.shape)==2: + # check that index 0: atoms, 1: coords + if pos.shape[0]==n_atoms and pos.shape[1] == 3: + pos = pos.unsqueeze(0) # add batch dimension + # check that is not 0: batch, 1: atom*coords + elif not pos.shape[1] == int(n_atoms * 3): + raise ValueError(f"The given positions tensor has the wrong format, found {pos.shape}, expected either {[n_atoms, 3]} or {-1, n_atoms*3}") + + if len(pos.shape)==1: + # check that index 0: atoms*coord + if len(pos) != n_atoms*3: + raise ValueError(f"The given positions tensor has the wrong format. It should be at least of shape {int(n_atoms*3)}, found {pos.shape[0]}") + # else: + # pos = pos.unsqueeze(0) # add batch dimension + + pos = torch.reshape(pos, (-1, n_atoms, 3)) + + batch_size = pos.shape[0] + return pos, batch_size + +def sanitize_cell_shape(cell: Union[float, torch.Tensor, list]): + # Convert cell to tensor and shape it to have 3 dims + if isinstance(cell, float) or isinstance(cell, int): + cell = torch.Tensor([cell]) + elif isinstance(cell, list): + cell = torch.Tensor(cell) + + if cell.shape[0] != 1 and cell.shape[0] != 3: + raise ValueError(f"Cell must have either shape (1) or (3). Found {cell.shape} ") + + if isinstance(cell, torch.Tensor): + # TODO assert size makes sense if you directly pass a tensor + if len(cell) != 3: + cell = torch.tile(cell, (3,)) + + return cell + +def compute_distances_matrix(pos: torch.Tensor, + n_atoms: int, + PBC: bool, + cell: Union[float, list], + vector: bool = False, + scaled_coords: bool = False, + ) -> torch.Tensor: + """Compute the pairwise distances matrix from batches of atomic coordinates. + The matrix is symmetric, of size (n_atoms,n_atoms) and i,j-th element gives the distance between atoms i and j. + Optionally can return the vector distances. + + Parameters + ---------- + pos : torch.Tensor + Positions of the atoms, they can be given with shapes: + - Shape: (n_batch (optional), n_atoms * 3), i.e [ [x1,y1,z1, x2,y2,z2, .... xn,yn,zn] ] + - Shape: (n_batch (optional), n_atoms, 3), i.e [ [ [x1,y1,z1], [x2,y2,z2], .... [xn,yn,zn] ] ] + n_atoms : int + Number of atoms + PBC : bool + Switch for Periodic Boundary Conditions use + cell : Union[float, list] + Dimensions of the real cell, orthorombic-like cells only, by default False + vector : bool, optional + Switch to return vector distances + scaled_coords : bool, optional + Switch for coordinates scaled on cell's vectors use + + Returns + ------- + torch.Tensor + Matrix of the scalar pairwise distances, index map: (batch_idx, atom_i_idx, atom_j_idx) + Enabling `vector=True` can return the vector components of the distances, index map: (batch_idx, atom_i_idx, atom_j_idx, component_idx) + """ + # compute distances components, keep only first element of the output tuple + # ======================= CHECKS ======================= + pos, batch_size = sanitize_positions_shape(pos, n_atoms) + cell = sanitize_cell_shape(cell) + + # Set which cell to be used for PBC + if scaled_coords: + pbc_cell = torch.Tensor([1., 1., 1.]) + else: + pbc_cell = cell + + # ======================= COMPUTE ======================= + pos = torch.reshape(pos, (batch_size, n_atoms, 3)) # this preserves the order when the pos are passed as a list + pos = torch.transpose(pos, 1, 2) + pos = pos.reshape((batch_size, 3, n_atoms)) + + # expand tiling the coordinates to a tensor of shape (n_batch, 3, n_atoms, n_atoms) + pos_expanded = torch.tile(pos,(1, 1, n_atoms)).reshape(batch_size, 3, n_atoms, n_atoms) + + # compute the distances with transpose trick + # This works only with orthorombic cells + dist_components = pos_expanded - torch.transpose(pos_expanded, -2, -1) # transpose over the atom index dimensions + + # get PBC shifts + if PBC: + shifts = torch.zeros_like(dist_components) + # avoid loop if cell is cubic + if pbc_cell[0]==pbc_cell[1] and pbc_cell[1]==pbc_cell[2]: + shifts = torch.div(dist_components, pbc_cell[0]/2, rounding_mode='trunc') + shifts = torch.div(shifts + 1*torch.sign(shifts), 2, rounding_mode='trunc' )*pbc_cell[0] + + else: + # loop over dimensions of the pbc_cell + for d in range(3): + shifts[:, d, :, :] = torch.div(dist_components[:, d, :, :], pbc_cell[d]/2, rounding_mode='trunc') + shifts[:, d, :, :] = torch.div(shifts[:, d, :, :] + 1*torch.sign(shifts[:, d, :, :]), 2, rounding_mode='trunc' )*pbc_cell[d]/2 + + # apply shifts + dist_components = dist_components - shifts + + # if we used scaled coords we need to get back to real distances + if scaled_coords: + dist_components = torch.einsum('bijk,i->bijk', dist_components, cell) + + if vector: + return dist_components + else: + # mask out diagonal --> to keep the derivatives safe + mask_diag = ~torch.eye(n_atoms, dtype=bool) + mask_diag = torch.tile(mask_diag, (batch_size, 1, 1)) + + # sum squared components and get final distance + dist = torch.sum( torch.pow(dist_components, 2), 1 ) + dist[mask_diag] = torch.sqrt( dist[mask_diag]) + return dist + +def apply_cutoff(x: torch.Tensor, + cutoff: float, + mode: str = 'continuous', + switching_function = None) -> torch.Tensor: + """Apply a cutoff to a quantity. + Returns 1 below the cutoff and 0 above + + Parameters + ---------- + x : torch.Tensor + Quantity on which the cutoff has to be applied + cutoff : float + Value of the cutoff. In case of distances it must be given in the real units + mode : str, optional + Application mode for the cutoff, either 'continuous'or 'discontinuous', by default 'continuous' + This can be either: + - 'continuous': applies a switching function and gives stable derivatives accordingly + - 'discontinuous': sets to one what is below the cutoff and to zero what is above. The derivatives may be problematic + switching_function : function, optional + Switching function to be applied if in continuous mode, by default None. + This can be either a user-defined and torch-based function or a method of class SwitchingFuncitons + + Returns + ------- + torch.Tensor + Cutoffed quantity + """ + x_clone = torch.clone(x) + if mode == 'continuous' and switching_function is None: + raise ValueError('switching_function is required to use continuous mode! Set This can be either a user-defined and torch-based function or a method of class switching_functions/SwitchingFunctions') + + batch_size = x.shape[0] + mask_diag = ~torch.eye(x.shape[-1], dtype=bool) + mask_diag = torch.tile(mask_diag, (batch_size, 1, 1)) + + if mode == 'continuous': + x_clone[mask_diag] = switching_function( x_clone[mask_diag] ) + + if mode == 'discontinuous': + mask_cutoff = torch.ge(x_clone, cutoff) + x_clone[mask_cutoff] = x_clone[mask_cutoff] * 0 + mask = torch.logical_and(~mask_cutoff, mask_diag) + x_clone[mask] = x_clone[mask] ** 0 + return x_clone + + +def compute_adjacency_matrix(pos: torch.Tensor, + mode: str, + cutoff: float, + n_atoms: int, + PBC: bool, + cell: Union[float, list], + scaled_coords: bool = False, + switching_function = None) -> torch.Tensor: + """Initialize an adjacency matrix object. + + Parameters + ---------- + pos : torch.Tensor + Positions of the atoms, they can be given with shapes: + - Shape: (n_batch (optional), n_atoms * 3), i.e [ [x1,y1,z1, x2,y2,z2, .... xn,yn,zn] ] + - Shape: (n_batch (optional), n_atoms, 3), i.e [ [ [x1,y1,z1], [x2,y2,z2], .... [xn,yn,zn] ] ] + mode : str + Mode for cutoff application, either: + - 'continuous': applies a switching function to the distances which can be specified with switching_function keyword, has stable derivatives + - 'discontinuous': set at zero everything above the cutoff and one below, derivatives may be be incorrect + cutoff : float + Cutoff for the adjacency criterion + n_atoms : int + Number of atoms in the system + PBC : bool + Switch for Periodic Boundary Conditions use + cell : Union[float, list] + Dimensions of the real cell, orthorombic-like cells only + scaled_coords : bool + Switch for coordinates scaled on cell's vectors use, by default False + switching_function : _type_, optional + Switching function to be applied for the cutoff, can be either initialized as a switching_functions/SwitchingFunctions class or a simple function, by default None + + Returns + ------- + torch.Tensor + Adjacency matrix of all the n_atoms according to cutoff + """ + dist = compute_distances_matrix(pos=pos, + n_atoms=n_atoms, + PBC=PBC, + cell=cell, + scaled_coords=scaled_coords) + adj_matrix = apply_cutoff(x=dist, + cutoff=cutoff, + mode=mode, + switching_function = switching_function) + return adj_matrix + + +def test_applycutoff(): + from mlcolvar.core.transform.tools.switching_functions import SwitchingFunctions + + n_atoms=2 + pos = torch.Tensor([ [ [0., 0., 0.], + [1., 1., 1.] ], + [ [0., 0., 0.], + [1., 1., 1.] ] ] + ) + cell = torch.Tensor([1., 2, 1.]) + cutoff = 1.8 + + # TEST no scaled coords + out = compute_distances_matrix(pos=pos, n_atoms=n_atoms, PBC=True, cell=cell, scaled_coords=False) + switching_function=SwitchingFunctions(in_features=n_atoms**2, name='Fermi', cutoff=cutoff, options={'q':0.01}) + apply_cutoff(x=out, cutoff=cutoff, mode='continuous', switching_function=switching_function) + + def silly_switch(x): + return torch.pow(x, 2) + switching_function = silly_switch + apply_cutoff(x=out, cutoff=cutoff, mode='continuous', switching_function=switching_function) + apply_cutoff(x=out, cutoff=cutoff, mode='discontinuous') + + # TEST scaled coords + pos = torch.einsum('bij,j->bij', pos, 1/cell) + out = compute_distances_matrix(pos=pos, n_atoms=2, PBC=True, cell=cell, scaled_coords=True) + switching_function=SwitchingFunctions(in_features=n_atoms**2, name='Fermi', cutoff=cutoff, options={'q':0.01}) + apply_cutoff(x=out, cutoff=cutoff, mode='continuous', switching_function=switching_function) + apply_cutoff(x=out, cutoff=cutoff, mode='discontinuous') + + +def test_adjacency_matrix(): + from mlcolvar.core.transform.tools.switching_functions import SwitchingFunctions + + n_atoms=2 + pos = torch.Tensor([ [ [0., 0., 0.], + [1., 1., 1.] ], + [ [0., 0., 0.], + [1., 1.1, 1.] ] ] + ) + + cell = torch.Tensor([1., 2., 1.]) + cutoff = 1.8 + switching_function=SwitchingFunctions(in_features=n_atoms*3, name='Fermi', cutoff=cutoff, options={'q' : 0.01}) + + compute_adjacency_matrix(pos=pos, mode='continuous', cutoff=cutoff, n_atoms=n_atoms, PBC=True, cell=cell, scaled_coords=False, switching_function=switching_function) + +if __name__ == "__main__": + test_applycutoff() + test_adjacency_matrix() \ No newline at end of file diff --git a/mlcolvar/core/transform/tools/__init__.py b/mlcolvar/core/transform/tools/__init__.py new file mode 100644 index 00000000..d3543495 --- /dev/null +++ b/mlcolvar/core/transform/tools/__init__.py @@ -0,0 +1,5 @@ +__all__ = ["ContinuousHistogram", "Normalization", "SwitchingFunctions"] + +from .continuous_hist import * +from .normalization import * +from .switching_functions import * \ No newline at end of file diff --git a/mlcolvar/core/transform/tools/continuous_hist.py b/mlcolvar/core/transform/tools/continuous_hist.py new file mode 100644 index 00000000..88c0b4d5 --- /dev/null +++ b/mlcolvar/core/transform/tools/continuous_hist.py @@ -0,0 +1,68 @@ +import torch + +from mlcolvar.core.transform import Transform +from mlcolvar.core.transform.tools.utils import easy_KDE + +__all__ = ["ContinuousHistogram"] + +class ContinuousHistogram(Transform): + """ + Compute continuous histogram using Gaussian kernels + """ + + def __init__(self, + in_features: int, + min: float, + max: float, + bins: int, + sigma_to_center: float = 1.0) -> torch.Tensor : + """Computes the continuous histogram of a quantity using Gaussian kernels + + Parameters + ---------- + in_features : int + Number of inputs + min : float + Minimum value of the histogram + max : float + Maximum value of the histogram + bins : int + Number of bins of the histogram + sigma_to_center : float, optional + Sigma value in bin_size units, by default 1.0 + + + Returns + ------- + torch.Tensor + Values of the histogram for each bin + """ + + super().__init__(in_features=in_features, out_features=bins) + + self.min = min + self.max = max + self.bins = bins + self.sigma_to_center = sigma_to_center + + def compute_hist(self, x): + hist = easy_KDE(x=x, + n_input=self.in_features, + min_max=[self.min, self.max], + n=self.bins, + sigma_to_center=self.sigma_to_center) + return hist + + def forward(self, x: torch.Tensor): + x = self.compute_hist(x) + return x + +def test_continuous_histogram(): + x = torch.randn((5,100)) + x.requires_grad = True + hist = ContinuousHistogram(in_features=100, min=-1, max=1, bins=10, sigma_to_center=1) + out = hist(x) + out.sum().backward() + +if __name__ == "__main__": + test_continuous_histogram() \ No newline at end of file diff --git a/mlcolvar/core/transform/normalization.py b/mlcolvar/core/transform/tools/normalization.py similarity index 95% rename from mlcolvar/core/transform/normalization.py rename to mlcolvar/core/transform/tools/normalization.py index b70b4205..59ed9c4d 100644 --- a/mlcolvar/core/transform/normalization.py +++ b/mlcolvar/core/transform/tools/normalization.py @@ -1,5 +1,6 @@ import torch -from mlcolvar.core.transform.utils import batch_reshape, Statistics +from mlcolvar.core.transform.utils import Statistics +from mlcolvar.core.transform.tools.utils import batch_reshape from mlcolvar.core.transform import Transform __all__ = ["Normalization"] @@ -206,23 +207,20 @@ def test_normalization(): norm = Normalization(in_features, mean=stats["mean"], range=stats["std"]) y = norm(X) - # print(X.mean(0),y.mean(0)) - # print(X.std(0),y.std(0)) # test inverse z = norm.inverse(y) - # print(X.mean(0),z.mean(0)) - # print(X.std(0),z.std(0)) + assert(torch.allclose(X.mean(0), z.mean(0))) + assert(torch.allclose(X.std(0) , z.std(0))) # test inverse class inverse = Inverse(norm) q = inverse(y) - # print(X.mean(0),q.mean(0)) - # print(X.std(0),q.std(0)) + assert(torch.allclose(X.mean(0), q.mean(0))) + assert(torch.allclose(X.std(0), q.std(0))) norm = Normalization( in_features, mean=stats["mean"], range=stats["std"], mode="min_max" ) - if __name__ == "__main__": test_normalization() diff --git a/mlcolvar/core/transform/tools/switching_functions.py b/mlcolvar/core/transform/tools/switching_functions.py new file mode 100644 index 00000000..04c1dd1a --- /dev/null +++ b/mlcolvar/core/transform/tools/switching_functions.py @@ -0,0 +1,81 @@ +import torch + +from mlcolvar.core.transform import Transform + + +__all__ = ["SwitchingFunctions"] + +class SwitchingFunctions(Transform): + """ + Common switching functions + """ + SWITCH_FUNCS = ['Fermi', 'Rational'] + + def __init__(self, + in_features: int, + name: str, + cutoff: float, + options: dict = None): + f"""Initialize switching function object + + Parameters + ---------- + name : str + Name of the switching function to be used, available {",".join(self.SWITCH_FUNCS)} + cutoff : float + Cutoff for the swtiching functions + options : dict, optional + Dictionary with all the arguments of the switching function, by default None + """ + super().__init__(in_features=in_features, out_features=in_features) + + self.name = name + self.cutoff = cutoff + if options is None: + options = {} + self.options = options + + if name not in self.SWITCH_FUNCS: + raise NotImplementedError(f'''The switching function {name} is not implemented in this class. The available options are: {",".join(self.SWITCH_FUNCS)}. + You can initialize it as a method of the SwitchingFunctions class and tell us on Github, contributions are welcome!''') + + def forward(self, x: torch.Tensor): + switch_function = getattr(self, f'{self.name}_switch') + y = switch_function(x, self.cutoff, **self.options) + return y + + # ========================== define here switching functions ========================== + def Fermi_switch(self, + x: torch.Tensor, + cutoff: float, + q: float = 0.01, + prefactor_cutoff: float = 1.0): + y = torch.div( 1, ( 1 + torch.exp( torch.div((x - prefactor_cutoff*cutoff) , q )))) + return y + + def Rational_switch(self, + x: torch.Tensor, + cutoff: float, + n: int = 6, + m: int = 12, + eps: float = 1e-8, + prefactor_cutoff: float = 1.0): + y = torch.div((1 - torch.pow(x/(prefactor_cutoff*cutoff), n) + eps) , (1 - torch.pow(x/(prefactor_cutoff*cutoff), m) + 2*eps) ) + return y + + +def test_switchingfunctions(): + x = torch.Tensor([1., 2., 3.]) + cutoff = 2 + switch = SwitchingFunctions(in_features=len(x), name='Fermi', cutoff=cutoff) + switch(x) + + switch = SwitchingFunctions(in_features=len(x), name='Fermi', cutoff=cutoff, options = {'q' : 0.5}) + switch(x) + + switch = SwitchingFunctions(in_features=len(x), name='Rational', cutoff=cutoff, options = {'n' : 6, 'm' : 12}) + switch(x) + + +if __name__ == "__main__": + test_switchingfunctions() \ No newline at end of file diff --git a/mlcolvar/core/transform/tools/utils.py b/mlcolvar/core/transform/tools/utils.py new file mode 100644 index 00000000..6705ffff --- /dev/null +++ b/mlcolvar/core/transform/tools/utils.py @@ -0,0 +1,93 @@ +import torch +import numpy as np + +from typing import Union, List + +def batch_reshape(t: torch.Tensor, size: torch.Size) -> torch.Tensor: + """Return value reshaped according to size. + In case of batch unsqueeze and expand along the first dimension. + For single inputs just pass. + + Parameters + ---------- + mean and range + + """ + if len(size) == 1: + return t + if len(size) == 2: + batch_size = size[0] + x_size = size[1] + t = t.unsqueeze(0).expand(batch_size, x_size) + else: + raise ValueError( + f"Input tensor must of shape (n_features) or (n_batch,n_features), not {size} (len={len(size)})." + ) + return t + + +def _gaussian_expansion(x : torch.Tensor, + centers : torch.Tensor, + sigma : torch.Tensor): + """Computes the values in x of a set of Gaussian kernels centered on centers and with width sigma + + Parameters + ---------- + x : torch.Tensor + Input value(s) + centers : torch.Tensor + Centers of the Gaussian kernels + sigma : torch.Tensor + Width of the Gaussian kernels + """ + return torch.exp(- torch.div(torch.pow(x-centers, 2), 2*torch.pow(sigma,2) )) + +def easy_KDE(x : torch.Tensor, + n_input : int, + min_max : Union[List[float], np.ndarray], + n : int, + sigma_to_center : float = 1.0, + normalize : bool = False, + return_bins : bool = False) -> torch.Tensor: + """Compute histogram using KDE with Gaussian kernels + + Parameters + ---------- + x : torch.Tensor + Input + n_input : int + Number of inputs per batch + min_max : Union[list[float], np.ndarray] + Minimum and maximum values for the histogram + n : int + Number of Gaussian kernels + sigma_to_center : float, optional + Sigma value in bin_size units, by default 1.0 + normalize : bool, optional + Switch for normalization of the histogram to sum to n_input, by default False + return_bins : bool, optional + Switch to return the bins of the histogram alongside the values, by default False + + Returns + ------- + torch.Tensor + Values of the histogram for each bin. The bins can be optionally returned enabling `return_bins`. + """ + if len(x.shape) == 1: + x = torch.reshape(x, (1, n_input, 1)) + if x.shape[-1] != 1: + x = x.unsqueeze(-1) + if x.shape[0] == n_input: + x = x.unsqueeze(0) + + centers = torch.linspace(min_max[0], min_max[1], n, device=x.device) + bins = torch.clone(centers) + sigma = (centers[1] - centers[0]) * sigma_to_center + centers = torch.tile(centers, dims=(n_input,1)) + out = torch.sum(_gaussian_expansion(x, centers, sigma), dim=1) + if normalize: + out = torch.div(out, torch.sum(out, -1, keepdim=True)) * n_input + if return_bins: + return out, bins + else: + return out \ No newline at end of file diff --git a/mlcolvar/core/transform/transform.py b/mlcolvar/core/transform/transform.py index 71422e8d..1cebfa44 100644 --- a/mlcolvar/core/transform/transform.py +++ b/mlcolvar/core/transform/transform.py @@ -31,7 +31,7 @@ def setup_from_datamodule(self, datamodule): pass def forward(self, X: torch.Tensor): - raise NotImplementedError + raise NotImplementedError() def teardown(self): pass diff --git a/mlcolvar/core/transform/utils.py b/mlcolvar/core/transform/utils.py index e83ea908..cdbca532 100644 --- a/mlcolvar/core/transform/utils.py +++ b/mlcolvar/core/transform/utils.py @@ -2,9 +2,31 @@ from typing import Union from warnings import warn -__all__ = ["Statistics", "Inverse"] +__all__ = ["Inverse", "Statistics"] +class Inverse(torch.nn.Module): + "Wrapper to return the inverse method of a module as a torch.nn.Module" + + def __init__(self, module: torch.nn.Module): + """Return the inverse method of a module as a torch.nn.Module + + Parameters + ---------- + module : torch.nn.Module + Module to be inverted + """ + super().__init__() + if not hasattr(module, "inverse"): + raise AttributeError("The given module does not have a 'inverse' method!") + self.module = module + + def inverse(self, *args, **kwargs): + return self.module(*args, **kwargs) + + def forward(self, *args, **kwargs): + return self.module.inverse(*args, **kwargs) + class Statistics(object): """ Calculate statistics (running mean and std.dev based on Welford's algorithm, as well as min and max). @@ -79,58 +101,35 @@ def __repr__(self): for prop in self.properties: repr += f"{prop}: {getattr(self,prop).numpy()} " return repr + +def test_inverse(): + from mlcolvar.core.transform import Transform + # create dummy model to scale the average to 0 + class ForwardModel(Transform): + def __init__(self, in_features=5, out_features=5): + super().__init__(in_features=5, out_features=5) + self.mean = 0 -class Inverse(torch.nn.Module): - "Wrapper to return the inverse method of a module as a torch.nn.Module" + def update_mean(self, x): + self.mean = torch.mean(x) + + def forward(self, x): + x = x - self.mean + return x - def __init__(self, module: torch.nn.Module): - """Return the inverse method of a module as a torch.nn.Module - - Parameters - ---------- - module : torch.nn.Module - Module to be inverted - """ - super().__init__() - if not hasattr(module, "inverse"): - raise AttributeError("The given module does not have a 'inverse' method!") - self.module = module - - def inverse(self, *args, **kwargs): - return self.module.inverse(*args, **kwargs) - - def forward(self, *args, **kwargs): - return self.inverse(*args, **kwargs) - - -def batch_reshape(t: torch.Tensor, size: torch.Size) -> torch.Tensor: - """Return value reshaped according to size. - In case of batch unsqueeze and expand along the first dimension. - For single inputs just pass. - - Parameters - ---------- - mean and range - - """ - if len(size) == 1: - return t - if len(size) == 2: - batch_size = size[0] - x_size = size[1] - t = t.unsqueeze(0).expand(batch_size, x_size) - else: - raise ValueError( - f"Input tensor must of shape (n_features) or (n_batch,n_features), not {size} (len={len(size)})." - ) - return t + def inverse(self, x): + x = x + self.mean + return x + forward_model = ForwardModel() + inverse_model = Inverse(forward_model) -# ================================================================================================ -# ======================================== TEST FUNCTIONS ======================================== -# ================================================================================================ + input = torch.rand(5) + forward_model.update_mean(input) + out = forward_model(input) + assert(input.mean() == inverse_model(out).mean()) def test_statistics(): # create fake data @@ -173,8 +172,9 @@ def test_statistics(): stats[key].update(batch[key]) for key in loader.keys: - print(key, stats[key]) + print(key,stats[key]) if __name__ == "__main__": + test_inverse() test_statistics() diff --git a/mlcolvar/tests/test_core_transform_adjacencymatrix.py b/mlcolvar/tests/test_core_transform_adjacencymatrix.py new file mode 100644 index 00000000..5e4565e0 --- /dev/null +++ b/mlcolvar/tests/test_core_transform_adjacencymatrix.py @@ -0,0 +1,4 @@ +from mlcolvar.core.transform.descriptors.eigs_adjacency_matrix import test_eigs_of_adj_matrix + +if __name__ == "__main__": + test_eigs_of_adj_matrix() \ No newline at end of file diff --git a/mlcolvar/tests/test_core_transform_continuoushistogram.py b/mlcolvar/tests/test_core_transform_continuoushistogram.py new file mode 100644 index 00000000..c10437cb --- /dev/null +++ b/mlcolvar/tests/test_core_transform_continuoushistogram.py @@ -0,0 +1,4 @@ +from mlcolvar.core.transform.tools.continuous_hist import test_continuous_histogram + +if __name__ == "__main__": + test_continuous_histogram() \ No newline at end of file diff --git a/mlcolvar/tests/test_core_transform_coordinationnumbers.py b/mlcolvar/tests/test_core_transform_coordinationnumbers.py new file mode 100644 index 00000000..a34a31de --- /dev/null +++ b/mlcolvar/tests/test_core_transform_coordinationnumbers.py @@ -0,0 +1,4 @@ +from mlcolvar.core.transform.descriptors.coordination_numbers import test_coordination_number + +if __name__ == "__main__": + test_coordination_number() \ No newline at end of file diff --git a/mlcolvar/tests/test_core_transform_descriptors_utils.py b/mlcolvar/tests/test_core_transform_descriptors_utils.py new file mode 100644 index 00000000..df2f82bc --- /dev/null +++ b/mlcolvar/tests/test_core_transform_descriptors_utils.py @@ -0,0 +1,5 @@ +from mlcolvar.core.transform.descriptors.utils import test_adjacency_matrix,test_applycutoff + +if __name__ == "__main__": + test_applycutoff() + test_adjacency_matrix() \ No newline at end of file diff --git a/mlcolvar/tests/test_core_transform_multipledescriptors.py b/mlcolvar/tests/test_core_transform_multipledescriptors.py new file mode 100644 index 00000000..6c3e92bf --- /dev/null +++ b/mlcolvar/tests/test_core_transform_multipledescriptors.py @@ -0,0 +1,4 @@ +from mlcolvar.core.transform.descriptors.multiple_descriptors import test_multipledescriptors + +if __name__ == "__main__": + test_multipledescriptors() \ No newline at end of file diff --git a/mlcolvar/tests/test_core_transform_normalization.py b/mlcolvar/tests/test_core_transform_normalization.py index f9cb7211..e54bcd21 100644 --- a/mlcolvar/tests/test_core_transform_normalization.py +++ b/mlcolvar/tests/test_core_transform_normalization.py @@ -1,6 +1,4 @@ -import pytest - -from mlcolvar.core.transform.normalization import test_normalization +from mlcolvar.core.transform.tools.normalization import test_normalization if __name__ == "__main__": test_normalization() diff --git a/mlcolvar/tests/test_core_transform_pairwisedistances.py b/mlcolvar/tests/test_core_transform_pairwisedistances.py new file mode 100644 index 00000000..ae42838e --- /dev/null +++ b/mlcolvar/tests/test_core_transform_pairwisedistances.py @@ -0,0 +1,4 @@ +from mlcolvar.core.transform.descriptors.pairwise_distances import test_pairwise_distances + +if __name__ == "__main__": + test_pairwise_distances() \ No newline at end of file diff --git a/mlcolvar/tests/test_core_transform_switchingfunctions.py b/mlcolvar/tests/test_core_transform_switchingfunctions.py new file mode 100644 index 00000000..a458f457 --- /dev/null +++ b/mlcolvar/tests/test_core_transform_switchingfunctions.py @@ -0,0 +1,4 @@ +from mlcolvar.core.transform.tools.switching_functions import test_switchingfunctions + +if __name__ == "__main__": + test_switchingfunctions() \ No newline at end of file diff --git a/mlcolvar/tests/test_core_transform_torsionalangle.py b/mlcolvar/tests/test_core_transform_torsionalangle.py new file mode 100644 index 00000000..46dbe103 --- /dev/null +++ b/mlcolvar/tests/test_core_transform_torsionalangle.py @@ -0,0 +1,4 @@ +from mlcolvar.core.transform.descriptors.torsional_angle import test_torsional_angle + +if __name__ == "__main__": + test_torsional_angle() \ No newline at end of file diff --git a/mlcolvar/tests/test_core_transform_utils.py b/mlcolvar/tests/test_core_transform_utils.py index a6da087b..f7af78d8 100644 --- a/mlcolvar/tests/test_core_transform_utils.py +++ b/mlcolvar/tests/test_core_transform_utils.py @@ -1,6 +1,5 @@ -import pytest - -from mlcolvar.core.transform.utils import test_statistics +from mlcolvar.core.transform.utils import test_inverse, test_statistics if __name__ == "__main__": - test_statistics() + test_inverse() + test_statistics() \ No newline at end of file