Skip to content

Commit

Permalink
Moved variables to private and fixed reordering bug
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoTrizio committed May 7, 2024
1 parent 2ea7159 commit 204bed0
Showing 1 changed file with 52 additions and 9 deletions.
61 changes: 52 additions & 9 deletions mlcolvar/core/transform/descriptors/coordination_numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np

from mlcolvar.core.transform import Transform
from mlcolvar.core.transform.descriptors.utils import compute_distances_matrix, apply_cutoff
from mlcolvar.core.transform.descriptors.utils import compute_distances_matrix, apply_cutoff, sanitize_positions_shape

from typing import Union

Expand Down Expand Up @@ -57,10 +57,11 @@ def __init__(self,

# parse args
self.group_A = group_A
self.group_A_size = len(group_A)
self._group_A_size = len(group_A)
self.group_B = group_B
self.group_B_size = len(group_B)
self.reordering = np.concatenate([self.group_A, self.group_B])
self._group_B_size = len(group_B)
self._reordering = np.concatenate((self.group_A, self.group_B))


self.cutoff = cutoff

Expand All @@ -74,7 +75,8 @@ def __init__(self,

def compute_coordination_number(self, pos):
# move the group A elements to first positions
# pos = pos[self.reordering]
pos, batch_size = sanitize_positions_shape(pos, self.n_atoms)
pos = pos[:, self._reordering, :]
dist = compute_distances_matrix(pos=pos,
n_atoms=self.n_atoms,
PBC=self.PBC,
Expand All @@ -90,11 +92,11 @@ def compute_coordination_number(self, pos):
switching_function=self.switching_function)

# we can throw away part of the matrix as it is repeated uselessly
contributions = contributions[:, :self.group_A_size, :]
contributions = contributions[:, :self._group_A_size, :]

# and also ensure that the AxA part of the matrix is zero, we need also to preserve the gradients
mask = torch.ones_like(contributions)
mask[:, :self.group_A_size, :self.group_A_size] = 0
mask[:, :self._group_A_size, :self._group_A_size] = 0
contributions = contributions*mask

# compute coordination
Expand Down Expand Up @@ -156,10 +158,51 @@ def test_coordination_number():

out = model(pos)
out.sum().backward()
print(out)

# TODO add reference value for check
# we shift by hand the 0,1 atoms with 2,3
pos = torch.Tensor([[[-0.250341, -0.392700, -1.534535],
[-0.277187, -0.615506, -1.335904],
[-0.410219, -0.680065, -2.016121],
[-0.164329, -0.630426, -2.120843],
[-0.762276, -1.041939, -1.546581],
[-0.200766, -0.851481, -1.534129],
[ 0.051099, -0.898884, -1.628219],
[-1.257225, 1.671602, 0.166190],
[-0.486917, -0.902610, -1.554715],
[-0.020386, -0.566621, -1.597171],
[-0.507683, -0.541252, -1.540805],
[-0.527323, -0.206236, -1.532587]],
[[-0.250672, -0.389610, -1.536810],
[-0.275395, -0.612535, -1.338175],
[-0.410387, -0.677657, -2.018355],
[-0.163502, -0.626094, -2.123348],
[-0.762197, -1.037856, -1.547382],
[-0.200948, -0.847825, -1.536010],
[ 0.051170, -0.896311, -1.629396],
[-1.257530, 1.674078, 0.165089],
[-0.486894, -0.900076, -1.556366],
[-0.020235, -0.563252, -1.601229],
[-0.507242, -0.537527, -1.543025],
[-0.528576, -0.202031, -1.534733]]])

pos.requires_grad = True
switching_function=SwitchingFunctions(in_features=n_atoms*3, name='Rational', cutoff=cutoff, options={'n': 2, 'm' : 6, 'eps' : 1e0})

model = CoordinationNumbers(group_A = [2, 3],
group_B = [0, 1, 4, 5, 6, 7, 8, 9, 10, 11],
cutoff= cutoff,
n_atoms=n_atoms,
PBC=True,
cell=cell,
mode='continuous',
scaled_coords=False,
switching_function=switching_function)

out_2 = model(pos)
out_2.sum().backward()
assert(torch.allclose(out, out_2))

# TODO add reference value for check

if __name__ == "__main__":
test_coordination_number()
Expand Down

0 comments on commit 204bed0

Please sign in to comment.