Skip to content

Commit

Permalink
Polishing and improving docs and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoTrizio committed May 2, 2024
1 parent 0a0063e commit 001bb0f
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 108 deletions.
94 changes: 4 additions & 90 deletions mlcolvar/core/transform/descriptors/coordination_numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,82 +107,9 @@ def forward(self, pos):
return coord_numbers


# def test_coordination_number():
# from mlcolvar.core.transform.switching_functions import SwitchingFunctions

# pos = torch.Tensor([[3.057573, -0.970874, 0.010963],
# [3.303463, -0.921235, -0.093758],
# [3.143809, -0.969587, 0.243834],
# [3.144826, -1.055678, 1.234909],
# [2.817620, -0.882992, 0.701040],
# [3.523643, -0.600038, 1.250236],
# [3.047862, -0.612398, 0.731384],
# [3.029544, -0.438634, 1.131293],
# [3.058277, -1.465733, 1.020088],
# [3.267026, -1.142289, 0.492955]])
# pos.requires_grad = True

# n_atoms = 10
# cutoff=0.5

# switching_function=SwitchingFunctions(in_features=n_atoms*3, name='Fermi', cutoff=cutoff, options={'q':0.05})

# model = CoordinationNumbers(group_A = [0, 1, 2],
# group_B = [3, 4, 5, 6, 7, 8, 9],
# cutoff= cutoff,
# n_atoms=n_atoms,
# PBC=False,
# cell=10, # fake
# scaled_coords=False,
# mode='continuous',
# switching_function=switching_function)

# out = model(pos)
# out.sum().backward()
# print(out)

def test_coordination_number():
from mlcolvar.core.transform.tools.switching_functions import SwitchingFunctions

# pos = torch.Tensor([[3.057573, -0.970874, 0.010963],
# [3.303463, -0.921235, -0.093758],
# [3.143809, -0.969587, 0.243834],
# [3.144826, -1.055678, 1.234909],
# [2.817620, -0.882992, 0.701040],
# [3.523643, -0.600038, 1.250236],
# [3.047862, -0.612398, 0.731384],
# [3.029544, -0.438634, 1.131293],
# [3.058277, -1.465733, 1.020088],
# [3.267026, -1.142289, 0.492955]])


pos = torch.Tensor([[-0.410219, -0.680065, -2.016121],
[-0.164329, -0.630426, -2.120843],
[-0.250341, -0.392700, -1.534535],
[-0.277187, -0.615506, -1.335904],
[-0.762276, -1.041939, -1.546581],
[-0.200766, -0.851481, -1.534129],
[ 0.051099, -0.898884, -1.628219],
[-1.257225, 1.671602, 0.166190],
[-0.486917, -0.902610, -1.554715],
[-0.020386, -0.566621, -1.597171],
[-0.507683, -0.541252, -1.540805],
[-0.527323, -0.206236, -1.532587]])


pos = torch.Tensor([[-0.410387, -0.677657, -2.018355],
[-0.163502, -0.626094, -2.123348],
[-0.250672, -0.389610, -1.536810],
[-0.275395, -0.612535, -1.338175],
[-0.762197, -1.037856, -1.547382],
[-0.200948, -0.847825, -1.536010],
[ 0.051170, -0.896311, -1.629396],
[-1.257530, 1.674078, 0.165089],
[-0.486894, -0.900076, -1.556366],
[-0.020235, -0.563252, -1.601229],
[-0.507242, -0.537527, -1.543025],
[-0.528576, -0.202031, -1.534733]])

# simple example based on calixarene water coordination numbers
pos = torch.Tensor([[[-0.410219, -0.680065, -2.016121],
[-0.164329, -0.630426, -2.120843],
[-0.250341, -0.392700, -1.534535],
Expand All @@ -208,22 +135,6 @@ def test_coordination_number():
[-0.507242, -0.537527, -1.543025],
[-0.528576, -0.202031, -1.534733]]])

# pos = torch.Tensor([[-0.186334, -0.576215, -2.052373],
# [ 0.039420, -0.635731, -2.193494],
# [-0.035417, -0.031627, -1.686382],
# [-0.150362, -0.453559, -1.576238],
# [ 0.049959, -0.267481, -1.556228],
# [-0.401991, -0.700807, -1.459524],
# [-0.649622, -0.853509, -1.617916],
# [-0.291217, -0.244090, -1.488739],
# [-0.029036, -0.804894, -1.580041],
# [ 0.090713, -1.018815, -1.676822],
# [-1.024799, 1.815174, -0.125380],
# [-0.372569, -0.792924, -1.721617],
# [-0.673611, -0.309695, -1.654188],
# [ 0.183207, -0.407987, -1.758594],
# [-0.464715, -0.479660, -1.722147],
# [-0.389905, -0.116370, -1.718337]])
cell = 4.0273098

pos.requires_grad = True
Expand All @@ -247,6 +158,9 @@ def test_coordination_number():
out.sum().backward()
print(out)

# TODO add reference value for check


if __name__ == "__main__":
test_coordination_number()

50 changes: 37 additions & 13 deletions mlcolvar/core/transform/descriptors/pairwise_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,30 +71,54 @@ def forward(self, x: torch.Tensor):
return x

def test_pairwise_distances():
# simple test based on alanine distances
pos_abs = torch.Tensor([[ 1.4970, 1.3861, -0.0273, -1.4933, 1.5070, -0.1133, -1.4473, -1.4193,
-0.0553, 1.4940, 1.4990, -0.2403, 1.4780, -1.4173, -0.3363, -1.4243,
-1.4093, -0.4293, 1.3530, -1.4313, -0.4183, 1.3060, 1.4750, -0.4333,
1.2970, -1.3233, -0.4643, 1.1670, -1.3253, -0.5354]])
pos_abs.requires_grad = True

pos = torch.Tensor([ [ [0., 0., 0.],
[1., 1., 1.],
[1., 1., 1.1] ],
[ [0., 0., 0.],
[1., 1.1, 1.],
[1., 1., 1.] ] ]
)
cell = torch.Tensor([3.0233])

cell = torch.Tensor([1., 2., 1.])
pos_scaled = pos_abs / cell

ref_distances = torch.Tensor([[0.1521, 0.2335, 0.2412, 0.3798, 0.4733, 0.4649, 0.4575, 0.5741, 0.6815,
0.1220, 0.1323, 0.2495, 0.3407, 0.3627, 0.3919, 0.4634, 0.5885, 0.2280,
0.2976, 0.3748, 0.4262, 0.4821, 0.5043, 0.6376, 0.1447, 0.2449, 0.2454,
0.2705, 0.3597, 0.4833, 0.1528, 0.1502, 0.2370, 0.2408, 0.3805, 0.2472,
0.3243, 0.3159, 0.4527, 0.1270, 0.1301, 0.2440, 0.2273, 0.2819, 0.1482]])

# cell = torch.Tensor([1., 2., 1.])

model = PairwiseDistances(n_atoms = 3,
model = PairwiseDistances(n_atoms = 10,
PBC = True,
cell = cell,
scaled_coords = False)
out = model(pos)
assert(out.reshape(pos.shape[0], -1).shape[-1] == model.out_features)
out = model(pos_abs)
assert(out.reshape(pos_abs.shape[0], -1).shape[-1] == model.out_features)
print((out - ref_distances).max())
assert(torch.allclose(out, ref_distances, atol=1e-3))
out.sum().backward()

model = PairwiseDistances(n_atoms = 3,

model = PairwiseDistances(n_atoms = 10,
PBC = True,
cell = cell,
scaled_coords = False,
slicing_indeces=[0, 2])
out = model(pos)
out = model(pos_abs)
assert(torch.allclose(out, ref_distances[:, [0, 2]], atol=1e-3))
out.sum().backward()

model = PairwiseDistances(n_atoms = 10,
PBC = True,
cell = cell,
scaled_coords = True)
out = model(pos_scaled)
assert(out.reshape(pos_scaled.shape[0], -1).shape[-1] == model.out_features)
assert(torch.allclose(out, ref_distances, atol=1e-3))
out.sum().backward()


if __name__ == "__main__":
test_pairwise_distances()
3 changes: 3 additions & 0 deletions mlcolvar/core/transform/descriptors/torsional_angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def forward(self, x):
return out[:, self.mode_idx]

def test_torsional_angle():
# simple test on alanine phi angle
pos = torch.Tensor([[[ 0.3887, -0.4169, -0.1212],
[ 0.4264, -0.4374, -0.0983],
[ 0.4574, -0.4136, -0.0931],
Expand Down Expand Up @@ -146,6 +147,8 @@ def test_torsional_angle():
print(angle)
angle.sum().backward()

# TODO add reference value for check


if __name__ == "__main__":
test_torsional_angle()
Expand Down
27 changes: 25 additions & 2 deletions mlcolvar/core/transform/tools/continuous_hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,36 @@

class ContHist(Transform):
"""
Compute continuous histogram with KDE-like method
Compute continuous histogram using Gaussian kernels
"""

def __init__(self,
in_features : int,
min : float,
max : float,
bins : int,
sigma_to_center : float) -> torch.Tensor :
sigma_to_center : float = 1.0) -> torch.Tensor :
"""Computes the continuous histogram of a quantity using Gaussian kernels
Parameters
----------
in_features : int
Number of inputs
min : float
Minimum value of the histogram
max : float
Maximum value of the histogram
bins : int
Number of bins of the histogram
sigma_to_center : float, optional
Sigma value in bin_size units, by default 1.0
Returns
-------
torch.Tensor
Values of the histogram for each bin
"""

super().__init__(in_features=in_features, out_features=bins)

Expand All @@ -42,12 +63,14 @@ def forward(self, x: torch.Tensor):

def test_continuous_histogram():
x = torch.randn((5,100))
x.requires_grad = True
hist = ContHist(in_features=100,
min=-1,
max=1,
bins=10,
sigma_to_center=1)
out = hist(x)
out.sum().backward()

if __name__ == "__main__":
test_continuous_histogram()
52 changes: 49 additions & 3 deletions mlcolvar/core/transform/tools/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import torch
import numpy as np

from typing import Union

__all__ = ["Statistics"]

Expand Down Expand Up @@ -102,10 +105,53 @@ def batch_reshape(t: torch.Tensor, size: torch.Size) -> torch.Tensor:
return t


def sym_func(x, centers, sigma):
def _gaussian_expansion(x : torch.Tensor,
centers : torch.Tensor,
sigma : torch.Tensor):
"""Computes the values in x of a set of Gaussian kernels centered on centers and with width sigma
Parameters
----------
x : torch.Tensor
Input value(s)
centers : torch.Tensor
Centers of the Gaussian kernels
sigma : torch.Tensor
Width of the Gaussian kernels
"""
return torch.exp(- torch.div(torch.pow(x-centers, 2), 2*torch.pow(sigma,2) ))

def easy_KDE(x, n_input, min_max, n, sigma_to_center, normalize=False, return_bins=False):
def easy_KDE(x : torch.Tensor,
n_input : int,
min_max : Union[list[float], np.ndarray],
n : int,
sigma_to_center : float = 1.0,
normalize : bool = False,
return_bins : bool = False) -> torch.Tensor:
"""Compute histogram using KDE with Gaussian kernels
Parameters
----------
x : torch.Tensor
Input
n_input : int
Number of inputs per batch
min_max : Union[list[float], np.ndarray]
Minimum and maximum values for the histogram
n : int
Number of Gaussian kernels
sigma_to_center : float, optional
Sigma value in bin_size units, by default 1.0
normalize : bool, optional
Switch for normalization of the histogram to sum to n_input, by default False
return_bins : bool, optional
Switch to return the bins of the histogram alongside the values, by default False
Returns
-------
torch.Tensor
Values of the histogram for each bin. The bins can be optionally returned enabling `return_bins`.
"""
if len(x.shape) == 1:
x = torch.reshape(x, (1, n_input, 1))
if x.shape[-1] != 1:
Expand All @@ -117,7 +163,7 @@ def easy_KDE(x, n_input, min_max, n, sigma_to_center, normalize=False, return_bi
bins = torch.clone(centers)
sigma = (centers[1] - centers[0]) * sigma_to_center
centers = torch.tile(centers, dims=(n_input,1))
out = torch.sum(sym_func(x, centers, sigma), dim=1)
out = torch.sum(_gaussian_expansion(x, centers, sigma), dim=1)
if normalize:
out = torch.div(out, torch.sum(out, -1, keepdim=True)) * n_input
if return_bins:
Expand Down

0 comments on commit 001bb0f

Please sign in to comment.