diff --git a/mlcolvar/core/transform/descriptors/coordination_numbers.py b/mlcolvar/core/transform/descriptors/coordination_numbers.py index d95379d8..f5ba319a 100644 --- a/mlcolvar/core/transform/descriptors/coordination_numbers.py +++ b/mlcolvar/core/transform/descriptors/coordination_numbers.py @@ -2,7 +2,7 @@ import numpy as np from mlcolvar.core.transform import Transform -from mlcolvar.core.transform.descriptors.utils import compute_distances_matrix, apply_cutoff +from mlcolvar.core.transform.descriptors.utils import compute_distances_matrix, apply_cutoff, sanitize_positions_shape from typing import Union @@ -57,10 +57,11 @@ def __init__(self, # parse args self.group_A = group_A - self.group_A_size = len(group_A) + self._group_A_size = len(group_A) self.group_B = group_B - self.group_B_size = len(group_B) - self.reordering = np.concatenate([self.group_A, self.group_B]) + self._group_B_size = len(group_B) + self._reordering = np.concatenate((self.group_A, self.group_B)) + self.cutoff = cutoff @@ -74,7 +75,8 @@ def __init__(self, def compute_coordination_number(self, pos): # move the group A elements to first positions - # pos = pos[self.reordering] + pos, batch_size = sanitize_positions_shape(pos, self.n_atoms) + pos = pos[:, self._reordering, :] dist = compute_distances_matrix(pos=pos, n_atoms=self.n_atoms, PBC=self.PBC, @@ -90,11 +92,11 @@ def compute_coordination_number(self, pos): switching_function=self.switching_function) # we can throw away part of the matrix as it is repeated uselessly - contributions = contributions[:, :self.group_A_size, :] + contributions = contributions[:, :self._group_A_size, :] # and also ensure that the AxA part of the matrix is zero, we need also to preserve the gradients mask = torch.ones_like(contributions) - mask[:, :self.group_A_size, :self.group_A_size] = 0 + mask[:, :self._group_A_size, :self._group_A_size] = 0 contributions = contributions*mask # compute coordination @@ -156,10 +158,51 @@ def test_coordination_number(): out = model(pos) out.sum().backward() - print(out) - # TODO add reference value for check + # we shift by hand the 0,1 atoms with 2,3 + pos = torch.Tensor([[[-0.250341, -0.392700, -1.534535], + [-0.277187, -0.615506, -1.335904], + [-0.410219, -0.680065, -2.016121], + [-0.164329, -0.630426, -2.120843], + [-0.762276, -1.041939, -1.546581], + [-0.200766, -0.851481, -1.534129], + [ 0.051099, -0.898884, -1.628219], + [-1.257225, 1.671602, 0.166190], + [-0.486917, -0.902610, -1.554715], + [-0.020386, -0.566621, -1.597171], + [-0.507683, -0.541252, -1.540805], + [-0.527323, -0.206236, -1.532587]], + [[-0.250672, -0.389610, -1.536810], + [-0.275395, -0.612535, -1.338175], + [-0.410387, -0.677657, -2.018355], + [-0.163502, -0.626094, -2.123348], + [-0.762197, -1.037856, -1.547382], + [-0.200948, -0.847825, -1.536010], + [ 0.051170, -0.896311, -1.629396], + [-1.257530, 1.674078, 0.165089], + [-0.486894, -0.900076, -1.556366], + [-0.020235, -0.563252, -1.601229], + [-0.507242, -0.537527, -1.543025], + [-0.528576, -0.202031, -1.534733]]]) + + pos.requires_grad = True + switching_function=SwitchingFunctions(in_features=n_atoms*3, name='Rational', cutoff=cutoff, options={'n': 2, 'm' : 6, 'eps' : 1e0}) + model = CoordinationNumbers(group_A = [2, 3], + group_B = [0, 1, 4, 5, 6, 7, 8, 9, 10, 11], + cutoff= cutoff, + n_atoms=n_atoms, + PBC=True, + cell=cell, + mode='continuous', + scaled_coords=False, + switching_function=switching_function) + + out_2 = model(pos) + out_2.sum().backward() + assert(torch.allclose(out, out_2)) + + # TODO add reference value for check if __name__ == "__main__": test_coordination_number()